source: src/Concurrency/Waituntil.cpp @ a3c7bac

Last change on this file since a3c7bac was b93bf85, checked in by caparsons <caparson@…>, 12 months ago

fixed spurious channel close waituntil error case. Was caused by a race condition causing an exception to be thrown while another was in flight

  • Property mode set to 100644
File size: 57.8 KB
Line 
1//
2// Cforall Version 1.0.0 Copyright (C) 2016 University of Waterloo
3//
4// The contents of this file are covered under the licence agreement in the
5// file "LICENCE" distributed with Cforall.
6//
7// WaitforNew.cpp -- Expand waitfor clauses into code.
8//
9// Author           : Andrew Beach
10// Created On       : Fri May 27 10:31:00 2022
11// Last Modified By : Andrew Beach
12// Last Modified On : Tue Jun 13 13:30:00 2022
13// Update Count     : 0
14//
15
16#include "Waituntil.hpp"
17
18#include <string>
19
20#include "AST/Copy.hpp"
21#include "AST/Expr.hpp"
22#include "AST/Pass.hpp"
23#include "AST/Print.hpp"
24#include "AST/Stmt.hpp"
25#include "AST/Type.hpp"
26#include "Common/UniqueName.h"
27
28using namespace ast;
29using namespace std;
30
31/* So this is what this pass dones:
32{
33    when ( condA ) waituntil( A ){ doA(); }
34    or when ( condB ) waituntil( B ){ doB(); }
35    and when ( condC ) waituntil( C ) { doC(); }
36}
37                 ||
38                 ||
39                \||/
40                 \/
41
42Generates these two routines:
43static inline bool is_full_sat_1( int * clause_statuses ) {
44    return clause_statuses[0]
45        || clause_statuses[1]
46        && clause_statuses[2];
47}
48
49static inline bool is_done_sat_1( int * clause_statuses ) {
50    return has_run(clause_statuses[0])
51        || has_run(clause_statuses[1])
52        && has_run(clause_statuses[2]);
53}
54
55Replaces the waituntil statement above with the following code:
56{
57    // used with atomic_dec/inc to get binary semaphore behaviour
58    int park_counter = 0;
59
60    // status (one for each clause)
61    int clause_statuses[3] = { 0 };
62
63    bool whenA = condA;
64    bool whenB = condB;
65    bool whenC = condC;
66
67    if ( !whenB ) clause_statuses[1] = __SELECT_RUN;
68    if ( !whenC ) clause_statuses[2] = __SELECT_RUN;
69
70    // some other conditional settors for clause_statuses are set here, see genSubtreeAssign and related routines
71
72    // three blocks
73    // for each block, create, setup, then register select_node
74    select_node clause1;
75    select_node clause2;
76    select_node clause3;
77
78    try {
79        if ( whenA ) { register_select(A, clause1); setup_clause( clause1, &clause_statuses[0], &park_counter ); }
80        ... repeat ^ for B and C ...
81
82        // if else clause is defined a separate branch can occur here to set initial values, see genWhenStateConditions
83
84        // loop & park until done
85        while( !is_full_sat_1( clause_statuses ) ) {
86           
87            // binary sem P();
88            if ( __atomic_sub_fetch( &park_counter, 1, __ATOMIC_SEQ_CST) < 0 )
89                park();
90           
91            // execute any blocks available with status set to 0
92            for ( int i = 0; i < 3; i++ ) {
93                if (clause_statuses[i] == __SELECT_SAT) {
94                    switch (i) {
95                        case 0:
96                            try {
97                                    on_selected( A, clause1 );
98                                    doA();
99                            }
100                            finally { clause_statuses[i] = __SELECT_RUN; unregister_select(A, clause1); }
101                            break;
102                        case 1:
103                            ... same gen as A but for B and clause2 ...
104                            break;
105                        case 2:
106                            ... same gen as A but for C and clause3 ...
107                            break;
108                    }
109                }
110            }
111        }
112
113        // ensure that the blocks that triggered is_full_sat_1 are run
114        // by running every un-run block that is SAT from the start until
115        // the predicate is SAT when considering RUN status = true
116        for ( int i = 0; i < 3; i++ ) {
117            if (is_done_sat_1( clause_statuses )) break;
118            if (clause_statuses[i] == __SELECT_SAT)
119                ... Same if body here as in loop above ...
120        }
121    } finally {
122        // the unregister and on_selected calls are needed to support primitives where the acquire has side effects
123        // so the corresponding block MUST be run for those primitives to not lose state (example is channels)
124        if ( !has_run(clause_statuses[0]) && whenA && unregister_select(A, clause1) )
125            on_selected( A, clause1 )
126            doA();
127        ... repeat if above for B and C ...
128    }
129}
130
131*/
132
133namespace Concurrency {
134
135class GenerateWaitUntilCore final {
136    vector<FunctionDecl *> & satFns;
137        UniqueName namer_sat = "__is_full_sat_"s;
138    UniqueName namer_run = "__is_run_sat_"s;
139        UniqueName namer_park = "__park_counter_"s;
140        UniqueName namer_status = "__clause_statuses_"s;
141        UniqueName namer_node = "__clause_"s;
142    UniqueName namer_target = "__clause_target_"s;
143    UniqueName namer_when = "__when_cond_"s;
144    UniqueName namer_label = "__waituntil_label_"s;
145
146    string idxName = "__CFA_clause_idx_";
147
148    struct ClauseData {
149        string nodeName;
150        string targetName;
151        string whenName;
152        int index;
153        string & statusName;
154        ClauseData( int index, string & statusName ) : index(index), statusName(statusName) {}
155    };
156
157    const StructDecl * selectNodeDecl = nullptr;
158
159    // This first set of routines are all used to do the complicated job of
160    //    dealing with how to set predicate statuses with certain when_conds T/F
161    //    so that the when_cond == F effectively makes that clause "disappear"
162    void updateAmbiguousWhen( WaitUntilStmt::ClauseNode * currNode, bool andAbove, bool orAbove, bool andBelow, bool orBelow );
163    void paintWhenTree( WaitUntilStmt::ClauseNode * currNode, bool andAbove, bool orAbove, bool & andBelow, bool & orBelow );
164    bool paintWhenTree( WaitUntilStmt::ClauseNode * currNode );
165    void collectWhens( WaitUntilStmt::ClauseNode * currNode, vector<pair<int, WaitUntilStmt::ClauseNode *>> & ambigIdxs, vector<int> & andIdxs, int & index, bool parentAmbig, bool parentAnd );
166    void collectWhens( WaitUntilStmt::ClauseNode * currNode, vector<pair<int, WaitUntilStmt::ClauseNode *>> & ambigIdxs, vector<int> & andIdxs );
167    void updateWhenState( WaitUntilStmt::ClauseNode * currNode );
168    void genSubtreeAssign( const WaitUntilStmt * stmt, WaitUntilStmt::ClauseNode * currNode, bool status, int & idx, CompoundStmt * retStmt, vector<ClauseData *> & clauseData );
169    void genStatusAssign( const WaitUntilStmt * stmt, WaitUntilStmt::ClauseNode * currNode, int & idx, CompoundStmt * retStmt, vector<ClauseData *> & clauseData );
170    CompoundStmt * getStatusAssignment( const WaitUntilStmt * stmt, vector<ClauseData *> & clauseData );
171    Stmt * genWhenStateConditions( const WaitUntilStmt * stmt, vector<ClauseData *> & clauseData, vector<pair<int, WaitUntilStmt::ClauseNode *>> & ambigClauses, vector<pair<int, WaitUntilStmt::ClauseNode *>>::size_type ambigIdx );
172
173    // These routines are just code-gen helpers
174    void addPredicates( const WaitUntilStmt * stmt, string & satName, string & runName );
175    void setUpClause( const WhenClause * clause, ClauseData * data, string & pCountName, CompoundStmt * body );
176    CompoundStmt * genStatusCheckFor( const WaitUntilStmt * stmt, vector<ClauseData *> & clauseData, string & predName );
177    Expr * genSelectTraitCall( const WhenClause * clause, const ClauseData * data, string fnName );
178    CompoundStmt * genStmtBlock( const WhenClause * clause, const ClauseData * data );
179    Stmt * genElseClauseBranch( const WaitUntilStmt * stmt, string & runName, string & arrName, vector<ClauseData *> & clauseData );
180    Stmt * genNoElseClauseBranch( const WaitUntilStmt * stmt, string & satName, string & runName, string & arrName, string & pCountName, vector<ClauseData *> & clauseData );
181    void genClauseInits( const WaitUntilStmt * stmt, vector<ClauseData *> & clauseData, CompoundStmt * body, string & statusName, string & elseWhenName );
182    Stmt * recursiveOrIfGen( const WaitUntilStmt * stmt, vector<ClauseData *> & data, vector<ClauseData*>::size_type idx, string & elseWhenName );
183    Stmt * buildOrCaseSwitch( const WaitUntilStmt * stmt, string & statusName, vector<ClauseData *> & data );
184    Stmt * genAllOr( const WaitUntilStmt * stmt );
185
186  public:
187    void previsit( const StructDecl * decl );
188        Stmt * postvisit( const WaitUntilStmt * stmt );
189    GenerateWaitUntilCore( vector<FunctionDecl *> & satFns ): satFns(satFns) {}
190};
191
192// Finds select_node decl
193void GenerateWaitUntilCore::previsit( const StructDecl * decl ) {
194    if ( !decl->body ) {
195                return;
196        } else if ( "select_node" == decl->name ) {
197                assert( !selectNodeDecl );
198                selectNodeDecl = decl;
199        }
200}
201
202void GenerateWaitUntilCore::updateAmbiguousWhen( WaitUntilStmt::ClauseNode * currNode, bool andAbove, bool orAbove, bool andBelow, bool orBelow ) {
203    // all children when-ambiguous
204    if ( currNode->left->ambiguousWhen && currNode->right->ambiguousWhen )
205        // true iff an ancestor/descendant has a different operation
206        currNode->ambiguousWhen = (orAbove || orBelow) && (andBelow || andAbove);
207    // ambiguousWhen is initially false so theres no need to set it here
208}
209
210// Traverses ClauseNode tree and paints each AND/OR node as when-ambiguous true or false
211// This tree painting is needed to generate the if statements that set the initial state
212//    of the clause statuses when some clauses are turned off via when_cond
213// An internal AND/OR node is when-ambiguous if it satisfies all of the following:
214// - It has an ancestor or descendant that is a different operation, i.e. (AND has an OR ancestor or vice versa)
215// - All of its descendent clauses are optional, i.e. they have a when_cond defined on the WhenClause
216void GenerateWaitUntilCore::paintWhenTree( WaitUntilStmt::ClauseNode * currNode, bool andAbove, bool orAbove, bool & andBelow, bool & orBelow ) {
217    bool aBelow = false; // updated by child nodes
218    bool oBelow = false; // updated by child nodes
219    switch (currNode->op) {
220        case WaitUntilStmt::ClauseNode::AND:
221            paintWhenTree( currNode->left, true, orAbove, aBelow, oBelow );
222            paintWhenTree( currNode->right, true, orAbove, aBelow, oBelow );
223
224            // update currNode's when flag based on conditions listed in fn signature comment above
225            updateAmbiguousWhen(currNode, true, orAbove, aBelow, oBelow );
226
227            // set return flags to tell parents which decendant ops have been seen
228            andBelow = true;
229            orBelow = oBelow;
230            return;
231        case WaitUntilStmt::ClauseNode::OR:
232            paintWhenTree( currNode->left, andAbove, true, aBelow, oBelow );
233            paintWhenTree( currNode->right, andAbove, true, aBelow, oBelow );
234
235            // update currNode's when flag based on conditions listed in fn signature comment above
236            updateAmbiguousWhen(currNode, andAbove, true, aBelow, oBelow );
237
238            // set return flags to tell parents which decendant ops have been seen
239            andBelow = aBelow;
240            orBelow = true;
241            return;
242        case WaitUntilStmt::ClauseNode::LEAF:
243            if ( currNode->leaf->when_cond )
244                currNode->ambiguousWhen = true;
245            return;
246        default:
247            assertf(false, "Unreachable waituntil clause node type. How did you get here???");
248    }
249}
250
251// overloaded wrapper for paintWhenTree that sets initial values
252// returns true if entire tree is OR's (special case)
253bool GenerateWaitUntilCore::paintWhenTree( WaitUntilStmt::ClauseNode * currNode ) {
254    bool aBelow = false, oBelow = false; // unused by initial call
255    paintWhenTree( currNode, false, false, aBelow, oBelow );
256    return !aBelow;
257}
258
259// Helper: returns Expr that represents arrName[index]
260Expr * genArrAccessExpr( const CodeLocation & loc, int index, string arrName ) {
261    return new UntypedExpr ( loc, 
262        new NameExpr( loc, "?[?]" ),
263        {
264            new NameExpr( loc, arrName ),
265            ConstantExpr::from_int( loc, index )
266        }
267    );
268}
269
270// After the ClauseNode AND/OR nodes are painted this routine is called to traverses the tree and does the following:
271// - collects a set of indices in the clause arr that refer whenclauses that can have ambiguous status assignments (ambigIdxs)
272// - collects a set of indices in the clause arr that refer whenclauses that have a when() defined and an AND node as a parent (andIdxs)
273// - updates LEAF nodes to be when-ambiguous if their direct parent is when-ambiguous.
274void GenerateWaitUntilCore::collectWhens( WaitUntilStmt::ClauseNode * currNode, vector<pair<int, WaitUntilStmt::ClauseNode *>> & ambigIdxs, vector<int> & andIdxs, int & index, bool parentAmbig, bool parentAnd ) {
275    switch (currNode->op) {
276        case WaitUntilStmt::ClauseNode::AND:
277            collectWhens( currNode->left, ambigIdxs, andIdxs, index, currNode->ambiguousWhen, true );
278            collectWhens( currNode->right,  ambigIdxs, andIdxs, index, currNode->ambiguousWhen, true );
279            return;
280        case WaitUntilStmt::ClauseNode::OR:
281            collectWhens( currNode->left,  ambigIdxs, andIdxs, index, currNode->ambiguousWhen, false );
282            collectWhens( currNode->right,  ambigIdxs, andIdxs, index, currNode->ambiguousWhen, false );
283            return;
284        case WaitUntilStmt::ClauseNode::LEAF:
285            if ( parentAmbig ) {
286                ambigIdxs.push_back(make_pair(index, currNode));
287            }
288            if ( parentAnd && currNode->leaf->when_cond ) {
289                currNode->childOfAnd = true;
290                andIdxs.push_back(index);
291            }
292            index++;
293            return;
294        default:
295            assertf(false, "Unreachable waituntil clause node type. How did you get here???");
296    }
297}
298
299// overloaded wrapper for collectWhens that sets initial values
300void GenerateWaitUntilCore::collectWhens( WaitUntilStmt::ClauseNode * currNode, vector<pair<int, WaitUntilStmt::ClauseNode *>> & ambigIdxs, vector<int> & andIdxs ) {
301    int idx = 0;
302    collectWhens( currNode, ambigIdxs, andIdxs, idx, false, false );
303}
304
305// recursively updates ClauseNode whenState on internal nodes so that next pass can see which
306//    subtrees are "turned off"
307// sets whenState = false iff both children have whenState == false.
308// similar to paintWhenTree except since paintWhenTree also filtered out clauses we don't need to consider based on op
309// since the ambiguous clauses were filtered in paintWhenTree we don't need to worry about that here
310void GenerateWaitUntilCore::updateWhenState( WaitUntilStmt::ClauseNode * currNode ) {
311    if ( currNode->op == WaitUntilStmt::ClauseNode::LEAF ) return;
312    updateWhenState( currNode->left );
313    updateWhenState( currNode->right );
314    if ( !currNode->left->whenState && !currNode->right->whenState )
315        currNode->whenState = false;
316    else 
317        currNode->whenState = true;
318}
319
320// generates the minimal set of status assignments to ensure predicate subtree passed as currNode evaluates to status
321// assumes that this will only be called on subtrees that are entirely whenState == false
322void GenerateWaitUntilCore::genSubtreeAssign( const WaitUntilStmt * stmt, WaitUntilStmt::ClauseNode * currNode, bool status, int & idx, CompoundStmt * retStmt, vector<ClauseData *> & clauseData ) {
323    if ( ( currNode->op == WaitUntilStmt::ClauseNode::AND && status )
324        || ( currNode->op == WaitUntilStmt::ClauseNode::OR && !status ) ) {
325        // need to recurse on both subtrees if && subtree needs to be true or || subtree needs to be false
326        genSubtreeAssign( stmt, currNode->left, status, idx, retStmt, clauseData );
327        genSubtreeAssign( stmt, currNode->right, status, idx, retStmt, clauseData );
328    } else if ( ( currNode->op == WaitUntilStmt::ClauseNode::OR && status )
329        || ( currNode->op == WaitUntilStmt::ClauseNode::AND && !status ) ) {
330        // only one subtree needs to evaluate to status if && subtree needs to be true or || subtree needs to be false
331        CompoundStmt * leftStmt = new CompoundStmt( stmt->location );
332        CompoundStmt * rightStmt = new CompoundStmt( stmt->location );
333
334        // only one side needs to evaluate to status so we recurse on both subtrees
335        //    but only keep the statements from the subtree with minimal statements
336        genSubtreeAssign( stmt, currNode->left, status, idx, leftStmt, clauseData );
337        genSubtreeAssign( stmt, currNode->right, status, idx, rightStmt, clauseData );
338       
339        // append minimal statements to retStmt
340        if ( leftStmt->kids.size() < rightStmt->kids.size() ) {
341            retStmt->kids.splice( retStmt->kids.end(), leftStmt->kids );
342        } else {
343            retStmt->kids.splice( retStmt->kids.end(), rightStmt->kids );
344        }
345       
346        delete leftStmt;
347        delete rightStmt;
348    } else if ( currNode->op == WaitUntilStmt::ClauseNode::LEAF ) {
349        const CodeLocation & loc = stmt->location;
350        if ( status && !currNode->childOfAnd ) {
351            retStmt->push_back(
352                new ExprStmt( loc, 
353                    UntypedExpr::createAssign( loc,
354                        genArrAccessExpr( loc, idx, clauseData.at(idx)->statusName ),
355                        new NameExpr( loc, "__SELECT_RUN" )
356                    )
357                )
358            );
359        } else if ( !status && currNode->childOfAnd ) {
360            retStmt->push_back(
361                new ExprStmt( loc, 
362                    UntypedExpr::createAssign( loc,
363                        genArrAccessExpr( loc, idx, clauseData.at(idx)->statusName ),
364                        new NameExpr( loc, "__SELECT_UNSAT" )
365                    )
366                )
367            );
368        }
369
370        // No need to generate statements for the following cases since childOfAnd are always set to true
371        //    and !childOfAnd are always false
372        // - status && currNode->childOfAnd
373        // - !status && !currNode->childOfAnd
374        idx++;
375    }
376}
377
378void GenerateWaitUntilCore::genStatusAssign( const WaitUntilStmt * stmt, WaitUntilStmt::ClauseNode * currNode, int & idx, CompoundStmt * retStmt, vector<ClauseData *> & clauseData ) {
379    switch (currNode->op) {
380        case WaitUntilStmt::ClauseNode::AND:
381            // check which subtrees have all whenState == false (disabled)
382            if (!currNode->left->whenState && !currNode->right->whenState) {
383                // this case can only occur when whole tree is disabled since otherwise
384                //    genStatusAssign( ... ) isn't called on nodes with whenState == false
385                assert( !currNode->whenState ); // paranoidWWW
386                // whole tree disabled so pass true so that select is SAT vacuously
387                genSubtreeAssign( stmt, currNode, true, idx, retStmt, clauseData );
388            } else if ( !currNode->left->whenState ) {
389                // pass true since x && true === x
390                genSubtreeAssign( stmt, currNode->left, true, idx, retStmt, clauseData );
391                genStatusAssign( stmt, currNode->right, idx, retStmt, clauseData );
392            } else if ( !currNode->right->whenState ) {
393                genStatusAssign( stmt, currNode->left, idx, retStmt, clauseData );
394                genSubtreeAssign( stmt, currNode->right, true, idx, retStmt, clauseData );
395            } else { 
396                // if no children with whenState == false recurse normally via break
397                break;
398            }
399            return;
400        case WaitUntilStmt::ClauseNode::OR:
401            if (!currNode->left->whenState && !currNode->right->whenState) {
402                assert( !currNode->whenState ); // paranoid
403                genSubtreeAssign( stmt, currNode, true, idx, retStmt, clauseData );
404            } else if ( !currNode->left->whenState ) {
405                // pass false since x || false === x
406                genSubtreeAssign( stmt, currNode->left, false, idx, retStmt, clauseData );
407                genStatusAssign( stmt, currNode->right, idx, retStmt, clauseData );
408            } else if ( !currNode->right->whenState ) {
409                genStatusAssign( stmt, currNode->left, idx, retStmt, clauseData );
410                genSubtreeAssign( stmt, currNode->right, false, idx, retStmt, clauseData );
411            } else { 
412                break;
413            }
414            return;
415        case WaitUntilStmt::ClauseNode::LEAF:
416            idx++;
417            return;
418        default:
419            assertf(false, "Unreachable waituntil clause node type. How did you get here???");
420    }
421    genStatusAssign( stmt, currNode->left, idx, retStmt, clauseData );
422    genStatusAssign( stmt, currNode->right, idx, retStmt, clauseData );
423}
424
425// generates a minimal set of assignments for status arr based on which whens are toggled on/off
426CompoundStmt * GenerateWaitUntilCore::getStatusAssignment( const WaitUntilStmt * stmt, vector<ClauseData *> & clauseData ) {
427    updateWhenState( stmt->predicateTree );
428    CompoundStmt * retval = new CompoundStmt( stmt->location );
429    int idx = 0;
430    genStatusAssign( stmt, stmt->predicateTree, idx, retval, clauseData );
431    return retval;
432}
433
434// generates nested if/elses for all possible assignments of ambiguous when_conds
435// exponential size of code gen but linear runtime O(n), where n is number of ambiguous whens()
436Stmt * GenerateWaitUntilCore::genWhenStateConditions( const WaitUntilStmt * stmt, vector<ClauseData *> & clauseData, 
437    vector<pair<int, WaitUntilStmt::ClauseNode *>> & ambigClauses, vector<pair<int, WaitUntilStmt::ClauseNode *>>::size_type ambigIdx ) {
438    // I hate C++ sometimes, using vector<pair<int, WaitUntilStmt::ClauseNode *>>::size_type for size() comparison seems silly.
439    //    Why is size_type parameterized on the type stored in the vector?????
440
441    const CodeLocation & loc = stmt->location;
442    int clauseIdx = ambigClauses.at(ambigIdx).first;
443    WaitUntilStmt::ClauseNode * currNode = ambigClauses.at(ambigIdx).second;
444    Stmt * thenStmt;
445    Stmt * elseStmt;
446   
447    if ( ambigIdx == ambigClauses.size() - 1 ) { // base case
448        currNode->whenState = true;
449        thenStmt = getStatusAssignment( stmt, clauseData );
450        currNode->whenState = false;
451        elseStmt = getStatusAssignment( stmt, clauseData );
452    } else {
453        // recurse both with when enabled and disabled to generate all possible cases
454        currNode->whenState = true;
455        thenStmt = genWhenStateConditions( stmt, clauseData, ambigClauses, ambigIdx + 1 );
456        currNode->whenState = false;
457        elseStmt = genWhenStateConditions( stmt, clauseData, ambigClauses, ambigIdx + 1 );
458    }
459
460    // insert first recursion result in if ( __when_cond_ ) { ... }
461    // insert second recursion result in else { ... }
462    return new CompoundStmt ( loc,
463        {
464            new IfStmt( loc,
465                new NameExpr( loc, clauseData.at(clauseIdx)->whenName ),
466                thenStmt,
467                elseStmt
468            )
469        }
470    );
471}
472
473// typedef a fn ptr so that we can reuse genPredExpr
474// genLeafExpr is used to refer to one of the following two routines
475typedef Expr * (*GenLeafExpr)( const CodeLocation & loc, int & index );
476
477// return Expr that represents clause_statuses[index]
478// mutates index to be index + 1
479Expr * genSatExpr( const CodeLocation & loc, int & index ) {
480    return genArrAccessExpr( loc, index++, "clause_statuses" );
481}
482
483// return Expr that represents has_run(clause_statuses[index])
484Expr * genRunExpr( const CodeLocation & loc, int & index ) {
485    return new UntypedExpr ( loc, 
486        new NameExpr( loc, "__CFA_has_clause_run" ),
487        { genSatExpr( loc, index ) }
488    );
489}
490
491// Takes in the ClauseNode tree and recursively generates
492// the predicate expr used inside the predicate functions
493Expr * genPredExpr( const CodeLocation & loc, WaitUntilStmt::ClauseNode * currNode, int & idx, GenLeafExpr genLeaf ) {
494    switch (currNode->op) {
495        case WaitUntilStmt::ClauseNode::AND:
496            return new LogicalExpr( loc, 
497                new CastExpr( loc, genPredExpr( loc, currNode->left, idx, genLeaf ), new BasicType( BasicType::Kind::Bool ), GeneratedFlag::ExplicitCast ),
498                new CastExpr( loc, genPredExpr( loc, currNode->right, idx, genLeaf ), new BasicType( BasicType::Kind::Bool ), GeneratedFlag::ExplicitCast ), 
499                LogicalFlag::AndExpr
500            );
501        case WaitUntilStmt::ClauseNode::OR:
502            return new LogicalExpr( loc,
503                new CastExpr( loc, genPredExpr( loc, currNode->left, idx, genLeaf ), new BasicType( BasicType::Kind::Bool ), GeneratedFlag::ExplicitCast ),
504                new CastExpr( loc, genPredExpr( loc, currNode->right, idx, genLeaf ), new BasicType( BasicType::Kind::Bool ), GeneratedFlag::ExplicitCast ), 
505                LogicalFlag::OrExpr );
506        case WaitUntilStmt::ClauseNode::LEAF:
507            return genLeaf( loc, idx );
508        default:
509            assertf(false, "Unreachable waituntil clause node type. How did you get here???");
510    }
511}
512
513
514// Builds the predicate functions used to check the status of the waituntil statement
515/* Ex:
516{
517    waituntil( A ){ doA(); }
518    or waituntil( B ){ doB(); }
519    and waituntil( C ) { doC(); }
520}
521generates =>
522static inline bool is_full_sat_1( int * clause_statuses ) {
523    return clause_statuses[0]
524        || clause_statuses[1]
525        && clause_statuses[2];
526}
527
528static inline bool is_done_sat_1( int * clause_statuses ) {
529    return has_run(clause_statuses[0])
530        || has_run(clause_statuses[1])
531        && has_run(clause_statuses[2]);
532}
533*/
534// Returns a predicate function decl
535// predName and genLeaf determine if this generates an is_done or an is_full predicate
536FunctionDecl * buildPredicate( const WaitUntilStmt * stmt, GenLeafExpr genLeaf, string & predName ) {
537    int arrIdx = 0;
538    const CodeLocation & loc = stmt->location;
539    CompoundStmt * body = new CompoundStmt( loc );
540    body->push_back( new ReturnStmt( loc, genPredExpr( loc,  stmt->predicateTree, arrIdx, genLeaf ) ) );
541
542    return new FunctionDecl( loc,
543        predName,
544        {},                     // forall
545        {
546            new ObjectDecl( loc,
547                "clause_statuses",
548                new PointerType( new BasicType( BasicType::Kind::LongUnsignedInt ) )
549            )
550        },
551        { 
552            new ObjectDecl( loc,
553                "sat_ret",
554                new BasicType( BasicType::Kind::Bool )
555            )
556        },
557        body,               // body
558        { Storage::Static },    // storage
559        Linkage::Cforall,       // linkage
560        {},                     // attributes
561        { Function::Inline }
562    );
563}
564
565// Creates is_done and is_full predicates
566void GenerateWaitUntilCore::addPredicates( const WaitUntilStmt * stmt, string & satName, string & runName ) {
567    if ( !stmt->else_stmt || stmt->else_cond ) // don't need SAT predicate when else variation with no else_cond
568        satFns.push_back( Concurrency::buildPredicate( stmt, genSatExpr, satName ) ); 
569    satFns.push_back( Concurrency::buildPredicate( stmt, genRunExpr, runName ) );
570}
571
572// Adds the following to body:
573// if ( when_cond ) { // this if is omitted if no when() condition
574//      setup_clause( clause1, &clause_statuses[0], &park_counter );
575//      register_select(A, clause1);
576// }
577void GenerateWaitUntilCore::setUpClause( const WhenClause * clause, ClauseData * data, string & pCountName, CompoundStmt * body ) {   
578    CompoundStmt * currBody = body;
579    const CodeLocation & loc = clause->location;
580
581    // If we have a when_cond make the initialization conditional
582    if ( clause->when_cond )
583        currBody = new CompoundStmt( loc );
584
585    // Generates: setup_clause( clause1, &clause_statuses[0], &park_counter );
586    currBody->push_back( new ExprStmt( loc,
587        new UntypedExpr ( loc,
588            new NameExpr( loc, "setup_clause" ),
589            {
590                new NameExpr( loc, data->nodeName ),
591                new AddressExpr( loc, genArrAccessExpr( loc, data->index, data->statusName ) ),
592                new AddressExpr( loc, new NameExpr( loc, pCountName ) )
593            }
594        )
595    ));
596
597    // Generates: register_select(A, clause1);
598    currBody->push_back( new ExprStmt( loc, genSelectTraitCall( clause, data, "register_select" ) ) );
599
600    // generates: if ( when_cond ) { ... currBody ... }
601    if ( clause->when_cond )
602        body->push_back( 
603            new IfStmt( loc,
604                new NameExpr( loc, data->whenName ),
605                currBody
606            )
607        );
608}
609
610// Used to generate a call to one of the select trait routines
611Expr * GenerateWaitUntilCore::genSelectTraitCall( const WhenClause * clause, const ClauseData * data, string fnName ) {
612    const CodeLocation & loc = clause->location;
613    return new UntypedExpr ( loc,
614        new NameExpr( loc, fnName ),
615        {
616            new NameExpr( loc, data->targetName ),
617            new NameExpr( loc, data->nodeName )
618        }
619    );
620}
621
622// Generates:
623/* on_selected( target_1, node_1 ); ... corresponding body of target_1 ...
624*/
625CompoundStmt * GenerateWaitUntilCore::genStmtBlock( const WhenClause * clause, const ClauseData * data ) {
626    const CodeLocation & cLoc = clause->location;
627    return new CompoundStmt( cLoc,
628        {
629            new ExprStmt( cLoc,
630                genSelectTraitCall( clause, data, "on_selected" )
631            ),
632            ast::deepCopy( clause->stmt )
633        }
634    );
635}
636
637// this routine generates and returns the following
638/*for ( int i = 0; i < numClauses; i++ ) {
639    if ( predName(clause_statuses) ) break;
640    if (clause_statuses[i] == __SELECT_SAT) {
641        switch (i) {
642            case 0:
643                try {
644                    on_selected( target1, clause1 );
645                    dotarget1stmt();
646                }
647                finally { clause_statuses[i] = __SELECT_RUN; unregister_select(target1, clause1); }
648                break;
649            ...
650            case N:
651                ...
652                break;
653        }
654    }
655}*/
656CompoundStmt * GenerateWaitUntilCore::genStatusCheckFor( const WaitUntilStmt * stmt, vector<ClauseData *> & clauseData, string & predName ) {
657    CompoundStmt * ifBody = new CompoundStmt( stmt->location );
658    const CodeLocation & loc = stmt->location;
659
660    string switchLabel = namer_label.newName();
661
662    /* generates:
663    switch (i) {
664        case 0:
665            try {
666                on_selected( target1, clause1 );
667                dotarget1stmt();
668            }
669            finally { clause_statuses[i] = __SELECT_RUN; unregister_select(target1, clause1); }
670            break;
671            ...
672        case N:
673            ...
674            break;
675    }*/
676    std::vector<ptr<CaseClause>> switchCases;
677    int idx = 0;
678    for ( const auto & clause: stmt->clauses ) {
679        const CodeLocation & cLoc = clause->location;
680        switchCases.push_back(
681            new CaseClause( cLoc,
682                ConstantExpr::from_int( cLoc, idx ),
683                {
684                    new CompoundStmt( cLoc,
685                        {
686                            new ast::TryStmt( cLoc,
687                                genStmtBlock( clause, clauseData.at(idx) ),
688                                {},
689                                new ast::FinallyClause( cLoc, 
690                                    new CompoundStmt( cLoc,
691                                        {
692                                            new ExprStmt( loc,
693                                                new UntypedExpr ( loc,
694                                                    new NameExpr( loc, "?=?" ),
695                                                    {
696                                                        new UntypedExpr ( loc, 
697                                                            new NameExpr( loc, "?[?]" ),
698                                                            {
699                                                                new NameExpr( loc, clauseData.at(0)->statusName ),
700                                                                new NameExpr( loc, idxName )
701                                                            }
702                                                        ),
703                                                        new NameExpr( loc, "__SELECT_RUN" )
704                                                    }
705                                                )
706                                            ),
707                                            new ExprStmt( loc, genSelectTraitCall( clause, clauseData.at(idx), "unregister_select" ) )
708                                        }
709                                    )
710                                )
711                            ),
712                            new BranchStmt( cLoc, BranchStmt::Kind::Break, Label( cLoc, switchLabel ) )
713                        }
714                    )
715                }
716            )
717        );
718        idx++;
719    }
720
721    ifBody->push_back(
722        new SwitchStmt( loc,
723            new NameExpr( loc, idxName ),
724            std::move( switchCases ),
725            { Label( loc, switchLabel ) }
726        )
727    );
728
729    // gens:
730    // if (clause_statuses[i] == __SELECT_SAT) {
731    //      ... ifBody  ...
732    // }
733    IfStmt * ifSwitch = new IfStmt( loc,
734        new UntypedExpr ( loc,
735            new NameExpr( loc, "?==?" ),
736            {
737                new UntypedExpr ( loc, 
738                    new NameExpr( loc, "?[?]" ),
739                    {
740                        new NameExpr( loc, clauseData.at(0)->statusName ),
741                        new NameExpr( loc, idxName )
742                    }
743                ),
744                new NameExpr( loc, "__SELECT_SAT" )
745            }
746        ),      // condition
747        ifBody  // body
748    );
749
750    string forLabel = namer_label.newName();
751
752    // we hoist init here so that this pass can happen after hoistdecls pass
753    return new CompoundStmt( loc,
754        {
755            new DeclStmt( loc,
756                new ObjectDecl( loc,
757                    idxName,
758                    new BasicType( BasicType::Kind::SignedInt ),
759                    new SingleInit( loc, ConstantExpr::from_int( loc, 0 ) )
760                )
761            ),
762            new ForStmt( loc,
763                {},  // inits
764                new UntypedExpr ( loc,
765                    new NameExpr( loc, "?<?" ),
766                    {
767                        new NameExpr( loc, idxName ),
768                        ConstantExpr::from_int( loc, stmt->clauses.size() )
769                    }
770                ),  // cond
771                new UntypedExpr ( loc,
772                    new NameExpr( loc, "?++" ),
773                    { new NameExpr( loc, idxName ) }
774                ),  // inc
775                new CompoundStmt( loc,
776                    {
777                        new IfStmt( loc,
778                            new UntypedExpr ( loc,
779                                new NameExpr( loc, predName ),
780                                { new NameExpr( loc, clauseData.at(0)->statusName ) }
781                            ),
782                            new BranchStmt( loc, BranchStmt::Kind::Break, Label( loc, forLabel ) )
783                        ),
784                        ifSwitch
785                    }
786                ),   // body
787                { Label( loc, forLabel ) }
788            )
789        }
790    );
791}
792
793// Generates: !is_full_sat_n() / !is_run_sat_n()
794Expr * genNotSatExpr( const WaitUntilStmt * stmt, string & satName, string & arrName ) {
795    const CodeLocation & loc = stmt->location;
796    return new UntypedExpr ( loc,
797        new NameExpr( loc, "!?" ),
798        {
799            new UntypedExpr ( loc,
800                new NameExpr( loc, satName ),
801                { new NameExpr( loc, arrName ) }
802            )
803        }
804    );
805}
806
807// Generates the code needed for waituntils with an else ( ... )
808// Checks clauses once after registering for completion and runs them if completes
809// If not enough have run to satisfy predicate after one pass then the else is run
810Stmt * GenerateWaitUntilCore::genElseClauseBranch( const WaitUntilStmt * stmt, string & runName, string & arrName, vector<ClauseData *> & clauseData ) {
811    return new CompoundStmt( stmt->else_stmt->location,
812        {
813            genStatusCheckFor( stmt, clauseData, runName ),
814            new IfStmt( stmt->else_stmt->location,
815                genNotSatExpr( stmt, runName, arrName ),
816                ast::deepCopy( stmt->else_stmt )
817            )
818        }
819    );
820}
821
822Stmt * GenerateWaitUntilCore::genNoElseClauseBranch( const WaitUntilStmt * stmt, string & satName, string & runName, string & arrName, string & pCountName, vector<ClauseData *> & clauseData ) {
823    CompoundStmt * whileBody = new CompoundStmt( stmt->location );
824    const CodeLocation & loc = stmt->location;
825
826    // generates: __CFA_maybe_park( &park_counter );
827    whileBody->push_back(
828        new ExprStmt( loc,
829            new UntypedExpr ( loc,
830                new NameExpr( loc, "__CFA_maybe_park" ),
831                { new AddressExpr( loc, new NameExpr( loc, pCountName ) ) }
832            )
833        )
834    );
835
836    whileBody->push_back( genStatusCheckFor( stmt, clauseData, satName ) );
837
838    return new CompoundStmt( loc,
839        {
840            new WhileDoStmt( loc,
841                genNotSatExpr( stmt, satName, arrName ),
842                whileBody,  // body
843                {}          // no inits
844            ),
845            genStatusCheckFor( stmt, clauseData, runName )
846        }
847    );
848}
849
850// generates the following decls for each clause to ensure the target expr and when_cond is only evaluated once
851// typeof(target) & __clause_target_0 = target;
852// bool __when_cond_0 = when_cond; // only generated if when_cond defined
853// select_node clause1;
854void GenerateWaitUntilCore::genClauseInits( const WaitUntilStmt * stmt, vector<ClauseData *> & clauseData, CompoundStmt * body, string & statusName, string & elseWhenName ) {
855    ClauseData * currClause;
856    for ( vector<ClauseData*>::size_type i = 0; i < stmt->clauses.size(); i++ ) {
857        currClause = new ClauseData( i, statusName );
858        currClause->nodeName = namer_node.newName();
859        currClause->targetName = namer_target.newName();
860        currClause->whenName = namer_when.newName();
861        clauseData.push_back(currClause);
862        const CodeLocation & cLoc = stmt->clauses.at(i)->location;
863
864        // typeof(target) & __clause_target_0 = target;
865        body->push_back(
866            new DeclStmt( cLoc,
867                new ObjectDecl( cLoc,
868                    currClause->targetName,
869                    new ReferenceType( new TypeofType( ast::deepCopy( stmt->clauses.at(i)->target ) ) ),
870                    new SingleInit( cLoc, ast::deepCopy( stmt->clauses.at(i)->target ) )
871                )
872            )
873        );
874
875        // bool __when_cond_0 = when_cond; // only generated if when_cond defined
876        if ( stmt->clauses.at(i)->when_cond )
877            body->push_back(
878                new DeclStmt( cLoc,
879                    new ObjectDecl( cLoc,
880                        currClause->whenName,
881                        new BasicType( BasicType::Kind::Bool ),
882                        new SingleInit( cLoc, ast::deepCopy( stmt->clauses.at(i)->when_cond ) )
883                    )
884                )
885            );
886       
887        // select_node clause1;
888        body->push_back(
889            new DeclStmt( cLoc,
890                new ObjectDecl( cLoc,
891                    currClause->nodeName,
892                    new StructInstType( selectNodeDecl )
893                )
894            )
895        );
896    }
897
898    if ( stmt->else_stmt && stmt->else_cond ) {
899        body->push_back(
900            new DeclStmt( stmt->else_cond->location,
901                new ObjectDecl( stmt->else_cond->location,
902                    elseWhenName,
903                    new BasicType( BasicType::Kind::Bool ),
904                    new SingleInit( stmt->else_cond->location, ast::deepCopy( stmt->else_cond ) )
905                )
906            )
907        );
908    }
909}
910
911/*
912if ( clause_status == &clause1 ) ... clause 1 body ...
913...
914elif ( clause_status == &clausen ) ... clause n body ...
915*/
916Stmt * GenerateWaitUntilCore::buildOrCaseSwitch( const WaitUntilStmt * stmt, string & statusName, vector<ClauseData *> & data ) {
917    const CodeLocation & loc = stmt->location;
918
919    IfStmt * outerIf = nullptr;
920        IfStmt * lastIf = nullptr;
921
922        //adds an if/elif clause for each select clause address to run the corresponding clause stmt
923        for ( long unsigned int i = 0; i < data.size(); i++ ) {
924        const CodeLocation & cLoc = stmt->clauses.at(i)->location;
925
926                IfStmt * currIf = new IfStmt( cLoc,
927                        new UntypedExpr( cLoc, 
928                new NameExpr( cLoc, "?==?" ), 
929                {
930                    new NameExpr( cLoc, statusName ),
931                    new CastExpr( cLoc, 
932                        new AddressExpr( cLoc, new NameExpr( cLoc, data.at(i)->nodeName ) ),
933                        new BasicType( BasicType::Kind::LongUnsignedInt ), GeneratedFlag::ExplicitCast
934                    )
935                }
936            ),
937            genStmtBlock( stmt->clauses.at(i), data.at(i) )
938                );
939               
940                if ( i == 0 ) {
941                        outerIf = currIf;
942                } else {
943                        // add ifstmt to else of previous stmt
944                        lastIf->else_ = currIf;
945                }
946
947                lastIf = currIf;
948        }
949
950    return new CompoundStmt( loc,
951        {
952            new ExprStmt( loc, new UntypedExpr( loc, new NameExpr( loc, "park" ) ) ),
953            outerIf
954        }
955    );
956}
957
958Stmt * GenerateWaitUntilCore::recursiveOrIfGen( const WaitUntilStmt * stmt, vector<ClauseData *> & data, vector<ClauseData*>::size_type idx, string & elseWhenName ) {
959    if ( idx == data.size() ) {   // base case, gen last else
960        const CodeLocation & cLoc = stmt->else_stmt->location;
961        if ( !stmt->else_stmt ) // normal non-else gen
962            return buildOrCaseSwitch( stmt, data.at(0)->statusName, data );
963
964        Expr * raceFnCall = new UntypedExpr( stmt->location,
965            new NameExpr( stmt->location, "__select_node_else_race" ),
966            { new NameExpr( stmt->location, data.at(0)->nodeName ) }
967        );
968
969        if ( stmt->else_stmt && stmt->else_cond ) { // return else conditional on both when and race
970            return new IfStmt( cLoc,
971                new LogicalExpr( cLoc,
972                    new CastExpr( cLoc,
973                        new NameExpr( cLoc, elseWhenName ),
974                        new BasicType( BasicType::Kind::Bool ), GeneratedFlag::ExplicitCast
975                    ),
976                    new CastExpr( cLoc,
977                        raceFnCall,
978                        new BasicType( BasicType::Kind::Bool ), GeneratedFlag::ExplicitCast
979                    ),
980                    LogicalFlag::AndExpr
981                ),
982                ast::deepCopy( stmt->else_stmt ),
983                buildOrCaseSwitch( stmt, data.at(0)->statusName, data )
984            );
985        }
986
987        // return else conditional on race
988        return new IfStmt( stmt->else_stmt->location,
989            raceFnCall,
990            ast::deepCopy( stmt->else_stmt ),
991            buildOrCaseSwitch( stmt, data.at(0)->statusName, data )
992        );
993    }
994    const CodeLocation & cLoc = stmt->clauses.at(idx)->location;
995
996    Expr * baseCond = genSelectTraitCall( stmt->clauses.at(idx), data.at(idx), "register_select" );
997    Expr * ifCond;
998
999    // If we have a when_cond make the register call conditional on it
1000    if ( stmt->clauses.at(idx)->when_cond ) {
1001        ifCond = new LogicalExpr( cLoc,
1002            new CastExpr( cLoc,
1003                new NameExpr( cLoc, data.at(idx)->whenName ), 
1004                new BasicType( BasicType::Kind::Bool ), GeneratedFlag::ExplicitCast
1005            ),
1006            new CastExpr( cLoc,
1007                baseCond,
1008                new BasicType( BasicType::Kind::Bool ), GeneratedFlag::ExplicitCast
1009            ),
1010            LogicalFlag::AndExpr
1011        );
1012    } else ifCond = baseCond;
1013
1014    return new CompoundStmt( cLoc,
1015        {   // gens: setup_clause( clause1, &status, 0p );
1016            new ExprStmt( cLoc,
1017                new UntypedExpr ( cLoc,
1018                    new NameExpr( cLoc, "setup_clause" ),
1019                    {
1020                        new NameExpr( cLoc, data.at(idx)->nodeName ),
1021                        new AddressExpr( cLoc, new NameExpr( cLoc, data.at(idx)->statusName ) ),
1022                        ConstantExpr::null( cLoc, new PointerType( new BasicType( BasicType::Kind::SignedInt ) ) )
1023                    }
1024                )
1025            ),
1026            // gens: if (__when_cond && register_select()) { clause body } else { ... recursiveOrIfGen ... }
1027            new IfStmt( cLoc,
1028                ifCond,
1029                genStmtBlock( stmt->clauses.at(idx), data.at(idx) ),
1030                recursiveOrIfGen( stmt, data, idx + 1, elseWhenName )
1031            )
1032        }
1033    );
1034}
1035
1036// This gens the special case of an all OR waituntil:
1037/*
1038int status = 0;
1039
1040typeof(target) & __clause_target_0 = target;
1041bool __when_cond_0 = when_cond; // only generated if when_cond defined
1042select_node clause1;
1043... generate above for rest of clauses ...
1044
1045try {
1046    setup_clause( clause1, &status, 0p );
1047    if ( __when_cond_0 && register_select( 1 ) ) {
1048        ... clause 1 body ...
1049    } else {
1050        ... recursively gen for each of n clauses ...
1051        setup_clause( clausen, &status, 0p );
1052        if ( __when_cond_n-1 && register_select( n ) ) {
1053            ... clause n body ...
1054        } else {
1055            if ( else_when ) ... else clause body ...
1056            else {
1057                park();
1058
1059                // after winning the race and before unpark() clause_status is set to be the winning clause index + 1
1060                if ( clause_status == &clause1) ... clause 1 body ...
1061                ...
1062                elif ( clause_status == &clausen ) ... clause n body ...
1063            }
1064        }
1065    }
1066}
1067finally {
1068    if ( __when_cond_1 && clause1.status != 0p) unregister_select( 1 ); // if registered unregister
1069    ...
1070    if ( __when_cond_n && clausen.status != 0p) unregister_select( n );
1071}
1072*/
1073Stmt * GenerateWaitUntilCore::genAllOr( const WaitUntilStmt * stmt ) {
1074    const CodeLocation & loc = stmt->location;
1075    string statusName = namer_status.newName();
1076    string elseWhenName = namer_when.newName();
1077    int numClauses = stmt->clauses.size();
1078    CompoundStmt * body = new CompoundStmt( stmt->location );
1079
1080    // Generates: unsigned long int status = 0;
1081    body->push_back( new DeclStmt( loc,
1082        new ObjectDecl( loc,
1083            statusName,
1084            new BasicType( BasicType::Kind::LongUnsignedInt ),
1085            new SingleInit( loc, ConstantExpr::from_int( loc, 0 ) )
1086        )
1087    ));
1088
1089    vector<ClauseData *> clauseData;
1090    genClauseInits( stmt, clauseData, body, statusName, elseWhenName );
1091
1092    vector<int> whenIndices; // track which clauses have whens
1093
1094    CompoundStmt * unregisters = new CompoundStmt( loc );
1095    Expr * ifCond;
1096    for ( int i = 0; i < numClauses; i++ ) {
1097        const CodeLocation & cLoc = stmt->clauses.at(i)->location;
1098        // Gens: node.status != 0p
1099        UntypedExpr * statusPtrCheck = new UntypedExpr( cLoc, 
1100            new NameExpr( cLoc, "?!=?" ), 
1101            {
1102                ConstantExpr::null( cLoc, new PointerType( new BasicType( BasicType::Kind::LongUnsignedInt ) ) ),
1103                new UntypedExpr( cLoc, 
1104                    new NameExpr( cLoc, "__get_clause_status" ), 
1105                    { new NameExpr( cLoc, clauseData.at(i)->nodeName ) } 
1106                ) 
1107            }
1108        );
1109
1110        // If we have a when_cond make the unregister call conditional on it
1111        if ( stmt->clauses.at(i)->when_cond ) {
1112            whenIndices.push_back(i);
1113            ifCond = new LogicalExpr( cLoc,
1114                new CastExpr( cLoc,
1115                    new NameExpr( cLoc, clauseData.at(i)->whenName ), 
1116                    new BasicType( BasicType::Kind::Bool ), GeneratedFlag::ExplicitCast
1117                ),
1118                new CastExpr( cLoc,
1119                    statusPtrCheck,
1120                    new BasicType( BasicType::Kind::Bool ), GeneratedFlag::ExplicitCast
1121                ),
1122                LogicalFlag::AndExpr
1123            );
1124        } else ifCond = statusPtrCheck;
1125       
1126        unregisters->push_back(
1127            new IfStmt( cLoc,
1128                ifCond,
1129                new ExprStmt( cLoc, genSelectTraitCall( stmt->clauses.at(i), clauseData.at(i), "unregister_select" ) ) 
1130            )
1131        );
1132    }
1133
1134    if ( whenIndices.empty() || whenIndices.size() != stmt->clauses.size() ) {
1135        body->push_back(
1136                new ast::TryStmt( loc,
1137                new CompoundStmt( loc, { recursiveOrIfGen( stmt, clauseData, 0, elseWhenName ) } ),
1138                {},
1139                new ast::FinallyClause( loc, unregisters )
1140            )
1141        );
1142    } else { // If all clauses have whens, we need to skip the waituntil if they are all false
1143        Expr * outerIfCond = new NameExpr( loc, clauseData.at( whenIndices.at(0) )->whenName );
1144        Expr * lastExpr = outerIfCond;
1145
1146        for ( vector<int>::size_type i = 1; i < whenIndices.size(); i++ ) {
1147            outerIfCond = new LogicalExpr( loc,
1148                new CastExpr( loc,
1149                    new NameExpr( loc, clauseData.at( whenIndices.at(i) )->whenName ), 
1150                    new BasicType( BasicType::Kind::Bool ), GeneratedFlag::ExplicitCast
1151                ),
1152                new CastExpr( loc,
1153                    lastExpr,
1154                    new BasicType( BasicType::Kind::Bool ), GeneratedFlag::ExplicitCast
1155                ),
1156                LogicalFlag::OrExpr
1157            );
1158            lastExpr = outerIfCond;
1159        }
1160
1161        body->push_back(
1162                new ast::TryStmt( loc,
1163                new CompoundStmt( loc, 
1164                    {
1165                        new IfStmt( loc,
1166                            outerIfCond,
1167                            recursiveOrIfGen( stmt, clauseData, 0, elseWhenName )
1168                        )
1169                    }
1170                ),
1171                {},
1172                new ast::FinallyClause( loc, unregisters )
1173            )
1174        );
1175    }
1176
1177    for ( ClauseData * datum : clauseData )
1178        delete datum;
1179
1180    return body;
1181}
1182
1183Stmt * GenerateWaitUntilCore::postvisit( const WaitUntilStmt * stmt ) {
1184    if ( !selectNodeDecl )
1185        SemanticError( stmt, "waituntil statement requires #include <waituntil.hfa>" );
1186
1187    // Prep clause tree to figure out how to set initial statuses
1188    // setTreeSizes( stmt->predicateTree );
1189    if ( paintWhenTree( stmt->predicateTree ) ) // if this returns true we can special case since tree is all OR's
1190        return genAllOr( stmt );
1191
1192    CompoundStmt * tryBody = new CompoundStmt( stmt->location );
1193    CompoundStmt * body = new CompoundStmt( stmt->location );
1194    string statusArrName = namer_status.newName();
1195    string pCountName = namer_park.newName();
1196    string satName = namer_sat.newName();
1197    string runName = namer_run.newName();
1198    string elseWhenName = namer_when.newName();
1199    int numClauses = stmt->clauses.size();
1200    addPredicates( stmt, satName, runName );
1201
1202    const CodeLocation & loc = stmt->location;
1203
1204    // Generates: int park_counter = 0;
1205    body->push_back( new DeclStmt( loc,
1206        new ObjectDecl( loc,
1207            pCountName,
1208            new BasicType( BasicType::Kind::SignedInt ),
1209            new SingleInit( loc, ConstantExpr::from_int( loc, 0 ) )
1210        )
1211    ));
1212
1213    // Generates: int clause_statuses[3] = { 0 };
1214    body->push_back( new DeclStmt( loc,
1215        new ObjectDecl( loc,
1216            statusArrName,
1217            new ArrayType( new BasicType( BasicType::Kind::LongUnsignedInt ), ConstantExpr::from_int( loc, numClauses ), LengthFlag::FixedLen, DimensionFlag::DynamicDim ),
1218            new ListInit( loc,
1219                {
1220                    new SingleInit( loc, ConstantExpr::from_int( loc, 0 ) )
1221                }
1222            )
1223        )
1224    ));
1225
1226    vector<ClauseData *> clauseData;
1227    genClauseInits( stmt, clauseData, body, statusArrName, elseWhenName );
1228
1229    vector<pair<int, WaitUntilStmt::ClauseNode *>> ambiguousClauses;       // list of ambiguous clauses
1230    vector<int> andWhenClauses;    // list of clauses that have an AND op as a direct parent and when_cond defined
1231
1232    collectWhens( stmt->predicateTree, ambiguousClauses, andWhenClauses );
1233
1234    // This is only needed for clauses that have AND as a parent and a when_cond defined
1235    // generates: if ( ! when_cond_0 ) clause_statuses_0 = __SELECT_RUN;
1236    for ( int idx : andWhenClauses ) {
1237        const CodeLocation & cLoc = stmt->clauses.at(idx)->location;
1238        body->push_back( 
1239            new IfStmt( cLoc,
1240                new UntypedExpr ( cLoc,
1241                    new NameExpr( cLoc, "!?" ),
1242                    { new NameExpr( cLoc, clauseData.at(idx)->whenName ) }
1243                ),  // IfStmt cond
1244                new ExprStmt( cLoc,
1245                    new UntypedExpr ( cLoc,
1246                        new NameExpr( cLoc, "?=?" ),
1247                        {
1248                            new UntypedExpr ( cLoc, 
1249                                new NameExpr( cLoc, "?[?]" ),
1250                                {
1251                                    new NameExpr( cLoc, statusArrName ),
1252                                    ConstantExpr::from_int( cLoc, idx )
1253                                }
1254                            ),
1255                            new NameExpr( cLoc, "__SELECT_RUN" )
1256                        }
1257                    )
1258                )  // IfStmt then
1259            )
1260        );
1261    }
1262
1263    // Only need to generate conditional initial state setting for ambiguous when clauses
1264    if ( !ambiguousClauses.empty() ) {
1265        body->push_back( genWhenStateConditions( stmt, clauseData, ambiguousClauses, 0 ) );
1266    }
1267
1268    // generates the following for each clause:
1269    // setup_clause( clause1, &clause_statuses[0], &park_counter );
1270    // register_select(A, clause1);
1271    for ( int i = 0; i < numClauses; i++ ) {
1272        setUpClause( stmt->clauses.at(i), clauseData.at(i), pCountName, tryBody );
1273    }
1274
1275    // generate satisfy logic based on if there is an else clause and if it is conditional
1276    if ( stmt->else_stmt && stmt->else_cond ) { // gen both else/non else branches
1277        tryBody->push_back(
1278            new IfStmt( stmt->else_cond->location,
1279                new NameExpr( stmt->else_cond->location, elseWhenName ),
1280                genElseClauseBranch( stmt, runName, statusArrName, clauseData ),
1281                genNoElseClauseBranch( stmt, satName, runName, statusArrName, pCountName, clauseData )
1282            )
1283        );
1284    } else if ( !stmt->else_stmt ) { // normal gen
1285        tryBody->push_back( genNoElseClauseBranch( stmt, satName, runName, statusArrName, pCountName, clauseData ) );
1286    } else { // generate just else
1287        tryBody->push_back( genElseClauseBranch( stmt, runName, statusArrName, clauseData ) );
1288    }
1289
1290    // Collection of unregister calls on resources to be put in finally clause
1291    // for each clause:
1292    // if ( !__CFA_has_clause_run( clause_statuses[i] )) && unregister_select( ... , clausei ) ) { ... clausei stmt ... }
1293    // OR if when( ... ) defined on resource
1294    // if ( when_cond_i && (!__CFA_has_clause_run( clause_statuses[i] )) && unregister_select( ... , clausei ) ) { ... clausei stmt ... }
1295    CompoundStmt * unregisters = new CompoundStmt( loc );
1296
1297    Expr * statusExpr; // !__CFA_has_clause_run( clause_statuses[i] )
1298    for ( int i = 0; i < numClauses; i++ ) {
1299        const CodeLocation & cLoc = stmt->clauses.at(i)->location;
1300
1301        // Generates: !__CFA_has_clause_run( clause_statuses[i] )
1302        statusExpr = new UntypedExpr ( cLoc,
1303            new NameExpr( cLoc, "!?" ),
1304            {
1305                new UntypedExpr ( cLoc, 
1306                    new NameExpr( cLoc, "__CFA_has_clause_run" ),
1307                    {
1308                        genArrAccessExpr( cLoc, i, statusArrName )
1309                    }
1310                )
1311            }
1312        );
1313       
1314        // Generates:
1315        // (!__CFA_has_clause_run( clause_statuses[i] )) && unregister_select( ... , clausei );
1316        statusExpr = new LogicalExpr( cLoc,
1317            new CastExpr( cLoc,
1318                statusExpr, 
1319                new BasicType( BasicType::Kind::Bool ), GeneratedFlag::ExplicitCast
1320            ),
1321            new CastExpr( cLoc,
1322                genSelectTraitCall( stmt->clauses.at(i), clauseData.at(i), "unregister_select" ),
1323                new BasicType( BasicType::Kind::Bool ), GeneratedFlag::ExplicitCast
1324            ),
1325            LogicalFlag::AndExpr
1326        );
1327       
1328        // if when cond defined generates:
1329        // when_cond_i && (!__CFA_has_clause_run( clause_statuses[i] )) && unregister_select( ... , clausei );
1330        if ( stmt->clauses.at(i)->when_cond )
1331            statusExpr = new LogicalExpr( cLoc,
1332                new CastExpr( cLoc,
1333                    new NameExpr( cLoc, clauseData.at(i)->whenName ), 
1334                    new BasicType( BasicType::Kind::Bool ), GeneratedFlag::ExplicitCast
1335                ),
1336                new CastExpr( cLoc,
1337                    statusExpr,
1338                    new BasicType( BasicType::Kind::Bool ), GeneratedFlag::ExplicitCast
1339                ),
1340                LogicalFlag::AndExpr
1341            );
1342
1343        // generates:
1344        // if ( statusExpr ) { ... clausei stmt ... }
1345        unregisters->push_back( 
1346            new IfStmt( cLoc,
1347                statusExpr,
1348                new CompoundStmt( cLoc,
1349                    {
1350                        new IfStmt( cLoc,
1351                            genSelectTraitCall( stmt->clauses.at(i), clauseData.at(i), "on_selected" ),
1352                            ast::deepCopy( stmt->clauses.at(i)->stmt )
1353                        )
1354                    }
1355                )
1356            )
1357        );
1358
1359        // // generates:
1360        // // if ( statusExpr ) { ... clausei stmt ... }
1361        // unregisters->push_back(
1362        //     new IfStmt( cLoc,
1363        //         statusExpr,
1364        //         genStmtBlock( stmt->clauses.at(i), clauseData.at(i) )
1365        //     )
1366        // );
1367    }
1368
1369    body->push_back( 
1370        new ast::TryStmt(
1371            loc,
1372            tryBody,
1373            {},
1374            new ast::FinallyClause( loc, unregisters )
1375        )
1376    );
1377
1378    for ( ClauseData * datum : clauseData )
1379        delete datum;
1380
1381    return body;
1382}
1383
1384// To add the predicates at global scope we need to do it in a second pass
1385// Predicates are added after "struct select_node { ... };"
1386class AddPredicateDecls final : public WithDeclsToAdd<> {
1387    vector<FunctionDecl *> & satFns;
1388    const StructDecl * selectNodeDecl = nullptr;
1389
1390  public:
1391    void previsit( const StructDecl * decl ) {
1392        if ( !decl->body ) {
1393            return;
1394        } else if ( "select_node" == decl->name ) {
1395            assert( !selectNodeDecl );
1396            selectNodeDecl = decl;
1397            for ( FunctionDecl * fn : satFns )
1398                declsToAddAfter.push_back(fn);           
1399        }
1400    }
1401    AddPredicateDecls( vector<FunctionDecl *> & satFns ): satFns(satFns) {}
1402};
1403
1404void generateWaitUntil( TranslationUnit & translationUnit ) {
1405    vector<FunctionDecl *> satFns;
1406        Pass<GenerateWaitUntilCore>::run( translationUnit, satFns );
1407    Pass<AddPredicateDecls>::run( translationUnit, satFns );
1408}
1409
1410} // namespace Concurrency
1411
1412// Local Variables: //
1413// tab-width: 4 //
1414// mode: c++ //
1415// compile-command: "make install" //
1416// End: //
Note: See TracBrowser for help on using the repository browser.