source: src/GenPoly/Specialize.cpp @ d84f2ae

Last change on this file since d84f2ae was 0bf03ba2, checked in by Michael Brooks <mlbrooks@…>, 4 weeks ago

Remove warnings due to unused parameters in generated code for zero-length ttype instantiations.

  • Property mode set to 100644
File size: 16.2 KB
Line 
1//
2// Cforall Version 1.0.0 Copyright (C) 2015 University of Waterloo
3//
4// The contents of this file are covered under the licence agreement in the
5// file "LICENCE" distributed with Cforall.
6//
7// Specialize.cpp -- Generate thunks to specialize polymorphic functions.
8//
9// Author           : Andrew Beach
10// Created On       : Tue Jun  7 13:37:00 2022
11// Last Modified By : Andrew Beach
12// Last Modified On : Tue Jun  7 13:37:00 2022
13// Update Count     : 0
14//
15
16#include "Specialize.hpp"
17
18#include "AST/Copy.hpp"                  // for deepCopy
19#include "AST/Inspect.hpp"               // for isIntrinsicCallExpr
20#include "AST/Pass.hpp"                  // for Pass
21#include "AST/TypeEnvironment.hpp"       // for OpenVarSet, AssertionSet
22#include "Common/UniqueName.hpp"         // for UniqueName
23#include "GenPoly/GenPoly.hpp"           // for getFunctionType
24#include "ResolvExpr/FindOpenVars.hpp"   // for findOpenVars
25
26namespace GenPoly {
27
28namespace {
29
30struct SpecializeCore final :
31                public ast::WithConstTypeSubstitution,
32                public ast::WithDeclsToAdd,
33                public ast::WithVisitorRef<SpecializeCore> {
34        std::string paramPrefix = "_p";
35
36        ast::ApplicationExpr * handleExplicitParams(
37                const ast::ApplicationExpr * expr );
38        const ast::Expr * createThunkFunction(
39                const CodeLocation & location,
40                const ast::FunctionType * funType,
41                const ast::Expr * actual,
42                const ast::InferredParams * inferParams );
43        const ast::Expr * doSpecialization(
44                const CodeLocation & location,
45                const ast::Type * formalType,
46                const ast::Expr * actual,
47                const ast::InferredParams * inferParams );
48
49        const ast::Expr * postvisit( const ast::ApplicationExpr * expr );
50        const ast::Expr * postvisit( const ast::CastExpr * expr );
51};
52
53const ast::InferredParams * getInferredParams( const ast::Expr * expr ) {
54        const ast::Expr::InferUnion & inferred = expr->inferred;
55        if ( inferred.hasParams() ) {
56                return &inferred.inferParams();
57        } else {
58                return nullptr;
59        }
60}
61
62// Check if both types have the same structure. The leaf (non-tuple) types
63// don't have to match but the tuples must match.
64bool isTupleStructureMatching( const ast::Type * t0, const ast::Type * t1 ) {
65        const ast::TupleType * tt0 = dynamic_cast<const ast::TupleType *>( t0 );
66        const ast::TupleType * tt1 = dynamic_cast<const ast::TupleType *>( t1 );
67        if ( tt0 && tt1 ) {
68                if ( tt0->size() != tt1->size() ) {
69                        return false;
70                }
71                for ( auto types : group_iterate( tt0->types, tt1->types ) ) {
72                        if ( !isTupleStructureMatching(
73                                        std::get<0>( types ), std::get<1>( types ) ) ) {
74                                return false;
75                        }
76                }
77                return true;
78        }
79        return (!tt0 && !tt1);
80}
81
82// The number of elements in a list, if all tuples had been flattened.
83size_t flatTypeListSize( const std::vector<ast::ptr<ast::Type>> & types ) {
84        size_t sum = 0;
85        for ( const ast::ptr<ast::Type> & type : types ) {
86                if ( const ast::TupleType * tuple = type.as<ast::TupleType>() ) {
87                        sum += flatTypeListSize( tuple->types );
88                } else {
89                        sum += 1;
90                }
91        }
92        return sum;
93}
94
95// Find the total number of components in a parameter list.
96size_t functionParameterSize( const ast::FunctionType * type ) {
97        return flatTypeListSize( type->params );
98}
99
100bool needsPolySpecialization(
101                const ast::Type * /*formalType*/,
102                const ast::Type * actualType,
103                const ast::TypeSubstitution * subs ) {
104        if ( !subs ) {
105                return false;
106        }
107
108        using namespace ResolvExpr;
109        ast::OpenVarSet openVars, closedVars;
110        ast::AssertionSet need, have; // unused
111        ast::TypeEnvironment env; // unused
112        // findOpenVars( formalType, openVars, closedVars, need, have, FirstClosed );
113        findOpenVars( actualType, openVars, closedVars, need, have, env, FirstOpen );
114        for ( const ast::OpenVarSet::value_type & openVar : openVars ) {
115                const ast::Type * boundType = subs->lookup( openVar.first );
116                // If the variable is not bound, move onto the next variable.
117                if ( !boundType ) continue;
118
119                // Is the variable cound to another type variable?
120                if ( auto inst = dynamic_cast<const ast::TypeInstType *>( boundType ) ) {
121                        if ( closedVars.find( *inst ) == closedVars.end() ) {
122                                return true;
123                        } else {
124                                assertf(false, "closed: %s", inst->name.c_str());
125                        }
126                // Otherwise, the variable is bound to a concrete type.
127                } else {
128                        return true;
129                }
130        }
131        // None of the type variables are bound.
132        return false;
133}
134
135bool needsTupleSpecialization(
136                const ast::Type * formalType, const ast::Type * actualType ) {
137        // Needs tuple specialization if the structure of the formal type and
138        // actual type do not match.
139
140        // This is the case if the formal type has ttype polymorphism, or if the structure  of tuple types
141        // between the function do not match exactly.
142        if ( const ast::FunctionType * ftype = getFunctionType( formalType ) ) {
143                // A pack in the parameter or return type requires specialization.
144                if ( ftype->isTtype() ) {
145                        return true;
146                }
147                // Conversion of 0 to a function type does not require specialization.
148                if ( dynamic_cast<const ast::ZeroType *>( actualType ) ) {
149                        return false;
150                }
151                const ast::FunctionType * atype =
152                        getFunctionType( actualType->stripReferences() );
153                assertf( atype,
154                        "formal type is a function type, but actual type is not: %s",
155                        toString( actualType ).c_str() );
156                // Can't tuple specialize if parameter sizes deeply-differ.
157                if ( functionParameterSize( ftype ) != functionParameterSize( atype ) ) {
158                        return false;
159                }
160                // If tuple parameter size matches but actual parameter sizes differ
161                // then there needs to be specialization.
162                if ( ftype->params.size() != atype->params.size() ) {
163                        return true;
164                }
165                // Total parameter size can be the same, while individual parameters
166                // can have different structure.
167                for ( auto pairs : group_iterate( ftype->params, atype->params ) ) {
168                        if ( !isTupleStructureMatching(
169                                        std::get<0>( pairs ), std::get<1>( pairs ) ) ) {
170                                return true;
171                        }
172                }
173        }
174        return false;
175}
176
177bool needsSpecialization(
178                const ast::Type * formalType, const ast::Type * actualType,
179                const ast::TypeSubstitution * subs ) {
180        return needsPolySpecialization( formalType, actualType, subs )
181                || needsTupleSpecialization( formalType, actualType );
182}
183
184ast::ApplicationExpr * SpecializeCore::handleExplicitParams(
185                const ast::ApplicationExpr * expr ) {
186        assert( expr->func->result );
187        const ast::FunctionType * func = getFunctionType( expr->func->result );
188        assert( func );
189
190        ast::ApplicationExpr * mut = ast::mutate( expr );
191
192        std::vector<ast::ptr<ast::Type>>::const_iterator formal;
193        std::vector<ast::ptr<ast::Expr>>::iterator actual;
194        for ( formal = func->params.begin(), actual = mut->args.begin() ;
195                        formal != func->params.end() && actual != mut->args.end() ;
196                        ++formal, ++actual ) {
197                *actual = doSpecialization( (*actual)->location,
198                        *formal, *actual, getInferredParams( expr ) );
199        }
200        return mut;
201}
202
203// Explode assuming simple cases: either type is pure tuple (but not tuple
204// expr) or type is non-tuple.
205template<typename OutputIterator>
206void explodeSimple( const CodeLocation & location,
207                const ast::Expr * expr, OutputIterator out ) {
208        // Recurse on tuple types using index expressions on each component.
209        if ( auto tuple = expr->result.as<ast::TupleType>() ) {
210                ast::ptr<ast::Expr> cleanup = expr;
211                for ( unsigned int i = 0 ; i < tuple->size() ; ++i ) {
212                        explodeSimple( location,
213                                new ast::TupleIndexExpr( location, expr, i ), out );
214                }
215        // For a non-tuple type, output a clone of the expression.
216        } else {
217                *out++ = expr;
218        }
219}
220
221// Restructures arguments to match the structure of the formal parameters
222// of the actual function. Returns the next structured argument.
223template<typename Iterator>
224const ast::Expr * structureArg(
225                const CodeLocation& location, const ast::ptr<ast::Type> & type,
226                Iterator & begin, const Iterator & end ) {
227        if ( auto tuple = type.as<ast::TupleType>() ) {
228                std::vector<ast::ptr<ast::Expr>> exprs;
229                for ( const ast::ptr<ast::Type> & t : *tuple ) {
230                        exprs.push_back( structureArg( location, t, begin, end ) );
231                }
232                return new ast::TupleExpr( location, std::move( exprs ) );
233        } else {
234                assertf( begin != end, "reached the end of the arguments while structuring" );
235                return *begin++;
236        }
237}
238
239struct TypeInstFixer final : public ast::WithShortCircuiting {
240        std::map<const ast::TypeDecl *, std::pair<int, int>> typeMap;
241
242        void previsit(const ast::TypeDecl *) { visit_children = false; }
243        const ast::TypeInstType * postvisit(const ast::TypeInstType * typeInst) {
244                if (typeMap.count(typeInst->base)) {
245                        ast::TypeInstType * newInst = mutate(typeInst);
246                        auto const & pair = typeMap[typeInst->base];
247                        newInst->expr_id = pair.first;
248                        newInst->formal_usage = pair.second;
249                        return newInst;
250                }
251                return typeInst;
252        }
253};
254
255const ast::Expr * SpecializeCore::createThunkFunction(
256                const CodeLocation & location,
257                const ast::FunctionType * funType,
258                const ast::Expr * actual,
259                const ast::InferredParams * inferParams ) {
260        // One set of unique names per program.
261        static UniqueName thunkNamer("_thunk");
262
263        const ast::FunctionType * newType = ast::deepCopy( funType );
264        if ( typeSubs ) {
265                // Must replace only occurrences of type variables
266                // that occure free in the thunk's type.
267                auto result = typeSubs->applyFree( newType );
268                newType = result.node.release();
269        }
270
271        using DWTVector = std::vector<ast::ptr<ast::DeclWithType>>;
272        using DeclVector = std::vector<ast::ptr<ast::TypeDecl>>;
273
274        UniqueName paramNamer( paramPrefix );
275
276        // Create new thunk with same signature as formal type.
277        ast::Pass<TypeInstFixer> fixer;
278        for (const auto & kv : newType->forall) {
279                if (fixer.core.typeMap.count(kv->base)) {
280                        std::cerr << location << ' ' << kv->base->name
281                                << ' ' << kv->expr_id << '_' << kv->formal_usage
282                                << ',' << fixer.core.typeMap[kv->base].first
283                                << '_' << fixer.core.typeMap[kv->base].second << std::endl;
284                        assertf(false, "multiple formals in specialize");
285                }
286                else {
287                        fixer.core.typeMap[kv->base] = std::make_pair(kv->expr_id, kv->formal_usage);
288                }
289        }
290
291        ast::CompoundStmt * thunkBody = new ast::CompoundStmt( location );
292        ast::FunctionDecl * thunkFunc = new ast::FunctionDecl(
293                location,
294                thunkNamer.newName(),
295                map_range<DeclVector>( newType->forall, []( const ast::TypeInstType * inst ) {
296                        return ast::deepCopy( inst->base );
297                } ),
298                map_range<DWTVector>( newType->assertions, []( const ast::VariableExpr * expr ) {
299                        return ast::deepCopy( expr->var );
300                } ),
301                map_range<DWTVector>( newType->params, [&location, &paramNamer]( const ast::Type * type ) {
302                        auto param = new ast::ObjectDecl( location, paramNamer.newName(), ast::deepCopy( type ) );
303                        param->attributes.push_back( new ast::Attribute( "unused" ) );
304                        return param;
305                } ),
306                map_range<DWTVector>( newType->returns, [&location, &paramNamer]( const ast::Type * type ) {
307                        return new ast::ObjectDecl( location, paramNamer.newName(), ast::deepCopy( type ) );
308                } ),
309                thunkBody,
310                ast::Storage::Classes(),
311                ast::Linkage::C
312                );
313
314        thunkFunc->fixUniqueId();
315
316        // Thunks may be generated and not used, avoid them.
317        thunkFunc->attributes.push_back( new ast::Attribute( "unused" ) );
318
319        // Global thunks must be static to avoid collitions.
320        // Nested thunks must not be unique and hence, not static.
321        thunkFunc->storage.is_static = !isInFunction();
322
323        // Weave thunk parameters into call to actual function,
324        // naming thunk parameters as we go.
325        ast::ApplicationExpr * app = new ast::ApplicationExpr( location, actual );
326
327        const ast::FunctionType * actualType = ast::deepCopy( getFunctionType( actual->result ) );
328        if ( typeSubs ) {
329                // Need to apply the environment to the actual function's type,
330                // since it may itself be polymorphic.
331                auto result = typeSubs->apply( actualType );
332                actualType = result.node.release();
333        }
334
335        ast::ptr<ast::FunctionType> actualTypeManager = actualType;
336
337        std::vector<ast::ptr<ast::Expr>> args;
338        for ( ast::ptr<ast::DeclWithType> & param : thunkFunc->params ) {
339                // Name each thunk parameter and explode it.
340                // These are then threaded back into the actual function call.
341                ast::DeclWithType * mutParam = ast::mutate( param.get() );
342                explodeSimple( location, new ast::VariableExpr( location, mutParam ),
343                        std::back_inserter( args ) );
344        }
345
346        // Walk parameters to the actual function alongside the exploded thunk
347        // parameters and restructure the arguments to match the actual parameters.
348        std::vector<ast::ptr<ast::Expr>>::iterator
349                argBegin = args.begin(), argEnd = args.end();
350        for ( const auto & actualArg : actualType->params ) {
351                app->args.push_back(
352                        structureArg( location, actualArg.get(), argBegin, argEnd ) );
353        }
354        assertf( argBegin == argEnd, "Did not structure all arguments." );
355
356        app->accept(fixer); // this should modify in place
357
358        app->env = ast::TypeSubstitution::newFromExpr( app, typeSubs );
359        if ( inferParams ) {
360                app->inferred.inferParams() = *inferParams;
361        }
362
363        // Handle any specializations that may still be present.
364        {
365                std::string oldParamPrefix = paramPrefix;
366                paramPrefix += "p";
367                std::list<ast::ptr<ast::Decl>> oldDecls;
368                oldDecls.splice( oldDecls.end(), declsToAddBefore );
369
370                app->accept( *visitor );
371                // Write recursive specializations into the thunk body.
372                for ( const ast::ptr<ast::Decl> & decl : declsToAddBefore ) {
373                        thunkBody->push_back( new ast::DeclStmt( decl->location, decl ) );
374                }
375
376                declsToAddBefore = std::move( oldDecls );
377                paramPrefix = std::move( oldParamPrefix );
378        }
379
380        // Add return (or valueless expression) to the thunk.
381        ast::Stmt * appStmt;
382        if ( funType->returns.empty() ) {
383                appStmt = new ast::ExprStmt( app->location, app );
384        } else {
385                appStmt = new ast::ReturnStmt( app->location, app );
386        }
387        thunkBody->push_back( appStmt );
388
389        // Add the thunk definition:
390        declsToAddBefore.push_back( thunkFunc );
391
392        // Return address of thunk function as replacement expression.
393        return new ast::AddressExpr( location,
394                new ast::VariableExpr( location, thunkFunc ) );
395}
396
397const ast::Expr * SpecializeCore::doSpecialization(
398                const CodeLocation & location,
399                const ast::Type * formalType,
400                const ast::Expr * actual,
401                const ast::InferredParams * inferParams ) {
402        assertf( actual->result, "attempting to specialize an untyped expression" );
403        if ( needsSpecialization( formalType, actual->result, typeSubs ) ) {
404                if ( const ast::FunctionType * type = getFunctionType( formalType ) ) {
405                        if ( const ast::ApplicationExpr * expr =
406                                        dynamic_cast<const ast::ApplicationExpr *>( actual ) ) {
407                                return createThunkFunction( location, type, expr->func, inferParams );
408                        } else if ( auto expr =
409                                        dynamic_cast<const ast::VariableExpr *>( actual ) ) {
410                                return createThunkFunction( location, type, expr, inferParams );
411                        } else {
412                                // (I don't even know what that comment means.)
413                                // This likely won't work, as anything that could build an ApplicationExpr probably hit one of the previous two branches
414                                return createThunkFunction( location, type, actual, inferParams );
415                        }
416                } else {
417                        return actual;
418                }
419        } else {
420                return actual;
421        }
422}
423
424const ast::Expr * SpecializeCore::postvisit(
425                const ast::ApplicationExpr * expr ) {
426        if ( ast::isIntrinsicCallExpr( expr ) ) {
427                return expr;
428        }
429
430        // Create thunks for the inferred parameters.
431        // This is not needed for intrinsic calls, because they aren't
432        // actually passed to the function. It needs to handle explicit params
433        // before inferred params so that explicit params do not recieve a
434        // changed set of inferParams (and change them again).
435        // Alternatively, if order starts to matter then copy expr's inferParams
436        // and pass them to handleExplicitParams.
437        ast::ApplicationExpr * mut = handleExplicitParams( expr );
438        if ( !mut->inferred.hasParams() ) {
439                return mut;
440        }
441        ast::InferredParams & inferParams = mut->inferred.inferParams();
442        for ( ast::InferredParams::value_type & inferParam : inferParams ) {
443                inferParam.second.expr = doSpecialization(
444                        inferParam.second.expr->location,
445                        inferParam.second.formalType,
446                        inferParam.second.expr,
447                        getInferredParams( inferParam.second.expr )
448                );
449        }
450        return mut;
451}
452
453const ast::Expr * SpecializeCore::postvisit( const ast::CastExpr * expr ) {
454        if ( expr->result->isVoid() ) {
455                // No specialization if there is no return value.
456                return expr;
457        }
458        const ast::Expr * specialized = doSpecialization(
459                expr->location, expr->result, expr->arg, getInferredParams( expr ) );
460        if ( specialized != expr->arg ) {
461                // Assume that the specialization incorporates the cast.
462                return specialized;
463        } else {
464                return expr;
465        }
466}
467
468} // namespace
469
470void convertSpecializations( ast::TranslationUnit & translationUnit ) {
471        ast::Pass<SpecializeCore>::run( translationUnit );
472}
473
474} // namespace GenPoly
475
476// Local Variables: //
477// tab-width: 4 //
478// mode: c++ //
479// compile-command: "make install" //
480// End: //
Note: See TracBrowser for help on using the repository browser.