Ignore:
File:
1 edited

Legend:

Unmodified
Added
Removed
  • src/Tuples/TupleExpansion.cc

    rc6b4432 r9939dc3  
    2323#include "AST/Node.hpp"
    2424#include "AST/Type.hpp"
     25#include "Common/PassVisitor.h"   // for PassVisitor, WithDeclsToAdd, WithGu...
    2526#include "Common/ScopedMap.h"     // for ScopedMap
    2627#include "Common/utility.h"       // for CodeLocation
    2728#include "InitTweak/InitTweak.h"  // for getFunction
     29#include "SynTree/LinkageSpec.h"  // for Spec, C, Intrinsic
     30#include "SynTree/Constant.h"     // for Constant
     31#include "SynTree/Declaration.h"  // for StructDecl, DeclarationWithType
     32#include "SynTree/Expression.h"   // for UntypedMemberExpr, Expression, Uniq...
     33#include "SynTree/Label.h"        // for operator==, Label
     34#include "SynTree/Mutator.h"      // for Mutator
     35#include "SynTree/Type.h"         // for Type, Type::Qualifiers, TupleType
     36#include "SynTree/Visitor.h"      // for Visitor
    2837#include "Tuples.h"
    2938
     39class CompoundStmt;
     40class TypeSubstitution;
     41
    3042namespace Tuples {
    31 
     43        namespace {
     44                struct MemberTupleExpander final : public WithShortCircuiting, public WithVisitorRef<MemberTupleExpander> {
     45                        void premutate( UntypedMemberExpr * ) { visit_children = false; }
     46                        Expression * postmutate( UntypedMemberExpr * memberExpr );
     47                };
     48
     49                struct UniqueExprExpander final : public WithDeclsToAdd {
     50                        Expression * postmutate( UniqueExpr * unqExpr );
     51
     52                        std::map< int, Expression * > decls; // not vector, because order added may not be increasing order
     53
     54                        ~UniqueExprExpander() {
     55                                for ( std::pair<const int, Expression *> & p : decls ) {
     56                                        delete p.second;
     57                                }
     58                        }
     59                };
     60
     61                struct TupleAssignExpander {
     62                        Expression * postmutate( TupleAssignExpr * tupleExpr );
     63                };
     64
     65                struct TupleTypeReplacer : public WithDeclsToAdd, public WithGuards, public WithConstTypeSubstitution {
     66                        Type * postmutate( TupleType * tupleType );
     67
     68                        void premutate( CompoundStmt * ) {
     69                                GuardScope( typeMap );
     70                        }
     71                  private:
     72                        ScopedMap< int, StructDecl * > typeMap;
     73                };
     74
     75                struct TupleIndexExpander {
     76                        Expression * postmutate( TupleIndexExpr * tupleExpr );
     77                };
     78
     79                struct TupleExprExpander final {
     80                        Expression * postmutate( TupleExpr * tupleExpr );
     81                };
     82        }
     83
     84        void expandMemberTuples( std::list< Declaration * > & translationUnit ) {
     85                PassVisitor<MemberTupleExpander> expander;
     86                mutateAll( translationUnit, expander );
     87        }
     88
     89        void expandUniqueExpr( std::list< Declaration * > & translationUnit ) {
     90                PassVisitor<UniqueExprExpander> unqExpander;
     91                mutateAll( translationUnit, unqExpander );
     92        }
     93
     94        void expandTuples( std::list< Declaration * > & translationUnit ) {
     95                PassVisitor<TupleAssignExpander> assnExpander;
     96                mutateAll( translationUnit, assnExpander );
     97
     98                PassVisitor<TupleTypeReplacer> replacer;
     99                mutateAll( translationUnit, replacer );
     100
     101                PassVisitor<TupleIndexExpander> idxExpander;
     102                mutateAll( translationUnit, idxExpander );
     103
     104                PassVisitor<TupleExprExpander> exprExpander;
     105                mutateAll( translationUnit, exprExpander );
     106        }
     107
     108        namespace {
     109                /// given a expression representing the member and an expression representing the aggregate,
     110                /// reconstructs a flattened UntypedMemberExpr with the right precedence
     111                Expression * reconstructMemberExpr( Expression * member, Expression * aggr, CodeLocation & loc ) {
     112                        if ( UntypedMemberExpr * memberExpr = dynamic_cast< UntypedMemberExpr * >( member ) ) {
     113                                // construct a new UntypedMemberExpr with the correct structure , and recursively
     114                                // expand that member expression.
     115                                PassVisitor<MemberTupleExpander> expander;
     116                                UntypedMemberExpr * inner = new UntypedMemberExpr( memberExpr->aggregate, aggr->clone() );
     117                                UntypedMemberExpr * newMemberExpr = new UntypedMemberExpr( memberExpr->member, inner );
     118                                inner->location = newMemberExpr->location = loc;
     119                                memberExpr->member = nullptr;
     120                                memberExpr->aggregate = nullptr;
     121                                delete memberExpr;
     122                                return newMemberExpr->acceptMutator( expander );
     123                        } else {
     124                                // not a member expression, so there is nothing to do but attach and return
     125                                UntypedMemberExpr * newMemberExpr = new UntypedMemberExpr( member, aggr->clone() );
     126                                newMemberExpr->location = loc;
     127                                return newMemberExpr;
     128                        }
     129                }
     130        }
     131
     132        Expression * MemberTupleExpander::postmutate( UntypedMemberExpr * memberExpr ) {
     133                if ( UntypedTupleExpr * tupleExpr = dynamic_cast< UntypedTupleExpr * > ( memberExpr->member ) ) {
     134                        Expression * aggr = memberExpr->aggregate->clone()->acceptMutator( *visitor );
     135                        // aggregate expressions which might be impure must be wrapped in unique expressions
     136                        if ( Tuples::maybeImpureIgnoreUnique( memberExpr->aggregate ) ) aggr = new UniqueExpr( aggr );
     137                        for ( Expression *& expr : tupleExpr->exprs ) {
     138                                expr = reconstructMemberExpr( expr, aggr, memberExpr->location );
     139                                expr->location = memberExpr->location;
     140                        }
     141                        delete aggr;
     142                        tupleExpr->location = memberExpr->location;
     143                        return tupleExpr;
     144                } else {
     145                        // there may be a tuple expr buried in the aggregate
     146                        // xxx - this is a memory leak
     147                        UntypedMemberExpr * newMemberExpr = new UntypedMemberExpr( memberExpr->member->clone(), memberExpr->aggregate->acceptMutator( *visitor ) );
     148                        newMemberExpr->location = memberExpr->location;
     149                        return newMemberExpr;
     150                }
     151        }
     152
     153        Expression * UniqueExprExpander::postmutate( UniqueExpr * unqExpr ) {
     154                const int id = unqExpr->get_id();
     155
     156                // on first time visiting a unique expr with a particular ID, generate the expression that replaces all UniqueExprs with that ID,
     157                // and lookup on subsequent hits. This ensures that all unique exprs with the same ID reference the same variable.
     158                if ( ! decls.count( id ) ) {
     159                        Expression * assignUnq;
     160                        Expression * var = unqExpr->get_var();
     161                        if ( unqExpr->get_object() ) {
     162                                // an object was generated to represent this unique expression -- it should be added to the list of declarations now
     163                                declsToAddBefore.push_back( unqExpr->get_object() );
     164                                unqExpr->set_object( nullptr );
     165                                // steal the expr from the unqExpr
     166                                assignUnq = UntypedExpr::createAssign( unqExpr->get_var()->clone(), unqExpr->get_expr() );
     167                                unqExpr->set_expr( nullptr );
     168                        } else {
     169                                // steal the already generated assignment to var from the unqExpr - this has been generated by FixInit
     170                                Expression * expr = unqExpr->get_expr();
     171                                CommaExpr * commaExpr = strict_dynamic_cast< CommaExpr * >( expr );
     172                                assignUnq = commaExpr->get_arg1();
     173                                commaExpr->set_arg1( nullptr );
     174                        }
     175                        ObjectDecl * finished = new ObjectDecl( toString( "_unq", id, "_finished_" ), Type::StorageClasses(), LinkageSpec::Cforall, nullptr, new BasicType( Type::Qualifiers(), BasicType::Bool ),
     176                                                                                                        new SingleInit( new ConstantExpr( Constant::from_int( 0 ) ) ) );
     177                        declsToAddBefore.push_back( finished );
     178                        // (finished ? _unq_expr_N : (_unq_expr_N = <unqExpr->get_expr()>, finished = 1, _unq_expr_N))
     179                        // This pattern ensures that each unique expression is evaluated once, regardless of evaluation order of the generated C code.
     180                        Expression * assignFinished = UntypedExpr::createAssign( new VariableExpr(finished), new ConstantExpr( Constant::from_int( 1 ) ) );
     181                        ConditionalExpr * condExpr = new ConditionalExpr( new VariableExpr( finished ), var->clone(),
     182                                new CommaExpr( new CommaExpr( assignUnq, assignFinished ), var->clone() ) );
     183                        condExpr->set_result( var->get_result()->clone() );
     184                        condExpr->set_env( maybeClone( unqExpr->get_env() ) );
     185                        decls[id] = condExpr;
     186                }
     187                delete unqExpr;
     188                return decls[id]->clone();
     189        }
     190
     191        Expression * TupleAssignExpander::postmutate( TupleAssignExpr * assnExpr ) {
     192                StmtExpr * ret = assnExpr->get_stmtExpr();
     193                assnExpr->set_stmtExpr( nullptr );
     194                // move env to StmtExpr
     195                ret->set_env( assnExpr->get_env() );
     196                assnExpr->set_env( nullptr );
     197                delete assnExpr;
     198                return ret;
     199        }
     200
     201        Type * TupleTypeReplacer::postmutate( TupleType * tupleType ) {
     202                unsigned tupleSize = tupleType->size();
     203                if ( ! typeMap.count( tupleSize ) ) {
     204                        // generate struct type to replace tuple type based on the number of components in the tuple
     205                        StructDecl * decl = new StructDecl( toString( "_tuple", tupleSize, "_" ) );
     206                        decl->location = tupleType->location;
     207                        decl->set_body( true );
     208                        for ( size_t i = 0; i < tupleSize; ++i ) {
     209                                TypeDecl * tyParam = new TypeDecl( toString( "tuple_param_", tupleSize, "_", i ), Type::StorageClasses(), nullptr, TypeDecl::Dtype, true );
     210                                decl->get_members().push_back( new ObjectDecl( toString("field_", i ), Type::StorageClasses(), LinkageSpec::C, nullptr, new TypeInstType( Type::Qualifiers(), tyParam->get_name(), tyParam ), nullptr ) );
     211                                decl->get_parameters().push_back( tyParam );
     212                        }
     213                        if ( tupleSize == 0 ) {
     214                                // empty structs are not standard C. Add a dummy field to empty tuples to silence warnings when a compound literal Tuple0 is created.
     215                                decl->get_members().push_back( new ObjectDecl( "dummy", Type::StorageClasses(), LinkageSpec::C, nullptr, new BasicType( Type::Qualifiers(), BasicType::SignedInt ), nullptr ) );
     216                        }
     217                        typeMap[tupleSize] = decl;
     218                        declsToAddBefore.push_back( decl );
     219                }
     220                Type::Qualifiers qualifiers = tupleType->get_qualifiers();
     221
     222                StructDecl * decl = typeMap[tupleSize];
     223                StructInstType * newType = new StructInstType( qualifiers, decl );
     224                for ( auto p : group_iterate( tupleType->get_types(), decl->get_parameters() ) ) {
     225                        Type * t = std::get<0>(p);
     226                        newType->get_parameters().push_back( new TypeExpr( t->clone() ) );
     227                }
     228                delete tupleType;
     229                return newType;
     230        }
     231
     232        Expression * TupleIndexExpander::postmutate( TupleIndexExpr * tupleExpr ) {
     233                Expression * tuple = tupleExpr->tuple;
     234                assert( tuple );
     235                tupleExpr->tuple = nullptr;
     236                unsigned int idx = tupleExpr->index;
     237                TypeSubstitution * env = tupleExpr->env;
     238                tupleExpr->env = nullptr;
     239                delete tupleExpr;
     240
     241                if ( TupleExpr * tupleExpr = dynamic_cast< TupleExpr * > ( tuple ) ) {
     242                        if ( ! maybeImpureIgnoreUnique( tupleExpr ) ) {
     243                                // optimization: definitely pure tuple expr => can reduce to the only relevant component.
     244                                assert( tupleExpr->exprs.size() > idx );
     245                                Expression *& expr = *std::next(tupleExpr->exprs.begin(), idx);
     246                                Expression * ret = expr;
     247                                ret->env = env;
     248                                expr = nullptr; // remove from list so it can safely be deleted
     249                                delete tupleExpr;
     250                                return ret;
     251                        }
     252                }
     253
     254                StructInstType * type = strict_dynamic_cast< StructInstType * >( tuple->result );
     255                StructDecl * structDecl = type->baseStruct;
     256                assert( structDecl->members.size() > idx );
     257                Declaration * member = *std::next(structDecl->members.begin(), idx);
     258                MemberExpr * memExpr = new MemberExpr( strict_dynamic_cast< DeclarationWithType * >( member ), tuple );
     259                memExpr->env = env;
     260                return memExpr;
     261        }
     262
     263        Expression * replaceTupleExpr( Type * result, const std::list< Expression * > & exprs, TypeSubstitution * env ) {
     264                if ( result->isVoid() ) {
     265                        // void result - don't need to produce a value for cascading - just output a chain of comma exprs
     266                        assert( ! exprs.empty() );
     267                        std::list< Expression * >::const_iterator iter = exprs.begin();
     268                        Expression * expr = new CastExpr( *iter++ );
     269                        for ( ; iter != exprs.end(); ++iter ) {
     270                                expr = new CommaExpr( expr, new CastExpr( *iter ) );
     271                        }
     272                        expr->set_env( env );
     273                        return expr;
     274                } else {
     275                        // typed tuple expression - produce a compound literal which performs each of the expressions
     276                        // as a distinct part of its initializer - the produced compound literal may be used as part of
     277                        // another expression
     278                        std::list< Initializer * > inits;
     279                        for ( Expression * expr : exprs ) {
     280                                inits.push_back( new SingleInit( expr ) );
     281                        }
     282                        Expression * expr = new CompoundLiteralExpr( result, new ListInit( inits ) );
     283                        expr->set_env( env );
     284                        return expr;
     285                }
     286        }
     287
     288        Expression * TupleExprExpander::postmutate( TupleExpr * tupleExpr ) {
     289                Type * result = tupleExpr->get_result();
     290                std::list< Expression * > exprs = tupleExpr->get_exprs();
     291                assert( result );
     292                TypeSubstitution * env = tupleExpr->get_env();
     293
     294                // remove data from shell and delete it
     295                tupleExpr->set_result( nullptr );
     296                tupleExpr->get_exprs().clear();
     297                tupleExpr->set_env( nullptr );
     298                delete tupleExpr;
     299
     300                return replaceTupleExpr( result, exprs, env );
     301        }
     302
     303        Type * makeTupleType( const std::list< Expression * > & exprs ) {
     304                // produce the TupleType which aggregates the types of the exprs
     305                std::list< Type * > types;
     306                Type::Qualifiers qualifiers( Type::Const | Type::Volatile | Type::Restrict | Type::Atomic | Type::Mutex );
     307                for ( Expression * expr : exprs ) {
     308                        assert( expr->get_result() );
     309                        if ( expr->get_result()->isVoid() ) {
     310                                // if the type of any expr is void, the type of the entire tuple is void
     311                                return new VoidType( Type::Qualifiers() );
     312                        }
     313                        Type * type = expr->get_result()->clone();
     314                        types.push_back( type );
     315                        // the qualifiers on the tuple type are the qualifiers that exist on all component types
     316                        qualifiers &= type->get_qualifiers();
     317                } // for
     318                if ( exprs.empty() ) qualifiers = Type::Qualifiers();
     319                return new TupleType( qualifiers, types );
     320        }
    32321        const ast::Type * makeTupleType( const std::vector<ast::ptr<ast::Expr>> & exprs ) {
    33322                // produce the TupleType which aggregates the types of the exprs
     
    52341        }
    53342
     343        TypeInstType * isTtype( Type * type ) {
     344                if ( TypeInstType * inst = dynamic_cast< TypeInstType * >( type ) ) {
     345                        if ( inst->get_baseType() && inst->get_baseType()->get_kind() == TypeDecl::Ttype ) {
     346                                return inst;
     347                        }
     348                }
     349                return nullptr;
     350        }
     351
     352        const TypeInstType * isTtype( const Type * type ) {
     353                if ( const TypeInstType * inst = dynamic_cast< const TypeInstType * >( type ) ) {
     354                        if ( inst->baseType && inst->baseType->kind == TypeDecl::Ttype ) {
     355                                return inst;
     356                        }
     357                }
     358                return nullptr;
     359        }
     360
    54361        const ast::TypeInstType * isTtype( const ast::Type * type ) {
    55362                if ( const ast::TypeInstType * inst = dynamic_cast< const ast::TypeInstType * >( type ) ) {
Note: See TracChangeset for help on using the changeset viewer.