Ignore:
File:
1 edited

Legend:

Unmodified
Added
Removed
  • src/InitTweak/InitTweak.cc

    r033ff37 r57acae0  
    1 //
    2 // Cforall Version 1.0.0 Copyright (C) 2015 University of Waterloo
    3 //
    4 // The contents of this file are covered under the licence agreement in the
    5 // file "LICENCE" distributed with Cforall.
    6 //
    7 // InitTweak.cc --
    8 //
    9 // Author           : Rob Schluntz
    10 // Created On       : Fri May 13 11:26:36 2016
    11 // Last Modified By : Peter A. Buhr
    12 // Last Modified On : Thu Jul 25 22:21:48 2019
    13 // Update Count     : 7
    14 //
    15 
    161#include <algorithm>               // for find, all_of
    172#include <cassert>                 // for assertf, assert, strict_dynamic_cast
     
    194#include <iterator>                // for back_insert_iterator, back_inserter
    205#include <memory>                  // for __shared_ptr
    21 #include <vector>
    22 
    23 #include "AST/Expr.hpp"
    24 #include "AST/Init.hpp"
    25 #include "AST/Node.hpp"
    26 #include "AST/Pass.hpp"
    27 #include "AST/Stmt.hpp"
    28 #include "AST/Type.hpp"
     6
    297#include "Common/PassVisitor.h"
    308#include "Common/SemanticError.h"  // for SemanticError
     
    4826#include "Tuples/Tuples.h"         // for Tuples::isTtype
    4927
     28class UntypedValofExpr;
     29
    5030namespace InitTweak {
    5131        namespace {
     
    8767                };
    8868
    89                 struct InitFlattener_old : public WithShortCircuiting {
     69                struct InitFlattener : public WithShortCircuiting {
    9070                        void previsit( SingleInit * singleInit ) {
    9171                                visit_children = false;
     
    9575                };
    9676
    97                 struct InitFlattener_new : public ast::WithShortCircuiting {
    98                         std::vector< ast::ptr< ast::Expr > > argList;
    99 
    100                         void previsit( const ast::SingleInit * singleInit ) {
    101                                 visit_children = false;
    102                                 argList.emplace_back( singleInit->value );
    103                         }
    104                 };
    105 
    106         } // anonymous namespace
     77        }
    10778
    10879        std::list< Expression * > makeInitList( Initializer * init ) {
    109                 PassVisitor<InitFlattener_old> flattener;
     80                PassVisitor<InitFlattener> flattener;
    11081                maybeAccept( init, flattener );
    11182                return flattener.pass.argList;
     
    12495        }
    12596
    126 std::vector< ast::ptr< ast::Expr > > makeInitList( const ast::Init * init ) {
    127         ast::Pass< InitFlattener_new > flattener;
    128         maybe_accept( init, flattener );
    129         return std::move( flattener.pass.argList );
    130 }
    131 
    132         class InitExpander_old::ExpanderImpl {
     97        class InitExpander::ExpanderImpl {
    13398        public:
    13499                virtual ~ExpanderImpl() = default;
     
    137102        };
    138103
    139         class InitImpl_old : public InitExpander_old::ExpanderImpl {
     104        class InitImpl : public InitExpander::ExpanderImpl {
    140105        public:
    141                 InitImpl_old( Initializer * init ) : init( init ) {}
    142                 virtual ~InitImpl_old() = default;
     106                InitImpl( Initializer * init ) : init( init ) {}
     107                virtual ~InitImpl() = default;
    143108
    144109                virtual std::list< Expression * > next( __attribute((unused)) std::list< Expression * > & indices ) {
     
    154119        };
    155120
    156         class ExprImpl_old : public InitExpander_old::ExpanderImpl {
     121        class ExprImpl : public InitExpander::ExpanderImpl {
    157122        public:
    158                 ExprImpl_old( Expression * expr ) : arg( expr ) {}
    159                 virtual ~ExprImpl_old() { delete arg; }
     123                ExprImpl( Expression * expr ) : arg( expr ) {}
     124                virtual ~ExprImpl() { delete arg; }
    160125
    161126                virtual std::list< Expression * > next( std::list< Expression * > & indices ) {
     
    181146        };
    182147
    183         InitExpander_old::InitExpander_old( Initializer * init ) : expander( new InitImpl_old( init ) ) {}
    184 
    185         InitExpander_old::InitExpander_old( Expression * expr ) : expander( new ExprImpl_old( expr ) ) {}
    186 
    187         std::list< Expression * > InitExpander_old::operator*() {
     148        InitExpander::InitExpander( Initializer * init ) : expander( new InitImpl( init ) ) {}
     149
     150        InitExpander::InitExpander( Expression * expr ) : expander( new ExprImpl( expr ) ) {}
     151
     152        std::list< Expression * > InitExpander::operator*() {
    188153                return cur;
    189154        }
    190155
    191         InitExpander_old & InitExpander_old::operator++() {
     156        InitExpander & InitExpander::operator++() {
    192157                cur = expander->next( indices );
    193158                return *this;
     
    195160
    196161        // use array indices list to build switch statement
    197         void InitExpander_old::addArrayIndex( Expression * index, Expression * dimension ) {
     162        void InitExpander::addArrayIndex( Expression * index, Expression * dimension ) {
    198163                indices.push_back( index );
    199164                indices.push_back( dimension );
    200165        }
    201166
    202         void InitExpander_old::clearArrayIndices() {
     167        void InitExpander::clearArrayIndices() {
    203168                deleteAll( indices );
    204169                indices.clear();
    205170        }
    206171
    207         bool InitExpander_old::addReference() {
     172        bool InitExpander::addReference() {
    208173                bool added = false;
    209174                for ( Expression *& expr : cur ) {
     
    236201
    237202                template< typename OutIterator >
    238                 void build( UntypedExpr * callExpr, InitExpander_old::IndexList::iterator idx, InitExpander_old::IndexList::iterator idxEnd, Initializer * init, OutIterator out ) {
     203                void build( UntypedExpr * callExpr, InitExpander::IndexList::iterator idx, InitExpander::IndexList::iterator idxEnd, Initializer * init, OutIterator out ) {
    239204                        if ( idx == idxEnd ) return;
    240205                        Expression * index = *idx++;
     
    293258        // remaining elements.
    294259        // To accomplish this, generate switch statement, consuming all of expander's elements
    295         Statement * InitImpl_old::buildListInit( UntypedExpr * dst, std::list< Expression * > & indices ) {
     260        Statement * InitImpl::buildListInit( UntypedExpr * dst, std::list< Expression * > & indices ) {
    296261                if ( ! init ) return nullptr;
    297262                CompoundStmt * block = new CompoundStmt();
     
    306271        }
    307272
    308         Statement * ExprImpl_old::buildListInit( UntypedExpr *, std::list< Expression * > & ) {
     273        Statement * ExprImpl::buildListInit( UntypedExpr *, std::list< Expression * > & ) {
    309274                return nullptr;
    310275        }
    311276
    312         Statement * InitExpander_old::buildListInit( UntypedExpr * dst ) {
     277        Statement * InitExpander::buildListInit( UntypedExpr * dst ) {
    313278                return expander->buildListInit( dst, indices );
    314279        }
    315 
    316 class InitExpander_new::ExpanderImpl {
    317 public:
    318         virtual ~ExpanderImpl() = default;
    319         virtual std::vector< ast::ptr< ast::Expr > > next( IndexList & indices ) = 0;
    320         virtual ast::ptr< ast::Stmt > buildListInit(
    321                 ast::UntypedExpr * callExpr, IndexList & indices ) = 0;
    322 };
    323 
    324 namespace {
    325         template< typename Out >
    326         void buildCallExpr(
    327                 ast::UntypedExpr * callExpr, const ast::Expr * index, const ast::Expr * dimension,
    328                 const ast::Init * init, Out & out
    329         ) {
    330                 const CodeLocation & loc = init->location;
    331 
    332                 auto cond = new ast::UntypedExpr{
    333                         loc, new ast::NameExpr{ loc, "?<?" }, { index, dimension } };
    334 
    335                 std::vector< ast::ptr< ast::Expr > > args = makeInitList( init );
    336                 splice( callExpr->args, args );
    337 
    338                 out.emplace_back( new ast::IfStmt{ loc, cond, new ast::ExprStmt{ loc, callExpr } } );
    339 
    340                 out.emplace_back( new ast::ExprStmt{
    341                         loc, new ast::UntypedExpr{ loc, new ast::NameExpr{ loc, "++?" }, { index } } } );
    342         }
    343 
    344         template< typename Out >
    345         void build(
    346                 ast::UntypedExpr * callExpr, const InitExpander_new::IndexList & indices,
    347                 const ast::Init * init, Out & out
    348         ) {
    349                 if ( indices.empty() ) return;
    350 
    351                 unsigned idx = 0;
    352 
    353                 const ast::Expr * index = indices[idx++];
    354                 assert( idx != indices.size() );
    355                 const ast::Expr * dimension = indices[idx++];
    356 
    357                 if ( idx == indices.size() ) {
    358                         if ( auto listInit = dynamic_cast< const ast::ListInit * >( init ) ) {
    359                                 for ( const ast::Init * init : *listInit ) {
    360                                         buildCallExpr( callExpr, index, dimension, init, out );
    361                                 }
    362                         } else {
    363                                 buildCallExpr( callExpr, index, dimension, init, out );
    364                         }
    365                 } else {
    366                         const CodeLocation & loc = init->location;
    367 
    368                         unsigned long cond = 0;
    369                         auto listInit = dynamic_cast< const ast::ListInit * >( init );
    370                         if ( ! listInit ) { SemanticError( loc, "unbalanced list initializers" ); }
    371 
    372                         static UniqueName targetLabel( "L__autogen__" );
    373                         ast::Label switchLabel{
    374                                 loc, targetLabel.newName(), { new ast::Attribute{ "unused" } } };
    375 
    376                         std::vector< ast::ptr< ast::Stmt > > branches;
    377                         for ( const ast::Init * init : *listInit ) {
    378                                 auto condition = ast::ConstantExpr::from_ulong( loc, cond );
    379                                 ++cond;
    380 
    381                                 std::vector< ast::ptr< ast::Stmt > > stmts;
    382                                 build( callExpr, indices, init, stmts );
    383                                 stmts.emplace_back(
    384                                         new ast::BranchStmt{ loc, ast::BranchStmt::Break, switchLabel } );
    385                                 branches.emplace_back( new ast::CaseStmt{ loc, condition, std::move( stmts ) } );
    386                         }
    387                         out.emplace_back( new ast::SwitchStmt{ loc, index, std::move( branches ) } );
    388                         out.emplace_back( new ast::NullStmt{ loc, { switchLabel } } );
    389                 }
    390         }
    391 
    392         class InitImpl_new final : public InitExpander_new::ExpanderImpl {
    393                 ast::ptr< ast::Init > init;
    394         public:
    395                 InitImpl_new( const ast::Init * i ) : init( i ) {}
    396 
    397                 std::vector< ast::ptr< ast::Expr > > next( InitExpander_new::IndexList & ) override {
    398                         return makeInitList( init );
    399                 }
    400 
    401                 ast::ptr< ast::Stmt > buildListInit(
    402                         ast::UntypedExpr * callExpr, InitExpander_new::IndexList & indices
    403                 ) override {
    404                         // If array came with an initializer list, initialize each element. We may have more
    405                         // initializers than elements of the array; need to check at each index that we have
    406                         // not exceeded size. We may have fewer initializers than elements in the array; need
    407                         // to default-construct remaining elements. To accomplish this, generate switch
    408                         // statement consuming all of expander's elements
    409 
    410                         if ( ! init ) return {};
    411 
    412                         std::list< ast::ptr< ast::Stmt > > stmts;
    413                         build( callExpr, indices, init, stmts );
    414                         if ( stmts.empty() ) {
    415                                 return {};
    416                         } else {
    417                                 auto block = new ast::CompoundStmt{ init->location, std::move( stmts ) };
    418                                 init = nullptr;  // consumed in creating the list init
    419                                 return block;
    420                         }
    421                 }
    422         };
    423 
    424         class ExprImpl_new final : public InitExpander_new::ExpanderImpl {
    425                 ast::ptr< ast::Expr > arg;
    426         public:
    427                 ExprImpl_new( const ast::Expr * a ) : arg( a ) {}
    428 
    429                 std::vector< ast::ptr< ast::Expr > > next(
    430                         InitExpander_new::IndexList & indices
    431                 ) override {
    432                         if ( ! arg ) return {};
    433 
    434                         const CodeLocation & loc = arg->location;
    435                         const ast::Expr * expr = arg;
    436                         for ( auto it = indices.rbegin(); it != indices.rend(); ++it ) {
    437                                 // go through indices and layer on subscript exprs ?[?]
    438                                 ++it;
    439                                 expr = new ast::UntypedExpr{
    440                                         loc, new ast::NameExpr{ loc, "?[?]" }, { expr, *it } };
    441                         }
    442                         return { expr };
    443                 }
    444 
    445                 ast::ptr< ast::Stmt > buildListInit(
    446                         ast::UntypedExpr *, InitExpander_new::IndexList &
    447                 ) override {
    448                         return {};
    449                 }
    450         };
    451 } // anonymous namespace
    452 
    453 InitExpander_new::InitExpander_new( const ast::Init * init )
    454 : expander( new InitImpl_new{ init } ), crnt(), indices() {}
    455 
    456 InitExpander_new::InitExpander_new( const ast::Expr * expr )
    457 : expander( new ExprImpl_new{ expr } ), crnt(), indices() {}
    458 
    459 std::vector< ast::ptr< ast::Expr > > InitExpander_new::operator* () { return crnt; }
    460 
    461 InitExpander_new & InitExpander_new::operator++ () {
    462         crnt = expander->next( indices );
    463         return *this;
    464 }
    465 
    466 /// builds statement which has the same semantics as a C-style list initializer (for array
    467 /// initializers) using callExpr as the base expression to perform initialization
    468 ast::ptr< ast::Stmt > InitExpander_new::buildListInit( ast::UntypedExpr * callExpr ) {
    469         return expander->buildListInit( callExpr, indices );
    470 }
    471 
    472 void InitExpander_new::addArrayIndex( const ast::Expr * index, const ast::Expr * dimension ) {
    473         indices.emplace_back( index );
    474         indices.emplace_back( dimension );
    475 }
    476 
    477 void InitExpander_new::clearArrayIndices() { indices.clear(); }
    478 
    479 bool InitExpander_new::addReference() {
    480         for ( ast::ptr< ast::Expr > & expr : crnt ) {
    481                 expr = new ast::AddressExpr{ expr };
    482         }
    483         return ! crnt.empty();
    484 }
    485280
    486281        Type * getTypeofThis( FunctionType * ftype ) {
     
    511306        }
    512307
    513         struct CallFinder_old {
    514                 CallFinder_old( const std::list< std::string > & names ) : names( names ) {}
     308        struct CallFinder {
     309                CallFinder( const std::list< std::string > & names ) : names( names ) {}
    515310
    516311                void postvisit( ApplicationExpr * appExpr ) {
     
    535330        };
    536331
    537         struct CallFinder_new final {
    538                 std::vector< ast::ptr< ast::Expr > > matches;
    539                 const std::vector< std::string > names;
    540 
    541                 CallFinder_new( std::vector< std::string > && ns ) : matches(), names( std::move(ns) ) {}
    542 
    543                 void handleCallExpr( const ast::Expr * expr ) {
    544                         std::string fname = getFunctionName( expr );
    545                         if ( std::find( names.begin(), names.end(), fname ) != names.end() ) {
    546                                 matches.emplace_back( expr );
    547                         }
    548                 }
    549 
    550                 void postvisit( const ast::ApplicationExpr * expr ) { handleCallExpr( expr ); }
    551                 void postvisit( const ast::UntypedExpr *     expr ) { handleCallExpr( expr ); }
    552         };
    553 
    554332        void collectCtorDtorCalls( Statement * stmt, std::list< Expression * > & matches ) {
    555                 static PassVisitor<CallFinder_old> finder( std::list< std::string >{ "?{}", "^?{}" } );
     333                static PassVisitor<CallFinder> finder( std::list< std::string >{ "?{}", "^?{}" } );
    556334                finder.pass.matches = &matches;
    557335                maybeAccept( stmt, finder );
    558         }
    559 
    560         std::vector< ast::ptr< ast::Expr > > collectCtorDtorCalls( const ast::Stmt * stmt ) {
    561                 ast::Pass< CallFinder_new > finder{ std::vector< std::string >{ "?{}", "^?{}" } };
    562                 maybe_accept( stmt, finder );
    563                 return std::move( finder.pass.matches );
    564336        }
    565337
     
    567339                std::list< Expression * > matches;
    568340                collectCtorDtorCalls( stmt, matches );
    569                 assertf( matches.size() <= 1, "%zd constructor/destructors found in %s", matches.size(), toString( stmt ).c_str() );
     341                assert( matches.size() <= 1 );
    570342                return matches.size() == 1 ? matches.front() : nullptr;
    571343        }
     
    573345        namespace {
    574346                DeclarationWithType * getCalledFunction( Expression * expr );
    575                 const ast::DeclWithType * getCalledFunction( const ast::Expr * expr );
    576347
    577348                template<typename CallExpr>
     
    583354                        return getCalledFunction( expr->get_args().front() );
    584355                }
    585 
    586                 template<typename CallExpr>
    587                 const ast::DeclWithType * handleDerefCalledFunction( const CallExpr * expr ) {
    588                         // (*f)(x) => should get "f"
    589                         std::string name = getFunctionName( expr );
    590                         assertf( name == "*?", "Unexpected untyped expression: %s", name.c_str() );
    591                         assertf( ! expr->args.empty(), "Cannot get called function from dereference with no arguments" );
    592                         return getCalledFunction( expr->args.front() );
    593                 }
    594 
    595356
    596357                DeclarationWithType * getCalledFunction( Expression * expr ) {
     
    613374                        return nullptr;
    614375                }
    615 
    616                 const ast::DeclWithType * getCalledFunction( const ast::Expr * expr ) {
    617                         assert( expr );
    618                         if ( const ast::VariableExpr * varExpr = dynamic_cast< const ast::VariableExpr * >( expr ) ) {
    619                                 return varExpr->var;
    620                         } else if ( const ast::MemberExpr * memberExpr = dynamic_cast< const ast::MemberExpr * >( expr ) ) {
    621                                 return memberExpr->member;
    622                         } else if ( const ast::CastExpr * castExpr = dynamic_cast< const ast::CastExpr * >( expr ) ) {
    623                                 return getCalledFunction( castExpr->arg );
    624                         } else if ( const ast::UntypedExpr * untypedExpr = dynamic_cast< const ast::UntypedExpr * >( expr ) ) {
    625                                 return handleDerefCalledFunction( untypedExpr );
    626                         } else if ( const ast::ApplicationExpr * appExpr = dynamic_cast< const ast::ApplicationExpr * > ( expr ) ) {
    627                                 return handleDerefCalledFunction( appExpr );
    628                         } else if ( const ast::AddressExpr * addrExpr = dynamic_cast< const ast::AddressExpr * >( expr ) ) {
    629                                 return getCalledFunction( addrExpr->arg );
    630                         } else if ( const ast::CommaExpr * commaExpr = dynamic_cast< const ast::CommaExpr * >( expr ) ) {
    631                                 return getCalledFunction( commaExpr->arg2 );
    632                         }
    633                         return nullptr;
    634                 }
    635 
    636                 DeclarationWithType * getFunctionCore( const Expression * expr ) {
    637                         if ( const auto * appExpr = dynamic_cast< const ApplicationExpr * >( expr ) ) {
    638                                 return getCalledFunction( appExpr->function );
    639                         } else if ( const auto * untyped = dynamic_cast< const UntypedExpr * >( expr ) ) {
    640                                 return getCalledFunction( untyped->function );
    641                         }
    642                         assertf( false, "getFunction with unknown expression: %s", toString( expr ).c_str() );
    643                 }
    644376        }
    645377
    646378        DeclarationWithType * getFunction( Expression * expr ) {
    647                 return getFunctionCore( expr );
    648         }
    649 
    650         const DeclarationWithType * getFunction( const Expression * expr ) {
    651                 return getFunctionCore( expr );
    652         }
    653 
    654         const ast::DeclWithType * getFunction( const ast::Expr * expr ) {
    655                 if ( const ast::ApplicationExpr * appExpr = dynamic_cast< const ast::ApplicationExpr * >( expr ) ) {
    656                         return getCalledFunction( appExpr->func );
    657                 } else if ( const ast::UntypedExpr * untyped = dynamic_cast< const ast::UntypedExpr * > ( expr ) ) {
    658                         return getCalledFunction( untyped->func );
     379                if ( ApplicationExpr * appExpr = dynamic_cast< ApplicationExpr * >( expr ) ) {
     380                        return getCalledFunction( appExpr->get_function() );
     381                } else if ( UntypedExpr * untyped = dynamic_cast< UntypedExpr * > ( expr ) ) {
     382                        return getCalledFunction( untyped->get_function() );
    659383                }
    660384                assertf( false, "getFunction received unknown expression: %s", toString( expr ).c_str() );
     
    671395        }
    672396
    673         const ast::ApplicationExpr * isIntrinsicCallExpr( const ast::Expr * expr ) {
    674                 auto appExpr = dynamic_cast< const ast::ApplicationExpr * >( expr );
    675                 if ( ! appExpr ) return nullptr;
    676 
    677                 const ast::DeclWithType * func = getCalledFunction( appExpr->func );
    678                 assertf( func,
    679                         "getCalledFunction returned nullptr: %s", toString( appExpr->func ).c_str() );
    680 
    681                 // check for Intrinsic only -- don't want to remove all overridable ctor/dtor because
    682                 // autogenerated ctor/dtor will call all member dtors, and some members may have a
    683                 // user-defined dtor
    684                 return func->linkage == ast::Linkage::Intrinsic ? appExpr : nullptr;
    685         }
    686 
    687397        namespace {
    688398                template <typename Predicate>
     
    693403                        return std::all_of( callExprs.begin(), callExprs.end(), pred);
    694404                }
    695 
    696                 template <typename Predicate>
    697                 bool allofCtorDtor( const ast::Stmt * stmt, const Predicate & pred ) {
    698                         std::vector< ast::ptr< ast::Expr > > callExprs = collectCtorDtorCalls( stmt );
    699                         return std::all_of( callExprs.begin(), callExprs.end(), pred );
    700                 }
    701405        }
    702406
     
    704408                return allofCtorDtor( stmt, []( Expression * callExpr ){
    705409                        if ( ApplicationExpr * appExpr = isIntrinsicCallExpr( callExpr ) ) {
    706                                 FunctionType *funcType = GenPoly::getFunctionType( appExpr->function->result );
     410                                FunctionType *funcType = GenPoly::getFunctionType( appExpr->get_function()->get_result() );
    707411                                assert( funcType );
    708412                                return funcType->get_parameters().size() == 1;
    709                         }
    710                         return false;
    711                 });
    712         }
    713 
    714         bool isIntrinsicSingleArgCallStmt( const ast::Stmt * stmt ) {
    715                 return allofCtorDtor( stmt, []( const ast::Expr * callExpr ){
    716                         if ( const ast::ApplicationExpr * appExpr = isIntrinsicCallExpr( callExpr ) ) {
    717                                 const ast::FunctionType * funcType =
    718                                         GenPoly::getFunctionType( appExpr->func->result );
    719                                 assert( funcType );
    720                                 return funcType->params.size() == 1;
    721413                        }
    722414                        return false;
     
    737429                                if ( pos == 0 ) return arg;
    738430                                pos--;
    739                         }
    740                         assert( false );
    741                 }
    742 
    743                 template<typename CallExpr>
    744                 const ast::Expr * callArg( const CallExpr * call, unsigned int pos ) {
    745                         if( pos >= call->args.size() ) {
    746                                 assertf( false, "getCallArg for argument that doesn't exist: (%u); %s.",
    747                                         pos, toString( call ).c_str() );
    748                         }
    749                         for ( const ast::Expr * arg : call->args ) {
    750                                 if ( pos == 0 ) return arg;
    751                                 --pos;
    752431                        }
    753432                        assert( false );
     
    774453        }
    775454
    776         const ast::Expr * getCallArg( const ast::Expr * call, unsigned pos ) {
    777                 if ( auto app = dynamic_cast< const ast::ApplicationExpr * >( call ) ) {
    778                         return callArg( app, pos );
    779                 } else if ( auto untyped = dynamic_cast< const ast::UntypedExpr * >( call ) ) {
    780                         return callArg( untyped, pos );
    781                 } else if ( auto tupleAssn = dynamic_cast< const ast::TupleAssignExpr * >( call ) ) {
    782                         const std::list<ast::ptr<ast::Stmt>>& stmts = tupleAssn->stmtExpr->stmts->kids;
    783                         assertf( ! stmts.empty(), "TupleAssignExpr missing statements." );
    784                         auto stmt  = strict_dynamic_cast< const ast::ExprStmt * >( stmts.back().get() );
    785                         auto tuple = strict_dynamic_cast< const ast::TupleExpr * >( stmt->expr.get() );
    786                         assertf( ! tuple->exprs.empty(), "TupleAssignExpr has empty tuple expr.");
    787                         return getCallArg( tuple->exprs.front(), pos );
    788                 } else if ( auto ctor = dynamic_cast< const ast::ImplicitCopyCtorExpr * >( call ) ) {
    789                         return getCallArg( ctor->callExpr, pos );
    790                 } else {
    791                         assertf( false, "Unexpected expression type passed to getCallArg: %s",
    792                                 toString( call ).c_str() );
    793                 }
    794         }
    795 
    796455        namespace {
    797456                std::string funcName( Expression * func );
    798                 std::string funcName( const ast::Expr * func );
    799457
    800458                template<typename CallExpr>
     
    805463                        assertf( ! expr->get_args().empty(), "Cannot get function name from dereference with no arguments" );
    806464                        return funcName( expr->get_args().front() );
    807                 }
    808 
    809                 template<typename CallExpr>
    810                 std::string handleDerefName( const CallExpr * expr ) {
    811                         // (*f)(x) => should get name "f"
    812                         std::string name = getFunctionName( expr );
    813                         assertf( name == "*?", "Unexpected untyped expression: %s", name.c_str() );
    814                         assertf( ! expr->args.empty(), "Cannot get function name from dereference with no arguments" );
    815                         return funcName( expr->args.front() );
    816465                }
    817466
     
    837486                        }
    838487                }
    839 
    840                 std::string funcName( const ast::Expr * func ) {
    841                         if ( const ast::NameExpr * nameExpr = dynamic_cast< const ast::NameExpr * >( func ) ) {
    842                                 return nameExpr->name;
    843                         } else if ( const ast::VariableExpr * varExpr = dynamic_cast< const ast::VariableExpr * >( func ) ) {
    844                                 return varExpr->var->name;
    845                         }       else if ( const ast::CastExpr * castExpr = dynamic_cast< const ast::CastExpr * >( func ) ) {
    846                                 return funcName( castExpr->arg );
    847                         } else if ( const ast::MemberExpr * memberExpr = dynamic_cast< const ast::MemberExpr * >( func ) ) {
    848                                 return memberExpr->member->name;
    849                         } else if ( const ast::UntypedMemberExpr * memberExpr = dynamic_cast< const ast::UntypedMemberExpr * > ( func ) ) {
    850                                 return funcName( memberExpr->member );
    851                         } else if ( const ast::UntypedExpr * untypedExpr = dynamic_cast< const ast::UntypedExpr * >( func ) ) {
    852                                 return handleDerefName( untypedExpr );
    853                         } else if ( const ast::ApplicationExpr * appExpr = dynamic_cast< const ast::ApplicationExpr * >( func ) ) {
    854                                 return handleDerefName( appExpr );
    855                         } else if ( const ast::ConstructorExpr * ctorExpr = dynamic_cast< const ast::ConstructorExpr * >( func ) ) {
    856                                 return funcName( getCallArg( ctorExpr->callExpr, 0 ) );
    857                         } else {
    858                                 assertf( false, "Unexpected expression type being called as a function in call expression: %s", toString( func ).c_str() );
    859                         }
    860                 }
    861488        }
    862489
     
    875502        }
    876503
    877         std::string getFunctionName( const ast::Expr * expr ) {
    878                 // there's some unforunate overlap here with getCalledFunction. Ideally this would be able to use getCalledFunction and
    879                 // return the name of the DeclarationWithType, but this needs to work for NameExpr and UntypedMemberExpr, where getCalledFunction
    880                 // can't possibly do anything reasonable.
    881                 if ( const ast::ApplicationExpr * appExpr = dynamic_cast< const ast::ApplicationExpr * >( expr ) ) {
    882                         return funcName( appExpr->func );
    883                 } else if ( const ast::UntypedExpr * untypedExpr = dynamic_cast< const ast::UntypedExpr * > ( expr ) ) {
    884                         return funcName( untypedExpr->func );
    885                 } else {
    886                         std::cerr << expr << std::endl;
    887                         assertf( false, "Unexpected expression type passed to getFunctionName" );
    888                 }
    889         }
    890 
    891504        Type * getPointerBase( Type * type ) {
    892505                if ( PointerType * ptrType = dynamic_cast< PointerType * >( type ) ) {
     
    900513                }
    901514        }
    902         const ast::Type* getPointerBase( const ast::Type* t ) {
    903                 if ( const auto * p = dynamic_cast< const ast::PointerType * >( t ) ) {
    904                         return p->base;
    905                 } else if ( const auto * a = dynamic_cast< const ast::ArrayType * >( t ) ) {
    906                         return a->base;
    907                 } else if ( const auto * r = dynamic_cast< const ast::ReferenceType * >( t ) ) {
    908                         return r->base;
    909                 } else return nullptr;
    910         }
    911515
    912516        Type * isPointerType( Type * type ) {
     
    957561                void previsit( OffsetofExpr * ) {}
    958562                void previsit( OffsetPackExpr * ) {}
     563                void previsit( AttrExpr * ) {}
    959564                void previsit( CommaExpr * ) {}
    960565                void previsit( LogicalExpr * ) {}
     
    1004609        bool isCtorDtorAssign( const std::string & str ) { return isCtorDtor( str ) || isAssignment( str ); }
    1005610
    1006         const FunctionDecl * isCopyFunction( const Declaration * decl, const std::string & fname ) {
    1007                 const FunctionDecl * function = dynamic_cast< const FunctionDecl * >( decl );
     611        FunctionDecl * isCopyFunction( Declaration * decl, const std::string & fname ) {
     612                FunctionDecl * function = dynamic_cast< FunctionDecl * >( decl );
    1008613                if ( ! function ) return nullptr;
    1009614                if ( function->name != fname ) return nullptr;
     
    1022627        }
    1023628
    1024         bool isCopyFunction( const ast::FunctionDecl * decl ) {
    1025                 const ast::FunctionType * ftype = decl->type;
    1026                 if ( ftype->params.size() != 2 ) return false;
    1027 
    1028                 const ast::Type * t1 = getPointerBase( ftype->params.front()->get_type() );
    1029                 if ( ! t1 ) return false;
    1030                 const ast::Type * t2 = ftype->params.back()->get_type();
    1031 
    1032                 return ResolvExpr::typesCompatibleIgnoreQualifiers( t1, t2, ast::SymbolTable{} );
    1033         }
    1034 
    1035         const FunctionDecl * isAssignment( const Declaration * decl ) {
     629        FunctionDecl * isAssignment( Declaration * decl ) {
    1036630                return isCopyFunction( decl, "?=?" );
    1037631        }
    1038         const FunctionDecl * isDestructor( const Declaration * decl ) {
    1039                 if ( isDestructor( decl->name ) ) {
    1040                         return dynamic_cast< const FunctionDecl * >( decl );
     632        FunctionDecl * isDestructor( Declaration * decl ) {
     633                if ( isDestructor( decl->get_name() ) ) {
     634                        return dynamic_cast< FunctionDecl * >( decl );
    1041635                }
    1042636                return nullptr;
    1043637        }
    1044         const FunctionDecl * isDefaultConstructor( const Declaration * decl ) {
     638        FunctionDecl * isDefaultConstructor( Declaration * decl ) {
    1045639                if ( isConstructor( decl->name ) ) {
    1046                         if ( const FunctionDecl * func = dynamic_cast< const FunctionDecl * >( decl ) ) {
     640                        if ( FunctionDecl * func = dynamic_cast< FunctionDecl * >( decl ) ) {
    1047641                                if ( func->type->parameters.size() == 1 ) {
    1048642                                        return func;
     
    1052646                return nullptr;
    1053647        }
    1054         const FunctionDecl * isCopyConstructor( const Declaration * decl ) {
     648        FunctionDecl * isCopyConstructor( Declaration * decl ) {
    1055649                return isCopyFunction( decl, "?{}" );
    1056650        }
Note: See TracChangeset for help on using the changeset viewer.