source: src/GenPoly/SpecializeNew.cpp@ 34c6e1e6

Last change on this file since 34c6e1e6 was 8f31be6, checked in by Andrew Beach <ajbeach@…>, 2 years ago

Fixed some warnings, deleted some commented out code.

  • Property mode set to 100644
File size: 16.2 KB
Line 
1//
2// Cforall Version 1.0.0 Copyright (C) 2015 University of Waterloo
3//
4// The contents of this file are covered under the licence agreement in the
5// file "LICENCE" distributed with Cforall.
6//
7// SpecializeNew.cpp -- Generate thunks to specialize polymorphic functions.
8//
9// Author : Andrew Beach
10// Created On : Tue Jun 7 13:37:00 2022
11// Last Modified By : Andrew Beach
12// Last Modified On : Tue Jun 7 13:37:00 2022
13// Update Count : 0
14//
15
16#include "Specialize.h"
17
18#include "AST/Copy.hpp" // for deepCopy
19#include "AST/Inspect.hpp" // for isIntrinsicCallExpr
20#include "AST/Pass.hpp" // for Pass
21#include "AST/TypeEnvironment.hpp" // for OpenVarSet, AssertionSet
22#include "Common/UniqueName.h" // for UniqueName
23#include "GenPoly/GenPoly.h" // for getFunctionType
24#include "ResolvExpr/FindOpenVars.h" // for findOpenVars
25#include "ResolvExpr/TypeEnvironment.h" // for FirstOpen, FirstClosed
26
27namespace GenPoly {
28
29namespace {
30
31struct SpecializeCore final :
32 public ast::WithConstTypeSubstitution,
33 public ast::WithDeclsToAdd<>,
34 public ast::WithVisitorRef<SpecializeCore> {
35 std::string paramPrefix = "_p";
36
37 ast::ApplicationExpr * handleExplicitParams(
38 const ast::ApplicationExpr * expr );
39 const ast::Expr * createThunkFunction(
40 const CodeLocation & location,
41 const ast::FunctionType * funType,
42 const ast::Expr * actual,
43 const ast::InferredParams * inferParams );
44 const ast::Expr * doSpecialization(
45 const CodeLocation & location,
46 const ast::Type * formalType,
47 const ast::Expr * actual,
48 const ast::InferredParams * inferParams );
49
50 const ast::Expr * postvisit( const ast::ApplicationExpr * expr );
51 const ast::Expr * postvisit( const ast::CastExpr * expr );
52};
53
54const ast::InferredParams * getInferredParams( const ast::Expr * expr ) {
55 const ast::Expr::InferUnion & inferred = expr->inferred;
56 if ( inferred.hasParams() ) {
57 return &inferred.inferParams();
58 } else {
59 return nullptr;
60 }
61}
62
63// Check if both types have the same structure. The leaf (non-tuple) types
64// don't have to match but the tuples must match.
65bool isTupleStructureMatching( const ast::Type * t0, const ast::Type * t1 ) {
66 const ast::TupleType * tt0 = dynamic_cast<const ast::TupleType *>( t0 );
67 const ast::TupleType * tt1 = dynamic_cast<const ast::TupleType *>( t1 );
68 if ( tt0 && tt1 ) {
69 if ( tt0->size() != tt1->size() ) {
70 return false;
71 }
72 for ( auto types : group_iterate( tt0->types, tt1->types ) ) {
73 if ( !isTupleStructureMatching(
74 std::get<0>( types ), std::get<1>( types ) ) ) {
75 return false;
76 }
77 }
78 return true;
79 }
80 return (!tt0 && !tt1);
81}
82
83// The number of elements in a type if it is a flattened tuple.
84size_t flatTupleSize( const ast::Type * type ) {
85 if ( auto tuple = dynamic_cast<const ast::TupleType *>( type ) ) {
86 size_t sum = 0;
87 for ( auto t : *tuple ) {
88 sum += flatTupleSize( t );
89 }
90 return sum;
91 } else {
92 return 1;
93 }
94}
95
96// Find the total number of components in a parameter list.
97size_t functionParameterSize( const ast::FunctionType * type ) {
98 size_t sum = 0;
99 for ( auto param : type->params ) {
100 sum += flatTupleSize( param );
101 }
102 return sum;
103}
104
105bool needsPolySpecialization(
106 const ast::Type * /*formalType*/,
107 const ast::Type * actualType,
108 const ast::TypeSubstitution * subs ) {
109 if ( !subs ) {
110 return false;
111 }
112
113 using namespace ResolvExpr;
114 ast::OpenVarSet openVars, closedVars;
115 ast::AssertionSet need, have; // unused
116 ast::TypeEnvironment env; // unused
117 // findOpenVars( formalType, openVars, closedVars, need, have, FirstClosed );
118 findOpenVars( actualType, openVars, closedVars, need, have, env, FirstOpen );
119 for ( const ast::OpenVarSet::value_type & openVar : openVars ) {
120 const ast::Type * boundType = subs->lookup( openVar.first );
121 // If the variable is not bound, move onto the next variable.
122 if ( !boundType ) continue;
123
124 // Is the variable cound to another type variable?
125 if ( auto inst = dynamic_cast<const ast::TypeInstType *>( boundType ) ) {
126 if ( closedVars.find( *inst ) == closedVars.end() ) {
127 return true;
128 } else {
129 assertf(false, "closed: %s", inst->name.c_str());
130 }
131 // Otherwise, the variable is bound to a concrete type.
132 } else {
133 return true;
134 }
135 }
136 // None of the type variables are bound.
137 return false;
138}
139
140bool needsTupleSpecialization(
141 const ast::Type * formalType, const ast::Type * actualType ) {
142 // Needs tuple specialization if the structure of the formal type and
143 // actual type do not match.
144
145 // This is the case if the formal type has ttype polymorphism, or if the structure of tuple types
146 // between the function do not match exactly.
147 if ( const ast::FunctionType * ftype = getFunctionType( formalType ) ) {
148 // A pack in the parameter or return type requires specialization.
149 if ( ftype->isTtype() ) {
150 return true;
151 }
152 // Conversion of 0 to a function type does not require specialization.
153 if ( dynamic_cast<const ast::ZeroType *>( actualType ) ) {
154 return false;
155 }
156 const ast::FunctionType * atype =
157 getFunctionType( actualType->stripReferences() );
158 assertf( atype,
159 "formal type is a function type, but actual type is not: %s",
160 toString( actualType ).c_str() );
161 // Can't tuple specialize if parameter sizes deeply-differ.
162 if ( functionParameterSize( ftype ) != functionParameterSize( atype ) ) {
163 return false;
164 }
165 // If tuple parameter size matches but actual parameter sizes differ
166 // then there needs to be specialization.
167 if ( ftype->params.size() != atype->params.size() ) {
168 return true;
169 }
170 // Total parameter size can be the same, while individual parameters
171 // can have different structure.
172 for ( auto pairs : group_iterate( ftype->params, atype->params ) ) {
173 if ( !isTupleStructureMatching(
174 std::get<0>( pairs ), std::get<1>( pairs ) ) ) {
175 return true;
176 }
177 }
178 }
179 return false;
180}
181
182bool needsSpecialization(
183 const ast::Type * formalType, const ast::Type * actualType,
184 const ast::TypeSubstitution * subs ) {
185 return needsPolySpecialization( formalType, actualType, subs )
186 || needsTupleSpecialization( formalType, actualType );
187}
188
189ast::ApplicationExpr * SpecializeCore::handleExplicitParams(
190 const ast::ApplicationExpr * expr ) {
191 assert( expr->func->result );
192 const ast::FunctionType * func = getFunctionType( expr->func->result );
193 assert( func );
194
195 ast::ApplicationExpr * mut = ast::mutate( expr );
196
197 std::vector<ast::ptr<ast::Type>>::const_iterator formal;
198 std::vector<ast::ptr<ast::Expr>>::iterator actual;
199 for ( formal = func->params.begin(), actual = mut->args.begin() ;
200 formal != func->params.end() && actual != mut->args.end() ;
201 ++formal, ++actual ) {
202 *actual = doSpecialization( (*actual)->location,
203 *formal, *actual, getInferredParams( expr ) );
204 }
205 return mut;
206}
207
208// Explode assuming simple cases: either type is pure tuple (but not tuple
209// expr) or type is non-tuple.
210template<typename OutputIterator>
211void explodeSimple( const CodeLocation & location,
212 const ast::Expr * expr, OutputIterator out ) {
213 // Recurse on tuple types using index expressions on each component.
214 if ( auto tuple = expr->result.as<ast::TupleType>() ) {
215 ast::ptr<ast::Expr> cleanup = expr;
216 for ( unsigned int i = 0 ; i < tuple->size() ; ++i ) {
217 explodeSimple( location,
218 new ast::TupleIndexExpr( location, expr, i ), out );
219 }
220 // For a non-tuple type, output a clone of the expression.
221 } else {
222 *out++ = expr;
223 }
224}
225
226// Restructures arguments to match the structure of the formal parameters
227// of the actual function. Returns the next structured argument.
228template<typename Iterator>
229const ast::Expr * structureArg(
230 const CodeLocation& location, const ast::ptr<ast::Type> & type,
231 Iterator & begin, const Iterator & end ) {
232 if ( auto tuple = type.as<ast::TupleType>() ) {
233 std::vector<ast::ptr<ast::Expr>> exprs;
234 for ( const ast::ptr<ast::Type> & t : *tuple ) {
235 exprs.push_back( structureArg( location, t, begin, end ) );
236 }
237 return new ast::TupleExpr( location, std::move( exprs ) );
238 } else {
239 assertf( begin != end, "reached the end of the arguments while structuring" );
240 return *begin++;
241 }
242}
243
244struct TypeInstFixer final : public ast::WithShortCircuiting {
245 std::map<const ast::TypeDecl *, std::pair<int, int>> typeMap;
246
247 void previsit(const ast::TypeDecl *) { visit_children = false; }
248 const ast::TypeInstType * postvisit(const ast::TypeInstType * typeInst) {
249 if (typeMap.count(typeInst->base)) {
250 ast::TypeInstType * newInst = mutate(typeInst);
251 auto const & pair = typeMap[typeInst->base];
252 newInst->expr_id = pair.first;
253 newInst->formal_usage = pair.second;
254 return newInst;
255 }
256 return typeInst;
257 }
258};
259
260const ast::Expr * SpecializeCore::createThunkFunction(
261 const CodeLocation & location,
262 const ast::FunctionType * funType,
263 const ast::Expr * actual,
264 const ast::InferredParams * inferParams ) {
265 // One set of unique names per program.
266 static UniqueName thunkNamer("_thunk");
267
268 const ast::FunctionType * newType = ast::deepCopy( funType );
269 if ( typeSubs ) {
270 // Must replace only occurrences of type variables
271 // that occure free in the thunk's type.
272 auto result = typeSubs->applyFree( newType );
273 newType = result.node.release();
274 }
275
276 using DWTVector = std::vector<ast::ptr<ast::DeclWithType>>;
277 using DeclVector = std::vector<ast::ptr<ast::TypeDecl>>;
278
279 UniqueName paramNamer( paramPrefix );
280
281 // Create new thunk with same signature as formal type.
282 ast::Pass<TypeInstFixer> fixer;
283 for (const auto & kv : newType->forall) {
284 if (fixer.core.typeMap.count(kv->base)) {
285 std::cerr << location << ' ' << kv->base->name
286 << ' ' << kv->expr_id << '_' << kv->formal_usage
287 << ',' << fixer.core.typeMap[kv->base].first
288 << '_' << fixer.core.typeMap[kv->base].second << std::endl;
289 assertf(false, "multiple formals in specialize");
290 }
291 else {
292 fixer.core.typeMap[kv->base] = std::make_pair(kv->expr_id, kv->formal_usage);
293 }
294 }
295
296 ast::CompoundStmt * thunkBody = new ast::CompoundStmt( location );
297 ast::FunctionDecl * thunkFunc = new ast::FunctionDecl(
298 location,
299 thunkNamer.newName(),
300 map_range<DeclVector>( newType->forall, []( const ast::TypeInstType * inst ) {
301 return ast::deepCopy( inst->base );
302 } ),
303 map_range<DWTVector>( newType->assertions, []( const ast::VariableExpr * expr ) {
304 return ast::deepCopy( expr->var );
305 } ),
306 map_range<DWTVector>( newType->params, [&location, &paramNamer]( const ast::Type * type ) {
307 return new ast::ObjectDecl( location, paramNamer.newName(), ast::deepCopy( type ) );
308 } ),
309 map_range<DWTVector>( newType->returns, [&location, &paramNamer]( const ast::Type * type ) {
310 return new ast::ObjectDecl( location, paramNamer.newName(), ast::deepCopy( type ) );
311 } ),
312 thunkBody,
313 ast::Storage::Classes(),
314 ast::Linkage::C
315 );
316
317 thunkFunc->fixUniqueId();
318
319 // Thunks may be generated and not used, avoid them.
320 thunkFunc->attributes.push_back( new ast::Attribute( "unused" ) );
321
322 // Global thunks must be static to avoid collitions.
323 // Nested thunks must not be unique and hence, not static.
324 thunkFunc->storage.is_static = !isInFunction();
325
326 // Weave thunk parameters into call to actual function,
327 // naming thunk parameters as we go.
328 ast::ApplicationExpr * app = new ast::ApplicationExpr( location, actual );
329
330 const ast::FunctionType * actualType = ast::deepCopy( getFunctionType( actual->result ) );
331 if ( typeSubs ) {
332 // Need to apply the environment to the actual function's type,
333 // since it may itself be polymorphic.
334 auto result = typeSubs->apply( actualType );
335 actualType = result.node.release();
336 }
337
338 ast::ptr<ast::FunctionType> actualTypeManager = actualType;
339
340 std::vector<ast::ptr<ast::Expr>> args;
341 for ( ast::ptr<ast::DeclWithType> & param : thunkFunc->params ) {
342 // Name each thunk parameter and explode it.
343 // These are then threaded back into the actual function call.
344 ast::DeclWithType * mutParam = ast::mutate( param.get() );
345 explodeSimple( location, new ast::VariableExpr( location, mutParam ),
346 std::back_inserter( args ) );
347 }
348
349 // Walk parameters to the actual function alongside the exploded thunk
350 // parameters and restructure the arguments to match the actual parameters.
351 std::vector<ast::ptr<ast::Expr>>::iterator
352 argBegin = args.begin(), argEnd = args.end();
353 for ( const auto & actualArg : actualType->params ) {
354 app->args.push_back(
355 structureArg( location, actualArg.get(), argBegin, argEnd ) );
356 }
357 assertf( argBegin == argEnd, "Did not structure all arguments." );
358
359 app->accept(fixer); // this should modify in place
360
361 app->env = ast::TypeSubstitution::newFromExpr( app, typeSubs );
362 if ( inferParams ) {
363 app->inferred.inferParams() = *inferParams;
364 }
365
366 // Handle any specializations that may still be present.
367 {
368 std::string oldParamPrefix = paramPrefix;
369 paramPrefix += "p";
370 std::list<ast::ptr<ast::Decl>> oldDecls;
371 oldDecls.splice( oldDecls.end(), declsToAddBefore );
372
373 app->accept( *visitor );
374 // Write recursive specializations into the thunk body.
375 for ( const ast::ptr<ast::Decl> & decl : declsToAddBefore ) {
376 thunkBody->push_back( new ast::DeclStmt( decl->location, decl ) );
377 }
378
379 declsToAddBefore = std::move( oldDecls );
380 paramPrefix = std::move( oldParamPrefix );
381 }
382
383 // Add return (or valueless expression) to the thunk.
384 ast::Stmt * appStmt;
385 if ( funType->returns.empty() ) {
386 appStmt = new ast::ExprStmt( app->location, app );
387 } else {
388 appStmt = new ast::ReturnStmt( app->location, app );
389 }
390 thunkBody->push_back( appStmt );
391
392 // Add the thunk definition:
393 declsToAddBefore.push_back( thunkFunc );
394
395 // Return address of thunk function as replacement expression.
396 return new ast::AddressExpr( location,
397 new ast::VariableExpr( location, thunkFunc ) );
398}
399
400const ast::Expr * SpecializeCore::doSpecialization(
401 const CodeLocation & location,
402 const ast::Type * formalType,
403 const ast::Expr * actual,
404 const ast::InferredParams * inferParams ) {
405 assertf( actual->result, "attempting to specialize an untyped expression" );
406 if ( needsSpecialization( formalType, actual->result, typeSubs ) ) {
407 if ( const ast::FunctionType * type = getFunctionType( formalType ) ) {
408 if ( const ast::ApplicationExpr * expr =
409 dynamic_cast<const ast::ApplicationExpr *>( actual ) ) {
410 return createThunkFunction( location, type, expr->func, inferParams );
411 } else if ( auto expr =
412 dynamic_cast<const ast::VariableExpr *>( actual ) ) {
413 return createThunkFunction( location, type, expr, inferParams );
414 } else {
415 // (I don't even know what that comment means.)
416 // This likely won't work, as anything that could build an ApplicationExpr probably hit one of the previous two branches
417 return createThunkFunction( location, type, actual, inferParams );
418 }
419 } else {
420 return actual;
421 }
422 } else {
423 return actual;
424 }
425}
426
427const ast::Expr * SpecializeCore::postvisit(
428 const ast::ApplicationExpr * expr ) {
429 if ( ast::isIntrinsicCallExpr( expr ) ) {
430 return expr;
431 }
432
433 // Create thunks for the inferred parameters.
434 // This is not needed for intrinsic calls, because they aren't
435 // actually passed to the function. It needs to handle explicit params
436 // before inferred params so that explicit params do not recieve a
437 // changed set of inferParams (and change them again).
438 // Alternatively, if order starts to matter then copy expr's inferParams
439 // and pass them to handleExplicitParams.
440 ast::ApplicationExpr * mut = handleExplicitParams( expr );
441 if ( !mut->inferred.hasParams() ) {
442 return mut;
443 }
444 ast::InferredParams & inferParams = mut->inferred.inferParams();
445 for ( ast::InferredParams::value_type & inferParam : inferParams ) {
446 inferParam.second.expr = doSpecialization(
447 inferParam.second.expr->location,
448 inferParam.second.formalType,
449 inferParam.second.expr,
450 getInferredParams( inferParam.second.expr )
451 );
452 }
453 return mut;
454}
455
456const ast::Expr * SpecializeCore::postvisit( const ast::CastExpr * expr ) {
457 if ( expr->result->isVoid() ) {
458 // No specialization if there is no return value.
459 return expr;
460 }
461 const ast::Expr * specialized = doSpecialization(
462 expr->location, expr->result, expr->arg, getInferredParams( expr ) );
463 if ( specialized != expr->arg ) {
464 // Assume that the specialization incorporates the cast.
465 return specialized;
466 } else {
467 return expr;
468 }
469}
470
471} // namespace
472
473void convertSpecializations( ast::TranslationUnit & translationUnit ) {
474 ast::Pass<SpecializeCore>::run( translationUnit );
475}
476
477} // namespace GenPoly
478
479// Local Variables: //
480// tab-width: 4 //
481// mode: c++ //
482// compile-command: "make install" //
483// End: //
Note: See TracBrowser for help on using the repository browser.