Changeset b60f9d9


Ignore:
Timestamp:
Jun 13, 2018, 4:14:31 PM (6 years ago)
Author:
Aaron Moss <a3moss@…>
Branches:
new-env
Children:
97397a26
Parents:
6d53e779
Message:

Start on breadth-first assertion resolution

Location:
src/ResolvExpr
Files:
2 edited

Legend:

Unmodified
Added
Removed
  • src/ResolvExpr/AlternativeFinder.cc

    r6d53e779 rb60f9d9  
    99// Author           : Richard C. Bilson
    1010// Created On       : Sat May 16 23:52:08 2015
    11 // Last Modified By : Peter A. Buhr
    12 // Last Modified On : Sat Feb 17 11:19:39 2018
    13 // Update Count     : 33
     11// Last Modified By : Aaron B. Moss
     12// Last Modified On : Mon Jun 11 16:40:00 2018
     13// Update Count     : 34
    1414//
    1515
     
    1919#include <iostream>                // for operator<<, cerr, ostream, endl
    2020#include <iterator>                // for back_insert_iterator, back_inserter
     21#include <limits>                  // for numeric_limits<int>::max()
    2122#include <list>                    // for _List_iterator, list, _List_const_...
    2223#include <map>                     // for _Rb_tree_iterator, map, _Rb_tree_c...
    2324#include <memory>                  // for allocator_traits<>::value_type
     25#include <tuple>                   // for tuple
    2426#include <utility>                 // for pair
    2527#include <vector>                  // for vector
     
    444446                }
    445447
     448                #if 0 // cost of assertions accounted for in function creation
    446449                for ( InferredParams::const_iterator assert = appExpr->get_inferParams().begin(); assert != appExpr->get_inferParams().end(); ++assert ) {
    447450                        convCost += computeConversionCost( assert->second.actualType, assert->second.formalType, indexer, alt.env );
    448451                }
     452                #endif
    449453
    450454                return convCost;
     
    572576        }
    573577
     578        namespace {
     579                /// Information required to defer resolution of an expression
     580                struct AssertionPack {
     581                        SymTab::Indexer::IdData cdata;  ///< Satisfying declaration
     582                        Type* adjType;                  ///< Satisfying type
     583                        TypeEnvironment env;            ///< Post-unification environment
     584                        AssertionSet have;              ///< Post-unification have-set
     585                        AssertionSet need;              ///< Post-unification need-set
     586                        OpenVarSet openVars;            ///< Post-unification open-var set
     587
     588                        AssertionPack( const SymTab::Indexer::IdData& cdata, Type* adjType,
     589                                        TypeEnvironment&& env, AssertionSet&& have, AssertionSet&& need,
     590                                        OpenVarSet&& openVars )
     591                                : cdata(cdata), adjType(adjType), env(std::move(env)), have(std::move(have)),
     592                                  need(std::move(need)), openVars(std::move(openVars)) {}
     593                };
     594
     595                /// List of deferred assertion resolutions for the same type
     596                using DeferList = std::vector<AssertionPack>;
     597
     598                /// Intermediate state for assertion resolution
     599                struct AssertionResnState {
     600                        using DeferItem = std::tuple<DeclarationWithType*, AssertionSetValue, DeferList>;
     601
     602                        const Alternative& alt;           ///< Alternative being built from
     603                        AssertionSet newNeed;             ///< New assertions found from current assertions
     604                        OpenVarSet openVars;              ///< Open variables in current context
     605                        std::vector<DeferItem> deferred;  ///< Possible deferred assertion resolutions
     606                        const SymTab::Indexer& indexer;   ///< Name lookup
     607
     608                        AssertionResnState(const Alternative& alt, const OpenVarSet& openVars,
     609                                const SymTab::Indexer& indexer )
     610                                : alt{alt}, have{}, newNeed{}, openVars{openVars}, indexer{indexer} {}
     611                };
     612
     613                /// Binds a single assertion from a compatible AssertionPack, updating resolution state
     614                /// as appropriate.
     615                void bindAssertion( DeclarationWithType* curDecl, AssertionSetValue& assnInfo,
     616                                AssertionResnState& resn, AssertionPack&& match ) {
     617                        DeclarationWithType* candidate = match.cdata.id;
     618                       
     619                        addToIndexer( match.have, resn.indexer );
     620                        resn.newNeed.insert( match.need.begin(), match.need.end() );
     621                        resn.openVars = std::move(match.openVars);
     622                        resn.alt.env = std::move(match.env);
     623
     624                        assertf( candidate->get_uniqueId(), "Assertion candidate does not have a unique ID: %s", toString( candidate ).c_str() );
     625                        for ( auto& a : match.need ) {
     626                                if ( a.second.idChain.empty() ) {
     627                                        a.second.idChain = assnInfo.idChain;
     628                                        a.second.idChain.push_back( curDecl->get_uniqueId() );
     629                                }
     630                        }
     631
     632                        Expression* varExpr = match.cdata.combine( resn.alt.cvtCost );
     633                        varExpr->result = match.adjType;
     634
     635                        // follow the current assertion's ID chain to find the correct set of inferred
     636                        // parameters to add the candidate o (i.e. the set of inferred parameters belonging
     637                        // to the entity which requested the assertion parameter)
     638                        InferredParams* inferParams = &resn.alt.expr->inferParams;
     639                        for ( UniqueId id : assnInfo.idChain ) {
     640                                inferParams = (*inferParams)[ id ].inferParams.get();
     641                        }
     642
     643                        (*inferParams)[ curDecl->get_uniqueId() ] = ParamEntry{
     644                                candidate->get_uniqueId(), match.adjType, curDecl->get_type(), varExpr };
     645                }
     646
     647                /// Resolves a single assertion, returning false if no satisfying assertion, binding it
     648                /// if there is exactly one satisfying assertion, or adding to the defer-list if there
     649                /// is more than one
     650                bool resolveAssertion( DeclarationWithType* curDecl, AssertionSetValue& assnInfo,
     651                                AssertionResnState& resn ) {
     652                        // skip unused assertions
     653                        if ( ! assnInfo.isUsed ) return true;
     654
     655                        // lookup candidates for this assertion
     656                        std::list< SymTab::Indexer::IdData > candidates;
     657                        decls.lookupId( curDecl->name, candidates );
     658
     659                        // find the ones that unify with the desired type
     660                        DeferList matches;
     661                        for ( const auto& cdata : candidates ) {
     662                                DeclarationWithType* candidate = cdata.id;
     663
     664                                // build independent unification context for candidate
     665                                AssertionSet have, newNeed;
     666                                TypeEnvironment newEnv{ resn.alt.env };
     667                                OpenVarSet newOpenVars{ resn.openVars };
     668                                Type* adjType = candidate->get_type()->clone();
     669                                adjustExprType( adjType, newEnv, resn.indexer );
     670                                renameTyVars( adjType );
     671
     672                                if ( unify( curDecl->get_type(), adjType, newEnv,
     673                                                newNeed, have, newOpenVars, resn.indexer ) ) {
     674                                        matches.emplace_back( cdata, adjType, std::move(newEnv), std::move(have),
     675                                                std::move(newNeed), std::move(newOpenVars) );
     676                                }
     677                        }
     678
     679                        // Break if no suitable assertion
     680                        if ( matches.empty() ) return false;
     681
     682                        // Defer if too many suitable assertions
     683                        if ( matches.size() > 1 ) {
     684                                resn.deferred.emplace_back( curDecl, assnInfo, std::move(matches) );
     685                                return true;
     686                        }
     687
     688                        // otherwise bind current match in ongoing scope
     689                        bindAssertion( curDecl, assnInfo, resn, std::move(matches.front()) );
     690
     691                        return true;
     692                }
     693        }
     694
    574695        template< typename OutputIterator >
    575696        void AlternativeFinder::Finder::inferParameters( const AssertionSet &need, AssertionSet &have, const Alternative &newAlt, OpenVarSet &openVars, OutputIterator out ) {
     
    586707                // )
    587708                addToIndexer( have, decls );
     709
     710                AssertionResnState resn{ newAlt, openVars, indexer };
     711
     712                // resolve assertions in breadth-first-order up to a limited number of levels deep
     713                int level;
     714                for ( level = 0; level < recursionLimit; ++level ) {
     715                        // make initial pass at matching assertions
     716                        for ( auto& assn : need ) {
     717                                if ( ! resolveAssertion( assn.first, assn.second, resn ) ) {
     718                                        // fail early if any assertion fails to resolve
     719                                        return;
     720                                }
     721                        }
     722
     723                        // resolve deferred assertions by mutual compatibility and min-cost
     724                        if ( ! resn.deferred.empty() ) {
     725                                // TODO
     726                                assert(false && "TODO: deferred assertions unimplemented");
     727
     728                                // reset for next round
     729                                resn.deferred.clear();
     730                        }
     731
     732                        // quit resolving assertions if done
     733                        if ( resn.newNeed.empty() ) break;
     734
     735                        // otherwise start on next group of recursive assertions
     736                        need.swap( resn.newNeed );
     737                        resn.newNeed.clear();
     738                }
     739                if ( level >= recursionLimit ) {
     740                        SemanticError( newAlt.expr->location, "Too many recursive assertions" );
     741                }
     742               
     743                // add completed assertion to output
     744                *out++ = newAlt;
     745
     746#if 0           
    588747                AssertionSet newNeed;
    589748                PRINT(
     
    599758//          *out++ = newAlt;
    600759//          )
     760#endif
    601761        }
    602762
     
    630790
    631791                ArgPack(const TypeEnvironment& env, const AssertionSet& need, const AssertionSet& have,
    632                                 const OpenVarSet& openVars)
    633                         : parent(0), expr(nullptr), cost(Cost::zero), env(env), need(need), have(have),
     792                                const OpenVarSet& openVars, Cost initCost = Cost::zero)
     793                        : parent(0), expr(nullptr), cost(initCost), env(env), need(need), have(have),
    634794                          openVars(openVars), nextArg(0), tupleStart(0), nextExpl(0), explAlt(0) {}
    635795
     
    9231083        }
    9241084
     1085        namespace {
     1086
     1087                struct CountSpecs : public WithVisitorRef<CountSpecs>, WithShortCircuiting {
     1088
     1089                        void postvisit(PointerType*) {
     1090                                // mark specialization of base type
     1091                                if ( count >= 0 ) ++count;
     1092                        }
     1093
     1094                        void postvisit(ArrayType*) {
     1095                                // mark specialization of base type
     1096                                if ( count >= 0 ) ++count;
     1097                        }
     1098
     1099                        void postvisit(ReferenceType*) {
     1100                                // mark specialization of base type -- xxx maybe not?
     1101                                if ( count >= 0 ) ++count;
     1102                        }
     1103
     1104                private:
     1105                        // takes minimum non-negative count over parameter/return list
     1106                        void takeminover( int& mincount, std::list<DeclarationWithType*>& dwts ) {
     1107                                for ( DeclarationWithType* dwt : dwts ) {
     1108                                        count = -1;
     1109                                        maybeAccept( dwt->get_type(), *visitor );
     1110                                        if ( count != -1 && count < mincount ) mincount = count;
     1111                                }
     1112                        }
     1113
     1114                public:
     1115                        void previsit(FunctionType*) {
     1116                                // override default child visiting behaviour
     1117                                visit_children = false;
     1118                        }
     1119
     1120                        void postvisit(FunctionType* fty) {
     1121                                // take minimal set value of count over ->returnVals and ->parameters
     1122                                int mincount = std::numeric_limits<int>::max();
     1123                                takeminover( mincount, fty->parameters );
     1124                                takeminover( mincount, fty->returnVals );
     1125                                // add another level to mincount if set
     1126                                count = mincount < std::numeric_limits<int>::max() ? mincount + 1 : -1;
     1127                        }
     1128
     1129                private:
     1130                        // returns minimum non-negative count + 1 over type parameters (-1 if none such)
     1131                        int minover( std::list<Expression*>& parms ) {
     1132                                int mincount = std::numeric_limits<int>::max();
     1133                                for ( Expression* parm : parms ) {
     1134                                        count = -1;
     1135                                        maybeAccept( parm->get_result(), *visitor );
     1136                                        if ( count != -1 && count < mincount ) mincount = count;
     1137                                }
     1138                                return mincount < std::numeric_limits<int>::max() ? mincount + 1 : -1;
     1139                        }
     1140                       
     1141                public:
     1142                        void previsit(StructInstType*) {
     1143                                // override default child behaviour
     1144                                visit_children = false;
     1145                        }
     1146
     1147                        void postvisit(StructInstType* sity) {
     1148                                // look for polymorphic parameters
     1149                                count = minover( sity->parameters );
     1150                        }
     1151
     1152                        void previsit(UnionInstType*) {
     1153                                // override default child behaviour
     1154                                visit_children = false;
     1155                        }
     1156
     1157                        void postvisit(UnionInstType* uity) {
     1158                                // look for polymorphic parameters
     1159                                count = minover( uity->parameters );
     1160                        }
     1161
     1162                        void postvisit(TypeInstType*) {
     1163                                // note polymorphic type (which may be specialized)
     1164                                // xxx - maybe account for open/closed type variables
     1165                                count = 0;
     1166                        }
     1167
     1168                        void previsit(TupleType*) {
     1169                                // override default child behaviour
     1170                                visit_children = false;
     1171                        }
     1172
     1173                        void postvisit(TupleType* tty) {
     1174                                // take minimum non-negative count
     1175                                int mincount = std::numeric_limits<int>::max();
     1176                                for ( Type* ty : tty->types ) {
     1177                                        count = -1;
     1178                                        maybeAccept( ty, *visitor );
     1179                                        if ( count != -1 && count < mincount ) mincount = count;
     1180                                }
     1181                                // xxx - maybe don't increment, tuple flattening doesn't necessarily specialize
     1182                                count = mincount < std::numeric_limits<int>::max() ? mincount + 1: -1;
     1183                        }
     1184
     1185                        int get_count() const { return count >= 0 ? count : 0; }
     1186                private:
     1187                        int count = -1;
     1188                };
     1189
     1190                /// Counts the specializations in the types in a function parameter or return list
     1191                int countSpecs( std::list<DeclarationWithType*>& dwts ) {
     1192                        int k = 0;
     1193                        for ( DeclarationWithType* dwt : dwts ) {
     1194                                PassVisitor<CountSpecs> counter;
     1195                                maybeAccept( dwt->get_type(), *counter.pass.visitor );
     1196                                k += counter.pass.get_count();
     1197                        }
     1198                        return k;
     1199                }
     1200
     1201                /// Calculates the inherent costs in a function declaration; varCost for the number of
     1202                /// type variables and specCost for type assertions, as well as PolyType specializations
     1203                /// in the parameter and return lists.
     1204                Cost declCost( FunctionType* funcType ) {
     1205                        Cost k = Cost::zero;
     1206
     1207                        // add cost of type variables
     1208                        k.incVar( funcType->forall.size() );
     1209
     1210                        // subtract cost of type assertions
     1211                        for ( TypeDecl* td : funcType->forall ) {
     1212                                k.decSpec( td->assertions.size() );
     1213                        }
     1214
     1215                        // count specialized polymorphic types in parameter/return lists
     1216                        k.decSpec( countSpecs( funcType->parameters ) );
     1217                        k.decSpec( countSpecs( funcType->returnVals ) );
     1218
     1219                        return k;
     1220                }
     1221        }
     1222
    9251223        template<typename OutputIterator>
    9261224        void AlternativeFinder::Finder::validateFunctionAlternative( const Alternative &func,
     
    9671265                }
    9681266
     1267                // calculate declaration cost of function (+vars-spec)
     1268                Cost funcCost = declCost( funcType );
     1269
    9691270                // iteratively build matches, one parameter at a time
    9701271                std::vector<ArgPack> results;
    971                 results.push_back( ArgPack{ funcEnv, funcNeed, funcHave, funcOpenVars } );
     1272                results.push_back( ArgPack{ funcEnv, funcNeed, funcHave, funcOpenVars, funcCost } );
    9721273                std::size_t genStart = 0;
    9731274
  • src/ResolvExpr/Resolver.cc

    r6d53e779 rb60f9d9  
    5151namespace ResolvExpr {
    5252        struct Resolver final : public WithIndexer, public WithGuards, public WithVisitorRef<Resolver>, public WithShortCircuiting, public WithStmtsToAdd {
     53       
     54        friend void resolve( std::list<Declaration*> );
     55
    5356                Resolver() {}
    5457                Resolver( const SymTab::Indexer & other ) {
     
    9699                CurrentObject currentObject = nullptr;
    97100                bool inEnumDecl = false;
     101                bool atTopLevel = false;  ///< Was this resolver set up at the top level of resolution
    98102        };
    99103
    100104        void resolve( std::list< Declaration * > translationUnit ) {
    101105                PassVisitor<Resolver> resolver;
     106                resolver.pass.atTopLevel = true;  // mark resolver as top-level
    102107                acceptAll( translationUnit, resolver );
    103108        }
Note: See TracChangeset for help on using the changeset viewer.