source: src/GenPoly/SpecializeNew.cpp@ b0d9ff7

ADT ast-experimental pthread-emulation qualifiedEnum
Last change on this file since b0d9ff7 was 9e23b446, checked in by Fangren Yu <f37yu@…>, 3 years ago

add specialize pass

  • Property mode set to 100644
File size: 17.4 KB
Line 
1//
2// Cforall Version 1.0.0 Copyright (C) 2015 University of Waterloo
3//
4// The contents of this file are covered under the licence agreement in the
5// file "LICENCE" distributed with Cforall.
6//
7// SpecializeNew.cpp --
8//
9// Author : Andrew Beach
10// Created On : Tue Jun 7 13:37:00 2022
11// Last Modified By : Andrew Beach
12// Last Modified On : Tue Jun 7 13:37:00 2022
13// Update Count : 0
14//
15
16#include "Specialize.h"
17
18#include "AST/Pass.hpp"
19#include "AST/TypeEnvironment.hpp" // for OpenVarSet, AssertionSet
20#include "Common/UniqueName.h" // for UniqueName
21#include "GenPoly/GenPoly.h" // for getFunctionType
22#include "InitTweak/InitTweak.h" // for isIntrinsicCallExpr
23#include "ResolvExpr/FindOpenVars.h" // for findOpenVars
24#include "ResolvExpr/TypeEnvironment.h" // for FirstOpen, FirstClosed
25
26#include "AST/Print.hpp"
27
28namespace GenPoly {
29
30namespace {
31
32struct SpecializeCore final :
33 public ast::WithConstTypeSubstitution,
34 public ast::WithDeclsToAdd<>,
35 public ast::WithVisitorRef<SpecializeCore> {
36 std::string paramPrefix = "_p";
37
38 ast::ApplicationExpr * handleExplicitParams(
39 const ast::ApplicationExpr * expr );
40 const ast::Expr * createThunkFunction(
41 const CodeLocation & location,
42 const ast::FunctionType * funType,
43 const ast::Expr * actual,
44 const ast::InferredParams * inferParams );
45 const ast::Expr * doSpecialization(
46 const CodeLocation & location,
47 const ast::Type * formalType,
48 const ast::Expr * actual,
49 const ast::InferredParams * inferParams );
50
51 const ast::Expr * postvisit( const ast::ApplicationExpr * expr );
52 const ast::Expr * postvisit( const ast::CastExpr * expr );
53};
54
55const ast::InferredParams * getInferredParams( const ast::Expr * expr ) {
56 const ast::Expr::InferUnion & inferred = expr->inferred;
57 if ( inferred.hasParams() ) {
58 return &inferred.inferParams();
59 } else {
60 return nullptr;
61 }
62}
63
64// Check if both types have the same structure. The leaf (non-tuple) types
65// don't have to match but the tuples must match.
66bool isTupleStructureMatching( const ast::Type * t0, const ast::Type * t1 ) {
67 const ast::TupleType * tt0 = dynamic_cast<const ast::TupleType *>( t0 );
68 const ast::TupleType * tt1 = dynamic_cast<const ast::TupleType *>( t1 );
69 if ( tt0 && tt1 ) {
70 if ( tt0->size() != tt1->size() ) {
71 return false;
72 }
73 for ( auto types : group_iterate( tt0->types, tt1->types ) ) {
74 if ( !isTupleStructureMatching(
75 std::get<0>( types ), std::get<1>( types ) ) ) {
76 return false;
77 }
78 }
79 return true;
80 }
81 return (!tt0 && !tt1);
82}
83
84// The number of elements in a type if it is a flattened tuple.
85size_t flatTupleSize( const ast::Type * type ) {
86 if ( auto tuple = dynamic_cast<const ast::TupleType *>( type ) ) {
87 size_t sum = 0;
88 for ( auto t : *tuple ) {
89 sum += flatTupleSize( t );
90 }
91 return sum;
92 } else {
93 return 1;
94 }
95}
96
97// Find the total number of components in a parameter list.
98size_t functionParameterSize( const ast::FunctionType * type ) {
99 size_t sum = 0;
100 for ( auto param : type->params ) {
101 sum += flatTupleSize( param );
102 }
103 return sum;
104}
105
106bool needsPolySpecialization(
107 const ast::Type * formalType,
108 const ast::Type * actualType,
109 const ast::TypeSubstitution * subs ) {
110 if ( !subs ) {
111 return false;
112 }
113
114 using namespace ResolvExpr;
115 ast::OpenVarSet openVars, closedVars;
116 ast::AssertionSet need, have;
117 findOpenVars( formalType, openVars, closedVars, need, have, FirstClosed );
118 findOpenVars( actualType, openVars, closedVars, need, have, FirstOpen );
119 for ( const ast::OpenVarSet::value_type & openVar : openVars ) {
120 const ast::Type * boundType = subs->lookup( openVar.first );
121 // If the variable is not bound, move onto the next variable.
122 if ( !boundType ) continue;
123
124 // Is the variable cound to another type variable?
125 if ( auto inst = dynamic_cast<const ast::TypeInstType *>( boundType ) ) {
126 if ( closedVars.find( *inst ) == closedVars.end() ) {
127 return true;
128 }
129 // Otherwise, the variable is bound to a concrete type.
130 } else {
131 return true;
132 }
133 }
134 // None of the type variables are bound.
135 return false;
136}
137
138bool needsTupleSpecialization(
139 const ast::Type * formalType, const ast::Type * actualType ) {
140 // Needs tuple specialization if the structure of the formal type and
141 // actual type do not match.
142
143 // This is the case if the formal type has ttype polymorphism, or if the structure of tuple types
144 // between the function do not match exactly.
145 if ( const ast::FunctionType * ftype = getFunctionType( formalType ) ) {
146 // A pack in the parameter or return type requires specialization.
147 if ( ftype->isTtype() ) {
148 return true;
149 }
150 // Conversion of 0 to a function type does not require specialization.
151 if ( dynamic_cast<const ast::ZeroType *>( actualType ) ) {
152 return false;
153 }
154 const ast::FunctionType * atype =
155 getFunctionType( actualType->stripReferences() );
156 assertf( atype,
157 "formal type is a function type, but actual type is not: %s",
158 toString( actualType ).c_str() );
159 // Can't tuple specialize if parameter sizes deeply-differ.
160 if ( functionParameterSize( ftype ) != functionParameterSize( atype ) ) {
161 return false;
162 }
163 // If tuple parameter size matches but actual parameter sizes differ
164 // then there needs to be specialization.
165 if ( ftype->params.size() != atype->params.size() ) {
166 return true;
167 }
168 // Total parameter size can be the same, while individual parameters
169 // can have different structure.
170 for ( auto pairs : group_iterate( ftype->params, atype->params ) ) {
171 if ( !isTupleStructureMatching(
172 std::get<0>( pairs ), std::get<1>( pairs ) ) ) {
173 return true;
174 }
175 }
176 }
177 return false;
178}
179
180bool needsSpecialization(
181 const ast::Type * formalType, const ast::Type * actualType,
182 const ast::TypeSubstitution * subs ) {
183 return needsPolySpecialization( formalType, actualType, subs )
184 || needsTupleSpecialization( formalType, actualType );
185}
186
187ast::ApplicationExpr * SpecializeCore::handleExplicitParams(
188 const ast::ApplicationExpr * expr ) {
189 assert( expr->func->result );
190 const ast::FunctionType * func = getFunctionType( expr->func->result );
191 assert( func );
192
193 ast::ApplicationExpr * mut = ast::mutate( expr );
194
195 std::vector<ast::ptr<ast::Type>>::const_iterator formal;
196 std::vector<ast::ptr<ast::Expr>>::iterator actual;
197 for ( formal = func->params.begin(), actual = mut->args.begin() ;
198 formal != func->params.end() && actual != mut->args.end() ;
199 ++formal, ++actual ) {
200 *actual = doSpecialization( (*actual)->location,
201 *formal, *actual, getInferredParams( expr ) );
202 }
203 //for ( auto pair : group_iterate( formal->params, mut->args ) ) {
204 // const ast::ptr<ast::Type> & formal = std::get<0>( pair );
205 // ast::ptr<ast::Expr> & actual = std::get<1>( pair );
206 // *actual = doSpecialization( (*actual)->location, *formal, *actual, getInferredParams( expr ) );
207 //}
208 return mut;
209}
210
211// Explode assuming simple cases: either type is pure tuple (but not tuple
212// expr) or type is non-tuple.
213template<typename OutputIterator>
214void explodeSimple( const CodeLocation & location,
215 const ast::Expr * expr, OutputIterator out ) {
216 // Recurse on tuple types using index expressions on each component.
217 if ( auto tuple = expr->result.as<ast::TupleType>() ) {
218 ast::ptr<ast::Expr> cleanup = expr;
219 for ( unsigned int i = 0 ; i < tuple->size() ; ++i ) {
220 explodeSimple( location,
221 new ast::TupleIndexExpr( location, expr, i ), out );
222 }
223 // For a non-tuple type, output a clone of the expression.
224 } else {
225 *out++ = expr;
226 }
227}
228
229// Restructures the arguments to match the structure of the formal parameters
230// of the actual function. [begin, end) are the exploded arguments.
231template<typename Iterator, typename OutIterator>
232void structureArg( const CodeLocation & location, const ast::Type * type,
233 Iterator & begin, Iterator end, OutIterator out ) {
234 if ( auto tuple = dynamic_cast<const ast::TupleType *>( type ) ) {
235 std::vector<ast::ptr<ast::Expr>> exprs;
236 for ( const ast::Type * t : *tuple ) {
237 structureArg( location, t, begin, end, std::back_inserter( exprs ) );
238 }
239 *out++ = new ast::TupleExpr( location, std::move( exprs ) );
240 } else {
241 assertf( begin != end, "reached the end of the arguments while structuring" );
242 *out++ = *begin++;
243 }
244}
245
246#if 0
247template<typename Iterator>
248const ast::Expr * structureArg(
249 const CodeLocation& location, const ast::ptr<ast::Type> & type,
250 Iterator & begin, const Iterator & end ) {
251 if ( auto tuple = type->as<ast::TupleType>() ) {
252 std::vector<ast::ptr<ast::Expr>> exprs;
253 for ( const ast::Type * t : *tuple ) {
254 exprs.push_back( structureArg( location, t, begin, end ) );
255 }
256 return new ast::TupleExpr( location, std::move( exprs ) );
257 } else {
258 assertf( begin != end, "reached the end of the arguments while structuring" );
259 return *begin++;
260 }
261}
262#endif
263
264namespace {
265 struct TypeInstFixer : public ast::WithShortCircuiting {
266 std::map<const ast::TypeDecl *, std::pair<int, int>> typeMap;
267
268 void previsit(const ast::TypeDecl *) { visit_children = false; }
269 const ast::TypeInstType * postvisit(const ast::TypeInstType * typeInst) {
270 if (typeMap.count(typeInst->base)) {
271 ast::TypeInstType * newInst = mutate(typeInst);
272 newInst->expr_id = typeMap[typeInst->base].first;
273 newInst->formal_usage = typeMap[typeInst->base].second;
274 return newInst;
275 }
276 return typeInst;
277 }
278 };
279}
280
281const ast::Expr * SpecializeCore::createThunkFunction(
282 const CodeLocation & location,
283 const ast::FunctionType * funType,
284 const ast::Expr * actual,
285 const ast::InferredParams * inferParams ) {
286 // One set of unique names per program.
287 static UniqueName thunkNamer("_thunk");
288
289 const ast::FunctionType * newType = ast::deepCopy( funType );
290 if ( typeSubs ) {
291 // Must replace only occurrences of type variables
292 // that occure free in the thunk's type.
293 //ast::TypeSubstitution::ApplyResult<ast::FunctionType>
294 // typeSubs->applyFree( newType );
295 auto result = typeSubs->applyFree( newType );
296 newType = result.node.release();
297 }
298
299 using DWTVector = std::vector<ast::ptr<ast::DeclWithType>>;
300 using DeclVector = std::vector<ast::ptr<ast::TypeDecl>>;
301
302 //const std::string & thunkName = thunkNamer.newName();
303 //UniqueName paramNamer(thunkName + "Param");
304 UniqueName paramNamer( paramPrefix );
305
306 //auto toParamDecl = [&location, &paramNamer]( const ast::Type * type ) {
307 // return new ast::ObjectDecl(
308 // location, paramNamer.newName(), ast::deepCopy( type ) );
309 //};
310
311 // Create new thunk with same signature as formal type.
312
313 // std::map<const ast::TypeDecl *, std::pair<int, int>> typeMap;
314 ast::Pass<TypeInstFixer> fixer;
315 for (const auto & kv : newType->forall) {
316 if (fixer.core.typeMap.count(kv->base)) {
317 std::cerr << location << ' ' << kv->base->name << ' ' << kv->expr_id << '_' << kv->formal_usage << ','
318 << fixer.core.typeMap[kv->base].first << '_' << fixer.core.typeMap[kv->base].second << std::endl;
319 assertf(false, "multiple formals in specialize");
320 }
321 else {
322 fixer.core.typeMap[kv->base] = std::make_pair(kv->expr_id, kv->formal_usage);
323 }
324 }
325
326 ast::CompoundStmt * thunkBody = new ast::CompoundStmt( location );
327 ast::FunctionDecl * thunkFunc = new ast::FunctionDecl(
328 location,
329 thunkNamer.newName(),
330 map_range<DeclVector>( newType->forall, []( const ast::TypeInstType * inst ) {
331 return ast::deepCopy( inst->base );
332 } ),
333 map_range<DWTVector>( newType->assertions, []( const ast::VariableExpr * expr ) {
334 return ast::deepCopy( expr->var );
335 } ),
336 map_range<DWTVector>( newType->params, [&location, &paramNamer]( const ast::Type * type ) {
337 return new ast::ObjectDecl( location, paramNamer.newName(), ast::deepCopy( type ) );
338 } ),
339 map_range<DWTVector>( newType->returns, [&location, &paramNamer]( const ast::Type * type ) {
340 return new ast::ObjectDecl( location, paramNamer.newName(), ast::deepCopy( type ) );
341 } ),
342 thunkBody,
343 ast::Storage::Classes(),
344 ast::Linkage::C
345 );
346
347 // thunkFunc->accept(fixer);
348 thunkFunc->fixUniqueId();
349
350
351
352 // Thunks may be generated and not used, avoid them.
353 thunkFunc->attributes.push_back( new ast::Attribute( "unused" ) );
354
355 // Global thunks must be static to avoid collitions.
356 // Nested thunks must not be unique and hence, not static.
357 thunkFunc->storage.is_static = !isInFunction();
358
359 // Weave thunk parameters into call to actual function,
360 // naming thunk parameters as we go.
361 ast::ApplicationExpr * app = new ast::ApplicationExpr( location, actual );
362
363 const ast::FunctionType * actualType = ast::deepCopy( getFunctionType( actual->result ) );
364 if ( typeSubs ) {
365 // Need to apply the environment to the actual function's type,
366 // since it may itself be polymorphic.
367 auto result = typeSubs->apply( actualType );
368 actualType = result.node.release();
369 }
370
371 ast::ptr<ast::FunctionType> actualTypeManager = actualType;
372
373 std::vector<ast::ptr<ast::Expr>> args;
374 for ( ast::ptr<ast::DeclWithType> & param : thunkFunc->params ) {
375 // Name each thunk parameter and explode it.
376 // These are then threaded back into the actual function call.
377 //param->name = paramNamer.newName();
378 ast::DeclWithType * mutParam = ast::mutate( param.get() );
379 // - Should be renamed earlier. -
380 //mutParam->name = paramNamer.newName();
381 explodeSimple( location, new ast::VariableExpr( location, mutParam ),
382 std::back_inserter( args ) );
383 }
384
385 // Walk parameters to the actual function alongside the exploded thunk
386 // parameters and restructure the arguments to match the actual parameters.
387 std::vector<ast::ptr<ast::Expr>>::iterator
388 argBegin = args.begin(), argEnd = args.end();
389 for ( const auto & actualArg : actualType->params ) {
390 structureArg( location, actualArg.get(), argBegin, argEnd,
391 std::back_inserter( app->args ) );
392 }
393 assertf( argBegin == argEnd, "Did not structure all arguments." );
394
395 app->accept(fixer); // this should modify in place
396
397 app->env = ast::TypeSubstitution::newFromExpr( app, typeSubs );
398 if ( inferParams ) {
399 app->inferred.inferParams() = *inferParams;
400 }
401
402 // Handle any specializations that may still be present.
403 {
404 std::string oldParamPrefix = paramPrefix;
405 paramPrefix += "p";
406 std::list<ast::ptr<ast::Decl>> oldDecls;
407 oldDecls.splice( oldDecls.end(), declsToAddBefore );
408
409 app->accept( *visitor );
410 // Write recursive specializations into the thunk body.
411 for ( const ast::ptr<ast::Decl> & decl : declsToAddBefore ) {
412 thunkBody->push_back( new ast::DeclStmt( decl->location, decl ) );
413 }
414
415 declsToAddBefore = std::move( oldDecls );
416 paramPrefix = std::move( oldParamPrefix );
417 }
418
419 // Add return (or valueless expression) to the thunk.
420 ast::Stmt * appStmt;
421 if ( funType->returns.empty() ) {
422 appStmt = new ast::ExprStmt( app->location, app );
423 } else {
424 appStmt = new ast::ReturnStmt( app->location, app );
425 }
426 thunkBody->push_back( appStmt );
427
428 // Add the thunk definition:
429 declsToAddBefore.push_back( thunkFunc );
430
431 // Return address of thunk function as replacement expression.
432 return new ast::AddressExpr( location,
433 new ast::VariableExpr( location, thunkFunc ) );
434}
435
436const ast::Expr * SpecializeCore::doSpecialization(
437 const CodeLocation & location,
438 const ast::Type * formalType,
439 const ast::Expr * actual,
440 const ast::InferredParams * inferParams ) {
441 assertf( actual->result, "attempting to specialize an untyped expression" );
442 if ( needsSpecialization( formalType, actual->result, typeSubs ) ) {
443 if ( const ast::FunctionType * type = getFunctionType( formalType ) ) {
444 if ( const ast::ApplicationExpr * expr =
445 dynamic_cast<const ast::ApplicationExpr *>( actual ) ) {
446 return createThunkFunction( location, type, expr->func, inferParams );
447 } else if ( auto expr =
448 dynamic_cast<const ast::VariableExpr *>( actual ) ) {
449 return createThunkFunction( location, type, expr, inferParams );
450 } else {
451 // (I don't even know what that comment means.)
452 // This likely won't work, as anything that could build an ApplicationExpr probably hit one of the previous two branches
453 return createThunkFunction( location, type, actual, inferParams );
454 }
455 } else {
456 return actual;
457 }
458 } else {
459 return actual;
460 }
461}
462
463const ast::Expr * SpecializeCore::postvisit(
464 const ast::ApplicationExpr * expr ) {
465 if ( InitTweak::isIntrinsicCallExpr( expr ) ) {
466 return expr;
467 }
468
469 // Create thunks for the inferred parameters.
470 // This is not needed for intrinsic calls, because they aren't
471 // actually passed
472 //
473 // need to handle explicit params before inferred params so that
474 // explicit params do not recieve a changed set of inferParams (and
475 // change them again) alternatively, if order starts to matter then
476 // copy appExpr's inferParams and pass them to handleExplicitParams.
477 ast::ApplicationExpr * mut = handleExplicitParams( expr );
478 if ( !mut->inferred.hasParams() ) {
479 return mut;
480 }
481 ast::InferredParams & inferParams = mut->inferred.inferParams();
482 for ( ast::InferredParams::value_type & inferParam : inferParams ) {
483 inferParam.second.expr = doSpecialization(
484 inferParam.second.expr->location,
485 inferParam.second.formalType,
486 inferParam.second.expr,
487 getInferredParams( inferParam.second.expr )
488 );
489 }
490 return mut;
491}
492
493const ast::Expr * SpecializeCore::postvisit( const ast::CastExpr * expr ) {
494 if ( expr->result->isVoid() ) {
495 // No specialization if there is no return value.
496 return expr;
497 }
498 const ast::Expr * specialized = doSpecialization(
499 expr->location, expr->result, expr->arg, getInferredParams( expr ) );
500 if ( specialized != expr->arg ) {
501 // Assume that the specialization incorporates the cast.
502 std::cerr << expr <<std::endl;
503 return specialized;
504 } else {
505 return expr;
506 }
507}
508
509} // namespace
510
511void convertSpecializations( ast::TranslationUnit & translationUnit ) {
512 ast::Pass<SpecializeCore>::run( translationUnit );
513}
514
515} // namespace GenPoly
516
517// Local Variables: //
518// tab-width: 4 //
519// mode: c++ //
520// compile-command: "make install" //
521// End: //
Note: See TracBrowser for help on using the repository browser.