Ignore:
File:
1 edited

Legend:

Unmodified
Added
Removed
  • src/ResolvExpr/SatisfyAssertions.cpp

    rb69233ac r396037d  
    1616#include "SatisfyAssertions.hpp"
    1717
    18 #include <algorithm>
    1918#include <cassert>
    20 #include <sstream>
    21 #include <string>
    22 #include <unordered_map>
    23 #include <vector>
    24 
    25 #include "Candidate.hpp"
    26 #include "CandidateFinder.hpp"
    27 #include "Cost.h"
    28 #include "RenameVars.h"
    29 #include "typeops.h"
    30 #include "Unify.h"
    31 #include "AST/Decl.hpp"
    32 #include "AST/Expr.hpp"
    33 #include "AST/Node.hpp"
    34 #include "AST/Pass.hpp"
    35 #include "AST/Print.hpp"
    36 #include "AST/SymbolTable.hpp"
    37 #include "AST/TypeEnvironment.hpp"
    38 #include "Common/FilterCombos.h"
    39 #include "Common/Indenter.h"
    40 #include "GenPoly/GenPoly.h"
    41 #include "SymTab/Mangler.h"
    4219
    4320namespace ResolvExpr {
    4421
    45 // in CandidateFinder.cpp; unique ID for assertion satisfaction
    46 extern UniqueId globalResnSlot;
    47 
    48 namespace {
    49         /// Post-unification assertion satisfaction candidate
    50         struct AssnCandidate {
    51                 ast::SymbolTable::IdData cdata;  ///< Satisfying declaration
    52                 ast::ptr< ast::Type > adjType;   ///< Satisfying type
    53                 ast::TypeEnvironment env;        ///< Post-unification environment
    54                 ast::AssertionSet have;          ///< Post-unification have-set
    55                 ast::AssertionSet need;          ///< Post-unification need-set
    56                 ast::OpenVarSet open;            ///< Post-unification open-var-set
    57                 ast::UniqueId resnSlot;          ///< Slot for any recursive assertion IDs
    58 
    59                 AssnCandidate(
    60                         const ast::SymbolTable::IdData c, const ast::Type * at, ast::TypeEnvironment && e,
    61                         ast::AssertionSet && h, ast::AssertionSet && n, ast::OpenVarSet && o, ast::UniqueId rs )
    62                 : cdata( c ), adjType( at ), env( std::move( e ) ), have( std::move( h ) ),
    63                   need( std::move( n ) ), open( std::move( o ) ), resnSlot( rs ) {}
    64         };
    65 
    66         /// List of assertion satisfaction candidates
    67         using AssnCandidateList = std::vector< AssnCandidate >;
    68 
    69         /// Reference to a single deferred item
    70         struct DeferRef {
    71                 const ast::DeclWithType * decl;
    72                 const ast::AssertionSetValue & info;
    73                 const AssnCandidate & match;
    74         };
    75        
    76         /// Wrapper for the deferred items from a single assertion satisfaction.
    77         /// Acts like an indexed list of DeferRef
    78         struct DeferItem {
    79                 const ast::DeclWithType * decl;
    80                 const ast::AssertionSetValue & info;
    81                 AssnCandidateList matches;
    82 
    83                 DeferItem(
    84                         const ast::DeclWithType * d, const ast::AssertionSetValue & i, AssnCandidateList && ms )
    85                 : decl( d ), info( i ), matches( std::move( ms ) ) {}
    86 
    87                 bool empty() const { return matches.empty(); }
    88 
    89                 AssnCandidateList::size_type size() const { return matches.size(); }
    90 
    91                 DeferRef operator[] ( unsigned i ) const { return { decl, info, matches[i] }; }
    92         };
    93 
    94         /// List of deferred satisfaction items
    95         using DeferList = std::vector< DeferItem >;
    96 
    97         /// Set of assertion satisfactions, grouped by resolution ID
    98         using InferCache = std::unordered_map< ast::UniqueId, ast::InferredParams >;
    99 
    100         /// Lexicographically-ordered vector of costs.
    101         /// Lexicographic order comes from default operator< on std::vector.
    102         using CostVec = std::vector< Cost >;
    103 
    104         /// Flag for state iteration
    105         enum IterateFlag { IterateState };
    106 
    107         /// Intermediate state for satisfying a set of assertions
    108         struct SatState {
    109                 CandidateRef cand;          ///< Candidate assertion is rooted on
    110                 ast::AssertionList need;    ///< Assertions to find
    111                 ast::AssertionSet newNeed;  ///< Recursive assertions from current satisfied assertions
    112                 DeferList deferred;         ///< Deferred matches
    113                 InferCache inferred;        ///< Cache of already-inferred assertions
    114                 CostVec costs;              ///< Disambiguating costs of recursive assertion satisfaction
    115                 ast::SymbolTable symtab;    ///< Name lookup (depends on previous assertions)
    116 
    117                 /// Initial satisfaction state for a candidate
    118                 SatState( CandidateRef & c, const ast::SymbolTable & syms )
    119                 : cand( c ), need(), newNeed(), deferred(), inferred(), costs{ Cost::zero },
    120                   symtab( syms ) { need.swap( c->need ); }
    121                
    122                 /// Update satisfaction state for next step after previous state
    123                 SatState( SatState && o, IterateFlag )
    124                 : cand( std::move( o.cand ) ), need( o.newNeed.begin(), o.newNeed.end() ), newNeed(),
    125                   deferred(), inferred( std::move( o.inferred ) ), costs( std::move( o.costs ) ),
    126                   symtab( o.symtab ) { costs.emplace_back( Cost::zero ); }
    127                
    128                 /// Field-wise next step constructor
    129                 SatState(
    130                         CandidateRef && c, ast::AssertionSet && nn, InferCache && i, CostVec && cs,
    131                         ast::SymbolTable && syms )
    132                 : cand( std::move( c ) ), need( nn.begin(), nn.end() ), newNeed(), deferred(),
    133                   inferred( std::move( i ) ), costs( std::move( cs ) ), symtab( std::move( syms ) )
    134                   { costs.emplace_back( Cost::zero ); }
    135         };
    136 
    137         /// Adds a captured assertion to the symbol table
    138         void addToSymbolTable( const ast::AssertionSet & have, ast::SymbolTable & symtab ) {
    139                 for ( auto & i : have ) {
    140                         if ( i.second.isUsed ) { symtab.addId( i.first ); }
    141                 }
    142         }
    143 
    144         /// Binds a single assertion, updating satisfaction state
    145         void bindAssertion(
    146                 const ast::DeclWithType * decl, const ast::AssertionSetValue & info, CandidateRef & cand,
    147                 AssnCandidate & match, InferCache & inferred
    148         ) {
    149                 const ast::DeclWithType * candidate = match.cdata.id;
    150                 assertf( candidate->uniqueId,
    151                         "Assertion candidate does not have a unique ID: %s", toString( candidate ).c_str() );
    152                
    153                 ast::Expr * varExpr = match.cdata.combine( cand->expr->location, cand->cvtCost );
    154                 varExpr->result = match.adjType;
    155                 if ( match.resnSlot ) { varExpr->inferred.resnSlots().emplace_back( match.resnSlot ); }
    156 
    157                 // place newly-inferred assertion in proper location in cache
    158                 inferred[ info.resnSlot ][ decl->uniqueId ] = ast::ParamEntry{
    159                         candidate->uniqueId, candidate, match.adjType, decl->get_type(), varExpr };
    160         }
    161 
    162         /// Satisfy a single assertion
    163         bool satisfyAssertion( ast::AssertionList::value_type & assn, SatState & sat ) {
    164                 // skip unused assertions
    165                 if ( ! assn.second.isUsed ) return true;
    166 
    167                 // find candidates that unify with the desired type
    168                 AssnCandidateList matches;
    169                 for ( const ast::SymbolTable::IdData & cdata : sat.symtab.lookupId( assn.first->name ) ) {
    170                         const ast::DeclWithType * candidate = cdata.id;
    171 
    172                         // build independent unification context for candidate
    173                         ast::AssertionSet have, newNeed;
    174                         ast::TypeEnvironment newEnv{ sat.cand->env };
    175                         ast::OpenVarSet newOpen{ sat.cand->open };
    176                         ast::ptr< ast::Type > toType = assn.first->get_type();
    177                         ast::ptr< ast::Type > adjType =
    178                                 renameTyVars( adjustExprType( candidate->get_type(), newEnv, sat.symtab ) );
    179 
    180                         // only keep candidates which unify
    181                         if ( unify( toType, adjType, newEnv, newNeed, have, newOpen, sat.symtab ) ) {
    182                                 // set up binding slot for recursive assertions
    183                                 ast::UniqueId crntResnSlot = 0;
    184                                 if ( ! newNeed.empty() ) {
    185                                         crntResnSlot = ++globalResnSlot;
    186                                         for ( auto & a : newNeed ) { a.second.resnSlot = crntResnSlot; }
    187                                 }
    188 
    189                                 matches.emplace_back(
    190                                         cdata, adjType, std::move( newEnv ), std::move( newNeed ), std::move( have ),
    191                                         std::move( newOpen ), crntResnSlot );
    192                         }
    193                 }
    194 
    195                 // break if no satisfying match
    196                 if ( matches.empty() ) return false;
    197 
    198                 // defer if too many satisfying matches
    199                 if ( matches.size() > 1 ) {
    200                         sat.deferred.emplace_back( assn.first, assn.second, std::move( matches ) );
    201                         return true;
    202                 }
    203 
    204                 // otherwise bind unique match in ongoing scope
    205                 AssnCandidate & match = matches.front();
    206                 addToSymbolTable( match.have, sat.symtab );
    207                 sat.newNeed.insert( match.need.begin(), match.need.end() );
    208                 sat.cand->env = std::move( match.env );
    209                 sat.cand->open = std::move( match.open );
    210 
    211                 bindAssertion( assn.first, assn.second, sat.cand, match, sat.inferred );
    212                 return true;
    213         }
    214 
    215         /// Map of candidate return types to recursive assertion satisfaction costs
    216         using PruneMap = std::unordered_map< std::string, CostVec >;
    217 
    218         /// Gets the pruning key for a candidate (derived from environment-adjusted return type)
    219         std::string pruneKey( const Candidate & cand ) {
    220                 ast::ptr< ast::Type > resType = cand.expr->result;
    221                 cand.env.apply( resType );
    222                 return Mangle::mangle( resType, Mangle::typeMode() );
    223         }
    224 
    225         /// Associates inferred parameters with an expression
    226         struct InferMatcher final {
    227                 InferCache & inferred;
    228 
    229                 InferMatcher( InferCache & inferred ) : inferred( inferred ) {}
    230 
    231                 const ast::Expr * postmutate( const ast::Expr * expr ) {
    232                         // Skip if no slots to find
    233                         if ( expr->inferred.mode != ast::Expr::InferUnion::Slots ) return expr;
    234 
    235                         // find inferred parameters for resolution slots
    236                         ast::InferredParams newInferred;
    237                         for ( UniqueId slot : expr->inferred.resnSlots() ) {
    238                                 // fail if no matching assertions found
    239                                 auto it = inferred.find( slot );
    240                                 if ( it == inferred.end() ) {
    241                                         assert(!"missing assertion");
    242                                 }
    243 
    244                                 // place inferred parameters into new map
    245                                 for ( auto & entry : it->second ) {
    246                                         // recurse on inferParams of resolved expressions
    247                                         entry.second.expr = postmutate( entry.second.expr );
    248                                         auto res = newInferred.emplace( entry );
    249                                         assert( res.second && "all assertions newly placed" );
    250                                 }
    251                         }
    252 
    253                         ast::Expr * ret = mutate( expr );
    254                         ret->inferred.set_inferParams( std::move( newInferred ) );
    255                         return ret;
    256                 }
    257         };
    258 
    259         /// Replace ResnSlots with InferParams and add alternative to output list, if it meets pruning
    260         /// threshold.
    261         void finalizeAssertions(
    262                 CandidateRef & cand, InferCache & inferred, PruneMap & thresholds, CostVec && costs,
    263                 CandidateList & out
    264         ) {
    265                 // prune if cheaper alternative for same key has already been generated
    266                 std::string key = pruneKey( *cand );
    267                 auto it = thresholds.find( key );
    268                 if ( it != thresholds.end() ) {
    269                         if ( it->second < costs ) return;
    270                 } else {
    271                         thresholds.emplace_hint( it, key, std::move( costs ) );
    272                 }
    273 
    274                 // replace resolution slots with inferred parameters, add to output
    275                 ast::Pass< InferMatcher > matcher{ inferred };
    276                 cand->expr = cand->expr->accept( matcher );
    277                 out.emplace_back( cand );
    278         }
    279 
    280         /// Combo iterator that combines candidates into an output list, merging their environments.
    281         /// Rejects an appended candidate if environments cannot be merged. See `Common/FilterCombos.h`
    282         /// for description of "combo iterator".
    283         class CandidateEnvMerger {
    284                 /// Current list of merged candidates
    285                 std::vector< DeferRef > crnt;
    286                 /// Stack of environments to support backtracking
    287                 std::vector< ast::TypeEnvironment > envs;
    288                 /// Stack of open variables to support backtracking
    289                 std::vector< ast::OpenVarSet > opens;
    290                 /// Symbol table to use for merges
    291                 const ast::SymbolTable & symtab;
    292 
    293         public:
    294                 /// The merged environment/open variables and the list of candidates
    295                 struct OutType {
    296                         ast::TypeEnvironment env;
    297                         ast::OpenVarSet open;
    298                         std::vector< DeferRef > assns;
    299                         Cost cost;
    300 
    301                         OutType(
    302                                 const ast::TypeEnvironment & e, const ast::OpenVarSet & o,
    303                                 const std::vector< DeferRef > & as, const ast::SymbolTable & symtab )
    304                         : env( e ), open( o ), assns( as ), cost( Cost::zero ) {
    305                                 // compute combined conversion cost
    306                                 for ( const DeferRef & assn : assns ) {
    307                                         // compute conversion cost from satisfying decl to assertion
    308                                         cost += computeConversionCost(
    309                                                 assn.match.adjType, assn.decl->get_type(), symtab, env );
    310                                        
    311                                         // mark vars+specialization on function-type assertions
    312                                         const ast::FunctionType * func =
    313                                                 GenPoly::getFunctionType( assn.match.cdata.id->get_type() );
    314                                         if ( ! func ) continue;
    315 
    316                                         for ( const ast::DeclWithType * param : func->params ) {
    317                                                 cost.decSpec( specCost( param->get_type() ) );
    318                                         }
    319                                        
    320                                         cost.incVar( func->forall.size() );
    321                                        
    322                                         for ( const ast::TypeDecl * td : func->forall ) {
    323                                                 cost.decSpec( td->assertions.size() );
    324                                         }
    325                                 }
    326                         }
    327 
    328                         bool operator< ( const OutType & o ) const { return cost < o.cost; }
    329                 };
    330 
    331                 CandidateEnvMerger(
    332                         const ast::TypeEnvironment & env, const ast::OpenVarSet & open,
    333                         const ast::SymbolTable & syms )
    334                 : crnt(), envs{ env }, opens{ open }, symtab( syms ) {}
    335 
    336                 bool append( DeferRef i ) {
    337                         ast::TypeEnvironment env = envs.back();
    338                         ast::OpenVarSet open = opens.back();
    339                         mergeOpenVars( open, i.match.open );
    340 
    341                         if ( ! env.combine( i.match.env, open, symtab ) ) return false;
    342 
    343                         crnt.emplace_back( i );
    344                         envs.emplace_back( std::move( env ) );
    345                         opens.emplace_back( std::move( open ) );
    346                         return true;
    347                 }
    348 
    349                 void backtrack() {
    350                         crnt.pop_back();
    351                         envs.pop_back();
    352                         opens.pop_back();
    353                 }
    354 
    355                 OutType finalize() { return { envs.back(), opens.back(), crnt, symtab }; }
    356         };
    357 
    358         /// Limit to depth of recursion of assertion satisfaction
    359         static const int recursionLimit = 4;
    360         /// Maximum number of simultaneously-deferred assertions to attempt concurrent satisfaction of
    361         static const int deferLimit = 10;
    362 } // anonymous namespace
    363 
    36422void satisfyAssertions(
    365         CandidateRef & cand, const ast::SymbolTable & symtab, CandidateList & out,
     23        Candidate & alt, const ast::SymbolTable & symtab, CandidateList & out,
    36624        std::vector<std::string> & errors
    36725) {
    368         // finish early if no assertions to satisfy
    369         if ( cand->need.empty() ) {
    370                 out.emplace_back( cand );
    371                 return;
    372         }
    373 
    374         // build list of possible combinations of satisfying declarations
    375         std::vector< SatState > sats{ SatState{ cand, symtab } };
    376         std::vector< SatState > nextSats{};
    377 
    378         // pruning thresholds by result type of output candidates.
    379         // Candidates *should* be generated in sorted order, so no need to retroactively prune
    380         PruneMap thresholds;
    381 
    382         // satisfy assertions in breadth-first order over the recursion tree of assertion satisfaction.
    383         // Stop recursion at a limited number of levels deep to avoid infinite loops.
    384         for ( unsigned level = 0; level < recursionLimit; ++level ) {
    385                 // for each current mutually-compatible set of assertions
    386                 for ( SatState & sat : sats ) {
    387                         // stop this branch if a better option is already found
    388                         auto it = thresholds.find( pruneKey( *sat.cand ) );
    389                         if ( it != thresholds.end() && it->second < sat.costs ) goto nextSat;
    390 
    391                         // make initial pass at matching assertions
    392                         for ( auto & assn : sat.need ) {
    393                                 // fail early if any assertion is not satisfiable
    394                                 if ( ! satisfyAssertion( assn, sat ) ) {
    395                                         Indenter tabs{ 3 };
    396                                         std::ostringstream ss;
    397                                         ss << tabs << "Unsatisfiable alternative:\n";
    398                                         print( ss, *sat.cand, ++tabs );
    399                                         ss << (tabs-1) << "Could not satisfy assertion:\n";
    400                                         ast::print( ss, assn.first, tabs );
    401 
    402                                         errors.emplace_back( ss.str() );
    403                                         goto nextSat;
    404                                 }
    405                         }
    406 
    407                         if ( sat.deferred.empty() ) {
    408                                 // either add successful match or push back next state
    409                                 if ( sat.newNeed.empty() ) {
    410                                         finalizeAssertions(
    411                                                 sat.cand, sat.inferred, thresholds, std::move( sat.costs ), out );
    412                                 } else {
    413                                         nextSats.emplace_back( std::move( sat ), IterateState );
    414                                 }
    415                         } else if ( sat.deferred.size() > deferLimit ) {
    416                                 // too many deferred assertions to attempt mutual compatibility
    417                                 Indenter tabs{ 3 };
    418                                 std::ostringstream ss;
    419                                 ss << tabs << "Unsatisfiable alternative:\n";
    420                                 print( ss, *sat.cand, ++tabs );
    421                                 ss << (tabs-1) << "Too many non-unique satisfying assignments for assertions:\n";
    422                                 for ( const auto & d : sat.deferred ) {
    423                                         ast::print( ss, d.decl, tabs );
    424                                 }
    425 
    426                                 errors.emplace_back( ss.str() );
    427                                 goto nextSat;
    428                         } else {
    429                                 // combine deferred assertions by mutual compatibility
    430                                 std::vector< CandidateEnvMerger::OutType > compatible = filterCombos(
    431                                         sat.deferred, CandidateEnvMerger{ sat.cand->env, sat.cand->open, sat.symtab } );
    432                                
    433                                 // fail early if no mutually-compatible assertion satisfaction
    434                                 if ( compatible.empty() ) {
    435                                         Indenter tabs{ 3 };
    436                                         std::ostringstream ss;
    437                                         ss << tabs << "Unsatisfiable alternative:\n";
    438                                         print( ss, *sat.cand, ++tabs );
    439                                         ss << (tabs-1) << "No mutually-compatible satisfaction for assertions:\n";
    440                                         for ( const auto& d : sat.deferred ) {
    441                                                 ast::print( ss, d.decl, tabs );
    442                                         }
    443 
    444                                         errors.emplace_back( ss.str() );
    445                                         goto nextSat;
    446                                 }
    447 
    448                                 // sort by cost (for overall pruning order)
    449                                 std::sort( compatible.begin(), compatible.end() );
    450 
    451                                 // process mutually-compatible combinations
    452                                 for ( auto & compat : compatible ) {
    453                                         // set up next satisfaction state
    454                                         CandidateRef nextCand = std::make_shared<Candidate>(
    455                                                 sat.cand->expr, std::move( compat.env ), std::move( compat.open ),
    456                                                 ast::AssertionSet{} /* need moved into satisfaction state */,
    457                                                 sat.cand->cost, sat.cand->cvtCost );
    458 
    459                                         ast::AssertionSet nextNewNeed{ sat.newNeed };
    460                                         InferCache nextInferred{ sat.inferred };
    461                                        
    462                                         CostVec nextCosts{ sat.costs };
    463                                         nextCosts.back() += compat.cost;
    464                                                                
    465                                         ast::SymbolTable nextSymtab{ sat.symtab };
    466 
    467                                         // add compatible assertions to new satisfaction state
    468                                         for ( DeferRef r : compat.assns ) {
    469                                                 AssnCandidate match = r.match;
    470                                                 addToSymbolTable( match.have, nextSymtab );
    471                                                 nextNewNeed.insert( match.need.begin(), match.need.end() );
    472 
    473                                                 bindAssertion( r.decl, r.info, nextCand, match, nextInferred );
    474                                         }
    475 
    476                                         // either add successful match or push back next state
    477                                         if ( nextNewNeed.empty() ) {
    478                                                 finalizeAssertions(
    479                                                         nextCand, nextInferred, thresholds, std::move( nextCosts ), out );
    480                                         } else {
    481                                                 nextSats.emplace_back(
    482                                                         std::move( nextCand ), std::move( nextNewNeed ),
    483                                                         std::move( nextInferred ), std::move( nextCosts ),
    484                                                         std::move( nextSymtab ) );
    485                                         }
    486                                 }
    487                         }
    488                 nextSat:; }
    489 
    490                 // finish or reset for next round
    491                 if ( nextSats.empty() ) return;
    492                 sats.swap( nextSats );
    493                 nextSats.clear();
    494         }
    495        
    496         // exceeded recursion limit if reaches here
    497         if ( out.empty() ) {
    498                 SemanticError( cand->expr->location, "Too many recursive assertions" );
    499         }
     26        #warning unimplemented
     27        (void)alt; (void)symtab; (void)out; (void)errors;
     28        assert(false);
    50029}
    50130
Note: See TracChangeset for help on using the changeset viewer.