Ignore:
File:
1 edited

Legend:

Unmodified
Added
Removed
  • src/GenPoly/Specialize.cc

    rf3b0a07 r6c3a988f  
    3535
    3636namespace GenPoly {
     37        class Specializer;
    3738        class Specialize final : public PolyMutator {
     39                friend class Specializer;
    3840          public:
    3941                using PolyMutator::mutate;
     
    4547                // virtual Expression * mutate( CommaExpr *commaExpr );
    4648
     49                Specializer * specializer = nullptr;
    4750                void handleExplicitParams( ApplicationExpr *appExpr );
    48                 Expression * createThunkFunction( FunctionType *funType, Expression *actual, InferredParams *inferParams );
    49                 Expression * doSpecialization( Type *formalType, Expression *actual, InferredParams *inferParams = nullptr );
    50 
     51        };
     52
     53        class Specializer {
     54          public:
     55                Specializer( Specialize & spec ) : spec( spec ), env( spec.env ), stmtsToAdd( spec.stmtsToAdd ) {}
     56                virtual bool needsSpecialization( Type * formalType, Type * actualType, TypeSubstitution * env ) = 0;
     57                virtual Expression *createThunkFunction( FunctionType *funType, Expression *actual, InferredParams *inferParams ) = 0;
     58                virtual Expression *doSpecialization( Type *formalType, Expression *actual, InferredParams *inferParams = 0 );
     59
     60          protected:
     61                Specialize & spec;
    5162                std::string paramPrefix = "_p";
     63                TypeSubstitution *& env;
     64                std::list< Statement * > & stmtsToAdd;
    5265        };
    5366
     67        // for normal polymorphic -> monomorphic function conversion
     68        class PolySpecializer : public Specializer {
     69          public:
     70                PolySpecializer( Specialize & spec ) : Specializer( spec ) {}
     71                virtual bool needsSpecialization( Type * formalType, Type * actualType, TypeSubstitution * env ) override;
     72                virtual Expression *createThunkFunction( FunctionType *funType, Expression *actual, InferredParams *inferParams ) override;
     73        };
     74
     75        // // for tuple -> non-tuple function conversion
     76        class TupleSpecializer : public Specializer {
     77          public:
     78                TupleSpecializer( Specialize & spec ) : Specializer( spec ) {}
     79                virtual bool needsSpecialization( Type * formalType, Type * actualType, TypeSubstitution * env ) override;
     80                virtual Expression *createThunkFunction( FunctionType *funType, Expression *actual, InferredParams *inferParams ) override;
     81        };
     82
    5483        /// Looks up open variables in actual type, returning true if any of them are bound in the environment or formal type.
    55         bool needsPolySpecialization( Type *formalType, Type *actualType, TypeSubstitution *env ) {
     84        bool PolySpecializer::needsSpecialization( Type *formalType, Type *actualType, TypeSubstitution *env ) {
    5685                if ( env ) {
    5786                        using namespace ResolvExpr;
     
    77106        }
    78107
    79         bool needsTupleSpecialization( Type *formalType, Type *actualType, TypeSubstitution *env ) {
    80                 if ( FunctionType * ftype = getFunctionType( formalType ) ) {
    81                         return ftype->isTtype();
    82                 }
    83                 return false;
    84         }
    85 
    86         bool needsSpecialization( Type *formalType, Type *actualType, TypeSubstitution *env ) {
    87                 return needsPolySpecialization( formalType, actualType, env ) || needsTupleSpecialization( formalType, actualType, env );
    88         }
    89 
    90         Expression * Specialize::doSpecialization( Type *formalType, Expression *actual, InferredParams *inferParams ) {
     108        /// Generates a thunk that calls `actual` with type `funType` and returns its address
     109        Expression * PolySpecializer::createThunkFunction( FunctionType *funType, Expression *actual, InferredParams *inferParams ) {
     110                static UniqueName thunkNamer( "_thunk" );
     111
     112                FunctionType *newType = funType->clone();
     113                if ( env ) {
     114                        // it is important to replace only occurrences of type variables that occur free in the
     115                        // thunk's type
     116                        env->applyFree( newType );
     117                } // if
     118                // create new thunk with same signature as formal type (C linkage, empty body)
     119                FunctionDecl *thunkFunc = new FunctionDecl( thunkNamer.newName(), DeclarationNode::NoStorageClass, LinkageSpec::C, newType, new CompoundStmt( noLabels ), false, false );
     120                thunkFunc->fixUniqueId();
     121
     122                // thunks may be generated and not used - silence warning with attribute
     123                thunkFunc->get_attributes().push_back( new Attribute( "unused" ) );
     124
     125                // thread thunk parameters into call to actual function, naming thunk parameters as we go
     126                UniqueName paramNamer( paramPrefix );
     127                ApplicationExpr *appExpr = new ApplicationExpr( actual );
     128                for ( std::list< DeclarationWithType* >::iterator param = thunkFunc->get_functionType()->get_parameters().begin(); param != thunkFunc->get_functionType()->get_parameters().end(); ++param ) {
     129                        (*param )->set_name( paramNamer.newName() );
     130                        appExpr->get_args().push_back( new VariableExpr( *param ) );
     131                } // for
     132                appExpr->set_env( maybeClone( env ) );
     133                if ( inferParams ) {
     134                        appExpr->get_inferParams() = *inferParams;
     135                } // if
     136
     137                // handle any specializations that may still be present
     138                std::string oldParamPrefix = paramPrefix;
     139                paramPrefix += "p";
     140                // save stmtsToAdd in oldStmts
     141                std::list< Statement* > oldStmts;
     142                oldStmts.splice( oldStmts.end(), stmtsToAdd );
     143                spec.handleExplicitParams( appExpr );
     144                paramPrefix = oldParamPrefix;
     145                // write any statements added for recursive specializations into the thunk body
     146                thunkFunc->get_statements()->get_kids().splice( thunkFunc->get_statements()->get_kids().end(), stmtsToAdd );
     147                // restore oldStmts into stmtsToAdd
     148                stmtsToAdd.splice( stmtsToAdd.end(), oldStmts );
     149
     150                // add return (or valueless expression) to the thunk
     151                Statement *appStmt;
     152                if ( funType->get_returnVals().empty() ) {
     153                        appStmt = new ExprStmt( noLabels, appExpr );
     154                } else {
     155                        appStmt = new ReturnStmt( noLabels, appExpr );
     156                } // if
     157                thunkFunc->get_statements()->get_kids().push_back( appStmt );
     158
     159                // add thunk definition to queue of statements to add
     160                stmtsToAdd.push_back( new DeclStmt( noLabels, thunkFunc ) );
     161                // return address of thunk function as replacement expression
     162                return new AddressExpr( new VariableExpr( thunkFunc ) );
     163        }
     164
     165        Expression * Specializer::doSpecialization( Type *formalType, Expression *actual, InferredParams *inferParams ) {
    91166                assertf( actual->has_result(), "attempting to specialize an untyped expression" );
    92167                if ( needsSpecialization( formalType, actual->get_result(), env ) ) {
     
    110185        }
    111186
     187        bool TupleSpecializer::needsSpecialization( Type *formalType, Type *actualType, TypeSubstitution *env ) {
     188                if ( FunctionType * ftype = getFunctionType( formalType ) ) {
     189                        return ftype->isTtype();
     190                }
     191                return false;
     192        }
     193
    112194        /// restructures arg to match the structure of a single formal parameter. Assumes that atomic types are compatible (as the Resolver should have ensured this)
    113195        template< typename OutIterator >
     
    125207
    126208        /// restructures the ttype argument to match the structure of the formal parameters of the actual function.
    127         /// [begin, end) are the formal parameters.
    128         /// args is the list of arguments currently given to the actual function, the last of which needs to be restructured.
     209        // [begin, end) are the formal parameters.
     210        // args is the list of arguments currently given to the actual function, the last of which needs to be restructured.
    129211        template< typename Iterator, typename OutIterator >
    130212        void fixLastArg( Expression * last, Iterator begin, Iterator end, OutIterator out ) {
    131                 if ( Tuples::isTtype( last->get_result() ) ) {
    132                         *out++ = last;
    133                 } else {
    134                         // safe_dynamic_cast for the assertion
    135                         safe_dynamic_cast< TupleType * >( last->get_result() );
    136                         unsigned idx = 0;
    137                         for ( ; begin != end; ++begin ) {
    138                                 DeclarationWithType * formal = *begin;
    139                                 Type * formalType = formal->get_type();
    140                                 matchOneFormal( last, idx, formalType, out );
    141                         }
    142                         delete last;
    143                 }
    144         }
    145 
    146         /// Generates a thunk that calls `actual` with type `funType` and returns its address
    147         Expression * Specialize::createThunkFunction( FunctionType *funType, Expression *actual, InferredParams *inferParams ) {
    148                 static UniqueName thunkNamer( "_thunk" );
     213                // safe_dynamic_cast for the assertion
     214                safe_dynamic_cast< TupleType * >( last->get_result() );
     215                unsigned idx = 0;
     216                for ( ; begin != end; ++begin ) {
     217                        DeclarationWithType * formal = *begin;
     218                        Type * formalType = formal->get_type();
     219                        matchOneFormal( last, idx, formalType, out );
     220                }
     221                delete last;
     222        }
     223
     224        Expression * TupleSpecializer::createThunkFunction( FunctionType *funType, Expression *actual, InferredParams *inferParams ) {
     225                static UniqueName thunkNamer( "_tupleThunk" );
    149226
    150227                FunctionType *newType = funType->clone();
     
    176253                std::list< DeclarationWithType * >::iterator formalEnd = funType->get_parameters().end();
    177254
     255                Expression * last = nullptr;
    178256                for ( DeclarationWithType* param : thunkFunc->get_functionType()->get_parameters() ) {
    179257                        // walk the parameters to the actual function alongside the parameters to the thunk to find the location where the ttype parameter begins to satisfy parameters in the actual function.
     
    181259                        assertf( formalBegin != formalEnd, "Reached end of formal parameters before finding ttype parameter" );
    182260                        if ( Tuples::isTtype((*formalBegin)->get_type()) ) {
    183                                 fixLastArg( new VariableExpr( param ), actualBegin, actualEnd, back_inserter( appExpr->get_args() ) );
     261                                last = new VariableExpr( param );
    184262                                break;
    185263                        }
     
    190268                        appExpr->get_args().push_back( new VariableExpr( param ) );
    191269                } // for
     270                assert( last );
     271                fixLastArg( last, actualBegin, actualEnd, back_inserter( appExpr->get_args() ) );
    192272                appExpr->set_env( maybeClone( env ) );
    193273                if ( inferParams ) {
     
    201281                std::list< Statement* > oldStmts;
    202282                oldStmts.splice( oldStmts.end(), stmtsToAdd );
    203                 mutate( appExpr );
     283                spec.mutate( appExpr );
    204284                paramPrefix = oldParamPrefix;
    205285                // write any statements added for recursive specializations into the thunk body
     
    231311                std::list< Expression* >::iterator actual;
    232312                for ( formal = function->get_parameters().begin(), actual = appExpr->get_args().begin(); formal != function->get_parameters().end() && actual != appExpr->get_args().end(); ++formal, ++actual ) {
    233                         *actual = doSpecialization( (*formal )->get_type(), *actual, &appExpr->get_inferParams() );
     313                        *actual = specializer->doSpecialization( (*formal )->get_type(), *actual, &appExpr->get_inferParams() );
    234314                }
    235315        }
     
    242322                        // create thunks for the inferred parameters
    243323                        // don't need to do this for intrinsic calls, because they aren't actually passed
    244                         // need to handle explicit params before inferred params so that explicit params do not recieve a changed set of inferParams (and change them again)
    245                         // alternatively, if order starts to matter then copy appExpr's inferParams and pass them to handleExplicitParams.
     324                        for ( InferredParams::iterator inferParam = appExpr->get_inferParams().begin(); inferParam != appExpr->get_inferParams().end(); ++inferParam ) {
     325                                inferParam->second.expr = specializer->doSpecialization( inferParam->second.formalType, inferParam->second.expr, inferParam->second.inferParams.get() );
     326                        }
    246327                        handleExplicitParams( appExpr );
    247                         for ( InferredParams::iterator inferParam = appExpr->get_inferParams().begin(); inferParam != appExpr->get_inferParams().end(); ++inferParam ) {
    248                                 inferParam->second.expr = doSpecialization( inferParam->second.formalType, inferParam->second.expr, inferParam->second.inferParams.get() );
    249                         }
    250328                }
    251329                return appExpr;
     
    255333                addrExpr->get_arg()->acceptMutator( *this );
    256334                assert( addrExpr->has_result() );
    257                 addrExpr->set_arg( doSpecialization( addrExpr->get_result(), addrExpr->get_arg() ) );
     335                addrExpr->set_arg( specializer->doSpecialization( addrExpr->get_result(), addrExpr->get_arg() ) );
    258336                return addrExpr;
    259337        }
     
    265343                        return castExpr;
    266344                }
    267                 Expression *specialized = doSpecialization( castExpr->get_result(), castExpr->get_arg() );
     345                Expression *specialized = specializer->doSpecialization( castExpr->get_result(), castExpr->get_arg() );
    268346                if ( specialized != castExpr->get_arg() ) {
    269347                        // assume here that the specialization incorporates the cast
     
    292370        void convertSpecializations( std::list< Declaration* >& translationUnit ) {
    293371                Specialize spec;
     372
     373                TupleSpecializer tupleSpec( spec );
     374                spec.specializer = &tupleSpec;
     375                mutateAll( translationUnit, spec );
     376
     377                PolySpecializer polySpec( spec );
     378                spec.specializer = &polySpec;
    294379                mutateAll( translationUnit, spec );
    295380        }
Note: See TracChangeset for help on using the changeset viewer.