source: src/GenPoly/Specialize.cc @ 075734f

ADTaaron-thesisarm-ehast-experimentalcleanup-dtorsdeferred_resndemanglerenumforall-pointer-decayjacob/cs343-translationjenkins-sandboxnew-astnew-ast-unique-exprnew-envno_listpersistent-indexerpthread-emulationqualifiedEnumresolv-newwith_gc
Last change on this file since 075734f was 6c3a988f, checked in by Rob Schluntz <rschlunt@…>, 7 years ago

fix inferred parameter data structures to correctly associate parameters with the entity that requested them, modify tuple specialization and unification to work with self-recursive assertions

  • Property mode set to 100644
File size: 15.8 KB
Line 
1//
2// Cforall Version 1.0.0 Copyright (C) 2015 University of Waterloo
3//
4// The contents of this file are covered under the licence agreement in the
5// file "LICENCE" distributed with Cforall.
6//
7// Specialize.cc --
8//
9// Author           : Richard C. Bilson
10// Created On       : Mon May 18 07:44:20 2015
11// Last Modified By : Rob Schluntz
12// Last Modified On : Thu Apr 28 15:17:45 2016
13// Update Count     : 24
14//
15
16#include <cassert>
17
18#include "Specialize.h"
19#include "GenPoly.h"
20#include "PolyMutator.h"
21
22#include "Parser/ParseNode.h"
23
24#include "SynTree/Expression.h"
25#include "SynTree/Statement.h"
26#include "SynTree/Type.h"
27#include "SynTree/Attribute.h"
28#include "SynTree/TypeSubstitution.h"
29#include "SynTree/Mutator.h"
30#include "ResolvExpr/FindOpenVars.h"
31#include "Common/UniqueName.h"
32#include "Common/utility.h"
33#include "InitTweak/InitTweak.h"
34#include "Tuples/Tuples.h"
35
36namespace GenPoly {
37        class Specializer;
38        class Specialize final : public PolyMutator {
39                friend class Specializer;
40          public:
41                using PolyMutator::mutate;
42                virtual Expression * mutate( ApplicationExpr *applicationExpr ) override;
43                virtual Expression * mutate( AddressExpr *castExpr ) override;
44                virtual Expression * mutate( CastExpr *castExpr ) override;
45                // virtual Expression * mutate( LogicalExpr *logicalExpr );
46                // virtual Expression * mutate( ConditionalExpr *conditionalExpr );
47                // virtual Expression * mutate( CommaExpr *commaExpr );
48
49                Specializer * specializer = nullptr;
50                void handleExplicitParams( ApplicationExpr *appExpr );
51        };
52
53        class Specializer {
54          public:
55                Specializer( Specialize & spec ) : spec( spec ), env( spec.env ), stmtsToAdd( spec.stmtsToAdd ) {}
56                virtual bool needsSpecialization( Type * formalType, Type * actualType, TypeSubstitution * env ) = 0;
57                virtual Expression *createThunkFunction( FunctionType *funType, Expression *actual, InferredParams *inferParams ) = 0;
58                virtual Expression *doSpecialization( Type *formalType, Expression *actual, InferredParams *inferParams = 0 );
59
60          protected:
61                Specialize & spec;
62                std::string paramPrefix = "_p";
63                TypeSubstitution *& env;
64                std::list< Statement * > & stmtsToAdd;
65        };
66
67        // for normal polymorphic -> monomorphic function conversion
68        class PolySpecializer : public Specializer {
69          public:
70                PolySpecializer( Specialize & spec ) : Specializer( spec ) {}
71                virtual bool needsSpecialization( Type * formalType, Type * actualType, TypeSubstitution * env ) override;
72                virtual Expression *createThunkFunction( FunctionType *funType, Expression *actual, InferredParams *inferParams ) override;
73        };
74
75        // // for tuple -> non-tuple function conversion
76        class TupleSpecializer : public Specializer {
77          public:
78                TupleSpecializer( Specialize & spec ) : Specializer( spec ) {}
79                virtual bool needsSpecialization( Type * formalType, Type * actualType, TypeSubstitution * env ) override;
80                virtual Expression *createThunkFunction( FunctionType *funType, Expression *actual, InferredParams *inferParams ) override;
81        };
82
83        /// Looks up open variables in actual type, returning true if any of them are bound in the environment or formal type.
84        bool PolySpecializer::needsSpecialization( Type *formalType, Type *actualType, TypeSubstitution *env ) {
85                if ( env ) {
86                        using namespace ResolvExpr;
87                        OpenVarSet openVars, closedVars;
88                        AssertionSet need, have;
89                        findOpenVars( formalType, openVars, closedVars, need, have, false );
90                        findOpenVars( actualType, openVars, closedVars, need, have, true );
91                        for ( OpenVarSet::const_iterator openVar = openVars.begin(); openVar != openVars.end(); ++openVar ) {
92                                Type *boundType = env->lookup( openVar->first );
93                                if ( ! boundType ) continue;
94                                if ( TypeInstType *typeInst = dynamic_cast< TypeInstType* >( boundType ) ) {
95                                        if ( closedVars.find( typeInst->get_name() ) == closedVars.end() ) {
96                                                return true;
97                                        } // if
98                                } else {
99                                        return true;
100                                } // if
101                        } // for
102                        return false;
103                } else {
104                        return false;
105                } // if
106        }
107
108        /// Generates a thunk that calls `actual` with type `funType` and returns its address
109        Expression * PolySpecializer::createThunkFunction( FunctionType *funType, Expression *actual, InferredParams *inferParams ) {
110                static UniqueName thunkNamer( "_thunk" );
111
112                FunctionType *newType = funType->clone();
113                if ( env ) {
114                        // it is important to replace only occurrences of type variables that occur free in the
115                        // thunk's type
116                        env->applyFree( newType );
117                } // if
118                // create new thunk with same signature as formal type (C linkage, empty body)
119                FunctionDecl *thunkFunc = new FunctionDecl( thunkNamer.newName(), DeclarationNode::NoStorageClass, LinkageSpec::C, newType, new CompoundStmt( noLabels ), false, false );
120                thunkFunc->fixUniqueId();
121
122                // thunks may be generated and not used - silence warning with attribute
123                thunkFunc->get_attributes().push_back( new Attribute( "unused" ) );
124
125                // thread thunk parameters into call to actual function, naming thunk parameters as we go
126                UniqueName paramNamer( paramPrefix );
127                ApplicationExpr *appExpr = new ApplicationExpr( actual );
128                for ( std::list< DeclarationWithType* >::iterator param = thunkFunc->get_functionType()->get_parameters().begin(); param != thunkFunc->get_functionType()->get_parameters().end(); ++param ) {
129                        (*param )->set_name( paramNamer.newName() );
130                        appExpr->get_args().push_back( new VariableExpr( *param ) );
131                } // for
132                appExpr->set_env( maybeClone( env ) );
133                if ( inferParams ) {
134                        appExpr->get_inferParams() = *inferParams;
135                } // if
136
137                // handle any specializations that may still be present
138                std::string oldParamPrefix = paramPrefix;
139                paramPrefix += "p";
140                // save stmtsToAdd in oldStmts
141                std::list< Statement* > oldStmts;
142                oldStmts.splice( oldStmts.end(), stmtsToAdd );
143                spec.handleExplicitParams( appExpr );
144                paramPrefix = oldParamPrefix;
145                // write any statements added for recursive specializations into the thunk body
146                thunkFunc->get_statements()->get_kids().splice( thunkFunc->get_statements()->get_kids().end(), stmtsToAdd );
147                // restore oldStmts into stmtsToAdd
148                stmtsToAdd.splice( stmtsToAdd.end(), oldStmts );
149
150                // add return (or valueless expression) to the thunk
151                Statement *appStmt;
152                if ( funType->get_returnVals().empty() ) {
153                        appStmt = new ExprStmt( noLabels, appExpr );
154                } else {
155                        appStmt = new ReturnStmt( noLabels, appExpr );
156                } // if
157                thunkFunc->get_statements()->get_kids().push_back( appStmt );
158
159                // add thunk definition to queue of statements to add
160                stmtsToAdd.push_back( new DeclStmt( noLabels, thunkFunc ) );
161                // return address of thunk function as replacement expression
162                return new AddressExpr( new VariableExpr( thunkFunc ) );
163        }
164
165        Expression * Specializer::doSpecialization( Type *formalType, Expression *actual, InferredParams *inferParams ) {
166                assertf( actual->has_result(), "attempting to specialize an untyped expression" );
167                if ( needsSpecialization( formalType, actual->get_result(), env ) ) {
168                        if ( FunctionType *funType = getFunctionType( formalType ) ) {
169                                ApplicationExpr *appExpr;
170                                VariableExpr *varExpr;
171                                if ( ( appExpr = dynamic_cast<ApplicationExpr*>( actual ) ) ) {
172                                        return createThunkFunction( funType, appExpr->get_function(), inferParams );
173                                } else if ( ( varExpr = dynamic_cast<VariableExpr*>( actual ) ) ) {
174                                        return createThunkFunction( funType, varExpr, inferParams );
175                                } else {
176                                        // This likely won't work, as anything that could build an ApplicationExpr probably hit one of the previous two branches
177                                        return createThunkFunction( funType, actual, inferParams );
178                                }
179                        } else {
180                                return actual;
181                        } // if
182                } else {
183                        return actual;
184                } // if
185        }
186
187        bool TupleSpecializer::needsSpecialization( Type *formalType, Type *actualType, TypeSubstitution *env ) {
188                if ( FunctionType * ftype = getFunctionType( formalType ) ) {
189                        return ftype->isTtype();
190                }
191                return false;
192        }
193
194        /// restructures arg to match the structure of a single formal parameter. Assumes that atomic types are compatible (as the Resolver should have ensured this)
195        template< typename OutIterator >
196        void matchOneFormal( Expression * arg, unsigned & idx, Type * formal, OutIterator out ) {
197                if ( TupleType * tupleType = dynamic_cast< TupleType * >( formal ) ) {
198                        std::list< Expression * > exprs;
199                        for ( Type * t : *tupleType ) {
200                                matchOneFormal( arg, idx, t, back_inserter( exprs ) );
201                        }
202                        *out++ = new TupleExpr( exprs );
203                } else {
204                        *out++ = new TupleIndexExpr( arg->clone(), idx++ );
205                }
206        }
207
208        /// restructures the ttype argument to match the structure of the formal parameters of the actual function.
209        // [begin, end) are the formal parameters.
210        // args is the list of arguments currently given to the actual function, the last of which needs to be restructured.
211        template< typename Iterator, typename OutIterator >
212        void fixLastArg( Expression * last, Iterator begin, Iterator end, OutIterator out ) {
213                // safe_dynamic_cast for the assertion
214                safe_dynamic_cast< TupleType * >( last->get_result() );
215                unsigned idx = 0;
216                for ( ; begin != end; ++begin ) {
217                        DeclarationWithType * formal = *begin;
218                        Type * formalType = formal->get_type();
219                        matchOneFormal( last, idx, formalType, out );
220                }
221                delete last;
222        }
223
224        Expression * TupleSpecializer::createThunkFunction( FunctionType *funType, Expression *actual, InferredParams *inferParams ) {
225                static UniqueName thunkNamer( "_tupleThunk" );
226
227                FunctionType *newType = funType->clone();
228                if ( env ) {
229                        // it is important to replace only occurrences of type variables that occur free in the
230                        // thunk's type
231                        env->applyFree( newType );
232                } // if
233                // create new thunk with same signature as formal type (C linkage, empty body)
234                FunctionDecl *thunkFunc = new FunctionDecl( thunkNamer.newName(), DeclarationNode::NoStorageClass, LinkageSpec::C, newType, new CompoundStmt( noLabels ), false, false );
235                thunkFunc->fixUniqueId();
236
237                // thunks may be generated and not used - silence warning with attribute
238                thunkFunc->get_attributes().push_back( new Attribute( "unused" ) );
239
240                // thread thunk parameters into call to actual function, naming thunk parameters as we go
241                UniqueName paramNamer( paramPrefix );
242                ApplicationExpr *appExpr = new ApplicationExpr( actual );
243
244                FunctionType * actualType = getFunctionType( actual->get_result() )->clone();
245                if ( env ) {
246                        // need to apply the environment to the actual function's type, since it may itself be polymorphic
247                        env->apply( actualType );
248                }
249                std::unique_ptr< FunctionType > actualTypeManager( actualType ); // for RAII
250                std::list< DeclarationWithType * >::iterator actualBegin = actualType->get_parameters().begin();
251                std::list< DeclarationWithType * >::iterator actualEnd = actualType->get_parameters().end();
252                std::list< DeclarationWithType * >::iterator formalBegin = funType->get_parameters().begin();
253                std::list< DeclarationWithType * >::iterator formalEnd = funType->get_parameters().end();
254
255                Expression * last = nullptr;
256                for ( DeclarationWithType* param : thunkFunc->get_functionType()->get_parameters() ) {
257                        // walk the parameters to the actual function alongside the parameters to the thunk to find the location where the ttype parameter begins to satisfy parameters in the actual function.
258                        param->set_name( paramNamer.newName() );
259                        assertf( formalBegin != formalEnd, "Reached end of formal parameters before finding ttype parameter" );
260                        if ( Tuples::isTtype((*formalBegin)->get_type()) ) {
261                                last = new VariableExpr( param );
262                                break;
263                        }
264                        assertf( actualBegin != actualEnd, "reached end of actual function's arguments before finding ttype parameter" );
265                        ++actualBegin;
266                        ++formalBegin;
267
268                        appExpr->get_args().push_back( new VariableExpr( param ) );
269                } // for
270                assert( last );
271                fixLastArg( last, actualBegin, actualEnd, back_inserter( appExpr->get_args() ) );
272                appExpr->set_env( maybeClone( env ) );
273                if ( inferParams ) {
274                        appExpr->get_inferParams() = *inferParams;
275                } // if
276
277                // handle any specializations that may still be present
278                std::string oldParamPrefix = paramPrefix;
279                paramPrefix += "p";
280                // save stmtsToAdd in oldStmts
281                std::list< Statement* > oldStmts;
282                oldStmts.splice( oldStmts.end(), stmtsToAdd );
283                spec.mutate( appExpr );
284                paramPrefix = oldParamPrefix;
285                // write any statements added for recursive specializations into the thunk body
286                thunkFunc->get_statements()->get_kids().splice( thunkFunc->get_statements()->get_kids().end(), stmtsToAdd );
287                // restore oldStmts into stmtsToAdd
288                stmtsToAdd.splice( stmtsToAdd.end(), oldStmts );
289
290                // add return (or valueless expression) to the thunk
291                Statement *appStmt;
292                if ( funType->get_returnVals().empty() ) {
293                        appStmt = new ExprStmt( noLabels, appExpr );
294                } else {
295                        appStmt = new ReturnStmt( noLabels, appExpr );
296                } // if
297                thunkFunc->get_statements()->get_kids().push_back( appStmt );
298
299                // add thunk definition to queue of statements to add
300                stmtsToAdd.push_back( new DeclStmt( noLabels, thunkFunc ) );
301                // return address of thunk function as replacement expression
302                return new AddressExpr( new VariableExpr( thunkFunc ) );
303        }
304
305        void Specialize::handleExplicitParams( ApplicationExpr *appExpr ) {
306                // create thunks for the explicit parameters
307                assert( appExpr->get_function()->has_result() );
308                FunctionType *function = getFunctionType( appExpr->get_function()->get_result() );
309                assert( function );
310                std::list< DeclarationWithType* >::iterator formal;
311                std::list< Expression* >::iterator actual;
312                for ( formal = function->get_parameters().begin(), actual = appExpr->get_args().begin(); formal != function->get_parameters().end() && actual != appExpr->get_args().end(); ++formal, ++actual ) {
313                        *actual = specializer->doSpecialization( (*formal )->get_type(), *actual, &appExpr->get_inferParams() );
314                }
315        }
316
317        Expression * Specialize::mutate( ApplicationExpr *appExpr ) {
318                appExpr->get_function()->acceptMutator( *this );
319                mutateAll( appExpr->get_args(), *this );
320
321                if ( ! InitTweak::isIntrinsicCallExpr( appExpr ) ) {
322                        // create thunks for the inferred parameters
323                        // don't need to do this for intrinsic calls, because they aren't actually passed
324                        for ( InferredParams::iterator inferParam = appExpr->get_inferParams().begin(); inferParam != appExpr->get_inferParams().end(); ++inferParam ) {
325                                inferParam->second.expr = specializer->doSpecialization( inferParam->second.formalType, inferParam->second.expr, inferParam->second.inferParams.get() );
326                        }
327                        handleExplicitParams( appExpr );
328                }
329                return appExpr;
330        }
331
332        Expression * Specialize::mutate( AddressExpr *addrExpr ) {
333                addrExpr->get_arg()->acceptMutator( *this );
334                assert( addrExpr->has_result() );
335                addrExpr->set_arg( specializer->doSpecialization( addrExpr->get_result(), addrExpr->get_arg() ) );
336                return addrExpr;
337        }
338
339        Expression * Specialize::mutate( CastExpr *castExpr ) {
340                castExpr->get_arg()->acceptMutator( *this );
341                if ( castExpr->get_result()->isVoid() ) {
342                        // can't specialize if we don't have a return value
343                        return castExpr;
344                }
345                Expression *specialized = specializer->doSpecialization( castExpr->get_result(), castExpr->get_arg() );
346                if ( specialized != castExpr->get_arg() ) {
347                        // assume here that the specialization incorporates the cast
348                        return specialized;
349                } else {
350                        return castExpr;
351                }
352        }
353
354        // Removing these for now. Richard put these in for some reason, but it's not clear why.
355        // In particular, copy constructors produce a comma expression, and with this code the parts
356        // of that comma expression are not specialized, which causes problems.
357
358        // Expression * Specialize::mutate( LogicalExpr *logicalExpr ) {
359        //      return logicalExpr;
360        // }
361
362        // Expression * Specialize::mutate( ConditionalExpr *condExpr ) {
363        //      return condExpr;
364        // }
365
366        // Expression * Specialize::mutate( CommaExpr *commaExpr ) {
367        //      return commaExpr;
368        // }
369
370        void convertSpecializations( std::list< Declaration* >& translationUnit ) {
371                Specialize spec;
372
373                TupleSpecializer tupleSpec( spec );
374                spec.specializer = &tupleSpec;
375                mutateAll( translationUnit, spec );
376
377                PolySpecializer polySpec( spec );
378                spec.specializer = &polySpec;
379                mutateAll( translationUnit, spec );
380        }
381} // namespace GenPoly
382
383// Local Variables: //
384// tab-width: 4 //
385// mode: c++ //
386// compile-command: "make install" //
387// End: //
Note: See TracBrowser for help on using the repository browser.