Ignore:
Timestamp:
Feb 7, 2022, 12:50:05 PM (2 years ago)
Author:
Thierry Delisle <tdelisle@…>
Branches:
ADT, ast-experimental, enum, forall-pointer-decay, master, pthread-emulation, qualifiedEnum
Children:
250583e
Parents:
b56ad5e
Message:

Change pass visitor to avoid more transient strong references

File:
1 edited

Legend:

Unmodified
Added
Removed
  • src/AST/Pass.impl.hpp

    rb56ad5e rf8143a6  
    7979
    8080                template<typename it_t, template <class...> class container_t>
    81                 static inline void take_all( it_t it, container_t<ast::ptr<ast::Stmt>> * decls, bool * mutated = nullptr ) {
    82                         if(empty(decls)) return;
    83 
    84                         std::move(decls->begin(), decls->end(), it);
    85                         decls->clear();
     81                static inline void take_all( it_t it, container_t<ast::ptr<ast::Stmt>> * stmts, bool * mutated = nullptr ) {
     82                        if(empty(stmts)) return;
     83
     84                        std::move(stmts->begin(), stmts->end(), it);
     85                        stmts->clear();
    8686                        if(mutated) *mutated = true;
    8787                }
     
    123123                        return !new_val.empty();
    124124                }
     125        }
     126
     127
     128        template< typename core_t >
     129        template< typename node_t >
     130        template< typename object_t, typename super_t, typename field_t >
     131        void ast::Pass< core_t >::result1< node_t >::apply(object_t * object, field_t super_t::* field) {
     132                object->*field = value;
    125133        }
    126134
     
    131139                                !std::is_base_of<ast::Expr, node_t>::value &&
    132140                                !std::is_base_of<ast::Stmt, node_t>::value
    133                         , decltype( node->accept(*this) )
     141                        , ast::Pass< core_t >::result1<
     142                                typename std::remove_pointer< decltype( node->accept(*this) ) >::type
     143                        >
    134144                >::type
    135145        {
     
    140150                static_assert( !std::is_base_of<ast::Stmt, node_t>::value, "ERROR");
    141151
    142                 return node->accept( *this );
     152                auto nval = node->accept( *this );
     153                ast::Pass< core_t >::result1<
     154                        typename std::remove_pointer< decltype( node->accept(*this) ) >::type
     155                > res;
     156                res.differs = nval != node;
     157                res.value = nval;
     158                return res;
    143159        }
    144160
    145161        template< typename core_t >
    146         const ast::Expr * ast::Pass< core_t >::call_accept( const ast::Expr * expr ) {
     162        ast::Pass< core_t >::result1<ast::Expr> ast::Pass< core_t >::call_accept( const ast::Expr * expr ) {
    147163                __pedantic_pass_assert( __visit_children() );
    148164                __pedantic_pass_assert( expr );
     
    153169                }
    154170
    155                 return expr->accept( *this );
     171                auto nval = expr->accept( *this );
     172                return { nval != expr, nval };
    156173        }
    157174
    158175        template< typename core_t >
    159         const ast::Stmt * ast::Pass< core_t >::call_accept( const ast::Stmt * stmt ) {
     176        ast::Pass< core_t >::result1<ast::Stmt> ast::Pass< core_t >::call_accept( const ast::Stmt * stmt ) {
    160177                __pedantic_pass_assert( __visit_children() );
    161178                __pedantic_pass_assert( stmt );
    162179
    163                 return stmt->accept( *this );
     180                const ast::Stmt * nval = stmt->accept( *this );
     181                return { nval != stmt, nval };
    164182        }
    165183
    166184        template< typename core_t >
    167         const ast::Stmt * ast::Pass< core_t >::call_accept_as_compound( const ast::Stmt * stmt ) {
     185        ast::Pass< core_t >::result1<ast::Stmt> ast::Pass< core_t >::call_accept_as_compound( const ast::Stmt * stmt ) {
    168186                __pedantic_pass_assert( __visit_children() );
    169187                __pedantic_pass_assert( stmt );
     
    190208                // If the pass doesn't want to add anything then we are done
    191209                if( empty(stmts_before) && empty(stmts_after) && empty(decls_before) && empty(decls_after) ) {
    192                         return nstmt;
     210                        return { nstmt != stmt, nstmt };
    193211                }
    194212
     
    212230                __pass::take_all( std::back_inserter( compound->kids ), stmts_after );
    213231
    214                 return compound;
     232                return {true, compound};
    215233        }
    216234
    217235        template< typename core_t >
    218236        template< template <class...> class container_t >
    219         container_t< ptr<Stmt> > ast::Pass< core_t >::call_accept( const container_t< ptr<Stmt> > & statements ) {
     237        template< typename object_t, typename super_t, typename field_t >
     238        void ast::Pass< core_t >::resultNstmt<container_t>::apply(object_t * object, field_t super_t::* field) {
     239                auto & container = object->*field;
     240                __pedantic_pass_assert( container.size() <= values.size() );
     241
     242                auto cit = enumerate(container).begin();
     243
     244                container_t<ptr<Stmt>> nvals;
     245                for(delta & d : values) {
     246                        if( d.is_old ) {
     247                                __pedantic_pass_assert( cit.idx <= d.old_idx );
     248                                std::advance( cit, d.old_idx - cit.idx );
     249                                nvals.push_back( std::move( (*cit).val) );
     250                        } else {
     251                                nvals.push_back( std::move(d.nval) );
     252                        }
     253                }
     254
     255                object->*field = std::move(nvals);
     256        }
     257
     258        template< typename core_t >
     259        template< template <class...> class container_t >
     260        ast::Pass< core_t >::resultNstmt<container_t> ast::Pass< core_t >::call_accept( const container_t< ptr<Stmt> > & statements ) {
    220261                __pedantic_pass_assert( __visit_children() );
    221262                if( statements.empty() ) return {};
     
    244285                pass_visitor_stats.avg->push(pass_visitor_stats.depth);
    245286
    246                 bool mutated = false;
    247                 container_t< ptr<Stmt> > new_kids;
    248                 for( const Stmt * stmt : statements ) {
     287                resultNstmt<container_t> new_kids;
     288                for( auto value : enumerate( statements ) ) {
    249289                        try {
     290                                size_t i = value.idx;
     291                                const Stmt * stmt = value.val;
    250292                                __pedantic_pass_assert( stmt );
    251293                                const ast::Stmt * new_stmt = stmt->accept( *this );
    252294                                assert( new_stmt );
    253                                 if(new_stmt != stmt ) mutated = true;
     295                                if(new_stmt != stmt ) { new_kids.differs = true; }
    254296
    255297                                // Make sure that it is either adding statements or declartions but not both
     
    261303
    262304                                // Take all the statements which should have gone after, N/A for first iteration
    263                                 __pass::take_all( std::back_inserter( new_kids ), decls_before, &mutated );
    264                                 __pass::take_all( std::back_inserter( new_kids ), stmts_before, &mutated );
     305                                new_kids.take_all( decls_before );
     306                                new_kids.take_all( stmts_before );
    265307
    266308                                // Now add the statement if there is one
    267                                 new_kids.emplace_back( new_stmt );
     309                                if(new_stmt != stmt) {
     310                                        new_kids.values.emplace_back( new_stmt, i, false );
     311                                } else {
     312                                        new_kids.values.emplace_back( nullptr, i, true );
     313                                }
    268314
    269315                                // Take all the declarations that go before
    270                                 __pass::take_all( std::back_inserter( new_kids ), decls_after, &mutated );
    271                                 __pass::take_all( std::back_inserter( new_kids ), stmts_after, &mutated );
     316                                new_kids.take_all( decls_after );
     317                                new_kids.take_all( stmts_after );
    272318                        }
    273319                        catch ( SemanticErrorException &e ) {
     
    278324                if ( !errors.isEmpty() ) { throw errors; }
    279325
    280                 return mutated ? new_kids : container_t< ptr<Stmt> >();
     326                return new_kids;
    281327        }
    282328
    283329        template< typename core_t >
    284330        template< template <class...> class container_t, typename node_t >
    285         container_t< ast::ptr<node_t> > ast::Pass< core_t >::call_accept( const container_t< ast::ptr<node_t> > & container ) {
     331        template< typename object_t, typename super_t, typename field_t >
     332        void ast::Pass< core_t >::resultN<container_t, node_t>::apply(object_t * object, field_t super_t::* field) {
     333                auto & container = object->*field;
     334                __pedantic_pass_assert( container.size() == values.size() );
     335
     336                for(size_t i = 0; i < container.size(); i++) {
     337                        // Take all the elements that are different in 'values'
     338                        // and swap them into 'container'
     339                        if( values[i] != nullptr ) std::swap(container[i], values[i]);
     340                }
     341
     342                // Now the original containers should still have the unchanged values
     343                // but also contain the new values
     344        }
     345
     346        template< typename core_t >
     347        template< template <class...> class container_t, typename node_t >
     348        ast::Pass< core_t >::resultN<container_t, node_t> ast::Pass< core_t >::call_accept( const container_t< ast::ptr<node_t> > & container ) {
    286349                __pedantic_pass_assert( __visit_children() );
    287350                if( container.empty() ) return {};
     
    293356
    294357                bool mutated = false;
    295                 container_t< ast::ptr<node_t> > new_kids;
     358                container_t<ptr<node_t>> new_kids;
    296359                for ( const node_t * node : container ) {
    297360                        try {
    298361                                __pedantic_pass_assert( node );
    299362                                const node_t * new_stmt = strict_dynamic_cast< const node_t * >( node->accept( *this ) );
    300                                 if(new_stmt != node ) mutated = true;
    301 
    302                                 new_kids.emplace_back( new_stmt );
     363                                if(new_stmt != node ) {
     364                                        mutated = true;
     365                                        new_kids.emplace_back( new_stmt );
     366                                } else {
     367                                        new_kids.emplace_back( nullptr );
     368                                }
     369
    303370                        }
    304371                        catch( SemanticErrorException &e ) {
     
    306373                        }
    307374                }
     375
     376                __pedantic_pass_assert( new_kids.size() == container.size() );
    308377                pass_visitor_stats.depth--;
    309378                if ( ! errors.isEmpty() ) { throw errors; }
    310379
    311                 return mutated ? new_kids : container_t< ast::ptr<node_t> >();
     380                return ast::Pass< core_t >::resultN<container_t, node_t>{ mutated,  new_kids };
    312381        }
    313382
     
    327396                auto new_val = call_accept( old_val );
    328397
    329                 static_assert( !std::is_same<const ast::Node *, decltype(new_val)>::value || std::is_same<int, decltype(old_val)>::value, "ERROR");
    330 
    331                 if( __pass::differs(old_val, new_val) ) {
     398                static_assert( !std::is_same<const ast::Node *, decltype(new_val)>::value /* || std::is_same<int, decltype(old_val)>::value */, "ERROR");
     399
     400                if( new_val.differs ) {
    332401                        auto new_parent = __pass::mutate<core_t>(parent);
    333                         new_parent->*child = new_val;
     402                        new_val.apply(new_parent, child);
    334403                        parent = new_parent;
    335404                }
     
    353422                static_assert( !std::is_same<const ast::Node *, decltype(new_val)>::value || std::is_same<int, decltype(old_val)>::value, "ERROR");
    354423
    355                 if( __pass::differs(old_val, new_val) ) {
     424                if( new_val.differs ) {
    356425                        auto new_parent = __pass::mutate<core_t>(parent);
    357                         new_parent->*child = new_val;
     426                        new_val.apply( new_parent, child );
    358427                        parent = new_parent;
    359428                }
     
    9411010                        const Expr * func = clause.target.func ? clause.target.func->accept(*this) : nullptr;
    9421011                        if(func != clause.target.func) mutated = true;
     1012                        else func = nullptr;
    9431013
    9441014                        std::vector<ptr<Expr>> new_args;
     
    9461016                        for( const auto & arg : clause.target.args ) {
    9471017                                auto a = arg->accept(*this);
    948                                 new_args.push_back( a );
    949                                 if( a != arg ) mutated = true;
     1018                                if( a != arg ) {
     1019                                        mutated = true;
     1020                                        new_args.push_back( a );
     1021                                } else
     1022                                        new_args.push_back( nullptr );
    9501023                        }
    9511024
    9521025                        const Stmt * stmt = clause.stmt ? clause.stmt->accept(*this) : nullptr;
    9531026                        if(stmt != clause.stmt) mutated = true;
     1027                        else stmt = nullptr;
    9541028
    9551029                        const Expr * cond = clause.cond ? clause.cond->accept(*this) : nullptr;
    9561030                        if(cond != clause.cond) mutated = true;
     1031                        else cond = nullptr;
    9571032
    9581033                        new_clauses.push_back( WaitForStmt::Clause{ {func, std::move(new_args) }, stmt, cond } );
     
    9611036                if(mutated) {
    9621037                        auto n = __pass::mutate<core_t>(node);
    963                         n->clauses = std::move( new_clauses );
     1038                        for(size_t i = 0; i < new_clauses.size(); i++) {
     1039                                if(new_clauses.at(i).target.func != nullptr) std::swap(n->clauses.at(i).target.func, new_clauses.at(i).target.func);
     1040
     1041                                for(size_t j = 0; j < new_clauses.at(i).target.args.size(); j++) {
     1042                                        if(new_clauses.at(i).target.args.at(j) != nullptr) std::swap(n->clauses.at(i).target.args.at(j), new_clauses.at(i).target.args.at(j));
     1043                                }
     1044
     1045                                if(new_clauses.at(i).stmt != nullptr) std::swap(n->clauses.at(i).stmt, new_clauses.at(i).stmt);
     1046                                if(new_clauses.at(i).cond != nullptr) std::swap(n->clauses.at(i).cond, new_clauses.at(i).cond);
     1047                        }
    9641048                        node = n;
    9651049                }
     
    9691053                if(node->field) { \
    9701054                        auto nval = call_accept( node->field ); \
    971                         if(nval != node->field ) { \
     1055                        if(nval.differs ) { \
    9721056                                auto nparent = __pass::mutate<core_t>(node); \
    973                                 nparent->field = nval; \
     1057                                nparent->field = nval.value; \
    9741058                                node = nparent; \
    9751059                        } \
Note: See TracChangeset for help on using the changeset viewer.