// // Cforall Version 1.0.0 Copyright (C) 2015 University of Waterloo // // The contents of this file are covered under the licence agreement in the // file "LICENCE" distributed with Cforall. // // Specialize.cc -- // // Author : Richard C. Bilson // Created On : Mon May 18 07:44:20 2015 // Last Modified By : Rob Schluntz // Last Modified On : Thu Apr 28 15:17:45 2016 // Update Count : 24 // #include #include "Specialize.h" #include "GenPoly.h" #include "PolyMutator.h" #include "Parser/ParseNode.h" #include "SynTree/Expression.h" #include "SynTree/Statement.h" #include "SynTree/Type.h" #include "SynTree/Attribute.h" #include "SynTree/TypeSubstitution.h" #include "SynTree/Mutator.h" #include "ResolvExpr/FindOpenVars.h" #include "Common/UniqueName.h" #include "Common/utility.h" #include "InitTweak/InitTweak.h" namespace GenPoly { class Specializer; class Specialize final : public PolyMutator { friend class Specializer; public: using PolyMutator::mutate; virtual Expression * mutate( ApplicationExpr *applicationExpr ) override; virtual Expression * mutate( AddressExpr *castExpr ) override; virtual Expression * mutate( CastExpr *castExpr ) override; // virtual Expression * mutate( LogicalExpr *logicalExpr ); // virtual Expression * mutate( ConditionalExpr *conditionalExpr ); // virtual Expression * mutate( CommaExpr *commaExpr ); Specializer * specializer = nullptr; void handleExplicitParams( ApplicationExpr *appExpr ); }; class Specializer { public: Specializer( Specialize & spec ) : spec( spec ), env( spec.env ), stmtsToAdd( spec.stmtsToAdd ) {} virtual bool needsSpecialization( Type * formalType, Type * actualType, TypeSubstitution * env ) = 0; virtual Expression *createThunkFunction( FunctionType *funType, Expression *actual, InferredParams *inferParams ) = 0; virtual Expression *doSpecialization( Type *formalType, Expression *actual, InferredParams *inferParams = 0 ); protected: Specialize & spec; std::string paramPrefix = "_p"; TypeSubstitution *& env; std::list< Statement * > & stmtsToAdd; }; // for normal polymorphic -> monomorphic function conversion class PolySpecializer : public Specializer { public: PolySpecializer( Specialize & spec ) : Specializer( spec ) {} virtual bool needsSpecialization( Type * formalType, Type * actualType, TypeSubstitution * env ) override; virtual Expression *createThunkFunction( FunctionType *funType, Expression *actual, InferredParams *inferParams ) override; }; // // for tuple -> non-tuple function conversion class TupleSpecializer : public Specializer { public: TupleSpecializer( Specialize & spec ) : Specializer( spec ) {} virtual bool needsSpecialization( Type * formalType, Type * actualType, TypeSubstitution * env ) override; virtual Expression *createThunkFunction( FunctionType *funType, Expression *actual, InferredParams *inferParams ) override; }; /// Looks up open variables in actual type, returning true if any of them are bound in the environment or formal type. bool PolySpecializer::needsSpecialization( Type *formalType, Type *actualType, TypeSubstitution *env ) { if ( env ) { using namespace ResolvExpr; OpenVarSet openVars, closedVars; AssertionSet need, have; findOpenVars( formalType, openVars, closedVars, need, have, false ); findOpenVars( actualType, openVars, closedVars, need, have, true ); for ( OpenVarSet::const_iterator openVar = openVars.begin(); openVar != openVars.end(); ++openVar ) { Type *boundType = env->lookup( openVar->first ); if ( ! boundType ) continue; if ( TypeInstType *typeInst = dynamic_cast< TypeInstType* >( boundType ) ) { if ( closedVars.find( typeInst->get_name() ) == closedVars.end() ) { return true; } // if } else { return true; } // if } // for return false; } else { return false; } // if } /// Generates a thunk that calls `actual` with type `funType` and returns its address Expression * PolySpecializer::createThunkFunction( FunctionType *funType, Expression *actual, InferredParams *inferParams ) { static UniqueName thunkNamer( "_thunk" ); FunctionType *newType = funType->clone(); if ( env ) { TypeSubstitution newEnv( *env ); // it is important to replace only occurrences of type variables that occur free in the // thunk's type newEnv.applyFree( newType ); } // if // create new thunk with same signature as formal type (C linkage, empty body) FunctionDecl *thunkFunc = new FunctionDecl( thunkNamer.newName(), DeclarationNode::NoStorageClass, LinkageSpec::C, newType, new CompoundStmt( noLabels ), false, false ); thunkFunc->fixUniqueId(); // thunks may be generated and not used - silence warning with attribute thunkFunc->get_attributes().push_back( new Attribute( "unused" ) ); // thread thunk parameters into call to actual function, naming thunk parameters as we go UniqueName paramNamer( paramPrefix ); ApplicationExpr *appExpr = new ApplicationExpr( actual ); for ( std::list< DeclarationWithType* >::iterator param = thunkFunc->get_functionType()->get_parameters().begin(); param != thunkFunc->get_functionType()->get_parameters().end(); ++param ) { (*param )->set_name( paramNamer.newName() ); appExpr->get_args().push_back( new VariableExpr( *param ) ); } // for appExpr->set_env( maybeClone( env ) ); if ( inferParams ) { appExpr->get_inferParams() = *inferParams; } // if // handle any specializations that may still be present std::string oldParamPrefix = paramPrefix; paramPrefix += "p"; // save stmtsToAdd in oldStmts std::list< Statement* > oldStmts; oldStmts.splice( oldStmts.end(), stmtsToAdd ); spec.handleExplicitParams( appExpr ); paramPrefix = oldParamPrefix; // write any statements added for recursive specializations into the thunk body thunkFunc->get_statements()->get_kids().splice( thunkFunc->get_statements()->get_kids().end(), stmtsToAdd ); // restore oldStmts into stmtsToAdd stmtsToAdd.splice( stmtsToAdd.end(), oldStmts ); // add return (or valueless expression) to the thunk Statement *appStmt; if ( funType->get_returnVals().empty() ) { appStmt = new ExprStmt( noLabels, appExpr ); } else { appStmt = new ReturnStmt( noLabels, appExpr ); } // if thunkFunc->get_statements()->get_kids().push_back( appStmt ); // add thunk definition to queue of statements to add stmtsToAdd.push_back( new DeclStmt( noLabels, thunkFunc ) ); // return address of thunk function as replacement expression return new AddressExpr( new VariableExpr( thunkFunc ) ); } Expression * Specializer::doSpecialization( Type *formalType, Expression *actual, InferredParams *inferParams ) { assertf( actual->has_result(), "attempting to specialize an untyped expression" ); if ( needsSpecialization( formalType, actual->get_result(), env ) ) { FunctionType *funType; if ( ( funType = getFunctionType( formalType ) ) ) { ApplicationExpr *appExpr; VariableExpr *varExpr; if ( ( appExpr = dynamic_cast( actual ) ) ) { return createThunkFunction( funType, appExpr->get_function(), inferParams ); } else if ( ( varExpr = dynamic_cast( actual ) ) ) { return createThunkFunction( funType, varExpr, inferParams ); } else { // This likely won't work, as anything that could build an ApplicationExpr probably hit one of the previous two branches return createThunkFunction( funType, actual, inferParams ); } } else { return actual; } // if } else { return actual; } // if } bool TupleSpecializer::needsSpecialization( Type *formalType, Type *actualType, TypeSubstitution *env ) { // std::cerr << "asking if type needs tuple spec: " << formalType << std::endl; if ( FunctionType * ftype = getFunctionType( formalType ) ) { return ftype->isTtype(); } return false; } /// restructures arg to match the structure of a single formal parameter. Assumes that atomic types are compatible (as the Resolver should have ensured this) template< typename OutIterator > void matchOneFormal( Expression * arg, unsigned & idx, Type * formal, OutIterator out ) { if ( TupleType * tupleType = dynamic_cast< TupleType * >( formal ) ) { std::list< Expression * > exprs; for ( Type * t : *tupleType ) { matchOneFormal( arg, idx, t, back_inserter( exprs ) ); } *out++ = new TupleExpr( exprs ); } else { *out++ = new TupleIndexExpr( arg->clone(), idx++ ); } } /// restructures the ttype argument to match the structure of the formal parameters of the actual function. // [begin, end) are the formal parameters. // args is the list of arguments currently given to the actual function, the last of which needs to be restructured. template< typename Iterator > void fixLastArg( std::list< Expression * > & args, Iterator begin, Iterator end ) { assertf( ! args.empty(), "Somehow args to tuple function are empty" ); // xxx - it's quite possible this will trigger for the nullary case... Expression * last = args.back(); // safe_dynamic_cast for the assertion safe_dynamic_cast< TupleType * >( last->get_result() ); // xxx - it's quite possible this will trigger for the unary case... args.pop_back(); // replace last argument in the call with unsigned idx = 0; for ( ; begin != end; ++begin ) { DeclarationWithType * formal = *begin; Type * formalType = formal->get_type(); matchOneFormal( last, idx, formalType, back_inserter( args ) ); } delete last; } Expression * TupleSpecializer::createThunkFunction( FunctionType *funType, Expression *actual, InferredParams *inferParams ) { static UniqueName thunkNamer( "_tupleThunk" ); // std::cerr << "creating tuple thunk for " << funType << std::endl; FunctionType *newType = funType->clone(); if ( env ) { TypeSubstitution newEnv( *env ); // it is important to replace only occurrences of type variables that occur free in the // thunk's type newEnv.applyFree( newType ); } // if // create new thunk with same signature as formal type (C linkage, empty body) FunctionDecl *thunkFunc = new FunctionDecl( thunkNamer.newName(), DeclarationNode::NoStorageClass, LinkageSpec::C, newType, new CompoundStmt( noLabels ), false, false ); thunkFunc->fixUniqueId(); // thunks may be generated and not used - silence warning with attribute thunkFunc->get_attributes().push_back( new Attribute( "unused" ) ); // thread thunk parameters into call to actual function, naming thunk parameters as we go UniqueName paramNamer( paramPrefix ); ApplicationExpr *appExpr = new ApplicationExpr( actual ); // std::cerr << actual << std::endl; FunctionType * actualType = getFunctionType( actual->get_result() ); std::list< DeclarationWithType * >::iterator begin = actualType->get_parameters().begin(); std::list< DeclarationWithType * >::iterator end = actualType->get_parameters().end(); for ( DeclarationWithType* param : thunkFunc->get_functionType()->get_parameters() ) { // walk the parameters to the actual function alongside the parameters to the thunk to find the location where the ttype parameter begins to satisfy parameters in the actual function. assert( begin != end ); ++begin; // std::cerr << "thunk param: " << param << std::endl; // last param will always be a tuple type... expand it into the actual type(?) param->set_name( paramNamer.newName() ); appExpr->get_args().push_back( new VariableExpr( param ) ); } // for fixLastArg( appExpr->get_args(), --begin, end ); appExpr->set_env( maybeClone( env ) ); if ( inferParams ) { appExpr->get_inferParams() = *inferParams; } // if // handle any specializations that may still be present std::string oldParamPrefix = paramPrefix; paramPrefix += "p"; // save stmtsToAdd in oldStmts std::list< Statement* > oldStmts; oldStmts.splice( oldStmts.end(), stmtsToAdd ); spec.handleExplicitParams( appExpr ); paramPrefix = oldParamPrefix; // write any statements added for recursive specializations into the thunk body thunkFunc->get_statements()->get_kids().splice( thunkFunc->get_statements()->get_kids().end(), stmtsToAdd ); // restore oldStmts into stmtsToAdd stmtsToAdd.splice( stmtsToAdd.end(), oldStmts ); // add return (or valueless expression) to the thunk Statement *appStmt; if ( funType->get_returnVals().empty() ) { appStmt = new ExprStmt( noLabels, appExpr ); } else { appStmt = new ReturnStmt( noLabels, appExpr ); } // if thunkFunc->get_statements()->get_kids().push_back( appStmt ); // std::cerr << "thunkFunc is: " << thunkFunc << std::endl; // add thunk definition to queue of statements to add stmtsToAdd.push_back( new DeclStmt( noLabels, thunkFunc ) ); // return address of thunk function as replacement expression return new AddressExpr( new VariableExpr( thunkFunc ) ); } void Specialize::handleExplicitParams( ApplicationExpr *appExpr ) { // create thunks for the explicit parameters assert( appExpr->get_function()->has_result() ); FunctionType *function = getFunctionType( appExpr->get_function()->get_result() ); assert( function ); std::list< DeclarationWithType* >::iterator formal; std::list< Expression* >::iterator actual; for ( formal = function->get_parameters().begin(), actual = appExpr->get_args().begin(); formal != function->get_parameters().end() && actual != appExpr->get_args().end(); ++formal, ++actual ) { *actual = specializer->doSpecialization( (*formal )->get_type(), *actual, &appExpr->get_inferParams() ); } } Expression * Specialize::mutate( ApplicationExpr *appExpr ) { appExpr->get_function()->acceptMutator( *this ); mutateAll( appExpr->get_args(), *this ); if ( ! InitTweak::isIntrinsicCallExpr( appExpr ) ) { // create thunks for the inferred parameters // don't need to do this for intrinsic calls, because they aren't actually passed for ( InferredParams::iterator inferParam = appExpr->get_inferParams().begin(); inferParam != appExpr->get_inferParams().end(); ++inferParam ) { inferParam->second.expr = specializer->doSpecialization( inferParam->second.formalType, inferParam->second.expr, &appExpr->get_inferParams() ); } handleExplicitParams( appExpr ); } return appExpr; } Expression * Specialize::mutate( AddressExpr *addrExpr ) { addrExpr->get_arg()->acceptMutator( *this ); assert( addrExpr->has_result() ); addrExpr->set_arg( specializer->doSpecialization( addrExpr->get_result(), addrExpr->get_arg() ) ); return addrExpr; } Expression * Specialize::mutate( CastExpr *castExpr ) { castExpr->get_arg()->acceptMutator( *this ); if ( castExpr->get_result()->isVoid() ) { // can't specialize if we don't have a return value return castExpr; } Expression *specialized = specializer->doSpecialization( castExpr->get_result(), castExpr->get_arg() ); if ( specialized != castExpr->get_arg() ) { // assume here that the specialization incorporates the cast return specialized; } else { return castExpr; } } // Removing these for now. Richard put these in for some reason, but it's not clear why. // In particular, copy constructors produce a comma expression, and with this code the parts // of that comma expression are not specialized, which causes problems. // Expression * Specialize::mutate( LogicalExpr *logicalExpr ) { // return logicalExpr; // } // Expression * Specialize::mutate( ConditionalExpr *condExpr ) { // return condExpr; // } // Expression * Specialize::mutate( CommaExpr *commaExpr ) { // return commaExpr; // } void convertSpecializations( std::list< Declaration* >& translationUnit ) { Specialize spec; TupleSpecializer tupleSpec( spec ); spec.specializer = &tupleSpec; mutateAll( translationUnit, spec ); PolySpecializer polySpec( spec ); spec.specializer = &polySpec; mutateAll( translationUnit, spec ); } } // namespace GenPoly // Local Variables: // // tab-width: 4 // // mode: c++ // // compile-command: "make install" // // End: //