source: src/GenPoly/Specialize.cc @ 64eae56

ADTaaron-thesisarm-ehast-experimentalcleanup-dtorsdeferred_resndemanglerenumforall-pointer-decayjacob/cs343-translationjenkins-sandboxnew-astnew-ast-unique-exprnew-envno_listpersistent-indexerpthread-emulationqualifiedEnumresolv-newwith_gc
Last change on this file since 64eae56 was 64eae56, checked in by Rob Schluntz <rschlunt@…>, 8 years ago

match formal parameter type of actual function when specializing ttype parameter, flatten types when unifying ttype parameters

  • Property mode set to 100644
File size: 15.7 KB
Line 
1//
2// Cforall Version 1.0.0 Copyright (C) 2015 University of Waterloo
3//
4// The contents of this file are covered under the licence agreement in the
5// file "LICENCE" distributed with Cforall.
6//
7// Specialize.cc --
8//
9// Author           : Richard C. Bilson
10// Created On       : Mon May 18 07:44:20 2015
11// Last Modified By : Rob Schluntz
12// Last Modified On : Thu Apr 28 15:17:45 2016
13// Update Count     : 24
14//
15
16#include <cassert>
17
18#include "Specialize.h"
19#include "GenPoly.h"
20#include "PolyMutator.h"
21
22#include "Parser/ParseNode.h"
23
24#include "SynTree/Expression.h"
25#include "SynTree/Statement.h"
26#include "SynTree/Type.h"
27#include "SynTree/Attribute.h"
28#include "SynTree/TypeSubstitution.h"
29#include "SynTree/Mutator.h"
30#include "ResolvExpr/FindOpenVars.h"
31#include "Common/UniqueName.h"
32#include "Common/utility.h"
33#include "InitTweak/InitTweak.h"
34
35namespace GenPoly {
36        class Specializer;
37        class Specialize final : public PolyMutator {
38                friend class Specializer;
39          public:
40                using PolyMutator::mutate;
41                virtual Expression * mutate( ApplicationExpr *applicationExpr ) override;
42                virtual Expression * mutate( AddressExpr *castExpr ) override;
43                virtual Expression * mutate( CastExpr *castExpr ) override;
44                // virtual Expression * mutate( LogicalExpr *logicalExpr );
45                // virtual Expression * mutate( ConditionalExpr *conditionalExpr );
46                // virtual Expression * mutate( CommaExpr *commaExpr );
47
48                Specializer * specializer = nullptr;
49                void handleExplicitParams( ApplicationExpr *appExpr );
50        };
51
52        class Specializer {
53          public:
54                Specializer( Specialize & spec ) : spec( spec ), env( spec.env ), stmtsToAdd( spec.stmtsToAdd ) {}
55                virtual bool needsSpecialization( Type * formalType, Type * actualType, TypeSubstitution * env ) = 0;
56                virtual Expression *createThunkFunction( FunctionType *funType, Expression *actual, InferredParams *inferParams ) = 0;
57                virtual Expression *doSpecialization( Type *formalType, Expression *actual, InferredParams *inferParams = 0 );
58
59          protected:
60                Specialize & spec;
61                std::string paramPrefix = "_p";
62                TypeSubstitution *& env;
63                std::list< Statement * > & stmtsToAdd;
64        };
65
66        // for normal polymorphic -> monomorphic function conversion
67        class PolySpecializer : public Specializer {
68          public:
69                PolySpecializer( Specialize & spec ) : Specializer( spec ) {}
70                virtual bool needsSpecialization( Type * formalType, Type * actualType, TypeSubstitution * env ) override;
71                virtual Expression *createThunkFunction( FunctionType *funType, Expression *actual, InferredParams *inferParams ) override;
72        };
73
74        // // for tuple -> non-tuple function conversion
75        class TupleSpecializer : public Specializer {
76          public:
77                TupleSpecializer( Specialize & spec ) : Specializer( spec ) {}
78                virtual bool needsSpecialization( Type * formalType, Type * actualType, TypeSubstitution * env ) override;
79                virtual Expression *createThunkFunction( FunctionType *funType, Expression *actual, InferredParams *inferParams ) override;
80        };
81
82        /// Looks up open variables in actual type, returning true if any of them are bound in the environment or formal type.
83        bool PolySpecializer::needsSpecialization( Type *formalType, Type *actualType, TypeSubstitution *env ) {
84                if ( env ) {
85                        using namespace ResolvExpr;
86                        OpenVarSet openVars, closedVars;
87                        AssertionSet need, have;
88                        findOpenVars( formalType, openVars, closedVars, need, have, false );
89                        findOpenVars( actualType, openVars, closedVars, need, have, true );
90                        for ( OpenVarSet::const_iterator openVar = openVars.begin(); openVar != openVars.end(); ++openVar ) {
91                                Type *boundType = env->lookup( openVar->first );
92                                if ( ! boundType ) continue;
93                                if ( TypeInstType *typeInst = dynamic_cast< TypeInstType* >( boundType ) ) {
94                                        if ( closedVars.find( typeInst->get_name() ) == closedVars.end() ) {
95                                                return true;
96                                        } // if
97                                } else {
98                                        return true;
99                                } // if
100                        } // for
101                        return false;
102                } else {
103                        return false;
104                } // if
105        }
106
107        /// Generates a thunk that calls `actual` with type `funType` and returns its address
108        Expression * PolySpecializer::createThunkFunction( FunctionType *funType, Expression *actual, InferredParams *inferParams ) {
109                static UniqueName thunkNamer( "_thunk" );
110
111                FunctionType *newType = funType->clone();
112                if ( env ) {
113                        TypeSubstitution newEnv( *env );
114                        // it is important to replace only occurrences of type variables that occur free in the
115                        // thunk's type
116                        newEnv.applyFree( newType );
117                } // if
118                // create new thunk with same signature as formal type (C linkage, empty body)
119                FunctionDecl *thunkFunc = new FunctionDecl( thunkNamer.newName(), DeclarationNode::NoStorageClass, LinkageSpec::C, newType, new CompoundStmt( noLabels ), false, false );
120                thunkFunc->fixUniqueId();
121
122                // thunks may be generated and not used - silence warning with attribute
123                thunkFunc->get_attributes().push_back( new Attribute( "unused" ) );
124
125                // thread thunk parameters into call to actual function, naming thunk parameters as we go
126                UniqueName paramNamer( paramPrefix );
127                ApplicationExpr *appExpr = new ApplicationExpr( actual );
128                for ( std::list< DeclarationWithType* >::iterator param = thunkFunc->get_functionType()->get_parameters().begin(); param != thunkFunc->get_functionType()->get_parameters().end(); ++param ) {
129                        (*param )->set_name( paramNamer.newName() );
130                        appExpr->get_args().push_back( new VariableExpr( *param ) );
131                } // for
132                appExpr->set_env( maybeClone( env ) );
133                if ( inferParams ) {
134                        appExpr->get_inferParams() = *inferParams;
135                } // if
136
137                // handle any specializations that may still be present
138                std::string oldParamPrefix = paramPrefix;
139                paramPrefix += "p";
140                // save stmtsToAdd in oldStmts
141                std::list< Statement* > oldStmts;
142                oldStmts.splice( oldStmts.end(), stmtsToAdd );
143                spec.handleExplicitParams( appExpr );
144                paramPrefix = oldParamPrefix;
145                // write any statements added for recursive specializations into the thunk body
146                thunkFunc->get_statements()->get_kids().splice( thunkFunc->get_statements()->get_kids().end(), stmtsToAdd );
147                // restore oldStmts into stmtsToAdd
148                stmtsToAdd.splice( stmtsToAdd.end(), oldStmts );
149
150                // add return (or valueless expression) to the thunk
151                Statement *appStmt;
152                if ( funType->get_returnVals().empty() ) {
153                        appStmt = new ExprStmt( noLabels, appExpr );
154                } else {
155                        appStmt = new ReturnStmt( noLabels, appExpr );
156                } // if
157                thunkFunc->get_statements()->get_kids().push_back( appStmt );
158
159                // add thunk definition to queue of statements to add
160                stmtsToAdd.push_back( new DeclStmt( noLabels, thunkFunc ) );
161                // return address of thunk function as replacement expression
162                return new AddressExpr( new VariableExpr( thunkFunc ) );
163        }
164
165        Expression * Specializer::doSpecialization( Type *formalType, Expression *actual, InferredParams *inferParams ) {
166                assertf( actual->has_result(), "attempting to specialize an untyped expression" );
167                if ( needsSpecialization( formalType, actual->get_result(), env ) ) {
168                        FunctionType *funType;
169                        if ( ( funType = getFunctionType( formalType ) ) ) {
170                                ApplicationExpr *appExpr;
171                                VariableExpr *varExpr;
172                                if ( ( appExpr = dynamic_cast<ApplicationExpr*>( actual ) ) ) {
173                                        return createThunkFunction( funType, appExpr->get_function(), inferParams );
174                                } else if ( ( varExpr = dynamic_cast<VariableExpr*>( actual ) ) ) {
175                                        return createThunkFunction( funType, varExpr, inferParams );
176                                } else {
177                                        // This likely won't work, as anything that could build an ApplicationExpr probably hit one of the previous two branches
178                                        return createThunkFunction( funType, actual, inferParams );
179                                }
180                        } else {
181                                return actual;
182                        } // if
183                } else {
184                        return actual;
185                } // if
186        }
187
188        bool TupleSpecializer::needsSpecialization( Type *formalType, Type *actualType, TypeSubstitution *env ) {
189                // std::cerr << "asking if type needs tuple spec: " << formalType << std::endl;
190                if ( FunctionType * ftype = getFunctionType( formalType ) ) {
191                        return ftype->isTtype();
192                }
193                return false;
194        }
195
196        /// restructures arg to match the structure of a single formal parameter. Assumes that atomic types are compatible (as the Resolver should have ensured this)
197        template< typename OutIterator >
198        void matchOneFormal( Expression * arg, unsigned & idx, Type * formal, OutIterator out ) {
199                if ( TupleType * tupleType = dynamic_cast< TupleType * >( formal ) ) {
200                        std::list< Expression * > exprs;
201                        for ( Type * t : *tupleType ) {
202                                matchOneFormal( arg, idx, t, back_inserter( exprs ) );
203                        }
204                        *out++ = new TupleExpr( exprs );
205                } else {
206                        *out++ = new TupleIndexExpr( arg->clone(), idx++ );
207                }
208        }
209
210        /// restructures the ttype argument to match the structure of the formal parameters of the actual function.
211        // [begin, end) are the formal parameters.
212        // args is the list of arguments currently given to the actual function, the last of which needs to be restructured.
213        template< typename Iterator >
214        void fixLastArg( std::list< Expression * > & args, Iterator begin, Iterator end ) {
215                assertf( ! args.empty(), "Somehow args to tuple function are empty" ); // xxx - it's quite possible this will trigger for the nullary case...
216                Expression * last = args.back();
217                // safe_dynamic_cast for the assertion
218                safe_dynamic_cast< TupleType * >( last->get_result() ); // xxx - it's quite possible this will trigger for the unary case...
219                args.pop_back(); // replace last argument in the call with
220                unsigned idx = 0;
221                for ( ; begin != end; ++begin ) {
222                        DeclarationWithType * formal = *begin;
223                        Type * formalType = formal->get_type();
224                        matchOneFormal( last, idx, formalType, back_inserter( args ) );
225                }
226                delete last;
227        }
228
229        Expression * TupleSpecializer::createThunkFunction( FunctionType *funType, Expression *actual, InferredParams *inferParams ) {
230                static UniqueName thunkNamer( "_tupleThunk" );
231                // std::cerr << "creating tuple thunk for " << funType << std::endl;
232
233                FunctionType *newType = funType->clone();
234                if ( env ) {
235                        TypeSubstitution newEnv( *env );
236                        // it is important to replace only occurrences of type variables that occur free in the
237                        // thunk's type
238                        newEnv.applyFree( newType );
239                } // if
240                // create new thunk with same signature as formal type (C linkage, empty body)
241                FunctionDecl *thunkFunc = new FunctionDecl( thunkNamer.newName(), DeclarationNode::NoStorageClass, LinkageSpec::C, newType, new CompoundStmt( noLabels ), false, false );
242                thunkFunc->fixUniqueId();
243
244                // thunks may be generated and not used - silence warning with attribute
245                thunkFunc->get_attributes().push_back( new Attribute( "unused" ) );
246
247                // thread thunk parameters into call to actual function, naming thunk parameters as we go
248                UniqueName paramNamer( paramPrefix );
249                ApplicationExpr *appExpr = new ApplicationExpr( actual );
250                // std::cerr << actual << std::endl;
251
252                FunctionType * actualType = getFunctionType( actual->get_result() );
253                std::list< DeclarationWithType * >::iterator begin = actualType->get_parameters().begin();
254                std::list< DeclarationWithType * >::iterator end = actualType->get_parameters().end();
255
256                for ( DeclarationWithType* param : thunkFunc->get_functionType()->get_parameters() ) {
257                        // walk the parameters to the actual function alongside the parameters to the thunk to find the location where the ttype parameter begins to satisfy parameters in the actual function.
258                        assert( begin != end );
259                        ++begin;
260
261                        // std::cerr << "thunk param: " << param << std::endl;
262                        // last param will always be a tuple type... expand it into the actual type(?)
263                        param->set_name( paramNamer.newName() );
264                        appExpr->get_args().push_back( new VariableExpr( param ) );
265                } // for
266                fixLastArg( appExpr->get_args(), --begin, end );
267                appExpr->set_env( maybeClone( env ) );
268                if ( inferParams ) {
269                        appExpr->get_inferParams() = *inferParams;
270                } // if
271
272                // handle any specializations that may still be present
273                std::string oldParamPrefix = paramPrefix;
274                paramPrefix += "p";
275                // save stmtsToAdd in oldStmts
276                std::list< Statement* > oldStmts;
277                oldStmts.splice( oldStmts.end(), stmtsToAdd );
278                spec.handleExplicitParams( appExpr );
279                paramPrefix = oldParamPrefix;
280                // write any statements added for recursive specializations into the thunk body
281                thunkFunc->get_statements()->get_kids().splice( thunkFunc->get_statements()->get_kids().end(), stmtsToAdd );
282                // restore oldStmts into stmtsToAdd
283                stmtsToAdd.splice( stmtsToAdd.end(), oldStmts );
284
285                // add return (or valueless expression) to the thunk
286                Statement *appStmt;
287                if ( funType->get_returnVals().empty() ) {
288                        appStmt = new ExprStmt( noLabels, appExpr );
289                } else {
290                        appStmt = new ReturnStmt( noLabels, appExpr );
291                } // if
292                thunkFunc->get_statements()->get_kids().push_back( appStmt );
293
294                // std::cerr << "thunkFunc is: " << thunkFunc << std::endl;
295
296                // add thunk definition to queue of statements to add
297                stmtsToAdd.push_back( new DeclStmt( noLabels, thunkFunc ) );
298                // return address of thunk function as replacement expression
299                return new AddressExpr( new VariableExpr( thunkFunc ) );
300        }
301
302        void Specialize::handleExplicitParams( ApplicationExpr *appExpr ) {
303                // create thunks for the explicit parameters
304                assert( appExpr->get_function()->has_result() );
305                FunctionType *function = getFunctionType( appExpr->get_function()->get_result() );
306                assert( function );
307                std::list< DeclarationWithType* >::iterator formal;
308                std::list< Expression* >::iterator actual;
309                for ( formal = function->get_parameters().begin(), actual = appExpr->get_args().begin(); formal != function->get_parameters().end() && actual != appExpr->get_args().end(); ++formal, ++actual ) {
310                        *actual = specializer->doSpecialization( (*formal )->get_type(), *actual, &appExpr->get_inferParams() );
311                }
312        }
313
314        Expression * Specialize::mutate( ApplicationExpr *appExpr ) {
315                appExpr->get_function()->acceptMutator( *this );
316                mutateAll( appExpr->get_args(), *this );
317
318                if ( ! InitTweak::isIntrinsicCallExpr( appExpr ) ) {
319                        // create thunks for the inferred parameters
320                        // don't need to do this for intrinsic calls, because they aren't actually passed
321                        for ( InferredParams::iterator inferParam = appExpr->get_inferParams().begin(); inferParam != appExpr->get_inferParams().end(); ++inferParam ) {
322                                inferParam->second.expr = specializer->doSpecialization( inferParam->second.formalType, inferParam->second.expr, &appExpr->get_inferParams() );
323                        }
324                        handleExplicitParams( appExpr );
325                }
326                return appExpr;
327        }
328
329        Expression * Specialize::mutate( AddressExpr *addrExpr ) {
330                addrExpr->get_arg()->acceptMutator( *this );
331                assert( addrExpr->has_result() );
332                addrExpr->set_arg( specializer->doSpecialization( addrExpr->get_result(), addrExpr->get_arg() ) );
333                return addrExpr;
334        }
335
336        Expression * Specialize::mutate( CastExpr *castExpr ) {
337                castExpr->get_arg()->acceptMutator( *this );
338                if ( castExpr->get_result()->isVoid() ) {
339                        // can't specialize if we don't have a return value
340                        return castExpr;
341                }
342                Expression *specialized = specializer->doSpecialization( castExpr->get_result(), castExpr->get_arg() );
343                if ( specialized != castExpr->get_arg() ) {
344                        // assume here that the specialization incorporates the cast
345                        return specialized;
346                } else {
347                        return castExpr;
348                }
349        }
350
351        // Removing these for now. Richard put these in for some reason, but it's not clear why.
352        // In particular, copy constructors produce a comma expression, and with this code the parts
353        // of that comma expression are not specialized, which causes problems.
354
355        // Expression * Specialize::mutate( LogicalExpr *logicalExpr ) {
356        //      return logicalExpr;
357        // }
358
359        // Expression * Specialize::mutate( ConditionalExpr *condExpr ) {
360        //      return condExpr;
361        // }
362
363        // Expression * Specialize::mutate( CommaExpr *commaExpr ) {
364        //      return commaExpr;
365        // }
366
367        void convertSpecializations( std::list< Declaration* >& translationUnit ) {
368                Specialize spec;
369
370                TupleSpecializer tupleSpec( spec );
371                spec.specializer = &tupleSpec;
372                mutateAll( translationUnit, spec );
373
374                PolySpecializer polySpec( spec );
375                spec.specializer = &polySpec;
376                mutateAll( translationUnit, spec );
377        }
378} // namespace GenPoly
379
380// Local Variables: //
381// tab-width: 4 //
382// mode: c++ //
383// compile-command: "make install" //
384// End: //
Note: See TracBrowser for help on using the repository browser.