source: src/GenPoly/Specialize.cc @ 4c8621ac

aaron-thesisarm-ehcleanup-dtorsdeferred_resndemanglerjacob/cs343-translationjenkins-sandboxnew-astnew-ast-unique-exprnew-envno_listpersistent-indexerresolv-newwith_gc
Last change on this file since 4c8621ac was 4c8621ac, checked in by Rob Schluntz <rschlunt@…>, 5 years ago

allow construction, destruction, and assignment for empty tuples, allow matching a ttype parameter with an empty tuple, fix specialization to work with empty tuples and void functions

  • Property mode set to 100644
File size: 16.0 KB
Line 
1//
2// Cforall Version 1.0.0 Copyright (C) 2015 University of Waterloo
3//
4// The contents of this file are covered under the licence agreement in the
5// file "LICENCE" distributed with Cforall.
6//
7// Specialize.cc --
8//
9// Author           : Richard C. Bilson
10// Created On       : Mon May 18 07:44:20 2015
11// Last Modified By : Rob Schluntz
12// Last Modified On : Thu Apr 28 15:17:45 2016
13// Update Count     : 24
14//
15
16#include <cassert>
17
18#include "Specialize.h"
19#include "GenPoly.h"
20#include "PolyMutator.h"
21
22#include "Parser/ParseNode.h"
23
24#include "SynTree/Expression.h"
25#include "SynTree/Statement.h"
26#include "SynTree/Type.h"
27#include "SynTree/Attribute.h"
28#include "SynTree/TypeSubstitution.h"
29#include "SynTree/Mutator.h"
30#include "ResolvExpr/FindOpenVars.h"
31#include "Common/UniqueName.h"
32#include "Common/utility.h"
33#include "InitTweak/InitTweak.h"
34#include "Tuples/Tuples.h"
35
36namespace GenPoly {
37        class Specializer;
38        class Specialize final : public PolyMutator {
39                friend class Specializer;
40          public:
41                using PolyMutator::mutate;
42                virtual Expression * mutate( ApplicationExpr *applicationExpr ) override;
43                virtual Expression * mutate( AddressExpr *castExpr ) override;
44                virtual Expression * mutate( CastExpr *castExpr ) override;
45                // virtual Expression * mutate( LogicalExpr *logicalExpr );
46                // virtual Expression * mutate( ConditionalExpr *conditionalExpr );
47                // virtual Expression * mutate( CommaExpr *commaExpr );
48
49                Specializer * specializer = nullptr;
50                void handleExplicitParams( ApplicationExpr *appExpr );
51        };
52
53        class Specializer {
54          public:
55                Specializer( Specialize & spec ) : spec( spec ), env( spec.env ), stmtsToAdd( spec.stmtsToAdd ) {}
56                virtual bool needsSpecialization( Type * formalType, Type * actualType, TypeSubstitution * env ) = 0;
57                virtual Expression *createThunkFunction( FunctionType *funType, Expression *actual, InferredParams *inferParams ) = 0;
58                virtual Expression *doSpecialization( Type *formalType, Expression *actual, InferredParams *inferParams = 0 );
59
60          protected:
61                Specialize & spec;
62                std::string paramPrefix = "_p";
63                TypeSubstitution *& env;
64                std::list< Statement * > & stmtsToAdd;
65        };
66
67        // for normal polymorphic -> monomorphic function conversion
68        class PolySpecializer : public Specializer {
69          public:
70                PolySpecializer( Specialize & spec ) : Specializer( spec ) {}
71                virtual bool needsSpecialization( Type * formalType, Type * actualType, TypeSubstitution * env ) override;
72                virtual Expression *createThunkFunction( FunctionType *funType, Expression *actual, InferredParams *inferParams ) override;
73        };
74
75        // // for tuple -> non-tuple function conversion
76        class TupleSpecializer : public Specializer {
77          public:
78                TupleSpecializer( Specialize & spec ) : Specializer( spec ) {}
79                virtual bool needsSpecialization( Type * formalType, Type * actualType, TypeSubstitution * env ) override;
80                virtual Expression *createThunkFunction( FunctionType *funType, Expression *actual, InferredParams *inferParams ) override;
81        };
82
83        /// Looks up open variables in actual type, returning true if any of them are bound in the environment or formal type.
84        bool PolySpecializer::needsSpecialization( Type *formalType, Type *actualType, TypeSubstitution *env ) {
85                if ( env ) {
86                        using namespace ResolvExpr;
87                        OpenVarSet openVars, closedVars;
88                        AssertionSet need, have;
89                        findOpenVars( formalType, openVars, closedVars, need, have, false );
90                        findOpenVars( actualType, openVars, closedVars, need, have, true );
91                        for ( OpenVarSet::const_iterator openVar = openVars.begin(); openVar != openVars.end(); ++openVar ) {
92                                Type *boundType = env->lookup( openVar->first );
93                                if ( ! boundType ) continue;
94                                if ( TypeInstType *typeInst = dynamic_cast< TypeInstType* >( boundType ) ) {
95                                        if ( closedVars.find( typeInst->get_name() ) == closedVars.end() ) {
96                                                return true;
97                                        } // if
98                                } else {
99                                        return true;
100                                } // if
101                        } // for
102                        return false;
103                } else {
104                        return false;
105                } // if
106        }
107
108        /// Generates a thunk that calls `actual` with type `funType` and returns its address
109        Expression * PolySpecializer::createThunkFunction( FunctionType *funType, Expression *actual, InferredParams *inferParams ) {
110                static UniqueName thunkNamer( "_thunk" );
111
112                FunctionType *newType = funType->clone();
113                if ( env ) {
114                        TypeSubstitution newEnv( *env );
115                        // it is important to replace only occurrences of type variables that occur free in the
116                        // thunk's type
117                        newEnv.applyFree( newType );
118                } // if
119                // create new thunk with same signature as formal type (C linkage, empty body)
120                FunctionDecl *thunkFunc = new FunctionDecl( thunkNamer.newName(), DeclarationNode::NoStorageClass, LinkageSpec::C, newType, new CompoundStmt( noLabels ), false, false );
121                thunkFunc->fixUniqueId();
122
123                // thunks may be generated and not used - silence warning with attribute
124                thunkFunc->get_attributes().push_back( new Attribute( "unused" ) );
125
126                // thread thunk parameters into call to actual function, naming thunk parameters as we go
127                UniqueName paramNamer( paramPrefix );
128                ApplicationExpr *appExpr = new ApplicationExpr( actual );
129                for ( std::list< DeclarationWithType* >::iterator param = thunkFunc->get_functionType()->get_parameters().begin(); param != thunkFunc->get_functionType()->get_parameters().end(); ++param ) {
130                        (*param )->set_name( paramNamer.newName() );
131                        appExpr->get_args().push_back( new VariableExpr( *param ) );
132                } // for
133                appExpr->set_env( maybeClone( env ) );
134                if ( inferParams ) {
135                        appExpr->get_inferParams() = *inferParams;
136                } // if
137
138                // handle any specializations that may still be present
139                std::string oldParamPrefix = paramPrefix;
140                paramPrefix += "p";
141                // save stmtsToAdd in oldStmts
142                std::list< Statement* > oldStmts;
143                oldStmts.splice( oldStmts.end(), stmtsToAdd );
144                spec.handleExplicitParams( appExpr );
145                paramPrefix = oldParamPrefix;
146                // write any statements added for recursive specializations into the thunk body
147                thunkFunc->get_statements()->get_kids().splice( thunkFunc->get_statements()->get_kids().end(), stmtsToAdd );
148                // restore oldStmts into stmtsToAdd
149                stmtsToAdd.splice( stmtsToAdd.end(), oldStmts );
150
151                // add return (or valueless expression) to the thunk
152                Statement *appStmt;
153                if ( funType->get_returnVals().empty() ) {
154                        appStmt = new ExprStmt( noLabels, appExpr );
155                } else {
156                        appStmt = new ReturnStmt( noLabels, appExpr );
157                } // if
158                thunkFunc->get_statements()->get_kids().push_back( appStmt );
159
160                // add thunk definition to queue of statements to add
161                stmtsToAdd.push_back( new DeclStmt( noLabels, thunkFunc ) );
162                // return address of thunk function as replacement expression
163                return new AddressExpr( new VariableExpr( thunkFunc ) );
164        }
165
166        Expression * Specializer::doSpecialization( Type *formalType, Expression *actual, InferredParams *inferParams ) {
167                assertf( actual->has_result(), "attempting to specialize an untyped expression" );
168                if ( needsSpecialization( formalType, actual->get_result(), env ) ) {
169                        FunctionType *funType;
170                        if ( ( funType = getFunctionType( formalType ) ) ) {
171                                ApplicationExpr *appExpr;
172                                VariableExpr *varExpr;
173                                if ( ( appExpr = dynamic_cast<ApplicationExpr*>( actual ) ) ) {
174                                        return createThunkFunction( funType, appExpr->get_function(), inferParams );
175                                } else if ( ( varExpr = dynamic_cast<VariableExpr*>( actual ) ) ) {
176                                        return createThunkFunction( funType, varExpr, inferParams );
177                                } else {
178                                        // This likely won't work, as anything that could build an ApplicationExpr probably hit one of the previous two branches
179                                        return createThunkFunction( funType, actual, inferParams );
180                                }
181                        } else {
182                                return actual;
183                        } // if
184                } else {
185                        return actual;
186                } // if
187        }
188
189        bool TupleSpecializer::needsSpecialization( Type *formalType, Type *actualType, TypeSubstitution *env ) {
190                // std::cerr << "asking if type needs tuple spec: " << formalType << std::endl;
191                if ( FunctionType * ftype = getFunctionType( formalType ) ) {
192                        return ftype->isTtype();
193                }
194                return false;
195        }
196
197        /// restructures arg to match the structure of a single formal parameter. Assumes that atomic types are compatible (as the Resolver should have ensured this)
198        template< typename OutIterator >
199        void matchOneFormal( Expression * arg, unsigned & idx, Type * formal, OutIterator out ) {
200                if ( TupleType * tupleType = dynamic_cast< TupleType * >( formal ) ) {
201                        std::list< Expression * > exprs;
202                        for ( Type * t : *tupleType ) {
203                                matchOneFormal( arg, idx, t, back_inserter( exprs ) );
204                        }
205                        *out++ = new TupleExpr( exprs );
206                } else {
207                        *out++ = new TupleIndexExpr( arg->clone(), idx++ );
208                }
209        }
210
211        /// restructures the ttype argument to match the structure of the formal parameters of the actual function.
212        // [begin, end) are the formal parameters.
213        // args is the list of arguments currently given to the actual function, the last of which needs to be restructured.
214        template< typename Iterator, typename OutIterator >
215        void fixLastArg( Expression * last, Iterator begin, Iterator end, OutIterator out ) {
216                // safe_dynamic_cast for the assertion
217                safe_dynamic_cast< TupleType * >( last->get_result() ); // xxx - it's quite possible this will trigger for the unary case...
218                unsigned idx = 0;
219                for ( ; begin != end; ++begin ) {
220                        DeclarationWithType * formal = *begin;
221                        Type * formalType = formal->get_type();
222                        matchOneFormal( last, idx, formalType, out );
223                }
224                delete last;
225        }
226
227        Expression * TupleSpecializer::createThunkFunction( FunctionType *funType, Expression *actual, InferredParams *inferParams ) {
228                static UniqueName thunkNamer( "_tupleThunk" );
229                // std::cerr << "creating tuple thunk for " << funType << std::endl;
230
231                FunctionType *newType = funType->clone();
232                if ( env ) {
233                        TypeSubstitution newEnv( *env );
234                        // it is important to replace only occurrences of type variables that occur free in the
235                        // thunk's type
236                        newEnv.applyFree( newType );
237                } // if
238                // create new thunk with same signature as formal type (C linkage, empty body)
239                FunctionDecl *thunkFunc = new FunctionDecl( thunkNamer.newName(), DeclarationNode::NoStorageClass, LinkageSpec::C, newType, new CompoundStmt( noLabels ), false, false );
240                thunkFunc->fixUniqueId();
241
242                // thunks may be generated and not used - silence warning with attribute
243                thunkFunc->get_attributes().push_back( new Attribute( "unused" ) );
244
245                // thread thunk parameters into call to actual function, naming thunk parameters as we go
246                UniqueName paramNamer( paramPrefix );
247                ApplicationExpr *appExpr = new ApplicationExpr( actual );
248                // std::cerr << actual << std::endl;
249
250                FunctionType * actualType = getFunctionType( actual->get_result() );
251                std::list< DeclarationWithType * >::iterator actualBegin = actualType->get_parameters().begin();
252                std::list< DeclarationWithType * >::iterator actualEnd = actualType->get_parameters().end();
253                std::list< DeclarationWithType * >::iterator formalBegin = funType->get_parameters().begin();
254                std::list< DeclarationWithType * >::iterator formalEnd = funType->get_parameters().end();
255
256                Expression * last = nullptr;
257                for ( DeclarationWithType* param : thunkFunc->get_functionType()->get_parameters() ) {
258                        // walk the parameters to the actual function alongside the parameters to the thunk to find the location where the ttype parameter begins to satisfy parameters in the actual function.
259                        param->set_name( paramNamer.newName() );
260                        assertf( formalBegin != formalEnd, "Reached end of formal parameters before finding ttype parameter" );
261                        if ( Tuples::isTtype((*formalBegin)->get_type()) ) {
262                                last = new VariableExpr( param );
263                                break;
264                        }
265                        assertf( actualBegin != actualEnd, "reached end of actual function's arguments before finding ttype parameter" );
266                        ++actualBegin;
267                        ++formalBegin;
268
269                        appExpr->get_args().push_back( new VariableExpr( param ) );
270                } // for
271                assert( last );
272                fixLastArg( last, actualBegin, actualEnd, back_inserter( appExpr->get_args() ) );
273                appExpr->set_env( maybeClone( env ) );
274                if ( inferParams ) {
275                        appExpr->get_inferParams() = *inferParams;
276                } // if
277
278                // handle any specializations that may still be present
279                std::string oldParamPrefix = paramPrefix;
280                paramPrefix += "p";
281                // save stmtsToAdd in oldStmts
282                std::list< Statement* > oldStmts;
283                oldStmts.splice( oldStmts.end(), stmtsToAdd );
284                spec.handleExplicitParams( appExpr );
285                paramPrefix = oldParamPrefix;
286                // write any statements added for recursive specializations into the thunk body
287                thunkFunc->get_statements()->get_kids().splice( thunkFunc->get_statements()->get_kids().end(), stmtsToAdd );
288                // restore oldStmts into stmtsToAdd
289                stmtsToAdd.splice( stmtsToAdd.end(), oldStmts );
290
291                // add return (or valueless expression) to the thunk
292                Statement *appStmt;
293                if ( funType->get_returnVals().empty() ) {
294                        appStmt = new ExprStmt( noLabels, appExpr );
295                } else {
296                        appStmt = new ReturnStmt( noLabels, appExpr );
297                } // if
298                thunkFunc->get_statements()->get_kids().push_back( appStmt );
299
300                // std::cerr << "thunkFunc is: " << thunkFunc << std::endl;
301
302                // add thunk definition to queue of statements to add
303                stmtsToAdd.push_back( new DeclStmt( noLabels, thunkFunc ) );
304                // return address of thunk function as replacement expression
305                return new AddressExpr( new VariableExpr( thunkFunc ) );
306        }
307
308        void Specialize::handleExplicitParams( ApplicationExpr *appExpr ) {
309                // create thunks for the explicit parameters
310                assert( appExpr->get_function()->has_result() );
311                FunctionType *function = getFunctionType( appExpr->get_function()->get_result() );
312                assert( function );
313                std::list< DeclarationWithType* >::iterator formal;
314                std::list< Expression* >::iterator actual;
315                for ( formal = function->get_parameters().begin(), actual = appExpr->get_args().begin(); formal != function->get_parameters().end() && actual != appExpr->get_args().end(); ++formal, ++actual ) {
316                        *actual = specializer->doSpecialization( (*formal )->get_type(), *actual, &appExpr->get_inferParams() );
317                }
318        }
319
320        Expression * Specialize::mutate( ApplicationExpr *appExpr ) {
321                appExpr->get_function()->acceptMutator( *this );
322                mutateAll( appExpr->get_args(), *this );
323
324                if ( ! InitTweak::isIntrinsicCallExpr( appExpr ) ) {
325                        // create thunks for the inferred parameters
326                        // don't need to do this for intrinsic calls, because they aren't actually passed
327                        for ( InferredParams::iterator inferParam = appExpr->get_inferParams().begin(); inferParam != appExpr->get_inferParams().end(); ++inferParam ) {
328                                inferParam->second.expr = specializer->doSpecialization( inferParam->second.formalType, inferParam->second.expr, &appExpr->get_inferParams() );
329                        }
330                        handleExplicitParams( appExpr );
331                }
332                return appExpr;
333        }
334
335        Expression * Specialize::mutate( AddressExpr *addrExpr ) {
336                addrExpr->get_arg()->acceptMutator( *this );
337                assert( addrExpr->has_result() );
338                addrExpr->set_arg( specializer->doSpecialization( addrExpr->get_result(), addrExpr->get_arg() ) );
339                return addrExpr;
340        }
341
342        Expression * Specialize::mutate( CastExpr *castExpr ) {
343                castExpr->get_arg()->acceptMutator( *this );
344                if ( castExpr->get_result()->isVoid() ) {
345                        // can't specialize if we don't have a return value
346                        return castExpr;
347                }
348                Expression *specialized = specializer->doSpecialization( castExpr->get_result(), castExpr->get_arg() );
349                if ( specialized != castExpr->get_arg() ) {
350                        // assume here that the specialization incorporates the cast
351                        return specialized;
352                } else {
353                        return castExpr;
354                }
355        }
356
357        // Removing these for now. Richard put these in for some reason, but it's not clear why.
358        // In particular, copy constructors produce a comma expression, and with this code the parts
359        // of that comma expression are not specialized, which causes problems.
360
361        // Expression * Specialize::mutate( LogicalExpr *logicalExpr ) {
362        //      return logicalExpr;
363        // }
364
365        // Expression * Specialize::mutate( ConditionalExpr *condExpr ) {
366        //      return condExpr;
367        // }
368
369        // Expression * Specialize::mutate( CommaExpr *commaExpr ) {
370        //      return commaExpr;
371        // }
372
373        void convertSpecializations( std::list< Declaration* >& translationUnit ) {
374                Specialize spec;
375
376                TupleSpecializer tupleSpec( spec );
377                spec.specializer = &tupleSpec;
378                mutateAll( translationUnit, spec );
379
380                PolySpecializer polySpec( spec );
381                spec.specializer = &polySpec;
382                mutateAll( translationUnit, spec );
383        }
384} // namespace GenPoly
385
386// Local Variables: //
387// tab-width: 4 //
388// mode: c++ //
389// compile-command: "make install" //
390// End: //
Note: See TracBrowser for help on using the repository browser.