source: src/GenPoly/SpecializeNew.cpp @ a983cbf

Last change on this file since a983cbf was 24d6572, checked in by Fangren Yu <f37yu@…>, 13 months ago

Merge branch 'master' into ast-experimental

  • Property mode set to 100644
File size: 16.2 KB
Line 
1//
2// Cforall Version 1.0.0 Copyright (C) 2015 University of Waterloo
3//
4// The contents of this file are covered under the licence agreement in the
5// file "LICENCE" distributed with Cforall.
6//
7// SpecializeNew.cpp -- Generate thunks to specialize polymorphic functions.
8//
9// Author           : Andrew Beach
10// Created On       : Tue Jun  7 13:37:00 2022
11// Last Modified By : Andrew Beach
12// Last Modified On : Tue Jun  7 13:37:00 2022
13// Update Count     : 0
14//
15
16#include "Specialize.h"
17
18#include "AST/Copy.hpp"                  // for deepCopy
19#include "AST/Inspect.hpp"               // for isIntrinsicCallExpr
20#include "AST/Pass.hpp"                  // for Pass
21#include "AST/TypeEnvironment.hpp"       // for OpenVarSet, AssertionSet
22#include "Common/UniqueName.h"           // for UniqueName
23#include "GenPoly/GenPoly.h"             // for getFunctionType
24#include "ResolvExpr/FindOpenVars.h"     // for findOpenVars
25#include "ResolvExpr/TypeEnvironment.h"  // for FirstOpen, FirstClosed
26
27namespace GenPoly {
28
29namespace {
30
31struct SpecializeCore final :
32                public ast::WithConstTypeSubstitution,
33                public ast::WithDeclsToAdd<>,
34                public ast::WithVisitorRef<SpecializeCore> {
35        std::string paramPrefix = "_p";
36
37        ast::ApplicationExpr * handleExplicitParams(
38                const ast::ApplicationExpr * expr );
39        const ast::Expr * createThunkFunction(
40                const CodeLocation & location,
41                const ast::FunctionType * funType,
42                const ast::Expr * actual,
43                const ast::InferredParams * inferParams );
44        const ast::Expr * doSpecialization(
45                const CodeLocation & location,
46                const ast::Type * formalType,
47                const ast::Expr * actual,
48                const ast::InferredParams * inferParams );
49
50        const ast::Expr * postvisit( const ast::ApplicationExpr * expr );
51        const ast::Expr * postvisit( const ast::CastExpr * expr );
52};
53
54const ast::InferredParams * getInferredParams( const ast::Expr * expr ) {
55        const ast::Expr::InferUnion & inferred = expr->inferred;
56        if ( inferred.hasParams() ) {
57                return &inferred.inferParams();
58        } else {
59                return nullptr;
60        }
61}
62
63// Check if both types have the same structure. The leaf (non-tuple) types
64// don't have to match but the tuples must match.
65bool isTupleStructureMatching( const ast::Type * t0, const ast::Type * t1 ) {
66        const ast::TupleType * tt0 = dynamic_cast<const ast::TupleType *>( t0 );
67        const ast::TupleType * tt1 = dynamic_cast<const ast::TupleType *>( t1 );
68        if ( tt0 && tt1 ) {
69                if ( tt0->size() != tt1->size() ) {
70                        return false;
71                }
72                for ( auto types : group_iterate( tt0->types, tt1->types ) ) {
73                        if ( !isTupleStructureMatching(
74                                        std::get<0>( types ), std::get<1>( types ) ) ) {
75                                return false;
76                        }
77                }
78                return true;
79        }
80        return (!tt0 && !tt1);
81}
82
83// The number of elements in a type if it is a flattened tuple.
84size_t flatTupleSize( const ast::Type * type ) {
85        if ( auto tuple = dynamic_cast<const ast::TupleType *>( type ) ) {
86                size_t sum = 0;
87                for ( auto t : *tuple ) {
88                        sum += flatTupleSize( t );
89                }
90                return sum;
91        } else {
92                return 1;
93        }
94}
95
96// Find the total number of components in a parameter list.
97size_t functionParameterSize( const ast::FunctionType * type ) {
98        size_t sum = 0;
99        for ( auto param : type->params ) {
100                sum += flatTupleSize( param );
101        }
102        return sum;
103}
104
105bool needsPolySpecialization(
106                const ast::Type * formalType,
107                const ast::Type * actualType,
108                const ast::TypeSubstitution * subs ) {
109        if ( !subs ) {
110                return false;
111        }
112
113        using namespace ResolvExpr;
114        ast::OpenVarSet openVars, closedVars;
115        ast::AssertionSet need, have; // unused
116        ast::TypeEnvironment env; // unused
117        // findOpenVars( formalType, openVars, closedVars, need, have, FirstClosed );
118        findOpenVars( actualType, openVars, closedVars, need, have, env, FirstOpen );
119        for ( const ast::OpenVarSet::value_type & openVar : openVars ) {
120                const ast::Type * boundType = subs->lookup( openVar.first );
121                // If the variable is not bound, move onto the next variable.
122                if ( !boundType ) continue;
123
124                // Is the variable cound to another type variable?
125                if ( auto inst = dynamic_cast<const ast::TypeInstType *>( boundType ) ) {
126                        if ( closedVars.find( *inst ) == closedVars.end() ) {
127                                return true;
128                        }
129                        else {
130                                assertf(false, "closed: %s", inst->name.c_str());
131                        }
132                // Otherwise, the variable is bound to a concrete type.
133                } else {
134                        return true;
135                }
136        }
137        // None of the type variables are bound.
138        return false;
139}
140
141bool needsTupleSpecialization(
142                const ast::Type * formalType, const ast::Type * actualType ) {
143        // Needs tuple specialization if the structure of the formal type and
144        // actual type do not match.
145
146        // This is the case if the formal type has ttype polymorphism, or if the structure  of tuple types
147        // between the function do not match exactly.
148        if ( const ast::FunctionType * ftype = getFunctionType( formalType ) ) {
149                // A pack in the parameter or return type requires specialization.
150                if ( ftype->isTtype() ) {
151                        return true;
152                }
153                // Conversion of 0 to a function type does not require specialization.
154                if ( dynamic_cast<const ast::ZeroType *>( actualType ) ) {
155                        return false;
156                }
157                const ast::FunctionType * atype =
158                        getFunctionType( actualType->stripReferences() );
159                assertf( atype,
160                        "formal type is a function type, but actual type is not: %s",
161                        toString( actualType ).c_str() );
162                // Can't tuple specialize if parameter sizes deeply-differ.
163                if ( functionParameterSize( ftype ) != functionParameterSize( atype ) ) {
164                        return false;
165                }
166                // If tuple parameter size matches but actual parameter sizes differ
167                // then there needs to be specialization.
168                if ( ftype->params.size() != atype->params.size() ) {
169                        return true;
170                }
171                // Total parameter size can be the same, while individual parameters
172                // can have different structure.
173                for ( auto pairs : group_iterate( ftype->params, atype->params ) ) {
174                        if ( !isTupleStructureMatching(
175                                        std::get<0>( pairs ), std::get<1>( pairs ) ) ) {
176                                return true;
177                        }
178                }
179        }
180        return false;
181}
182
183bool needsSpecialization(
184                const ast::Type * formalType, const ast::Type * actualType,
185                const ast::TypeSubstitution * subs ) {
186        return needsPolySpecialization( formalType, actualType, subs )
187                || needsTupleSpecialization( formalType, actualType );
188}
189
190ast::ApplicationExpr * SpecializeCore::handleExplicitParams(
191                const ast::ApplicationExpr * expr ) {
192        assert( expr->func->result );
193        const ast::FunctionType * func = getFunctionType( expr->func->result );
194        assert( func );
195
196        ast::ApplicationExpr * mut = ast::mutate( expr );
197
198        std::vector<ast::ptr<ast::Type>>::const_iterator formal;
199        std::vector<ast::ptr<ast::Expr>>::iterator actual;
200        for ( formal = func->params.begin(), actual = mut->args.begin() ;
201                        formal != func->params.end() && actual != mut->args.end() ;
202                        ++formal, ++actual ) {
203                *actual = doSpecialization( (*actual)->location,
204                        *formal, *actual, getInferredParams( expr ) );
205        }
206        return mut;
207}
208
209// Explode assuming simple cases: either type is pure tuple (but not tuple
210// expr) or type is non-tuple.
211template<typename OutputIterator>
212void explodeSimple( const CodeLocation & location,
213                const ast::Expr * expr, OutputIterator out ) {
214        // Recurse on tuple types using index expressions on each component.
215        if ( auto tuple = expr->result.as<ast::TupleType>() ) {
216                ast::ptr<ast::Expr> cleanup = expr;
217                for ( unsigned int i = 0 ; i < tuple->size() ; ++i ) {
218                        explodeSimple( location,
219                                new ast::TupleIndexExpr( location, expr, i ), out );
220                }
221        // For a non-tuple type, output a clone of the expression.
222        } else {
223                *out++ = expr;
224        }
225}
226
227// Restructures arguments to match the structure of the formal parameters
228// of the actual function. Returns the next structured argument.
229template<typename Iterator>
230const ast::Expr * structureArg(
231                const CodeLocation& location, const ast::ptr<ast::Type> & type,
232                Iterator & begin, const Iterator & end ) {
233        if ( auto tuple = type.as<ast::TupleType>() ) {
234                std::vector<ast::ptr<ast::Expr>> exprs;
235                for ( const ast::ptr<ast::Type> & t : *tuple ) {
236                        exprs.push_back( structureArg( location, t, begin, end ) );
237                }
238                return new ast::TupleExpr( location, std::move( exprs ) );
239        } else {
240                assertf( begin != end, "reached the end of the arguments while structuring" );
241                return *begin++;
242        }
243}
244
245struct TypeInstFixer final : public ast::WithShortCircuiting {
246        std::map<const ast::TypeDecl *, std::pair<int, int>> typeMap;
247
248        void previsit(const ast::TypeDecl *) { visit_children = false; }
249        const ast::TypeInstType * postvisit(const ast::TypeInstType * typeInst) {
250                if (typeMap.count(typeInst->base)) {
251                        ast::TypeInstType * newInst = mutate(typeInst);
252                        auto const & pair = typeMap[typeInst->base];
253                        newInst->expr_id = pair.first;
254                        newInst->formal_usage = pair.second;
255                        return newInst;
256                }
257                return typeInst;
258        }
259};
260
261const ast::Expr * SpecializeCore::createThunkFunction(
262                const CodeLocation & location,
263                const ast::FunctionType * funType,
264                const ast::Expr * actual,
265                const ast::InferredParams * inferParams ) {
266        // One set of unique names per program.
267        static UniqueName thunkNamer("_thunk");
268
269        const ast::FunctionType * newType = ast::deepCopy( funType );
270        if ( typeSubs ) {
271                // Must replace only occurrences of type variables
272                // that occure free in the thunk's type.
273                auto result = typeSubs->applyFree( newType );
274                newType = result.node.release();
275        }
276
277        using DWTVector = std::vector<ast::ptr<ast::DeclWithType>>;
278        using DeclVector = std::vector<ast::ptr<ast::TypeDecl>>;
279
280        UniqueName paramNamer( paramPrefix );
281
282        // Create new thunk with same signature as formal type.
283        ast::Pass<TypeInstFixer> fixer;
284        for (const auto & kv : newType->forall) {
285                if (fixer.core.typeMap.count(kv->base)) {
286                        std::cerr << location << ' ' << kv->base->name
287                                << ' ' << kv->expr_id << '_' << kv->formal_usage
288                                << ',' << fixer.core.typeMap[kv->base].first
289                                << '_' << fixer.core.typeMap[kv->base].second << std::endl;
290                        assertf(false, "multiple formals in specialize");
291                }
292                else {
293                        fixer.core.typeMap[kv->base] = std::make_pair(kv->expr_id, kv->formal_usage);
294                }
295        }
296
297        ast::CompoundStmt * thunkBody = new ast::CompoundStmt( location );
298        ast::FunctionDecl * thunkFunc = new ast::FunctionDecl(
299                location,
300                thunkNamer.newName(),
301                map_range<DeclVector>( newType->forall, []( const ast::TypeInstType * inst ) {
302                        return ast::deepCopy( inst->base );
303                } ),
304                map_range<DWTVector>( newType->assertions, []( const ast::VariableExpr * expr ) {
305                        return ast::deepCopy( expr->var );
306                } ),
307                map_range<DWTVector>( newType->params, [&location, &paramNamer]( const ast::Type * type ) {
308                        return new ast::ObjectDecl( location, paramNamer.newName(), ast::deepCopy( type ) );
309                } ),
310                map_range<DWTVector>( newType->returns, [&location, &paramNamer]( const ast::Type * type ) {
311                        return new ast::ObjectDecl( location, paramNamer.newName(), ast::deepCopy( type ) );
312                } ),
313                thunkBody,
314                ast::Storage::Classes(),
315                ast::Linkage::C
316                );
317
318        thunkFunc->fixUniqueId();
319
320        // Thunks may be generated and not used, avoid them.
321        thunkFunc->attributes.push_back( new ast::Attribute( "unused" ) );
322
323        // Global thunks must be static to avoid collitions.
324        // Nested thunks must not be unique and hence, not static.
325        thunkFunc->storage.is_static = !isInFunction();
326
327        // Weave thunk parameters into call to actual function,
328        // naming thunk parameters as we go.
329        ast::ApplicationExpr * app = new ast::ApplicationExpr( location, actual );
330
331        const ast::FunctionType * actualType = ast::deepCopy( getFunctionType( actual->result ) );
332        if ( typeSubs ) {
333                // Need to apply the environment to the actual function's type,
334                // since it may itself be polymorphic.
335                auto result = typeSubs->apply( actualType );
336                actualType = result.node.release();
337        }
338
339        ast::ptr<ast::FunctionType> actualTypeManager = actualType;
340
341        std::vector<ast::ptr<ast::Expr>> args;
342        for ( ast::ptr<ast::DeclWithType> & param : thunkFunc->params ) {
343                // Name each thunk parameter and explode it.
344                // These are then threaded back into the actual function call.
345                ast::DeclWithType * mutParam = ast::mutate( param.get() );
346                explodeSimple( location, new ast::VariableExpr( location, mutParam ),
347                        std::back_inserter( args ) );
348        }
349
350        // Walk parameters to the actual function alongside the exploded thunk
351        // parameters and restructure the arguments to match the actual parameters.
352        std::vector<ast::ptr<ast::Expr>>::iterator
353                argBegin = args.begin(), argEnd = args.end();
354        for ( const auto & actualArg : actualType->params ) {
355                app->args.push_back(
356                        structureArg( location, actualArg.get(), argBegin, argEnd ) );
357        }
358        assertf( argBegin == argEnd, "Did not structure all arguments." );
359
360        app->accept(fixer); // this should modify in place
361
362        app->env = ast::TypeSubstitution::newFromExpr( app, typeSubs );
363        if ( inferParams ) {
364                app->inferred.inferParams() = *inferParams;
365        }
366
367        // Handle any specializations that may still be present.
368        {
369                std::string oldParamPrefix = paramPrefix;
370                paramPrefix += "p";
371                std::list<ast::ptr<ast::Decl>> oldDecls;
372                oldDecls.splice( oldDecls.end(), declsToAddBefore );
373
374                app->accept( *visitor );
375                // Write recursive specializations into the thunk body.
376                for ( const ast::ptr<ast::Decl> & decl : declsToAddBefore ) {
377                        thunkBody->push_back( new ast::DeclStmt( decl->location, decl ) );
378                }
379
380                declsToAddBefore = std::move( oldDecls );
381                paramPrefix = std::move( oldParamPrefix );
382        }
383
384        // Add return (or valueless expression) to the thunk.
385        ast::Stmt * appStmt;
386        if ( funType->returns.empty() ) {
387                appStmt = new ast::ExprStmt( app->location, app );
388        } else {
389                appStmt = new ast::ReturnStmt( app->location, app );
390        }
391        thunkBody->push_back( appStmt );
392
393        // Add the thunk definition:
394        declsToAddBefore.push_back( thunkFunc );
395
396        // Return address of thunk function as replacement expression.
397        return new ast::AddressExpr( location,
398                new ast::VariableExpr( location, thunkFunc ) );
399}
400
401const ast::Expr * SpecializeCore::doSpecialization(
402                const CodeLocation & location,
403                const ast::Type * formalType,
404                const ast::Expr * actual,
405                const ast::InferredParams * inferParams ) {
406        assertf( actual->result, "attempting to specialize an untyped expression" );
407        if ( needsSpecialization( formalType, actual->result, typeSubs ) ) {
408                if ( const ast::FunctionType * type = getFunctionType( formalType ) ) {
409                        if ( const ast::ApplicationExpr * expr =
410                                        dynamic_cast<const ast::ApplicationExpr *>( actual ) ) {
411                                return createThunkFunction( location, type, expr->func, inferParams );
412                        } else if ( auto expr =
413                                        dynamic_cast<const ast::VariableExpr *>( actual ) ) {
414                                return createThunkFunction( location, type, expr, inferParams );
415                        } else {
416                                // (I don't even know what that comment means.)
417                                // This likely won't work, as anything that could build an ApplicationExpr probably hit one of the previous two branches
418                                return createThunkFunction( location, type, actual, inferParams );
419                        }
420                } else {
421                        return actual;
422                }
423        } else {
424                return actual;
425        }
426}
427
428const ast::Expr * SpecializeCore::postvisit(
429                const ast::ApplicationExpr * expr ) {
430        if ( ast::isIntrinsicCallExpr( expr ) ) {
431                return expr;
432        }
433
434        // Create thunks for the inferred parameters.
435        // This is not needed for intrinsic calls, because they aren't
436        // actually passed to the function. It needs to handle explicit params
437        // before inferred params so that explicit params do not recieve a
438        // changed set of inferParams (and change them again).
439        // Alternatively, if order starts to matter then copy expr's inferParams
440        // and pass them to handleExplicitParams.
441        ast::ApplicationExpr * mut = handleExplicitParams( expr );
442        if ( !mut->inferred.hasParams() ) {
443                return mut;
444        }
445        ast::InferredParams & inferParams = mut->inferred.inferParams();
446        for ( ast::InferredParams::value_type & inferParam : inferParams ) {
447                inferParam.second.expr = doSpecialization(
448                        inferParam.second.expr->location,
449                        inferParam.second.formalType,
450                        inferParam.second.expr,
451                        getInferredParams( inferParam.second.expr )
452                );
453        }
454        return mut;
455}
456
457const ast::Expr * SpecializeCore::postvisit( const ast::CastExpr * expr ) {
458        if ( expr->result->isVoid() ) {
459                // No specialization if there is no return value.
460                return expr;
461        }
462        const ast::Expr * specialized = doSpecialization(
463                expr->location, expr->result, expr->arg, getInferredParams( expr ) );
464        if ( specialized != expr->arg ) {
465                // Assume that the specialization incorporates the cast.
466                return specialized;
467        } else {
468                return expr;
469        }
470}
471
472} // namespace
473
474void convertSpecializations( ast::TranslationUnit & translationUnit ) {
475        ast::Pass<SpecializeCore>::run( translationUnit );
476}
477
478} // namespace GenPoly
479
480// Local Variables: //
481// tab-width: 4 //
482// mode: c++ //
483// compile-command: "make install" //
484// End: //
Note: See TracBrowser for help on using the repository browser.