source: src/Concurrency/Waituntil.cpp @ 37ceccb

Last change on this file since 37ceccb was 1d66a91, checked in by caparsons <caparson@…>, 16 months ago

added support for general channel operators and cleaned up some cruft

  • Property mode set to 100644
File size: 57.9 KB
Line 
1//
2// Cforall Version 1.0.0 Copyright (C) 2016 University of Waterloo
3//
4// The contents of this file are covered under the licence agreement in the
5// file "LICENCE" distributed with Cforall.
6//
7// WaitforNew.cpp -- Expand waitfor clauses into code.
8//
9// Author           : Andrew Beach
10// Created On       : Fri May 27 10:31:00 2022
11// Last Modified By : Andrew Beach
12// Last Modified On : Tue Jun 13 13:30:00 2022
13// Update Count     : 0
14//
15
16#include "Waituntil.hpp"
17
18#include <string>
19
20#include "AST/Copy.hpp"
21#include "AST/Expr.hpp"
22#include "AST/Pass.hpp"
23#include "AST/Print.hpp"
24#include "AST/Stmt.hpp"
25#include "AST/Type.hpp"
26#include "Common/UniqueName.h"
27
28using namespace ast;
29using namespace std;
30
31/* So this is what this pass dones:
32{
33    when ( condA ) waituntil( A ){ doA(); }
34    or when ( condB ) waituntil( B ){ doB(); }
35    and when ( condC ) waituntil( C ) { doC(); }
36}
37                 ||
38                 ||
39                \||/
40                 \/
41
42Generates these two routines:
43static inline bool is_full_sat_1( int * clause_statuses ) {
44    return clause_statuses[0]
45        || clause_statuses[1]
46        && clause_statuses[2];
47}
48
49static inline bool is_done_sat_1( int * clause_statuses ) {
50    return has_run(clause_statuses[0])
51        || has_run(clause_statuses[1])
52        && has_run(clause_statuses[2]);
53}
54
55Replaces the waituntil statement above with the following code:
56{
57    // used with atomic_dec/inc to get binary semaphore behaviour
58    int park_counter = 0;
59
60    // status (one for each clause)
61    int clause_statuses[3] = { 0 };
62
63    bool whenA = condA;
64    bool whenB = condB;
65    bool whenC = condC;
66
67    if ( !whenB ) clause_statuses[1] = __SELECT_RUN;
68    if ( !whenC ) clause_statuses[2] = __SELECT_RUN;
69
70    // some other conditional settors for clause_statuses are set here, see genSubtreeAssign and related routines
71
72    // three blocks
73    // for each block, create, setup, then register select_node
74    select_node clause1;
75    select_node clause2;
76    select_node clause3;
77
78    try {
79        if ( whenA ) { register_select(A, clause1); setup_clause( clause1, &clause_statuses[0], &park_counter ); }
80        ... repeat ^ for B and C ...
81
82        // if else clause is defined a separate branch can occur here to set initial values, see genWhenStateConditions
83
84        // loop & park until done
85        while( !is_full_sat_1( clause_statuses ) ) {
86           
87            // binary sem P();
88            if ( __atomic_sub_fetch( &park_counter, 1, __ATOMIC_SEQ_CST) < 0 )
89                park();
90           
91            // execute any blocks available with status set to 0
92            for ( int i = 0; i < 3; i++ ) {
93                if (clause_statuses[i] == __SELECT_SAT) {
94                    switch (i) {
95                        case 0:
96                            try {
97                                    on_selected( A, clause1 );
98                                    doA();
99                            }
100                            finally { clause_statuses[i] = __SELECT_RUN; unregister_select(A, clause1); }
101                            break;
102                        case 1:
103                            ... same gen as A but for B and clause2 ...
104                            break;
105                        case 2:
106                            ... same gen as A but for C and clause3 ...
107                            break;
108                    }
109                }
110            }
111        }
112
113        // ensure that the blocks that triggered is_full_sat_1 are run
114        // by running every un-run block that is SAT from the start until
115        // the predicate is SAT when considering RUN status = true
116        for ( int i = 0; i < 3; i++ ) {
117            if (is_done_sat_1( clause_statuses )) break;
118            if (clause_statuses[i] == __SELECT_SAT)
119                ... Same if body here as in loop above ...
120        }
121    } finally {
122        // the unregister and on_selected calls are needed to support primitives where the acquire has side effects
123        // so the corresponding block MUST be run for those primitives to not lose state (example is channels)
124        if ( !has_run(clause_statuses[0]) && whenA && unregister_select(A, clause1) )
125            on_selected( A, clause1 )
126            doA();
127        ... repeat if above for B and C ...
128    }
129}
130
131*/
132
133namespace Concurrency {
134
135class GenerateWaitUntilCore final {
136    vector<FunctionDecl *> & satFns;
137        UniqueName namer_sat = "__is_full_sat_"s;
138    UniqueName namer_run = "__is_run_sat_"s;
139        UniqueName namer_park = "__park_counter_"s;
140        UniqueName namer_status = "__clause_statuses_"s;
141        UniqueName namer_node = "__clause_"s;
142    UniqueName namer_target = "__clause_target_"s;
143    UniqueName namer_when = "__when_cond_"s;
144    UniqueName namer_label = "__waituntil_label_"s;
145
146    string idxName = "__CFA_clause_idx_";
147
148    struct ClauseData {
149        string nodeName;
150        string targetName;
151        string whenName;
152        int index;
153        string & statusName;
154        ClauseData( int index, string & statusName ) : index(index), statusName(statusName) {}
155    };
156
157    const StructDecl * selectNodeDecl = nullptr;
158
159    // This first set of routines are all used to do the complicated job of
160    //    dealing with how to set predicate statuses with certain when_conds T/F
161    //    so that the when_cond == F effectively makes that clause "disappear"
162    void updateAmbiguousWhen( WaitUntilStmt::ClauseNode * currNode, bool andAbove, bool orAbove, bool andBelow, bool orBelow );
163    void paintWhenTree( WaitUntilStmt::ClauseNode * currNode, bool andAbove, bool orAbove, bool & andBelow, bool & orBelow );
164    bool paintWhenTree( WaitUntilStmt::ClauseNode * currNode );
165    void collectWhens( WaitUntilStmt::ClauseNode * currNode, vector<pair<int, WaitUntilStmt::ClauseNode *>> & ambigIdxs, vector<int> & andIdxs, int & index, bool parentAmbig, bool parentAnd );
166    void collectWhens( WaitUntilStmt::ClauseNode * currNode, vector<pair<int, WaitUntilStmt::ClauseNode *>> & ambigIdxs, vector<int> & andIdxs );
167    void updateWhenState( WaitUntilStmt::ClauseNode * currNode );
168    void genSubtreeAssign( const WaitUntilStmt * stmt, WaitUntilStmt::ClauseNode * currNode, bool status, int & idx, CompoundStmt * retStmt, vector<ClauseData *> & clauseData );
169    void genStatusAssign( const WaitUntilStmt * stmt, WaitUntilStmt::ClauseNode * currNode, int & idx, CompoundStmt * retStmt, vector<ClauseData *> & clauseData );
170    CompoundStmt * getStatusAssignment( const WaitUntilStmt * stmt, vector<ClauseData *> & clauseData );
171    Stmt * genWhenStateConditions( const WaitUntilStmt * stmt, vector<ClauseData *> & clauseData, vector<pair<int, WaitUntilStmt::ClauseNode *>> & ambigClauses, vector<pair<int, WaitUntilStmt::ClauseNode *>>::size_type ambigIdx );
172
173    // These routines are just code-gen helpers
174    void addPredicates( const WaitUntilStmt * stmt, string & satName, string & runName );
175    void setUpClause( const WhenClause * clause, ClauseData * data, string & pCountName, CompoundStmt * body );
176    CompoundStmt * genStatusCheckFor( const WaitUntilStmt * stmt, vector<ClauseData *> & clauseData, string & predName );
177    Expr * genSelectTraitCall( const WhenClause * clause, const ClauseData * data, string fnName );
178    CompoundStmt * genStmtBlock( const WhenClause * clause, const ClauseData * data );
179    Stmt * genElseClauseBranch( const WaitUntilStmt * stmt, string & runName, string & arrName, vector<ClauseData *> & clauseData );
180    Stmt * genNoElseClauseBranch( const WaitUntilStmt * stmt, string & runName, string & arrName, string & pCountName, vector<ClauseData *> & clauseData );
181    void genClauseInits( const WaitUntilStmt * stmt, vector<ClauseData *> & clauseData, CompoundStmt * body, string & statusName, string & elseWhenName );
182    Stmt * recursiveOrIfGen( const WaitUntilStmt * stmt, vector<ClauseData *> & data, vector<ClauseData*>::size_type idx, string & elseWhenName );
183    Stmt * buildOrCaseSwitch( const WaitUntilStmt * stmt, string & statusName, vector<ClauseData *> & data );
184    Stmt * genAllOr( const WaitUntilStmt * stmt );
185
186  public:
187    void previsit( const StructDecl * decl );
188        Stmt * postvisit( const WaitUntilStmt * stmt );
189    GenerateWaitUntilCore( vector<FunctionDecl *> & satFns ): satFns(satFns) {}
190};
191
192// Finds select_node decl
193void GenerateWaitUntilCore::previsit( const StructDecl * decl ) {
194    if ( !decl->body ) {
195                return;
196        } else if ( "select_node" == decl->name ) {
197                assert( !selectNodeDecl );
198                selectNodeDecl = decl;
199        }
200}
201
202void GenerateWaitUntilCore::updateAmbiguousWhen( WaitUntilStmt::ClauseNode * currNode, bool andAbove, bool orAbove, bool andBelow, bool orBelow ) {
203    // all children when-ambiguous
204    if ( currNode->left->ambiguousWhen && currNode->right->ambiguousWhen )
205        // true iff an ancestor/descendant has a different operation
206        currNode->ambiguousWhen = (orAbove || orBelow) && (andBelow || andAbove);
207    // ambiguousWhen is initially false so theres no need to set it here
208}
209
210// Traverses ClauseNode tree and paints each AND/OR node as when-ambiguous true or false
211// This tree painting is needed to generate the if statements that set the initial state
212//    of the clause statuses when some clauses are turned off via when_cond
213// An internal AND/OR node is when-ambiguous if it satisfies all of the following:
214// - It has an ancestor or descendant that is a different operation, i.e. (AND has an OR ancestor or vice versa)
215// - All of its descendent clauses are optional, i.e. they have a when_cond defined on the WhenClause
216void GenerateWaitUntilCore::paintWhenTree( WaitUntilStmt::ClauseNode * currNode, bool andAbove, bool orAbove, bool & andBelow, bool & orBelow ) {
217    bool aBelow = false; // updated by child nodes
218    bool oBelow = false; // updated by child nodes
219    switch (currNode->op) {
220        case WaitUntilStmt::ClauseNode::AND:
221            paintWhenTree( currNode->left, true, orAbove, aBelow, oBelow );
222            paintWhenTree( currNode->right, true, orAbove, aBelow, oBelow );
223
224            // update currNode's when flag based on conditions listed in fn signature comment above
225            updateAmbiguousWhen(currNode, true, orAbove, aBelow, oBelow );
226
227            // set return flags to tell parents which decendant ops have been seen
228            andBelow = true;
229            orBelow = oBelow;
230            return;
231        case WaitUntilStmt::ClauseNode::OR:
232            paintWhenTree( currNode->left, andAbove, true, aBelow, oBelow );
233            paintWhenTree( currNode->right, andAbove, true, aBelow, oBelow );
234
235            // update currNode's when flag based on conditions listed in fn signature comment above
236            updateAmbiguousWhen(currNode, andAbove, true, aBelow, oBelow );
237
238            // set return flags to tell parents which decendant ops have been seen
239            andBelow = aBelow;
240            orBelow = true;
241            return;
242        case WaitUntilStmt::ClauseNode::LEAF:
243            if ( currNode->leaf->when_cond )
244                currNode->ambiguousWhen = true;
245            return;
246        default:
247            assertf(false, "Unreachable waituntil clause node type. How did you get here???");
248    }
249}
250
251// overloaded wrapper for paintWhenTree that sets initial values
252// returns true if entire tree is OR's (special case)
253bool GenerateWaitUntilCore::paintWhenTree( WaitUntilStmt::ClauseNode * currNode ) {
254    bool aBelow = false, oBelow = false; // unused by initial call
255    paintWhenTree( currNode, false, false, aBelow, oBelow );
256    return !aBelow;
257}
258
259// Helper: returns Expr that represents arrName[index]
260Expr * genArrAccessExpr( const CodeLocation & loc, int index, string arrName ) {
261    return new UntypedExpr ( loc, 
262        new NameExpr( loc, "?[?]" ),
263        {
264            new NameExpr( loc, arrName ),
265            ConstantExpr::from_int( loc, index )
266        }
267    );
268}
269
270// After the ClauseNode AND/OR nodes are painted this routine is called to traverses the tree and does the following:
271// - collects a set of indices in the clause arr that refer whenclauses that can have ambiguous status assignments (ambigIdxs)
272// - collects a set of indices in the clause arr that refer whenclauses that have a when() defined and an AND node as a parent (andIdxs)
273// - updates LEAF nodes to be when-ambiguous if their direct parent is when-ambiguous.
274void GenerateWaitUntilCore::collectWhens( WaitUntilStmt::ClauseNode * currNode, vector<pair<int, WaitUntilStmt::ClauseNode *>> & ambigIdxs, vector<int> & andIdxs, int & index, bool parentAmbig, bool parentAnd ) {
275    switch (currNode->op) {
276        case WaitUntilStmt::ClauseNode::AND:
277            collectWhens( currNode->left, ambigIdxs, andIdxs, index, currNode->ambiguousWhen, true );
278            collectWhens( currNode->right,  ambigIdxs, andIdxs, index, currNode->ambiguousWhen, true );
279            return;
280        case WaitUntilStmt::ClauseNode::OR:
281            collectWhens( currNode->left,  ambigIdxs, andIdxs, index, currNode->ambiguousWhen, false );
282            collectWhens( currNode->right,  ambigIdxs, andIdxs, index, currNode->ambiguousWhen, false );
283            return;
284        case WaitUntilStmt::ClauseNode::LEAF:
285            if ( parentAmbig ) {
286                ambigIdxs.push_back(make_pair(index, currNode));
287            }
288            if ( parentAnd && currNode->leaf->when_cond ) {
289                currNode->childOfAnd = true;
290                andIdxs.push_back(index);
291            }
292            index++;
293            return;
294        default:
295            assertf(false, "Unreachable waituntil clause node type. How did you get here???");
296    }
297}
298
299// overloaded wrapper for collectWhens that sets initial values
300void GenerateWaitUntilCore::collectWhens( WaitUntilStmt::ClauseNode * currNode, vector<pair<int, WaitUntilStmt::ClauseNode *>> & ambigIdxs, vector<int> & andIdxs ) {
301    int idx = 0;
302    collectWhens( currNode, ambigIdxs, andIdxs, idx, false, false );
303}
304
305// recursively updates ClauseNode whenState on internal nodes so that next pass can see which
306//    subtrees are "turned off"
307// sets whenState = false iff both children have whenState == false.
308// similar to paintWhenTree except since paintWhenTree also filtered out clauses we don't need to consider based on op
309// since the ambiguous clauses were filtered in paintWhenTree we don't need to worry about that here
310void GenerateWaitUntilCore::updateWhenState( WaitUntilStmt::ClauseNode * currNode ) {
311    if ( currNode->op == WaitUntilStmt::ClauseNode::LEAF ) return;
312    updateWhenState( currNode->left );
313    updateWhenState( currNode->right );
314    if ( !currNode->left->whenState && !currNode->right->whenState )
315        currNode->whenState = false;
316    else 
317        currNode->whenState = true;
318}
319
320// generates the minimal set of status assignments to ensure predicate subtree passed as currNode evaluates to status
321// assumes that this will only be called on subtrees that are entirely whenState == false
322void GenerateWaitUntilCore::genSubtreeAssign( const WaitUntilStmt * stmt, WaitUntilStmt::ClauseNode * currNode, bool status, int & idx, CompoundStmt * retStmt, vector<ClauseData *> & clauseData ) {
323    if ( ( currNode->op == WaitUntilStmt::ClauseNode::AND && status )
324        || ( currNode->op == WaitUntilStmt::ClauseNode::OR && !status ) ) {
325        // need to recurse on both subtrees if && subtree needs to be true or || subtree needs to be false
326        genSubtreeAssign( stmt, currNode->left, status, idx, retStmt, clauseData );
327        genSubtreeAssign( stmt, currNode->right, status, idx, retStmt, clauseData );
328    } else if ( ( currNode->op == WaitUntilStmt::ClauseNode::OR && status )
329        || ( currNode->op == WaitUntilStmt::ClauseNode::AND && !status ) ) {
330        // only one subtree needs to evaluate to status if && subtree needs to be true or || subtree needs to be false
331        CompoundStmt * leftStmt = new CompoundStmt( stmt->location );
332        CompoundStmt * rightStmt = new CompoundStmt( stmt->location );
333
334        // only one side needs to evaluate to status so we recurse on both subtrees
335        //    but only keep the statements from the subtree with minimal statements
336        genSubtreeAssign( stmt, currNode->left, status, idx, leftStmt, clauseData );
337        genSubtreeAssign( stmt, currNode->right, status, idx, rightStmt, clauseData );
338       
339        // append minimal statements to retStmt
340        if ( leftStmt->kids.size() < rightStmt->kids.size() ) {
341            retStmt->kids.splice( retStmt->kids.end(), leftStmt->kids );
342        } else {
343            retStmt->kids.splice( retStmt->kids.end(), rightStmt->kids );
344        }
345       
346        delete leftStmt;
347        delete rightStmt;
348    } else if ( currNode->op == WaitUntilStmt::ClauseNode::LEAF ) {
349        const CodeLocation & loc = stmt->location;
350        if ( status && !currNode->childOfAnd ) {
351            retStmt->push_back(
352                new ExprStmt( loc, 
353                    UntypedExpr::createAssign( loc,
354                        genArrAccessExpr( loc, idx, clauseData.at(idx)->statusName ),
355                        new NameExpr( loc, "__SELECT_RUN" )
356                    )
357                )
358            );
359        } else if ( !status && currNode->childOfAnd ) {
360            retStmt->push_back(
361                new ExprStmt( loc, 
362                    UntypedExpr::createAssign( loc,
363                        genArrAccessExpr( loc, idx, clauseData.at(idx)->statusName ),
364                        new NameExpr( loc, "__SELECT_UNSAT" )
365                    )
366                )
367            );
368        }
369
370        // No need to generate statements for the following cases since childOfAnd are always set to true
371        //    and !childOfAnd are always false
372        // - status && currNode->childOfAnd
373        // - !status && !currNode->childOfAnd
374        idx++;
375    }
376}
377
378void GenerateWaitUntilCore::genStatusAssign( const WaitUntilStmt * stmt, WaitUntilStmt::ClauseNode * currNode, int & idx, CompoundStmt * retStmt, vector<ClauseData *> & clauseData ) {
379    switch (currNode->op) {
380        case WaitUntilStmt::ClauseNode::AND:
381            // check which subtrees have all whenState == false (disabled)
382            if (!currNode->left->whenState && !currNode->right->whenState) {
383                // this case can only occur when whole tree is disabled since otherwise
384                //    genStatusAssign( ... ) isn't called on nodes with whenState == false
385                assert( !currNode->whenState ); // paranoidWWW
386                // whole tree disabled so pass true so that select is SAT vacuously
387                genSubtreeAssign( stmt, currNode, true, idx, retStmt, clauseData );
388            } else if ( !currNode->left->whenState ) {
389                // pass true since x && true === x
390                genSubtreeAssign( stmt, currNode->left, true, idx, retStmt, clauseData );
391                genStatusAssign( stmt, currNode->right, idx, retStmt, clauseData );
392            } else if ( !currNode->right->whenState ) {
393                genStatusAssign( stmt, currNode->left, idx, retStmt, clauseData );
394                genSubtreeAssign( stmt, currNode->right, true, idx, retStmt, clauseData );
395            } else { 
396                // if no children with whenState == false recurse normally via break
397                break;
398            }
399            return;
400        case WaitUntilStmt::ClauseNode::OR:
401            if (!currNode->left->whenState && !currNode->right->whenState) {
402                assert( !currNode->whenState ); // paranoid
403                genSubtreeAssign( stmt, currNode, true, idx, retStmt, clauseData );
404            } else if ( !currNode->left->whenState ) {
405                // pass false since x || false === x
406                genSubtreeAssign( stmt, currNode->left, false, idx, retStmt, clauseData );
407                genStatusAssign( stmt, currNode->right, idx, retStmt, clauseData );
408            } else if ( !currNode->right->whenState ) {
409                genStatusAssign( stmt, currNode->left, idx, retStmt, clauseData );
410                genSubtreeAssign( stmt, currNode->right, false, idx, retStmt, clauseData );
411            } else { 
412                break;
413            }
414            return;
415        case WaitUntilStmt::ClauseNode::LEAF:
416            idx++;
417            return;
418        default:
419            assertf(false, "Unreachable waituntil clause node type. How did you get here???");
420    }
421    genStatusAssign( stmt, currNode->left, idx, retStmt, clauseData );
422    genStatusAssign( stmt, currNode->right, idx, retStmt, clauseData );
423}
424
425// generates a minimal set of assignments for status arr based on which whens are toggled on/off
426CompoundStmt * GenerateWaitUntilCore::getStatusAssignment( const WaitUntilStmt * stmt, vector<ClauseData *> & clauseData ) {
427    updateWhenState( stmt->predicateTree );
428    CompoundStmt * retval = new CompoundStmt( stmt->location );
429    int idx = 0;
430    genStatusAssign( stmt, stmt->predicateTree, idx, retval, clauseData );
431    return retval;
432}
433
434// generates nested if/elses for all possible assignments of ambiguous when_conds
435// exponential size of code gen but linear runtime O(n), where n is number of ambiguous whens()
436Stmt * GenerateWaitUntilCore::genWhenStateConditions( const WaitUntilStmt * stmt, vector<ClauseData *> & clauseData, 
437    vector<pair<int, WaitUntilStmt::ClauseNode *>> & ambigClauses, vector<pair<int, WaitUntilStmt::ClauseNode *>>::size_type ambigIdx ) {
438    // I hate C++ sometimes, using vector<pair<int, WaitUntilStmt::ClauseNode *>>::size_type for size() comparison seems silly.
439    //    Why is size_type parameterized on the type stored in the vector?????
440
441    const CodeLocation & loc = stmt->location;
442    int clauseIdx = ambigClauses.at(ambigIdx).first;
443    WaitUntilStmt::ClauseNode * currNode = ambigClauses.at(ambigIdx).second;
444    Stmt * thenStmt;
445    Stmt * elseStmt;
446   
447    if ( ambigIdx == ambigClauses.size() - 1 ) { // base case
448        currNode->whenState = true;
449        thenStmt = getStatusAssignment( stmt, clauseData );
450        currNode->whenState = false;
451        elseStmt = getStatusAssignment( stmt, clauseData );
452    } else {
453        // recurse both with when enabled and disabled to generate all possible cases
454        currNode->whenState = true;
455        thenStmt = genWhenStateConditions( stmt, clauseData, ambigClauses, ambigIdx + 1 );
456        currNode->whenState = false;
457        elseStmt = genWhenStateConditions( stmt, clauseData, ambigClauses, ambigIdx + 1 );
458    }
459
460    // insert first recursion result in if ( __when_cond_ ) { ... }
461    // insert second recursion result in else { ... }
462    return new CompoundStmt ( loc,
463        {
464            new IfStmt( loc,
465                new NameExpr( loc, clauseData.at(clauseIdx)->whenName ),
466                thenStmt,
467                elseStmt
468            )
469        }
470    );
471}
472
473// typedef a fn ptr so that we can reuse genPredExpr
474// genLeafExpr is used to refer to one of the following two routines
475typedef Expr * (*GenLeafExpr)( const CodeLocation & loc, int & index );
476
477// return Expr that represents clause_statuses[index]
478// mutates index to be index + 1
479Expr * genSatExpr( const CodeLocation & loc, int & index ) {
480    return genArrAccessExpr( loc, index++, "clause_statuses" );
481}
482
483// return Expr that represents has_run(clause_statuses[index])
484Expr * genRunExpr( const CodeLocation & loc, int & index ) {
485    return new UntypedExpr ( loc, 
486        new NameExpr( loc, "__CFA_has_clause_run" ),
487        { genSatExpr( loc, index ) }
488    );
489}
490
491// Takes in the ClauseNode tree and recursively generates
492// the predicate expr used inside the predicate functions
493Expr * genPredExpr( const CodeLocation & loc, WaitUntilStmt::ClauseNode * currNode, int & idx, GenLeafExpr genLeaf ) {
494    switch (currNode->op) {
495        case WaitUntilStmt::ClauseNode::AND:
496            return new LogicalExpr( loc, 
497                new CastExpr( loc, genPredExpr( loc, currNode->left, idx, genLeaf ), new BasicType( BasicType::Kind::Bool ), GeneratedFlag::ExplicitCast ),
498                new CastExpr( loc, genPredExpr( loc, currNode->right, idx, genLeaf ), new BasicType( BasicType::Kind::Bool ), GeneratedFlag::ExplicitCast ), 
499                LogicalFlag::AndExpr
500            );
501        case WaitUntilStmt::ClauseNode::OR:
502            return new LogicalExpr( loc,
503                new CastExpr( loc, genPredExpr( loc, currNode->left, idx, genLeaf ), new BasicType( BasicType::Kind::Bool ), GeneratedFlag::ExplicitCast ),
504                new CastExpr( loc, genPredExpr( loc, currNode->right, idx, genLeaf ), new BasicType( BasicType::Kind::Bool ), GeneratedFlag::ExplicitCast ), 
505                LogicalFlag::OrExpr );
506        case WaitUntilStmt::ClauseNode::LEAF:
507            return genLeaf( loc, idx );
508        default:
509            assertf(false, "Unreachable waituntil clause node type. How did you get here???");
510    }
511}
512
513
514// Builds the predicate functions used to check the status of the waituntil statement
515/* Ex:
516{
517    waituntil( A ){ doA(); }
518    or waituntil( B ){ doB(); }
519    and waituntil( C ) { doC(); }
520}
521generates =>
522static inline bool is_full_sat_1( int * clause_statuses ) {
523    return clause_statuses[0]
524        || clause_statuses[1]
525        && clause_statuses[2];
526}
527
528static inline bool is_done_sat_1( int * clause_statuses ) {
529    return has_run(clause_statuses[0])
530        || has_run(clause_statuses[1])
531        && has_run(clause_statuses[2]);
532}
533*/
534// Returns a predicate function decl
535// predName and genLeaf determine if this generates an is_done or an is_full predicate
536FunctionDecl * buildPredicate( const WaitUntilStmt * stmt, GenLeafExpr genLeaf, string & predName ) {
537    int arrIdx = 0;
538    const CodeLocation & loc = stmt->location;
539    CompoundStmt * body = new CompoundStmt( loc );
540    body->push_back( new ReturnStmt( loc, genPredExpr( loc,  stmt->predicateTree, arrIdx, genLeaf ) ) );
541
542    return new FunctionDecl( loc,
543        predName,
544        {},                     // forall
545        {
546            new ObjectDecl( loc,
547                "clause_statuses",
548                new PointerType( new BasicType( BasicType::Kind::LongUnsignedInt ) )
549            )
550        },
551        { 
552            new ObjectDecl( loc,
553                "sat_ret",
554                new BasicType( BasicType::Kind::Bool )
555            )
556        },
557        body,               // body
558        { Storage::Static },    // storage
559        Linkage::Cforall,       // linkage
560        {},                     // attributes
561        { Function::Inline }
562    );
563}
564
565// Creates is_done and is_full predicates
566void GenerateWaitUntilCore::addPredicates( const WaitUntilStmt * stmt, string & satName, string & runName ) {
567    if ( !stmt->else_stmt || stmt->else_cond ) // don't need SAT predicate when else variation with no else_cond
568        satFns.push_back( Concurrency::buildPredicate( stmt, genSatExpr, satName ) ); 
569    satFns.push_back( Concurrency::buildPredicate( stmt, genRunExpr, runName ) );
570}
571
572// Adds the following to body:
573// if ( when_cond ) { // this if is omitted if no when() condition
574//      setup_clause( clause1, &clause_statuses[0], &park_counter );
575//      register_select(A, clause1);
576// }
577void GenerateWaitUntilCore::setUpClause( const WhenClause * clause, ClauseData * data, string & pCountName, CompoundStmt * body ) {   
578    CompoundStmt * currBody = body;
579    const CodeLocation & loc = clause->location;
580
581    // If we have a when_cond make the initialization conditional
582    if ( clause->when_cond )
583        currBody = new CompoundStmt( loc );
584
585    // Generates: setup_clause( clause1, &clause_statuses[0], &park_counter );
586    currBody->push_back( new ExprStmt( loc,
587        new UntypedExpr ( loc,
588            new NameExpr( loc, "setup_clause" ),
589            {
590                new NameExpr( loc, data->nodeName ),
591                new AddressExpr( loc, genArrAccessExpr( loc, data->index, data->statusName ) ),
592                new AddressExpr( loc, new NameExpr( loc, pCountName ) )
593            }
594        )
595    ));
596
597    // Generates: register_select(A, clause1);
598    currBody->push_back( new ExprStmt( loc, genSelectTraitCall( clause, data, "register_select" ) ) );
599
600    // generates: if ( when_cond ) { ... currBody ... }
601    if ( clause->when_cond )
602        body->push_back( 
603            new IfStmt( loc,
604                new NameExpr( loc, data->whenName ),
605                currBody
606            )
607        );
608}
609
610// Used to generate a call to one of the select trait routines
611Expr * GenerateWaitUntilCore::genSelectTraitCall( const WhenClause * clause, const ClauseData * data, string fnName ) {
612    const CodeLocation & loc = clause->location;
613    return new UntypedExpr ( loc,
614        new NameExpr( loc, fnName ),
615        {
616            new NameExpr( loc, data->targetName ),
617            new NameExpr( loc, data->nodeName )
618        }
619    );
620}
621
622// Generates:
623/* on_selected( target_1, node_1 ); ... corresponding body of target_1 ...
624*/
625CompoundStmt * GenerateWaitUntilCore::genStmtBlock( const WhenClause * clause, const ClauseData * data ) {
626    const CodeLocation & cLoc = clause->location;
627    return new CompoundStmt( cLoc,
628        {
629            new IfStmt( cLoc,
630                genSelectTraitCall( clause, data, "on_selected" ),
631                ast::deepCopy( clause->stmt )
632            )
633        }
634    );
635}
636
637// this routine generates and returns the following
638/*for ( int i = 0; i < numClauses; i++ ) {
639    if ( predName(clause_statuses) ) break;
640    if (clause_statuses[i] == __SELECT_SAT) {
641        switch (i) {
642            case 0:
643                try {
644                    on_selected( target1, clause1 );
645                    dotarget1stmt();
646                }
647                finally { clause_statuses[i] = __SELECT_RUN; unregister_select(target1, clause1); }
648                break;
649            ...
650            case N:
651                ...
652                break;
653        }
654    }
655}*/
656CompoundStmt * GenerateWaitUntilCore::genStatusCheckFor( const WaitUntilStmt * stmt, vector<ClauseData *> & clauseData, string & predName ) {
657    CompoundStmt * ifBody = new CompoundStmt( stmt->location );
658    const CodeLocation & loc = stmt->location;
659
660    string switchLabel = namer_label.newName();
661
662    /* generates:
663    switch (i) {
664        case 0:
665            try {
666                on_selected( target1, clause1 );
667                dotarget1stmt();
668            }
669            finally { clause_statuses[i] = __SELECT_RUN; unregister_select(target1, clause1); }
670            break;
671            ...
672        case N:
673            ...
674            break;
675    }*/
676    std::vector<ptr<CaseClause>> switchCases;
677    int idx = 0;
678    for ( const auto & clause: stmt->clauses ) {
679        const CodeLocation & cLoc = clause->location;
680        switchCases.push_back(
681            new CaseClause( cLoc,
682                ConstantExpr::from_int( cLoc, idx ),
683                {
684                    new CompoundStmt( cLoc,
685                        {
686                            new ast::TryStmt( cLoc,
687                                genStmtBlock( clause, clauseData.at(idx) ),
688                                {},
689                                new ast::FinallyClause( cLoc, 
690                                    new CompoundStmt( cLoc,
691                                        {
692                                            new ExprStmt( loc,
693                                                new UntypedExpr ( loc,
694                                                    new NameExpr( loc, "?=?" ),
695                                                    {
696                                                        new UntypedExpr ( loc, 
697                                                            new NameExpr( loc, "?[?]" ),
698                                                            {
699                                                                new NameExpr( loc, clauseData.at(0)->statusName ),
700                                                                new NameExpr( loc, idxName )
701                                                            }
702                                                        ),
703                                                        new NameExpr( loc, "__SELECT_RUN" )
704                                                    }
705                                                )
706                                            ),
707                                            new ExprStmt( loc, genSelectTraitCall( clause, clauseData.at(idx), "unregister_select" ) )
708                                        }
709                                    )
710                                )
711                            ),
712                            new BranchStmt( cLoc, BranchStmt::Kind::Break, Label( cLoc, switchLabel ) )
713                        }
714                    )
715                }
716            )
717        );
718        idx++;
719    }
720
721    ifBody->push_back(
722        new SwitchStmt( loc,
723            new NameExpr( loc, idxName ),
724            std::move( switchCases ),
725            { Label( loc, switchLabel ) }
726        )
727    );
728
729    // gens:
730    // if (clause_statuses[i] == __SELECT_SAT) {
731    //      ... ifBody  ...
732    // }
733    IfStmt * ifSwitch = new IfStmt( loc,
734        new UntypedExpr ( loc,
735            new NameExpr( loc, "?==?" ),
736            {
737                new UntypedExpr ( loc, 
738                    new NameExpr( loc, "?[?]" ),
739                    {
740                        new NameExpr( loc, clauseData.at(0)->statusName ),
741                        new NameExpr( loc, idxName )
742                    }
743                ),
744                new NameExpr( loc, "__SELECT_SAT" )
745            }
746        ),      // condition
747        ifBody  // body
748    );
749
750    string forLabel = namer_label.newName();
751
752    // we hoist init here so that this pass can happen after hoistdecls pass
753    return new CompoundStmt( loc,
754        {
755            new DeclStmt( loc,
756                new ObjectDecl( loc,
757                    idxName,
758                    new BasicType( BasicType::Kind::SignedInt ),
759                    new SingleInit( loc, ConstantExpr::from_int( loc, 0 ) )
760                )
761            ),
762            new ForStmt( loc,
763                {},  // inits
764                new UntypedExpr ( loc,
765                    new NameExpr( loc, "?<?" ),
766                    {
767                        new NameExpr( loc, idxName ),
768                        ConstantExpr::from_int( loc, stmt->clauses.size() )
769                    }
770                ),  // cond
771                new UntypedExpr ( loc,
772                    new NameExpr( loc, "?++" ),
773                    { new NameExpr( loc, idxName ) }
774                ),  // inc
775                new CompoundStmt( loc,
776                    {
777                        new IfStmt( loc,
778                            new UntypedExpr ( loc,
779                                new NameExpr( loc, predName ),
780                                { new NameExpr( loc, clauseData.at(0)->statusName ) }
781                            ),
782                            new BranchStmt( loc, BranchStmt::Kind::Break, Label( loc, forLabel ) )
783                        ),
784                        ifSwitch
785                    }
786                ),   // body
787                { Label( loc, forLabel ) }
788            )
789        }
790    );
791}
792
793// Generates: !is_full_sat_n() / !is_run_sat_n()
794Expr * genNotSatExpr( const WaitUntilStmt * stmt, string & satName, string & arrName ) {
795    const CodeLocation & loc = stmt->location;
796    return new UntypedExpr ( loc,
797        new NameExpr( loc, "!?" ),
798        {
799            new UntypedExpr ( loc,
800                new NameExpr( loc, satName ),
801                { new NameExpr( loc, arrName ) }
802            )
803        }
804    );
805}
806
807// Generates the code needed for waituntils with an else ( ... )
808// Checks clauses once after registering for completion and runs them if completes
809// If not enough have run to satisfy predicate after one pass then the else is run
810Stmt * GenerateWaitUntilCore::genElseClauseBranch( const WaitUntilStmt * stmt, string & runName, string & arrName, vector<ClauseData *> & clauseData ) {
811    return new CompoundStmt( stmt->else_stmt->location,
812        {
813            genStatusCheckFor( stmt, clauseData, runName ),
814            new IfStmt( stmt->else_stmt->location,
815                genNotSatExpr( stmt, runName, arrName ),
816                ast::deepCopy( stmt->else_stmt )
817            )
818        }
819    );
820}
821
822Stmt * GenerateWaitUntilCore::genNoElseClauseBranch( const WaitUntilStmt * stmt, string & runName, string & arrName, string & pCountName, vector<ClauseData *> & clauseData ) {
823    CompoundStmt * whileBody = new CompoundStmt( stmt->location );
824    const CodeLocation & loc = stmt->location;
825
826    // generates: __CFA_maybe_park( &park_counter );
827    whileBody->push_back(
828        new ExprStmt( loc,
829            new UntypedExpr ( loc,
830                new NameExpr( loc, "__CFA_maybe_park" ),
831                { new AddressExpr( loc, new NameExpr( loc, pCountName ) ) }
832            )
833        )
834    );
835
836    whileBody->push_back( genStatusCheckFor( stmt, clauseData, runName ) );
837
838    return new CompoundStmt( loc,
839        {
840            new WhileDoStmt( loc,
841                genNotSatExpr( stmt, runName, arrName ),
842                whileBody,  // body
843                {}          // no inits
844            )
845        }
846    );
847}
848
849// generates the following decls for each clause to ensure the target expr and when_cond is only evaluated once
850// typeof(target) & __clause_target_0 = target;
851// bool __when_cond_0 = when_cond; // only generated if when_cond defined
852// select_node clause1;
853void GenerateWaitUntilCore::genClauseInits( const WaitUntilStmt * stmt, vector<ClauseData *> & clauseData, CompoundStmt * body, string & statusName, string & elseWhenName ) {
854    ClauseData * currClause;
855    for ( vector<ClauseData*>::size_type i = 0; i < stmt->clauses.size(); i++ ) {
856        currClause = new ClauseData( i, statusName );
857        currClause->nodeName = namer_node.newName();
858        currClause->targetName = namer_target.newName();
859        currClause->whenName = namer_when.newName();
860        clauseData.push_back(currClause);
861        const CodeLocation & cLoc = stmt->clauses.at(i)->location;
862
863        // typeof(target) & __clause_target_0 = target;
864        body->push_back(
865            new DeclStmt( cLoc,
866                new ObjectDecl( cLoc,
867                    currClause->targetName,
868                    new ReferenceType( 
869                        new TypeofType( new UntypedExpr( cLoc,
870                            new NameExpr( cLoc, "__CFA_select_get_type" ),
871                            { ast::deepCopy( stmt->clauses.at(i)->target ) }
872                        ))
873                    ),
874                    new SingleInit( cLoc, ast::deepCopy( stmt->clauses.at(i)->target ) )
875                )
876            )
877        );
878
879        // bool __when_cond_0 = when_cond; // only generated if when_cond defined
880        if ( stmt->clauses.at(i)->when_cond )
881            body->push_back(
882                new DeclStmt( cLoc,
883                    new ObjectDecl( cLoc,
884                        currClause->whenName,
885                        new BasicType( BasicType::Kind::Bool ),
886                        new SingleInit( cLoc, ast::deepCopy( stmt->clauses.at(i)->when_cond ) )
887                    )
888                )
889            );
890       
891        // select_node clause1;
892        body->push_back(
893            new DeclStmt( cLoc,
894                new ObjectDecl( cLoc,
895                    currClause->nodeName,
896                    new StructInstType( selectNodeDecl )
897                )
898            )
899        );
900    }
901
902    if ( stmt->else_stmt && stmt->else_cond ) {
903        body->push_back(
904            new DeclStmt( stmt->else_cond->location,
905                new ObjectDecl( stmt->else_cond->location,
906                    elseWhenName,
907                    new BasicType( BasicType::Kind::Bool ),
908                    new SingleInit( stmt->else_cond->location, ast::deepCopy( stmt->else_cond ) )
909                )
910            )
911        );
912    }
913}
914
915/*
916if ( clause_status == &clause1 ) ... clause 1 body ...
917...
918elif ( clause_status == &clausen ) ... clause n body ...
919*/
920Stmt * GenerateWaitUntilCore::buildOrCaseSwitch( const WaitUntilStmt * stmt, string & statusName, vector<ClauseData *> & data ) {
921    const CodeLocation & loc = stmt->location;
922
923    IfStmt * outerIf = nullptr;
924        IfStmt * lastIf = nullptr;
925
926        //adds an if/elif clause for each select clause address to run the corresponding clause stmt
927        for ( long unsigned int i = 0; i < data.size(); i++ ) {
928        const CodeLocation & cLoc = stmt->clauses.at(i)->location;
929
930                IfStmt * currIf = new IfStmt( cLoc,
931                        new UntypedExpr( cLoc, 
932                new NameExpr( cLoc, "?==?" ), 
933                {
934                    new NameExpr( cLoc, statusName ),
935                    new CastExpr( cLoc, 
936                        new AddressExpr( cLoc, new NameExpr( cLoc, data.at(i)->nodeName ) ),
937                        new BasicType( BasicType::Kind::LongUnsignedInt ), GeneratedFlag::ExplicitCast
938                    )
939                }
940            ),
941            genStmtBlock( stmt->clauses.at(i), data.at(i) )
942                );
943               
944                if ( i == 0 ) {
945                        outerIf = currIf;
946                } else {
947                        // add ifstmt to else of previous stmt
948                        lastIf->else_ = currIf;
949                }
950
951                lastIf = currIf;
952        }
953
954    return new CompoundStmt( loc,
955        {
956            new ExprStmt( loc, new UntypedExpr( loc, new NameExpr( loc, "park" ) ) ),
957            outerIf
958        }
959    );
960}
961
962Stmt * GenerateWaitUntilCore::recursiveOrIfGen( const WaitUntilStmt * stmt, vector<ClauseData *> & data, vector<ClauseData*>::size_type idx, string & elseWhenName ) {
963    if ( idx == data.size() ) {   // base case, gen last else
964        const CodeLocation & cLoc = stmt->else_stmt->location;
965        if ( !stmt->else_stmt ) // normal non-else gen
966            return buildOrCaseSwitch( stmt, data.at(0)->statusName, data );
967
968        Expr * raceFnCall = new UntypedExpr( stmt->location,
969            new NameExpr( stmt->location, "__select_node_else_race" ),
970            { new NameExpr( stmt->location, data.at(0)->nodeName ) }
971        );
972
973        if ( stmt->else_stmt && stmt->else_cond ) { // return else conditional on both when and race
974            return new IfStmt( cLoc,
975                new LogicalExpr( cLoc,
976                    new CastExpr( cLoc,
977                        new NameExpr( cLoc, elseWhenName ),
978                        new BasicType( BasicType::Kind::Bool ), GeneratedFlag::ExplicitCast
979                    ),
980                    new CastExpr( cLoc,
981                        raceFnCall,
982                        new BasicType( BasicType::Kind::Bool ), GeneratedFlag::ExplicitCast
983                    ),
984                    LogicalFlag::AndExpr
985                ),
986                ast::deepCopy( stmt->else_stmt ),
987                buildOrCaseSwitch( stmt, data.at(0)->statusName, data )
988            );
989        }
990
991        // return else conditional on race
992        return new IfStmt( stmt->else_stmt->location,
993            raceFnCall,
994            ast::deepCopy( stmt->else_stmt ),
995            buildOrCaseSwitch( stmt, data.at(0)->statusName, data )
996        );
997    }
998    const CodeLocation & cLoc = stmt->clauses.at(idx)->location;
999
1000    Expr * baseCond = genSelectTraitCall( stmt->clauses.at(idx), data.at(idx), "register_select" );
1001    Expr * ifCond;
1002
1003    // If we have a when_cond make the register call conditional on it
1004    if ( stmt->clauses.at(idx)->when_cond ) {
1005        ifCond = new LogicalExpr( cLoc,
1006            new CastExpr( cLoc,
1007                new NameExpr( cLoc, data.at(idx)->whenName ), 
1008                new BasicType( BasicType::Kind::Bool ), GeneratedFlag::ExplicitCast
1009            ),
1010            new CastExpr( cLoc,
1011                baseCond,
1012                new BasicType( BasicType::Kind::Bool ), GeneratedFlag::ExplicitCast
1013            ),
1014            LogicalFlag::AndExpr
1015        );
1016    } else ifCond = baseCond;
1017
1018    return new CompoundStmt( cLoc,
1019        {   // gens: setup_clause( clause1, &status, 0p );
1020            new ExprStmt( cLoc,
1021                new UntypedExpr ( cLoc,
1022                    new NameExpr( cLoc, "setup_clause" ),
1023                    {
1024                        new NameExpr( cLoc, data.at(idx)->nodeName ),
1025                        new AddressExpr( cLoc, new NameExpr( cLoc, data.at(idx)->statusName ) ),
1026                        ConstantExpr::null( cLoc, new PointerType( new BasicType( BasicType::Kind::SignedInt ) ) )
1027                    }
1028                )
1029            ),
1030            // gens: if (__when_cond && register_select()) { clause body } else { ... recursiveOrIfGen ... }
1031            new IfStmt( cLoc,
1032                ifCond,
1033                genStmtBlock( stmt->clauses.at(idx), data.at(idx) ),
1034                recursiveOrIfGen( stmt, data, idx + 1, elseWhenName )
1035            )
1036        }
1037    );
1038}
1039
1040// This gens the special case of an all OR waituntil:
1041/*
1042int status = 0;
1043
1044typeof(target) & __clause_target_0 = target;
1045bool __when_cond_0 = when_cond; // only generated if when_cond defined
1046select_node clause1;
1047... generate above for rest of clauses ...
1048
1049try {
1050    setup_clause( clause1, &status, 0p );
1051    if ( __when_cond_0 && register_select( 1 ) ) {
1052        ... clause 1 body ...
1053    } else {
1054        ... recursively gen for each of n clauses ...
1055        setup_clause( clausen, &status, 0p );
1056        if ( __when_cond_n-1 && register_select( n ) ) {
1057            ... clause n body ...
1058        } else {
1059            if ( else_when ) ... else clause body ...
1060            else {
1061                park();
1062
1063                // after winning the race and before unpark() clause_status is set to be the winning clause index + 1
1064                if ( clause_status == &clause1) ... clause 1 body ...
1065                ...
1066                elif ( clause_status == &clausen ) ... clause n body ...
1067            }
1068        }
1069    }
1070}
1071finally {
1072    if ( __when_cond_1 && clause1.status != 0p) unregister_select( 1 ); // if registered unregister
1073    ...
1074    if ( __when_cond_n && clausen.status != 0p) unregister_select( n );
1075}
1076*/
1077Stmt * GenerateWaitUntilCore::genAllOr( const WaitUntilStmt * stmt ) {
1078    const CodeLocation & loc = stmt->location;
1079    string statusName = namer_status.newName();
1080    string elseWhenName = namer_when.newName();
1081    int numClauses = stmt->clauses.size();
1082    CompoundStmt * body = new CompoundStmt( stmt->location );
1083
1084    // Generates: unsigned long int status = 0;
1085    body->push_back( new DeclStmt( loc,
1086        new ObjectDecl( loc,
1087            statusName,
1088            new BasicType( BasicType::Kind::LongUnsignedInt ),
1089            new SingleInit( loc, ConstantExpr::from_int( loc, 0 ) )
1090        )
1091    ));
1092
1093    vector<ClauseData *> clauseData;
1094    genClauseInits( stmt, clauseData, body, statusName, elseWhenName );
1095
1096    vector<int> whenIndices; // track which clauses have whens
1097
1098    CompoundStmt * unregisters = new CompoundStmt( loc );
1099    Expr * ifCond;
1100    for ( int i = 0; i < numClauses; i++ ) {
1101        const CodeLocation & cLoc = stmt->clauses.at(i)->location;
1102        // Gens: node.status != 0p
1103        UntypedExpr * statusPtrCheck = new UntypedExpr( cLoc, 
1104            new NameExpr( cLoc, "?!=?" ), 
1105            {
1106                ConstantExpr::null( cLoc, new PointerType( new BasicType( BasicType::Kind::LongUnsignedInt ) ) ),
1107                new UntypedExpr( cLoc, 
1108                    new NameExpr( cLoc, "__get_clause_status" ), 
1109                    { new NameExpr( cLoc, clauseData.at(i)->nodeName ) } 
1110                ) 
1111            }
1112        );
1113
1114        // If we have a when_cond make the unregister call conditional on it
1115        if ( stmt->clauses.at(i)->when_cond ) {
1116            whenIndices.push_back(i);
1117            ifCond = new LogicalExpr( cLoc,
1118                new CastExpr( cLoc,
1119                    new NameExpr( cLoc, clauseData.at(i)->whenName ), 
1120                    new BasicType( BasicType::Kind::Bool ), GeneratedFlag::ExplicitCast
1121                ),
1122                new CastExpr( cLoc,
1123                    statusPtrCheck,
1124                    new BasicType( BasicType::Kind::Bool ), GeneratedFlag::ExplicitCast
1125                ),
1126                LogicalFlag::AndExpr
1127            );
1128        } else ifCond = statusPtrCheck;
1129       
1130        unregisters->push_back(
1131            new IfStmt( cLoc,
1132                ifCond,
1133                new ExprStmt( cLoc, genSelectTraitCall( stmt->clauses.at(i), clauseData.at(i), "unregister_select" ) ) 
1134            )
1135        );
1136    }
1137
1138    if ( whenIndices.empty() || whenIndices.size() != stmt->clauses.size() ) {
1139        body->push_back(
1140                new ast::TryStmt( loc,
1141                new CompoundStmt( loc, { recursiveOrIfGen( stmt, clauseData, 0, elseWhenName ) } ),
1142                {},
1143                new ast::FinallyClause( loc, unregisters )
1144            )
1145        );
1146    } else { // If all clauses have whens, we need to skip the waituntil if they are all false
1147        Expr * outerIfCond = new NameExpr( loc, clauseData.at( whenIndices.at(0) )->whenName );
1148        Expr * lastExpr = outerIfCond;
1149
1150        for ( vector<int>::size_type i = 1; i < whenIndices.size(); i++ ) {
1151            outerIfCond = new LogicalExpr( loc,
1152                new CastExpr( loc,
1153                    new NameExpr( loc, clauseData.at( whenIndices.at(i) )->whenName ), 
1154                    new BasicType( BasicType::Kind::Bool ), GeneratedFlag::ExplicitCast
1155                ),
1156                new CastExpr( loc,
1157                    lastExpr,
1158                    new BasicType( BasicType::Kind::Bool ), GeneratedFlag::ExplicitCast
1159                ),
1160                LogicalFlag::OrExpr
1161            );
1162            lastExpr = outerIfCond;
1163        }
1164
1165        body->push_back(
1166                new ast::TryStmt( loc,
1167                new CompoundStmt( loc, 
1168                    {
1169                        new IfStmt( loc,
1170                            outerIfCond,
1171                            recursiveOrIfGen( stmt, clauseData, 0, elseWhenName )
1172                        )
1173                    }
1174                ),
1175                {},
1176                new ast::FinallyClause( loc, unregisters )
1177            )
1178        );
1179    }
1180
1181    for ( ClauseData * datum : clauseData )
1182        delete datum;
1183
1184    return body;
1185}
1186
1187Stmt * GenerateWaitUntilCore::postvisit( const WaitUntilStmt * stmt ) {
1188    if ( !selectNodeDecl )
1189        SemanticError( stmt, "waituntil statement requires #include <waituntil.hfa>" );
1190
1191    // Prep clause tree to figure out how to set initial statuses
1192    // setTreeSizes( stmt->predicateTree );
1193    if ( paintWhenTree( stmt->predicateTree ) ) // if this returns true we can special case since tree is all OR's
1194        return genAllOr( stmt );
1195
1196    CompoundStmt * tryBody = new CompoundStmt( stmt->location );
1197    CompoundStmt * body = new CompoundStmt( stmt->location );
1198    string statusArrName = namer_status.newName();
1199    string pCountName = namer_park.newName();
1200    string satName = namer_sat.newName();
1201    string runName = namer_run.newName();
1202    string elseWhenName = namer_when.newName();
1203    int numClauses = stmt->clauses.size();
1204    addPredicates( stmt, satName, runName );
1205
1206    const CodeLocation & loc = stmt->location;
1207
1208    // Generates: int park_counter = 0;
1209    body->push_back( new DeclStmt( loc,
1210        new ObjectDecl( loc,
1211            pCountName,
1212            new BasicType( BasicType::Kind::SignedInt ),
1213            new SingleInit( loc, ConstantExpr::from_int( loc, 0 ) )
1214        )
1215    ));
1216
1217    // Generates: int clause_statuses[3] = { 0 };
1218    body->push_back( new DeclStmt( loc,
1219        new ObjectDecl( loc,
1220            statusArrName,
1221            new ArrayType( new BasicType( BasicType::Kind::LongUnsignedInt ), ConstantExpr::from_int( loc, numClauses ), LengthFlag::FixedLen, DimensionFlag::DynamicDim ),
1222            new ListInit( loc,
1223                {
1224                    new SingleInit( loc, ConstantExpr::from_int( loc, 0 ) )
1225                }
1226            )
1227        )
1228    ));
1229
1230    vector<ClauseData *> clauseData;
1231    genClauseInits( stmt, clauseData, body, statusArrName, elseWhenName );
1232
1233    vector<pair<int, WaitUntilStmt::ClauseNode *>> ambiguousClauses;       // list of ambiguous clauses
1234    vector<int> andWhenClauses;    // list of clauses that have an AND op as a direct parent and when_cond defined
1235
1236    collectWhens( stmt->predicateTree, ambiguousClauses, andWhenClauses );
1237
1238    // This is only needed for clauses that have AND as a parent and a when_cond defined
1239    // generates: if ( ! when_cond_0 ) clause_statuses_0 = __SELECT_RUN;
1240    for ( int idx : andWhenClauses ) {
1241        const CodeLocation & cLoc = stmt->clauses.at(idx)->location;
1242        body->push_back( 
1243            new IfStmt( cLoc,
1244                new UntypedExpr ( cLoc,
1245                    new NameExpr( cLoc, "!?" ),
1246                    { new NameExpr( cLoc, clauseData.at(idx)->whenName ) }
1247                ),  // IfStmt cond
1248                new ExprStmt( cLoc,
1249                    new UntypedExpr ( cLoc,
1250                        new NameExpr( cLoc, "?=?" ),
1251                        {
1252                            new UntypedExpr ( cLoc, 
1253                                new NameExpr( cLoc, "?[?]" ),
1254                                {
1255                                    new NameExpr( cLoc, statusArrName ),
1256                                    ConstantExpr::from_int( cLoc, idx )
1257                                }
1258                            ),
1259                            new NameExpr( cLoc, "__SELECT_RUN" )
1260                        }
1261                    )
1262                )  // IfStmt then
1263            )
1264        );
1265    }
1266
1267    // Only need to generate conditional initial state setting for ambiguous when clauses
1268    if ( !ambiguousClauses.empty() ) {
1269        body->push_back( genWhenStateConditions( stmt, clauseData, ambiguousClauses, 0 ) );
1270    }
1271
1272    // generates the following for each clause:
1273    // setup_clause( clause1, &clause_statuses[0], &park_counter );
1274    // register_select(A, clause1);
1275    for ( int i = 0; i < numClauses; i++ ) {
1276        setUpClause( stmt->clauses.at(i), clauseData.at(i), pCountName, tryBody );
1277    }
1278
1279    // generate satisfy logic based on if there is an else clause and if it is conditional
1280    if ( stmt->else_stmt && stmt->else_cond ) { // gen both else/non else branches
1281        tryBody->push_back(
1282            new IfStmt( stmt->else_cond->location,
1283                new NameExpr( stmt->else_cond->location, elseWhenName ),
1284                genElseClauseBranch( stmt, runName, statusArrName, clauseData ),
1285                genNoElseClauseBranch( stmt, runName, statusArrName, pCountName, clauseData )
1286            )
1287        );
1288    } else if ( !stmt->else_stmt ) { // normal gen
1289        tryBody->push_back( genNoElseClauseBranch( stmt, runName, statusArrName, pCountName, clauseData ) );
1290    } else { // generate just else
1291        tryBody->push_back( genElseClauseBranch( stmt, runName, statusArrName, clauseData ) );
1292    }
1293
1294    // Collection of unregister calls on resources to be put in finally clause
1295    // for each clause:
1296    // if ( !__CFA_has_clause_run( clause_statuses[i] )) && unregister_select( ... , clausei ) ) { ... clausei stmt ... }
1297    // OR if when( ... ) defined on resource
1298    // if ( when_cond_i && (!__CFA_has_clause_run( clause_statuses[i] )) && unregister_select( ... , clausei ) ) { ... clausei stmt ... }
1299    CompoundStmt * unregisters = new CompoundStmt( loc );
1300
1301    Expr * statusExpr; // !__CFA_has_clause_run( clause_statuses[i] )
1302    for ( int i = 0; i < numClauses; i++ ) {
1303        const CodeLocation & cLoc = stmt->clauses.at(i)->location;
1304
1305        // Generates: !__CFA_has_clause_run( clause_statuses[i] )
1306        statusExpr = new UntypedExpr ( cLoc,
1307            new NameExpr( cLoc, "!?" ),
1308            {
1309                new UntypedExpr ( cLoc, 
1310                    new NameExpr( cLoc, "__CFA_has_clause_run" ),
1311                    {
1312                        genArrAccessExpr( cLoc, i, statusArrName )
1313                    }
1314                )
1315            }
1316        );
1317       
1318        // Generates:
1319        // (!__CFA_has_clause_run( clause_statuses[i] )) && unregister_select( ... , clausei );
1320        statusExpr = new LogicalExpr( cLoc,
1321            new CastExpr( cLoc,
1322                statusExpr, 
1323                new BasicType( BasicType::Kind::Bool ), GeneratedFlag::ExplicitCast
1324            ),
1325            new CastExpr( cLoc,
1326                genSelectTraitCall( stmt->clauses.at(i), clauseData.at(i), "unregister_select" ),
1327                new BasicType( BasicType::Kind::Bool ), GeneratedFlag::ExplicitCast
1328            ),
1329            LogicalFlag::AndExpr
1330        );
1331       
1332        // if when cond defined generates:
1333        // when_cond_i && (!__CFA_has_clause_run( clause_statuses[i] )) && unregister_select( ... , clausei );
1334        if ( stmt->clauses.at(i)->when_cond )
1335            statusExpr = new LogicalExpr( cLoc,
1336                new CastExpr( cLoc,
1337                    new NameExpr( cLoc, clauseData.at(i)->whenName ), 
1338                    new BasicType( BasicType::Kind::Bool ), GeneratedFlag::ExplicitCast
1339                ),
1340                new CastExpr( cLoc,
1341                    statusExpr,
1342                    new BasicType( BasicType::Kind::Bool ), GeneratedFlag::ExplicitCast
1343                ),
1344                LogicalFlag::AndExpr
1345            );
1346
1347        // generates:
1348        // if ( statusExpr ) { ... clausei stmt ... }
1349        unregisters->push_back( 
1350            new IfStmt( cLoc,
1351                statusExpr,
1352                new CompoundStmt( cLoc,
1353                    {
1354                        new IfStmt( cLoc,
1355                            genSelectTraitCall( stmt->clauses.at(i), clauseData.at(i), "on_selected" ),
1356                            ast::deepCopy( stmt->clauses.at(i)->stmt )
1357                        )
1358                    }
1359                )
1360            )
1361        );
1362
1363        // // generates:
1364        // // if ( statusExpr ) { ... clausei stmt ... }
1365        // unregisters->push_back(
1366        //     new IfStmt( cLoc,
1367        //         statusExpr,
1368        //         genStmtBlock( stmt->clauses.at(i), clauseData.at(i) )
1369        //     )
1370        // );
1371    }
1372
1373    body->push_back( 
1374        new ast::TryStmt(
1375            loc,
1376            tryBody,
1377            {},
1378            new ast::FinallyClause( loc, unregisters )
1379        )
1380    );
1381
1382    for ( ClauseData * datum : clauseData )
1383        delete datum;
1384
1385    return body;
1386}
1387
1388// To add the predicates at global scope we need to do it in a second pass
1389// Predicates are added after "struct select_node { ... };"
1390class AddPredicateDecls final : public WithDeclsToAdd<> {
1391    vector<FunctionDecl *> & satFns;
1392    const StructDecl * selectNodeDecl = nullptr;
1393
1394  public:
1395    void previsit( const StructDecl * decl ) {
1396        if ( !decl->body ) {
1397            return;
1398        } else if ( "select_node" == decl->name ) {
1399            assert( !selectNodeDecl );
1400            selectNodeDecl = decl;
1401            for ( FunctionDecl * fn : satFns )
1402                declsToAddAfter.push_back(fn);           
1403        }
1404    }
1405    AddPredicateDecls( vector<FunctionDecl *> & satFns ): satFns(satFns) {}
1406};
1407
1408void generateWaitUntil( TranslationUnit & translationUnit ) {
1409    vector<FunctionDecl *> satFns;
1410        Pass<GenerateWaitUntilCore>::run( translationUnit, satFns );
1411    Pass<AddPredicateDecls>::run( translationUnit, satFns );
1412}
1413
1414} // namespace Concurrency
1415
1416// Local Variables: //
1417// tab-width: 4 //
1418// mode: c++ //
1419// compile-command: "make install" //
1420// End: //
Note: See TracBrowser for help on using the repository browser.