source: src/GenPoly/Specialize.cc @ 63f78f0

ADTaaron-thesisarm-ehast-experimentalcleanup-dtorsdeferred_resndemanglerenumforall-pointer-decayjacob/cs343-translationjenkins-sandboxnew-astnew-ast-unique-exprnew-envno_listpersistent-indexerpthread-emulationqualifiedEnumresolv-newwith_gc
Last change on this file since 63f78f0 was f3b0a07, checked in by Rob Schluntz <rschlunt@…>, 8 years ago

allow ttypes contained in tuple types to unify, refactor and simplify Specialize pass

  • Property mode set to 100644
File size: 12.3 KB
Line 
1//
2// Cforall Version 1.0.0 Copyright (C) 2015 University of Waterloo
3//
4// The contents of this file are covered under the licence agreement in the
5// file "LICENCE" distributed with Cforall.
6//
7// Specialize.cc --
8//
9// Author           : Richard C. Bilson
10// Created On       : Mon May 18 07:44:20 2015
11// Last Modified By : Rob Schluntz
12// Last Modified On : Thu Apr 28 15:17:45 2016
13// Update Count     : 24
14//
15
16#include <cassert>
17
18#include "Specialize.h"
19#include "GenPoly.h"
20#include "PolyMutator.h"
21
22#include "Parser/ParseNode.h"
23
24#include "SynTree/Expression.h"
25#include "SynTree/Statement.h"
26#include "SynTree/Type.h"
27#include "SynTree/Attribute.h"
28#include "SynTree/TypeSubstitution.h"
29#include "SynTree/Mutator.h"
30#include "ResolvExpr/FindOpenVars.h"
31#include "Common/UniqueName.h"
32#include "Common/utility.h"
33#include "InitTweak/InitTweak.h"
34#include "Tuples/Tuples.h"
35
36namespace GenPoly {
37        class Specialize final : public PolyMutator {
38          public:
39                using PolyMutator::mutate;
40                virtual Expression * mutate( ApplicationExpr *applicationExpr ) override;
41                virtual Expression * mutate( AddressExpr *castExpr ) override;
42                virtual Expression * mutate( CastExpr *castExpr ) override;
43                // virtual Expression * mutate( LogicalExpr *logicalExpr );
44                // virtual Expression * mutate( ConditionalExpr *conditionalExpr );
45                // virtual Expression * mutate( CommaExpr *commaExpr );
46
47                void handleExplicitParams( ApplicationExpr *appExpr );
48                Expression * createThunkFunction( FunctionType *funType, Expression *actual, InferredParams *inferParams );
49                Expression * doSpecialization( Type *formalType, Expression *actual, InferredParams *inferParams = nullptr );
50
51                std::string paramPrefix = "_p";
52        };
53
54        /// Looks up open variables in actual type, returning true if any of them are bound in the environment or formal type.
55        bool needsPolySpecialization( Type *formalType, Type *actualType, TypeSubstitution *env ) {
56                if ( env ) {
57                        using namespace ResolvExpr;
58                        OpenVarSet openVars, closedVars;
59                        AssertionSet need, have;
60                        findOpenVars( formalType, openVars, closedVars, need, have, false );
61                        findOpenVars( actualType, openVars, closedVars, need, have, true );
62                        for ( OpenVarSet::const_iterator openVar = openVars.begin(); openVar != openVars.end(); ++openVar ) {
63                                Type *boundType = env->lookup( openVar->first );
64                                if ( ! boundType ) continue;
65                                if ( TypeInstType *typeInst = dynamic_cast< TypeInstType* >( boundType ) ) {
66                                        if ( closedVars.find( typeInst->get_name() ) == closedVars.end() ) {
67                                                return true;
68                                        } // if
69                                } else {
70                                        return true;
71                                } // if
72                        } // for
73                        return false;
74                } else {
75                        return false;
76                } // if
77        }
78
79        bool needsTupleSpecialization( Type *formalType, Type *actualType, TypeSubstitution *env ) {
80                if ( FunctionType * ftype = getFunctionType( formalType ) ) {
81                        return ftype->isTtype();
82                }
83                return false;
84        }
85
86        bool needsSpecialization( Type *formalType, Type *actualType, TypeSubstitution *env ) {
87                return needsPolySpecialization( formalType, actualType, env ) || needsTupleSpecialization( formalType, actualType, env );
88        }
89
90        Expression * Specialize::doSpecialization( Type *formalType, Expression *actual, InferredParams *inferParams ) {
91                assertf( actual->has_result(), "attempting to specialize an untyped expression" );
92                if ( needsSpecialization( formalType, actual->get_result(), env ) ) {
93                        if ( FunctionType *funType = getFunctionType( formalType ) ) {
94                                ApplicationExpr *appExpr;
95                                VariableExpr *varExpr;
96                                if ( ( appExpr = dynamic_cast<ApplicationExpr*>( actual ) ) ) {
97                                        return createThunkFunction( funType, appExpr->get_function(), inferParams );
98                                } else if ( ( varExpr = dynamic_cast<VariableExpr*>( actual ) ) ) {
99                                        return createThunkFunction( funType, varExpr, inferParams );
100                                } else {
101                                        // This likely won't work, as anything that could build an ApplicationExpr probably hit one of the previous two branches
102                                        return createThunkFunction( funType, actual, inferParams );
103                                }
104                        } else {
105                                return actual;
106                        } // if
107                } else {
108                        return actual;
109                } // if
110        }
111
112        /// restructures arg to match the structure of a single formal parameter. Assumes that atomic types are compatible (as the Resolver should have ensured this)
113        template< typename OutIterator >
114        void matchOneFormal( Expression * arg, unsigned & idx, Type * formal, OutIterator out ) {
115                if ( TupleType * tupleType = dynamic_cast< TupleType * >( formal ) ) {
116                        std::list< Expression * > exprs;
117                        for ( Type * t : *tupleType ) {
118                                matchOneFormal( arg, idx, t, back_inserter( exprs ) );
119                        }
120                        *out++ = new TupleExpr( exprs );
121                } else {
122                        *out++ = new TupleIndexExpr( arg->clone(), idx++ );
123                }
124        }
125
126        /// restructures the ttype argument to match the structure of the formal parameters of the actual function.
127        /// [begin, end) are the formal parameters.
128        /// args is the list of arguments currently given to the actual function, the last of which needs to be restructured.
129        template< typename Iterator, typename OutIterator >
130        void fixLastArg( Expression * last, Iterator begin, Iterator end, OutIterator out ) {
131                if ( Tuples::isTtype( last->get_result() ) ) {
132                        *out++ = last;
133                } else {
134                        // safe_dynamic_cast for the assertion
135                        safe_dynamic_cast< TupleType * >( last->get_result() );
136                        unsigned idx = 0;
137                        for ( ; begin != end; ++begin ) {
138                                DeclarationWithType * formal = *begin;
139                                Type * formalType = formal->get_type();
140                                matchOneFormal( last, idx, formalType, out );
141                        }
142                        delete last;
143                }
144        }
145
146        /// Generates a thunk that calls `actual` with type `funType` and returns its address
147        Expression * Specialize::createThunkFunction( FunctionType *funType, Expression *actual, InferredParams *inferParams ) {
148                static UniqueName thunkNamer( "_thunk" );
149
150                FunctionType *newType = funType->clone();
151                if ( env ) {
152                        // it is important to replace only occurrences of type variables that occur free in the
153                        // thunk's type
154                        env->applyFree( newType );
155                } // if
156                // create new thunk with same signature as formal type (C linkage, empty body)
157                FunctionDecl *thunkFunc = new FunctionDecl( thunkNamer.newName(), DeclarationNode::NoStorageClass, LinkageSpec::C, newType, new CompoundStmt( noLabels ), false, false );
158                thunkFunc->fixUniqueId();
159
160                // thunks may be generated and not used - silence warning with attribute
161                thunkFunc->get_attributes().push_back( new Attribute( "unused" ) );
162
163                // thread thunk parameters into call to actual function, naming thunk parameters as we go
164                UniqueName paramNamer( paramPrefix );
165                ApplicationExpr *appExpr = new ApplicationExpr( actual );
166
167                FunctionType * actualType = getFunctionType( actual->get_result() )->clone();
168                if ( env ) {
169                        // need to apply the environment to the actual function's type, since it may itself be polymorphic
170                        env->apply( actualType );
171                }
172                std::unique_ptr< FunctionType > actualTypeManager( actualType ); // for RAII
173                std::list< DeclarationWithType * >::iterator actualBegin = actualType->get_parameters().begin();
174                std::list< DeclarationWithType * >::iterator actualEnd = actualType->get_parameters().end();
175                std::list< DeclarationWithType * >::iterator formalBegin = funType->get_parameters().begin();
176                std::list< DeclarationWithType * >::iterator formalEnd = funType->get_parameters().end();
177
178                for ( DeclarationWithType* param : thunkFunc->get_functionType()->get_parameters() ) {
179                        // walk the parameters to the actual function alongside the parameters to the thunk to find the location where the ttype parameter begins to satisfy parameters in the actual function.
180                        param->set_name( paramNamer.newName() );
181                        assertf( formalBegin != formalEnd, "Reached end of formal parameters before finding ttype parameter" );
182                        if ( Tuples::isTtype((*formalBegin)->get_type()) ) {
183                                fixLastArg( new VariableExpr( param ), actualBegin, actualEnd, back_inserter( appExpr->get_args() ) );
184                                break;
185                        }
186                        assertf( actualBegin != actualEnd, "reached end of actual function's arguments before finding ttype parameter" );
187                        ++actualBegin;
188                        ++formalBegin;
189
190                        appExpr->get_args().push_back( new VariableExpr( param ) );
191                } // for
192                appExpr->set_env( maybeClone( env ) );
193                if ( inferParams ) {
194                        appExpr->get_inferParams() = *inferParams;
195                } // if
196
197                // handle any specializations that may still be present
198                std::string oldParamPrefix = paramPrefix;
199                paramPrefix += "p";
200                // save stmtsToAdd in oldStmts
201                std::list< Statement* > oldStmts;
202                oldStmts.splice( oldStmts.end(), stmtsToAdd );
203                mutate( appExpr );
204                paramPrefix = oldParamPrefix;
205                // write any statements added for recursive specializations into the thunk body
206                thunkFunc->get_statements()->get_kids().splice( thunkFunc->get_statements()->get_kids().end(), stmtsToAdd );
207                // restore oldStmts into stmtsToAdd
208                stmtsToAdd.splice( stmtsToAdd.end(), oldStmts );
209
210                // add return (or valueless expression) to the thunk
211                Statement *appStmt;
212                if ( funType->get_returnVals().empty() ) {
213                        appStmt = new ExprStmt( noLabels, appExpr );
214                } else {
215                        appStmt = new ReturnStmt( noLabels, appExpr );
216                } // if
217                thunkFunc->get_statements()->get_kids().push_back( appStmt );
218
219                // add thunk definition to queue of statements to add
220                stmtsToAdd.push_back( new DeclStmt( noLabels, thunkFunc ) );
221                // return address of thunk function as replacement expression
222                return new AddressExpr( new VariableExpr( thunkFunc ) );
223        }
224
225        void Specialize::handleExplicitParams( ApplicationExpr *appExpr ) {
226                // create thunks for the explicit parameters
227                assert( appExpr->get_function()->has_result() );
228                FunctionType *function = getFunctionType( appExpr->get_function()->get_result() );
229                assert( function );
230                std::list< DeclarationWithType* >::iterator formal;
231                std::list< Expression* >::iterator actual;
232                for ( formal = function->get_parameters().begin(), actual = appExpr->get_args().begin(); formal != function->get_parameters().end() && actual != appExpr->get_args().end(); ++formal, ++actual ) {
233                        *actual = doSpecialization( (*formal )->get_type(), *actual, &appExpr->get_inferParams() );
234                }
235        }
236
237        Expression * Specialize::mutate( ApplicationExpr *appExpr ) {
238                appExpr->get_function()->acceptMutator( *this );
239                mutateAll( appExpr->get_args(), *this );
240
241                if ( ! InitTweak::isIntrinsicCallExpr( appExpr ) ) {
242                        // create thunks for the inferred parameters
243                        // don't need to do this for intrinsic calls, because they aren't actually passed
244                        // need to handle explicit params before inferred params so that explicit params do not recieve a changed set of inferParams (and change them again)
245                        // alternatively, if order starts to matter then copy appExpr's inferParams and pass them to handleExplicitParams.
246                        handleExplicitParams( appExpr );
247                        for ( InferredParams::iterator inferParam = appExpr->get_inferParams().begin(); inferParam != appExpr->get_inferParams().end(); ++inferParam ) {
248                                inferParam->second.expr = doSpecialization( inferParam->second.formalType, inferParam->second.expr, inferParam->second.inferParams.get() );
249                        }
250                }
251                return appExpr;
252        }
253
254        Expression * Specialize::mutate( AddressExpr *addrExpr ) {
255                addrExpr->get_arg()->acceptMutator( *this );
256                assert( addrExpr->has_result() );
257                addrExpr->set_arg( doSpecialization( addrExpr->get_result(), addrExpr->get_arg() ) );
258                return addrExpr;
259        }
260
261        Expression * Specialize::mutate( CastExpr *castExpr ) {
262                castExpr->get_arg()->acceptMutator( *this );
263                if ( castExpr->get_result()->isVoid() ) {
264                        // can't specialize if we don't have a return value
265                        return castExpr;
266                }
267                Expression *specialized = doSpecialization( castExpr->get_result(), castExpr->get_arg() );
268                if ( specialized != castExpr->get_arg() ) {
269                        // assume here that the specialization incorporates the cast
270                        return specialized;
271                } else {
272                        return castExpr;
273                }
274        }
275
276        // Removing these for now. Richard put these in for some reason, but it's not clear why.
277        // In particular, copy constructors produce a comma expression, and with this code the parts
278        // of that comma expression are not specialized, which causes problems.
279
280        // Expression * Specialize::mutate( LogicalExpr *logicalExpr ) {
281        //      return logicalExpr;
282        // }
283
284        // Expression * Specialize::mutate( ConditionalExpr *condExpr ) {
285        //      return condExpr;
286        // }
287
288        // Expression * Specialize::mutate( CommaExpr *commaExpr ) {
289        //      return commaExpr;
290        // }
291
292        void convertSpecializations( std::list< Declaration* >& translationUnit ) {
293                Specialize spec;
294                mutateAll( translationUnit, spec );
295        }
296} // namespace GenPoly
297
298// Local Variables: //
299// tab-width: 4 //
300// mode: c++ //
301// compile-command: "make install" //
302// End: //
Note: See TracBrowser for help on using the repository browser.