source: src/GenPoly/SpecializeNew.cpp @ dd33c1f

ADTast-experimentalpthread-emulation
Last change on this file since dd33c1f was dd33c1f, checked in by Fangren Yu <f37yu@…>, 2 years ago

Merge branch 'master' of plg.uwaterloo.ca:software/cfa/cfa-cc

  • Property mode set to 100644
File size: 16.0 KB
Line 
1//
2// Cforall Version 1.0.0 Copyright (C) 2015 University of Waterloo
3//
4// The contents of this file are covered under the licence agreement in the
5// file "LICENCE" distributed with Cforall.
6//
7// SpecializeNew.cpp -- Generate thunks to specialize polymorphic functions.
8//
9// Author           : Andrew Beach
10// Created On       : Tue Jun  7 13:37:00 2022
11// Last Modified By : Andrew Beach
12// Last Modified On : Tue Jun  7 13:37:00 2022
13// Update Count     : 0
14//
15
16#include "Specialize.h"
17
18#include "AST/Pass.hpp"
19#include "AST/TypeEnvironment.hpp"       // for OpenVarSet, AssertionSet
20#include "Common/UniqueName.h"           // for UniqueName
21#include "GenPoly/GenPoly.h"             // for getFunctionType
22#include "InitTweak/InitTweak.h"         // for isIntrinsicCallExpr
23#include "ResolvExpr/FindOpenVars.h"     // for findOpenVars
24#include "ResolvExpr/TypeEnvironment.h"  // for FirstOpen, FirstClosed
25
26#include "AST/Print.hpp"
27
28namespace GenPoly {
29
30namespace {
31
32struct SpecializeCore final :
33                public ast::WithConstTypeSubstitution,
34                public ast::WithDeclsToAdd<>,
35                public ast::WithVisitorRef<SpecializeCore> {
36        std::string paramPrefix = "_p";
37
38        ast::ApplicationExpr * handleExplicitParams(
39                const ast::ApplicationExpr * expr );
40        const ast::Expr * createThunkFunction(
41                const CodeLocation & location,
42                const ast::FunctionType * funType,
43                const ast::Expr * actual,
44                const ast::InferredParams * inferParams );
45        const ast::Expr * doSpecialization(
46                const CodeLocation & location,
47                const ast::Type * formalType,
48                const ast::Expr * actual,
49                const ast::InferredParams * inferParams );
50
51        const ast::Expr * postvisit( const ast::ApplicationExpr * expr );
52        const ast::Expr * postvisit( const ast::CastExpr * expr );
53};
54
55const ast::InferredParams * getInferredParams( const ast::Expr * expr ) {
56        const ast::Expr::InferUnion & inferred = expr->inferred;
57        if ( inferred.hasParams() ) {
58                return &inferred.inferParams();
59        } else {
60                return nullptr;
61        }
62}
63
64// Check if both types have the same structure. The leaf (non-tuple) types
65// don't have to match but the tuples must match.
66bool isTupleStructureMatching( const ast::Type * t0, const ast::Type * t1 ) {
67        const ast::TupleType * tt0 = dynamic_cast<const ast::TupleType *>( t0 );
68        const ast::TupleType * tt1 = dynamic_cast<const ast::TupleType *>( t1 );
69        if ( tt0 && tt1 ) {
70                if ( tt0->size() != tt1->size() ) {
71                        return false;
72                }
73                for ( auto types : group_iterate( tt0->types, tt1->types ) ) {
74                        if ( !isTupleStructureMatching(
75                                        std::get<0>( types ), std::get<1>( types ) ) ) {
76                                return false;
77                        }
78                }
79                return true;
80        }
81        return (!tt0 && !tt1);
82}
83
84// The number of elements in a type if it is a flattened tuple.
85size_t flatTupleSize( const ast::Type * type ) {
86        if ( auto tuple = dynamic_cast<const ast::TupleType *>( type ) ) {
87                size_t sum = 0;
88                for ( auto t : *tuple ) {
89                        sum += flatTupleSize( t );
90                }
91                return sum;
92        } else {
93                return 1;
94        }
95}
96
97// Find the total number of components in a parameter list.
98size_t functionParameterSize( const ast::FunctionType * type ) {
99        size_t sum = 0;
100        for ( auto param : type->params ) {
101                sum += flatTupleSize( param );
102        }
103        return sum;
104}
105
106bool needsPolySpecialization(
107                const ast::Type * formalType,
108                const ast::Type * actualType,
109                const ast::TypeSubstitution * subs ) {
110        if ( !subs ) {
111                return false;
112        }
113
114        using namespace ResolvExpr;
115        ast::OpenVarSet openVars, closedVars;
116        ast::AssertionSet need, have;
117        findOpenVars( formalType, openVars, closedVars, need, have, FirstClosed );
118        findOpenVars( actualType, openVars, closedVars, need, have, FirstOpen );
119        for ( const ast::OpenVarSet::value_type & openVar : openVars ) {
120                const ast::Type * boundType = subs->lookup( openVar.first );
121                // If the variable is not bound, move onto the next variable.
122                if ( !boundType ) continue;
123
124                // Is the variable cound to another type variable?
125                if ( auto inst = dynamic_cast<const ast::TypeInstType *>( boundType ) ) {
126                        if ( closedVars.find( *inst ) == closedVars.end() ) {
127                                return true;
128                        }
129                // Otherwise, the variable is bound to a concrete type.
130                } else {
131                        return true;
132                }
133        }
134        // None of the type variables are bound.
135        return false;
136}
137
138bool needsTupleSpecialization(
139                const ast::Type * formalType, const ast::Type * actualType ) {
140        // Needs tuple specialization if the structure of the formal type and
141        // actual type do not match.
142
143        // This is the case if the formal type has ttype polymorphism, or if the structure  of tuple types
144        // between the function do not match exactly.
145        if ( const ast::FunctionType * ftype = getFunctionType( formalType ) ) {
146                // A pack in the parameter or return type requires specialization.
147                if ( ftype->isTtype() ) {
148                        return true;
149                }
150                // Conversion of 0 to a function type does not require specialization.
151                if ( dynamic_cast<const ast::ZeroType *>( actualType ) ) {
152                        return false;
153                }
154                const ast::FunctionType * atype =
155                        getFunctionType( actualType->stripReferences() );
156                assertf( atype,
157                        "formal type is a function type, but actual type is not: %s",
158                        toString( actualType ).c_str() );
159                // Can't tuple specialize if parameter sizes deeply-differ.
160                if ( functionParameterSize( ftype ) != functionParameterSize( atype ) ) {
161                        return false;
162                }
163                // If tuple parameter size matches but actual parameter sizes differ
164                // then there needs to be specialization.
165                if ( ftype->params.size() != atype->params.size() ) {
166                        return true;
167                }
168                // Total parameter size can be the same, while individual parameters
169                // can have different structure.
170                for ( auto pairs : group_iterate( ftype->params, atype->params ) ) {
171                        if ( !isTupleStructureMatching(
172                                        std::get<0>( pairs ), std::get<1>( pairs ) ) ) {
173                                return true;
174                        }
175                }
176        }
177        return false;
178}
179
180bool needsSpecialization(
181                const ast::Type * formalType, const ast::Type * actualType,
182                const ast::TypeSubstitution * subs ) {
183        return needsPolySpecialization( formalType, actualType, subs )
184                || needsTupleSpecialization( formalType, actualType );
185}
186
187ast::ApplicationExpr * SpecializeCore::handleExplicitParams(
188                const ast::ApplicationExpr * expr ) {
189        assert( expr->func->result );
190        const ast::FunctionType * func = getFunctionType( expr->func->result );
191        assert( func );
192
193        ast::ApplicationExpr * mut = ast::mutate( expr );
194
195        std::vector<ast::ptr<ast::Type>>::const_iterator formal;
196        std::vector<ast::ptr<ast::Expr>>::iterator actual;
197        for ( formal = func->params.begin(), actual = mut->args.begin() ;
198                        formal != func->params.end() && actual != mut->args.end() ;
199                        ++formal, ++actual ) {
200                *actual = doSpecialization( (*actual)->location,
201                        *formal, *actual, getInferredParams( expr ) );
202        }
203        return mut;
204}
205
206// Explode assuming simple cases: either type is pure tuple (but not tuple
207// expr) or type is non-tuple.
208template<typename OutputIterator>
209void explodeSimple( const CodeLocation & location,
210                const ast::Expr * expr, OutputIterator out ) {
211        // Recurse on tuple types using index expressions on each component.
212        if ( auto tuple = expr->result.as<ast::TupleType>() ) {
213                ast::ptr<ast::Expr> cleanup = expr;
214                for ( unsigned int i = 0 ; i < tuple->size() ; ++i ) {
215                        explodeSimple( location,
216                                new ast::TupleIndexExpr( location, expr, i ), out );
217                }
218        // For a non-tuple type, output a clone of the expression.
219        } else {
220                *out++ = expr;
221        }
222}
223
224// Restructures arguments to match the structure of the formal parameters
225// of the actual function. Returns the next structured argument.
226template<typename Iterator>
227const ast::Expr * structureArg(
228                const CodeLocation& location, const ast::ptr<ast::Type> & type,
229                Iterator & begin, const Iterator & end ) {
230        if ( auto tuple = type.as<ast::TupleType>() ) {
231                std::vector<ast::ptr<ast::Expr>> exprs;
232                for ( const ast::Type * t : *tuple ) {
233                        exprs.push_back( structureArg( location, t, begin, end ) );
234                }
235                return new ast::TupleExpr( location, std::move( exprs ) );
236        } else {
237                assertf( begin != end, "reached the end of the arguments while structuring" );
238                return *begin++;
239        }
240}
241
242namespace {
243        struct TypeInstFixer : public ast::WithShortCircuiting {
244                std::map<const ast::TypeDecl *, std::pair<int, int>> typeMap;
245
246                void previsit(const ast::TypeDecl *) { visit_children = false; }
247                const ast::TypeInstType * postvisit(const ast::TypeInstType * typeInst) {
248                        if (typeMap.count(typeInst->base)) {
249                                ast::TypeInstType * newInst = mutate(typeInst);
250                                newInst->expr_id = typeMap[typeInst->base].first;
251                                newInst->formal_usage = typeMap[typeInst->base].second;
252                                return newInst;
253                        }
254                        return typeInst;
255                }
256        };
257}
258
259const ast::Expr * SpecializeCore::createThunkFunction(
260                const CodeLocation & location,
261                const ast::FunctionType * funType,
262                const ast::Expr * actual,
263                const ast::InferredParams * inferParams ) {
264        // One set of unique names per program.
265        static UniqueName thunkNamer("_thunk");
266
267        const ast::FunctionType * newType = ast::deepCopy( funType );
268        if ( typeSubs ) {
269                // Must replace only occurrences of type variables
270                // that occure free in the thunk's type.
271                auto result = typeSubs->applyFree( newType );
272                newType = result.node.release();
273        }
274
275        using DWTVector = std::vector<ast::ptr<ast::DeclWithType>>;
276        using DeclVector = std::vector<ast::ptr<ast::TypeDecl>>;
277
278        UniqueName paramNamer( paramPrefix );
279
280        // Create new thunk with same signature as formal type.
281        ast::Pass<TypeInstFixer> fixer;
282        for (const auto & kv : newType->forall) {
283                if (fixer.core.typeMap.count(kv->base)) {
284                        std::cerr << location << ' ' << kv->base->name
285                                << ' ' << kv->expr_id << '_' << kv->formal_usage
286                                << ',' << fixer.core.typeMap[kv->base].first
287                                << '_' << fixer.core.typeMap[kv->base].second << std::endl;
288                        assertf(false, "multiple formals in specialize");
289                }
290                else {
291                        fixer.core.typeMap[kv->base] = std::make_pair(kv->expr_id, kv->formal_usage);
292                }
293        }
294
295        ast::CompoundStmt * thunkBody = new ast::CompoundStmt( location );
296        ast::FunctionDecl * thunkFunc = new ast::FunctionDecl(
297                location,
298                thunkNamer.newName(),
299                map_range<DeclVector>( newType->forall, []( const ast::TypeInstType * inst ) {
300                        return ast::deepCopy( inst->base );
301                } ),
302                map_range<DWTVector>( newType->assertions, []( const ast::VariableExpr * expr ) {
303                        return ast::deepCopy( expr->var );
304                } ),
305                map_range<DWTVector>( newType->params, [&location, &paramNamer]( const ast::Type * type ) {
306                        return new ast::ObjectDecl( location, paramNamer.newName(), ast::deepCopy( type ) );
307                } ),
308                map_range<DWTVector>( newType->returns, [&location, &paramNamer]( const ast::Type * type ) {
309                        return new ast::ObjectDecl( location, paramNamer.newName(), ast::deepCopy( type ) );
310                } ),
311                thunkBody,
312                ast::Storage::Classes(),
313                ast::Linkage::C
314                );
315
316        thunkFunc->fixUniqueId();
317
318        // Thunks may be generated and not used, avoid them.
319        thunkFunc->attributes.push_back( new ast::Attribute( "unused" ) );
320
321        // Global thunks must be static to avoid collitions.
322        // Nested thunks must not be unique and hence, not static.
323        thunkFunc->storage.is_static = !isInFunction();
324
325        // Weave thunk parameters into call to actual function,
326        // naming thunk parameters as we go.
327        ast::ApplicationExpr * app = new ast::ApplicationExpr( location, actual );
328
329        const ast::FunctionType * actualType = ast::deepCopy( getFunctionType( actual->result ) );
330        if ( typeSubs ) {
331                // Need to apply the environment to the actual function's type,
332                // since it may itself be polymorphic.
333                auto result = typeSubs->apply( actualType );
334                actualType = result.node.release();
335        }
336
337        ast::ptr<ast::FunctionType> actualTypeManager = actualType;
338
339        std::vector<ast::ptr<ast::Expr>> args;
340        for ( ast::ptr<ast::DeclWithType> & param : thunkFunc->params ) {
341                // Name each thunk parameter and explode it.
342                // These are then threaded back into the actual function call.
343                ast::DeclWithType * mutParam = ast::mutate( param.get() );
344                explodeSimple( location, new ast::VariableExpr( location, mutParam ),
345                        std::back_inserter( args ) );
346        }
347
348        // Walk parameters to the actual function alongside the exploded thunk
349        // parameters and restructure the arguments to match the actual parameters.
350        std::vector<ast::ptr<ast::Expr>>::iterator
351                argBegin = args.begin(), argEnd = args.end();
352        for ( const auto & actualArg : actualType->params ) {
353                app->args.push_back(
354                        structureArg( location, actualArg.get(), argBegin, argEnd ) );
355        }
356        assertf( argBegin == argEnd, "Did not structure all arguments." );
357
358        app->accept(fixer); // this should modify in place
359
360        app->env = ast::TypeSubstitution::newFromExpr( app, typeSubs );
361        if ( inferParams ) {
362                app->inferred.inferParams() = *inferParams;
363        }
364
365        // Handle any specializations that may still be present.
366        {
367                std::string oldParamPrefix = paramPrefix;
368                paramPrefix += "p";
369                std::list<ast::ptr<ast::Decl>> oldDecls;
370                oldDecls.splice( oldDecls.end(), declsToAddBefore );
371
372                app->accept( *visitor );
373                // Write recursive specializations into the thunk body.
374                for ( const ast::ptr<ast::Decl> & decl : declsToAddBefore ) {
375                        thunkBody->push_back( new ast::DeclStmt( decl->location, decl ) );
376                }
377
378                declsToAddBefore = std::move( oldDecls );
379                paramPrefix = std::move( oldParamPrefix );
380        }
381
382        // Add return (or valueless expression) to the thunk.
383        ast::Stmt * appStmt;
384        if ( funType->returns.empty() ) {
385                appStmt = new ast::ExprStmt( app->location, app );
386        } else {
387                appStmt = new ast::ReturnStmt( app->location, app );
388        }
389        thunkBody->push_back( appStmt );
390
391        // Add the thunk definition:
392        declsToAddBefore.push_back( thunkFunc );
393
394        // Return address of thunk function as replacement expression.
395        return new ast::AddressExpr( location,
396                new ast::VariableExpr( location, thunkFunc ) );
397}
398
399const ast::Expr * SpecializeCore::doSpecialization(
400                const CodeLocation & location,
401                const ast::Type * formalType,
402                const ast::Expr * actual,
403                const ast::InferredParams * inferParams ) {
404        assertf( actual->result, "attempting to specialize an untyped expression" );
405        if ( needsSpecialization( formalType, actual->result, typeSubs ) ) {
406                if ( const ast::FunctionType * type = getFunctionType( formalType ) ) {
407                        if ( const ast::ApplicationExpr * expr =
408                                        dynamic_cast<const ast::ApplicationExpr *>( actual ) ) {
409                                return createThunkFunction( location, type, expr->func, inferParams );
410                        } else if ( auto expr =
411                                        dynamic_cast<const ast::VariableExpr *>( actual ) ) {
412                                return createThunkFunction( location, type, expr, inferParams );
413                        } else {
414                                // (I don't even know what that comment means.)
415                                // This likely won't work, as anything that could build an ApplicationExpr probably hit one of the previous two branches
416                                return createThunkFunction( location, type, actual, inferParams );
417                        }
418                } else {
419                        return actual;
420                }
421        } else {
422                return actual;
423        }
424}
425
426const ast::Expr * SpecializeCore::postvisit(
427                const ast::ApplicationExpr * expr ) {
428        if ( InitTweak::isIntrinsicCallExpr( expr ) ) {
429                return expr;
430        }
431
432        // Create thunks for the inferred parameters.
433        // This is not needed for intrinsic calls, because they aren't
434        // actually passed to the function. It needs to handle explicit params
435        // before inferred params so that explicit params do not recieve a
436        // changed set of inferParams (and change them again).
437        // Alternatively, if order starts to matter then copy expr's inferParams
438        // and pass them to handleExplicitParams.
439        ast::ApplicationExpr * mut = handleExplicitParams( expr );
440        if ( !mut->inferred.hasParams() ) {
441                return mut;
442        }
443        ast::InferredParams & inferParams = mut->inferred.inferParams();
444        for ( ast::InferredParams::value_type & inferParam : inferParams ) {
445                inferParam.second.expr = doSpecialization(
446                        inferParam.second.expr->location,
447                        inferParam.second.formalType,
448                        inferParam.second.expr,
449                        getInferredParams( inferParam.second.expr )
450                );
451        }
452        return mut;
453}
454
455const ast::Expr * SpecializeCore::postvisit( const ast::CastExpr * expr ) {
456        if ( expr->result->isVoid() ) {
457                // No specialization if there is no return value.
458                return expr;
459        }
460        const ast::Expr * specialized = doSpecialization(
461                expr->location, expr->result, expr->arg, getInferredParams( expr ) );
462        if ( specialized != expr->arg ) {
463                // Assume that the specialization incorporates the cast.
464                // std::cerr << expr <<std::endl;
465                return specialized;
466        } else {
467                return expr;
468        }
469}
470
471} // namespace
472
473void convertSpecializations( ast::TranslationUnit & translationUnit ) {
474        ast::Pass<SpecializeCore>::run( translationUnit );
475}
476
477} // namespace GenPoly
478
479// Local Variables: //
480// tab-width: 4 //
481// mode: c++ //
482// compile-command: "make install" //
483// End: //
Note: See TracBrowser for help on using the repository browser.