Ignore:
File:
1 edited

Legend:

Unmodified
Added
Removed
  • src/GenPoly/Specialize.cc

    r68fe077a r623ecf3  
    7777        }
    7878
     79        /// True if both types have the same structure, but not necessarily the same types.
     80        /// That is, either both types are tuple types with the same size (recursively), or
     81        /// both are not tuple types.
     82        bool matchingTupleStructure( Type * t1, Type * t2 ) {
     83                TupleType * tuple1 = dynamic_cast< TupleType * >( t1 );
     84                TupleType * tuple2 = dynamic_cast< TupleType * >( t2 );
     85                if ( tuple1 && tuple2 ) {
     86                        if ( tuple1->size() != tuple2->size() ) return false;
     87                        for ( auto types : group_iterate( tuple1->get_types(), tuple2->get_types() ) ) {
     88                                if ( ! matchingTupleStructure( std::get<0>( types ), std::get<1>( types ) ) ) return false;
     89                        }
     90                        return true;
     91                } else if ( ! tuple1 && ! tuple2 ) return true;
     92                return false;
     93        }
     94
    7995        bool needsTupleSpecialization( Type *formalType, Type *actualType, TypeSubstitution *env ) {
    80                 if ( FunctionType * ftype = getFunctionType( formalType ) ) {
    81                         return ftype->isTtype();
     96                // Needs tuple specialization if the structure of the formal type and actual type do not match.
     97                // This is the case if the formal type has ttype polymorphism, or if the structure  of tuple types
     98                // between the function do not match exactly.
     99                if ( FunctionType * fftype = getFunctionType( formalType ) ) {
     100                        if ( fftype->isTtype() ) return true;
     101                        FunctionType * aftype = getFunctionType( actualType );
     102                        assertf( aftype, "formal type is a function type, but actual type is not." );
     103                        if ( fftype->get_parameters().size() != aftype->get_parameters().size() ) return true;
     104                        for ( auto params : group_iterate( fftype->get_parameters(), aftype->get_parameters() ) ) {
     105                                DeclarationWithType * formal = std::get<0>(params);
     106                                DeclarationWithType * actual = std::get<1>(params);
     107                                if ( ! matchingTupleStructure( formal->get_type(), actual->get_type() ) ) return true;
     108                        }
    82109                }
    83110                return false;
     
    110137        }
    111138
    112         /// restructures arg to match the structure of a single formal parameter. Assumes that atomic types are compatible (as the Resolver should have ensured this)
    113         template< typename OutIterator >
    114         void matchOneFormal( Expression * arg, unsigned & idx, Type * formal, OutIterator out ) {
    115                 if ( TupleType * tupleType = dynamic_cast< TupleType * >( formal ) ) {
     139        /// restructures the arguments to match the structure of the formal parameters of the actual function.
     140        /// [begin, end) are the exploded arguments.
     141        template< typename Iterator, typename OutIterator >
     142        void structureArg( Type * type, Iterator & begin, Iterator end, OutIterator out ) {
     143                if ( TupleType * tuple = dynamic_cast< TupleType * >( type ) ) {
    116144                        std::list< Expression * > exprs;
    117                         for ( Type * t : *tupleType ) {
    118                                 matchOneFormal( arg, idx, t, back_inserter( exprs ) );
     145                        for ( Type * t : *tuple ) {
     146                                structureArg( t, begin, end, back_inserter( exprs ) );
    119147                        }
    120148                        *out++ = new TupleExpr( exprs );
    121149                } else {
    122                         *out++ = new TupleIndexExpr( arg->clone(), idx++ );
    123                 }
    124         }
    125 
    126         /// restructures the ttype argument to match the structure of the formal parameters of the actual function.
    127         /// [begin, end) are the formal parameters.
    128         /// args is the list of arguments currently given to the actual function, the last of which needs to be restructured.
    129         template< typename Iterator, typename OutIterator >
    130         void fixLastArg( Expression * last, Iterator begin, Iterator end, OutIterator out ) {
    131                 if ( Tuples::isTtype( last->get_result() ) ) {
    132                         *out++ = last;
    133                 } else {
    134                         // safe_dynamic_cast for the assertion
    135                         safe_dynamic_cast< TupleType * >( last->get_result() );
    136                         unsigned idx = 0;
    137                         for ( ; begin != end; ++begin ) {
    138                                 DeclarationWithType * formal = *begin;
    139                                 Type * formalType = formal->get_type();
    140                                 matchOneFormal( last, idx, formalType, out );
    141                         }
    142                         delete last;
    143                 }
     150                        assertf( begin != end, "reached the end of the arguments while structuring" );
     151                        *out++ = *begin++;
     152                }
     153        }
     154
     155        /// explode assuming simple cases: either type is pure tuple (but not tuple expr) or type is non-tuple.
     156        template< typename OutputIterator >
     157        void explodeSimple( Expression * expr, OutputIterator out ) {
     158                if ( TupleType * tupleType = dynamic_cast< TupleType * > ( expr->get_result() ) ) {
     159                        // tuple type, recursively index into its components
     160                        for ( unsigned int i = 0; i < tupleType->size(); i++ ) {
     161                                explodeSimple( new TupleIndexExpr( expr->clone(), i ), out );
     162                        }
     163                        delete expr;
     164                } else {
     165                        // non-tuple type - output a clone of the expression
     166                        *out++ = expr;
     167                }
     168        }
     169
     170        struct EnvTrimmer : public Visitor {
     171                TypeSubstitution * env, * newEnv;
     172                EnvTrimmer( TypeSubstitution * env, TypeSubstitution * newEnv ) : env( env ), newEnv( newEnv ){}
     173                virtual void visit( TypeDecl * tyDecl ) {
     174                        // transfer known bindings for seen type variables
     175                        if ( Type * t = env->lookup( tyDecl->get_name() ) ) {
     176                                newEnv->add( tyDecl->get_name(), t );
     177                        }
     178                }
     179        };
     180
     181        /// reduce environment to just the parts that are referenced in a given expression
     182        TypeSubstitution * trimEnv( ApplicationExpr * expr, TypeSubstitution * env ) {
     183                if ( env ) {
     184                        TypeSubstitution * newEnv = new TypeSubstitution();
     185                        EnvTrimmer trimmer( env, newEnv );
     186                        expr->accept( trimmer );
     187                        return newEnv;
     188                }
     189                return nullptr;
    144190        }
    145191
     
    173219                std::list< DeclarationWithType * >::iterator actualBegin = actualType->get_parameters().begin();
    174220                std::list< DeclarationWithType * >::iterator actualEnd = actualType->get_parameters().end();
    175                 std::list< DeclarationWithType * >::iterator formalBegin = funType->get_parameters().begin();
    176                 std::list< DeclarationWithType * >::iterator formalEnd = funType->get_parameters().end();
    177 
     221
     222                std::list< Expression * > args;
    178223                for ( DeclarationWithType* param : thunkFunc->get_functionType()->get_parameters() ) {
    179                         // walk the parameters to the actual function alongside the parameters to the thunk to find the location where the ttype parameter begins to satisfy parameters in the actual function.
     224                        // name each thunk parameter and explode it - these are then threaded back into the actual function call.
    180225                        param->set_name( paramNamer.newName() );
    181                         assertf( formalBegin != formalEnd, "Reached end of formal parameters before finding ttype parameter" );
    182                         if ( Tuples::isTtype((*formalBegin)->get_type()) ) {
    183                                 fixLastArg( new VariableExpr( param ), actualBegin, actualEnd, back_inserter( appExpr->get_args() ) );
    184                                 break;
    185                         }
    186                         assertf( actualBegin != actualEnd, "reached end of actual function's arguments before finding ttype parameter" );
    187                         ++actualBegin;
    188                         ++formalBegin;
    189 
    190                         appExpr->get_args().push_back( new VariableExpr( param ) );
    191                 } // for
    192                 appExpr->set_env( maybeClone( env ) );
     226                        explodeSimple( new VariableExpr( param ), back_inserter( args ) );
     227                }
     228
     229                // walk parameters to the actual function alongside the exploded thunk parameters and restructure the arguments to match the actual parameters.
     230                std::list< Expression * >::iterator argBegin = args.begin(), argEnd = args.end();
     231                for ( ; actualBegin != actualEnd; ++actualBegin ) {
     232                        structureArg( (*actualBegin)->get_type(), argBegin, argEnd, back_inserter( appExpr->get_args() ) );
     233                }
     234
     235                appExpr->set_env( trimEnv( appExpr, env ) );
    193236                if ( inferParams ) {
    194237                        appExpr->get_inferParams() = *inferParams;
Note: See TracChangeset for help on using the changeset viewer.