source: src/GenPoly/SpecializeNew.cpp@ 8d182b1

Last change on this file since 8d182b1 was c6b4432, checked in by Andrew Beach <ajbeach@…>, 23 months ago

Remove BaseSyntaxNode and clean-up.

  • Property mode set to 100644
File size: 16.1 KB
Line 
1//
2// Cforall Version 1.0.0 Copyright (C) 2015 University of Waterloo
3//
4// The contents of this file are covered under the licence agreement in the
5// file "LICENCE" distributed with Cforall.
6//
7// SpecializeNew.cpp -- Generate thunks to specialize polymorphic functions.
8//
9// Author : Andrew Beach
10// Created On : Tue Jun 7 13:37:00 2022
11// Last Modified By : Andrew Beach
12// Last Modified On : Tue Jun 7 13:37:00 2022
13// Update Count : 0
14//
15
16#include "Specialize.h"
17
18#include "AST/Copy.hpp" // for deepCopy
19#include "AST/Inspect.hpp" // for isIntrinsicCallExpr
20#include "AST/Pass.hpp" // for Pass
21#include "AST/TypeEnvironment.hpp" // for OpenVarSet, AssertionSet
22#include "Common/UniqueName.h" // for UniqueName
23#include "GenPoly/GenPoly.h" // for getFunctionType
24#include "ResolvExpr/FindOpenVars.h" // for findOpenVars
25
26namespace GenPoly {
27
28namespace {
29
30struct SpecializeCore final :
31 public ast::WithConstTypeSubstitution,
32 public ast::WithDeclsToAdd<>,
33 public ast::WithVisitorRef<SpecializeCore> {
34 std::string paramPrefix = "_p";
35
36 ast::ApplicationExpr * handleExplicitParams(
37 const ast::ApplicationExpr * expr );
38 const ast::Expr * createThunkFunction(
39 const CodeLocation & location,
40 const ast::FunctionType * funType,
41 const ast::Expr * actual,
42 const ast::InferredParams * inferParams );
43 const ast::Expr * doSpecialization(
44 const CodeLocation & location,
45 const ast::Type * formalType,
46 const ast::Expr * actual,
47 const ast::InferredParams * inferParams );
48
49 const ast::Expr * postvisit( const ast::ApplicationExpr * expr );
50 const ast::Expr * postvisit( const ast::CastExpr * expr );
51};
52
53const ast::InferredParams * getInferredParams( const ast::Expr * expr ) {
54 const ast::Expr::InferUnion & inferred = expr->inferred;
55 if ( inferred.hasParams() ) {
56 return &inferred.inferParams();
57 } else {
58 return nullptr;
59 }
60}
61
62// Check if both types have the same structure. The leaf (non-tuple) types
63// don't have to match but the tuples must match.
64bool isTupleStructureMatching( const ast::Type * t0, const ast::Type * t1 ) {
65 const ast::TupleType * tt0 = dynamic_cast<const ast::TupleType *>( t0 );
66 const ast::TupleType * tt1 = dynamic_cast<const ast::TupleType *>( t1 );
67 if ( tt0 && tt1 ) {
68 if ( tt0->size() != tt1->size() ) {
69 return false;
70 }
71 for ( auto types : group_iterate( tt0->types, tt1->types ) ) {
72 if ( !isTupleStructureMatching(
73 std::get<0>( types ), std::get<1>( types ) ) ) {
74 return false;
75 }
76 }
77 return true;
78 }
79 return (!tt0 && !tt1);
80}
81
82// The number of elements in a list, if all tuples had been flattened.
83size_t flatTypeListSize( const std::vector<ast::ptr<ast::Type>> & types ) {
84 size_t sum = 0;
85 for ( const ast::ptr<ast::Type> & type : types ) {
86 if ( const ast::TupleType * tuple = type.as<ast::TupleType>() ) {
87 sum += flatTypeListSize( tuple->types );
88 } else {
89 sum += 1;
90 }
91 }
92 return sum;
93}
94
95// Find the total number of components in a parameter list.
96size_t functionParameterSize( const ast::FunctionType * type ) {
97 return flatTypeListSize( type->params );
98}
99
100bool needsPolySpecialization(
101 const ast::Type * /*formalType*/,
102 const ast::Type * actualType,
103 const ast::TypeSubstitution * subs ) {
104 if ( !subs ) {
105 return false;
106 }
107
108 using namespace ResolvExpr;
109 ast::OpenVarSet openVars, closedVars;
110 ast::AssertionSet need, have; // unused
111 ast::TypeEnvironment env; // unused
112 // findOpenVars( formalType, openVars, closedVars, need, have, FirstClosed );
113 findOpenVars( actualType, openVars, closedVars, need, have, env, FirstOpen );
114 for ( const ast::OpenVarSet::value_type & openVar : openVars ) {
115 const ast::Type * boundType = subs->lookup( openVar.first );
116 // If the variable is not bound, move onto the next variable.
117 if ( !boundType ) continue;
118
119 // Is the variable cound to another type variable?
120 if ( auto inst = dynamic_cast<const ast::TypeInstType *>( boundType ) ) {
121 if ( closedVars.find( *inst ) == closedVars.end() ) {
122 return true;
123 } else {
124 assertf(false, "closed: %s", inst->name.c_str());
125 }
126 // Otherwise, the variable is bound to a concrete type.
127 } else {
128 return true;
129 }
130 }
131 // None of the type variables are bound.
132 return false;
133}
134
135bool needsTupleSpecialization(
136 const ast::Type * formalType, const ast::Type * actualType ) {
137 // Needs tuple specialization if the structure of the formal type and
138 // actual type do not match.
139
140 // This is the case if the formal type has ttype polymorphism, or if the structure of tuple types
141 // between the function do not match exactly.
142 if ( const ast::FunctionType * ftype = getFunctionType( formalType ) ) {
143 // A pack in the parameter or return type requires specialization.
144 if ( ftype->isTtype() ) {
145 return true;
146 }
147 // Conversion of 0 to a function type does not require specialization.
148 if ( dynamic_cast<const ast::ZeroType *>( actualType ) ) {
149 return false;
150 }
151 const ast::FunctionType * atype =
152 getFunctionType( actualType->stripReferences() );
153 assertf( atype,
154 "formal type is a function type, but actual type is not: %s",
155 toString( actualType ).c_str() );
156 // Can't tuple specialize if parameter sizes deeply-differ.
157 if ( functionParameterSize( ftype ) != functionParameterSize( atype ) ) {
158 return false;
159 }
160 // If tuple parameter size matches but actual parameter sizes differ
161 // then there needs to be specialization.
162 if ( ftype->params.size() != atype->params.size() ) {
163 return true;
164 }
165 // Total parameter size can be the same, while individual parameters
166 // can have different structure.
167 for ( auto pairs : group_iterate( ftype->params, atype->params ) ) {
168 if ( !isTupleStructureMatching(
169 std::get<0>( pairs ), std::get<1>( pairs ) ) ) {
170 return true;
171 }
172 }
173 }
174 return false;
175}
176
177bool needsSpecialization(
178 const ast::Type * formalType, const ast::Type * actualType,
179 const ast::TypeSubstitution * subs ) {
180 return needsPolySpecialization( formalType, actualType, subs )
181 || needsTupleSpecialization( formalType, actualType );
182}
183
184ast::ApplicationExpr * SpecializeCore::handleExplicitParams(
185 const ast::ApplicationExpr * expr ) {
186 assert( expr->func->result );
187 const ast::FunctionType * func = getFunctionType( expr->func->result );
188 assert( func );
189
190 ast::ApplicationExpr * mut = ast::mutate( expr );
191
192 std::vector<ast::ptr<ast::Type>>::const_iterator formal;
193 std::vector<ast::ptr<ast::Expr>>::iterator actual;
194 for ( formal = func->params.begin(), actual = mut->args.begin() ;
195 formal != func->params.end() && actual != mut->args.end() ;
196 ++formal, ++actual ) {
197 *actual = doSpecialization( (*actual)->location,
198 *formal, *actual, getInferredParams( expr ) );
199 }
200 return mut;
201}
202
203// Explode assuming simple cases: either type is pure tuple (but not tuple
204// expr) or type is non-tuple.
205template<typename OutputIterator>
206void explodeSimple( const CodeLocation & location,
207 const ast::Expr * expr, OutputIterator out ) {
208 // Recurse on tuple types using index expressions on each component.
209 if ( auto tuple = expr->result.as<ast::TupleType>() ) {
210 ast::ptr<ast::Expr> cleanup = expr;
211 for ( unsigned int i = 0 ; i < tuple->size() ; ++i ) {
212 explodeSimple( location,
213 new ast::TupleIndexExpr( location, expr, i ), out );
214 }
215 // For a non-tuple type, output a clone of the expression.
216 } else {
217 *out++ = expr;
218 }
219}
220
221// Restructures arguments to match the structure of the formal parameters
222// of the actual function. Returns the next structured argument.
223template<typename Iterator>
224const ast::Expr * structureArg(
225 const CodeLocation& location, const ast::ptr<ast::Type> & type,
226 Iterator & begin, const Iterator & end ) {
227 if ( auto tuple = type.as<ast::TupleType>() ) {
228 std::vector<ast::ptr<ast::Expr>> exprs;
229 for ( const ast::ptr<ast::Type> & t : *tuple ) {
230 exprs.push_back( structureArg( location, t, begin, end ) );
231 }
232 return new ast::TupleExpr( location, std::move( exprs ) );
233 } else {
234 assertf( begin != end, "reached the end of the arguments while structuring" );
235 return *begin++;
236 }
237}
238
239struct TypeInstFixer final : public ast::WithShortCircuiting {
240 std::map<const ast::TypeDecl *, std::pair<int, int>> typeMap;
241
242 void previsit(const ast::TypeDecl *) { visit_children = false; }
243 const ast::TypeInstType * postvisit(const ast::TypeInstType * typeInst) {
244 if (typeMap.count(typeInst->base)) {
245 ast::TypeInstType * newInst = mutate(typeInst);
246 auto const & pair = typeMap[typeInst->base];
247 newInst->expr_id = pair.first;
248 newInst->formal_usage = pair.second;
249 return newInst;
250 }
251 return typeInst;
252 }
253};
254
255const ast::Expr * SpecializeCore::createThunkFunction(
256 const CodeLocation & location,
257 const ast::FunctionType * funType,
258 const ast::Expr * actual,
259 const ast::InferredParams * inferParams ) {
260 // One set of unique names per program.
261 static UniqueName thunkNamer("_thunk");
262
263 const ast::FunctionType * newType = ast::deepCopy( funType );
264 if ( typeSubs ) {
265 // Must replace only occurrences of type variables
266 // that occure free in the thunk's type.
267 auto result = typeSubs->applyFree( newType );
268 newType = result.node.release();
269 }
270
271 using DWTVector = std::vector<ast::ptr<ast::DeclWithType>>;
272 using DeclVector = std::vector<ast::ptr<ast::TypeDecl>>;
273
274 UniqueName paramNamer( paramPrefix );
275
276 // Create new thunk with same signature as formal type.
277 ast::Pass<TypeInstFixer> fixer;
278 for (const auto & kv : newType->forall) {
279 if (fixer.core.typeMap.count(kv->base)) {
280 std::cerr << location << ' ' << kv->base->name
281 << ' ' << kv->expr_id << '_' << kv->formal_usage
282 << ',' << fixer.core.typeMap[kv->base].first
283 << '_' << fixer.core.typeMap[kv->base].second << std::endl;
284 assertf(false, "multiple formals in specialize");
285 }
286 else {
287 fixer.core.typeMap[kv->base] = std::make_pair(kv->expr_id, kv->formal_usage);
288 }
289 }
290
291 ast::CompoundStmt * thunkBody = new ast::CompoundStmt( location );
292 ast::FunctionDecl * thunkFunc = new ast::FunctionDecl(
293 location,
294 thunkNamer.newName(),
295 map_range<DeclVector>( newType->forall, []( const ast::TypeInstType * inst ) {
296 return ast::deepCopy( inst->base );
297 } ),
298 map_range<DWTVector>( newType->assertions, []( const ast::VariableExpr * expr ) {
299 return ast::deepCopy( expr->var );
300 } ),
301 map_range<DWTVector>( newType->params, [&location, &paramNamer]( const ast::Type * type ) {
302 return new ast::ObjectDecl( location, paramNamer.newName(), ast::deepCopy( type ) );
303 } ),
304 map_range<DWTVector>( newType->returns, [&location, &paramNamer]( const ast::Type * type ) {
305 return new ast::ObjectDecl( location, paramNamer.newName(), ast::deepCopy( type ) );
306 } ),
307 thunkBody,
308 ast::Storage::Classes(),
309 ast::Linkage::C
310 );
311
312 thunkFunc->fixUniqueId();
313
314 // Thunks may be generated and not used, avoid them.
315 thunkFunc->attributes.push_back( new ast::Attribute( "unused" ) );
316
317 // Global thunks must be static to avoid collitions.
318 // Nested thunks must not be unique and hence, not static.
319 thunkFunc->storage.is_static = !isInFunction();
320
321 // Weave thunk parameters into call to actual function,
322 // naming thunk parameters as we go.
323 ast::ApplicationExpr * app = new ast::ApplicationExpr( location, actual );
324
325 const ast::FunctionType * actualType = ast::deepCopy( getFunctionType( actual->result ) );
326 if ( typeSubs ) {
327 // Need to apply the environment to the actual function's type,
328 // since it may itself be polymorphic.
329 auto result = typeSubs->apply( actualType );
330 actualType = result.node.release();
331 }
332
333 ast::ptr<ast::FunctionType> actualTypeManager = actualType;
334
335 std::vector<ast::ptr<ast::Expr>> args;
336 for ( ast::ptr<ast::DeclWithType> & param : thunkFunc->params ) {
337 // Name each thunk parameter and explode it.
338 // These are then threaded back into the actual function call.
339 ast::DeclWithType * mutParam = ast::mutate( param.get() );
340 explodeSimple( location, new ast::VariableExpr( location, mutParam ),
341 std::back_inserter( args ) );
342 }
343
344 // Walk parameters to the actual function alongside the exploded thunk
345 // parameters and restructure the arguments to match the actual parameters.
346 std::vector<ast::ptr<ast::Expr>>::iterator
347 argBegin = args.begin(), argEnd = args.end();
348 for ( const auto & actualArg : actualType->params ) {
349 app->args.push_back(
350 structureArg( location, actualArg.get(), argBegin, argEnd ) );
351 }
352 assertf( argBegin == argEnd, "Did not structure all arguments." );
353
354 app->accept(fixer); // this should modify in place
355
356 app->env = ast::TypeSubstitution::newFromExpr( app, typeSubs );
357 if ( inferParams ) {
358 app->inferred.inferParams() = *inferParams;
359 }
360
361 // Handle any specializations that may still be present.
362 {
363 std::string oldParamPrefix = paramPrefix;
364 paramPrefix += "p";
365 std::list<ast::ptr<ast::Decl>> oldDecls;
366 oldDecls.splice( oldDecls.end(), declsToAddBefore );
367
368 app->accept( *visitor );
369 // Write recursive specializations into the thunk body.
370 for ( const ast::ptr<ast::Decl> & decl : declsToAddBefore ) {
371 thunkBody->push_back( new ast::DeclStmt( decl->location, decl ) );
372 }
373
374 declsToAddBefore = std::move( oldDecls );
375 paramPrefix = std::move( oldParamPrefix );
376 }
377
378 // Add return (or valueless expression) to the thunk.
379 ast::Stmt * appStmt;
380 if ( funType->returns.empty() ) {
381 appStmt = new ast::ExprStmt( app->location, app );
382 } else {
383 appStmt = new ast::ReturnStmt( app->location, app );
384 }
385 thunkBody->push_back( appStmt );
386
387 // Add the thunk definition:
388 declsToAddBefore.push_back( thunkFunc );
389
390 // Return address of thunk function as replacement expression.
391 return new ast::AddressExpr( location,
392 new ast::VariableExpr( location, thunkFunc ) );
393}
394
395const ast::Expr * SpecializeCore::doSpecialization(
396 const CodeLocation & location,
397 const ast::Type * formalType,
398 const ast::Expr * actual,
399 const ast::InferredParams * inferParams ) {
400 assertf( actual->result, "attempting to specialize an untyped expression" );
401 if ( needsSpecialization( formalType, actual->result, typeSubs ) ) {
402 if ( const ast::FunctionType * type = getFunctionType( formalType ) ) {
403 if ( const ast::ApplicationExpr * expr =
404 dynamic_cast<const ast::ApplicationExpr *>( actual ) ) {
405 return createThunkFunction( location, type, expr->func, inferParams );
406 } else if ( auto expr =
407 dynamic_cast<const ast::VariableExpr *>( actual ) ) {
408 return createThunkFunction( location, type, expr, inferParams );
409 } else {
410 // (I don't even know what that comment means.)
411 // This likely won't work, as anything that could build an ApplicationExpr probably hit one of the previous two branches
412 return createThunkFunction( location, type, actual, inferParams );
413 }
414 } else {
415 return actual;
416 }
417 } else {
418 return actual;
419 }
420}
421
422const ast::Expr * SpecializeCore::postvisit(
423 const ast::ApplicationExpr * expr ) {
424 if ( ast::isIntrinsicCallExpr( expr ) ) {
425 return expr;
426 }
427
428 // Create thunks for the inferred parameters.
429 // This is not needed for intrinsic calls, because they aren't
430 // actually passed to the function. It needs to handle explicit params
431 // before inferred params so that explicit params do not recieve a
432 // changed set of inferParams (and change them again).
433 // Alternatively, if order starts to matter then copy expr's inferParams
434 // and pass them to handleExplicitParams.
435 ast::ApplicationExpr * mut = handleExplicitParams( expr );
436 if ( !mut->inferred.hasParams() ) {
437 return mut;
438 }
439 ast::InferredParams & inferParams = mut->inferred.inferParams();
440 for ( ast::InferredParams::value_type & inferParam : inferParams ) {
441 inferParam.second.expr = doSpecialization(
442 inferParam.second.expr->location,
443 inferParam.second.formalType,
444 inferParam.second.expr,
445 getInferredParams( inferParam.second.expr )
446 );
447 }
448 return mut;
449}
450
451const ast::Expr * SpecializeCore::postvisit( const ast::CastExpr * expr ) {
452 if ( expr->result->isVoid() ) {
453 // No specialization if there is no return value.
454 return expr;
455 }
456 const ast::Expr * specialized = doSpecialization(
457 expr->location, expr->result, expr->arg, getInferredParams( expr ) );
458 if ( specialized != expr->arg ) {
459 // Assume that the specialization incorporates the cast.
460 return specialized;
461 } else {
462 return expr;
463 }
464}
465
466} // namespace
467
468void convertSpecializations( ast::TranslationUnit & translationUnit ) {
469 ast::Pass<SpecializeCore>::run( translationUnit );
470}
471
472} // namespace GenPoly
473
474// Local Variables: //
475// tab-width: 4 //
476// mode: c++ //
477// compile-command: "make install" //
478// End: //
Note: See TracBrowser for help on using the repository browser.