source: src/InitTweak/InitTweak.cc @ 44f6341

ADTaaron-thesisarm-ehast-experimentalcleanup-dtorsctordeferred_resndemanglerenumforall-pointer-decayjacob/cs343-translationjenkins-sandboxnew-astnew-ast-unique-exprnew-envno_listpersistent-indexerpthread-emulationqualifiedEnumresolv-newwith_gc
Last change on this file since 44f6341 was 4d4882a, checked in by Rob Schluntz <rschlunt@…>, 8 years ago

implicitly insert missing copy constructors when appropriate, update test output

  • Property mode set to 100644
File size: 16.1 KB
Line 
1#include <algorithm>
2#include "InitTweak.h"
3#include "SynTree/Visitor.h"
4#include "SynTree/Statement.h"
5#include "SynTree/Initializer.h"
6#include "SynTree/Expression.h"
7#include "SynTree/Attribute.h"
8#include "GenPoly/GenPoly.h"
9#include "ResolvExpr/typeops.h"
10
11namespace InitTweak {
12        namespace {
13                class HasDesignations : public Visitor {
14                public:
15                        bool hasDesignations = false;
16                        template<typename Init>
17                        void handleInit( Init * init ) {
18                                if ( ! init->get_designators().empty() ) hasDesignations = true;
19                                else Visitor::visit( init );
20                        }
21                        virtual void visit( SingleInit * singleInit ) { handleInit( singleInit); }
22                        virtual void visit( ListInit * listInit ) { handleInit( listInit); }
23                };
24
25                class InitFlattener : public Visitor {
26                        public:
27                        virtual void visit( SingleInit * singleInit );
28                        virtual void visit( ListInit * listInit );
29                        std::list< Expression * > argList;
30                };
31
32                void InitFlattener::visit( SingleInit * singleInit ) {
33                        argList.push_back( singleInit->get_value()->clone() );
34                }
35
36                void InitFlattener::visit( ListInit * listInit ) {
37                        // flatten nested list inits
38                        std::list<Initializer*>::iterator it = listInit->begin();
39                        for ( ; it != listInit->end(); ++it ) {
40                                (*it)->accept( *this );
41                        }
42                }
43        }
44
45        std::list< Expression * > makeInitList( Initializer * init ) {
46                InitFlattener flattener;
47                maybeAccept( init, flattener );
48                return flattener.argList;
49        }
50
51        bool isDesignated( Initializer * init ) {
52                HasDesignations finder;
53                maybeAccept( init, finder );
54                return finder.hasDesignations;
55        }
56
57        class InitExpander::ExpanderImpl {
58        public:
59                virtual std::list< Expression * > next( std::list< Expression * > & indices ) = 0;
60                virtual Statement * buildListInit( UntypedExpr * callExpr, std::list< Expression * > & indices ) = 0;
61        };
62
63        class InitImpl : public InitExpander::ExpanderImpl {
64        public:
65                InitImpl( Initializer * init ) : init( init ) {}
66
67                virtual std::list< Expression * > next( std::list< Expression * > & indices ) {
68                        // this is wrong, but just a placeholder for now
69                        // if ( ! flattened ) flatten( indices );
70                        // return ! inits.empty() ? makeInitList( inits.front() ) : std::list< Expression * >();
71                        return makeInitList( init );
72                }
73
74                virtual Statement * buildListInit( UntypedExpr * callExpr, std::list< Expression * > & indices );
75        private:
76                Initializer * init;
77        };
78
79        class ExprImpl : public InitExpander::ExpanderImpl {
80        public:
81                ExprImpl( Expression * expr ) : arg( expr ) {}
82
83                ~ExprImpl() { delete arg; }
84
85                virtual std::list< Expression * > next( std::list< Expression * > & indices ) {
86                        std::list< Expression * > ret;
87                        Expression * expr = maybeClone( arg );
88                        if ( expr ) {
89                                for ( std::list< Expression * >::reverse_iterator it = indices.rbegin(); it != indices.rend(); ++it ) {
90                                        // go through indices and layer on subscript exprs ?[?]
91                                        ++it;
92                                        UntypedExpr * subscriptExpr = new UntypedExpr( new NameExpr( "?[?]") );
93                                        subscriptExpr->get_args().push_back( expr );
94                                        subscriptExpr->get_args().push_back( (*it)->clone() );
95                                        expr = subscriptExpr;
96                                }
97                                ret.push_back( expr );
98                        }
99                        return ret;
100                }
101
102                virtual Statement * buildListInit( UntypedExpr * callExpr, std::list< Expression * > & indices );
103        private:
104                Expression * arg;
105        };
106
107        InitExpander::InitExpander( Initializer * init ) : expander( new InitImpl( init ) ) {}
108
109        InitExpander::InitExpander( Expression * expr ) : expander( new ExprImpl( expr ) ) {}
110
111        std::list< Expression * > InitExpander::operator*() {
112                return cur;
113        }
114
115        InitExpander & InitExpander::operator++() {
116                cur = expander->next( indices );
117                return *this;
118        }
119
120        // use array indices list to build switch statement
121        void InitExpander::addArrayIndex( Expression * index, Expression * dimension ) {
122                indices.push_back( index );
123                indices.push_back( dimension );
124        }
125
126        void InitExpander::clearArrayIndices() {
127                deleteAll( indices );
128                indices.clear();
129        }
130
131        namespace {
132                /// given index i, dimension d, initializer init, and callExpr f, generates
133                ///   if (i < d) f(..., init)
134                ///   ++i;
135                /// so that only elements within the range of the array are constructed
136                template< typename OutIterator >
137                void buildCallExpr( UntypedExpr * callExpr, Expression * index, Expression * dimension, Initializer * init, OutIterator out ) {
138                        UntypedExpr * cond = new UntypedExpr( new NameExpr( "?<?") );
139                        cond->get_args().push_back( index->clone() );
140                        cond->get_args().push_back( dimension->clone() );
141
142                        std::list< Expression * > args = makeInitList( init );
143                        callExpr->get_args().splice( callExpr->get_args().end(), args );
144
145                        *out++ = new IfStmt( noLabels, cond, new ExprStmt( noLabels, callExpr ), NULL );
146
147                        UntypedExpr * increment = new UntypedExpr( new NameExpr( "++?" ) );
148                        increment->get_args().push_back( new AddressExpr( index->clone() ) );
149                        *out++ = new ExprStmt( noLabels, increment );
150                }
151
152                template< typename OutIterator >
153                void build( UntypedExpr * callExpr, InitExpander::IndexList::iterator idx, InitExpander::IndexList::iterator idxEnd, Initializer * init, OutIterator out ) {
154                        if ( idx == idxEnd ) return;
155                        Expression * index = *idx++;
156                        assert( idx != idxEnd );
157                        Expression * dimension = *idx++;
158
159                        // xxx - may want to eventually issue a warning here if we can detect
160                        // that the number of elements exceeds to dimension of the array
161                        if ( idx == idxEnd ) {
162                                if ( ListInit * listInit = dynamic_cast< ListInit * >( init ) ) {
163                                        for ( Initializer * init : *listInit ) {
164                                                buildCallExpr( callExpr->clone(), index, dimension, init, out );
165                                        }
166                                } else {
167                                        buildCallExpr( callExpr->clone(), index, dimension, init, out );
168                                }
169                        } else {
170                                std::list< Statement * > branches;
171
172                                unsigned long cond = 0;
173                                ListInit * listInit = dynamic_cast< ListInit * >( init );
174                                if ( ! listInit ) {
175                                        // xxx - this shouldn't be an error, but need a way to
176                                        // terminate without creating output, so should catch this error
177                                        throw SemanticError( "unbalanced list initializers" );
178                                }
179
180                                static UniqueName targetLabel( "L__autogen__" );
181                                Label switchLabel( targetLabel.newName(), 0, std::list< Attribute * >{ new Attribute("unused") } );
182                                for ( Initializer * init : *listInit ) {
183                                        Expression * condition;
184                                        // check for designations
185                                        // if ( init-> ) {
186                                                condition = new ConstantExpr( Constant::from_ulong( cond ) );
187                                                ++cond;
188                                        // } else {
189                                        //      condition = // ... take designation
190                                        //      cond = // ... take designation+1
191                                        // }
192                                        std::list< Statement * > stmts;
193                                        build( callExpr, idx, idxEnd, init, back_inserter( stmts ) );
194                                        stmts.push_back( new BranchStmt( noLabels, switchLabel, BranchStmt::Break ) );
195                                        CaseStmt * caseStmt = new CaseStmt( noLabels, condition, stmts );
196                                        branches.push_back( caseStmt );
197                                }
198                                *out++ = new SwitchStmt( noLabels, index->clone(), branches );
199                                *out++ = new NullStmt( std::list<Label>{ switchLabel } );
200                        }
201                }
202        }
203
204        // if array came with an initializer list: initialize each element
205        // may have more initializers than elements in the array - need to check at each index that
206        // we haven't exceeded size.
207        // may have fewer initializers than elements in the array - need to default construct
208        // remaining elements.
209        // To accomplish this, generate switch statement, consuming all of expander's elements
210        Statement * InitImpl::buildListInit( UntypedExpr * dst, std::list< Expression * > & indices ) {
211                if ( ! init ) return NULL;
212                CompoundStmt * block = new CompoundStmt( noLabels );
213                build( dst, indices.begin(), indices.end(), init, back_inserter( block->get_kids() ) );
214                if ( block->get_kids().empty() ) {
215                        delete block;
216                        return NULL;
217                } else {
218                        init = NULL; // init was consumed in creating the list init
219                        return block;
220                }
221        }
222
223        Statement * ExprImpl::buildListInit( UntypedExpr * dst, std::list< Expression * > & indices ) {
224                return NULL;
225        }
226
227        Statement * InitExpander::buildListInit( UntypedExpr * dst ) {
228                return expander->buildListInit( dst, indices );
229        }
230
231        bool tryConstruct( ObjectDecl * objDecl ) {
232                return ! LinkageSpec::isBuiltin( objDecl->get_linkage() ) &&
233                        (objDecl->get_init() == NULL ||
234                                ( objDecl->get_init() != NULL && objDecl->get_init()->get_maybeConstructed() )) &&
235                        ! isDesignated( objDecl->get_init() )
236                        && objDecl->get_storageClass() != DeclarationNode::Extern;
237        }
238
239        class CallFinder : public Visitor {
240        public:
241                typedef Visitor Parent;
242                CallFinder( const std::list< std::string > & names ) : names( names ) {}
243
244                virtual void visit( ApplicationExpr * appExpr ) {
245                        handleCallExpr( appExpr );
246                }
247
248                virtual void visit( UntypedExpr * untypedExpr ) {
249                        handleCallExpr( untypedExpr );
250                }
251
252                std::list< Expression * > * matches;
253        private:
254                const std::list< std::string > names;
255
256                template< typename CallExpr >
257                void handleCallExpr( CallExpr * expr ) {
258                        Parent::visit( expr );
259                        std::string fname = getFunctionName( expr );
260                        if ( std::find( names.begin(), names.end(), fname ) != names.end() ) {
261                                matches->push_back( expr );
262                        }
263                }
264        };
265
266        void collectCtorDtorCalls( Statement * stmt, std::list< Expression * > & matches ) {
267                static CallFinder finder( std::list< std::string >{ "?{}", "^?{}" } );
268                finder.matches = &matches;
269                maybeAccept( stmt, finder );
270        }
271
272        Expression * getCtorDtorCall( Statement * stmt ) {
273                std::list< Expression * > matches;
274                collectCtorDtorCalls( stmt, matches );
275                assert( matches.size() <= 1 );
276                return matches.size() == 1 ? matches.front() : NULL;
277        }
278
279        namespace {
280                VariableExpr * getCalledFunction( ApplicationExpr * appExpr ) {
281                        assert( appExpr );
282                        // xxx - it's possible this can be other things, e.g. MemberExpr, so this is insufficient
283                        return dynamic_cast< VariableExpr * >( appExpr->get_function() );
284                }
285        }
286
287        ApplicationExpr * isIntrinsicCallExpr( Expression * expr ) {
288                ApplicationExpr * appExpr = dynamic_cast< ApplicationExpr * >( expr );
289                if ( ! appExpr ) return NULL;
290                VariableExpr * function = getCalledFunction( appExpr );
291                assert( function );
292                // check for Intrinsic only - don't want to remove all overridable ctor/dtors because autogenerated ctor/dtor
293                // will call all member dtors, and some members may have a user defined dtor.
294                return function->get_var()->get_linkage() == LinkageSpec::Intrinsic ? appExpr : NULL;
295        }
296
297        namespace {
298                template <typename Predicate>
299                bool allofCtorDtor( Statement * stmt, const Predicate & pred ) {
300                        std::list< Expression * > callExprs;
301                        collectCtorDtorCalls( stmt, callExprs );
302                        // if ( callExprs.empty() ) return false; // xxx - do I still need this check?
303                        return std::all_of( callExprs.begin(), callExprs.end(), pred);
304                }
305        }
306
307        bool isIntrinsicSingleArgCallStmt( Statement * stmt ) {
308                return allofCtorDtor( stmt, []( Expression * callExpr ){
309                        if ( ApplicationExpr * appExpr = isIntrinsicCallExpr( callExpr ) ) {
310                                assert( ! appExpr->get_function()->get_results().empty() );
311                                FunctionType *funcType = GenPoly::getFunctionType( appExpr->get_function()->get_results().front() );
312                                assert( funcType );
313                                return funcType->get_parameters().size() == 1;
314                        }
315                        return false;
316                });
317        }
318
319        bool isIntrinsicCallStmt( Statement * stmt ) {
320                return allofCtorDtor( stmt, []( Expression * callExpr ) {
321                        return isIntrinsicCallExpr( callExpr );
322                });
323        }
324
325        namespace {
326                template<typename CallExpr>
327                Expression *& callArg( CallExpr * callExpr, unsigned int pos ) {
328                        if ( pos >= callExpr->get_args().size() ) assert( false && "asking for argument that doesn't exist. Return NULL/throw exception?" );
329                        for ( Expression *& arg : callExpr->get_args() ) {
330                                if ( pos == 0 ) return arg;
331                                pos--;
332                        }
333                        assert( false );
334                }
335        }
336
337        Expression *& getCallArg( Expression * callExpr, unsigned int pos ) {
338                if ( ApplicationExpr * appExpr = dynamic_cast< ApplicationExpr * >( callExpr ) ) {
339                        return callArg( appExpr, pos );
340                } else if ( UntypedExpr * untypedExpr = dynamic_cast< UntypedExpr * >( callExpr ) ) {
341                        return callArg( untypedExpr, pos );
342                } else {
343                        assert( false && "Unexpected expression type passed to getCallArg" );
344                }
345        }
346
347        namespace {
348                std::string funcName( Expression * func ) {
349                        if ( NameExpr * nameExpr = dynamic_cast< NameExpr * >( func ) ) {
350                                return nameExpr->get_name();
351                        } else if ( VariableExpr * varExpr = dynamic_cast< VariableExpr * >( func ) ) {
352                                return varExpr->get_var()->get_name();
353                        }       else if ( CastExpr * castExpr = dynamic_cast< CastExpr * >( func ) ) {
354                                return funcName( castExpr->get_arg() );
355                        } else {
356                                assert( false && "Unexpected expression type being called as a function in call expression" );
357                        }
358                }
359        }
360
361        std::string getFunctionName( Expression * expr ) {
362                if ( ApplicationExpr * appExpr = dynamic_cast< ApplicationExpr * >( expr ) ) {
363                        return funcName( appExpr->get_function() );
364                } else if ( UntypedExpr * untypedExpr = dynamic_cast< UntypedExpr * > ( expr ) ) {
365                        return funcName( untypedExpr->get_function() );
366                } else {
367                        std::cerr << expr << std::endl;
368                        assert( false && "Unexpected expression type passed to getFunctionName" );
369                }
370        }
371
372        Type * getPointerBase( Type * type ) {
373                if ( PointerType * ptrType = dynamic_cast< PointerType * >( type ) ) {
374                        return ptrType->get_base();
375                } else if ( ArrayType * arrayType = dynamic_cast< ArrayType * >( type ) ) {
376                        return arrayType->get_base();
377                } else {
378                        return NULL;
379                }
380        }
381
382        Type * isPointerType( Type * type ) {
383                if ( getPointerBase( type ) ) return type;
384                else return NULL;
385        }
386
387        class ConstExprChecker : public Visitor {
388        public:
389                ConstExprChecker() : isConstExpr( true ) {}
390
391                virtual void visit( ApplicationExpr *applicationExpr ) { isConstExpr = false; }
392                virtual void visit( UntypedExpr *untypedExpr ) { isConstExpr = false; }
393                virtual void visit( NameExpr *nameExpr ) { isConstExpr = false; }
394                virtual void visit( CastExpr *castExpr ) { isConstExpr = false; }
395                virtual void visit( LabelAddressExpr *labAddressExpr ) { isConstExpr = false; }
396                virtual void visit( UntypedMemberExpr *memberExpr ) { isConstExpr = false; }
397                virtual void visit( MemberExpr *memberExpr ) { isConstExpr = false; }
398                virtual void visit( VariableExpr *variableExpr ) { isConstExpr = false; }
399                virtual void visit( ConstantExpr *constantExpr ) { /* bottom out */ }
400                // these might be okay?
401                // virtual void visit( SizeofExpr *sizeofExpr );
402                // virtual void visit( AlignofExpr *alignofExpr );
403                // virtual void visit( UntypedOffsetofExpr *offsetofExpr );
404                // virtual void visit( OffsetofExpr *offsetofExpr );
405                // virtual void visit( OffsetPackExpr *offsetPackExpr );
406                // virtual void visit( AttrExpr *attrExpr );
407                // virtual void visit( CommaExpr *commaExpr );
408                // virtual void visit( LogicalExpr *logicalExpr );
409                // virtual void visit( ConditionalExpr *conditionalExpr );
410                virtual void visit( TupleExpr *tupleExpr ) { isConstExpr = false; }
411                virtual void visit( SolvedTupleExpr *tupleExpr ) { isConstExpr = false; }
412                virtual void visit( TypeExpr *typeExpr ) { isConstExpr = false; }
413                virtual void visit( AsmExpr *asmExpr ) { isConstExpr = false; }
414                virtual void visit( UntypedValofExpr *valofExpr ) { isConstExpr = false; }
415                virtual void visit( CompoundLiteralExpr *compLitExpr ) { isConstExpr = false; }
416
417                bool isConstExpr;
418        };
419
420        bool isConstExpr( Expression * expr ) {
421                if ( expr ) {
422                        ConstExprChecker checker;
423                        expr->accept( checker );
424                        return checker.isConstExpr;
425                }
426                return true;
427        }
428
429        bool isConstExpr( Initializer * init ) {
430                if ( init ) {
431                        ConstExprChecker checker;
432                        init->accept( checker );
433                        return checker.isConstExpr;
434                } // if
435                // for all intents and purposes, no initializer means const expr
436                return true;
437        }
438
439        bool isConstructor( const std::string & str ) { return str == "?{}"; }
440        bool isDestructor( const std::string & str ) { return str == "^?{}"; }
441        bool isCtorDtor( const std::string & str ) { return isConstructor( str ) || isDestructor( str ); }
442
443        FunctionDecl * isCopyConstructor( Declaration * decl ) {
444                FunctionDecl * function = dynamic_cast< FunctionDecl * >( decl );
445                if ( ! function ) return 0;
446                if ( ! isConstructor( function->get_name() ) ) return 0;
447                FunctionType * ftype = function->get_functionType();
448                if ( ftype->get_parameters().size() != 2 ) return 0;
449
450                Type * t1 = ftype->get_parameters().front()->get_type();
451                Type * t2 = ftype->get_parameters().back()->get_type();
452                PointerType * ptrType = dynamic_cast< PointerType * > ( t1 );
453                assert( ptrType );
454
455                if ( ResolvExpr::typesCompatible( ptrType->get_base(), t2, SymTab::Indexer() ) ) {
456                        return function;
457                } else {
458                        return 0;
459                }
460        }
461}
Note: See TracBrowser for help on using the repository browser.