Ignore:
Timestamp:
Oct 29, 2019, 4:01:24 PM (6 years ago)
Author:
Thierry Delisle <tdelisle@…>
Branches:
ADT, arm-eh, ast-experimental, enum, forall-pointer-decay, jacob/cs343-translation, jenkins-sandbox, master, new-ast, new-ast-unique-expr, pthread-emulation, qualifiedEnum
Children:
773db65, 9421f3d8
Parents:
7951100 (diff), 8364209 (diff)
Note: this is a merge changeset, the changes displayed below correspond to the merge itself.
Use the (diff) links above to see all the changes relative to each parent.
Message:

Merge branch 'master' of plg.uwaterloo.ca:software/cfa/cfa-cc

File:
1 edited

Legend:

Unmodified
Added
Removed
  • src/InitTweak/InitTweak.cc

    r7951100 rb067d9b  
     1//
     2// Cforall Version 1.0.0 Copyright (C) 2015 University of Waterloo
     3//
     4// The contents of this file are covered under the licence agreement in the
     5// file "LICENCE" distributed with Cforall.
     6//
     7// InitTweak.cc --
     8//
     9// Author           : Rob Schluntz
     10// Created On       : Fri May 13 11:26:36 2016
     11// Last Modified By : Peter A. Buhr
     12// Last Modified On : Thu Jul 25 22:21:48 2019
     13// Update Count     : 7
     14//
     15
    116#include <algorithm>               // for find, all_of
    217#include <cassert>                 // for assertf, assert, strict_dynamic_cast
     
    419#include <iterator>                // for back_insert_iterator, back_inserter
    520#include <memory>                  // for __shared_ptr
    6 
     21#include <vector>
     22
     23#include "AST/Expr.hpp"
     24#include "AST/Init.hpp"
     25#include "AST/Node.hpp"
     26#include "AST/Pass.hpp"
     27#include "AST/Stmt.hpp"
     28#include "AST/Type.hpp"
    729#include "Common/PassVisitor.h"
    830#include "Common/SemanticError.h"  // for SemanticError
     
    2648#include "Tuples/Tuples.h"         // for Tuples::isTtype
    2749
    28 class UntypedValofExpr;
    29 
    3050namespace InitTweak {
    3151        namespace {
     
    6787                };
    6888
    69                 struct InitFlattener : public WithShortCircuiting {
     89                struct InitFlattener_old : public WithShortCircuiting {
    7090                        void previsit( SingleInit * singleInit ) {
    7191                                visit_children = false;
     
    7595                };
    7696
    77         }
     97                struct InitFlattener_new : public ast::WithShortCircuiting {
     98                        std::vector< ast::ptr< ast::Expr > > argList;
     99
     100                        void previsit( const ast::SingleInit * singleInit ) {
     101                                visit_children = false;
     102                                argList.emplace_back( singleInit->value );
     103                        }
     104                };
     105
     106        } // anonymous namespace
    78107
    79108        std::list< Expression * > makeInitList( Initializer * init ) {
    80                 PassVisitor<InitFlattener> flattener;
     109                PassVisitor<InitFlattener_old> flattener;
    81110                maybeAccept( init, flattener );
    82111                return flattener.pass.argList;
     
    95124        }
    96125
    97         class InitExpander::ExpanderImpl {
     126std::vector< ast::ptr< ast::Expr > > makeInitList( const ast::Init * init ) {
     127        ast::Pass< InitFlattener_new > flattener;
     128        maybe_accept( init, flattener );
     129        return std::move( flattener.pass.argList );
     130}
     131
     132        class InitExpander_old::ExpanderImpl {
    98133        public:
    99134                virtual ~ExpanderImpl() = default;
     
    102137        };
    103138
    104         class InitImpl : public InitExpander::ExpanderImpl {
     139        class InitImpl_old : public InitExpander_old::ExpanderImpl {
    105140        public:
    106                 InitImpl( Initializer * init ) : init( init ) {}
    107                 virtual ~InitImpl() = default;
     141                InitImpl_old( Initializer * init ) : init( init ) {}
     142                virtual ~InitImpl_old() = default;
    108143
    109144                virtual std::list< Expression * > next( __attribute((unused)) std::list< Expression * > & indices ) {
     
    119154        };
    120155
    121         class ExprImpl : public InitExpander::ExpanderImpl {
     156        class ExprImpl_old : public InitExpander_old::ExpanderImpl {
    122157        public:
    123                 ExprImpl( Expression * expr ) : arg( expr ) {}
    124                 virtual ~ExprImpl() { delete arg; }
     158                ExprImpl_old( Expression * expr ) : arg( expr ) {}
     159                virtual ~ExprImpl_old() { delete arg; }
    125160
    126161                virtual std::list< Expression * > next( std::list< Expression * > & indices ) {
     
    146181        };
    147182
    148         InitExpander::InitExpander( Initializer * init ) : expander( new InitImpl( init ) ) {}
    149 
    150         InitExpander::InitExpander( Expression * expr ) : expander( new ExprImpl( expr ) ) {}
    151 
    152         std::list< Expression * > InitExpander::operator*() {
     183        InitExpander_old::InitExpander_old( Initializer * init ) : expander( new InitImpl_old( init ) ) {}
     184
     185        InitExpander_old::InitExpander_old( Expression * expr ) : expander( new ExprImpl_old( expr ) ) {}
     186
     187        std::list< Expression * > InitExpander_old::operator*() {
    153188                return cur;
    154189        }
    155190
    156         InitExpander & InitExpander::operator++() {
     191        InitExpander_old & InitExpander_old::operator++() {
    157192                cur = expander->next( indices );
    158193                return *this;
     
    160195
    161196        // use array indices list to build switch statement
    162         void InitExpander::addArrayIndex( Expression * index, Expression * dimension ) {
     197        void InitExpander_old::addArrayIndex( Expression * index, Expression * dimension ) {
    163198                indices.push_back( index );
    164199                indices.push_back( dimension );
    165200        }
    166201
    167         void InitExpander::clearArrayIndices() {
     202        void InitExpander_old::clearArrayIndices() {
    168203                deleteAll( indices );
    169204                indices.clear();
    170205        }
    171206
    172         bool InitExpander::addReference() {
     207        bool InitExpander_old::addReference() {
    173208                bool added = false;
    174209                for ( Expression *& expr : cur ) {
     
    201236
    202237                template< typename OutIterator >
    203                 void build( UntypedExpr * callExpr, InitExpander::IndexList::iterator idx, InitExpander::IndexList::iterator idxEnd, Initializer * init, OutIterator out ) {
     238                void build( UntypedExpr * callExpr, InitExpander_old::IndexList::iterator idx, InitExpander_old::IndexList::iterator idxEnd, Initializer * init, OutIterator out ) {
    204239                        if ( idx == idxEnd ) return;
    205240                        Expression * index = *idx++;
     
    258293        // remaining elements.
    259294        // To accomplish this, generate switch statement, consuming all of expander's elements
    260         Statement * InitImpl::buildListInit( UntypedExpr * dst, std::list< Expression * > & indices ) {
     295        Statement * InitImpl_old::buildListInit( UntypedExpr * dst, std::list< Expression * > & indices ) {
    261296                if ( ! init ) return nullptr;
    262297                CompoundStmt * block = new CompoundStmt();
     
    271306        }
    272307
    273         Statement * ExprImpl::buildListInit( UntypedExpr *, std::list< Expression * > & ) {
     308        Statement * ExprImpl_old::buildListInit( UntypedExpr *, std::list< Expression * > & ) {
    274309                return nullptr;
    275310        }
    276311
    277         Statement * InitExpander::buildListInit( UntypedExpr * dst ) {
     312        Statement * InitExpander_old::buildListInit( UntypedExpr * dst ) {
    278313                return expander->buildListInit( dst, indices );
    279314        }
     315
     316class InitExpander_new::ExpanderImpl {
     317public:
     318        virtual ~ExpanderImpl() = default;
     319        virtual std::vector< ast::ptr< ast::Expr > > next( IndexList & indices ) = 0;
     320        virtual ast::ptr< ast::Stmt > buildListInit(
     321                ast::UntypedExpr * callExpr, IndexList & indices ) = 0;
     322};
     323
     324namespace {
     325        template< typename Out >
     326        void buildCallExpr(
     327                ast::UntypedExpr * callExpr, const ast::Expr * index, const ast::Expr * dimension,
     328                const ast::Init * init, Out & out
     329        ) {
     330                const CodeLocation & loc = init->location;
     331
     332                auto cond = new ast::UntypedExpr{
     333                        loc, new ast::NameExpr{ loc, "?<?" }, { index, dimension } };
     334
     335                std::vector< ast::ptr< ast::Expr > > args = makeInitList( init );
     336                splice( callExpr->args, args );
     337
     338                out.emplace_back( new ast::IfStmt{ loc, cond, new ast::ExprStmt{ loc, callExpr } } );
     339
     340                out.emplace_back( new ast::ExprStmt{
     341                        loc, new ast::UntypedExpr{ loc, new ast::NameExpr{ loc, "++?" }, { index } } } );
     342        }
     343
     344        template< typename Out >
     345        void build(
     346                ast::UntypedExpr * callExpr, const InitExpander_new::IndexList & indices,
     347                const ast::Init * init, Out & out
     348        ) {
     349                if ( indices.empty() ) return;
     350
     351                unsigned idx = 0;
     352
     353                const ast::Expr * index = indices[idx++];
     354                assert( idx != indices.size() );
     355                const ast::Expr * dimension = indices[idx++];
     356
     357                if ( idx == indices.size() ) {
     358                        if ( auto listInit = dynamic_cast< const ast::ListInit * >( init ) ) {
     359                                for ( const ast::Init * init : *listInit ) {
     360                                        buildCallExpr( callExpr, index, dimension, init, out );
     361                                }
     362                        } else {
     363                                buildCallExpr( callExpr, index, dimension, init, out );
     364                        }
     365                } else {
     366                        const CodeLocation & loc = init->location;
     367
     368                        unsigned long cond = 0;
     369                        auto listInit = dynamic_cast< const ast::ListInit * >( init );
     370                        if ( ! listInit ) { SemanticError( loc, "unbalanced list initializers" ); }
     371
     372                        static UniqueName targetLabel( "L__autogen__" );
     373                        ast::Label switchLabel{
     374                                loc, targetLabel.newName(), { new ast::Attribute{ "unused" } } };
     375
     376                        std::vector< ast::ptr< ast::Stmt > > branches;
     377                        for ( const ast::Init * init : *listInit ) {
     378                                auto condition = ast::ConstantExpr::from_ulong( loc, cond );
     379                                ++cond;
     380
     381                                std::vector< ast::ptr< ast::Stmt > > stmts;
     382                                build( callExpr, indices, init, stmts );
     383                                stmts.emplace_back(
     384                                        new ast::BranchStmt{ loc, ast::BranchStmt::Break, switchLabel } );
     385                                branches.emplace_back( new ast::CaseStmt{ loc, condition, std::move( stmts ) } );
     386                        }
     387                        out.emplace_back( new ast::SwitchStmt{ loc, index, std::move( branches ) } );
     388                        out.emplace_back( new ast::NullStmt{ loc, { switchLabel } } );
     389                }
     390        }
     391
     392        class InitImpl_new final : public InitExpander_new::ExpanderImpl {
     393                ast::ptr< ast::Init > init;
     394        public:
     395                InitImpl_new( const ast::Init * i ) : init( i ) {}
     396
     397                std::vector< ast::ptr< ast::Expr > > next( InitExpander_new::IndexList & ) override {
     398                        return makeInitList( init );
     399                }
     400
     401                ast::ptr< ast::Stmt > buildListInit(
     402                        ast::UntypedExpr * callExpr, InitExpander_new::IndexList & indices
     403                ) override {
     404                        // If array came with an initializer list, initialize each element. We may have more
     405                        // initializers than elements of the array; need to check at each index that we have
     406                        // not exceeded size. We may have fewer initializers than elements in the array; need
     407                        // to default-construct remaining elements. To accomplish this, generate switch
     408                        // statement consuming all of expander's elements
     409
     410                        if ( ! init ) return {};
     411
     412                        std::list< ast::ptr< ast::Stmt > > stmts;
     413                        build( callExpr, indices, init, stmts );
     414                        if ( stmts.empty() ) {
     415                                return {};
     416                        } else {
     417                                auto block = new ast::CompoundStmt{ init->location, std::move( stmts ) };
     418                                init = nullptr;  // consumed in creating the list init
     419                                return block;
     420                        }
     421                }
     422        };
     423
     424        class ExprImpl_new final : public InitExpander_new::ExpanderImpl {
     425                ast::ptr< ast::Expr > arg;
     426        public:
     427                ExprImpl_new( const ast::Expr * a ) : arg( a ) {}
     428
     429                std::vector< ast::ptr< ast::Expr > > next(
     430                        InitExpander_new::IndexList & indices
     431                ) override {
     432                        if ( ! arg ) return {};
     433
     434                        const CodeLocation & loc = arg->location;
     435                        const ast::Expr * expr = arg;
     436                        for ( auto it = indices.rbegin(); it != indices.rend(); ++it ) {
     437                                // go through indices and layer on subscript exprs ?[?]
     438                                ++it;
     439                                expr = new ast::UntypedExpr{
     440                                        loc, new ast::NameExpr{ loc, "?[?]" }, { expr, *it } };
     441                        }
     442                        return { expr };
     443                }
     444
     445                ast::ptr< ast::Stmt > buildListInit(
     446                        ast::UntypedExpr *, InitExpander_new::IndexList &
     447                ) override {
     448                        return {};
     449                }
     450        };
     451} // anonymous namespace
     452
     453InitExpander_new::InitExpander_new( const ast::Init * init )
     454: expander( new InitImpl_new{ init } ), crnt(), indices() {}
     455
     456InitExpander_new::InitExpander_new( const ast::Expr * expr )
     457: expander( new ExprImpl_new{ expr } ), crnt(), indices() {}
     458
     459std::vector< ast::ptr< ast::Expr > > InitExpander_new::operator* () { return crnt; }
     460
     461InitExpander_new & InitExpander_new::operator++ () {
     462        crnt = expander->next( indices );
     463        return *this;
     464}
     465
     466/// builds statement which has the same semantics as a C-style list initializer (for array
     467/// initializers) using callExpr as the base expression to perform initialization
     468ast::ptr< ast::Stmt > InitExpander_new::buildListInit( ast::UntypedExpr * callExpr ) {
     469        return expander->buildListInit( callExpr, indices );
     470}
     471
     472void InitExpander_new::addArrayIndex( const ast::Expr * index, const ast::Expr * dimension ) {
     473        indices.emplace_back( index );
     474        indices.emplace_back( dimension );
     475}
     476
     477void InitExpander_new::clearArrayIndices() { indices.clear(); }
     478
     479bool InitExpander_new::addReference() {
     480        for ( ast::ptr< ast::Expr > & expr : crnt ) {
     481                expr = new ast::AddressExpr{ expr };
     482        }
     483        return ! crnt.empty();
     484}
    280485
    281486        Type * getTypeofThis( FunctionType * ftype ) {
     
    306511        }
    307512
    308         struct CallFinder {
    309                 CallFinder( const std::list< std::string > & names ) : names( names ) {}
     513        struct CallFinder_old {
     514                CallFinder_old( const std::list< std::string > & names ) : names( names ) {}
    310515
    311516                void postvisit( ApplicationExpr * appExpr ) {
     
    330535        };
    331536
     537        struct CallFinder_new final {
     538                std::vector< ast::ptr< ast::Expr > > matches;
     539                const std::vector< std::string > names;
     540
     541                CallFinder_new( std::vector< std::string > && ns ) : matches(), names( std::move(ns) ) {}
     542
     543                void handleCallExpr( const ast::Expr * expr ) {
     544                        std::string fname = getFunctionName( expr );
     545                        if ( std::find( names.begin(), names.end(), fname ) != names.end() ) {
     546                                matches.emplace_back( expr );
     547                        }
     548                }
     549
     550                void postvisit( const ast::ApplicationExpr * expr ) { handleCallExpr( expr ); }
     551                void postvisit( const ast::UntypedExpr *     expr ) { handleCallExpr( expr ); }
     552        };
     553
    332554        void collectCtorDtorCalls( Statement * stmt, std::list< Expression * > & matches ) {
    333                 static PassVisitor<CallFinder> finder( std::list< std::string >{ "?{}", "^?{}" } );
     555                static PassVisitor<CallFinder_old> finder( std::list< std::string >{ "?{}", "^?{}" } );
    334556                finder.pass.matches = &matches;
    335557                maybeAccept( stmt, finder );
     558        }
     559
     560        std::vector< ast::ptr< ast::Expr > > collectCtorDtorCalls( const ast::Stmt * stmt ) {
     561                ast::Pass< CallFinder_new > finder{ std::vector< std::string >{ "?{}", "^?{}" } };
     562                maybe_accept( stmt, finder );
     563                return std::move( finder.pass.matches );
    336564        }
    337565
     
    339567                std::list< Expression * > matches;
    340568                collectCtorDtorCalls( stmt, matches );
    341                 assert( matches.size() <= 1 );
     569                assertf( matches.size() <= 1, "%zd constructor/destructors found in %s", matches.size(), toString( stmt ).c_str() );
    342570                return matches.size() == 1 ? matches.front() : nullptr;
    343571        }
     
    345573        namespace {
    346574                DeclarationWithType * getCalledFunction( Expression * expr );
     575                const ast::DeclWithType * getCalledFunction( const ast::Expr * expr );
    347576
    348577                template<typename CallExpr>
     
    354583                        return getCalledFunction( expr->get_args().front() );
    355584                }
     585
     586                template<typename CallExpr>
     587                const ast::DeclWithType * handleDerefCalledFunction( const CallExpr * expr ) {
     588                        // (*f)(x) => should get "f"
     589                        std::string name = getFunctionName( expr );
     590                        assertf( name == "*?", "Unexpected untyped expression: %s", name.c_str() );
     591                        assertf( ! expr->args.empty(), "Cannot get called function from dereference with no arguments" );
     592                        return getCalledFunction( expr->args.front() );
     593                }
     594
    356595
    357596                DeclarationWithType * getCalledFunction( Expression * expr ) {
     
    374613                        return nullptr;
    375614                }
     615
     616                const ast::DeclWithType * getCalledFunction( const ast::Expr * expr ) {
     617                        assert( expr );
     618                        if ( const ast::VariableExpr * varExpr = dynamic_cast< const ast::VariableExpr * >( expr ) ) {
     619                                return varExpr->var;
     620                        } else if ( const ast::MemberExpr * memberExpr = dynamic_cast< const ast::MemberExpr * >( expr ) ) {
     621                                return memberExpr->member;
     622                        } else if ( const ast::CastExpr * castExpr = dynamic_cast< const ast::CastExpr * >( expr ) ) {
     623                                return getCalledFunction( castExpr->arg );
     624                        } else if ( const ast::UntypedExpr * untypedExpr = dynamic_cast< const ast::UntypedExpr * >( expr ) ) {
     625                                return handleDerefCalledFunction( untypedExpr );
     626                        } else if ( const ast::ApplicationExpr * appExpr = dynamic_cast< const ast::ApplicationExpr * > ( expr ) ) {
     627                                return handleDerefCalledFunction( appExpr );
     628                        } else if ( const ast::AddressExpr * addrExpr = dynamic_cast< const ast::AddressExpr * >( expr ) ) {
     629                                return getCalledFunction( addrExpr->arg );
     630                        } else if ( const ast::CommaExpr * commaExpr = dynamic_cast< const ast::CommaExpr * >( expr ) ) {
     631                                return getCalledFunction( commaExpr->arg2 );
     632                        }
     633                        return nullptr;
     634                }
     635
     636                DeclarationWithType * getFunctionCore( const Expression * expr ) {
     637                        if ( const auto * appExpr = dynamic_cast< const ApplicationExpr * >( expr ) ) {
     638                                return getCalledFunction( appExpr->function );
     639                        } else if ( const auto * untyped = dynamic_cast< const UntypedExpr * >( expr ) ) {
     640                                return getCalledFunction( untyped->function );
     641                        }
     642                        assertf( false, "getFunction with unknown expression: %s", toString( expr ).c_str() );
     643                }
    376644        }
    377645
    378646        DeclarationWithType * getFunction( Expression * expr ) {
    379                 if ( ApplicationExpr * appExpr = dynamic_cast< ApplicationExpr * >( expr ) ) {
    380                         return getCalledFunction( appExpr->get_function() );
    381                 } else if ( UntypedExpr * untyped = dynamic_cast< UntypedExpr * > ( expr ) ) {
    382                         return getCalledFunction( untyped->get_function() );
     647                return getFunctionCore( expr );
     648        }
     649
     650        const DeclarationWithType * getFunction( const Expression * expr ) {
     651                return getFunctionCore( expr );
     652        }
     653
     654        const ast::DeclWithType * getFunction( const ast::Expr * expr ) {
     655                if ( const ast::ApplicationExpr * appExpr = dynamic_cast< const ast::ApplicationExpr * >( expr ) ) {
     656                        return getCalledFunction( appExpr->func );
     657                } else if ( const ast::UntypedExpr * untyped = dynamic_cast< const ast::UntypedExpr * > ( expr ) ) {
     658                        return getCalledFunction( untyped->func );
    383659                }
    384660                assertf( false, "getFunction received unknown expression: %s", toString( expr ).c_str() );
     
    395671        }
    396672
     673        const ast::ApplicationExpr * isIntrinsicCallExpr( const ast::Expr * expr ) {
     674                auto appExpr = dynamic_cast< const ast::ApplicationExpr * >( expr );
     675                if ( ! appExpr ) return nullptr;
     676
     677                const ast::DeclWithType * func = getCalledFunction( appExpr->func );
     678                assertf( func,
     679                        "getCalledFunction returned nullptr: %s", toString( appExpr->func ).c_str() );
     680
     681                // check for Intrinsic only -- don't want to remove all overridable ctor/dtor because
     682                // autogenerated ctor/dtor will call all member dtors, and some members may have a
     683                // user-defined dtor
     684                return func->linkage == ast::Linkage::Intrinsic ? appExpr : nullptr;
     685        }
     686
    397687        namespace {
    398688                template <typename Predicate>
     
    403693                        return std::all_of( callExprs.begin(), callExprs.end(), pred);
    404694                }
     695
     696                template <typename Predicate>
     697                bool allofCtorDtor( const ast::Stmt * stmt, const Predicate & pred ) {
     698                        std::vector< ast::ptr< ast::Expr > > callExprs = collectCtorDtorCalls( stmt );
     699                        return std::all_of( callExprs.begin(), callExprs.end(), pred );
     700                }
    405701        }
    406702
     
    408704                return allofCtorDtor( stmt, []( Expression * callExpr ){
    409705                        if ( ApplicationExpr * appExpr = isIntrinsicCallExpr( callExpr ) ) {
    410                                 FunctionType *funcType = GenPoly::getFunctionType( appExpr->get_function()->get_result() );
     706                                FunctionType *funcType = GenPoly::getFunctionType( appExpr->function->result );
    411707                                assert( funcType );
    412708                                return funcType->get_parameters().size() == 1;
     709                        }
     710                        return false;
     711                });
     712        }
     713
     714        bool isIntrinsicSingleArgCallStmt( const ast::Stmt * stmt ) {
     715                return allofCtorDtor( stmt, []( const ast::Expr * callExpr ){
     716                        if ( const ast::ApplicationExpr * appExpr = isIntrinsicCallExpr( callExpr ) ) {
     717                                const ast::FunctionType * funcType =
     718                                        GenPoly::getFunctionType( appExpr->func->result );
     719                                assert( funcType );
     720                                return funcType->params.size() == 1;
    413721                        }
    414722                        return false;
     
    429737                                if ( pos == 0 ) return arg;
    430738                                pos--;
     739                        }
     740                        assert( false );
     741                }
     742
     743                template<typename CallExpr>
     744                const ast::Expr * callArg( const CallExpr * call, unsigned int pos ) {
     745                        if( pos >= call->args.size() ) {
     746                                assertf( false, "getCallArg for argument that doesn't exist: (%u); %s.",
     747                                        pos, toString( call ).c_str() );
     748                        }
     749                        for ( const ast::Expr * arg : call->args ) {
     750                                if ( pos == 0 ) return arg;
     751                                --pos;
    431752                        }
    432753                        assert( false );
     
    453774        }
    454775
     776        const ast::Expr * getCallArg( const ast::Expr * call, unsigned pos ) {
     777                if ( auto app = dynamic_cast< const ast::ApplicationExpr * >( call ) ) {
     778                        return callArg( app, pos );
     779                } else if ( auto untyped = dynamic_cast< const ast::UntypedExpr * >( call ) ) {
     780                        return callArg( untyped, pos );
     781                } else if ( auto tupleAssn = dynamic_cast< const ast::TupleAssignExpr * >( call ) ) {
     782                        const std::list<ast::ptr<ast::Stmt>>& stmts = tupleAssn->stmtExpr->stmts->kids;
     783                        assertf( ! stmts.empty(), "TupleAssignExpr missing statements." );
     784                        auto stmt  = strict_dynamic_cast< const ast::ExprStmt * >( stmts.back().get() );
     785                        auto tuple = strict_dynamic_cast< const ast::TupleExpr * >( stmt->expr.get() );
     786                        assertf( ! tuple->exprs.empty(), "TupleAssignExpr has empty tuple expr.");
     787                        return getCallArg( tuple->exprs.front(), pos );
     788                } else if ( auto ctor = dynamic_cast< const ast::ImplicitCopyCtorExpr * >( call ) ) {
     789                        return getCallArg( ctor->callExpr, pos );
     790                } else {
     791                        assertf( false, "Unexpected expression type passed to getCallArg: %s",
     792                                toString( call ).c_str() );
     793                }
     794        }
     795
    455796        namespace {
    456797                std::string funcName( Expression * func );
     798                std::string funcName( const ast::Expr * func );
    457799
    458800                template<typename CallExpr>
     
    463805                        assertf( ! expr->get_args().empty(), "Cannot get function name from dereference with no arguments" );
    464806                        return funcName( expr->get_args().front() );
     807                }
     808
     809                template<typename CallExpr>
     810                std::string handleDerefName( const CallExpr * expr ) {
     811                        // (*f)(x) => should get name "f"
     812                        std::string name = getFunctionName( expr );
     813                        assertf( name == "*?", "Unexpected untyped expression: %s", name.c_str() );
     814                        assertf( ! expr->args.empty(), "Cannot get function name from dereference with no arguments" );
     815                        return funcName( expr->args.front() );
    465816                }
    466817
     
    486837                        }
    487838                }
     839
     840                std::string funcName( const ast::Expr * func ) {
     841                        if ( const ast::NameExpr * nameExpr = dynamic_cast< const ast::NameExpr * >( func ) ) {
     842                                return nameExpr->name;
     843                        } else if ( const ast::VariableExpr * varExpr = dynamic_cast< const ast::VariableExpr * >( func ) ) {
     844                                return varExpr->var->name;
     845                        }       else if ( const ast::CastExpr * castExpr = dynamic_cast< const ast::CastExpr * >( func ) ) {
     846                                return funcName( castExpr->arg );
     847                        } else if ( const ast::MemberExpr * memberExpr = dynamic_cast< const ast::MemberExpr * >( func ) ) {
     848                                return memberExpr->member->name;
     849                        } else if ( const ast::UntypedMemberExpr * memberExpr = dynamic_cast< const ast::UntypedMemberExpr * > ( func ) ) {
     850                                return funcName( memberExpr->member );
     851                        } else if ( const ast::UntypedExpr * untypedExpr = dynamic_cast< const ast::UntypedExpr * >( func ) ) {
     852                                return handleDerefName( untypedExpr );
     853                        } else if ( const ast::ApplicationExpr * appExpr = dynamic_cast< const ast::ApplicationExpr * >( func ) ) {
     854                                return handleDerefName( appExpr );
     855                        } else if ( const ast::ConstructorExpr * ctorExpr = dynamic_cast< const ast::ConstructorExpr * >( func ) ) {
     856                                return funcName( getCallArg( ctorExpr->callExpr, 0 ) );
     857                        } else {
     858                                assertf( false, "Unexpected expression type being called as a function in call expression: %s", toString( func ).c_str() );
     859                        }
     860                }
    488861        }
    489862
     
    502875        }
    503876
     877        std::string getFunctionName( const ast::Expr * expr ) {
     878                // there's some unforunate overlap here with getCalledFunction. Ideally this would be able to use getCalledFunction and
     879                // return the name of the DeclarationWithType, but this needs to work for NameExpr and UntypedMemberExpr, where getCalledFunction
     880                // can't possibly do anything reasonable.
     881                if ( const ast::ApplicationExpr * appExpr = dynamic_cast< const ast::ApplicationExpr * >( expr ) ) {
     882                        return funcName( appExpr->func );
     883                } else if ( const ast::UntypedExpr * untypedExpr = dynamic_cast< const ast::UntypedExpr * > ( expr ) ) {
     884                        return funcName( untypedExpr->func );
     885                } else {
     886                        std::cerr << expr << std::endl;
     887                        assertf( false, "Unexpected expression type passed to getFunctionName" );
     888                }
     889        }
     890
    504891        Type * getPointerBase( Type * type ) {
    505892                if ( PointerType * ptrType = dynamic_cast< PointerType * >( type ) ) {
     
    513900                }
    514901        }
     902        const ast::Type* getPointerBase( const ast::Type* t ) {
     903                if ( const auto * p = dynamic_cast< const ast::PointerType * >( t ) ) {
     904                        return p->base;
     905                } else if ( const auto * a = dynamic_cast< const ast::ArrayType * >( t ) ) {
     906                        return a->base;
     907                } else if ( const auto * r = dynamic_cast< const ast::ReferenceType * >( t ) ) {
     908                        return r->base;
     909                } else return nullptr;
     910        }
    515911
    516912        Type * isPointerType( Type * type ) {
     
    561957                void previsit( OffsetofExpr * ) {}
    562958                void previsit( OffsetPackExpr * ) {}
    563                 void previsit( AttrExpr * ) {}
    564959                void previsit( CommaExpr * ) {}
    565960                void previsit( LogicalExpr * ) {}
     
    6091004        bool isCtorDtorAssign( const std::string & str ) { return isCtorDtor( str ) || isAssignment( str ); }
    6101005
    611         FunctionDecl * isCopyFunction( Declaration * decl, const std::string & fname ) {
    612                 FunctionDecl * function = dynamic_cast< FunctionDecl * >( decl );
     1006        const FunctionDecl * isCopyFunction( const Declaration * decl, const std::string & fname ) {
     1007                const FunctionDecl * function = dynamic_cast< const FunctionDecl * >( decl );
    6131008                if ( ! function ) return nullptr;
    6141009                if ( function->name != fname ) return nullptr;
     
    6271022        }
    6281023
    629         FunctionDecl * isAssignment( Declaration * decl ) {
     1024        bool isCopyFunction( const ast::FunctionDecl * decl ) {
     1025                const ast::FunctionType * ftype = decl->type;
     1026                if ( ftype->params.size() != 2 ) return false;
     1027
     1028                const ast::Type * t1 = getPointerBase( ftype->params.front()->get_type() );
     1029                if ( ! t1 ) return false;
     1030                const ast::Type * t2 = ftype->params.back()->get_type();
     1031
     1032                return ResolvExpr::typesCompatibleIgnoreQualifiers( t1, t2, ast::SymbolTable{} );
     1033        }
     1034
     1035        const FunctionDecl * isAssignment( const Declaration * decl ) {
    6301036                return isCopyFunction( decl, "?=?" );
    6311037        }
    632         FunctionDecl * isDestructor( Declaration * decl ) {
    633                 if ( isDestructor( decl->get_name() ) ) {
    634                         return dynamic_cast< FunctionDecl * >( decl );
     1038        const FunctionDecl * isDestructor( const Declaration * decl ) {
     1039                if ( isDestructor( decl->name ) ) {
     1040                        return dynamic_cast< const FunctionDecl * >( decl );
    6351041                }
    6361042                return nullptr;
    6371043        }
    638         FunctionDecl * isDefaultConstructor( Declaration * decl ) {
     1044        const FunctionDecl * isDefaultConstructor( const Declaration * decl ) {
    6391045                if ( isConstructor( decl->name ) ) {
    640                         if ( FunctionDecl * func = dynamic_cast< FunctionDecl * >( decl ) ) {
     1046                        if ( const FunctionDecl * func = dynamic_cast< const FunctionDecl * >( decl ) ) {
    6411047                                if ( func->type->parameters.size() == 1 ) {
    6421048                                        return func;
     
    6461052                return nullptr;
    6471053        }
    648         FunctionDecl * isCopyConstructor( Declaration * decl ) {
     1054        const FunctionDecl * isCopyConstructor( const Declaration * decl ) {
    6491055                return isCopyFunction( decl, "?{}" );
    6501056        }
Note: See TracChangeset for help on using the changeset viewer.