Changeset f3b0a07


Ignore:
Timestamp:
Jan 16, 2017, 3:29:18 PM (5 years ago)
Author:
Rob Schluntz <rschlunt@…>
Branches:
aaron-thesis, arm-eh, cleanup-dtors, deferred_resn, demangler, jacob/cs343-translation, jenkins-sandbox, master, new-ast, new-ast-unique-expr, new-env, no_list, persistent-indexer, resolv-new, with_gc
Children:
5ebb2fbc
Parents:
981bdc6
Message:

allow ttypes contained in tuple types to unify, refactor and simplify Specialize pass

Location:
src
Files:
3 edited

Legend:

Unmodified
Added
Removed
  • src/GenPoly/Specialize.cc

    r981bdc6 rf3b0a07  
    3535
    3636namespace GenPoly {
    37         class Specializer;
    3837        class Specialize final : public PolyMutator {
    39                 friend class Specializer;
    4038          public:
    4139                using PolyMutator::mutate;
     
    4745                // virtual Expression * mutate( CommaExpr *commaExpr );
    4846
    49                 Specializer * specializer = nullptr;
    5047                void handleExplicitParams( ApplicationExpr *appExpr );
     48                Expression * createThunkFunction( FunctionType *funType, Expression *actual, InferredParams *inferParams );
     49                Expression * doSpecialization( Type *formalType, Expression *actual, InferredParams *inferParams = nullptr );
     50
     51                std::string paramPrefix = "_p";
    5152        };
    5253
    53         class Specializer {
    54           public:
    55                 Specializer( Specialize & spec ) : spec( spec ), env( spec.env ), stmtsToAdd( spec.stmtsToAdd ) {}
    56                 virtual bool needsSpecialization( Type * formalType, Type * actualType, TypeSubstitution * env ) = 0;
    57                 virtual Expression *createThunkFunction( FunctionType *funType, Expression *actual, InferredParams *inferParams ) = 0;
    58                 virtual Expression *doSpecialization( Type *formalType, Expression *actual, InferredParams *inferParams = 0 );
    59 
    60           protected:
    61                 Specialize & spec;
    62                 std::string paramPrefix = "_p";
    63                 TypeSubstitution *& env;
    64                 std::list< Statement * > & stmtsToAdd;
    65         };
    66 
    67         // for normal polymorphic -> monomorphic function conversion
    68         class PolySpecializer : public Specializer {
    69           public:
    70                 PolySpecializer( Specialize & spec ) : Specializer( spec ) {}
    71                 virtual bool needsSpecialization( Type * formalType, Type * actualType, TypeSubstitution * env ) override;
    72                 virtual Expression *createThunkFunction( FunctionType *funType, Expression *actual, InferredParams *inferParams ) override;
    73         };
    74 
    75         // // for tuple -> non-tuple function conversion
    76         class TupleSpecializer : public Specializer {
    77           public:
    78                 TupleSpecializer( Specialize & spec ) : Specializer( spec ) {}
    79                 virtual bool needsSpecialization( Type * formalType, Type * actualType, TypeSubstitution * env ) override;
    80                 virtual Expression *createThunkFunction( FunctionType *funType, Expression *actual, InferredParams *inferParams ) override;
    81         };
    82 
    8354        /// Looks up open variables in actual type, returning true if any of them are bound in the environment or formal type.
    84         bool PolySpecializer::needsSpecialization( Type *formalType, Type *actualType, TypeSubstitution *env ) {
     55        bool needsPolySpecialization( Type *formalType, Type *actualType, TypeSubstitution *env ) {
    8556                if ( env ) {
    8657                        using namespace ResolvExpr;
     
    10677        }
    10778
    108         /// Generates a thunk that calls `actual` with type `funType` and returns its address
    109         Expression * PolySpecializer::createThunkFunction( FunctionType *funType, Expression *actual, InferredParams *inferParams ) {
    110                 static UniqueName thunkNamer( "_thunk" );
    111 
    112                 FunctionType *newType = funType->clone();
    113                 if ( env ) {
    114                         // it is important to replace only occurrences of type variables that occur free in the
    115                         // thunk's type
    116                         env->applyFree( newType );
    117                 } // if
    118                 // create new thunk with same signature as formal type (C linkage, empty body)
    119                 FunctionDecl *thunkFunc = new FunctionDecl( thunkNamer.newName(), DeclarationNode::NoStorageClass, LinkageSpec::C, newType, new CompoundStmt( noLabels ), false, false );
    120                 thunkFunc->fixUniqueId();
    121 
    122                 // thunks may be generated and not used - silence warning with attribute
    123                 thunkFunc->get_attributes().push_back( new Attribute( "unused" ) );
    124 
    125                 // thread thunk parameters into call to actual function, naming thunk parameters as we go
    126                 UniqueName paramNamer( paramPrefix );
    127                 ApplicationExpr *appExpr = new ApplicationExpr( actual );
    128                 for ( std::list< DeclarationWithType* >::iterator param = thunkFunc->get_functionType()->get_parameters().begin(); param != thunkFunc->get_functionType()->get_parameters().end(); ++param ) {
    129                         (*param )->set_name( paramNamer.newName() );
    130                         appExpr->get_args().push_back( new VariableExpr( *param ) );
    131                 } // for
    132                 appExpr->set_env( maybeClone( env ) );
    133                 if ( inferParams ) {
    134                         appExpr->get_inferParams() = *inferParams;
    135                 } // if
    136 
    137                 // handle any specializations that may still be present
    138                 std::string oldParamPrefix = paramPrefix;
    139                 paramPrefix += "p";
    140                 // save stmtsToAdd in oldStmts
    141                 std::list< Statement* > oldStmts;
    142                 oldStmts.splice( oldStmts.end(), stmtsToAdd );
    143                 spec.handleExplicitParams( appExpr );
    144                 paramPrefix = oldParamPrefix;
    145                 // write any statements added for recursive specializations into the thunk body
    146                 thunkFunc->get_statements()->get_kids().splice( thunkFunc->get_statements()->get_kids().end(), stmtsToAdd );
    147                 // restore oldStmts into stmtsToAdd
    148                 stmtsToAdd.splice( stmtsToAdd.end(), oldStmts );
    149 
    150                 // add return (or valueless expression) to the thunk
    151                 Statement *appStmt;
    152                 if ( funType->get_returnVals().empty() ) {
    153                         appStmt = new ExprStmt( noLabels, appExpr );
    154                 } else {
    155                         appStmt = new ReturnStmt( noLabels, appExpr );
    156                 } // if
    157                 thunkFunc->get_statements()->get_kids().push_back( appStmt );
    158 
    159                 // add thunk definition to queue of statements to add
    160                 stmtsToAdd.push_back( new DeclStmt( noLabels, thunkFunc ) );
    161                 // return address of thunk function as replacement expression
    162                 return new AddressExpr( new VariableExpr( thunkFunc ) );
    163         }
    164 
    165         Expression * Specializer::doSpecialization( Type *formalType, Expression *actual, InferredParams *inferParams ) {
     79        bool needsTupleSpecialization( Type *formalType, Type *actualType, TypeSubstitution *env ) {
     80                if ( FunctionType * ftype = getFunctionType( formalType ) ) {
     81                        return ftype->isTtype();
     82                }
     83                return false;
     84        }
     85
     86        bool needsSpecialization( Type *formalType, Type *actualType, TypeSubstitution *env ) {
     87                return needsPolySpecialization( formalType, actualType, env ) || needsTupleSpecialization( formalType, actualType, env );
     88        }
     89
     90        Expression * Specialize::doSpecialization( Type *formalType, Expression *actual, InferredParams *inferParams ) {
    16691                assertf( actual->has_result(), "attempting to specialize an untyped expression" );
    16792                if ( needsSpecialization( formalType, actual->get_result(), env ) ) {
     
    185110        }
    186111
    187         bool TupleSpecializer::needsSpecialization( Type *formalType, Type *actualType, TypeSubstitution *env ) {
    188                 if ( FunctionType * ftype = getFunctionType( formalType ) ) {
    189                         return ftype->isTtype();
    190                 }
    191                 return false;
    192         }
    193 
    194112        /// restructures arg to match the structure of a single formal parameter. Assumes that atomic types are compatible (as the Resolver should have ensured this)
    195113        template< typename OutIterator >
     
    207125
    208126        /// restructures the ttype argument to match the structure of the formal parameters of the actual function.
    209         // [begin, end) are the formal parameters.
    210         // args is the list of arguments currently given to the actual function, the last of which needs to be restructured.
     127        /// [begin, end) are the formal parameters.
     128        /// args is the list of arguments currently given to the actual function, the last of which needs to be restructured.
    211129        template< typename Iterator, typename OutIterator >
    212130        void fixLastArg( Expression * last, Iterator begin, Iterator end, OutIterator out ) {
    213                 // safe_dynamic_cast for the assertion
    214                 safe_dynamic_cast< TupleType * >( last->get_result() );
    215                 unsigned idx = 0;
    216                 for ( ; begin != end; ++begin ) {
    217                         DeclarationWithType * formal = *begin;
    218                         Type * formalType = formal->get_type();
    219                         matchOneFormal( last, idx, formalType, out );
    220                 }
    221                 delete last;
    222         }
    223 
    224         Expression * TupleSpecializer::createThunkFunction( FunctionType *funType, Expression *actual, InferredParams *inferParams ) {
    225                 static UniqueName thunkNamer( "_tupleThunk" );
     131                if ( Tuples::isTtype( last->get_result() ) ) {
     132                        *out++ = last;
     133                } else {
     134                        // safe_dynamic_cast for the assertion
     135                        safe_dynamic_cast< TupleType * >( last->get_result() );
     136                        unsigned idx = 0;
     137                        for ( ; begin != end; ++begin ) {
     138                                DeclarationWithType * formal = *begin;
     139                                Type * formalType = formal->get_type();
     140                                matchOneFormal( last, idx, formalType, out );
     141                        }
     142                        delete last;
     143                }
     144        }
     145
     146        /// Generates a thunk that calls `actual` with type `funType` and returns its address
     147        Expression * Specialize::createThunkFunction( FunctionType *funType, Expression *actual, InferredParams *inferParams ) {
     148                static UniqueName thunkNamer( "_thunk" );
    226149
    227150                FunctionType *newType = funType->clone();
     
    253176                std::list< DeclarationWithType * >::iterator formalEnd = funType->get_parameters().end();
    254177
    255                 Expression * last = nullptr;
    256178                for ( DeclarationWithType* param : thunkFunc->get_functionType()->get_parameters() ) {
    257179                        // walk the parameters to the actual function alongside the parameters to the thunk to find the location where the ttype parameter begins to satisfy parameters in the actual function.
     
    259181                        assertf( formalBegin != formalEnd, "Reached end of formal parameters before finding ttype parameter" );
    260182                        if ( Tuples::isTtype((*formalBegin)->get_type()) ) {
    261                                 last = new VariableExpr( param );
     183                                fixLastArg( new VariableExpr( param ), actualBegin, actualEnd, back_inserter( appExpr->get_args() ) );
    262184                                break;
    263185                        }
     
    268190                        appExpr->get_args().push_back( new VariableExpr( param ) );
    269191                } // for
    270                 assert( last );
    271                 fixLastArg( last, actualBegin, actualEnd, back_inserter( appExpr->get_args() ) );
    272192                appExpr->set_env( maybeClone( env ) );
    273193                if ( inferParams ) {
     
    281201                std::list< Statement* > oldStmts;
    282202                oldStmts.splice( oldStmts.end(), stmtsToAdd );
    283                 spec.mutate( appExpr );
     203                mutate( appExpr );
    284204                paramPrefix = oldParamPrefix;
    285205                // write any statements added for recursive specializations into the thunk body
     
    311231                std::list< Expression* >::iterator actual;
    312232                for ( formal = function->get_parameters().begin(), actual = appExpr->get_args().begin(); formal != function->get_parameters().end() && actual != appExpr->get_args().end(); ++formal, ++actual ) {
    313                         *actual = specializer->doSpecialization( (*formal )->get_type(), *actual, &appExpr->get_inferParams() );
     233                        *actual = doSpecialization( (*formal )->get_type(), *actual, &appExpr->get_inferParams() );
    314234                }
    315235        }
     
    322242                        // create thunks for the inferred parameters
    323243                        // don't need to do this for intrinsic calls, because they aren't actually passed
     244                        // need to handle explicit params before inferred params so that explicit params do not recieve a changed set of inferParams (and change them again)
     245                        // alternatively, if order starts to matter then copy appExpr's inferParams and pass them to handleExplicitParams.
     246                        handleExplicitParams( appExpr );
    324247                        for ( InferredParams::iterator inferParam = appExpr->get_inferParams().begin(); inferParam != appExpr->get_inferParams().end(); ++inferParam ) {
    325                                 inferParam->second.expr = specializer->doSpecialization( inferParam->second.formalType, inferParam->second.expr, inferParam->second.inferParams.get() );
     248                                inferParam->second.expr = doSpecialization( inferParam->second.formalType, inferParam->second.expr, inferParam->second.inferParams.get() );
    326249                        }
    327                         handleExplicitParams( appExpr );
    328250                }
    329251                return appExpr;
     
    333255                addrExpr->get_arg()->acceptMutator( *this );
    334256                assert( addrExpr->has_result() );
    335                 addrExpr->set_arg( specializer->doSpecialization( addrExpr->get_result(), addrExpr->get_arg() ) );
     257                addrExpr->set_arg( doSpecialization( addrExpr->get_result(), addrExpr->get_arg() ) );
    336258                return addrExpr;
    337259        }
     
    343265                        return castExpr;
    344266                }
    345                 Expression *specialized = specializer->doSpecialization( castExpr->get_result(), castExpr->get_arg() );
     267                Expression *specialized = doSpecialization( castExpr->get_result(), castExpr->get_arg() );
    346268                if ( specialized != castExpr->get_arg() ) {
    347269                        // assume here that the specialization incorporates the cast
     
    370292        void convertSpecializations( std::list< Declaration* >& translationUnit ) {
    371293                Specialize spec;
    372 
    373                 TupleSpecializer tupleSpec( spec );
    374                 spec.specializer = &tupleSpec;
    375                 mutateAll( translationUnit, spec );
    376 
    377                 PolySpecializer polySpec( spec );
    378                 spec.specializer = &polySpec;
    379294                mutateAll( translationUnit, spec );
    380295        }
  • src/InitTweak/InitTweak.cc

    r981bdc6 rf3b0a07  
    327327                        } else if ( ApplicationExpr * appExpr = dynamic_cast< ApplicationExpr * > ( expr ) ) {
    328328                                return handleDerefCalledFunction( appExpr );
     329                        } else if ( AddressExpr * addrExpr = dynamic_cast< AddressExpr * >( expr ) ) {
     330                                return getCalledFunction( addrExpr->get_arg() );
    329331                        }
    330332                        return nullptr;
     
    336338                if ( ! appExpr ) return NULL;
    337339                DeclarationWithType * function = getCalledFunction( appExpr->get_function() );
    338                 assert( function );
     340                assertf( function, "getCalledFunction returned nullptr: %s", toString( appExpr->get_function() ).c_str() );
    339341                // check for Intrinsic only - don't want to remove all overridable ctor/dtors because autogenerated ctor/dtor
    340342                // will call all member dtors, and some members may have a user defined dtor.
     
    386388                } else if ( UntypedExpr * untypedExpr = dynamic_cast< UntypedExpr * >( callExpr ) ) {
    387389                        return callArg( untypedExpr, pos );
     390                } else if ( TupleAssignExpr * tupleExpr = dynamic_cast< TupleAssignExpr * > ( callExpr ) ) {
     391                        std::list< Statement * > & stmts = tupleExpr->get_stmtExpr()->get_statements()->get_kids();
     392                        assertf( ! stmts.empty(), "TupleAssignExpr somehow has no statements." );
     393                        ExprStmt * stmt = safe_dynamic_cast< ExprStmt * >( stmts.back() );
     394                        TupleExpr * tuple = safe_dynamic_cast< TupleExpr * >( stmt->get_expr() );
     395                        assertf( ! tuple->get_exprs().empty(), "TupleAssignExpr somehow has empty tuple expr." );
     396                        return getCallArg( tuple->get_exprs().front(), pos );
    388397                } else {
    389                         assertf( false, "Unexpected expression type passed to getCallArg" );
     398                        assertf( false, "Unexpected expression type passed to getCallArg: %s", toString( callExpr ).c_str() );
    390399                }
    391400        }
  • src/ResolvExpr/Unify.cc

    r981bdc6 rf3b0a07  
    163163                        case TypeDecl::Ttype:
    164164                        // ttype unifies with any tuple type
    165                         return dynamic_cast< TupleType * >( type );
     165                        return dynamic_cast< TupleType * >( type ) || Tuples::isTtype( type );
    166166                } // switch
    167167                return false;
     
    488488        }
    489489
    490         template< typename Iterator >
    491         std::unique_ptr<Type> combineTypes( Iterator begin, Iterator end ) {
     490        template< typename Iterator, typename Func >
     491        std::unique_ptr<Type> combineTypes( Iterator begin, Iterator end, Func & toType ) {
    492492                std::list< Type * > types;
    493493                for ( ; begin != end; ++begin ) {
    494494                        // it's guaranteed that a ttype variable will be bound to a flat tuple, so ensure that this results in a flat tuple
    495                         flatten( (*begin)->get_type(), back_inserter( types ) );
     495                        flatten( toType( *begin ), back_inserter( types ) );
    496496                }
    497497                return std::unique_ptr<Type>( new TupleType( Type::Qualifiers(), types ) );
     
    500500        template< typename Iterator1, typename Iterator2 >
    501501        bool unifyDeclList( Iterator1 list1Begin, Iterator1 list1End, Iterator2 list2Begin, Iterator2 list2End, TypeEnvironment &env, AssertionSet &needAssertions, AssertionSet &haveAssertions, const OpenVarSet &openVars, const SymTab::Indexer &indexer ) {
     502                auto get_type = [](DeclarationWithType * dwt){ return dwt->get_type(); };
    502503                for ( ; list1Begin != list1End && list2Begin != list2End; ++list1Begin, ++list2Begin ) {
    503504                        Type * t1 = (*list1Begin)->get_type();
     
    509510                        if ( isTtype1 && ! isTtype2 ) {
    510511                                // combine all of the things in list2, then unify
    511                                 return unifyExact( t1, combineTypes( list2Begin, list2End ).get(), env, needAssertions, haveAssertions, openVars, WidenMode( false, false ), indexer );
     512                                return unifyExact( t1, combineTypes( list2Begin, list2End, get_type ).get(), env, needAssertions, haveAssertions, openVars, WidenMode( false, false ), indexer );
    512513                        } else if ( isTtype2 && ! isTtype1 ) {
    513514                                // combine all of the things in list1, then unify
    514                                 return unifyExact( combineTypes( list1Begin, list1End ).get(), t2, env, needAssertions, haveAssertions, openVars, WidenMode( false, false ), indexer );
     515                                return unifyExact( combineTypes( list1Begin, list1End, get_type ).get(), t2, env, needAssertions, haveAssertions, openVars, WidenMode( false, false ), indexer );
    515516                        } else if ( ! unifyExact( t1, t2, env, needAssertions, haveAssertions, openVars, WidenMode( false, false ), indexer ) ) {
    516517                                return false;
     
    522523                        Type * t1 = (*list1Begin)->get_type();
    523524                        if ( Tuples::isTtype( t1 ) ) {
    524                                 return unifyExact( t1, combineTypes( list2Begin, list2End ).get(), env, needAssertions, haveAssertions, openVars, WidenMode( false, false ), indexer );
     525                                return unifyExact( t1, combineTypes( list2Begin, list2End, get_type ).get(), env, needAssertions, haveAssertions, openVars, WidenMode( false, false ), indexer );
    525526                        } else return false;
    526527                } else if ( list2Begin != list2End ) {
     
    528529                        Type * t2 = (*list2Begin)->get_type();
    529530                        if ( Tuples::isTtype( t2 ) ) {
    530                                 return unifyExact( combineTypes( list1Begin, list1End ).get(), t2, env, needAssertions, haveAssertions, openVars, WidenMode( false, false ), indexer );
     531                                return unifyExact( combineTypes( list1Begin, list1End, get_type ).get(), t2, env, needAssertions, haveAssertions, openVars, WidenMode( false, false ), indexer );
    531532                        } else return false;
    532533                } else {
     
    665666        template< typename Iterator1, typename Iterator2 >
    666667        bool unifyList( Iterator1 list1Begin, Iterator1 list1End, Iterator2 list2Begin, Iterator2 list2End, TypeEnvironment &env, AssertionSet &needAssertions, AssertionSet &haveAssertions, const OpenVarSet &openVars, WidenMode widenMode, const SymTab::Indexer &indexer ) {
     668                auto get_type = [](Type * t) { return t; };
    667669                for ( ; list1Begin != list1End && list2Begin != list2End; ++list1Begin, ++list2Begin ) {
    668                         Type *commonType = 0;
    669                         if ( ! unifyInexact( *list1Begin, *list2Begin, env, needAssertions, haveAssertions, openVars, widenMode, indexer, commonType ) ) {
     670                        Type * t1 = *list1Begin;
     671                        Type * t2 = *list2Begin;
     672                        bool isTtype1 = Tuples::isTtype( t1 );
     673                        bool isTtype2 = Tuples::isTtype( t2 );
     674                        // xxx - assumes ttype must be last parameter
     675                        // xxx - there may be a nice way to refactor this, but be careful because the argument positioning might matter in some cases.
     676                        if ( isTtype1 && ! isTtype2 ) {
     677                                // combine all of the things in list2, then unify
     678                                return unifyExact( t1, combineTypes( list2Begin, list2End, get_type ).get(), env, needAssertions, haveAssertions, openVars, WidenMode( false, false ), indexer );
     679                        } else if ( isTtype2 && ! isTtype1 ) {
     680                                // combine all of the things in list1, then unify
     681                                return unifyExact( combineTypes( list1Begin, list1End, get_type ).get(), t2, env, needAssertions, haveAssertions, openVars, WidenMode( false, false ), indexer );
     682                        } else if ( ! unifyExact( t1, t2, env, needAssertions, haveAssertions, openVars, WidenMode( false, false ), indexer ) ) {
    670683                                return false;
    671                         }
    672                         delete commonType;
     684                        } // if
     685
    673686                } // for
    674                 if ( list1Begin != list1End || list2Begin != list2End ) {
    675                         return false;
     687                if ( list1Begin != list1End ) {
     688                        // try unifying empty tuple type with ttype
     689                        Type * t1 = *list1Begin;
     690                        if ( Tuples::isTtype( t1 ) ) {
     691                                return unifyExact( t1, combineTypes( list2Begin, list2End, get_type ).get(), env, needAssertions, haveAssertions, openVars, WidenMode( false, false ), indexer );
     692                        } else return false;
     693                } else if ( list2Begin != list2End ) {
     694                        // try unifying empty tuple type with ttype
     695                        Type * t2 = *list2Begin;
     696                        if ( Tuples::isTtype( t2 ) ) {
     697                                return unifyExact( combineTypes( list1Begin, list1End, get_type ).get(), t2, env, needAssertions, haveAssertions, openVars, WidenMode( false, false ), indexer );
     698                        } else return false;
    676699                } else {
    677700                        return true;
    678                 } //if
     701                } // if
    679702        }
    680703
    681704        void Unify::visit(TupleType *tupleType) {
    682705                if ( TupleType *otherTuple = dynamic_cast< TupleType* >( type2 ) ) {
    683                         result = unifyList( tupleType->get_types().begin(), tupleType->get_types().end(), otherTuple->get_types().begin(), otherTuple->get_types().end(), env, needAssertions, haveAssertions, openVars, widenMode, indexer );
     706                        std::unique_ptr<TupleType> flat1( tupleType->clone() );
     707                        std::unique_ptr<TupleType> flat2( otherTuple->clone() );
     708                        std::list<Type *> types1, types2;
     709
     710                        TtypeExpander expander( env );
     711                        flat1->acceptMutator( expander );
     712                        flat2->acceptMutator( expander );
     713
     714                        flatten( flat1.get(), back_inserter( types1 ) );
     715                        flatten( flat2.get(), back_inserter( types2 ) );
     716
     717                        result = unifyList( types1.begin(), types1.end(), types2.begin(), types2.end(), env, needAssertions, haveAssertions, openVars, widenMode, indexer );
    684718                } // if
    685719        }
Note: See TracChangeset for help on using the changeset viewer.