source: src/Concurrency/Waituntil.cpp @ a5294af

ast-experimental
Last change on this file since a5294af was bccd70a, checked in by Andrew Beach <ajbeach@…>, 18 months ago

Removed internal code from TypeSubstitution? header. It caused a chain of include problems, which have been corrected.

  • Property mode set to 100644
File size: 57.7 KB
Line 
1//
2// Cforall Version 1.0.0 Copyright (C) 2016 University of Waterloo
3//
4// The contents of this file are covered under the licence agreement in the
5// file "LICENCE" distributed with Cforall.
6//
7// WaitforNew.cpp -- Expand waitfor clauses into code.
8//
9// Author           : Andrew Beach
10// Created On       : Fri May 27 10:31:00 2022
11// Last Modified By : Andrew Beach
12// Last Modified On : Tue Jun 13 13:30:00 2022
13// Update Count     : 0
14//
15
16#include "Waituntil.hpp"
17
18#include <string>
19
20#include "AST/Copy.hpp"
21#include "AST/Expr.hpp"
22#include "AST/Pass.hpp"
23#include "AST/Print.hpp"
24#include "AST/Stmt.hpp"
25#include "AST/Type.hpp"
26#include "Common/UniqueName.h"
27
28using namespace ast;
29using namespace std;
30
31/* So this is what this pass dones:
32{
33    when ( condA ) waituntil( A ){ doA(); }
34    or when ( condB ) waituntil( B ){ doB(); }
35    and when ( condC ) waituntil( C ) { doC(); }
36}
37                 ||
38                 ||
39                \||/
40                 \/
41
42Generates these two routines:
43static inline bool is_full_sat_1( int * clause_statuses ) {
44    return clause_statuses[0]
45        || clause_statuses[1]
46        && clause_statuses[2];
47}
48
49static inline bool is_done_sat_1( int * clause_statuses ) {
50    return has_run(clause_statuses[0])
51        || has_run(clause_statuses[1])
52        && has_run(clause_statuses[2]);
53}
54
55Replaces the waituntil statement above with the following code:
56{
57    // used with atomic_dec/inc to get binary semaphore behaviour
58    int park_counter = 0;
59
60    // status (one for each clause)
61    int clause_statuses[3] = { 0 };
62
63    bool whenA = condA;
64    bool whenB = condB;
65    bool whenC = condC;
66
67    if ( !whenB ) clause_statuses[1] = __SELECT_RUN;
68    if ( !whenC ) clause_statuses[2] = __SELECT_RUN;
69
70    // some other conditional settors for clause_statuses are set here, see genSubtreeAssign and related routines
71
72    // three blocks
73    // for each block, create, setup, then register select_node
74    select_node clause1;
75    select_node clause2;
76    select_node clause3;
77
78    try {
79        if ( whenA ) { register_select(A, clause1); setup_clause( clause1, &clause_statuses[0], &park_counter ); }
80        ... repeat ^ for B and C ...
81
82        // if else clause is defined a separate branch can occur here to set initial values, see genWhenStateConditions
83
84        // loop & park until done
85        while( !is_full_sat_1( clause_statuses ) ) {
86           
87            // binary sem P();
88            if ( __atomic_sub_fetch( &park_counter, 1, __ATOMIC_SEQ_CST) < 0 )
89                park();
90           
91            // execute any blocks available with status set to 0
92            for ( int i = 0; i < 3; i++ ) {
93                if (clause_statuses[i] == __SELECT_SAT) {
94                    switch (i) {
95                        case 0:
96                            try {
97                                if (on_selected( A, clause1 ))
98                                    doA();
99                            }
100                            finally { clause_statuses[i] = __SELECT_RUN; unregister_select(A, clause1); }
101                            break;
102                        case 1:
103                            ... same gen as A but for B and clause2 ...
104                            break;
105                        case 2:
106                            ... same gen as A but for C and clause3 ...
107                            break;
108                    }
109                }
110            }
111        }
112
113        // ensure that the blocks that triggered is_full_sat_1 are run
114        // by running every un-run block that is SAT from the start until
115        // the predicate is SAT when considering RUN status = true
116        for ( int i = 0; i < 3; i++ ) {
117            if (is_done_sat_1( clause_statuses )) break;
118            if (clause_statuses[i] == __SELECT_SAT)
119                ... Same if body here as in loop above ...
120        }
121    } finally {
122        // the unregister and on_selected calls are needed to support primitives where the acquire has side effects
123        // so the corresponding block MUST be run for those primitives to not lose state (example is channels)
124        if ( ! has_run(clause_statuses[0]) && whenA && unregister_select(A, clause1) && on_selected( A, clause1 ) )
125            doA();
126        ... repeat if above for B and C ...
127    }
128}
129
130*/
131
132namespace Concurrency {
133
134class GenerateWaitUntilCore final {
135    vector<FunctionDecl *> & satFns;
136        UniqueName namer_sat = "__is_full_sat_"s;
137    UniqueName namer_run = "__is_run_sat_"s;
138        UniqueName namer_park = "__park_counter_"s;
139        UniqueName namer_status = "__clause_statuses_"s;
140        UniqueName namer_node = "__clause_"s;
141    UniqueName namer_target = "__clause_target_"s;
142    UniqueName namer_when = "__when_cond_"s;
143
144    string idxName = "__CFA_clause_idx_";
145
146    struct ClauseData {
147        string nodeName;
148        string targetName;
149        string whenName;
150        int index;
151        string & statusName;
152        ClauseData( int index, string & statusName ) : index(index), statusName(statusName) {}
153    };
154
155    const StructDecl * selectNodeDecl = nullptr;
156
157    // This first set of routines are all used to do the complicated job of
158    //    dealing with how to set predicate statuses with certain when_conds T/F
159    //    so that the when_cond == F effectively makes that clause "disappear"
160    void updateAmbiguousWhen( WaitUntilStmt::ClauseNode * currNode, bool andAbove, bool orAbove, bool andBelow, bool orBelow );
161    void paintWhenTree( WaitUntilStmt::ClauseNode * currNode, bool andAbove, bool orAbove, bool & andBelow, bool & orBelow );
162    bool paintWhenTree( WaitUntilStmt::ClauseNode * currNode );
163    void collectWhens( WaitUntilStmt::ClauseNode * currNode, vector<pair<int, WaitUntilStmt::ClauseNode *>> & ambigIdxs, vector<int> & andIdxs, int & index, bool parentAmbig, bool parentAnd );
164    void collectWhens( WaitUntilStmt::ClauseNode * currNode, vector<pair<int, WaitUntilStmt::ClauseNode *>> & ambigIdxs, vector<int> & andIdxs );
165    void updateWhenState( WaitUntilStmt::ClauseNode * currNode );
166    void genSubtreeAssign( const WaitUntilStmt * stmt, WaitUntilStmt::ClauseNode * currNode, bool status, int & idx, CompoundStmt * retStmt, vector<ClauseData *> & clauseData );
167    void genStatusAssign( const WaitUntilStmt * stmt, WaitUntilStmt::ClauseNode * currNode, int & idx, CompoundStmt * retStmt, vector<ClauseData *> & clauseData );
168    CompoundStmt * getStatusAssignment( const WaitUntilStmt * stmt, vector<ClauseData *> & clauseData );
169    Stmt * genWhenStateConditions( const WaitUntilStmt * stmt, vector<ClauseData *> & clauseData, vector<pair<int, WaitUntilStmt::ClauseNode *>> & ambigClauses, vector<pair<int, WaitUntilStmt::ClauseNode *>>::size_type ambigIdx );
170
171    // These routines are just code-gen helpers
172    void addPredicates( const WaitUntilStmt * stmt, string & satName, string & runName );
173    void setUpClause( const WhenClause * clause, ClauseData * data, string & pCountName, CompoundStmt * body );
174    ForStmt * genStatusCheckFor( const WaitUntilStmt * stmt, vector<ClauseData *> & clauseData, string & predName );
175    Expr * genSelectTraitCall( const WhenClause * clause, const ClauseData * data, string fnName );
176    CompoundStmt * genStmtBlock( const WhenClause * clause, const ClauseData * data );
177    Stmt * genElseClauseBranch( const WaitUntilStmt * stmt, string & runName, string & arrName, vector<ClauseData *> & clauseData );
178    Stmt * genNoElseClauseBranch( const WaitUntilStmt * stmt, string & satName, string & runName, string & arrName, string & pCountName, vector<ClauseData *> & clauseData );
179    void genClauseInits( const WaitUntilStmt * stmt, vector<ClauseData *> & clauseData, CompoundStmt * body, string & statusName, string & elseWhenName );
180    Stmt * recursiveOrIfGen( const WaitUntilStmt * stmt, vector<ClauseData *> & data, vector<ClauseData*>::size_type idx, string & elseWhenName );
181    Stmt * buildOrCaseSwitch( const WaitUntilStmt * stmt, string & statusName, vector<ClauseData *> & data );
182    Stmt * genAllOr( const WaitUntilStmt * stmt );
183
184  public:
185    void previsit( const StructDecl * decl );
186        Stmt * postvisit( const WaitUntilStmt * stmt );
187    GenerateWaitUntilCore( vector<FunctionDecl *> & satFns ): satFns(satFns) {}
188};
189
190// Finds select_node decl
191void GenerateWaitUntilCore::previsit( const StructDecl * decl ) {
192    if ( !decl->body ) {
193                return;
194        } else if ( "select_node" == decl->name ) {
195                assert( !selectNodeDecl );
196                selectNodeDecl = decl;
197        }
198}
199
200void GenerateWaitUntilCore::updateAmbiguousWhen( WaitUntilStmt::ClauseNode * currNode, bool andAbove, bool orAbove, bool andBelow, bool orBelow ) {
201    // all children when-ambiguous
202    if ( currNode->left->ambiguousWhen && currNode->right->ambiguousWhen )
203        // true iff an ancestor/descendant has a different operation
204        currNode->ambiguousWhen = (orAbove || orBelow) && (andBelow || andAbove);
205    // ambiguousWhen is initially false so theres no need to set it here
206}
207
208// Traverses ClauseNode tree and paints each AND/OR node as when-ambiguous true or false
209// This tree painting is needed to generate the if statements that set the initial state
210//    of the clause statuses when some clauses are turned off via when_cond
211// An internal AND/OR node is when-ambiguous if it satisfies all of the following:
212// - It has an ancestor or descendant that is a different operation, i.e. (AND has an OR ancestor or vice versa)
213// - All of its descendent clauses are optional, i.e. they have a when_cond defined on the WhenClause
214void GenerateWaitUntilCore::paintWhenTree( WaitUntilStmt::ClauseNode * currNode, bool andAbove, bool orAbove, bool & andBelow, bool & orBelow ) {
215    bool aBelow = false; // updated by child nodes
216    bool oBelow = false; // updated by child nodes
217    switch (currNode->op) {
218        case WaitUntilStmt::ClauseNode::AND:
219            paintWhenTree( currNode->left, true, orAbove, aBelow, oBelow );
220            paintWhenTree( currNode->right, true, orAbove, aBelow, oBelow );
221
222            // update currNode's when flag based on conditions listed in fn signature comment above
223            updateAmbiguousWhen(currNode, true, orAbove, aBelow, oBelow );
224
225            // set return flags to tell parents which decendant ops have been seen
226            andBelow = true;
227            orBelow = oBelow;
228            return;
229        case WaitUntilStmt::ClauseNode::OR:
230            paintWhenTree( currNode->left, andAbove, true, aBelow, oBelow );
231            paintWhenTree( currNode->right, andAbove, true, aBelow, oBelow );
232
233            // update currNode's when flag based on conditions listed in fn signature comment above
234            updateAmbiguousWhen(currNode, andAbove, true, aBelow, oBelow );
235
236            // set return flags to tell parents which decendant ops have been seen
237            andBelow = aBelow;
238            orBelow = true;
239            return;
240        case WaitUntilStmt::ClauseNode::LEAF:
241            if ( currNode->leaf->when_cond )
242                currNode->ambiguousWhen = true;
243            return;
244        default:
245            assertf(false, "Unreachable waituntil clause node type. How did you get here???");
246    }
247}
248
249// overloaded wrapper for paintWhenTree that sets initial values
250// returns true if entire tree is OR's (special case)
251bool GenerateWaitUntilCore::paintWhenTree( WaitUntilStmt::ClauseNode * currNode ) {
252    bool aBelow = false, oBelow = false; // unused by initial call
253    paintWhenTree( currNode, false, false, aBelow, oBelow );
254    return !aBelow;
255}
256
257// Helper: returns Expr that represents arrName[index]
258Expr * genArrAccessExpr( const CodeLocation & loc, int index, string arrName ) {
259    return new UntypedExpr ( loc, 
260        new NameExpr( loc, "?[?]" ),
261        {
262            new NameExpr( loc, arrName ),
263            ConstantExpr::from_int( loc, index )
264        }
265    );
266}
267
268// After the ClauseNode AND/OR nodes are painted this routine is called to traverses the tree and does the following:
269// - collects a set of indices in the clause arr that refer whenclauses that can have ambiguous status assignments (ambigIdxs)
270// - collects a set of indices in the clause arr that refer whenclauses that have a when() defined and an AND node as a parent (andIdxs)
271// - updates LEAF nodes to be when-ambiguous if their direct parent is when-ambiguous.
272void GenerateWaitUntilCore::collectWhens( WaitUntilStmt::ClauseNode * currNode, vector<pair<int, WaitUntilStmt::ClauseNode *>> & ambigIdxs, vector<int> & andIdxs, int & index, bool parentAmbig, bool parentAnd ) {
273    switch (currNode->op) {
274        case WaitUntilStmt::ClauseNode::AND:
275            collectWhens( currNode->left, ambigIdxs, andIdxs, index, currNode->ambiguousWhen, true );
276            collectWhens( currNode->right,  ambigIdxs, andIdxs, index, currNode->ambiguousWhen, true );
277            return;
278        case WaitUntilStmt::ClauseNode::OR:
279            collectWhens( currNode->left,  ambigIdxs, andIdxs, index, currNode->ambiguousWhen, false );
280            collectWhens( currNode->right,  ambigIdxs, andIdxs, index, currNode->ambiguousWhen, false );
281            return;
282        case WaitUntilStmt::ClauseNode::LEAF:
283            if ( parentAmbig ) {
284                ambigIdxs.push_back(make_pair(index, currNode));
285            }
286            if ( parentAnd && currNode->leaf->when_cond ) {
287                currNode->childOfAnd = true;
288                andIdxs.push_back(index);
289            }
290            index++;
291            return;
292        default:
293            assertf(false, "Unreachable waituntil clause node type. How did you get here???");
294    }
295}
296
297// overloaded wrapper for collectWhens that sets initial values
298void GenerateWaitUntilCore::collectWhens( WaitUntilStmt::ClauseNode * currNode, vector<pair<int, WaitUntilStmt::ClauseNode *>> & ambigIdxs, vector<int> & andIdxs ) {
299    int idx = 0;
300    collectWhens( currNode, ambigIdxs, andIdxs, idx, false, false );
301}
302
303// recursively updates ClauseNode whenState on internal nodes so that next pass can see which
304//    subtrees are "turned off"
305// sets whenState = false iff both children have whenState == false.
306// similar to paintWhenTree except since paintWhenTree also filtered out clauses we don't need to consider based on op
307// since the ambiguous clauses were filtered in paintWhenTree we don't need to worry about that here
308void GenerateWaitUntilCore::updateWhenState( WaitUntilStmt::ClauseNode * currNode ) {
309    if ( currNode->op == WaitUntilStmt::ClauseNode::LEAF ) return;
310    updateWhenState( currNode->left );
311    updateWhenState( currNode->right );
312    if ( !currNode->left->whenState && !currNode->right->whenState )
313        currNode->whenState = false;
314    else 
315        currNode->whenState = true;
316}
317
318// generates the minimal set of status assignments to ensure predicate subtree passed as currNode evaluates to status
319// assumes that this will only be called on subtrees that are entirely whenState == false
320void GenerateWaitUntilCore::genSubtreeAssign( const WaitUntilStmt * stmt, WaitUntilStmt::ClauseNode * currNode, bool status, int & idx, CompoundStmt * retStmt, vector<ClauseData *> & clauseData ) {
321    if ( ( currNode->op == WaitUntilStmt::ClauseNode::AND && status )
322        || ( currNode->op == WaitUntilStmt::ClauseNode::OR && !status ) ) {
323        // need to recurse on both subtrees if && subtree needs to be true or || subtree needs to be false
324        genSubtreeAssign( stmt, currNode->left, status, idx, retStmt, clauseData );
325        genSubtreeAssign( stmt, currNode->right, status, idx, retStmt, clauseData );
326    } else if ( ( currNode->op == WaitUntilStmt::ClauseNode::OR && status )
327        || ( currNode->op == WaitUntilStmt::ClauseNode::AND && !status ) ) {
328        // only one subtree needs to evaluate to status if && subtree needs to be true or || subtree needs to be false
329        CompoundStmt * leftStmt = new CompoundStmt( stmt->location );
330        CompoundStmt * rightStmt = new CompoundStmt( stmt->location );
331
332        // only one side needs to evaluate to status so we recurse on both subtrees
333        //    but only keep the statements from the subtree with minimal statements
334        genSubtreeAssign( stmt, currNode->left, status, idx, leftStmt, clauseData );
335        genSubtreeAssign( stmt, currNode->right, status, idx, rightStmt, clauseData );
336       
337        // append minimal statements to retStmt
338        if ( leftStmt->kids.size() < rightStmt->kids.size() ) {
339            retStmt->kids.splice( retStmt->kids.end(), leftStmt->kids );
340        } else {
341            retStmt->kids.splice( retStmt->kids.end(), rightStmt->kids );
342        }
343       
344        delete leftStmt;
345        delete rightStmt;
346    } else if ( currNode->op == WaitUntilStmt::ClauseNode::LEAF ) {
347        const CodeLocation & loc = stmt->location;
348        if ( status && !currNode->childOfAnd ) {
349            retStmt->push_back(
350                new ExprStmt( loc, 
351                    UntypedExpr::createAssign( loc,
352                        genArrAccessExpr( loc, idx, clauseData.at(idx)->statusName ),
353                        new NameExpr( loc, "__SELECT_RUN" )
354                    )
355                )
356            );
357        } else if ( !status && currNode->childOfAnd ) {
358            retStmt->push_back(
359                new ExprStmt( loc, 
360                    UntypedExpr::createAssign( loc,
361                        genArrAccessExpr( loc, idx, clauseData.at(idx)->statusName ),
362                        new NameExpr( loc, "__SELECT_UNSAT" )
363                    )
364                )
365            );
366        }
367
368        // No need to generate statements for the following cases since childOfAnd are always set to true
369        //    and !childOfAnd are always false
370        // - status && currNode->childOfAnd
371        // - !status && !currNode->childOfAnd
372        idx++;
373    }
374}
375
376void GenerateWaitUntilCore::genStatusAssign( const WaitUntilStmt * stmt, WaitUntilStmt::ClauseNode * currNode, int & idx, CompoundStmt * retStmt, vector<ClauseData *> & clauseData ) {
377    switch (currNode->op) {
378        case WaitUntilStmt::ClauseNode::AND:
379            // check which subtrees have all whenState == false (disabled)
380            if (!currNode->left->whenState && !currNode->right->whenState) {
381                // this case can only occur when whole tree is disabled since otherwise
382                //    genStatusAssign( ... ) isn't called on nodes with whenState == false
383                assert( !currNode->whenState ); // paranoidWWW
384                // whole tree disabled so pass true so that select is SAT vacuously
385                genSubtreeAssign( stmt, currNode, true, idx, retStmt, clauseData );
386            } else if ( !currNode->left->whenState ) {
387                // pass true since x && true === x
388                genSubtreeAssign( stmt, currNode->left, true, idx, retStmt, clauseData );
389                genStatusAssign( stmt, currNode->right, idx, retStmt, clauseData );
390            } else if ( !currNode->right->whenState ) {
391                genStatusAssign( stmt, currNode->left, idx, retStmt, clauseData );
392                genSubtreeAssign( stmt, currNode->right, true, idx, retStmt, clauseData );
393            } else { 
394                // if no children with whenState == false recurse normally via break
395                break;
396            }
397            return;
398        case WaitUntilStmt::ClauseNode::OR:
399            if (!currNode->left->whenState && !currNode->right->whenState) {
400                assert( !currNode->whenState ); // paranoid
401                genSubtreeAssign( stmt, currNode, true, idx, retStmt, clauseData );
402            } else if ( !currNode->left->whenState ) {
403                // pass false since x || false === x
404                genSubtreeAssign( stmt, currNode->left, false, idx, retStmt, clauseData );
405                genStatusAssign( stmt, currNode->right, idx, retStmt, clauseData );
406            } else if ( !currNode->right->whenState ) {
407                genStatusAssign( stmt, currNode->left, idx, retStmt, clauseData );
408                genSubtreeAssign( stmt, currNode->right, false, idx, retStmt, clauseData );
409            } else { 
410                break;
411            }
412            return;
413        case WaitUntilStmt::ClauseNode::LEAF:
414            idx++;
415            return;
416        default:
417            assertf(false, "Unreachable waituntil clause node type. How did you get here???");
418    }
419    genStatusAssign( stmt, currNode->left, idx, retStmt, clauseData );
420    genStatusAssign( stmt, currNode->right, idx, retStmt, clauseData );
421}
422
423// generates a minimal set of assignments for status arr based on which whens are toggled on/off
424CompoundStmt * GenerateWaitUntilCore::getStatusAssignment( const WaitUntilStmt * stmt, vector<ClauseData *> & clauseData ) {
425    updateWhenState( stmt->predicateTree );
426    CompoundStmt * retval = new CompoundStmt( stmt->location );
427    int idx = 0;
428    genStatusAssign( stmt, stmt->predicateTree, idx, retval, clauseData );
429    return retval;
430}
431
432// generates nested if/elses for all possible assignments of ambiguous when_conds
433// exponential size of code gen but linear runtime O(n), where n is number of ambiguous whens()
434Stmt * GenerateWaitUntilCore::genWhenStateConditions( const WaitUntilStmt * stmt, vector<ClauseData *> & clauseData, 
435    vector<pair<int, WaitUntilStmt::ClauseNode *>> & ambigClauses, vector<pair<int, WaitUntilStmt::ClauseNode *>>::size_type ambigIdx ) {
436    // I hate C++ sometimes, using vector<pair<int, WaitUntilStmt::ClauseNode *>>::size_type for size() comparison seems silly.
437    //    Why is size_type parameterized on the type stored in the vector?????
438
439    const CodeLocation & loc = stmt->location;
440    int clauseIdx = ambigClauses.at(ambigIdx).first;
441    WaitUntilStmt::ClauseNode * currNode = ambigClauses.at(ambigIdx).second;
442    Stmt * thenStmt;
443    Stmt * elseStmt;
444   
445    if ( ambigIdx == ambigClauses.size() - 1 ) { // base case
446        currNode->whenState = true;
447        thenStmt = getStatusAssignment( stmt, clauseData );
448        currNode->whenState = false;
449        elseStmt = getStatusAssignment( stmt, clauseData );
450    } else {
451        // recurse both with when enabled and disabled to generate all possible cases
452        currNode->whenState = true;
453        thenStmt = genWhenStateConditions( stmt, clauseData, ambigClauses, ambigIdx + 1 );
454        currNode->whenState = false;
455        elseStmt = genWhenStateConditions( stmt, clauseData, ambigClauses, ambigIdx + 1 );
456    }
457
458    // insert first recursion result in if ( __when_cond_ ) { ... }
459    // insert second recursion result in else { ... }
460    return new CompoundStmt ( loc,
461        {
462            new IfStmt( loc,
463                new NameExpr( loc, clauseData.at(clauseIdx)->whenName ),
464                thenStmt,
465                elseStmt
466            )
467        }
468    );
469}
470
471// typedef a fn ptr so that we can reuse genPredExpr
472// genLeafExpr is used to refer to one of the following two routines
473typedef Expr * (*GenLeafExpr)( const CodeLocation & loc, int & index );
474
475// return Expr that represents clause_statuses[index]
476// mutates index to be index + 1
477Expr * genSatExpr( const CodeLocation & loc, int & index ) {
478    return genArrAccessExpr( loc, index++, "clause_statuses" );
479}
480
481// return Expr that represents has_run(clause_statuses[index])
482Expr * genRunExpr( const CodeLocation & loc, int & index ) {
483    return new UntypedExpr ( loc, 
484        new NameExpr( loc, "__CFA_has_clause_run" ),
485        { genSatExpr( loc, index ) }
486    );
487}
488
489// Takes in the ClauseNode tree and recursively generates
490// the predicate expr used inside the predicate functions
491Expr * genPredExpr( const CodeLocation & loc, WaitUntilStmt::ClauseNode * currNode, int & idx, GenLeafExpr genLeaf ) {
492    switch (currNode->op) {
493        case WaitUntilStmt::ClauseNode::AND:
494            return new LogicalExpr( loc, 
495                new CastExpr( loc, genPredExpr( loc, currNode->left, idx, genLeaf ), new BasicType( BasicType::Kind::Bool ), GeneratedFlag::ExplicitCast ),
496                new CastExpr( loc, genPredExpr( loc, currNode->right, idx, genLeaf ), new BasicType( BasicType::Kind::Bool ), GeneratedFlag::ExplicitCast ), 
497                LogicalFlag::AndExpr
498            );
499        case WaitUntilStmt::ClauseNode::OR:
500            return new LogicalExpr( loc,
501                new CastExpr( loc, genPredExpr( loc, currNode->left, idx, genLeaf ), new BasicType( BasicType::Kind::Bool ), GeneratedFlag::ExplicitCast ),
502                new CastExpr( loc, genPredExpr( loc, currNode->right, idx, genLeaf ), new BasicType( BasicType::Kind::Bool ), GeneratedFlag::ExplicitCast ), 
503                LogicalFlag::OrExpr );
504        case WaitUntilStmt::ClauseNode::LEAF:
505            return genLeaf( loc, idx );
506        default:
507            assertf(false, "Unreachable waituntil clause node type. How did you get here???");
508    }
509}
510
511
512// Builds the predicate functions used to check the status of the waituntil statement
513/* Ex:
514{
515    waituntil( A ){ doA(); }
516    or waituntil( B ){ doB(); }
517    and waituntil( C ) { doC(); }
518}
519generates =>
520static inline bool is_full_sat_1( int * clause_statuses ) {
521    return clause_statuses[0]
522        || clause_statuses[1]
523        && clause_statuses[2];
524}
525
526static inline bool is_done_sat_1( int * clause_statuses ) {
527    return has_run(clause_statuses[0])
528        || has_run(clause_statuses[1])
529        && has_run(clause_statuses[2]);
530}
531*/
532// Returns a predicate function decl
533// predName and genLeaf determine if this generates an is_done or an is_full predicate
534FunctionDecl * buildPredicate( const WaitUntilStmt * stmt, GenLeafExpr genLeaf, string & predName ) {
535    int arrIdx = 0;
536    const CodeLocation & loc = stmt->location;
537    CompoundStmt * body = new CompoundStmt( loc );
538    body->push_back( new ReturnStmt( loc, genPredExpr( loc,  stmt->predicateTree, arrIdx, genLeaf ) ) );
539
540    return new FunctionDecl( loc,
541        predName,
542        {},                     // forall
543        {
544            new ObjectDecl( loc,
545                "clause_statuses",
546                new PointerType( new BasicType( BasicType::Kind::LongUnsignedInt ) )
547            )
548        },
549        { 
550            new ObjectDecl( loc,
551                "sat_ret",
552                new BasicType( BasicType::Kind::Bool )
553            )
554        },
555        body,               // body
556        { Storage::Static },    // storage
557        Linkage::Cforall,       // linkage
558        {},                     // attributes
559        { Function::Inline }
560    );
561}
562
563// Creates is_done and is_full predicates
564void GenerateWaitUntilCore::addPredicates( const WaitUntilStmt * stmt, string & satName, string & runName ) {
565    if ( !stmt->else_stmt || stmt->else_cond ) // don't need SAT predicate when else variation with no else_cond
566        satFns.push_back( Concurrency::buildPredicate( stmt, genSatExpr, satName ) ); 
567    satFns.push_back( Concurrency::buildPredicate( stmt, genRunExpr, runName ) );
568}
569
570// Adds the following to body:
571// if ( when_cond ) { // this if is omitted if no when() condition
572//      setup_clause( clause1, &clause_statuses[0], &park_counter );
573//      register_select(A, clause1);
574// }
575void GenerateWaitUntilCore::setUpClause( const WhenClause * clause, ClauseData * data, string & pCountName, CompoundStmt * body ) {   
576    CompoundStmt * currBody = body;
577    const CodeLocation & loc = clause->location;
578
579    // If we have a when_cond make the initialization conditional
580    if ( clause->when_cond )
581        currBody = new CompoundStmt( loc );
582
583    // Generates: setup_clause( clause1, &clause_statuses[0], &park_counter );
584    currBody->push_back( new ExprStmt( loc,
585        new UntypedExpr ( loc,
586            new NameExpr( loc, "setup_clause" ),
587            {
588                new NameExpr( loc, data->nodeName ),
589                new AddressExpr( loc, genArrAccessExpr( loc, data->index, data->statusName ) ),
590                new AddressExpr( loc, new NameExpr( loc, pCountName ) )
591            }
592        )
593    ));
594
595    // Generates: register_select(A, clause1);
596    currBody->push_back( new ExprStmt( loc, genSelectTraitCall( clause, data, "register_select" ) ) );
597
598    // generates: if ( when_cond ) { ... currBody ... }
599    if ( clause->when_cond )
600        body->push_back( 
601            new IfStmt( loc,
602                new NameExpr( loc, data->whenName ),
603                currBody
604            )
605        );
606}
607
608// Used to generate a call to one of the select trait routines
609Expr * GenerateWaitUntilCore::genSelectTraitCall( const WhenClause * clause, const ClauseData * data, string fnName ) {
610    const CodeLocation & loc = clause->location;
611    return new UntypedExpr ( loc,
612        new NameExpr( loc, fnName ),
613        {
614            new NameExpr( loc, data->targetName ),
615            new NameExpr( loc, data->nodeName )
616        }
617    );
618}
619
620// Generates:
621/* if ( on_selected( target_1, node_1 )) ... corresponding body of target_1 ...
622*/
623CompoundStmt * GenerateWaitUntilCore::genStmtBlock( const WhenClause * clause, const ClauseData * data ) {
624    const CodeLocation & cLoc = clause->location;
625    return new CompoundStmt( cLoc,
626        {
627            new IfStmt( cLoc,
628                genSelectTraitCall( clause, data, "on_selected" ),
629                new CompoundStmt( cLoc,
630                    {
631                        ast::deepCopy( clause->stmt )
632                    }
633                )
634            )
635        }
636    );
637}
638
639// this routine generates and returns the following
640/*for ( int i = 0; i < numClauses; i++ ) {
641    if ( predName(clause_statuses) ) break;
642    if (clause_statuses[i] == __SELECT_SAT) {
643        switch (i) {
644            case 0:
645                try {
646                    if (on_selected( target1, clause1 ))
647                        dotarget1stmt();
648                }
649                finally { clause_statuses[i] = __SELECT_RUN; unregister_select(target1, clause1); }
650                break;
651            ...
652            case N:
653                ...
654                break;
655        }
656    }
657}*/
658ForStmt * GenerateWaitUntilCore::genStatusCheckFor( const WaitUntilStmt * stmt, vector<ClauseData *> & clauseData, string & predName ) {
659    CompoundStmt * ifBody = new CompoundStmt( stmt->location );
660    const CodeLocation & loc = stmt->location;
661
662    /* generates:
663    switch (i) {
664        case 0:
665            try {
666                if (on_selected( target1, clause1 ))
667                    dotarget1stmt();
668            }
669            finally { clause_statuses[i] = __SELECT_RUN; unregister_select(target1, clause1); }
670            break;
671            ...
672        case N:
673            ...
674            break;
675    }*/
676    std::vector<ptr<CaseClause>> switchCases;
677    int idx = 0;
678    for ( const auto & clause: stmt->clauses ) {
679        const CodeLocation & cLoc = clause->location;
680        switchCases.push_back(
681            new CaseClause( cLoc,
682                ConstantExpr::from_int( cLoc, idx ),
683                {
684                    new CompoundStmt( cLoc,
685                        {
686                            new ast::TryStmt( cLoc,
687                                genStmtBlock( clause, clauseData.at(idx) ),
688                                {},
689                                new ast::FinallyClause( cLoc, 
690                                    new CompoundStmt( cLoc,
691                                        {
692                                            new ExprStmt( loc,
693                                                new UntypedExpr ( loc,
694                                                    new NameExpr( loc, "?=?" ),
695                                                    {
696                                                        new UntypedExpr ( loc, 
697                                                            new NameExpr( loc, "?[?]" ),
698                                                            {
699                                                                new NameExpr( loc, clauseData.at(0)->statusName ),
700                                                                new NameExpr( loc, idxName )
701                                                            }
702                                                        ),
703                                                        new NameExpr( loc, "__SELECT_RUN" )
704                                                    }
705                                                )
706                                            ),
707                                            new ExprStmt( loc, genSelectTraitCall( clause, clauseData.at(idx), "unregister_select" ) )
708                                        }
709                                    )
710                                )
711                            ),
712                            new BranchStmt( cLoc, BranchStmt::Kind::Break, Label( cLoc ) )
713                        }
714                    )
715                }
716            )
717        );
718        idx++;
719    }
720
721    ifBody->push_back(
722        new SwitchStmt( loc,
723            new NameExpr( loc, idxName ),
724            std::move( switchCases )
725        )
726    );
727
728    // gens:
729    // if (clause_statuses[i] == __SELECT_SAT) {
730    //      ... ifBody  ...
731    // }
732    IfStmt * ifSwitch = new IfStmt( loc,
733        new UntypedExpr ( loc,
734            new NameExpr( loc, "?==?" ),
735            {
736                new UntypedExpr ( loc, 
737                    new NameExpr( loc, "?[?]" ),
738                    {
739                        new NameExpr( loc, clauseData.at(0)->statusName ),
740                        new NameExpr( loc, idxName )
741                    }
742                ),
743                new NameExpr( loc, "__SELECT_SAT" )
744            }
745        ),      // condition
746        ifBody  // body
747    );
748
749    return new ForStmt( loc,
750        {
751            new DeclStmt( loc,
752                new ObjectDecl( loc,
753                    idxName,
754                    new BasicType( BasicType::Kind::SignedInt ),
755                    new SingleInit( loc, ConstantExpr::from_int( loc, 0 ) )
756                )
757            )
758        },  // inits
759        new UntypedExpr ( loc,
760            new NameExpr( loc, "?<?" ),
761            {
762                new NameExpr( loc, idxName ),
763                ConstantExpr::from_int( loc, stmt->clauses.size() )
764            }
765        ),  // cond
766        new UntypedExpr ( loc,
767            new NameExpr( loc, "?++" ),
768            { new NameExpr( loc, idxName ) }
769        ),  // inc
770        new CompoundStmt( loc,
771            {
772                new IfStmt( loc,
773                    new UntypedExpr ( loc,
774                        new NameExpr( loc, predName ),
775                        { new NameExpr( loc, clauseData.at(0)->statusName ) }
776                    ),
777                    new BranchStmt( loc, BranchStmt::Kind::Break, Label( loc ) )
778                ),
779                ifSwitch
780            }
781        )   // body
782    );
783}
784
785// Generates: !is_full_sat_n() / !is_run_sat_n()
786Expr * genNotSatExpr( const WaitUntilStmt * stmt, string & satName, string & arrName ) {
787    const CodeLocation & loc = stmt->location;
788    return new UntypedExpr ( loc,
789        new NameExpr( loc, "!?" ),
790        {
791            new UntypedExpr ( loc,
792                new NameExpr( loc, satName ),
793                { new NameExpr( loc, arrName ) }
794            )
795        }
796    );
797}
798
799// Generates the code needed for waituntils with an else ( ... )
800// Checks clauses once after registering for completion and runs them if completes
801// If not enough have run to satisfy predicate after one pass then the else is run
802Stmt * GenerateWaitUntilCore::genElseClauseBranch( const WaitUntilStmt * stmt, string & runName, string & arrName, vector<ClauseData *> & clauseData ) {
803    return new CompoundStmt( stmt->else_stmt->location,
804        {
805            genStatusCheckFor( stmt, clauseData, runName ),
806            new IfStmt( stmt->else_stmt->location,
807                genNotSatExpr( stmt, runName, arrName ),
808                ast::deepCopy( stmt->else_stmt )
809            )
810        }
811    );
812}
813
814Stmt * GenerateWaitUntilCore::genNoElseClauseBranch( const WaitUntilStmt * stmt, string & satName, string & runName, string & arrName, string & pCountName, vector<ClauseData *> & clauseData ) {
815    CompoundStmt * whileBody = new CompoundStmt( stmt->location );
816    const CodeLocation & loc = stmt->location;
817
818    // generates: __CFA_maybe_park( &park_counter );
819    whileBody->push_back(
820        new ExprStmt( loc,
821            new UntypedExpr ( loc,
822                new NameExpr( loc, "__CFA_maybe_park" ),
823                { new AddressExpr( loc, new NameExpr( loc, pCountName ) ) }
824            )
825        )
826    );
827
828    whileBody->push_back( genStatusCheckFor( stmt, clauseData, satName ) );
829
830    return new CompoundStmt( loc,
831        {
832            new WhileDoStmt( loc,
833                genNotSatExpr( stmt, satName, arrName ),
834                whileBody,  // body
835                {}          // no inits
836            ),
837            genStatusCheckFor( stmt, clauseData, runName )
838        }
839    );
840}
841
842// generates the following decls for each clause to ensure the target expr and when_cond is only evaluated once
843// typeof(target) & __clause_target_0 = target;
844// bool __when_cond_0 = when_cond; // only generated if when_cond defined
845// select_node clause1;
846void GenerateWaitUntilCore::genClauseInits( const WaitUntilStmt * stmt, vector<ClauseData *> & clauseData, CompoundStmt * body, string & statusName, string & elseWhenName ) {
847    ClauseData * currClause;
848    for ( vector<ClauseData*>::size_type i = 0; i < stmt->clauses.size(); i++ ) {
849        currClause = new ClauseData( i, statusName );
850        currClause->nodeName = namer_node.newName();
851        currClause->targetName = namer_target.newName();
852        currClause->whenName = namer_when.newName();
853        clauseData.push_back(currClause);
854        const CodeLocation & cLoc = stmt->clauses.at(i)->location;
855
856        // typeof(target) & __clause_target_0 = target;
857        body->push_back(
858            new DeclStmt( cLoc,
859                new ObjectDecl( cLoc,
860                    currClause->targetName,
861                    new ReferenceType( new TypeofType( ast::deepCopy( stmt->clauses.at(i)->target ) ) ),
862                    new SingleInit( cLoc, ast::deepCopy( stmt->clauses.at(i)->target ) )
863                )
864            )
865        );
866
867        // bool __when_cond_0 = when_cond; // only generated if when_cond defined
868        if ( stmt->clauses.at(i)->when_cond )
869            body->push_back(
870                new DeclStmt( cLoc,
871                    new ObjectDecl( cLoc,
872                        currClause->whenName,
873                        new BasicType( BasicType::Kind::Bool ),
874                        new SingleInit( cLoc, ast::deepCopy( stmt->clauses.at(i)->when_cond ) )
875                    )
876                )
877            );
878       
879        // select_node clause1;
880        body->push_back(
881            new DeclStmt( cLoc,
882                new ObjectDecl( cLoc,
883                    currClause->nodeName,
884                    new StructInstType( selectNodeDecl )
885                )
886            )
887        );
888    }
889
890    if ( stmt->else_stmt && stmt->else_cond ) {
891        body->push_back(
892            new DeclStmt( stmt->else_cond->location,
893                new ObjectDecl( stmt->else_cond->location,
894                    elseWhenName,
895                    new BasicType( BasicType::Kind::Bool ),
896                    new SingleInit( stmt->else_cond->location, ast::deepCopy( stmt->else_cond ) )
897                )
898            )
899        );
900    }
901}
902
903/*
904if ( clause_status == &clause1 ) ... clause 1 body ...
905...
906elif ( clause_status == &clausen ) ... clause n body ...
907*/
908Stmt * GenerateWaitUntilCore::buildOrCaseSwitch( const WaitUntilStmt * stmt, string & statusName, vector<ClauseData *> & data ) {
909    const CodeLocation & loc = stmt->location;
910
911    IfStmt * outerIf = nullptr;
912        IfStmt * lastIf = nullptr;
913
914        //adds an if/elif clause for each select clause address to run the corresponding clause stmt
915        for ( long unsigned int i = 0; i < data.size(); i++ ) {
916        const CodeLocation & cLoc = stmt->clauses.at(i)->location;
917
918                IfStmt * currIf = new IfStmt( cLoc,
919                        new UntypedExpr( cLoc, 
920                new NameExpr( cLoc, "?==?" ), 
921                {
922                    new NameExpr( cLoc, statusName ),
923                    new CastExpr( cLoc, 
924                        new AddressExpr( cLoc, new NameExpr( cLoc, data.at(i)->nodeName ) ),
925                        new BasicType( BasicType::Kind::LongUnsignedInt ), GeneratedFlag::ExplicitCast
926                    )
927                }
928            ),
929            genStmtBlock( stmt->clauses.at(i), data.at(i) )
930                );
931               
932                if ( i == 0 ) {
933                        outerIf = currIf;
934                } else {
935                        // add ifstmt to else of previous stmt
936                        lastIf->else_ = currIf;
937                }
938
939                lastIf = currIf;
940        }
941
942    // C_TODO: will remove this commented code later. Currently it isn't needed but may switch to a modified version of this later if it has better performance
943    // std::vector<ptr<CaseClause>> switchCases;
944
945    // int idx = 0;
946    // for ( const auto & clause: stmt->clauses ) {
947    //     const CodeLocation & cLoc = clause->location;
948    //     switchCases.push_back(
949    //         new CaseClause( cLoc,
950    //             new CastExpr( cLoc,
951    //                 new AddressExpr( cLoc, new NameExpr( cLoc, data.at(idx)->targetName ) ),
952    //                 new BasicType( BasicType::Kind::LongUnsignedInt ), GeneratedFlag::ExplicitCast
953    //             ),
954    //             {
955    //                 new CompoundStmt( cLoc,
956    //                     {
957    //                         ast::deepCopy( clause->stmt ),
958    //                         new BranchStmt( cLoc, BranchStmt::Kind::Break, Label( cLoc ) )
959    //                     }
960    //                 )
961    //             }
962    //         )
963    //     );
964    //     idx++;
965    // }
966
967    return new CompoundStmt( loc,
968        {
969            new ExprStmt( loc, new UntypedExpr( loc, new NameExpr( loc, "park" ) ) ),
970            outerIf
971            // new SwitchStmt( loc,
972            //     new NameExpr( loc, statusName ),
973            //     std::move( switchCases )
974            // )
975        }
976    );
977}
978
979Stmt * GenerateWaitUntilCore::recursiveOrIfGen( const WaitUntilStmt * stmt, vector<ClauseData *> & data, vector<ClauseData*>::size_type idx, string & elseWhenName ) {
980    if ( idx == data.size() ) {   // base case, gen last else
981        const CodeLocation & cLoc = stmt->else_stmt->location;
982        if ( !stmt->else_stmt ) // normal non-else gen
983            return buildOrCaseSwitch( stmt, data.at(0)->statusName, data );
984
985        Expr * raceFnCall = new UntypedExpr( stmt->location,
986            new NameExpr( stmt->location, "__select_node_else_race" ),
987            { new NameExpr( stmt->location, data.at(0)->nodeName ) }
988        );
989
990        if ( stmt->else_stmt && stmt->else_cond ) { // return else conditional on both when and race
991            return new IfStmt( cLoc,
992                new LogicalExpr( cLoc,
993                    new CastExpr( cLoc,
994                        new NameExpr( cLoc, elseWhenName ),
995                        new BasicType( BasicType::Kind::Bool ), GeneratedFlag::ExplicitCast
996                    ),
997                    new CastExpr( cLoc,
998                        raceFnCall,
999                        new BasicType( BasicType::Kind::Bool ), GeneratedFlag::ExplicitCast
1000                    ),
1001                    LogicalFlag::AndExpr
1002                ),
1003                ast::deepCopy( stmt->else_stmt ),
1004                buildOrCaseSwitch( stmt, data.at(0)->statusName, data )
1005            );
1006        }
1007
1008        // return else conditional on race
1009        return new IfStmt( stmt->else_stmt->location,
1010            raceFnCall,
1011            ast::deepCopy( stmt->else_stmt ),
1012            buildOrCaseSwitch( stmt, data.at(0)->statusName, data )
1013        );
1014    }
1015    const CodeLocation & cLoc = stmt->clauses.at(idx)->location;
1016
1017    Expr * ifCond;
1018
1019    // If we have a when_cond make the register call conditional on it
1020    if ( stmt->clauses.at(idx)->when_cond ) {
1021        ifCond = new LogicalExpr( cLoc,
1022            new CastExpr( cLoc,
1023                new NameExpr( cLoc, data.at(idx)->whenName ), 
1024                new BasicType( BasicType::Kind::Bool ), GeneratedFlag::ExplicitCast
1025            ),
1026            new CastExpr( cLoc,
1027                genSelectTraitCall( stmt->clauses.at(idx), data.at(idx), "register_select" ),
1028                new BasicType( BasicType::Kind::Bool ), GeneratedFlag::ExplicitCast
1029            ),
1030            LogicalFlag::AndExpr
1031        );
1032    } else ifCond = genSelectTraitCall( stmt->clauses.at(idx), data.at(idx), "register_select" );
1033
1034    return new CompoundStmt( cLoc,
1035        {   // gens: setup_clause( clause1, &status, 0p );
1036            new ExprStmt( cLoc,
1037                new UntypedExpr ( cLoc,
1038                    new NameExpr( cLoc, "setup_clause" ),
1039                    {
1040                        new NameExpr( cLoc, data.at(idx)->nodeName ),
1041                        new AddressExpr( cLoc, new NameExpr( cLoc, data.at(idx)->statusName ) ),
1042                        ConstantExpr::null( cLoc, new PointerType( new BasicType( BasicType::Kind::SignedInt ) ) )
1043                    }
1044                )
1045            ),
1046            // gens: if (__when_cond && register_select()) { clause body } else { ... recursiveOrIfGen ... }
1047            new IfStmt( cLoc,
1048                ifCond,
1049                genStmtBlock( stmt->clauses.at(idx), data.at(idx) ),
1050                // ast::deepCopy( stmt->clauses.at(idx)->stmt ),
1051                recursiveOrIfGen( stmt, data, idx + 1, elseWhenName )
1052            )
1053        }
1054    );
1055}
1056
1057// This gens the special case of an all OR waituntil:
1058/*
1059int status = 0;
1060
1061typeof(target) & __clause_target_0 = target;
1062bool __when_cond_0 = when_cond; // only generated if when_cond defined
1063select_node clause1;
1064... generate above for rest of clauses ...
1065
1066try {
1067    setup_clause( clause1, &status, 0p );
1068    if ( __when_cond_0 && register_select( 1 ) ) {
1069        ... clause 1 body ...
1070    } else {
1071        ... recursively gen for each of n clauses ...
1072        setup_clause( clausen, &status, 0p );
1073        if ( __when_cond_n-1 && register_select( n ) ) {
1074            ... clause n body ...
1075        } else {
1076            if ( else_when ) ... else clause body ...
1077            else {
1078                park();
1079
1080                // after winning the race and before unpark() clause_status is set to be the winning clause index + 1
1081                if ( clause_status == &clause1) ... clause 1 body ...
1082                ...
1083                elif ( clause_status == &clausen ) ... clause n body ...
1084            }
1085        }
1086    }
1087}
1088finally {
1089    if ( __when_cond_1 && clause1.status != 0p) unregister_select( 1 ); // if registered unregister
1090    ...
1091    if ( __when_cond_n && clausen.status != 0p) unregister_select( n );
1092}
1093*/
1094Stmt * GenerateWaitUntilCore::genAllOr( const WaitUntilStmt * stmt ) {
1095    const CodeLocation & loc = stmt->location;
1096    string statusName = namer_status.newName();
1097    string elseWhenName = namer_when.newName();
1098    int numClauses = stmt->clauses.size();
1099    CompoundStmt * body = new CompoundStmt( stmt->location );
1100
1101    // Generates: unsigned long int status = 0;
1102    body->push_back( new DeclStmt( loc,
1103        new ObjectDecl( loc,
1104            statusName,
1105            new BasicType( BasicType::Kind::LongUnsignedInt ),
1106            new SingleInit( loc, ConstantExpr::from_int( loc, 0 ) )
1107        )
1108    ));
1109
1110    vector<ClauseData *> clauseData;
1111    genClauseInits( stmt, clauseData, body, statusName, elseWhenName );
1112
1113    vector<int> whenIndices; // track which clauses have whens
1114
1115    CompoundStmt * unregisters = new CompoundStmt( loc );
1116    Expr * ifCond;
1117    for ( int i = 0; i < numClauses; i++ ) {
1118        const CodeLocation & cLoc = stmt->clauses.at(i)->location;
1119        // Gens: node.status != 0p
1120        UntypedExpr * statusPtrCheck = new UntypedExpr( cLoc, 
1121            new NameExpr( cLoc, "?!=?" ), 
1122            {
1123                ConstantExpr::null( cLoc, new PointerType( new BasicType( BasicType::Kind::LongUnsignedInt ) ) ),
1124                new UntypedExpr( cLoc, 
1125                    new NameExpr( cLoc, "__get_clause_status" ), 
1126                    { new NameExpr( cLoc, clauseData.at(i)->nodeName ) } 
1127                ) 
1128            }
1129        );
1130
1131        // If we have a when_cond make the unregister call conditional on it
1132        if ( stmt->clauses.at(i)->when_cond ) {
1133            whenIndices.push_back(i);
1134            ifCond = new LogicalExpr( cLoc,
1135                new CastExpr( cLoc,
1136                    new NameExpr( cLoc, clauseData.at(i)->whenName ), 
1137                    new BasicType( BasicType::Kind::Bool ), GeneratedFlag::ExplicitCast
1138                ),
1139                new CastExpr( cLoc,
1140                    statusPtrCheck,
1141                    new BasicType( BasicType::Kind::Bool ), GeneratedFlag::ExplicitCast
1142                ),
1143                LogicalFlag::AndExpr
1144            );
1145        } else ifCond = statusPtrCheck;
1146       
1147        unregisters->push_back(
1148            new IfStmt( cLoc,
1149                ifCond,
1150                new ExprStmt( cLoc, genSelectTraitCall( stmt->clauses.at(i), clauseData.at(i), "unregister_select" ) ) 
1151            )
1152        );
1153    }
1154
1155    if ( whenIndices.empty() || whenIndices.size() != stmt->clauses.size() ) {
1156        body->push_back(
1157                new ast::TryStmt( loc,
1158                new CompoundStmt( loc, { recursiveOrIfGen( stmt, clauseData, 0, elseWhenName ) } ),
1159                {},
1160                new ast::FinallyClause( loc, unregisters )
1161            )
1162        );
1163    } else { // If all clauses have whens, we need to skip the waituntil if they are all false
1164        Expr * outerIfCond = new NameExpr( loc, clauseData.at( whenIndices.at(0) )->whenName );
1165        Expr * lastExpr = outerIfCond;
1166
1167        for ( vector<int>::size_type i = 1; i < whenIndices.size(); i++ ) {
1168            outerIfCond = new LogicalExpr( loc,
1169                new CastExpr( loc,
1170                    new NameExpr( loc, clauseData.at( whenIndices.at(i) )->whenName ), 
1171                    new BasicType( BasicType::Kind::Bool ), GeneratedFlag::ExplicitCast
1172                ),
1173                new CastExpr( loc,
1174                    lastExpr,
1175                    new BasicType( BasicType::Kind::Bool ), GeneratedFlag::ExplicitCast
1176                ),
1177                LogicalFlag::OrExpr
1178            );
1179            lastExpr = outerIfCond;
1180        }
1181
1182        body->push_back(
1183                new ast::TryStmt( loc,
1184                new CompoundStmt( loc, 
1185                    {
1186                        new IfStmt( loc,
1187                            outerIfCond,
1188                            recursiveOrIfGen( stmt, clauseData, 0, elseWhenName )
1189                        )
1190                    }
1191                ),
1192                {},
1193                new ast::FinallyClause( loc, unregisters )
1194            )
1195        );
1196    }
1197
1198    for ( ClauseData * datum : clauseData )
1199        delete datum;
1200
1201    return body;
1202}
1203
1204Stmt * GenerateWaitUntilCore::postvisit( const WaitUntilStmt * stmt ) {
1205    if ( !selectNodeDecl )
1206        SemanticError( stmt, "waituntil statement requires #include <waituntil.hfa>" );
1207
1208    // Prep clause tree to figure out how to set initial statuses
1209    // setTreeSizes( stmt->predicateTree );
1210    if ( paintWhenTree( stmt->predicateTree ) ) // if this returns true we can special case since tree is all OR's
1211        return genAllOr( stmt );
1212
1213    CompoundStmt * tryBody = new CompoundStmt( stmt->location );
1214    CompoundStmt * body = new CompoundStmt( stmt->location );
1215    string statusArrName = namer_status.newName();
1216    string pCountName = namer_park.newName();
1217    string satName = namer_sat.newName();
1218    string runName = namer_run.newName();
1219    string elseWhenName = namer_when.newName();
1220    int numClauses = stmt->clauses.size();
1221    addPredicates( stmt, satName, runName );
1222
1223    const CodeLocation & loc = stmt->location;
1224
1225    // Generates: int park_counter = 0;
1226    body->push_back( new DeclStmt( loc,
1227        new ObjectDecl( loc,
1228            pCountName,
1229            new BasicType( BasicType::Kind::SignedInt ),
1230            new SingleInit( loc, ConstantExpr::from_int( loc, 0 ) )
1231        )
1232    ));
1233
1234    // Generates: int clause_statuses[3] = { 0 };
1235    body->push_back( new DeclStmt( loc,
1236        new ObjectDecl( loc,
1237            statusArrName,
1238            new ArrayType( new BasicType( BasicType::Kind::LongUnsignedInt ), ConstantExpr::from_int( loc, numClauses ), LengthFlag::FixedLen, DimensionFlag::DynamicDim ),
1239            new ListInit( loc,
1240                {
1241                    new SingleInit( loc, ConstantExpr::from_int( loc, 0 ) )
1242                }
1243            )
1244        )
1245    ));
1246
1247    vector<ClauseData *> clauseData;
1248    genClauseInits( stmt, clauseData, body, statusArrName, elseWhenName );
1249
1250    vector<pair<int, WaitUntilStmt::ClauseNode *>> ambiguousClauses;       // list of ambiguous clauses
1251    vector<int> andWhenClauses;    // list of clauses that have an AND op as a direct parent and when_cond defined
1252
1253    collectWhens( stmt->predicateTree, ambiguousClauses, andWhenClauses );
1254
1255    // This is only needed for clauses that have AND as a parent and a when_cond defined
1256    // generates: if ( ! when_cond_0 ) clause_statuses_0 = __SELECT_RUN;
1257    for ( int idx : andWhenClauses ) {
1258        const CodeLocation & cLoc = stmt->clauses.at(idx)->location;
1259        body->push_back( 
1260            new IfStmt( cLoc,
1261                new UntypedExpr ( cLoc,
1262                    new NameExpr( cLoc, "!?" ),
1263                    { new NameExpr( cLoc, clauseData.at(idx)->whenName ) }
1264                ),  // IfStmt cond
1265                new ExprStmt( cLoc,
1266                    new UntypedExpr ( cLoc,
1267                        new NameExpr( cLoc, "?=?" ),
1268                        {
1269                            new UntypedExpr ( cLoc, 
1270                                new NameExpr( cLoc, "?[?]" ),
1271                                {
1272                                    new NameExpr( cLoc, statusArrName ),
1273                                    ConstantExpr::from_int( cLoc, idx )
1274                                }
1275                            ),
1276                            new NameExpr( cLoc, "__SELECT_RUN" )
1277                        }
1278                    )
1279                )  // IfStmt then
1280            )
1281        );
1282    }
1283
1284    // Only need to generate conditional initial state setting for ambiguous when clauses
1285    if ( !ambiguousClauses.empty() ) {
1286        body->push_back( genWhenStateConditions( stmt, clauseData, ambiguousClauses, 0 ) );
1287    }
1288
1289    // generates the following for each clause:
1290    // setup_clause( clause1, &clause_statuses[0], &park_counter );
1291    // register_select(A, clause1);
1292    for ( int i = 0; i < numClauses; i++ ) {
1293        setUpClause( stmt->clauses.at(i), clauseData.at(i), pCountName, tryBody );
1294    }
1295
1296    // generate satisfy logic based on if there is an else clause and if it is conditional
1297    if ( stmt->else_stmt && stmt->else_cond ) { // gen both else/non else branches
1298        tryBody->push_back(
1299            new IfStmt( stmt->else_cond->location,
1300                new NameExpr( stmt->else_cond->location, elseWhenName ),
1301                genElseClauseBranch( stmt, runName, statusArrName, clauseData ),
1302                genNoElseClauseBranch( stmt, satName, runName, statusArrName, pCountName, clauseData )
1303            )
1304        );
1305    } else if ( !stmt->else_stmt ) { // normal gen
1306        tryBody->push_back( genNoElseClauseBranch( stmt, satName, runName, statusArrName, pCountName, clauseData ) );
1307    } else { // generate just else
1308        tryBody->push_back( genElseClauseBranch( stmt, runName, statusArrName, clauseData ) );
1309    }
1310
1311    CompoundStmt * unregisters = new CompoundStmt( loc );
1312    // generates for each clause:
1313    // if ( !has_run( clause_statuses[i] ) )
1314    // OR if when_cond defined
1315    // if ( when_cond_i && !has_run( clause_statuses[i] ) )
1316    // body of if is:
1317    // { if (unregister_select(A, clause1) && on_selected(A, clause1)) clause1->stmt; } // this conditionally runs the block unregister_select returns true (needed by some primitives)
1318    Expr * ifCond;
1319    UntypedExpr * statusExpr; // !clause_statuses[i]
1320    for ( int i = 0; i < numClauses; i++ ) {
1321        const CodeLocation & cLoc = stmt->clauses.at(i)->location;
1322
1323        statusExpr = new UntypedExpr ( cLoc,
1324            new NameExpr( cLoc, "!?" ),
1325            {
1326                new UntypedExpr ( cLoc, 
1327                    new NameExpr( cLoc, "__CFA_has_clause_run" ),
1328                    {
1329                        genArrAccessExpr( cLoc, i, statusArrName )
1330                    }
1331                )
1332            }
1333        );
1334
1335        if ( stmt->clauses.at(i)->when_cond ) {
1336            // generates: if( when_cond_i && !has_run(clause_statuses[i]) )
1337            ifCond = new LogicalExpr( cLoc,
1338                new CastExpr( cLoc,
1339                    new NameExpr( cLoc, clauseData.at(i)->whenName ), 
1340                    new BasicType( BasicType::Kind::Bool ), GeneratedFlag::ExplicitCast
1341                ),
1342                new CastExpr( cLoc,
1343                    statusExpr,
1344                    new BasicType( BasicType::Kind::Bool ), GeneratedFlag::ExplicitCast
1345                ),
1346                LogicalFlag::AndExpr
1347            );
1348        } else // generates: if( !clause_statuses[i] )
1349            ifCond = statusExpr;
1350       
1351        unregisters->push_back( 
1352            new IfStmt( cLoc,
1353                ifCond,
1354                new CompoundStmt( cLoc,
1355                    {
1356                        new IfStmt( cLoc,
1357                            genSelectTraitCall( stmt->clauses.at(i), clauseData.at(i), "unregister_select" ),
1358                            // ast::deepCopy( stmt->clauses.at(i)->stmt )
1359                            genStmtBlock( stmt->clauses.at(i), clauseData.at(i) )
1360                        )
1361                    }
1362                )
1363               
1364            )
1365        );
1366    }
1367
1368    body->push_back( 
1369        new ast::TryStmt(
1370            loc,
1371            tryBody,
1372            {},
1373            new ast::FinallyClause( loc, unregisters )
1374        )
1375    );
1376
1377    for ( ClauseData * datum : clauseData )
1378        delete datum;
1379
1380    return body;
1381}
1382
1383// To add the predicates at global scope we need to do it in a second pass
1384// Predicates are added after "struct select_node { ... };"
1385class AddPredicateDecls final : public WithDeclsToAdd<> {
1386    vector<FunctionDecl *> & satFns;
1387    const StructDecl * selectNodeDecl = nullptr;
1388
1389  public:
1390    void previsit( const StructDecl * decl ) {
1391        if ( !decl->body ) {
1392            return;
1393        } else if ( "select_node" == decl->name ) {
1394            assert( !selectNodeDecl );
1395            selectNodeDecl = decl;
1396            for ( FunctionDecl * fn : satFns )
1397                declsToAddAfter.push_back(fn);           
1398        }
1399    }
1400    AddPredicateDecls( vector<FunctionDecl *> & satFns ): satFns(satFns) {}
1401};
1402
1403void generateWaitUntil( TranslationUnit & translationUnit ) {
1404    vector<FunctionDecl *> satFns;
1405        Pass<GenerateWaitUntilCore>::run( translationUnit, satFns );
1406    Pass<AddPredicateDecls>::run( translationUnit, satFns );
1407}
1408
1409} // namespace Concurrency
1410
1411// Local Variables: //
1412// tab-width: 4 //
1413// mode: c++ //
1414// compile-command: "make install" //
1415// End: //
Note: See TracBrowser for help on using the repository browser.