source: src/InitTweak/InitTweak.cc @ f898983

Last change on this file since f898983 was b7c53a9d, checked in by Andrew Beach <ajbeach@…>, 12 months ago

Added a new invariant check and the fixes required to make it pass. Not the new check is by no means exaustive (it doesn't even check every readonly pointer) but it should catch the most common/problematic cases.

  • Property mode set to 100644
File size: 40.0 KB
Line 
1//
2// Cforall Version 1.0.0 Copyright (C) 2015 University of Waterloo
3//
4// The contents of this file are covered under the licence agreement in the
5// file "LICENCE" distributed with Cforall.
6//
7// InitTweak.cc --
8//
9// Author           : Rob Schluntz
10// Created On       : Fri May 13 11:26:36 2016
11// Last Modified By : Andrew Beach
12// Last Modified On : Wed Sep 22  9:50:00 2022
13// Update Count     : 21
14//
15
16#include <algorithm>               // for find, all_of
17#include <cassert>                 // for assertf, assert, strict_dynamic_cast
18#include <iostream>                // for ostream, cerr, endl
19#include <iterator>                // for back_insert_iterator, back_inserter
20#include <memory>                  // for __shared_ptr
21#include <vector>
22
23#include "AST/Expr.hpp"
24#include "AST/Init.hpp"
25#include "AST/Inspect.hpp"
26#include "AST/Node.hpp"
27#include "AST/Pass.hpp"
28#include "AST/Stmt.hpp"
29#include "AST/Type.hpp"
30#include "CodeGen/OperatorTable.h" // for isConstructor, isDestructor, isCto...
31#include "Common/PassVisitor.h"
32#include "Common/SemanticError.h"  // for SemanticError
33#include "Common/UniqueName.h"     // for UniqueName
34#include "Common/utility.h"        // for toString, deleteAll, maybeClone
35#include "GenPoly/GenPoly.h"       // for getFunctionType
36#include "InitTweak.h"
37#include "ResolvExpr/Unify.h"      // for typesCompatibleIgnoreQualifiers
38#include "SymTab/Autogen.h"
39#include "SymTab/Indexer.h"        // for Indexer
40#include "SynTree/LinkageSpec.h"   // for Spec, isBuiltin, Intrinsic
41#include "SynTree/Attribute.h"     // for Attribute
42#include "SynTree/Constant.h"      // for Constant
43#include "SynTree/Declaration.h"   // for ObjectDecl, DeclarationWithType
44#include "SynTree/Expression.h"    // for Expression, UntypedExpr, Applicati...
45#include "SynTree/Initializer.h"   // for Initializer, ListInit, Designation
46#include "SynTree/Label.h"         // for Label
47#include "SynTree/Statement.h"     // for CompoundStmt, ExprStmt, BranchStmt
48#include "SynTree/Type.h"          // for FunctionType, ArrayType, PointerType
49#include "SynTree/Visitor.h"       // for Visitor, maybeAccept
50#include "Tuples/Tuples.h"         // for Tuples::isTtype
51
52namespace InitTweak {
53        namespace {
54                struct HasDesignations : public WithShortCircuiting {
55                        bool hasDesignations = false;
56
57                        void previsit( BaseSyntaxNode * ) {
58                                // short circuit if we already know there are designations
59                                if ( hasDesignations ) visit_children = false;
60                        }
61
62                        void previsit( Designation * des ) {
63                                // short circuit if we already know there are designations
64                                if ( hasDesignations ) visit_children = false;
65                                else if ( ! des->get_designators().empty() ) {
66                                        hasDesignations = true;
67                                        visit_children = false;
68                                }
69                        }
70                };
71
72                struct InitDepthChecker : public WithGuards {
73                        bool depthOkay = true;
74                        Type * type;
75                        int curDepth = 0, maxDepth = 0;
76                        InitDepthChecker( Type * type ) : type( type ) {
77                                Type * t = type;
78                                while ( ArrayType * at = dynamic_cast< ArrayType * >( t ) ) {
79                                        maxDepth++;
80                                        t = at->get_base();
81                                }
82                                maxDepth++;
83                        }
84                        void previsit( ListInit * ) {
85                                curDepth++;
86                                GuardAction( [this]() { curDepth--; } );
87                                if ( curDepth > maxDepth ) depthOkay = false;
88                        }
89                };
90
91                struct HasDesignations_new : public ast::WithShortCircuiting {
92                        bool result = false;
93
94                        void previsit( const ast::Node * ) {
95                                // short circuit if we already know there are designations
96                                if ( result ) visit_children = false;
97                        }
98
99                        void previsit( const ast::Designation * des ) {
100                                // short circuit if we already know there are designations
101                                if ( result ) visit_children = false;
102                                else if ( ! des->designators.empty() ) {
103                                        result = true;
104                                        visit_children = false;
105                                }
106                        }
107                };
108
109                struct InitDepthChecker_new : public ast::WithGuards {
110                        bool result = true;
111                        const ast::Type * type;
112                        int curDepth = 0, maxDepth = 0;
113                        InitDepthChecker_new( const ast::Type * type ) : type( type ) {
114                                const ast::Type * t = type;
115                                while ( auto at = dynamic_cast< const ast::ArrayType * >( t ) ) {
116                                        maxDepth++;
117                                        t = at->base;
118                                }
119                                maxDepth++;
120                        }
121                        void previsit( ListInit * ) {
122                                curDepth++;
123                                GuardAction( [this]() { curDepth--; } );
124                                if ( curDepth > maxDepth ) result = false;
125                        }
126                };
127
128                struct InitFlattener_old : public WithShortCircuiting {
129                        void previsit( SingleInit * singleInit ) {
130                                visit_children = false;
131                                argList.push_back( singleInit->value->clone() );
132                        }
133                        std::list< Expression * > argList;
134                };
135
136                struct InitFlattener_new : public ast::WithShortCircuiting {
137                        std::vector< ast::ptr< ast::Expr > > argList;
138
139                        void previsit( const ast::SingleInit * singleInit ) {
140                                visit_children = false;
141                                argList.emplace_back( singleInit->value );
142                        }
143                };
144
145        } // anonymous namespace
146
147        std::list< Expression * > makeInitList( Initializer * init ) {
148                PassVisitor<InitFlattener_old> flattener;
149                maybeAccept( init, flattener );
150                return flattener.pass.argList;
151        }
152
153        bool isDesignated( Initializer * init ) {
154                PassVisitor<HasDesignations> finder;
155                maybeAccept( init, finder );
156                return finder.pass.hasDesignations;
157        }
158
159        bool checkInitDepth( ObjectDecl * objDecl ) {
160                PassVisitor<InitDepthChecker> checker( objDecl->type );
161                maybeAccept( objDecl->init, checker );
162                return checker.pass.depthOkay;
163        }
164
165        bool isDesignated( const ast::Init * init ) {
166                ast::Pass<HasDesignations_new> finder;
167                maybe_accept( init, finder );
168                return finder.core.result;
169        }
170
171        bool checkInitDepth( const ast::ObjectDecl * objDecl ) {
172                ast::Pass<InitDepthChecker_new> checker( objDecl->type );
173                maybe_accept( objDecl->init.get(), checker );
174                return checker.core.result;
175        }
176
177std::vector< ast::ptr< ast::Expr > > makeInitList( const ast::Init * init ) {
178        ast::Pass< InitFlattener_new > flattener;
179        maybe_accept( init, flattener );
180        return std::move( flattener.core.argList );
181}
182
183        class InitExpander_old::ExpanderImpl {
184        public:
185                virtual ~ExpanderImpl() = default;
186                virtual std::list< Expression * > next( std::list< Expression * > & indices ) = 0;
187                virtual Statement * buildListInit( UntypedExpr * callExpr, std::list< Expression * > & indices ) = 0;
188        };
189
190        class InitImpl_old : public InitExpander_old::ExpanderImpl {
191        public:
192                InitImpl_old( Initializer * init ) : init( init ) {}
193                virtual ~InitImpl_old() = default;
194
195                virtual std::list< Expression * > next( __attribute((unused)) std::list< Expression * > & indices ) {
196                        // this is wrong, but just a placeholder for now
197                        // if ( ! flattened ) flatten( indices );
198                        // return ! inits.empty() ? makeInitList( inits.front() ) : std::list< Expression * >();
199                        return makeInitList( init );
200                }
201
202                virtual Statement * buildListInit( UntypedExpr * callExpr, std::list< Expression * > & indices );
203        private:
204                Initializer * init;
205        };
206
207        class ExprImpl_old : public InitExpander_old::ExpanderImpl {
208        public:
209                ExprImpl_old( Expression * expr ) : arg( expr ) {}
210                virtual ~ExprImpl_old() { delete arg; }
211
212                virtual std::list< Expression * > next( std::list< Expression * > & indices ) {
213                        std::list< Expression * > ret;
214                        Expression * expr = maybeClone( arg );
215                        if ( expr ) {
216                                for ( std::list< Expression * >::reverse_iterator it = indices.rbegin(); it != indices.rend(); ++it ) {
217                                        // go through indices and layer on subscript exprs ?[?]
218                                        ++it;
219                                        UntypedExpr * subscriptExpr = new UntypedExpr( new NameExpr( "?[?]") );
220                                        subscriptExpr->get_args().push_back( expr );
221                                        subscriptExpr->get_args().push_back( (*it)->clone() );
222                                        expr = subscriptExpr;
223                                }
224                                ret.push_back( expr );
225                        }
226                        return ret;
227                }
228
229                virtual Statement * buildListInit( UntypedExpr * callExpr, std::list< Expression * > & indices );
230        private:
231                Expression * arg;
232        };
233
234        InitExpander_old::InitExpander_old( Initializer * init ) : expander( new InitImpl_old( init ) ) {}
235
236        InitExpander_old::InitExpander_old( Expression * expr ) : expander( new ExprImpl_old( expr ) ) {}
237
238        std::list< Expression * > InitExpander_old::operator*() {
239                return cur;
240        }
241
242        InitExpander_old & InitExpander_old::operator++() {
243                cur = expander->next( indices );
244                return *this;
245        }
246
247        // use array indices list to build switch statement
248        void InitExpander_old::addArrayIndex( Expression * index, Expression * dimension ) {
249                indices.push_back( index );
250                indices.push_back( dimension );
251        }
252
253        void InitExpander_old::clearArrayIndices() {
254                deleteAll( indices );
255                indices.clear();
256        }
257
258        bool InitExpander_old::addReference() {
259                bool added = false;
260                for ( Expression *& expr : cur ) {
261                        expr = new AddressExpr( expr );
262                        added = true;
263                }
264                return added;
265        }
266
267        namespace {
268                /// given index i, dimension d, initializer init, and callExpr f, generates
269                ///   if (i < d) f(..., init)
270                ///   ++i;
271                /// so that only elements within the range of the array are constructed
272                template< typename OutIterator >
273                void buildCallExpr( UntypedExpr * callExpr, Expression * index, Expression * dimension, Initializer * init, OutIterator out ) {
274                        UntypedExpr * cond = new UntypedExpr( new NameExpr( "?<?") );
275                        cond->get_args().push_back( index->clone() );
276                        cond->get_args().push_back( dimension->clone() );
277
278                        std::list< Expression * > args = makeInitList( init );
279                        callExpr->get_args().splice( callExpr->get_args().end(), args );
280
281                        *out++ = new IfStmt( cond, new ExprStmt( callExpr ), nullptr );
282
283                        UntypedExpr * increment = new UntypedExpr( new NameExpr( "++?" ) );
284                        increment->get_args().push_back( index->clone() );
285                        *out++ = new ExprStmt( increment );
286                }
287
288                template< typename OutIterator >
289                void build( UntypedExpr * callExpr, InitExpander_old::IndexList::iterator idx, InitExpander_old::IndexList::iterator idxEnd, Initializer * init, OutIterator out ) {
290                        if ( idx == idxEnd ) return;
291                        Expression * index = *idx++;
292                        assert( idx != idxEnd );
293                        Expression * dimension = *idx++;
294
295                        // xxx - may want to eventually issue a warning here if we can detect
296                        // that the number of elements exceeds to dimension of the array
297                        if ( idx == idxEnd ) {
298                                if ( ListInit * listInit = dynamic_cast< ListInit * >( init ) ) {
299                                        for ( Initializer * init : *listInit ) {
300                                                buildCallExpr( callExpr->clone(), index, dimension, init, out );
301                                        }
302                                } else {
303                                        buildCallExpr( callExpr->clone(), index, dimension, init, out );
304                                }
305                        } else {
306                                std::list< Statement * > branches;
307
308                                unsigned long cond = 0;
309                                ListInit * listInit = dynamic_cast< ListInit * >( init );
310                                if ( ! listInit ) {
311                                        // xxx - this shouldn't be an error, but need a way to
312                                        // terminate without creating output, so should catch this error
313                                        SemanticError( init->location, "unbalanced list initializers" );
314                                }
315
316                                static UniqueName targetLabel( "L__autogen__" );
317                                Label switchLabel( targetLabel.newName(), 0, std::list< Attribute * >{ new Attribute("unused") } );
318                                for ( Initializer * init : *listInit ) {
319                                        Expression * condition;
320                                        // check for designations
321                                        // if ( init-> ) {
322                                                condition = new ConstantExpr( Constant::from_ulong( cond ) );
323                                                ++cond;
324                                        // } else {
325                                        //      condition = // ... take designation
326                                        //      cond = // ... take designation+1
327                                        // }
328                                        std::list< Statement * > stmts;
329                                        build( callExpr, idx, idxEnd, init, back_inserter( stmts ) );
330                                        stmts.push_back( new BranchStmt( switchLabel, BranchStmt::Break ) );
331                                        CaseStmt * caseStmt = new CaseStmt( condition, stmts );
332                                        branches.push_back( caseStmt );
333                                }
334                                *out++ = new SwitchStmt( index->clone(), branches );
335                                *out++ = new NullStmt( { switchLabel } );
336                        }
337                }
338        }
339
340        // if array came with an initializer list: initialize each element
341        // may have more initializers than elements in the array - need to check at each index that
342        // we haven't exceeded size.
343        // may have fewer initializers than elements in the array - need to default construct
344        // remaining elements.
345        // To accomplish this, generate switch statement, consuming all of expander's elements
346        Statement * InitImpl_old::buildListInit( UntypedExpr * dst, std::list< Expression * > & indices ) {
347                if ( ! init ) return nullptr;
348                CompoundStmt * block = new CompoundStmt();
349                build( dst, indices.begin(), indices.end(), init, back_inserter( block->get_kids() ) );
350                if ( block->get_kids().empty() ) {
351                        delete block;
352                        return nullptr;
353                } else {
354                        init = nullptr; // init was consumed in creating the list init
355                        return block;
356                }
357        }
358
359        Statement * ExprImpl_old::buildListInit( UntypedExpr *, std::list< Expression * > & ) {
360                return nullptr;
361        }
362
363        Statement * InitExpander_old::buildListInit( UntypedExpr * dst ) {
364                return expander->buildListInit( dst, indices );
365        }
366
367class InitExpander_new::ExpanderImpl {
368public:
369        virtual ~ExpanderImpl() = default;
370        virtual std::vector< ast::ptr< ast::Expr > > next( IndexList & indices ) = 0;
371        virtual ast::ptr< ast::Stmt > buildListInit(
372                ast::UntypedExpr * callExpr, IndexList & indices ) = 0;
373};
374
375namespace {
376        template< typename Out >
377        void buildCallExpr(
378                ast::UntypedExpr * callExpr, const ast::Expr * index, const ast::Expr * dimension,
379                const ast::Init * init, Out & out
380        ) {
381                const CodeLocation & loc = init->location;
382
383                auto cond = new ast::UntypedExpr{
384                        loc, new ast::NameExpr{ loc, "?<?" }, { index, dimension } };
385
386                std::vector< ast::ptr< ast::Expr > > args = makeInitList( init );
387                splice( callExpr->args, args );
388
389                out.emplace_back( new ast::IfStmt{ loc, cond, new ast::ExprStmt{ loc, callExpr } } );
390
391                out.emplace_back( new ast::ExprStmt{
392                        loc, new ast::UntypedExpr{ loc, new ast::NameExpr{ loc, "++?" }, { index } } } );
393        }
394
395        template< typename Out >
396        void build(
397                ast::UntypedExpr * callExpr, const InitExpander_new::IndexList & indices,
398                const ast::Init * init, Out & out
399        ) {
400                if ( indices.empty() ) return;
401
402                unsigned idx = 0;
403
404                const ast::Expr * index = indices[idx++];
405                assert( idx != indices.size() );
406                const ast::Expr * dimension = indices[idx++];
407
408                if ( idx == indices.size() ) {
409                        if ( auto listInit = dynamic_cast< const ast::ListInit * >( init ) ) {
410                                for ( const ast::Init * init : *listInit ) {
411                                        buildCallExpr( shallowCopy(callExpr), index, dimension, init, out );
412                                }
413                        } else {
414                                buildCallExpr( shallowCopy(callExpr), index, dimension, init, out );
415                        }
416                } else {
417                        const CodeLocation & loc = init->location;
418
419                        unsigned long cond = 0;
420                        auto listInit = dynamic_cast< const ast::ListInit * >( init );
421                        if ( ! listInit ) { SemanticError( loc, "unbalanced list initializers" ); }
422
423                        static UniqueName targetLabel( "L__autogen__" );
424                        ast::Label switchLabel{
425                                loc, targetLabel.newName(), { new ast::Attribute{ "unused" } } };
426
427                        std::vector< ast::ptr< ast::CaseClause > > branches;
428                        for ( const ast::Init * init : *listInit ) {
429                                auto condition = ast::ConstantExpr::from_ulong( loc, cond );
430                                ++cond;
431
432                                std::vector< ast::ptr< ast::Stmt > > stmts;
433                                build( callExpr, indices, init, stmts );
434                                stmts.emplace_back(
435                                        new ast::BranchStmt{ loc, ast::BranchStmt::Break, switchLabel } );
436                                branches.emplace_back( new ast::CaseClause{ loc, condition, std::move( stmts ) } );
437                        }
438                        out.emplace_back( new ast::SwitchStmt{ loc, index, std::move( branches ) } );
439                        out.emplace_back( new ast::NullStmt{ loc, { switchLabel } } );
440                }
441        }
442
443        class InitImpl_new final : public InitExpander_new::ExpanderImpl {
444                ast::ptr< ast::Init > init;
445        public:
446                InitImpl_new( const ast::Init * i ) : init( i ) {}
447
448                std::vector< ast::ptr< ast::Expr > > next( InitExpander_new::IndexList & ) override {
449                        return makeInitList( init );
450                }
451
452                ast::ptr< ast::Stmt > buildListInit(
453                        ast::UntypedExpr * callExpr, InitExpander_new::IndexList & indices
454                ) override {
455                        // If array came with an initializer list, initialize each element. We may have more
456                        // initializers than elements of the array; need to check at each index that we have
457                        // not exceeded size. We may have fewer initializers than elements in the array; need
458                        // to default-construct remaining elements. To accomplish this, generate switch
459                        // statement consuming all of expander's elements
460
461                        if ( ! init ) return {};
462
463                        std::list< ast::ptr< ast::Stmt > > stmts;
464                        build( callExpr, indices, init, stmts );
465                        if ( stmts.empty() ) {
466                                return {};
467                        } else {
468                                auto block = new ast::CompoundStmt{ init->location, std::move( stmts ) };
469                                init = nullptr;  // consumed in creating the list init
470                                return block;
471                        }
472                }
473        };
474
475        class ExprImpl_new final : public InitExpander_new::ExpanderImpl {
476                ast::ptr< ast::Expr > arg;
477        public:
478                ExprImpl_new( const ast::Expr * a ) : arg( a ) {}
479
480                std::vector< ast::ptr< ast::Expr > > next(
481                        InitExpander_new::IndexList & indices
482                ) override {
483                        if ( ! arg ) return {};
484
485                        const CodeLocation & loc = arg->location;
486                        const ast::Expr * expr = arg;
487                        for ( auto it = indices.rbegin(); it != indices.rend(); ++it ) {
488                                // go through indices and layer on subscript exprs ?[?]
489                                ++it;
490                                expr = new ast::UntypedExpr{
491                                        loc, new ast::NameExpr{ loc, "?[?]" }, { expr, *it } };
492                        }
493                        return { expr };
494                }
495
496                ast::ptr< ast::Stmt > buildListInit(
497                        ast::UntypedExpr *, InitExpander_new::IndexList &
498                ) override {
499                        return {};
500                }
501        };
502} // anonymous namespace
503
504InitExpander_new::InitExpander_new( const ast::Init * init )
505: expander( new InitImpl_new{ init } ), crnt(), indices() {}
506
507InitExpander_new::InitExpander_new( const ast::Expr * expr )
508: expander( new ExprImpl_new{ expr } ), crnt(), indices() {}
509
510std::vector< ast::ptr< ast::Expr > > InitExpander_new::operator* () { return crnt; }
511
512InitExpander_new & InitExpander_new::operator++ () {
513        crnt = expander->next( indices );
514        return *this;
515}
516
517/// builds statement which has the same semantics as a C-style list initializer (for array
518/// initializers) using callExpr as the base expression to perform initialization
519ast::ptr< ast::Stmt > InitExpander_new::buildListInit( ast::UntypedExpr * callExpr ) {
520        return expander->buildListInit( callExpr, indices );
521}
522
523void InitExpander_new::addArrayIndex( const ast::Expr * index, const ast::Expr * dimension ) {
524        indices.emplace_back( index );
525        indices.emplace_back( dimension );
526}
527
528void InitExpander_new::clearArrayIndices() { indices.clear(); }
529
530bool InitExpander_new::addReference() {
531        for ( ast::ptr< ast::Expr > & expr : crnt ) {
532                expr = new ast::AddressExpr{ expr };
533        }
534        return ! crnt.empty();
535}
536
537        Type * getTypeofThis( FunctionType * ftype ) {
538                assertf( ftype, "getTypeofThis: nullptr ftype" );
539                ObjectDecl * thisParam = getParamThis( ftype );
540                ReferenceType * refType = strict_dynamic_cast< ReferenceType * >( thisParam->type );
541                return refType->base;
542        }
543
544        const ast::Type * getTypeofThis( const ast::FunctionType * ftype ) {
545                assertf( ftype, "getTypeofThis: nullptr ftype" );
546                const std::vector<ast::ptr<ast::Type>> & params = ftype->params;
547                assertf( !params.empty(), "getTypeofThis: ftype with 0 parameters: %s",
548                                toString( ftype ).c_str() );
549                const ast::ReferenceType * refType =
550                        params.front().strict_as<ast::ReferenceType>();
551                return refType->base;
552        }
553
554        ObjectDecl * getParamThis( FunctionType * ftype ) {
555                assertf( ftype, "getParamThis: nullptr ftype" );
556                auto & params = ftype->parameters;
557                assertf( ! params.empty(), "getParamThis: ftype with 0 parameters: %s", toString( ftype ).c_str() );
558                return strict_dynamic_cast< ObjectDecl * >( params.front() );
559        }
560
561        const ast::ObjectDecl * getParamThis(const ast::FunctionDecl * func) {
562                assertf( func, "getParamThis: nullptr ftype" );
563                auto & params = func->params;
564                assertf( ! params.empty(), "getParamThis: ftype with 0 parameters: %s", toString( func ).c_str());
565                return params.front().strict_as<ast::ObjectDecl>();
566        }
567
568        bool tryConstruct( DeclarationWithType * dwt ) {
569                ObjectDecl * objDecl = dynamic_cast< ObjectDecl * >( dwt );
570                if ( ! objDecl ) return false;
571                return (objDecl->get_init() == nullptr ||
572                                ( objDecl->get_init() != nullptr && objDecl->get_init()->get_maybeConstructed() ))
573                        && ! objDecl->get_storageClasses().is_extern
574                        && isConstructable( objDecl->type );
575        }
576
577        bool isConstructable( Type * type ) {
578                return ! dynamic_cast< VarArgsType * >( type ) && ! dynamic_cast< ReferenceType * >( type ) && ! dynamic_cast< FunctionType * >( type ) && ! Tuples::isTtype( type );
579        }
580
581        bool tryConstruct( const ast::DeclWithType * dwt ) {
582                auto objDecl = dynamic_cast< const ast::ObjectDecl * >( dwt );
583                if ( ! objDecl ) return false;
584                return (objDecl->init == nullptr ||
585                                ( objDecl->init != nullptr && objDecl->init->maybeConstructed ))
586                        && ! objDecl->storage.is_extern
587                        && isConstructable( objDecl->type );
588        }
589
590        bool isConstructable( const ast::Type * type ) {
591                return ! dynamic_cast< const ast::VarArgsType * >( type ) && ! dynamic_cast< const ast::ReferenceType * >( type )
592                && ! dynamic_cast< const ast::FunctionType * >( type ) && ! Tuples::isTtype( type );
593        }
594
595        struct CallFinder_old {
596                CallFinder_old( const std::list< std::string > & names ) : names( names ) {}
597
598                void postvisit( ApplicationExpr * appExpr ) {
599                        handleCallExpr( appExpr );
600                }
601
602                void postvisit( UntypedExpr * untypedExpr ) {
603                        handleCallExpr( untypedExpr );
604                }
605
606                std::list< Expression * > * matches;
607        private:
608                const std::list< std::string > names;
609
610                template< typename CallExpr >
611                void handleCallExpr( CallExpr * expr ) {
612                        std::string fname = getFunctionName( expr );
613                        if ( std::find( names.begin(), names.end(), fname ) != names.end() ) {
614                                matches->push_back( expr );
615                        }
616                }
617        };
618
619        struct CallFinder_new final {
620                std::vector< const ast::Expr * > matches;
621                const std::vector< std::string > names;
622
623                CallFinder_new( std::vector< std::string > && ns ) : matches(), names( std::move(ns) ) {}
624
625                void handleCallExpr( const ast::Expr * expr ) {
626                        std::string fname = getFunctionName( expr );
627                        if ( std::find( names.begin(), names.end(), fname ) != names.end() ) {
628                                matches.emplace_back( expr );
629                        }
630                }
631
632                void postvisit( const ast::ApplicationExpr * expr ) { handleCallExpr( expr ); }
633                void postvisit( const ast::UntypedExpr *     expr ) { handleCallExpr( expr ); }
634        };
635
636        void collectCtorDtorCalls( Statement * stmt, std::list< Expression * > & matches ) {
637                static PassVisitor<CallFinder_old> finder( std::list< std::string >{ "?{}", "^?{}" } );
638                finder.pass.matches = &matches;
639                maybeAccept( stmt, finder );
640        }
641
642        std::vector< const ast::Expr * > collectCtorDtorCalls( const ast::Stmt * stmt ) {
643                ast::Pass< CallFinder_new > finder{ std::vector< std::string >{ "?{}", "^?{}" } };
644                maybe_accept( stmt, finder );
645                return std::move( finder.core.matches );
646        }
647
648        Expression * getCtorDtorCall( Statement * stmt ) {
649                std::list< Expression * > matches;
650                collectCtorDtorCalls( stmt, matches );
651                assertf( matches.size() <= 1, "%zd constructor/destructors found in %s", matches.size(), toString( stmt ).c_str() );
652                return matches.size() == 1 ? matches.front() : nullptr;
653        }
654
655        namespace {
656                DeclarationWithType * getCalledFunction( Expression * expr );
657
658                template<typename CallExpr>
659                DeclarationWithType * handleDerefCalledFunction( CallExpr * expr ) {
660                        // (*f)(x) => should get "f"
661                        std::string name = getFunctionName( expr );
662                        assertf( name == "*?", "Unexpected untyped expression: %s", name.c_str() );
663                        assertf( ! expr->get_args().empty(), "Cannot get called function from dereference with no arguments" );
664                        return getCalledFunction( expr->get_args().front() );
665                }
666
667                DeclarationWithType * getCalledFunction( Expression * expr ) {
668                        assert( expr );
669                        if ( VariableExpr * varExpr = dynamic_cast< VariableExpr * >( expr ) ) {
670                                return varExpr->var;
671                        } else if ( MemberExpr * memberExpr = dynamic_cast< MemberExpr * >( expr ) ) {
672                                return memberExpr->member;
673                        } else if ( CastExpr * castExpr = dynamic_cast< CastExpr * >( expr ) ) {
674                                return getCalledFunction( castExpr->arg );
675                        } else if ( UntypedExpr * untypedExpr = dynamic_cast< UntypedExpr * >( expr ) ) {
676                                return handleDerefCalledFunction( untypedExpr );
677                        } else if ( ApplicationExpr * appExpr = dynamic_cast< ApplicationExpr * > ( expr ) ) {
678                                return handleDerefCalledFunction( appExpr );
679                        } else if ( AddressExpr * addrExpr = dynamic_cast< AddressExpr * >( expr ) ) {
680                                return getCalledFunction( addrExpr->arg );
681                        } else if ( CommaExpr * commaExpr = dynamic_cast< CommaExpr * >( expr ) ) {
682                                return getCalledFunction( commaExpr->arg2 );
683                        }
684                        return nullptr;
685                }
686
687                DeclarationWithType * getFunctionCore( const Expression * expr ) {
688                        if ( const auto * appExpr = dynamic_cast< const ApplicationExpr * >( expr ) ) {
689                                return getCalledFunction( appExpr->function );
690                        } else if ( const auto * untyped = dynamic_cast< const UntypedExpr * >( expr ) ) {
691                                return getCalledFunction( untyped->function );
692                        }
693                        assertf( false, "getFunction with unknown expression: %s", toString( expr ).c_str() );
694                }
695        }
696
697        DeclarationWithType * getFunction( Expression * expr ) {
698                return getFunctionCore( expr );
699        }
700
701        const DeclarationWithType * getFunction( const Expression * expr ) {
702                return getFunctionCore( expr );
703        }
704
705        ApplicationExpr * isIntrinsicCallExpr( Expression * expr ) {
706                ApplicationExpr * appExpr = dynamic_cast< ApplicationExpr * >( expr );
707                if ( ! appExpr ) return nullptr;
708                DeclarationWithType * function = getCalledFunction( appExpr->get_function() );
709                assertf( function, "getCalledFunction returned nullptr: %s", toString( appExpr->get_function() ).c_str() );
710                // check for Intrinsic only - don't want to remove all overridable ctor/dtors because autogenerated ctor/dtor
711                // will call all member dtors, and some members may have a user defined dtor.
712                return function->get_linkage() == LinkageSpec::Intrinsic ? appExpr : nullptr;
713        }
714
715        namespace {
716                template <typename Predicate>
717                bool allofCtorDtor( Statement * stmt, const Predicate & pred ) {
718                        std::list< Expression * > callExprs;
719                        collectCtorDtorCalls( stmt, callExprs );
720                        return std::all_of( callExprs.begin(), callExprs.end(), pred);
721                }
722
723                template <typename Predicate>
724                bool allofCtorDtor( const ast::Stmt * stmt, const Predicate & pred ) {
725                        std::vector< const ast::Expr * > callExprs = collectCtorDtorCalls( stmt );
726                        return std::all_of( callExprs.begin(), callExprs.end(), pred );
727                }
728        }
729
730        bool isIntrinsicSingleArgCallStmt( Statement * stmt ) {
731                return allofCtorDtor( stmt, []( Expression * callExpr ){
732                        if ( ApplicationExpr * appExpr = isIntrinsicCallExpr( callExpr ) ) {
733                                FunctionType *funcType = GenPoly::getFunctionType( appExpr->function->result );
734                                assert( funcType );
735                                return funcType->get_parameters().size() == 1;
736                        }
737                        return false;
738                });
739        }
740
741        bool isIntrinsicSingleArgCallStmt( const ast::Stmt * stmt ) {
742                return allofCtorDtor( stmt, []( const ast::Expr * callExpr ){
743                        if ( const ast::ApplicationExpr * appExpr = isIntrinsicCallExpr( callExpr ) ) {
744                                const ast::FunctionType * funcType =
745                                        GenPoly::getFunctionType( appExpr->func->result );
746                                assert( funcType );
747                                return funcType->params.size() == 1;
748                        }
749                        return false;
750                });
751        }
752
753        bool isIntrinsicCallStmt( Statement * stmt ) {
754                return allofCtorDtor( stmt, []( Expression * callExpr ) {
755                        return isIntrinsicCallExpr( callExpr );
756                });
757        }
758
759        namespace {
760                template<typename CallExpr>
761                Expression *& callArg( CallExpr * callExpr, unsigned int pos ) {
762                        if ( pos >= callExpr->get_args().size() ) assertf( false, "getCallArg for argument that doesn't exist: (%u); %s.", pos, toString( callExpr ).c_str() );
763                        for ( Expression *& arg : callExpr->get_args() ) {
764                                if ( pos == 0 ) return arg;
765                                pos--;
766                        }
767                        assert( false );
768                }
769        }
770
771        Expression *& getCallArg( Expression * callExpr, unsigned int pos ) {
772                if ( ApplicationExpr * appExpr = dynamic_cast< ApplicationExpr * >( callExpr ) ) {
773                        return callArg( appExpr, pos );
774                } else if ( UntypedExpr * untypedExpr = dynamic_cast< UntypedExpr * >( callExpr ) ) {
775                        return callArg( untypedExpr, pos );
776                } else if ( TupleAssignExpr * tupleExpr = dynamic_cast< TupleAssignExpr * > ( callExpr ) ) {
777                        std::list< Statement * > & stmts = tupleExpr->get_stmtExpr()->get_statements()->get_kids();
778                        assertf( ! stmts.empty(), "TupleAssignExpr somehow has no statements." );
779                        ExprStmt * stmt = strict_dynamic_cast< ExprStmt * >( stmts.back() );
780                        TupleExpr * tuple = strict_dynamic_cast< TupleExpr * >( stmt->get_expr() );
781                        assertf( ! tuple->get_exprs().empty(), "TupleAssignExpr somehow has empty tuple expr." );
782                        return getCallArg( tuple->get_exprs().front(), pos );
783                } else if ( ImplicitCopyCtorExpr * copyCtor = dynamic_cast< ImplicitCopyCtorExpr * >( callExpr ) ) {
784                        return getCallArg( copyCtor->callExpr, pos );
785                } else {
786                        assertf( false, "Unexpected expression type passed to getCallArg: %s", toString( callExpr ).c_str() );
787                }
788        }
789
790        namespace {
791                std::string funcName( Expression * func );
792
793                template<typename CallExpr>
794                std::string handleDerefName( CallExpr * expr ) {
795                        // (*f)(x) => should get name "f"
796                        std::string name = getFunctionName( expr );
797                        assertf( name == "*?", "Unexpected untyped expression: %s", name.c_str() );
798                        assertf( ! expr->get_args().empty(), "Cannot get function name from dereference with no arguments" );
799                        return funcName( expr->get_args().front() );
800                }
801
802                std::string funcName( Expression * func ) {
803                        if ( NameExpr * nameExpr = dynamic_cast< NameExpr * >( func ) ) {
804                                return nameExpr->get_name();
805                        } else if ( VariableExpr * varExpr = dynamic_cast< VariableExpr * >( func ) ) {
806                                return varExpr->get_var()->get_name();
807                        } else if ( CastExpr * castExpr = dynamic_cast< CastExpr * >( func ) ) {
808                                return funcName( castExpr->get_arg() );
809                        } else if ( MemberExpr * memberExpr = dynamic_cast< MemberExpr * >( func ) ) {
810                                return memberExpr->get_member()->get_name();
811                        } else if ( UntypedMemberExpr * memberExpr = dynamic_cast< UntypedMemberExpr * > ( func ) ) {
812                                return funcName( memberExpr->get_member() );
813                        } else if ( UntypedExpr * untypedExpr = dynamic_cast< UntypedExpr * >( func ) ) {
814                                return handleDerefName( untypedExpr );
815                        } else if ( ApplicationExpr * appExpr = dynamic_cast< ApplicationExpr * >( func ) ) {
816                                return handleDerefName( appExpr );
817                        } else if ( ConstructorExpr * ctorExpr = dynamic_cast< ConstructorExpr * >( func ) ) {
818                                return funcName( getCallArg( ctorExpr->get_callExpr(), 0 ) );
819                        } else {
820                                assertf( false, "Unexpected expression type being called as a function in call expression: %s", toString( func ).c_str() );
821                        }
822                }
823        }
824
825        std::string getFunctionName( Expression * expr ) {
826                // there's some unforunate overlap here with getCalledFunction. Ideally this would be able to use getCalledFunction and
827                // return the name of the DeclarationWithType, but this needs to work for NameExpr and UntypedMemberExpr, where getCalledFunction
828                // can't possibly do anything reasonable.
829                if ( ApplicationExpr * appExpr = dynamic_cast< ApplicationExpr * >( expr ) ) {
830                        return funcName( appExpr->get_function() );
831                } else if ( UntypedExpr * untypedExpr = dynamic_cast< UntypedExpr * > ( expr ) ) {
832                        return funcName( untypedExpr->get_function() );
833                } else {
834                        std::cerr << expr << std::endl;
835                        assertf( false, "Unexpected expression type passed to getFunctionName" );
836                }
837        }
838
839        Type * getPointerBase( Type * type ) {
840                if ( PointerType * ptrType = dynamic_cast< PointerType * >( type ) ) {
841                        return ptrType->get_base();
842                } else if ( ArrayType * arrayType = dynamic_cast< ArrayType * >( type ) ) {
843                        return arrayType->get_base();
844                } else if ( ReferenceType * refType = dynamic_cast< ReferenceType * >( type ) ) {
845                        return refType->get_base();
846                } else {
847                        return nullptr;
848                }
849        }
850
851        Type * isPointerType( Type * type ) {
852                return getPointerBase( type ) ? type : nullptr;
853        }
854
855        ApplicationExpr * createBitwiseAssignment( Expression * dst, Expression * src ) {
856                static FunctionDecl * assign = nullptr;
857                if ( ! assign ) {
858                        // temporary? Generate a fake assignment operator to represent bitwise assignments.
859                        // This operator could easily exist as a real function, but it's tricky because nothing should resolve to this function.
860                        TypeDecl * td = new TypeDecl( "T", noStorageClasses, nullptr, TypeDecl::Dtype, true );
861                        assign = new FunctionDecl( "?=?", noStorageClasses, LinkageSpec::Intrinsic, SymTab::genAssignType( new TypeInstType( noQualifiers, td->name, td ) ), nullptr );
862                }
863                if ( dynamic_cast< ReferenceType * >( dst->result ) ) {
864                        for (int depth = dst->result->referenceDepth(); depth > 0; depth--) {
865                                dst = new AddressExpr( dst );
866                        }
867                } else {
868                        dst = new CastExpr( dst, new ReferenceType( noQualifiers, dst->result->clone() ) );
869                }
870                if ( dynamic_cast< ReferenceType * >( src->result ) ) {
871                        for (int depth = src->result->referenceDepth(); depth > 0; depth--) {
872                                src = new AddressExpr( src );
873                        }
874                }
875                return new ApplicationExpr( VariableExpr::functionPointer( assign ), { dst, src } );
876        }
877
878        // looks like some other such codegen uses UntypedExpr and does not create fake function. should revisit afterwards
879        // following passes may accidentally resolve this expression if returned as untyped...
880        ast::Expr * createBitwiseAssignment (const ast::Expr * dst, const ast::Expr * src) {
881                static ast::ptr<ast::FunctionDecl> assign = nullptr;
882                if (!assign) {
883                        auto td = new ast::TypeDecl(CodeLocation(), "T", {}, nullptr, ast::TypeDecl::Dtype, true);
884                        assign = new ast::FunctionDecl(CodeLocation(), "?=?", {td},
885                        { new ast::ObjectDecl(CodeLocation(), "_dst", new ast::ReferenceType(new ast::TypeInstType("T", td))),
886                          new ast::ObjectDecl(CodeLocation(), "_src", new ast::TypeInstType("T", td))},
887                        { new ast::ObjectDecl(CodeLocation(), "_ret", new ast::TypeInstType("T", td))}, nullptr, {}, ast::Linkage::Intrinsic);
888                }
889                if (dst->result.as<ast::ReferenceType>()) {
890                        for (int depth = dst->result->referenceDepth(); depth > 0; depth--) {
891                                dst = new ast::AddressExpr(dst);
892                        }
893                }
894                else {
895                        dst = new ast::CastExpr(dst, new ast::ReferenceType(dst->result, {}));
896                }
897                if (src->result.as<ast::ReferenceType>()) {
898                        for (int depth = src->result->referenceDepth(); depth > 0; depth--) {
899                                src = new ast::AddressExpr(src);
900                        }
901                }
902                return new ast::ApplicationExpr(dst->location, ast::VariableExpr::functionPointer(dst->location, assign), {dst, src});
903        }
904
905        struct ConstExprChecker : public WithShortCircuiting {
906                // most expressions are not const expr
907                void previsit( Expression * ) { isConstExpr = false; visit_children = false; }
908
909                void previsit( AddressExpr *addressExpr ) {
910                        visit_children = false;
911
912                        // address of a variable or member expression is constexpr
913                        Expression * arg = addressExpr->get_arg();
914                        if ( ! dynamic_cast< NameExpr * >( arg) && ! dynamic_cast< VariableExpr * >( arg ) && ! dynamic_cast< MemberExpr * >( arg ) && ! dynamic_cast< UntypedMemberExpr * >( arg ) ) isConstExpr = false;
915                }
916
917                // these expressions may be const expr, depending on their children
918                void previsit( SizeofExpr * ) {}
919                void previsit( AlignofExpr * ) {}
920                void previsit( UntypedOffsetofExpr * ) {}
921                void previsit( OffsetofExpr * ) {}
922                void previsit( OffsetPackExpr * ) {}
923                void previsit( CommaExpr * ) {}
924                void previsit( LogicalExpr * ) {}
925                void previsit( ConditionalExpr * ) {}
926                void previsit( CastExpr * ) {}
927                void previsit( ConstantExpr * ) {}
928
929                void previsit( VariableExpr * varExpr ) {
930                        visit_children = false;
931
932                        if ( EnumInstType * inst = dynamic_cast< EnumInstType * >( varExpr->result ) ) {
933                                long long int value;
934                                if ( inst->baseEnum->valueOf( varExpr->var, value ) ) {
935                                        // enumerators are const expr
936                                        return;
937                                }
938                        }
939                        isConstExpr = false;
940                }
941
942                bool isConstExpr = true;
943        };
944
945        struct ConstExprChecker_new : public ast::WithShortCircuiting {
946                // most expressions are not const expr
947                void previsit( const ast::Expr * ) { result = false; visit_children = false; }
948
949                void previsit( const ast::AddressExpr *addressExpr ) {
950                        visit_children = false;
951                        const ast::Expr * arg = addressExpr->arg;
952
953                        // address of a variable or member expression is constexpr
954                        if ( ! dynamic_cast< const ast::NameExpr * >( arg )
955                        && ! dynamic_cast< const ast::VariableExpr * >( arg )
956                        && ! dynamic_cast< const ast::MemberExpr * >( arg )
957                        && ! dynamic_cast< const ast::UntypedMemberExpr * >( arg ) ) result = false;
958                }
959
960                // these expressions may be const expr, depending on their children
961                void previsit( const ast::SizeofExpr * ) {}
962                void previsit( const ast::AlignofExpr * ) {}
963                void previsit( const ast::UntypedOffsetofExpr * ) {}
964                void previsit( const ast::OffsetofExpr * ) {}
965                void previsit( const ast::OffsetPackExpr * ) {}
966                void previsit( const ast::CommaExpr * ) {}
967                void previsit( const ast::LogicalExpr * ) {}
968                void previsit( const ast::ConditionalExpr * ) {}
969                void previsit( const ast::CastExpr * ) {}
970                void previsit( const ast::ConstantExpr * ) {}
971
972                void previsit( const ast::VariableExpr * varExpr ) {
973                        visit_children = false;
974
975                        if ( auto inst = varExpr->result.as<ast::EnumInstType>() ) {
976                                long long int value;
977                                if ( inst->base->valueOf( varExpr->var, value ) ) {
978                                        // enumerators are const expr
979                                        return;
980                                }
981                        }
982                        result = false;
983                }
984
985                bool result = true;
986        };
987
988        bool isConstExpr( Expression * expr ) {
989                if ( expr ) {
990                        PassVisitor<ConstExprChecker> checker;
991                        expr->accept( checker );
992                        return checker.pass.isConstExpr;
993                }
994                return true;
995        }
996
997        bool isConstExpr( Initializer * init ) {
998                if ( init ) {
999                        PassVisitor<ConstExprChecker> checker;
1000                        init->accept( checker );
1001                        return checker.pass.isConstExpr;
1002                } // if
1003                // for all intents and purposes, no initializer means const expr
1004                return true;
1005        }
1006
1007        bool isConstExpr( const ast::Expr * expr ) {
1008                if ( expr ) {
1009                        ast::Pass<ConstExprChecker_new> checker;
1010                        expr->accept( checker );
1011                        return checker.core.result;
1012                }
1013                return true;
1014        }
1015
1016        bool isConstExpr( const ast::Init * init ) {
1017                if ( init ) {
1018                        ast::Pass<ConstExprChecker_new> checker;
1019                        init->accept( checker );
1020                        return checker.core.result;
1021                } // if
1022                // for all intents and purposes, no initializer means const expr
1023                return true;
1024        }
1025
1026        const FunctionDecl * isCopyFunction( const Declaration * decl, const std::string & fname ) {
1027                const FunctionDecl * function = dynamic_cast< const FunctionDecl * >( decl );
1028                if ( ! function ) return nullptr;
1029                if ( function->name != fname ) return nullptr;
1030                FunctionType * ftype = function->type;
1031                if ( ftype->parameters.size() != 2 ) return nullptr;
1032
1033                Type * t1 = getPointerBase( ftype->get_parameters().front()->get_type() );
1034                Type * t2 = ftype->parameters.back()->get_type();
1035                assert( t1 );
1036
1037                if ( ResolvExpr::typesCompatibleIgnoreQualifiers( t1, t2, SymTab::Indexer() ) ) {
1038                        return function;
1039                } else {
1040                        return nullptr;
1041                }
1042        }
1043
1044bool isAssignment( const ast::FunctionDecl * decl ) {
1045        return CodeGen::isAssignment( decl->name ) && isCopyFunction( decl );
1046}
1047
1048bool isDestructor( const ast::FunctionDecl * decl ) {
1049        return CodeGen::isDestructor( decl->name );
1050}
1051
1052bool isDefaultConstructor( const ast::FunctionDecl * decl ) {
1053        return CodeGen::isConstructor( decl->name ) && 1 == decl->params.size();
1054}
1055
1056bool isCopyConstructor( const ast::FunctionDecl * decl ) {
1057        return CodeGen::isConstructor( decl->name ) && 2 == decl->params.size();
1058}
1059
1060bool isCopyFunction( const ast::FunctionDecl * decl ) {
1061        const ast::FunctionType * ftype = decl->type;
1062        if ( ftype->params.size() != 2 ) return false;
1063
1064        const ast::Type * t1 = ast::getPointerBase( ftype->params.front() );
1065        if ( ! t1 ) return false;
1066        const ast::Type * t2 = ftype->params.back();
1067
1068        return ResolvExpr::typesCompatibleIgnoreQualifiers( t1, t2 );
1069}
1070
1071
1072        const FunctionDecl * isAssignment( const Declaration * decl ) {
1073                return isCopyFunction( decl, "?=?" );
1074        }
1075        const FunctionDecl * isDestructor( const Declaration * decl ) {
1076                if ( CodeGen::isDestructor( decl->name ) ) {
1077                        return dynamic_cast< const FunctionDecl * >( decl );
1078                }
1079                return nullptr;
1080        }
1081        const FunctionDecl * isDefaultConstructor( const Declaration * decl ) {
1082                if ( CodeGen::isConstructor( decl->name ) ) {
1083                        if ( const FunctionDecl * func = dynamic_cast< const FunctionDecl * >( decl ) ) {
1084                                if ( func->type->parameters.size() == 1 ) {
1085                                        return func;
1086                                }
1087                        }
1088                }
1089                return nullptr;
1090        }
1091        const FunctionDecl * isCopyConstructor( const Declaration * decl ) {
1092                return isCopyFunction( decl, "?{}" );
1093        }
1094
1095        #if defined( __x86_64 ) || defined( __i386 ) // assembler comment to prevent assembler warning message
1096                #define ASM_COMMENT "#"
1097        #else // defined( __ARM_ARCH )
1098                #define ASM_COMMENT "//"
1099        #endif
1100        static const char * const data_section =  ".data" ASM_COMMENT;
1101        static const char * const tlsd_section = ".tdata" ASM_COMMENT;
1102        void addDataSectionAttribute( ObjectDecl * objDecl ) {
1103                const bool is_tls = objDecl->get_storageClasses().is_threadlocal_any();
1104                const char * section = is_tls ? tlsd_section : data_section;
1105                objDecl->attributes.push_back(new Attribute("section", {
1106                        new ConstantExpr( Constant::from_string( section ) )
1107                }));
1108        }
1109
1110        void addDataSectionAttribute( ast::ObjectDecl * objDecl ) {
1111                const bool is_tls = objDecl->storage.is_threadlocal_any();
1112                const char * section = is_tls ? tlsd_section : data_section;
1113                objDecl->attributes.push_back(new ast::Attribute("section", {
1114                        ast::ConstantExpr::from_string(objDecl->location, section)
1115                }));
1116        }
1117
1118}
Note: See TracBrowser for help on using the repository browser.