source: src/GenPoly/Specialize.cc@ 8118303

ADT aaron-thesis arm-eh ast-experimental cleanup-dtors deferred_resn demangler enum forall-pointer-decay jacob/cs343-translation jenkins-sandbox new-ast new-ast-unique-expr new-env no_list persistent-indexer pthread-emulation qualifiedEnum resolv-new with_gc
Last change on this file since 8118303 was 6c3a988f, checked in by Rob Schluntz <rschlunt@…>, 9 years ago

fix inferred parameter data structures to correctly associate parameters with the entity that requested them, modify tuple specialization and unification to work with self-recursive assertions

  • Property mode set to 100644
File size: 15.8 KB
RevLine 
[51587aa]1//
2// Cforall Version 1.0.0 Copyright (C) 2015 University of Waterloo
3//
4// The contents of this file are covered under the licence agreement in the
5// file "LICENCE" distributed with Cforall.
6//
[f1e012b]7// Specialize.cc --
[51587aa]8//
9// Author : Richard C. Bilson
10// Created On : Mon May 18 07:44:20 2015
[8cbf8cd]11// Last Modified By : Rob Schluntz
[fea7ca7]12// Last Modified On : Thu Apr 28 15:17:45 2016
[771b3c3]13// Update Count : 24
[51587aa]14//
[51b73452]15
16#include <cassert>
17
18#include "Specialize.h"
[7754cde]19#include "GenPoly.h"
[51b73452]20#include "PolyMutator.h"
21
[68cd1ce]22#include "Parser/ParseNode.h"
23
[51b73452]24#include "SynTree/Expression.h"
[68cd1ce]25#include "SynTree/Statement.h"
[51b73452]26#include "SynTree/Type.h"
[64a32c6]27#include "SynTree/Attribute.h"
[51b73452]28#include "SynTree/TypeSubstitution.h"
29#include "SynTree/Mutator.h"
30#include "ResolvExpr/FindOpenVars.h"
[d3b7937]31#include "Common/UniqueName.h"
32#include "Common/utility.h"
[aedfd91]33#include "InitTweak/InitTweak.h"
[4c8621ac]34#include "Tuples/Tuples.h"
[51b73452]35
36namespace GenPoly {
[626dbc10]37 class Specializer;
[62e5546]38 class Specialize final : public PolyMutator {
[626dbc10]39 friend class Specializer;
[01aeade]40 public:
[62e5546]41 using PolyMutator::mutate;
42 virtual Expression * mutate( ApplicationExpr *applicationExpr ) override;
43 virtual Expression * mutate( AddressExpr *castExpr ) override;
44 virtual Expression * mutate( CastExpr *castExpr ) override;
[fea7ca7]45 // virtual Expression * mutate( LogicalExpr *logicalExpr );
46 // virtual Expression * mutate( ConditionalExpr *conditionalExpr );
47 // virtual Expression * mutate( CommaExpr *commaExpr );
[01aeade]48
[626dbc10]49 Specializer * specializer = nullptr;
[01aeade]50 void handleExplicitParams( ApplicationExpr *appExpr );
[626dbc10]51 };
[01aeade]52
[626dbc10]53 class Specializer {
54 public:
55 Specializer( Specialize & spec ) : spec( spec ), env( spec.env ), stmtsToAdd( spec.stmtsToAdd ) {}
56 virtual bool needsSpecialization( Type * formalType, Type * actualType, TypeSubstitution * env ) = 0;
57 virtual Expression *createThunkFunction( FunctionType *funType, Expression *actual, InferredParams *inferParams ) = 0;
58 virtual Expression *doSpecialization( Type *formalType, Expression *actual, InferredParams *inferParams = 0 );
59
60 protected:
61 Specialize & spec;
62 std::string paramPrefix = "_p";
63 TypeSubstitution *& env;
64 std::list< Statement * > & stmtsToAdd;
[01aeade]65 };
66
[626dbc10]67 // for normal polymorphic -> monomorphic function conversion
68 class PolySpecializer : public Specializer {
69 public:
70 PolySpecializer( Specialize & spec ) : Specializer( spec ) {}
71 virtual bool needsSpecialization( Type * formalType, Type * actualType, TypeSubstitution * env ) override;
72 virtual Expression *createThunkFunction( FunctionType *funType, Expression *actual, InferredParams *inferParams ) override;
73 };
[01aeade]74
[626dbc10]75 // // for tuple -> non-tuple function conversion
76 class TupleSpecializer : public Specializer {
77 public:
78 TupleSpecializer( Specialize & spec ) : Specializer( spec ) {}
79 virtual bool needsSpecialization( Type * formalType, Type * actualType, TypeSubstitution * env ) override;
80 virtual Expression *createThunkFunction( FunctionType *funType, Expression *actual, InferredParams *inferParams ) override;
81 };
[01aeade]82
[698664b3]83 /// Looks up open variables in actual type, returning true if any of them are bound in the environment or formal type.
[626dbc10]84 bool PolySpecializer::needsSpecialization( Type *formalType, Type *actualType, TypeSubstitution *env ) {
[01aeade]85 if ( env ) {
86 using namespace ResolvExpr;
87 OpenVarSet openVars, closedVars;
88 AssertionSet need, have;
89 findOpenVars( formalType, openVars, closedVars, need, have, false );
90 findOpenVars( actualType, openVars, closedVars, need, have, true );
91 for ( OpenVarSet::const_iterator openVar = openVars.begin(); openVar != openVars.end(); ++openVar ) {
92 Type *boundType = env->lookup( openVar->first );
93 if ( ! boundType ) continue;
94 if ( TypeInstType *typeInst = dynamic_cast< TypeInstType* >( boundType ) ) {
95 if ( closedVars.find( typeInst->get_name() ) == closedVars.end() ) {
96 return true;
97 } // if
98 } else {
99 return true;
100 } // if
101 } // for
102 return false;
103 } else {
104 return false;
105 } // if
106 }
107
[698664b3]108 /// Generates a thunk that calls `actual` with type `funType` and returns its address
[626dbc10]109 Expression * PolySpecializer::createThunkFunction( FunctionType *funType, Expression *actual, InferredParams *inferParams ) {
110 static UniqueName thunkNamer( "_thunk" );
111
[698664b3]112 FunctionType *newType = funType->clone();
113 if ( env ) {
114 // it is important to replace only occurrences of type variables that occur free in the
115 // thunk's type
[6c3a988f]116 env->applyFree( newType );
[698664b3]117 } // if
118 // create new thunk with same signature as formal type (C linkage, empty body)
[0f8e4ac]119 FunctionDecl *thunkFunc = new FunctionDecl( thunkNamer.newName(), DeclarationNode::NoStorageClass, LinkageSpec::C, newType, new CompoundStmt( noLabels ), false, false );
[698664b3]120 thunkFunc->fixUniqueId();
121
[64a32c6]122 // thunks may be generated and not used - silence warning with attribute
123 thunkFunc->get_attributes().push_back( new Attribute( "unused" ) );
124
[698664b3]125 // thread thunk parameters into call to actual function, naming thunk parameters as we go
126 UniqueName paramNamer( paramPrefix );
127 ApplicationExpr *appExpr = new ApplicationExpr( actual );
128 for ( std::list< DeclarationWithType* >::iterator param = thunkFunc->get_functionType()->get_parameters().begin(); param != thunkFunc->get_functionType()->get_parameters().end(); ++param ) {
129 (*param )->set_name( paramNamer.newName() );
130 appExpr->get_args().push_back( new VariableExpr( *param ) );
131 } // for
132 appExpr->set_env( maybeClone( env ) );
133 if ( inferParams ) {
134 appExpr->get_inferParams() = *inferParams;
135 } // if
136
137 // handle any specializations that may still be present
138 std::string oldParamPrefix = paramPrefix;
139 paramPrefix += "p";
140 // save stmtsToAdd in oldStmts
141 std::list< Statement* > oldStmts;
142 oldStmts.splice( oldStmts.end(), stmtsToAdd );
[626dbc10]143 spec.handleExplicitParams( appExpr );
[698664b3]144 paramPrefix = oldParamPrefix;
145 // write any statements added for recursive specializations into the thunk body
146 thunkFunc->get_statements()->get_kids().splice( thunkFunc->get_statements()->get_kids().end(), stmtsToAdd );
147 // restore oldStmts into stmtsToAdd
148 stmtsToAdd.splice( stmtsToAdd.end(), oldStmts );
149
150 // add return (or valueless expression) to the thunk
151 Statement *appStmt;
152 if ( funType->get_returnVals().empty() ) {
153 appStmt = new ExprStmt( noLabels, appExpr );
154 } else {
155 appStmt = new ReturnStmt( noLabels, appExpr );
156 } // if
157 thunkFunc->get_statements()->get_kids().push_back( appStmt );
158
159 // add thunk definition to queue of statements to add
160 stmtsToAdd.push_back( new DeclStmt( noLabels, thunkFunc ) );
161 // return address of thunk function as replacement expression
162 return new AddressExpr( new VariableExpr( thunkFunc ) );
163 }
[f1e012b]164
[626dbc10]165 Expression * Specializer::doSpecialization( Type *formalType, Expression *actual, InferredParams *inferParams ) {
[b3b2077]166 assertf( actual->has_result(), "attempting to specialize an untyped expression" );
[906e24d]167 if ( needsSpecialization( formalType, actual->get_result(), env ) ) {
[6c3a988f]168 if ( FunctionType *funType = getFunctionType( formalType ) ) {
[698664b3]169 ApplicationExpr *appExpr;
170 VariableExpr *varExpr;
171 if ( ( appExpr = dynamic_cast<ApplicationExpr*>( actual ) ) ) {
172 return createThunkFunction( funType, appExpr->get_function(), inferParams );
173 } else if ( ( varExpr = dynamic_cast<VariableExpr*>( actual ) ) ) {
174 return createThunkFunction( funType, varExpr, inferParams );
[01aeade]175 } else {
[698664b3]176 // This likely won't work, as anything that could build an ApplicationExpr probably hit one of the previous two branches
177 return createThunkFunction( funType, actual, inferParams );
178 }
[01aeade]179 } else {
180 return actual;
181 } // if
182 } else {
183 return actual;
184 } // if
185 }
186
[626dbc10]187 bool TupleSpecializer::needsSpecialization( Type *formalType, Type *actualType, TypeSubstitution *env ) {
188 if ( FunctionType * ftype = getFunctionType( formalType ) ) {
189 return ftype->isTtype();
190 }
191 return false;
192 }
193
[64eae56]194 /// restructures arg to match the structure of a single formal parameter. Assumes that atomic types are compatible (as the Resolver should have ensured this)
195 template< typename OutIterator >
196 void matchOneFormal( Expression * arg, unsigned & idx, Type * formal, OutIterator out ) {
197 if ( TupleType * tupleType = dynamic_cast< TupleType * >( formal ) ) {
198 std::list< Expression * > exprs;
199 for ( Type * t : *tupleType ) {
200 matchOneFormal( arg, idx, t, back_inserter( exprs ) );
201 }
202 *out++ = new TupleExpr( exprs );
203 } else {
204 *out++ = new TupleIndexExpr( arg->clone(), idx++ );
205 }
206 }
207
208 /// restructures the ttype argument to match the structure of the formal parameters of the actual function.
209 // [begin, end) are the formal parameters.
210 // args is the list of arguments currently given to the actual function, the last of which needs to be restructured.
[4c8621ac]211 template< typename Iterator, typename OutIterator >
212 void fixLastArg( Expression * last, Iterator begin, Iterator end, OutIterator out ) {
[626dbc10]213 // safe_dynamic_cast for the assertion
[6c3a988f]214 safe_dynamic_cast< TupleType * >( last->get_result() );
[626dbc10]215 unsigned idx = 0;
216 for ( ; begin != end; ++begin ) {
[64eae56]217 DeclarationWithType * formal = *begin;
218 Type * formalType = formal->get_type();
[4c8621ac]219 matchOneFormal( last, idx, formalType, out );
[626dbc10]220 }
221 delete last;
222 }
223
224 Expression * TupleSpecializer::createThunkFunction( FunctionType *funType, Expression *actual, InferredParams *inferParams ) {
225 static UniqueName thunkNamer( "_tupleThunk" );
226
227 FunctionType *newType = funType->clone();
228 if ( env ) {
229 // it is important to replace only occurrences of type variables that occur free in the
230 // thunk's type
[6c3a988f]231 env->applyFree( newType );
[626dbc10]232 } // if
233 // create new thunk with same signature as formal type (C linkage, empty body)
234 FunctionDecl *thunkFunc = new FunctionDecl( thunkNamer.newName(), DeclarationNode::NoStorageClass, LinkageSpec::C, newType, new CompoundStmt( noLabels ), false, false );
235 thunkFunc->fixUniqueId();
236
237 // thunks may be generated and not used - silence warning with attribute
238 thunkFunc->get_attributes().push_back( new Attribute( "unused" ) );
239
240 // thread thunk parameters into call to actual function, naming thunk parameters as we go
241 UniqueName paramNamer( paramPrefix );
242 ApplicationExpr *appExpr = new ApplicationExpr( actual );
243
[6c3a988f]244 FunctionType * actualType = getFunctionType( actual->get_result() )->clone();
245 if ( env ) {
246 // need to apply the environment to the actual function's type, since it may itself be polymorphic
247 env->apply( actualType );
248 }
249 std::unique_ptr< FunctionType > actualTypeManager( actualType ); // for RAII
[4c8621ac]250 std::list< DeclarationWithType * >::iterator actualBegin = actualType->get_parameters().begin();
251 std::list< DeclarationWithType * >::iterator actualEnd = actualType->get_parameters().end();
252 std::list< DeclarationWithType * >::iterator formalBegin = funType->get_parameters().begin();
253 std::list< DeclarationWithType * >::iterator formalEnd = funType->get_parameters().end();
[626dbc10]254
[4c8621ac]255 Expression * last = nullptr;
[626dbc10]256 for ( DeclarationWithType* param : thunkFunc->get_functionType()->get_parameters() ) {
[64eae56]257 // walk the parameters to the actual function alongside the parameters to the thunk to find the location where the ttype parameter begins to satisfy parameters in the actual function.
[626dbc10]258 param->set_name( paramNamer.newName() );
[4c8621ac]259 assertf( formalBegin != formalEnd, "Reached end of formal parameters before finding ttype parameter" );
260 if ( Tuples::isTtype((*formalBegin)->get_type()) ) {
261 last = new VariableExpr( param );
262 break;
263 }
264 assertf( actualBegin != actualEnd, "reached end of actual function's arguments before finding ttype parameter" );
265 ++actualBegin;
266 ++formalBegin;
267
[626dbc10]268 appExpr->get_args().push_back( new VariableExpr( param ) );
269 } // for
[4c8621ac]270 assert( last );
271 fixLastArg( last, actualBegin, actualEnd, back_inserter( appExpr->get_args() ) );
[626dbc10]272 appExpr->set_env( maybeClone( env ) );
273 if ( inferParams ) {
274 appExpr->get_inferParams() = *inferParams;
275 } // if
276
277 // handle any specializations that may still be present
278 std::string oldParamPrefix = paramPrefix;
279 paramPrefix += "p";
280 // save stmtsToAdd in oldStmts
281 std::list< Statement* > oldStmts;
282 oldStmts.splice( oldStmts.end(), stmtsToAdd );
[6c3a988f]283 spec.mutate( appExpr );
[626dbc10]284 paramPrefix = oldParamPrefix;
285 // write any statements added for recursive specializations into the thunk body
286 thunkFunc->get_statements()->get_kids().splice( thunkFunc->get_statements()->get_kids().end(), stmtsToAdd );
287 // restore oldStmts into stmtsToAdd
288 stmtsToAdd.splice( stmtsToAdd.end(), oldStmts );
289
290 // add return (or valueless expression) to the thunk
291 Statement *appStmt;
292 if ( funType->get_returnVals().empty() ) {
293 appStmt = new ExprStmt( noLabels, appExpr );
294 } else {
295 appStmt = new ReturnStmt( noLabels, appExpr );
296 } // if
297 thunkFunc->get_statements()->get_kids().push_back( appStmt );
298
299 // add thunk definition to queue of statements to add
300 stmtsToAdd.push_back( new DeclStmt( noLabels, thunkFunc ) );
301 // return address of thunk function as replacement expression
302 return new AddressExpr( new VariableExpr( thunkFunc ) );
303 }
304
[01aeade]305 void Specialize::handleExplicitParams( ApplicationExpr *appExpr ) {
306 // create thunks for the explicit parameters
[906e24d]307 assert( appExpr->get_function()->has_result() );
308 FunctionType *function = getFunctionType( appExpr->get_function()->get_result() );
[698664b3]309 assert( function );
[01aeade]310 std::list< DeclarationWithType* >::iterator formal;
311 std::list< Expression* >::iterator actual;
312 for ( formal = function->get_parameters().begin(), actual = appExpr->get_args().begin(); formal != function->get_parameters().end() && actual != appExpr->get_args().end(); ++formal, ++actual ) {
[626dbc10]313 *actual = specializer->doSpecialization( (*formal )->get_type(), *actual, &appExpr->get_inferParams() );
[01aeade]314 }
315 }
316
317 Expression * Specialize::mutate( ApplicationExpr *appExpr ) {
318 appExpr->get_function()->acceptMutator( *this );
319 mutateAll( appExpr->get_args(), *this );
320
[aedfd91]321 if ( ! InitTweak::isIntrinsicCallExpr( appExpr ) ) {
322 // create thunks for the inferred parameters
323 // don't need to do this for intrinsic calls, because they aren't actually passed
324 for ( InferredParams::iterator inferParam = appExpr->get_inferParams().begin(); inferParam != appExpr->get_inferParams().end(); ++inferParam ) {
[6c3a988f]325 inferParam->second.expr = specializer->doSpecialization( inferParam->second.formalType, inferParam->second.expr, inferParam->second.inferParams.get() );
[aedfd91]326 }
327 handleExplicitParams( appExpr );
328 }
[01aeade]329 return appExpr;
330 }
331
332 Expression * Specialize::mutate( AddressExpr *addrExpr ) {
333 addrExpr->get_arg()->acceptMutator( *this );
[906e24d]334 assert( addrExpr->has_result() );
[626dbc10]335 addrExpr->set_arg( specializer->doSpecialization( addrExpr->get_result(), addrExpr->get_arg() ) );
[01aeade]336 return addrExpr;
337 }
338
339 Expression * Specialize::mutate( CastExpr *castExpr ) {
340 castExpr->get_arg()->acceptMutator( *this );
[906e24d]341 if ( castExpr->get_result()->isVoid() ) {
[803deb1]342 // can't specialize if we don't have a return value
343 return castExpr;
344 }
[626dbc10]345 Expression *specialized = specializer->doSpecialization( castExpr->get_result(), castExpr->get_arg() );
[698664b3]346 if ( specialized != castExpr->get_arg() ) {
347 // assume here that the specialization incorporates the cast
348 return specialized;
349 } else {
350 return castExpr;
351 }
[01aeade]352 }
353
[fea7ca7]354 // Removing these for now. Richard put these in for some reason, but it's not clear why.
355 // In particular, copy constructors produce a comma expression, and with this code the parts
356 // of that comma expression are not specialized, which causes problems.
[01aeade]357
[fea7ca7]358 // Expression * Specialize::mutate( LogicalExpr *logicalExpr ) {
359 // return logicalExpr;
360 // }
[01aeade]361
[fea7ca7]362 // Expression * Specialize::mutate( ConditionalExpr *condExpr ) {
363 // return condExpr;
364 // }
365
366 // Expression * Specialize::mutate( CommaExpr *commaExpr ) {
367 // return commaExpr;
368 // }
[626dbc10]369
370 void convertSpecializations( std::list< Declaration* >& translationUnit ) {
371 Specialize spec;
372
373 TupleSpecializer tupleSpec( spec );
374 spec.specializer = &tupleSpec;
375 mutateAll( translationUnit, spec );
376
377 PolySpecializer polySpec( spec );
378 spec.specializer = &polySpec;
379 mutateAll( translationUnit, spec );
380 }
[51b73452]381} // namespace GenPoly
[01aeade]382
[51587aa]383// Local Variables: //
384// tab-width: 4 //
385// mode: c++ //
386// compile-command: "make install" //
387// End: //
Note: See TracBrowser for help on using the repository browser.