Ignore:
File:
1 edited

Legend:

Unmodified
Added
Removed
  • src/GenPoly/Specialize.cc

    r623ecf3 r68fe077a  
    7777        }
    7878
    79         /// True if both types have the same structure, but not necessarily the same types.
    80         /// That is, either both types are tuple types with the same size (recursively), or
    81         /// both are not tuple types.
    82         bool matchingTupleStructure( Type * t1, Type * t2 ) {
    83                 TupleType * tuple1 = dynamic_cast< TupleType * >( t1 );
    84                 TupleType * tuple2 = dynamic_cast< TupleType * >( t2 );
    85                 if ( tuple1 && tuple2 ) {
    86                         if ( tuple1->size() != tuple2->size() ) return false;
    87                         for ( auto types : group_iterate( tuple1->get_types(), tuple2->get_types() ) ) {
    88                                 if ( ! matchingTupleStructure( std::get<0>( types ), std::get<1>( types ) ) ) return false;
    89                         }
    90                         return true;
    91                 } else if ( ! tuple1 && ! tuple2 ) return true;
    92                 return false;
    93         }
    94 
    9579        bool needsTupleSpecialization( Type *formalType, Type *actualType, TypeSubstitution *env ) {
    96                 // Needs tuple specialization if the structure of the formal type and actual type do not match.
    97                 // This is the case if the formal type has ttype polymorphism, or if the structure  of tuple types
    98                 // between the function do not match exactly.
    99                 if ( FunctionType * fftype = getFunctionType( formalType ) ) {
    100                         if ( fftype->isTtype() ) return true;
    101                         FunctionType * aftype = getFunctionType( actualType );
    102                         assertf( aftype, "formal type is a function type, but actual type is not." );
    103                         if ( fftype->get_parameters().size() != aftype->get_parameters().size() ) return true;
    104                         for ( auto params : group_iterate( fftype->get_parameters(), aftype->get_parameters() ) ) {
    105                                 DeclarationWithType * formal = std::get<0>(params);
    106                                 DeclarationWithType * actual = std::get<1>(params);
    107                                 if ( ! matchingTupleStructure( formal->get_type(), actual->get_type() ) ) return true;
    108                         }
     80                if ( FunctionType * ftype = getFunctionType( formalType ) ) {
     81                        return ftype->isTtype();
    10982                }
    11083                return false;
     
    137110        }
    138111
    139         /// restructures the arguments to match the structure of the formal parameters of the actual function.
    140         /// [begin, end) are the exploded arguments.
    141         template< typename Iterator, typename OutIterator >
    142         void structureArg( Type * type, Iterator & begin, Iterator end, OutIterator out ) {
    143                 if ( TupleType * tuple = dynamic_cast< TupleType * >( type ) ) {
     112        /// restructures arg to match the structure of a single formal parameter. Assumes that atomic types are compatible (as the Resolver should have ensured this)
     113        template< typename OutIterator >
     114        void matchOneFormal( Expression * arg, unsigned & idx, Type * formal, OutIterator out ) {
     115                if ( TupleType * tupleType = dynamic_cast< TupleType * >( formal ) ) {
    144116                        std::list< Expression * > exprs;
    145                         for ( Type * t : *tuple ) {
    146                                 structureArg( t, begin, end, back_inserter( exprs ) );
     117                        for ( Type * t : *tupleType ) {
     118                                matchOneFormal( arg, idx, t, back_inserter( exprs ) );
    147119                        }
    148120                        *out++ = new TupleExpr( exprs );
    149121                } else {
    150                         assertf( begin != end, "reached the end of the arguments while structuring" );
    151                         *out++ = *begin++;
    152                 }
    153         }
    154 
    155         /// explode assuming simple cases: either type is pure tuple (but not tuple expr) or type is non-tuple.
    156         template< typename OutputIterator >
    157         void explodeSimple( Expression * expr, OutputIterator out ) {
    158                 if ( TupleType * tupleType = dynamic_cast< TupleType * > ( expr->get_result() ) ) {
    159                         // tuple type, recursively index into its components
    160                         for ( unsigned int i = 0; i < tupleType->size(); i++ ) {
    161                                 explodeSimple( new TupleIndexExpr( expr->clone(), i ), out );
     122                        *out++ = new TupleIndexExpr( arg->clone(), idx++ );
     123                }
     124        }
     125
     126        /// restructures the ttype argument to match the structure of the formal parameters of the actual function.
     127        /// [begin, end) are the formal parameters.
     128        /// args is the list of arguments currently given to the actual function, the last of which needs to be restructured.
     129        template< typename Iterator, typename OutIterator >
     130        void fixLastArg( Expression * last, Iterator begin, Iterator end, OutIterator out ) {
     131                if ( Tuples::isTtype( last->get_result() ) ) {
     132                        *out++ = last;
     133                } else {
     134                        // safe_dynamic_cast for the assertion
     135                        safe_dynamic_cast< TupleType * >( last->get_result() );
     136                        unsigned idx = 0;
     137                        for ( ; begin != end; ++begin ) {
     138                                DeclarationWithType * formal = *begin;
     139                                Type * formalType = formal->get_type();
     140                                matchOneFormal( last, idx, formalType, out );
    162141                        }
    163                         delete expr;
    164                 } else {
    165                         // non-tuple type - output a clone of the expression
    166                         *out++ = expr;
    167                 }
    168         }
    169 
    170         struct EnvTrimmer : public Visitor {
    171                 TypeSubstitution * env, * newEnv;
    172                 EnvTrimmer( TypeSubstitution * env, TypeSubstitution * newEnv ) : env( env ), newEnv( newEnv ){}
    173                 virtual void visit( TypeDecl * tyDecl ) {
    174                         // transfer known bindings for seen type variables
    175                         if ( Type * t = env->lookup( tyDecl->get_name() ) ) {
    176                                 newEnv->add( tyDecl->get_name(), t );
    177                         }
    178                 }
    179         };
    180 
    181         /// reduce environment to just the parts that are referenced in a given expression
    182         TypeSubstitution * trimEnv( ApplicationExpr * expr, TypeSubstitution * env ) {
    183                 if ( env ) {
    184                         TypeSubstitution * newEnv = new TypeSubstitution();
    185                         EnvTrimmer trimmer( env, newEnv );
    186                         expr->accept( trimmer );
    187                         return newEnv;
    188                 }
    189                 return nullptr;
     142                        delete last;
     143                }
    190144        }
    191145
     
    219173                std::list< DeclarationWithType * >::iterator actualBegin = actualType->get_parameters().begin();
    220174                std::list< DeclarationWithType * >::iterator actualEnd = actualType->get_parameters().end();
    221 
    222                 std::list< Expression * > args;
     175                std::list< DeclarationWithType * >::iterator formalBegin = funType->get_parameters().begin();
     176                std::list< DeclarationWithType * >::iterator formalEnd = funType->get_parameters().end();
     177
    223178                for ( DeclarationWithType* param : thunkFunc->get_functionType()->get_parameters() ) {
    224                         // name each thunk parameter and explode it - these are then threaded back into the actual function call.
     179                        // walk the parameters to the actual function alongside the parameters to the thunk to find the location where the ttype parameter begins to satisfy parameters in the actual function.
    225180                        param->set_name( paramNamer.newName() );
    226                         explodeSimple( new VariableExpr( param ), back_inserter( args ) );
    227                 }
    228 
    229                 // walk parameters to the actual function alongside the exploded thunk parameters and restructure the arguments to match the actual parameters.
    230                 std::list< Expression * >::iterator argBegin = args.begin(), argEnd = args.end();
    231                 for ( ; actualBegin != actualEnd; ++actualBegin ) {
    232                         structureArg( (*actualBegin)->get_type(), argBegin, argEnd, back_inserter( appExpr->get_args() ) );
    233                 }
    234 
    235                 appExpr->set_env( trimEnv( appExpr, env ) );
     181                        assertf( formalBegin != formalEnd, "Reached end of formal parameters before finding ttype parameter" );
     182                        if ( Tuples::isTtype((*formalBegin)->get_type()) ) {
     183                                fixLastArg( new VariableExpr( param ), actualBegin, actualEnd, back_inserter( appExpr->get_args() ) );
     184                                break;
     185                        }
     186                        assertf( actualBegin != actualEnd, "reached end of actual function's arguments before finding ttype parameter" );
     187                        ++actualBegin;
     188                        ++formalBegin;
     189
     190                        appExpr->get_args().push_back( new VariableExpr( param ) );
     191                } // for
     192                appExpr->set_env( maybeClone( env ) );
    236193                if ( inferParams ) {
    237194                        appExpr->get_inferParams() = *inferParams;
Note: See TracChangeset for help on using the changeset viewer.