source: src/Concurrency/Waituntil.cpp @ 7a780ad

Last change on this file since 7a780ad was 7a780ad, checked in by Andrew Beach <ajbeach@…>, 6 weeks ago

Moved ast::BasicType::Kind to ast::BasicKind? in its own hearder. This is more consistent with other utility enums (although we still use this as a enum class) and reduces what some files need to include. Also did a upgrade in a comment with MAX_INTEGER_TYPE, it is now part of the enum.

  • Property mode set to 100644
File size: 58.0 KB
Line 
1//
2// Cforall Version 1.0.0 Copyright (C) 2016 University of Waterloo
3//
4// The contents of this file are covered under the licence agreement in the
5// file "LICENCE" distributed with Cforall.
6//
7// WaitforNew.cpp -- Expand waitfor clauses into code.
8//
9// Author           : Andrew Beach
10// Created On       : Fri May 27 10:31:00 2022
11// Last Modified By : Andrew Beach
12// Last Modified On : Tue Jun 13 13:30:00 2022
13// Update Count     : 0
14//
15
16#include "Waituntil.hpp"
17
18#include <string>
19
20#include "AST/Copy.hpp"
21#include "AST/Expr.hpp"
22#include "AST/Pass.hpp"
23#include "AST/Print.hpp"
24#include "AST/Stmt.hpp"
25#include "AST/Type.hpp"
26#include "Common/UniqueName.h"
27
28using namespace ast;
29using namespace std;
30
31/* So this is what this pass dones:
32{
33    when ( condA ) waituntil( A ){ doA(); }
34    or when ( condB ) waituntil( B ){ doB(); }
35    and when ( condC ) waituntil( C ) { doC(); }
36}
37                 ||
38                 ||
39                \||/
40                 \/
41
42Generates these two routines:
43static inline bool is_full_sat_1( int * clause_statuses ) {
44    return clause_statuses[0]
45        || clause_statuses[1]
46        && clause_statuses[2];
47}
48
49static inline bool is_done_sat_1( int * clause_statuses ) {
50    return has_run(clause_statuses[0])
51        || has_run(clause_statuses[1])
52        && has_run(clause_statuses[2]);
53}
54
55Replaces the waituntil statement above with the following code:
56{
57    // used with atomic_dec/inc to get binary semaphore behaviour
58    int park_counter = 0;
59
60    // status (one for each clause)
61    int clause_statuses[3] = { 0 };
62
63    bool whenA = condA;
64    bool whenB = condB;
65    bool whenC = condC;
66
67    if ( !whenB ) clause_statuses[1] = __SELECT_RUN;
68    if ( !whenC ) clause_statuses[2] = __SELECT_RUN;
69
70    // some other conditional settors for clause_statuses are set here, see genSubtreeAssign and related routines
71
72    // three blocks
73    // for each block, create, setup, then register select_node
74    select_node clause1;
75    select_node clause2;
76    select_node clause3;
77
78    try {
79        if ( whenA ) { register_select(A, clause1); setup_clause( clause1, &clause_statuses[0], &park_counter ); }
80        ... repeat ^ for B and C ...
81
82        // if else clause is defined a separate branch can occur here to set initial values, see genWhenStateConditions
83
84        // loop & park until done
85        while( !is_full_sat_1( clause_statuses ) ) {
86           
87            // binary sem P();
88            if ( __atomic_sub_fetch( &park_counter, 1, __ATOMIC_SEQ_CST) < 0 )
89                park();
90           
91            // execute any blocks available with status set to 0
92            for ( int i = 0; i < 3; i++ ) {
93                if (clause_statuses[i] == __SELECT_SAT) {
94                    switch (i) {
95                        case 0:
96                            try {
97                                    on_selected( A, clause1 );
98                                    doA();
99                            }
100                            finally { clause_statuses[i] = __SELECT_RUN; unregister_select(A, clause1); }
101                            break;
102                        case 1:
103                            ... same gen as A but for B and clause2 ...
104                            break;
105                        case 2:
106                            ... same gen as A but for C and clause3 ...
107                            break;
108                    }
109                }
110            }
111        }
112
113        // ensure that the blocks that triggered is_full_sat_1 are run
114        // by running every un-run block that is SAT from the start until
115        // the predicate is SAT when considering RUN status = true
116        for ( int i = 0; i < 3; i++ ) {
117            if (is_done_sat_1( clause_statuses )) break;
118            if (clause_statuses[i] == __SELECT_SAT)
119                ... Same if body here as in loop above ...
120        }
121    } finally {
122        // the unregister and on_selected calls are needed to support primitives where the acquire has side effects
123        // so the corresponding block MUST be run for those primitives to not lose state (example is channels)
124        if ( !has_run(clause_statuses[0]) && whenA && unregister_select(A, clause1) )
125            on_selected( A, clause1 )
126            doA();
127        ... repeat if above for B and C ...
128    }
129}
130
131*/
132
133namespace Concurrency {
134
135class GenerateWaitUntilCore final {
136    vector<FunctionDecl *> & satFns;
137        UniqueName namer_sat = "__is_full_sat_"s;
138    UniqueName namer_run = "__is_run_sat_"s;
139        UniqueName namer_park = "__park_counter_"s;
140        UniqueName namer_status = "__clause_statuses_"s;
141        UniqueName namer_node = "__clause_"s;
142    UniqueName namer_target = "__clause_target_"s;
143    UniqueName namer_when = "__when_cond_"s;
144    UniqueName namer_label = "__waituntil_label_"s;
145
146    string idxName = "__CFA_clause_idx_";
147
148    struct ClauseData {
149        string nodeName;
150        string targetName;
151        string whenName;
152        int index;
153        string & statusName;
154        ClauseData( int index, string & statusName ) : index(index), statusName(statusName) {}
155    };
156
157    const StructDecl * selectNodeDecl = nullptr;
158
159    // This first set of routines are all used to do the complicated job of
160    //    dealing with how to set predicate statuses with certain when_conds T/F
161    //    so that the when_cond == F effectively makes that clause "disappear"
162    void updateAmbiguousWhen( WaitUntilStmt::ClauseNode * currNode, bool andAbove, bool orAbove, bool andBelow, bool orBelow );
163    void paintWhenTree( WaitUntilStmt::ClauseNode * currNode, bool andAbove, bool orAbove, bool & andBelow, bool & orBelow );
164    bool paintWhenTree( WaitUntilStmt::ClauseNode * currNode );
165    void collectWhens( WaitUntilStmt::ClauseNode * currNode, vector<pair<int, WaitUntilStmt::ClauseNode *>> & ambigIdxs, vector<int> & andIdxs, int & index, bool parentAmbig, bool parentAnd );
166    void collectWhens( WaitUntilStmt::ClauseNode * currNode, vector<pair<int, WaitUntilStmt::ClauseNode *>> & ambigIdxs, vector<int> & andIdxs );
167    void updateWhenState( WaitUntilStmt::ClauseNode * currNode );
168    void genSubtreeAssign( const WaitUntilStmt * stmt, WaitUntilStmt::ClauseNode * currNode, bool status, int & idx, CompoundStmt * retStmt, vector<ClauseData *> & clauseData );
169    void genStatusAssign( const WaitUntilStmt * stmt, WaitUntilStmt::ClauseNode * currNode, int & idx, CompoundStmt * retStmt, vector<ClauseData *> & clauseData );
170    CompoundStmt * getStatusAssignment( const WaitUntilStmt * stmt, vector<ClauseData *> & clauseData );
171    Stmt * genWhenStateConditions( const WaitUntilStmt * stmt, vector<ClauseData *> & clauseData, vector<pair<int, WaitUntilStmt::ClauseNode *>> & ambigClauses, vector<pair<int, WaitUntilStmt::ClauseNode *>>::size_type ambigIdx );
172
173    // These routines are just code-gen helpers
174    void addPredicates( const WaitUntilStmt * stmt, string & satName, string & runName );
175    void setUpClause( const WhenClause * clause, ClauseData * data, string & pCountName, CompoundStmt * body );
176    CompoundStmt * genStatusCheckFor( const WaitUntilStmt * stmt, vector<ClauseData *> & clauseData, string & predName );
177    Expr * genSelectTraitCall( const WhenClause * clause, const ClauseData * data, string fnName );
178    CompoundStmt * genStmtBlock( const WhenClause * clause, const ClauseData * data );
179    Stmt * genElseClauseBranch( const WaitUntilStmt * stmt, string & runName, string & arrName, vector<ClauseData *> & clauseData );
180    Stmt * genNoElseClauseBranch( const WaitUntilStmt * stmt, string & runName, string & arrName, string & pCountName, vector<ClauseData *> & clauseData );
181    void genClauseInits( const WaitUntilStmt * stmt, vector<ClauseData *> & clauseData, CompoundStmt * body, string & statusName, string & elseWhenName );
182    Stmt * recursiveOrIfGen( const WaitUntilStmt * stmt, vector<ClauseData *> & data, vector<ClauseData*>::size_type idx, string & elseWhenName );
183    Stmt * buildOrCaseSwitch( const WaitUntilStmt * stmt, string & statusName, vector<ClauseData *> & data );
184    Stmt * genAllOr( const WaitUntilStmt * stmt );
185
186  public:
187    void previsit( const StructDecl * decl );
188        Stmt * postvisit( const WaitUntilStmt * stmt );
189    GenerateWaitUntilCore( vector<FunctionDecl *> & satFns ): satFns(satFns) {}
190};
191
192// Finds select_node decl
193void GenerateWaitUntilCore::previsit( const StructDecl * decl ) {
194    if ( !decl->body ) {
195                return;
196        } else if ( "select_node" == decl->name ) {
197                assert( !selectNodeDecl );
198                selectNodeDecl = decl;
199        }
200}
201
202void GenerateWaitUntilCore::updateAmbiguousWhen( WaitUntilStmt::ClauseNode * currNode, bool andAbove, bool orAbove, bool andBelow, bool orBelow ) {
203    // all children when-ambiguous
204    if ( currNode->left->ambiguousWhen && currNode->right->ambiguousWhen )
205        // true iff an ancestor/descendant has a different operation
206        currNode->ambiguousWhen = (orAbove || orBelow) && (andBelow || andAbove);
207    // ambiguousWhen is initially false so theres no need to set it here
208}
209
210// Traverses ClauseNode tree and paints each AND/OR node as when-ambiguous true or false
211// This tree painting is needed to generate the if statements that set the initial state
212//    of the clause statuses when some clauses are turned off via when_cond
213// An internal AND/OR node is when-ambiguous if it satisfies all of the following:
214// - It has an ancestor or descendant that is a different operation, i.e. (AND has an OR ancestor or vice versa)
215// - All of its descendent clauses are optional, i.e. they have a when_cond defined on the WhenClause
216void GenerateWaitUntilCore::paintWhenTree( WaitUntilStmt::ClauseNode * currNode, bool andAbove, bool orAbove, bool & andBelow, bool & orBelow ) {
217    bool aBelow = false; // updated by child nodes
218    bool oBelow = false; // updated by child nodes
219    switch (currNode->op) {
220        case WaitUntilStmt::ClauseNode::AND:
221            paintWhenTree( currNode->left, true, orAbove, aBelow, oBelow );
222            paintWhenTree( currNode->right, true, orAbove, aBelow, oBelow );
223
224            // update currNode's when flag based on conditions listed in fn signature comment above
225            updateAmbiguousWhen(currNode, true, orAbove, aBelow, oBelow );
226
227            // set return flags to tell parents which decendant ops have been seen
228            andBelow = true;
229            orBelow = oBelow;
230            return;
231        case WaitUntilStmt::ClauseNode::OR:
232            paintWhenTree( currNode->left, andAbove, true, aBelow, oBelow );
233            paintWhenTree( currNode->right, andAbove, true, aBelow, oBelow );
234
235            // update currNode's when flag based on conditions listed in fn signature comment above
236            updateAmbiguousWhen(currNode, andAbove, true, aBelow, oBelow );
237
238            // set return flags to tell parents which decendant ops have been seen
239            andBelow = aBelow;
240            orBelow = true;
241            return;
242        case WaitUntilStmt::ClauseNode::LEAF:
243            if ( currNode->leaf->when_cond )
244                currNode->ambiguousWhen = true;
245            return;
246        default:
247            assertf(false, "Unreachable waituntil clause node type. How did you get here???");
248    }
249}
250
251// overloaded wrapper for paintWhenTree that sets initial values
252// returns true if entire tree is OR's (special case)
253bool GenerateWaitUntilCore::paintWhenTree( WaitUntilStmt::ClauseNode * currNode ) {
254    bool aBelow = false, oBelow = false; // unused by initial call
255    paintWhenTree( currNode, false, false, aBelow, oBelow );
256    return !aBelow;
257}
258
259// Helper: returns Expr that represents arrName[index]
260Expr * genArrAccessExpr( const CodeLocation & loc, int index, string arrName ) {
261    return new UntypedExpr ( loc, 
262        new NameExpr( loc, "?[?]" ),
263        {
264            new NameExpr( loc, arrName ),
265            ConstantExpr::from_int( loc, index )
266        }
267    );
268}
269
270// After the ClauseNode AND/OR nodes are painted this routine is called to traverses the tree and does the following:
271// - collects a set of indices in the clause arr that refer whenclauses that can have ambiguous status assignments (ambigIdxs)
272// - collects a set of indices in the clause arr that refer whenclauses that have a when() defined and an AND node as a parent (andIdxs)
273// - updates LEAF nodes to be when-ambiguous if their direct parent is when-ambiguous.
274void GenerateWaitUntilCore::collectWhens( WaitUntilStmt::ClauseNode * currNode, vector<pair<int, WaitUntilStmt::ClauseNode *>> & ambigIdxs, vector<int> & andIdxs, int & index, bool parentAmbig, bool parentAnd ) {
275    switch (currNode->op) {
276        case WaitUntilStmt::ClauseNode::AND:
277            collectWhens( currNode->left, ambigIdxs, andIdxs, index, currNode->ambiguousWhen, true );
278            collectWhens( currNode->right,  ambigIdxs, andIdxs, index, currNode->ambiguousWhen, true );
279            return;
280        case WaitUntilStmt::ClauseNode::OR:
281            collectWhens( currNode->left,  ambigIdxs, andIdxs, index, currNode->ambiguousWhen, false );
282            collectWhens( currNode->right,  ambigIdxs, andIdxs, index, currNode->ambiguousWhen, false );
283            return;
284        case WaitUntilStmt::ClauseNode::LEAF:
285            if ( parentAmbig ) {
286                ambigIdxs.push_back(make_pair(index, currNode));
287            }
288            if ( parentAnd && currNode->leaf->when_cond ) {
289                currNode->childOfAnd = true;
290                andIdxs.push_back(index);
291            }
292            index++;
293            return;
294        default:
295            assertf(false, "Unreachable waituntil clause node type. How did you get here???");
296    }
297}
298
299// overloaded wrapper for collectWhens that sets initial values
300void GenerateWaitUntilCore::collectWhens( WaitUntilStmt::ClauseNode * currNode, vector<pair<int, WaitUntilStmt::ClauseNode *>> & ambigIdxs, vector<int> & andIdxs ) {
301    int idx = 0;
302    collectWhens( currNode, ambigIdxs, andIdxs, idx, false, false );
303}
304
305// recursively updates ClauseNode whenState on internal nodes so that next pass can see which
306//    subtrees are "turned off"
307// sets whenState = false iff both children have whenState == false.
308// similar to paintWhenTree except since paintWhenTree also filtered out clauses we don't need to consider based on op
309// since the ambiguous clauses were filtered in paintWhenTree we don't need to worry about that here
310void GenerateWaitUntilCore::updateWhenState( WaitUntilStmt::ClauseNode * currNode ) {
311    if ( currNode->op == WaitUntilStmt::ClauseNode::LEAF ) return;
312    updateWhenState( currNode->left );
313    updateWhenState( currNode->right );
314    if ( !currNode->left->whenState && !currNode->right->whenState )
315        currNode->whenState = false;
316    else 
317        currNode->whenState = true;
318}
319
320// generates the minimal set of status assignments to ensure predicate subtree passed as currNode evaluates to status
321// assumes that this will only be called on subtrees that are entirely whenState == false
322void GenerateWaitUntilCore::genSubtreeAssign( const WaitUntilStmt * stmt, WaitUntilStmt::ClauseNode * currNode, bool status, int & idx, CompoundStmt * retStmt, vector<ClauseData *> & clauseData ) {
323    if ( ( currNode->op == WaitUntilStmt::ClauseNode::AND && status )
324        || ( currNode->op == WaitUntilStmt::ClauseNode::OR && !status ) ) {
325        // need to recurse on both subtrees if && subtree needs to be true or || subtree needs to be false
326        genSubtreeAssign( stmt, currNode->left, status, idx, retStmt, clauseData );
327        genSubtreeAssign( stmt, currNode->right, status, idx, retStmt, clauseData );
328    } else if ( ( currNode->op == WaitUntilStmt::ClauseNode::OR && status )
329        || ( currNode->op == WaitUntilStmt::ClauseNode::AND && !status ) ) {
330        // only one subtree needs to evaluate to status if && subtree needs to be true or || subtree needs to be false
331        CompoundStmt * leftStmt = new CompoundStmt( stmt->location );
332        CompoundStmt * rightStmt = new CompoundStmt( stmt->location );
333
334        // only one side needs to evaluate to status so we recurse on both subtrees
335        //    but only keep the statements from the subtree with minimal statements
336        genSubtreeAssign( stmt, currNode->left, status, idx, leftStmt, clauseData );
337        genSubtreeAssign( stmt, currNode->right, status, idx, rightStmt, clauseData );
338       
339        // append minimal statements to retStmt
340        if ( leftStmt->kids.size() < rightStmt->kids.size() ) {
341            retStmt->kids.splice( retStmt->kids.end(), leftStmt->kids );
342        } else {
343            retStmt->kids.splice( retStmt->kids.end(), rightStmt->kids );
344        }
345       
346        delete leftStmt;
347        delete rightStmt;
348    } else if ( currNode->op == WaitUntilStmt::ClauseNode::LEAF ) {
349        const CodeLocation & loc = stmt->location;
350        if ( status && !currNode->childOfAnd ) {
351            retStmt->push_back(
352                new ExprStmt( loc, 
353                    UntypedExpr::createAssign( loc,
354                        genArrAccessExpr( loc, idx, clauseData.at(idx)->statusName ),
355                        new NameExpr( loc, "__SELECT_RUN" )
356                    )
357                )
358            );
359        } else if ( !status && currNode->childOfAnd ) {
360            retStmt->push_back(
361                new ExprStmt( loc, 
362                    UntypedExpr::createAssign( loc,
363                        genArrAccessExpr( loc, idx, clauseData.at(idx)->statusName ),
364                        new NameExpr( loc, "__SELECT_UNSAT" )
365                    )
366                )
367            );
368        }
369
370        // No need to generate statements for the following cases since childOfAnd are always set to true
371        //    and !childOfAnd are always false
372        // - status && currNode->childOfAnd
373        // - !status && !currNode->childOfAnd
374        idx++;
375    }
376}
377
378void GenerateWaitUntilCore::genStatusAssign( const WaitUntilStmt * stmt, WaitUntilStmt::ClauseNode * currNode, int & idx, CompoundStmt * retStmt, vector<ClauseData *> & clauseData ) {
379    switch (currNode->op) {
380        case WaitUntilStmt::ClauseNode::AND:
381            // check which subtrees have all whenState == false (disabled)
382            if (!currNode->left->whenState && !currNode->right->whenState) {
383                // this case can only occur when whole tree is disabled since otherwise
384                //    genStatusAssign( ... ) isn't called on nodes with whenState == false
385                assert( !currNode->whenState ); // paranoidWWW
386                // whole tree disabled so pass true so that select is SAT vacuously
387                genSubtreeAssign( stmt, currNode, true, idx, retStmt, clauseData );
388            } else if ( !currNode->left->whenState ) {
389                // pass true since x && true === x
390                genSubtreeAssign( stmt, currNode->left, true, idx, retStmt, clauseData );
391                genStatusAssign( stmt, currNode->right, idx, retStmt, clauseData );
392            } else if ( !currNode->right->whenState ) {
393                genStatusAssign( stmt, currNode->left, idx, retStmt, clauseData );
394                genSubtreeAssign( stmt, currNode->right, true, idx, retStmt, clauseData );
395            } else { 
396                // if no children with whenState == false recurse normally via break
397                break;
398            }
399            return;
400        case WaitUntilStmt::ClauseNode::OR:
401            if (!currNode->left->whenState && !currNode->right->whenState) {
402                assert( !currNode->whenState ); // paranoid
403                genSubtreeAssign( stmt, currNode, true, idx, retStmt, clauseData );
404            } else if ( !currNode->left->whenState ) {
405                // pass false since x || false === x
406                genSubtreeAssign( stmt, currNode->left, false, idx, retStmt, clauseData );
407                genStatusAssign( stmt, currNode->right, idx, retStmt, clauseData );
408            } else if ( !currNode->right->whenState ) {
409                genStatusAssign( stmt, currNode->left, idx, retStmt, clauseData );
410                genSubtreeAssign( stmt, currNode->right, false, idx, retStmt, clauseData );
411            } else { 
412                break;
413            }
414            return;
415        case WaitUntilStmt::ClauseNode::LEAF:
416            idx++;
417            return;
418        default:
419            assertf(false, "Unreachable waituntil clause node type. How did you get here???");
420    }
421    genStatusAssign( stmt, currNode->left, idx, retStmt, clauseData );
422    genStatusAssign( stmt, currNode->right, idx, retStmt, clauseData );
423}
424
425// generates a minimal set of assignments for status arr based on which whens are toggled on/off
426CompoundStmt * GenerateWaitUntilCore::getStatusAssignment( const WaitUntilStmt * stmt, vector<ClauseData *> & clauseData ) {
427    updateWhenState( stmt->predicateTree );
428    CompoundStmt * retval = new CompoundStmt( stmt->location );
429    int idx = 0;
430    genStatusAssign( stmt, stmt->predicateTree, idx, retval, clauseData );
431    return retval;
432}
433
434// generates nested if/elses for all possible assignments of ambiguous when_conds
435// exponential size of code gen but linear runtime O(n), where n is number of ambiguous whens()
436Stmt * GenerateWaitUntilCore::genWhenStateConditions( const WaitUntilStmt * stmt, vector<ClauseData *> & clauseData, 
437    vector<pair<int, WaitUntilStmt::ClauseNode *>> & ambigClauses, vector<pair<int, WaitUntilStmt::ClauseNode *>>::size_type ambigIdx ) {
438    // I hate C++ sometimes, using vector<pair<int, WaitUntilStmt::ClauseNode *>>::size_type for size() comparison seems silly.
439    //    Why is size_type parameterized on the type stored in the vector?????
440
441    const CodeLocation & loc = stmt->location;
442    int clauseIdx = ambigClauses.at(ambigIdx).first;
443    WaitUntilStmt::ClauseNode * currNode = ambigClauses.at(ambigIdx).second;
444    Stmt * thenStmt;
445    Stmt * elseStmt;
446   
447    if ( ambigIdx == ambigClauses.size() - 1 ) { // base case
448        currNode->whenState = true;
449        thenStmt = getStatusAssignment( stmt, clauseData );
450        currNode->whenState = false;
451        elseStmt = getStatusAssignment( stmt, clauseData );
452    } else {
453        // recurse both with when enabled and disabled to generate all possible cases
454        currNode->whenState = true;
455        thenStmt = genWhenStateConditions( stmt, clauseData, ambigClauses, ambigIdx + 1 );
456        currNode->whenState = false;
457        elseStmt = genWhenStateConditions( stmt, clauseData, ambigClauses, ambigIdx + 1 );
458    }
459
460    // insert first recursion result in if ( __when_cond_ ) { ... }
461    // insert second recursion result in else { ... }
462    return new CompoundStmt ( loc,
463        {
464            new IfStmt( loc,
465                new NameExpr( loc, clauseData.at(clauseIdx)->whenName ),
466                thenStmt,
467                elseStmt
468            )
469        }
470    );
471}
472
473// typedef a fn ptr so that we can reuse genPredExpr
474// genLeafExpr is used to refer to one of the following two routines
475typedef Expr * (*GenLeafExpr)( const CodeLocation & loc, int & index );
476
477// return Expr that represents clause_statuses[index]
478// mutates index to be index + 1
479Expr * genSatExpr( const CodeLocation & loc, int & index ) {
480    return genArrAccessExpr( loc, index++, "clause_statuses" );
481}
482
483// return Expr that represents has_run(clause_statuses[index])
484Expr * genRunExpr( const CodeLocation & loc, int & index ) {
485    return new UntypedExpr ( loc, 
486        new NameExpr( loc, "__CFA_has_clause_run" ),
487        { genSatExpr( loc, index ) }
488    );
489}
490
491// Takes in the ClauseNode tree and recursively generates
492// the predicate expr used inside the predicate functions
493Expr * genPredExpr( const CodeLocation & loc, WaitUntilStmt::ClauseNode * currNode, int & idx, GenLeafExpr genLeaf ) {
494    Expr * leftExpr, * rightExpr;
495    switch (currNode->op) {
496        case WaitUntilStmt::ClauseNode::AND:
497            leftExpr = genPredExpr( loc, currNode->left, idx, genLeaf );
498            rightExpr = genPredExpr( loc, currNode->right, idx, genLeaf );
499            return new LogicalExpr( loc, 
500                new CastExpr( loc, leftExpr, new BasicType( BasicKind::Bool ), GeneratedFlag::ExplicitCast ),
501                new CastExpr( loc, rightExpr, new BasicType( BasicKind::Bool ), GeneratedFlag::ExplicitCast ), 
502                LogicalFlag::AndExpr
503            );
504            break;
505        case WaitUntilStmt::ClauseNode::OR:
506            leftExpr = genPredExpr( loc, currNode->left, idx, genLeaf );
507            rightExpr = genPredExpr( loc, currNode->right, idx, genLeaf );
508            return new LogicalExpr( loc,
509                new CastExpr( loc, leftExpr, new BasicType( BasicKind::Bool ), GeneratedFlag::ExplicitCast ),
510                new CastExpr( loc, rightExpr, new BasicType( BasicKind::Bool ), GeneratedFlag::ExplicitCast ), 
511                LogicalFlag::OrExpr );
512            break;
513        case WaitUntilStmt::ClauseNode::LEAF:
514            return genLeaf( loc, idx );
515            break;
516        default:
517            assertf(false, "Unreachable waituntil clause node type. How did you get here???");\
518            return nullptr;
519            break;
520    }
521    return nullptr;
522}
523
524
525// Builds the predicate functions used to check the status of the waituntil statement
526/* Ex:
527{
528    waituntil( A ){ doA(); }
529    or waituntil( B ){ doB(); }
530    and waituntil( C ) { doC(); }
531}
532generates =>
533static inline bool is_full_sat_1( int * clause_statuses ) {
534    return clause_statuses[0]
535        || clause_statuses[1]
536        && clause_statuses[2];
537}
538
539static inline bool is_done_sat_1( int * clause_statuses ) {
540    return has_run(clause_statuses[0])
541        || has_run(clause_statuses[1])
542        && has_run(clause_statuses[2]);
543}
544*/
545// Returns a predicate function decl
546// predName and genLeaf determine if this generates an is_done or an is_full predicate
547FunctionDecl * buildPredicate( const WaitUntilStmt * stmt, GenLeafExpr genLeaf, string & predName ) {
548    int arrIdx = 0;
549    const CodeLocation & loc = stmt->location;
550    CompoundStmt * body = new CompoundStmt( loc );
551    body->push_back( new ReturnStmt( loc, genPredExpr( loc,  stmt->predicateTree, arrIdx, genLeaf ) ) );
552
553    return new FunctionDecl( loc,
554        predName,
555        {
556            new ObjectDecl( loc,
557                "clause_statuses",
558                new PointerType( new BasicType( BasicKind::LongUnsignedInt ) )
559            )
560        },
561        {
562            new ObjectDecl( loc,
563                "sat_ret",
564                new BasicType( BasicKind::Bool )
565            )
566        },
567        body,               // body
568        { Storage::Static },    // storage
569        Linkage::Cforall,       // linkage
570        {},                     // attributes
571        { Function::Inline }
572    );
573}
574
575// Creates is_done and is_full predicates
576void GenerateWaitUntilCore::addPredicates( const WaitUntilStmt * stmt, string & satName, string & runName ) {
577    if ( !stmt->else_stmt || stmt->else_cond ) // don't need SAT predicate when else variation with no else_cond
578        satFns.push_back( Concurrency::buildPredicate( stmt, genSatExpr, satName ) ); 
579    satFns.push_back( Concurrency::buildPredicate( stmt, genRunExpr, runName ) );
580}
581
582// Adds the following to body:
583// if ( when_cond ) { // this if is omitted if no when() condition
584//      setup_clause( clause1, &clause_statuses[0], &park_counter );
585//      register_select(A, clause1);
586// }
587void GenerateWaitUntilCore::setUpClause( const WhenClause * clause, ClauseData * data, string & pCountName, CompoundStmt * body ) {   
588    CompoundStmt * currBody = body;
589    const CodeLocation & loc = clause->location;
590
591    // If we have a when_cond make the initialization conditional
592    if ( clause->when_cond )
593        currBody = new CompoundStmt( loc );
594
595    // Generates: setup_clause( clause1, &clause_statuses[0], &park_counter );
596    currBody->push_back( new ExprStmt( loc,
597        new UntypedExpr ( loc,
598            new NameExpr( loc, "setup_clause" ),
599            {
600                new NameExpr( loc, data->nodeName ),
601                new AddressExpr( loc, genArrAccessExpr( loc, data->index, data->statusName ) ),
602                new AddressExpr( loc, new NameExpr( loc, pCountName ) )
603            }
604        )
605    ));
606
607    // Generates: register_select(A, clause1);
608    currBody->push_back( new ExprStmt( loc, genSelectTraitCall( clause, data, "register_select" ) ) );
609
610    // generates: if ( when_cond ) { ... currBody ... }
611    if ( clause->when_cond )
612        body->push_back( 
613            new IfStmt( loc,
614                new NameExpr( loc, data->whenName ),
615                currBody
616            )
617        );
618}
619
620// Used to generate a call to one of the select trait routines
621Expr * GenerateWaitUntilCore::genSelectTraitCall( const WhenClause * clause, const ClauseData * data, string fnName ) {
622    const CodeLocation & loc = clause->location;
623    return new UntypedExpr ( loc,
624        new NameExpr( loc, fnName ),
625        {
626            new NameExpr( loc, data->targetName ),
627            new NameExpr( loc, data->nodeName )
628        }
629    );
630}
631
632// Generates:
633/* on_selected( target_1, node_1 ); ... corresponding body of target_1 ...
634*/
635CompoundStmt * GenerateWaitUntilCore::genStmtBlock( const WhenClause * clause, const ClauseData * data ) {
636    const CodeLocation & cLoc = clause->location;
637    return new CompoundStmt( cLoc,
638        {
639            new IfStmt( cLoc,
640                genSelectTraitCall( clause, data, "on_selected" ),
641                ast::deepCopy( clause->stmt )
642            )
643        }
644    );
645}
646
647// this routine generates and returns the following
648/*for ( int i = 0; i < numClauses; i++ ) {
649    if ( predName(clause_statuses) ) break;
650    if (clause_statuses[i] == __SELECT_SAT) {
651        switch (i) {
652            case 0:
653                try {
654                    on_selected( target1, clause1 );
655                    dotarget1stmt();
656                }
657                finally { clause_statuses[i] = __SELECT_RUN; unregister_select(target1, clause1); }
658                break;
659            ...
660            case N:
661                ...
662                break;
663        }
664    }
665}*/
666CompoundStmt * GenerateWaitUntilCore::genStatusCheckFor( const WaitUntilStmt * stmt, vector<ClauseData *> & clauseData, string & predName ) {
667    CompoundStmt * ifBody = new CompoundStmt( stmt->location );
668    const CodeLocation & loc = stmt->location;
669
670    string switchLabel = namer_label.newName();
671
672    /* generates:
673    switch (i) {
674        case 0:
675            try {
676                on_selected( target1, clause1 );
677                dotarget1stmt();
678            }
679            finally { clause_statuses[i] = __SELECT_RUN; unregister_select(target1, clause1); }
680            break;
681            ...
682        case N:
683            ...
684            break;
685    }*/
686    std::vector<ptr<CaseClause>> switchCases;
687    int idx = 0;
688    for ( const auto & clause: stmt->clauses ) {
689        const CodeLocation & cLoc = clause->location;
690        switchCases.push_back(
691            new CaseClause( cLoc,
692                ConstantExpr::from_int( cLoc, idx ),
693                {
694                    new CompoundStmt( cLoc,
695                        {
696                            new ast::TryStmt( cLoc,
697                                genStmtBlock( clause, clauseData.at(idx) ),
698                                {},
699                                new ast::FinallyClause( cLoc, 
700                                    new CompoundStmt( cLoc,
701                                        {
702                                            new ExprStmt( loc,
703                                                new UntypedExpr ( loc,
704                                                    new NameExpr( loc, "?=?" ),
705                                                    {
706                                                        new UntypedExpr ( loc, 
707                                                            new NameExpr( loc, "?[?]" ),
708                                                            {
709                                                                new NameExpr( loc, clauseData.at(0)->statusName ),
710                                                                new NameExpr( loc, idxName )
711                                                            }
712                                                        ),
713                                                        new NameExpr( loc, "__SELECT_RUN" )
714                                                    }
715                                                )
716                                            ),
717                                            new ExprStmt( loc, genSelectTraitCall( clause, clauseData.at(idx), "unregister_select" ) )
718                                        }
719                                    )
720                                )
721                            ),
722                            new BranchStmt( cLoc, BranchStmt::Kind::Break, Label( cLoc, switchLabel ) )
723                        }
724                    )
725                }
726            )
727        );
728        idx++;
729    }
730
731    ifBody->push_back(
732        new SwitchStmt( loc,
733            new NameExpr( loc, idxName ),
734            std::move( switchCases ),
735            { Label( loc, switchLabel ) }
736        )
737    );
738
739    // gens:
740    // if (clause_statuses[i] == __SELECT_SAT) {
741    //      ... ifBody  ...
742    // }
743    IfStmt * ifSwitch = new IfStmt( loc,
744        new UntypedExpr ( loc,
745            new NameExpr( loc, "?==?" ),
746            {
747                new UntypedExpr ( loc, 
748                    new NameExpr( loc, "?[?]" ),
749                    {
750                        new NameExpr( loc, clauseData.at(0)->statusName ),
751                        new NameExpr( loc, idxName )
752                    }
753                ),
754                new NameExpr( loc, "__SELECT_SAT" )
755            }
756        ),      // condition
757        ifBody  // body
758    );
759
760    string forLabel = namer_label.newName();
761
762    // we hoist init here so that this pass can happen after hoistdecls pass
763    return new CompoundStmt( loc,
764        {
765            new DeclStmt( loc,
766                new ObjectDecl( loc,
767                    idxName,
768                    new BasicType( BasicKind::SignedInt ),
769                    new SingleInit( loc, ConstantExpr::from_int( loc, 0 ) )
770                )
771            ),
772            new ForStmt( loc,
773                {},  // inits
774                new UntypedExpr ( loc,
775                    new NameExpr( loc, "?<?" ),
776                    {
777                        new NameExpr( loc, idxName ),
778                        ConstantExpr::from_int( loc, stmt->clauses.size() )
779                    }
780                ),  // cond
781                new UntypedExpr ( loc,
782                    new NameExpr( loc, "?++" ),
783                    { new NameExpr( loc, idxName ) }
784                ),  // inc
785                new CompoundStmt( loc,
786                    {
787                        new IfStmt( loc,
788                            new UntypedExpr ( loc,
789                                new NameExpr( loc, predName ),
790                                { new NameExpr( loc, clauseData.at(0)->statusName ) }
791                            ),
792                            new BranchStmt( loc, BranchStmt::Kind::Break, Label( loc, forLabel ) )
793                        ),
794                        ifSwitch
795                    }
796                ),   // body
797                { Label( loc, forLabel ) }
798            )
799        }
800    );
801}
802
803// Generates: !is_full_sat_n() / !is_run_sat_n()
804Expr * genNotSatExpr( const WaitUntilStmt * stmt, string & satName, string & arrName ) {
805    const CodeLocation & loc = stmt->location;
806    return new UntypedExpr ( loc,
807        new NameExpr( loc, "!?" ),
808        {
809            new UntypedExpr ( loc,
810                new NameExpr( loc, satName ),
811                { new NameExpr( loc, arrName ) }
812            )
813        }
814    );
815}
816
817// Generates the code needed for waituntils with an else ( ... )
818// Checks clauses once after registering for completion and runs them if completes
819// If not enough have run to satisfy predicate after one pass then the else is run
820Stmt * GenerateWaitUntilCore::genElseClauseBranch( const WaitUntilStmt * stmt, string & runName, string & arrName, vector<ClauseData *> & clauseData ) {
821    return new CompoundStmt( stmt->else_stmt->location,
822        {
823            genStatusCheckFor( stmt, clauseData, runName ),
824            new IfStmt( stmt->else_stmt->location,
825                genNotSatExpr( stmt, runName, arrName ),
826                ast::deepCopy( stmt->else_stmt )
827            )
828        }
829    );
830}
831
832Stmt * GenerateWaitUntilCore::genNoElseClauseBranch( const WaitUntilStmt * stmt, string & runName, string & arrName, string & pCountName, vector<ClauseData *> & clauseData ) {
833    CompoundStmt * whileBody = new CompoundStmt( stmt->location );
834    const CodeLocation & loc = stmt->location;
835
836    // generates: __CFA_maybe_park( &park_counter );
837    whileBody->push_back(
838        new ExprStmt( loc,
839            new UntypedExpr ( loc,
840                new NameExpr( loc, "__CFA_maybe_park" ),
841                { new AddressExpr( loc, new NameExpr( loc, pCountName ) ) }
842            )
843        )
844    );
845
846    whileBody->push_back( genStatusCheckFor( stmt, clauseData, runName ) );
847
848    return new CompoundStmt( loc,
849        {
850            new WhileDoStmt( loc,
851                genNotSatExpr( stmt, runName, arrName ),
852                whileBody,  // body
853                {}          // no inits
854            )
855        }
856    );
857}
858
859// generates the following decls for each clause to ensure the target expr and when_cond is only evaluated once
860// typeof(target) & __clause_target_0 = target;
861// bool __when_cond_0 = when_cond; // only generated if when_cond defined
862// select_node clause1;
863void GenerateWaitUntilCore::genClauseInits( const WaitUntilStmt * stmt, vector<ClauseData *> & clauseData, CompoundStmt * body, string & statusName, string & elseWhenName ) {
864    ClauseData * currClause;
865    for ( vector<ClauseData*>::size_type i = 0; i < stmt->clauses.size(); i++ ) {
866        currClause = new ClauseData( i, statusName );
867        currClause->nodeName = namer_node.newName();
868        currClause->targetName = namer_target.newName();
869        currClause->whenName = namer_when.newName();
870        clauseData.push_back(currClause);
871        const CodeLocation & cLoc = stmt->clauses.at(i)->location;
872
873        // typeof(target) & __clause_target_0 = target;
874        body->push_back(
875            new DeclStmt( cLoc,
876                new ObjectDecl( cLoc,
877                    currClause->targetName,
878                    new ReferenceType( 
879                        new TypeofType( new UntypedExpr( cLoc,
880                            new NameExpr( cLoc, "__CFA_select_get_type" ),
881                            { ast::deepCopy( stmt->clauses.at(i)->target ) }
882                        ))
883                    ),
884                    new SingleInit( cLoc, ast::deepCopy( stmt->clauses.at(i)->target ) )
885                )
886            )
887        );
888
889        // bool __when_cond_0 = when_cond; // only generated if when_cond defined
890        if ( stmt->clauses.at(i)->when_cond )
891            body->push_back(
892                new DeclStmt( cLoc,
893                    new ObjectDecl( cLoc,
894                        currClause->whenName,
895                        new BasicType( BasicKind::Bool ),
896                        new SingleInit( cLoc, ast::deepCopy( stmt->clauses.at(i)->when_cond ) )
897                    )
898                )
899            );
900       
901        // select_node clause1;
902        body->push_back(
903            new DeclStmt( cLoc,
904                new ObjectDecl( cLoc,
905                    currClause->nodeName,
906                    new StructInstType( selectNodeDecl )
907                )
908            )
909        );
910    }
911
912    if ( stmt->else_stmt && stmt->else_cond ) {
913        body->push_back(
914            new DeclStmt( stmt->else_cond->location,
915                new ObjectDecl( stmt->else_cond->location,
916                    elseWhenName,
917                    new BasicType( BasicKind::Bool ),
918                    new SingleInit( stmt->else_cond->location, ast::deepCopy( stmt->else_cond ) )
919                )
920            )
921        );
922    }
923}
924
925/*
926if ( clause_status == &clause1 ) ... clause 1 body ...
927...
928elif ( clause_status == &clausen ) ... clause n body ...
929*/
930Stmt * GenerateWaitUntilCore::buildOrCaseSwitch( const WaitUntilStmt * stmt, string & statusName, vector<ClauseData *> & data ) {
931    const CodeLocation & loc = stmt->location;
932
933    IfStmt * outerIf = nullptr;
934        IfStmt * lastIf = nullptr;
935
936        //adds an if/elif clause for each select clause address to run the corresponding clause stmt
937        for ( long unsigned int i = 0; i < data.size(); i++ ) {
938        const CodeLocation & cLoc = stmt->clauses.at(i)->location;
939
940                IfStmt * currIf = new IfStmt( cLoc,
941                        new UntypedExpr( cLoc, 
942                new NameExpr( cLoc, "?==?" ), 
943                {
944                    new NameExpr( cLoc, statusName ),
945                    new CastExpr( cLoc, 
946                        new AddressExpr( cLoc, new NameExpr( cLoc, data.at(i)->nodeName ) ),
947                        new BasicType( BasicKind::LongUnsignedInt ), GeneratedFlag::ExplicitCast
948                    )
949                }
950            ),
951            genStmtBlock( stmt->clauses.at(i), data.at(i) )
952                );
953               
954                if ( i == 0 ) {
955                        outerIf = currIf;
956                } else {
957                        // add ifstmt to else of previous stmt
958                        lastIf->else_ = currIf;
959                }
960
961                lastIf = currIf;
962        }
963
964    return new CompoundStmt( loc,
965        {
966            new ExprStmt( loc, new UntypedExpr( loc, new NameExpr( loc, "park" ) ) ),
967            outerIf
968        }
969    );
970}
971
972Stmt * GenerateWaitUntilCore::recursiveOrIfGen( const WaitUntilStmt * stmt, vector<ClauseData *> & data, vector<ClauseData*>::size_type idx, string & elseWhenName ) {
973    if ( idx == data.size() ) {   // base case, gen last else
974        const CodeLocation & cLoc = stmt->else_stmt->location;
975        if ( !stmt->else_stmt ) // normal non-else gen
976            return buildOrCaseSwitch( stmt, data.at(0)->statusName, data );
977
978        Expr * raceFnCall = new UntypedExpr( stmt->location,
979            new NameExpr( stmt->location, "__select_node_else_race" ),
980            { new NameExpr( stmt->location, data.at(0)->nodeName ) }
981        );
982
983        if ( stmt->else_stmt && stmt->else_cond ) { // return else conditional on both when and race
984            return new IfStmt( cLoc,
985                new LogicalExpr( cLoc,
986                    new CastExpr( cLoc,
987                        new NameExpr( cLoc, elseWhenName ),
988                        new BasicType( BasicKind::Bool ), GeneratedFlag::ExplicitCast
989                    ),
990                    new CastExpr( cLoc,
991                        raceFnCall,
992                        new BasicType( BasicKind::Bool ), GeneratedFlag::ExplicitCast
993                    ),
994                    LogicalFlag::AndExpr
995                ),
996                ast::deepCopy( stmt->else_stmt ),
997                buildOrCaseSwitch( stmt, data.at(0)->statusName, data )
998            );
999        }
1000
1001        // return else conditional on race
1002        return new IfStmt( stmt->else_stmt->location,
1003            raceFnCall,
1004            ast::deepCopy( stmt->else_stmt ),
1005            buildOrCaseSwitch( stmt, data.at(0)->statusName, data )
1006        );
1007    }
1008    const CodeLocation & cLoc = stmt->clauses.at(idx)->location;
1009
1010    Expr * baseCond = genSelectTraitCall( stmt->clauses.at(idx), data.at(idx), "register_select" );
1011    Expr * ifCond;
1012
1013    // If we have a when_cond make the register call conditional on it
1014    if ( stmt->clauses.at(idx)->when_cond ) {
1015        ifCond = new LogicalExpr( cLoc,
1016            new CastExpr( cLoc,
1017                new NameExpr( cLoc, data.at(idx)->whenName ), 
1018                new BasicType( BasicKind::Bool ), GeneratedFlag::ExplicitCast
1019            ),
1020            new CastExpr( cLoc,
1021                baseCond,
1022                new BasicType( BasicKind::Bool ), GeneratedFlag::ExplicitCast
1023            ),
1024            LogicalFlag::AndExpr
1025        );
1026    } else ifCond = baseCond;
1027
1028    return new CompoundStmt( cLoc,
1029        {   // gens: setup_clause( clause1, &status, 0p );
1030            new ExprStmt( cLoc,
1031                new UntypedExpr ( cLoc,
1032                    new NameExpr( cLoc, "setup_clause" ),
1033                    {
1034                        new NameExpr( cLoc, data.at(idx)->nodeName ),
1035                        new AddressExpr( cLoc, new NameExpr( cLoc, data.at(idx)->statusName ) ),
1036                        ConstantExpr::null( cLoc, new PointerType( new BasicType( BasicKind::SignedInt ) ) )
1037                    }
1038                )
1039            ),
1040            // gens: if (__when_cond && register_select()) { clause body } else { ... recursiveOrIfGen ... }
1041            new IfStmt( cLoc,
1042                ifCond,
1043                genStmtBlock( stmt->clauses.at(idx), data.at(idx) ),
1044                recursiveOrIfGen( stmt, data, idx + 1, elseWhenName )
1045            )
1046        }
1047    );
1048}
1049
1050// This gens the special case of an all OR waituntil:
1051/*
1052int status = 0;
1053
1054typeof(target) & __clause_target_0 = target;
1055bool __when_cond_0 = when_cond; // only generated if when_cond defined
1056select_node clause1;
1057... generate above for rest of clauses ...
1058
1059try {
1060    setup_clause( clause1, &status, 0p );
1061    if ( __when_cond_0 && register_select( 1 ) ) {
1062        ... clause 1 body ...
1063    } else {
1064        ... recursively gen for each of n clauses ...
1065        setup_clause( clausen, &status, 0p );
1066        if ( __when_cond_n-1 && register_select( n ) ) {
1067            ... clause n body ...
1068        } else {
1069            if ( else_when ) ... else clause body ...
1070            else {
1071                park();
1072
1073                // after winning the race and before unpark() clause_status is set to be the winning clause index + 1
1074                if ( clause_status == &clause1) ... clause 1 body ...
1075                ...
1076                elif ( clause_status == &clausen ) ... clause n body ...
1077            }
1078        }
1079    }
1080}
1081finally {
1082    if ( __when_cond_1 && clause1.status != 0p) unregister_select( 1 ); // if registered unregister
1083    ...
1084    if ( __when_cond_n && clausen.status != 0p) unregister_select( n );
1085}
1086*/
1087Stmt * GenerateWaitUntilCore::genAllOr( const WaitUntilStmt * stmt ) {
1088    const CodeLocation & loc = stmt->location;
1089    string statusName = namer_status.newName();
1090    string elseWhenName = namer_when.newName();
1091    int numClauses = stmt->clauses.size();
1092    CompoundStmt * body = new CompoundStmt( stmt->location );
1093
1094    // Generates: unsigned long int status = 0;
1095    body->push_back( new DeclStmt( loc,
1096        new ObjectDecl( loc,
1097            statusName,
1098            new BasicType( BasicKind::LongUnsignedInt ),
1099            new SingleInit( loc, ConstantExpr::from_int( loc, 0 ) )
1100        )
1101    ));
1102
1103    vector<ClauseData *> clauseData;
1104    genClauseInits( stmt, clauseData, body, statusName, elseWhenName );
1105
1106    vector<int> whenIndices; // track which clauses have whens
1107
1108    CompoundStmt * unregisters = new CompoundStmt( loc );
1109    Expr * ifCond;
1110    for ( int i = 0; i < numClauses; i++ ) {
1111        const CodeLocation & cLoc = stmt->clauses.at(i)->location;
1112        // Gens: node.status != 0p
1113        UntypedExpr * statusPtrCheck = new UntypedExpr( cLoc, 
1114            new NameExpr( cLoc, "?!=?" ), 
1115            {
1116                ConstantExpr::null( cLoc, new PointerType( new BasicType( BasicKind::LongUnsignedInt ) ) ),
1117                new UntypedExpr( cLoc, 
1118                    new NameExpr( cLoc, "__get_clause_status" ), 
1119                    { new NameExpr( cLoc, clauseData.at(i)->nodeName ) } 
1120                ) 
1121            }
1122        );
1123
1124        // If we have a when_cond make the unregister call conditional on it
1125        if ( stmt->clauses.at(i)->when_cond ) {
1126            whenIndices.push_back(i);
1127            ifCond = new LogicalExpr( cLoc,
1128                new CastExpr( cLoc,
1129                    new NameExpr( cLoc, clauseData.at(i)->whenName ), 
1130                    new BasicType( BasicKind::Bool ), GeneratedFlag::ExplicitCast
1131                ),
1132                new CastExpr( cLoc,
1133                    statusPtrCheck,
1134                    new BasicType( BasicKind::Bool ), GeneratedFlag::ExplicitCast
1135                ),
1136                LogicalFlag::AndExpr
1137            );
1138        } else ifCond = statusPtrCheck;
1139       
1140        unregisters->push_back(
1141            new IfStmt( cLoc,
1142                ifCond,
1143                new ExprStmt( cLoc, genSelectTraitCall( stmt->clauses.at(i), clauseData.at(i), "unregister_select" ) ) 
1144            )
1145        );
1146    }
1147
1148    if ( whenIndices.empty() || whenIndices.size() != stmt->clauses.size() ) {
1149        body->push_back(
1150                new ast::TryStmt( loc,
1151                new CompoundStmt( loc, { recursiveOrIfGen( stmt, clauseData, 0, elseWhenName ) } ),
1152                {},
1153                new ast::FinallyClause( loc, unregisters )
1154            )
1155        );
1156    } else { // If all clauses have whens, we need to skip the waituntil if they are all false
1157        Expr * outerIfCond = new NameExpr( loc, clauseData.at( whenIndices.at(0) )->whenName );
1158        Expr * lastExpr = outerIfCond;
1159
1160        for ( vector<int>::size_type i = 1; i < whenIndices.size(); i++ ) {
1161            outerIfCond = new LogicalExpr( loc,
1162                new CastExpr( loc,
1163                    new NameExpr( loc, clauseData.at( whenIndices.at(i) )->whenName ), 
1164                    new BasicType( BasicKind::Bool ), GeneratedFlag::ExplicitCast
1165                ),
1166                new CastExpr( loc,
1167                    lastExpr,
1168                    new BasicType( BasicKind::Bool ), GeneratedFlag::ExplicitCast
1169                ),
1170                LogicalFlag::OrExpr
1171            );
1172            lastExpr = outerIfCond;
1173        }
1174
1175        body->push_back(
1176                new ast::TryStmt( loc,
1177                new CompoundStmt( loc, 
1178                    {
1179                        new IfStmt( loc,
1180                            outerIfCond,
1181                            recursiveOrIfGen( stmt, clauseData, 0, elseWhenName )
1182                        )
1183                    }
1184                ),
1185                {},
1186                new ast::FinallyClause( loc, unregisters )
1187            )
1188        );
1189    }
1190
1191    for ( ClauseData * datum : clauseData )
1192        delete datum;
1193
1194    return body;
1195}
1196
1197Stmt * GenerateWaitUntilCore::postvisit( const WaitUntilStmt * stmt ) {
1198    if ( !selectNodeDecl )
1199        SemanticError( stmt, "waituntil statement requires #include <waituntil.hfa>" );
1200
1201    // Prep clause tree to figure out how to set initial statuses
1202    // setTreeSizes( stmt->predicateTree );
1203    if ( paintWhenTree( stmt->predicateTree ) ) // if this returns true we can special case since tree is all OR's
1204        return genAllOr( stmt );
1205
1206    CompoundStmt * tryBody = new CompoundStmt( stmt->location );
1207    CompoundStmt * body = new CompoundStmt( stmt->location );
1208    string statusArrName = namer_status.newName();
1209    string pCountName = namer_park.newName();
1210    string satName = namer_sat.newName();
1211    string runName = namer_run.newName();
1212    string elseWhenName = namer_when.newName();
1213    int numClauses = stmt->clauses.size();
1214    addPredicates( stmt, satName, runName );
1215
1216    const CodeLocation & loc = stmt->location;
1217
1218    // Generates: int park_counter = 0;
1219    body->push_back( new DeclStmt( loc,
1220        new ObjectDecl( loc,
1221            pCountName,
1222            new BasicType( BasicKind::SignedInt ),
1223            new SingleInit( loc, ConstantExpr::from_int( loc, 0 ) )
1224        )
1225    ));
1226
1227    // Generates: int clause_statuses[3] = { 0 };
1228    body->push_back( new DeclStmt( loc,
1229        new ObjectDecl( loc,
1230            statusArrName,
1231            new ArrayType( new BasicType( BasicKind::LongUnsignedInt ), ConstantExpr::from_int( loc, numClauses ), LengthFlag::FixedLen, DimensionFlag::DynamicDim ),
1232            new ListInit( loc,
1233                {
1234                    new SingleInit( loc, ConstantExpr::from_int( loc, 0 ) )
1235                }
1236            )
1237        )
1238    ));
1239
1240    vector<ClauseData *> clauseData;
1241    genClauseInits( stmt, clauseData, body, statusArrName, elseWhenName );
1242
1243    vector<pair<int, WaitUntilStmt::ClauseNode *>> ambiguousClauses;       // list of ambiguous clauses
1244    vector<int> andWhenClauses;    // list of clauses that have an AND op as a direct parent and when_cond defined
1245
1246    collectWhens( stmt->predicateTree, ambiguousClauses, andWhenClauses );
1247
1248    // This is only needed for clauses that have AND as a parent and a when_cond defined
1249    // generates: if ( ! when_cond_0 ) clause_statuses_0 = __SELECT_RUN;
1250    for ( int idx : andWhenClauses ) {
1251        const CodeLocation & cLoc = stmt->clauses.at(idx)->location;
1252        body->push_back( 
1253            new IfStmt( cLoc,
1254                new UntypedExpr ( cLoc,
1255                    new NameExpr( cLoc, "!?" ),
1256                    { new NameExpr( cLoc, clauseData.at(idx)->whenName ) }
1257                ),  // IfStmt cond
1258                new ExprStmt( cLoc,
1259                    new UntypedExpr ( cLoc,
1260                        new NameExpr( cLoc, "?=?" ),
1261                        {
1262                            new UntypedExpr ( cLoc, 
1263                                new NameExpr( cLoc, "?[?]" ),
1264                                {
1265                                    new NameExpr( cLoc, statusArrName ),
1266                                    ConstantExpr::from_int( cLoc, idx )
1267                                }
1268                            ),
1269                            new NameExpr( cLoc, "__SELECT_RUN" )
1270                        }
1271                    )
1272                )  // IfStmt then
1273            )
1274        );
1275    }
1276
1277    // Only need to generate conditional initial state setting for ambiguous when clauses
1278    if ( !ambiguousClauses.empty() ) {
1279        body->push_back( genWhenStateConditions( stmt, clauseData, ambiguousClauses, 0 ) );
1280    }
1281
1282    // generates the following for each clause:
1283    // setup_clause( clause1, &clause_statuses[0], &park_counter );
1284    // register_select(A, clause1);
1285    for ( int i = 0; i < numClauses; i++ ) {
1286        setUpClause( stmt->clauses.at(i), clauseData.at(i), pCountName, tryBody );
1287    }
1288
1289    // generate satisfy logic based on if there is an else clause and if it is conditional
1290    if ( stmt->else_stmt && stmt->else_cond ) { // gen both else/non else branches
1291        tryBody->push_back(
1292            new IfStmt( stmt->else_cond->location,
1293                new NameExpr( stmt->else_cond->location, elseWhenName ),
1294                genElseClauseBranch( stmt, runName, statusArrName, clauseData ),
1295                genNoElseClauseBranch( stmt, runName, statusArrName, pCountName, clauseData )
1296            )
1297        );
1298    } else if ( !stmt->else_stmt ) { // normal gen
1299        tryBody->push_back( genNoElseClauseBranch( stmt, runName, statusArrName, pCountName, clauseData ) );
1300    } else { // generate just else
1301        tryBody->push_back( genElseClauseBranch( stmt, runName, statusArrName, clauseData ) );
1302    }
1303
1304    // Collection of unregister calls on resources to be put in finally clause
1305    // for each clause:
1306    // if ( !__CFA_has_clause_run( clause_statuses[i] )) && unregister_select( ... , clausei ) ) { ... clausei stmt ... }
1307    // OR if when( ... ) defined on resource
1308    // if ( when_cond_i && (!__CFA_has_clause_run( clause_statuses[i] )) && unregister_select( ... , clausei ) ) { ... clausei stmt ... }
1309    CompoundStmt * unregisters = new CompoundStmt( loc );
1310
1311    Expr * statusExpr; // !__CFA_has_clause_run( clause_statuses[i] )
1312    for ( int i = 0; i < numClauses; i++ ) {
1313        const CodeLocation & cLoc = stmt->clauses.at(i)->location;
1314
1315        // Generates: !__CFA_has_clause_run( clause_statuses[i] )
1316        statusExpr = new UntypedExpr ( cLoc,
1317            new NameExpr( cLoc, "!?" ),
1318            {
1319                new UntypedExpr ( cLoc, 
1320                    new NameExpr( cLoc, "__CFA_has_clause_run" ),
1321                    {
1322                        genArrAccessExpr( cLoc, i, statusArrName )
1323                    }
1324                )
1325            }
1326        );
1327       
1328        // Generates:
1329        // (!__CFA_has_clause_run( clause_statuses[i] )) && unregister_select( ... , clausei );
1330        statusExpr = new LogicalExpr( cLoc,
1331            new CastExpr( cLoc,
1332                statusExpr, 
1333                new BasicType( BasicKind::Bool ), GeneratedFlag::ExplicitCast
1334            ),
1335            new CastExpr( cLoc,
1336                genSelectTraitCall( stmt->clauses.at(i), clauseData.at(i), "unregister_select" ),
1337                new BasicType( BasicKind::Bool ), GeneratedFlag::ExplicitCast
1338            ),
1339            LogicalFlag::AndExpr
1340        );
1341       
1342        // if when cond defined generates:
1343        // when_cond_i && (!__CFA_has_clause_run( clause_statuses[i] )) && unregister_select( ... , clausei );
1344        if ( stmt->clauses.at(i)->when_cond )
1345            statusExpr = new LogicalExpr( cLoc,
1346                new CastExpr( cLoc,
1347                    new NameExpr( cLoc, clauseData.at(i)->whenName ), 
1348                    new BasicType( BasicKind::Bool ), GeneratedFlag::ExplicitCast
1349                ),
1350                new CastExpr( cLoc,
1351                    statusExpr,
1352                    new BasicType( BasicKind::Bool ), GeneratedFlag::ExplicitCast
1353                ),
1354                LogicalFlag::AndExpr
1355            );
1356
1357        // generates:
1358        // if ( statusExpr ) { ... clausei stmt ... }
1359        unregisters->push_back( 
1360            new IfStmt( cLoc,
1361                statusExpr,
1362                new CompoundStmt( cLoc,
1363                    {
1364                        new IfStmt( cLoc,
1365                            genSelectTraitCall( stmt->clauses.at(i), clauseData.at(i), "on_selected" ),
1366                            ast::deepCopy( stmt->clauses.at(i)->stmt )
1367                        )
1368                    }
1369                )
1370            )
1371        );
1372
1373        // // generates:
1374        // // if ( statusExpr ) { ... clausei stmt ... }
1375        // unregisters->push_back(
1376        //     new IfStmt( cLoc,
1377        //         statusExpr,
1378        //         genStmtBlock( stmt->clauses.at(i), clauseData.at(i) )
1379        //     )
1380        // );
1381    }
1382
1383    body->push_back( 
1384        new ast::TryStmt(
1385            loc,
1386            tryBody,
1387            {},
1388            new ast::FinallyClause( loc, unregisters )
1389        )
1390    );
1391
1392    for ( ClauseData * datum : clauseData )
1393        delete datum;
1394
1395    return body;
1396}
1397
1398// To add the predicates at global scope we need to do it in a second pass
1399// Predicates are added after "struct select_node { ... };"
1400class AddPredicateDecls final : public WithDeclsToAdd<> {
1401    vector<FunctionDecl *> & satFns;
1402    const StructDecl * selectNodeDecl = nullptr;
1403
1404  public:
1405    void previsit( const StructDecl * decl ) {
1406        if ( !decl->body ) {
1407            return;
1408        } else if ( "select_node" == decl->name ) {
1409            assert( !selectNodeDecl );
1410            selectNodeDecl = decl;
1411            for ( FunctionDecl * fn : satFns )
1412                declsToAddAfter.push_back(fn);           
1413        }
1414    }
1415    AddPredicateDecls( vector<FunctionDecl *> & satFns ): satFns(satFns) {}
1416};
1417
1418void generateWaitUntil( TranslationUnit & translationUnit ) {
1419    vector<FunctionDecl *> satFns;
1420        Pass<GenerateWaitUntilCore>::run( translationUnit, satFns );
1421    Pass<AddPredicateDecls>::run( translationUnit, satFns );
1422}
1423
1424} // namespace Concurrency
1425
1426// Local Variables: //
1427// tab-width: 4 //
1428// mode: c++ //
1429// compile-command: "make install" //
1430// End: //
Note: See TracBrowser for help on using the repository browser.