source: src/GenPoly/Specialize.cc@ 907eccb

ADT aaron-thesis arm-eh ast-experimental cleanup-dtors deferred_resn demangler enum forall-pointer-decay jacob/cs343-translation jenkins-sandbox new-ast new-ast-unique-expr new-env no_list persistent-indexer pthread-emulation qualifiedEnum resolv-new with_gc
Last change on this file since 907eccb was 64eae56, checked in by Rob Schluntz <rschlunt@…>, 9 years ago

match formal parameter type of actual function when specializing ttype parameter, flatten types when unifying ttype parameters

  • Property mode set to 100644
File size: 15.7 KB
Line 
1//
2// Cforall Version 1.0.0 Copyright (C) 2015 University of Waterloo
3//
4// The contents of this file are covered under the licence agreement in the
5// file "LICENCE" distributed with Cforall.
6//
7// Specialize.cc --
8//
9// Author : Richard C. Bilson
10// Created On : Mon May 18 07:44:20 2015
11// Last Modified By : Rob Schluntz
12// Last Modified On : Thu Apr 28 15:17:45 2016
13// Update Count : 24
14//
15
16#include <cassert>
17
18#include "Specialize.h"
19#include "GenPoly.h"
20#include "PolyMutator.h"
21
22#include "Parser/ParseNode.h"
23
24#include "SynTree/Expression.h"
25#include "SynTree/Statement.h"
26#include "SynTree/Type.h"
27#include "SynTree/Attribute.h"
28#include "SynTree/TypeSubstitution.h"
29#include "SynTree/Mutator.h"
30#include "ResolvExpr/FindOpenVars.h"
31#include "Common/UniqueName.h"
32#include "Common/utility.h"
33#include "InitTweak/InitTweak.h"
34
35namespace GenPoly {
36 class Specializer;
37 class Specialize final : public PolyMutator {
38 friend class Specializer;
39 public:
40 using PolyMutator::mutate;
41 virtual Expression * mutate( ApplicationExpr *applicationExpr ) override;
42 virtual Expression * mutate( AddressExpr *castExpr ) override;
43 virtual Expression * mutate( CastExpr *castExpr ) override;
44 // virtual Expression * mutate( LogicalExpr *logicalExpr );
45 // virtual Expression * mutate( ConditionalExpr *conditionalExpr );
46 // virtual Expression * mutate( CommaExpr *commaExpr );
47
48 Specializer * specializer = nullptr;
49 void handleExplicitParams( ApplicationExpr *appExpr );
50 };
51
52 class Specializer {
53 public:
54 Specializer( Specialize & spec ) : spec( spec ), env( spec.env ), stmtsToAdd( spec.stmtsToAdd ) {}
55 virtual bool needsSpecialization( Type * formalType, Type * actualType, TypeSubstitution * env ) = 0;
56 virtual Expression *createThunkFunction( FunctionType *funType, Expression *actual, InferredParams *inferParams ) = 0;
57 virtual Expression *doSpecialization( Type *formalType, Expression *actual, InferredParams *inferParams = 0 );
58
59 protected:
60 Specialize & spec;
61 std::string paramPrefix = "_p";
62 TypeSubstitution *& env;
63 std::list< Statement * > & stmtsToAdd;
64 };
65
66 // for normal polymorphic -> monomorphic function conversion
67 class PolySpecializer : public Specializer {
68 public:
69 PolySpecializer( Specialize & spec ) : Specializer( spec ) {}
70 virtual bool needsSpecialization( Type * formalType, Type * actualType, TypeSubstitution * env ) override;
71 virtual Expression *createThunkFunction( FunctionType *funType, Expression *actual, InferredParams *inferParams ) override;
72 };
73
74 // // for tuple -> non-tuple function conversion
75 class TupleSpecializer : public Specializer {
76 public:
77 TupleSpecializer( Specialize & spec ) : Specializer( spec ) {}
78 virtual bool needsSpecialization( Type * formalType, Type * actualType, TypeSubstitution * env ) override;
79 virtual Expression *createThunkFunction( FunctionType *funType, Expression *actual, InferredParams *inferParams ) override;
80 };
81
82 /// Looks up open variables in actual type, returning true if any of them are bound in the environment or formal type.
83 bool PolySpecializer::needsSpecialization( Type *formalType, Type *actualType, TypeSubstitution *env ) {
84 if ( env ) {
85 using namespace ResolvExpr;
86 OpenVarSet openVars, closedVars;
87 AssertionSet need, have;
88 findOpenVars( formalType, openVars, closedVars, need, have, false );
89 findOpenVars( actualType, openVars, closedVars, need, have, true );
90 for ( OpenVarSet::const_iterator openVar = openVars.begin(); openVar != openVars.end(); ++openVar ) {
91 Type *boundType = env->lookup( openVar->first );
92 if ( ! boundType ) continue;
93 if ( TypeInstType *typeInst = dynamic_cast< TypeInstType* >( boundType ) ) {
94 if ( closedVars.find( typeInst->get_name() ) == closedVars.end() ) {
95 return true;
96 } // if
97 } else {
98 return true;
99 } // if
100 } // for
101 return false;
102 } else {
103 return false;
104 } // if
105 }
106
107 /// Generates a thunk that calls `actual` with type `funType` and returns its address
108 Expression * PolySpecializer::createThunkFunction( FunctionType *funType, Expression *actual, InferredParams *inferParams ) {
109 static UniqueName thunkNamer( "_thunk" );
110
111 FunctionType *newType = funType->clone();
112 if ( env ) {
113 TypeSubstitution newEnv( *env );
114 // it is important to replace only occurrences of type variables that occur free in the
115 // thunk's type
116 newEnv.applyFree( newType );
117 } // if
118 // create new thunk with same signature as formal type (C linkage, empty body)
119 FunctionDecl *thunkFunc = new FunctionDecl( thunkNamer.newName(), DeclarationNode::NoStorageClass, LinkageSpec::C, newType, new CompoundStmt( noLabels ), false, false );
120 thunkFunc->fixUniqueId();
121
122 // thunks may be generated and not used - silence warning with attribute
123 thunkFunc->get_attributes().push_back( new Attribute( "unused" ) );
124
125 // thread thunk parameters into call to actual function, naming thunk parameters as we go
126 UniqueName paramNamer( paramPrefix );
127 ApplicationExpr *appExpr = new ApplicationExpr( actual );
128 for ( std::list< DeclarationWithType* >::iterator param = thunkFunc->get_functionType()->get_parameters().begin(); param != thunkFunc->get_functionType()->get_parameters().end(); ++param ) {
129 (*param )->set_name( paramNamer.newName() );
130 appExpr->get_args().push_back( new VariableExpr( *param ) );
131 } // for
132 appExpr->set_env( maybeClone( env ) );
133 if ( inferParams ) {
134 appExpr->get_inferParams() = *inferParams;
135 } // if
136
137 // handle any specializations that may still be present
138 std::string oldParamPrefix = paramPrefix;
139 paramPrefix += "p";
140 // save stmtsToAdd in oldStmts
141 std::list< Statement* > oldStmts;
142 oldStmts.splice( oldStmts.end(), stmtsToAdd );
143 spec.handleExplicitParams( appExpr );
144 paramPrefix = oldParamPrefix;
145 // write any statements added for recursive specializations into the thunk body
146 thunkFunc->get_statements()->get_kids().splice( thunkFunc->get_statements()->get_kids().end(), stmtsToAdd );
147 // restore oldStmts into stmtsToAdd
148 stmtsToAdd.splice( stmtsToAdd.end(), oldStmts );
149
150 // add return (or valueless expression) to the thunk
151 Statement *appStmt;
152 if ( funType->get_returnVals().empty() ) {
153 appStmt = new ExprStmt( noLabels, appExpr );
154 } else {
155 appStmt = new ReturnStmt( noLabels, appExpr );
156 } // if
157 thunkFunc->get_statements()->get_kids().push_back( appStmt );
158
159 // add thunk definition to queue of statements to add
160 stmtsToAdd.push_back( new DeclStmt( noLabels, thunkFunc ) );
161 // return address of thunk function as replacement expression
162 return new AddressExpr( new VariableExpr( thunkFunc ) );
163 }
164
165 Expression * Specializer::doSpecialization( Type *formalType, Expression *actual, InferredParams *inferParams ) {
166 assertf( actual->has_result(), "attempting to specialize an untyped expression" );
167 if ( needsSpecialization( formalType, actual->get_result(), env ) ) {
168 FunctionType *funType;
169 if ( ( funType = getFunctionType( formalType ) ) ) {
170 ApplicationExpr *appExpr;
171 VariableExpr *varExpr;
172 if ( ( appExpr = dynamic_cast<ApplicationExpr*>( actual ) ) ) {
173 return createThunkFunction( funType, appExpr->get_function(), inferParams );
174 } else if ( ( varExpr = dynamic_cast<VariableExpr*>( actual ) ) ) {
175 return createThunkFunction( funType, varExpr, inferParams );
176 } else {
177 // This likely won't work, as anything that could build an ApplicationExpr probably hit one of the previous two branches
178 return createThunkFunction( funType, actual, inferParams );
179 }
180 } else {
181 return actual;
182 } // if
183 } else {
184 return actual;
185 } // if
186 }
187
188 bool TupleSpecializer::needsSpecialization( Type *formalType, Type *actualType, TypeSubstitution *env ) {
189 // std::cerr << "asking if type needs tuple spec: " << formalType << std::endl;
190 if ( FunctionType * ftype = getFunctionType( formalType ) ) {
191 return ftype->isTtype();
192 }
193 return false;
194 }
195
196 /// restructures arg to match the structure of a single formal parameter. Assumes that atomic types are compatible (as the Resolver should have ensured this)
197 template< typename OutIterator >
198 void matchOneFormal( Expression * arg, unsigned & idx, Type * formal, OutIterator out ) {
199 if ( TupleType * tupleType = dynamic_cast< TupleType * >( formal ) ) {
200 std::list< Expression * > exprs;
201 for ( Type * t : *tupleType ) {
202 matchOneFormal( arg, idx, t, back_inserter( exprs ) );
203 }
204 *out++ = new TupleExpr( exprs );
205 } else {
206 *out++ = new TupleIndexExpr( arg->clone(), idx++ );
207 }
208 }
209
210 /// restructures the ttype argument to match the structure of the formal parameters of the actual function.
211 // [begin, end) are the formal parameters.
212 // args is the list of arguments currently given to the actual function, the last of which needs to be restructured.
213 template< typename Iterator >
214 void fixLastArg( std::list< Expression * > & args, Iterator begin, Iterator end ) {
215 assertf( ! args.empty(), "Somehow args to tuple function are empty" ); // xxx - it's quite possible this will trigger for the nullary case...
216 Expression * last = args.back();
217 // safe_dynamic_cast for the assertion
218 safe_dynamic_cast< TupleType * >( last->get_result() ); // xxx - it's quite possible this will trigger for the unary case...
219 args.pop_back(); // replace last argument in the call with
220 unsigned idx = 0;
221 for ( ; begin != end; ++begin ) {
222 DeclarationWithType * formal = *begin;
223 Type * formalType = formal->get_type();
224 matchOneFormal( last, idx, formalType, back_inserter( args ) );
225 }
226 delete last;
227 }
228
229 Expression * TupleSpecializer::createThunkFunction( FunctionType *funType, Expression *actual, InferredParams *inferParams ) {
230 static UniqueName thunkNamer( "_tupleThunk" );
231 // std::cerr << "creating tuple thunk for " << funType << std::endl;
232
233 FunctionType *newType = funType->clone();
234 if ( env ) {
235 TypeSubstitution newEnv( *env );
236 // it is important to replace only occurrences of type variables that occur free in the
237 // thunk's type
238 newEnv.applyFree( newType );
239 } // if
240 // create new thunk with same signature as formal type (C linkage, empty body)
241 FunctionDecl *thunkFunc = new FunctionDecl( thunkNamer.newName(), DeclarationNode::NoStorageClass, LinkageSpec::C, newType, new CompoundStmt( noLabels ), false, false );
242 thunkFunc->fixUniqueId();
243
244 // thunks may be generated and not used - silence warning with attribute
245 thunkFunc->get_attributes().push_back( new Attribute( "unused" ) );
246
247 // thread thunk parameters into call to actual function, naming thunk parameters as we go
248 UniqueName paramNamer( paramPrefix );
249 ApplicationExpr *appExpr = new ApplicationExpr( actual );
250 // std::cerr << actual << std::endl;
251
252 FunctionType * actualType = getFunctionType( actual->get_result() );
253 std::list< DeclarationWithType * >::iterator begin = actualType->get_parameters().begin();
254 std::list< DeclarationWithType * >::iterator end = actualType->get_parameters().end();
255
256 for ( DeclarationWithType* param : thunkFunc->get_functionType()->get_parameters() ) {
257 // walk the parameters to the actual function alongside the parameters to the thunk to find the location where the ttype parameter begins to satisfy parameters in the actual function.
258 assert( begin != end );
259 ++begin;
260
261 // std::cerr << "thunk param: " << param << std::endl;
262 // last param will always be a tuple type... expand it into the actual type(?)
263 param->set_name( paramNamer.newName() );
264 appExpr->get_args().push_back( new VariableExpr( param ) );
265 } // for
266 fixLastArg( appExpr->get_args(), --begin, end );
267 appExpr->set_env( maybeClone( env ) );
268 if ( inferParams ) {
269 appExpr->get_inferParams() = *inferParams;
270 } // if
271
272 // handle any specializations that may still be present
273 std::string oldParamPrefix = paramPrefix;
274 paramPrefix += "p";
275 // save stmtsToAdd in oldStmts
276 std::list< Statement* > oldStmts;
277 oldStmts.splice( oldStmts.end(), stmtsToAdd );
278 spec.handleExplicitParams( appExpr );
279 paramPrefix = oldParamPrefix;
280 // write any statements added for recursive specializations into the thunk body
281 thunkFunc->get_statements()->get_kids().splice( thunkFunc->get_statements()->get_kids().end(), stmtsToAdd );
282 // restore oldStmts into stmtsToAdd
283 stmtsToAdd.splice( stmtsToAdd.end(), oldStmts );
284
285 // add return (or valueless expression) to the thunk
286 Statement *appStmt;
287 if ( funType->get_returnVals().empty() ) {
288 appStmt = new ExprStmt( noLabels, appExpr );
289 } else {
290 appStmt = new ReturnStmt( noLabels, appExpr );
291 } // if
292 thunkFunc->get_statements()->get_kids().push_back( appStmt );
293
294 // std::cerr << "thunkFunc is: " << thunkFunc << std::endl;
295
296 // add thunk definition to queue of statements to add
297 stmtsToAdd.push_back( new DeclStmt( noLabels, thunkFunc ) );
298 // return address of thunk function as replacement expression
299 return new AddressExpr( new VariableExpr( thunkFunc ) );
300 }
301
302 void Specialize::handleExplicitParams( ApplicationExpr *appExpr ) {
303 // create thunks for the explicit parameters
304 assert( appExpr->get_function()->has_result() );
305 FunctionType *function = getFunctionType( appExpr->get_function()->get_result() );
306 assert( function );
307 std::list< DeclarationWithType* >::iterator formal;
308 std::list< Expression* >::iterator actual;
309 for ( formal = function->get_parameters().begin(), actual = appExpr->get_args().begin(); formal != function->get_parameters().end() && actual != appExpr->get_args().end(); ++formal, ++actual ) {
310 *actual = specializer->doSpecialization( (*formal )->get_type(), *actual, &appExpr->get_inferParams() );
311 }
312 }
313
314 Expression * Specialize::mutate( ApplicationExpr *appExpr ) {
315 appExpr->get_function()->acceptMutator( *this );
316 mutateAll( appExpr->get_args(), *this );
317
318 if ( ! InitTweak::isIntrinsicCallExpr( appExpr ) ) {
319 // create thunks for the inferred parameters
320 // don't need to do this for intrinsic calls, because they aren't actually passed
321 for ( InferredParams::iterator inferParam = appExpr->get_inferParams().begin(); inferParam != appExpr->get_inferParams().end(); ++inferParam ) {
322 inferParam->second.expr = specializer->doSpecialization( inferParam->second.formalType, inferParam->second.expr, &appExpr->get_inferParams() );
323 }
324 handleExplicitParams( appExpr );
325 }
326 return appExpr;
327 }
328
329 Expression * Specialize::mutate( AddressExpr *addrExpr ) {
330 addrExpr->get_arg()->acceptMutator( *this );
331 assert( addrExpr->has_result() );
332 addrExpr->set_arg( specializer->doSpecialization( addrExpr->get_result(), addrExpr->get_arg() ) );
333 return addrExpr;
334 }
335
336 Expression * Specialize::mutate( CastExpr *castExpr ) {
337 castExpr->get_arg()->acceptMutator( *this );
338 if ( castExpr->get_result()->isVoid() ) {
339 // can't specialize if we don't have a return value
340 return castExpr;
341 }
342 Expression *specialized = specializer->doSpecialization( castExpr->get_result(), castExpr->get_arg() );
343 if ( specialized != castExpr->get_arg() ) {
344 // assume here that the specialization incorporates the cast
345 return specialized;
346 } else {
347 return castExpr;
348 }
349 }
350
351 // Removing these for now. Richard put these in for some reason, but it's not clear why.
352 // In particular, copy constructors produce a comma expression, and with this code the parts
353 // of that comma expression are not specialized, which causes problems.
354
355 // Expression * Specialize::mutate( LogicalExpr *logicalExpr ) {
356 // return logicalExpr;
357 // }
358
359 // Expression * Specialize::mutate( ConditionalExpr *condExpr ) {
360 // return condExpr;
361 // }
362
363 // Expression * Specialize::mutate( CommaExpr *commaExpr ) {
364 // return commaExpr;
365 // }
366
367 void convertSpecializations( std::list< Declaration* >& translationUnit ) {
368 Specialize spec;
369
370 TupleSpecializer tupleSpec( spec );
371 spec.specializer = &tupleSpec;
372 mutateAll( translationUnit, spec );
373
374 PolySpecializer polySpec( spec );
375 spec.specializer = &polySpec;
376 mutateAll( translationUnit, spec );
377 }
378} // namespace GenPoly
379
380// Local Variables: //
381// tab-width: 4 //
382// mode: c++ //
383// compile-command: "make install" //
384// End: //
Note: See TracBrowser for help on using the repository browser.