source: src/Concurrency/Waituntil.cpp @ a33a5e2

ADTast-experimental
Last change on this file since a33a5e2 was c86b08d, checked in by caparsons <caparson@…>, 14 months ago

added support for the waituntil statement in the compiler

  • Property mode set to 100644
File size: 57.6 KB
Line 
1//
2// Cforall Version 1.0.0 Copyright (C) 2016 University of Waterloo
3//
4// The contents of this file are covered under the licence agreement in the
5// file "LICENCE" distributed with Cforall.
6//
7// WaitforNew.cpp -- Expand waitfor clauses into code.
8//
9// Author           : Andrew Beach
10// Created On       : Fri May 27 10:31:00 2022
11// Last Modified By : Andrew Beach
12// Last Modified On : Tue Jun 13 13:30:00 2022
13// Update Count     : 0
14//
15
16#include <string>
17
18#include "Waituntil.hpp"
19#include "AST/Expr.hpp"
20#include "AST/Pass.hpp"
21#include "AST/Print.hpp"
22#include "AST/Stmt.hpp"
23#include "AST/Type.hpp"
24#include "Common/UniqueName.h"
25
26using namespace ast;
27using namespace std;
28
29/* So this is what this pass dones:
30{
31    when ( condA ) waituntil( A ){ doA(); }
32    or when ( condB ) waituntil( B ){ doB(); }
33    and when ( condC ) waituntil( C ) { doC(); }
34}
35                 ||
36                 ||
37                \||/
38                 \/
39
40Generates these two routines:
41static inline bool is_full_sat_1( int * clause_statuses ) {
42    return clause_statuses[0]
43        || clause_statuses[1]
44        && clause_statuses[2];
45}
46
47static inline bool is_done_sat_1( int * clause_statuses ) {
48    return has_run(clause_statuses[0])
49        || has_run(clause_statuses[1])
50        && has_run(clause_statuses[2]);
51}
52
53Replaces the waituntil statement above with the following code:
54{
55    // used with atomic_dec/inc to get binary semaphore behaviour
56    int park_counter = 0;
57
58    // status (one for each clause)
59    int clause_statuses[3] = { 0 };
60
61    bool whenA = condA;
62    bool whenB = condB;
63    bool whenC = condC;
64
65    if ( !whenB ) clause_statuses[1] = __SELECT_RUN;
66    if ( !whenC ) clause_statuses[2] = __SELECT_RUN;
67
68    // some other conditional settors for clause_statuses are set here, see genSubtreeAssign and related routines
69
70    // three blocks
71    // for each block, create, setup, then register select_node
72    select_node clause1;
73    select_node clause2;
74    select_node clause3;
75
76    try {
77        if ( whenA ) { register_select(A, clause1); setup_clause( clause1, &clause_statuses[0], &park_counter ); }
78        ... repeat ^ for B and C ...
79
80        // if else clause is defined a separate branch can occur here to set initial values, see genWhenStateConditions
81
82        // loop & park until done
83        while( !is_full_sat_1( clause_statuses ) ) {
84           
85            // binary sem P();
86            if ( __atomic_sub_fetch( &park_counter, 1, __ATOMIC_SEQ_CST) < 0 )
87                park();
88           
89            // execute any blocks available with status set to 0
90            for ( int i = 0; i < 3; i++ ) {
91                if (clause_statuses[i] == __SELECT_SAT) {
92                    switch (i) {
93                        case 0:
94                            try {
95                                if (on_selected( A, clause1 ))
96                                    doA();
97                            }
98                            finally { clause_statuses[i] = __SELECT_RUN; unregister_select(A, clause1); }
99                            break;
100                        case 1:
101                            ... same gen as A but for B and clause2 ...
102                            break;
103                        case 2:
104                            ... same gen as A but for C and clause3 ...
105                            break;
106                    }
107                }
108            }
109        }
110
111        // ensure that the blocks that triggered is_full_sat_1 are run
112        // by running every un-run block that is SAT from the start until
113        // the predicate is SAT when considering RUN status = true
114        for ( int i = 0; i < 3; i++ ) {
115            if (is_done_sat_1( clause_statuses )) break;
116            if (clause_statuses[i] == __SELECT_SAT)
117                ... Same if body here as in loop above ...
118        }
119    } finally {
120        // the unregister and on_selected calls are needed to support primitives where the acquire has side effects
121        // so the corresponding block MUST be run for those primitives to not lose state (example is channels)
122        if ( ! has_run(clause_statuses[0]) && whenA && unregister_select(A, clause1) && on_selected( A, clause1 ) )
123            doA();
124        ... repeat if above for B and C ...
125    }
126}
127
128*/
129
130namespace Concurrency {
131
132class GenerateWaitUntilCore final {
133    vector<FunctionDecl *> & satFns;
134        UniqueName namer_sat = "__is_full_sat_"s;
135    UniqueName namer_run = "__is_run_sat_"s;
136        UniqueName namer_park = "__park_counter_"s;
137        UniqueName namer_status = "__clause_statuses_"s;
138        UniqueName namer_node = "__clause_"s;
139    UniqueName namer_target = "__clause_target_"s;
140    UniqueName namer_when = "__when_cond_"s;
141
142    string idxName = "__CFA_clause_idx_";
143
144    struct ClauseData {
145        string nodeName;
146        string targetName;
147        string whenName;
148        int index;
149        string & statusName;
150        ClauseData( int index, string & statusName ) : index(index), statusName(statusName) {}
151    };
152
153    const StructDecl * selectNodeDecl = nullptr;
154
155    // This first set of routines are all used to do the complicated job of
156    //    dealing with how to set predicate statuses with certain when_conds T/F
157    //    so that the when_cond == F effectively makes that clause "disappear"
158    void updateAmbiguousWhen( WaitUntilStmt::ClauseNode * currNode, bool andAbove, bool orAbove, bool andBelow, bool orBelow );
159    void paintWhenTree( WaitUntilStmt::ClauseNode * currNode, bool andAbove, bool orAbove, bool & andBelow, bool & orBelow );
160    bool paintWhenTree( WaitUntilStmt::ClauseNode * currNode );
161    void collectWhens( WaitUntilStmt::ClauseNode * currNode, vector<pair<int, WaitUntilStmt::ClauseNode *>> & ambigIdxs, vector<int> & andIdxs, int & index, bool parentAmbig, bool parentAnd );
162    void collectWhens( WaitUntilStmt::ClauseNode * currNode, vector<pair<int, WaitUntilStmt::ClauseNode *>> & ambigIdxs, vector<int> & andIdxs );
163    void updateWhenState( WaitUntilStmt::ClauseNode * currNode );
164    void genSubtreeAssign( const WaitUntilStmt * stmt, WaitUntilStmt::ClauseNode * currNode, bool status, int & idx, CompoundStmt * retStmt, vector<ClauseData *> & clauseData );
165    void genStatusAssign( const WaitUntilStmt * stmt, WaitUntilStmt::ClauseNode * currNode, int & idx, CompoundStmt * retStmt, vector<ClauseData *> & clauseData );
166    CompoundStmt * getStatusAssignment( const WaitUntilStmt * stmt, vector<ClauseData *> & clauseData );
167    Stmt * genWhenStateConditions( const WaitUntilStmt * stmt, vector<ClauseData *> & clauseData, vector<pair<int, WaitUntilStmt::ClauseNode *>> & ambigClauses, vector<pair<int, WaitUntilStmt::ClauseNode *>>::size_type ambigIdx );
168
169    // These routines are just code-gen helpers
170    void addPredicates( const WaitUntilStmt * stmt, string & satName, string & runName );
171    void setUpClause( const WhenClause * clause, ClauseData * data, string & pCountName, CompoundStmt * body );
172    ForStmt * genStatusCheckFor( const WaitUntilStmt * stmt, vector<ClauseData *> & clauseData, string & predName );
173    Expr * genSelectTraitCall( const WhenClause * clause, const ClauseData * data, string fnName );
174    CompoundStmt * genStmtBlock( const WhenClause * clause, const ClauseData * data );
175    Stmt * genElseClauseBranch( const WaitUntilStmt * stmt, string & runName, string & arrName, vector<ClauseData *> & clauseData );
176    Stmt * genNoElseClauseBranch( const WaitUntilStmt * stmt, string & satName, string & runName, string & arrName, string & pCountName, vector<ClauseData *> & clauseData );
177    void genClauseInits( const WaitUntilStmt * stmt, vector<ClauseData *> & clauseData, CompoundStmt * body, string & statusName, string & elseWhenName );
178    Stmt * recursiveOrIfGen( const WaitUntilStmt * stmt, vector<ClauseData *> & data, vector<ClauseData*>::size_type idx, string & elseWhenName );
179    Stmt * buildOrCaseSwitch( const WaitUntilStmt * stmt, string & statusName, vector<ClauseData *> & data );
180    Stmt * genAllOr( const WaitUntilStmt * stmt );
181
182  public:
183    void previsit( const StructDecl * decl );
184        Stmt * postvisit( const WaitUntilStmt * stmt );
185    GenerateWaitUntilCore( vector<FunctionDecl *> & satFns ): satFns(satFns) {}
186};
187
188// Finds select_node decl
189void GenerateWaitUntilCore::previsit( const StructDecl * decl ) {
190    if ( !decl->body ) {
191                return;
192        } else if ( "select_node" == decl->name ) {
193                assert( !selectNodeDecl );
194                selectNodeDecl = decl;
195        }
196}
197
198void GenerateWaitUntilCore::updateAmbiguousWhen( WaitUntilStmt::ClauseNode * currNode, bool andAbove, bool orAbove, bool andBelow, bool orBelow ) {
199    // all children when-ambiguous
200    if ( currNode->left->ambiguousWhen && currNode->right->ambiguousWhen )
201        // true iff an ancestor/descendant has a different operation
202        currNode->ambiguousWhen = (orAbove || orBelow) && (andBelow || andAbove);
203    // ambiguousWhen is initially false so theres no need to set it here
204}
205
206// Traverses ClauseNode tree and paints each AND/OR node as when-ambiguous true or false
207// This tree painting is needed to generate the if statements that set the initial state
208//    of the clause statuses when some clauses are turned off via when_cond
209// An internal AND/OR node is when-ambiguous if it satisfies all of the following:
210// - It has an ancestor or descendant that is a different operation, i.e. (AND has an OR ancestor or vice versa)
211// - All of its descendent clauses are optional, i.e. they have a when_cond defined on the WhenClause
212void GenerateWaitUntilCore::paintWhenTree( WaitUntilStmt::ClauseNode * currNode, bool andAbove, bool orAbove, bool & andBelow, bool & orBelow ) {
213    bool aBelow = false; // updated by child nodes
214    bool oBelow = false; // updated by child nodes
215    switch (currNode->op) {
216        case WaitUntilStmt::ClauseNode::AND:
217            paintWhenTree( currNode->left, true, orAbove, aBelow, oBelow );
218            paintWhenTree( currNode->right, true, orAbove, aBelow, oBelow );
219
220            // update currNode's when flag based on conditions listed in fn signature comment above
221            updateAmbiguousWhen(currNode, true, orAbove, aBelow, oBelow );
222
223            // set return flags to tell parents which decendant ops have been seen
224            andBelow = true;
225            orBelow = oBelow;
226            return;
227        case WaitUntilStmt::ClauseNode::OR:
228            paintWhenTree( currNode->left, andAbove, true, aBelow, oBelow );
229            paintWhenTree( currNode->right, andAbove, true, aBelow, oBelow );
230
231            // update currNode's when flag based on conditions listed in fn signature comment above
232            updateAmbiguousWhen(currNode, andAbove, true, aBelow, oBelow );
233
234            // set return flags to tell parents which decendant ops have been seen
235            andBelow = aBelow;
236            orBelow = true;
237            return;
238        case WaitUntilStmt::ClauseNode::LEAF:
239            if ( currNode->leaf->when_cond )
240                currNode->ambiguousWhen = true;
241            return;
242        default:
243            assertf(false, "Unreachable waituntil clause node type. How did you get here???");
244    }
245}
246
247// overloaded wrapper for paintWhenTree that sets initial values
248// returns true if entire tree is OR's (special case)
249bool GenerateWaitUntilCore::paintWhenTree( WaitUntilStmt::ClauseNode * currNode ) {
250    bool aBelow = false, oBelow = false; // unused by initial call
251    paintWhenTree( currNode, false, false, aBelow, oBelow );
252    return !aBelow;
253}
254
255// Helper: returns Expr that represents arrName[index]
256Expr * genArrAccessExpr( const CodeLocation & loc, int index, string arrName ) {
257    return new UntypedExpr ( loc, 
258        new NameExpr( loc, "?[?]" ),
259        {
260            new NameExpr( loc, arrName ),
261            ConstantExpr::from_int( loc, index )
262        }
263    );
264}
265
266// After the ClauseNode AND/OR nodes are painted this routine is called to traverses the tree and does the following:
267// - collects a set of indices in the clause arr that refer whenclauses that can have ambiguous status assignments (ambigIdxs)
268// - collects a set of indices in the clause arr that refer whenclauses that have a when() defined and an AND node as a parent (andIdxs)
269// - updates LEAF nodes to be when-ambiguous if their direct parent is when-ambiguous.
270void GenerateWaitUntilCore::collectWhens( WaitUntilStmt::ClauseNode * currNode, vector<pair<int, WaitUntilStmt::ClauseNode *>> & ambigIdxs, vector<int> & andIdxs, int & index, bool parentAmbig, bool parentAnd ) {
271    switch (currNode->op) {
272        case WaitUntilStmt::ClauseNode::AND:
273            collectWhens( currNode->left, ambigIdxs, andIdxs, index, currNode->ambiguousWhen, true );
274            collectWhens( currNode->right,  ambigIdxs, andIdxs, index, currNode->ambiguousWhen, true );
275            return;
276        case WaitUntilStmt::ClauseNode::OR:
277            collectWhens( currNode->left,  ambigIdxs, andIdxs, index, currNode->ambiguousWhen, false );
278            collectWhens( currNode->right,  ambigIdxs, andIdxs, index, currNode->ambiguousWhen, false );
279            return;
280        case WaitUntilStmt::ClauseNode::LEAF:
281            if ( parentAmbig ) {
282                ambigIdxs.push_back(make_pair(index, currNode));
283            }
284            if ( parentAnd && currNode->leaf->when_cond ) {
285                currNode->childOfAnd = true;
286                andIdxs.push_back(index);
287            }
288            index++;
289            return;
290        default:
291            assertf(false, "Unreachable waituntil clause node type. How did you get here???");
292    }
293}
294
295// overloaded wrapper for collectWhens that sets initial values
296void GenerateWaitUntilCore::collectWhens( WaitUntilStmt::ClauseNode * currNode, vector<pair<int, WaitUntilStmt::ClauseNode *>> & ambigIdxs, vector<int> & andIdxs ) {
297    int idx = 0;
298    collectWhens( currNode, ambigIdxs, andIdxs, idx, false, false );
299}
300
301// recursively updates ClauseNode whenState on internal nodes so that next pass can see which
302//    subtrees are "turned off"
303// sets whenState = false iff both children have whenState == false.
304// similar to paintWhenTree except since paintWhenTree also filtered out clauses we don't need to consider based on op
305// since the ambiguous clauses were filtered in paintWhenTree we don't need to worry about that here
306void GenerateWaitUntilCore::updateWhenState( WaitUntilStmt::ClauseNode * currNode ) {
307    if ( currNode->op == WaitUntilStmt::ClauseNode::LEAF ) return;
308    updateWhenState( currNode->left );
309    updateWhenState( currNode->right );
310    if ( !currNode->left->whenState && !currNode->right->whenState )
311        currNode->whenState = false;
312    else 
313        currNode->whenState = true;
314}
315
316// generates the minimal set of status assignments to ensure predicate subtree passed as currNode evaluates to status
317// assumes that this will only be called on subtrees that are entirely whenState == false
318void GenerateWaitUntilCore::genSubtreeAssign( const WaitUntilStmt * stmt, WaitUntilStmt::ClauseNode * currNode, bool status, int & idx, CompoundStmt * retStmt, vector<ClauseData *> & clauseData ) {
319    if ( ( currNode->op == WaitUntilStmt::ClauseNode::AND && status )
320        || ( currNode->op == WaitUntilStmt::ClauseNode::OR && !status ) ) {
321        // need to recurse on both subtrees if && subtree needs to be true or || subtree needs to be false
322        genSubtreeAssign( stmt, currNode->left, status, idx, retStmt, clauseData );
323        genSubtreeAssign( stmt, currNode->right, status, idx, retStmt, clauseData );
324    } else if ( ( currNode->op == WaitUntilStmt::ClauseNode::OR && status )
325        || ( currNode->op == WaitUntilStmt::ClauseNode::AND && !status ) ) {
326        // only one subtree needs to evaluate to status if && subtree needs to be true or || subtree needs to be false
327        CompoundStmt * leftStmt = new CompoundStmt( stmt->location );
328        CompoundStmt * rightStmt = new CompoundStmt( stmt->location );
329
330        // only one side needs to evaluate to status so we recurse on both subtrees
331        //    but only keep the statements from the subtree with minimal statements
332        genSubtreeAssign( stmt, currNode->left, status, idx, leftStmt, clauseData );
333        genSubtreeAssign( stmt, currNode->right, status, idx, rightStmt, clauseData );
334       
335        // append minimal statements to retStmt
336        if ( leftStmt->kids.size() < rightStmt->kids.size() ) {
337            retStmt->kids.splice( retStmt->kids.end(), leftStmt->kids );
338        } else {
339            retStmt->kids.splice( retStmt->kids.end(), rightStmt->kids );
340        }
341       
342        delete leftStmt;
343        delete rightStmt;
344    } else if ( currNode->op == WaitUntilStmt::ClauseNode::LEAF ) {
345        const CodeLocation & loc = stmt->location;
346        if ( status && !currNode->childOfAnd ) {
347            retStmt->push_back(
348                new ExprStmt( loc, 
349                    UntypedExpr::createAssign( loc,
350                        genArrAccessExpr( loc, idx, clauseData.at(idx)->statusName ),
351                        new NameExpr( loc, "__SELECT_RUN" )
352                    )
353                )
354            );
355        } else if ( !status && currNode->childOfAnd ) {
356            retStmt->push_back(
357                new ExprStmt( loc, 
358                    UntypedExpr::createAssign( loc,
359                        genArrAccessExpr( loc, idx, clauseData.at(idx)->statusName ),
360                        new NameExpr( loc, "__SELECT_UNSAT" )
361                    )
362                )
363            );
364        }
365
366        // No need to generate statements for the following cases since childOfAnd are always set to true
367        //    and !childOfAnd are always false
368        // - status && currNode->childOfAnd
369        // - !status && !currNode->childOfAnd
370        idx++;
371    }
372}
373
374void GenerateWaitUntilCore::genStatusAssign( const WaitUntilStmt * stmt, WaitUntilStmt::ClauseNode * currNode, int & idx, CompoundStmt * retStmt, vector<ClauseData *> & clauseData ) {
375    switch (currNode->op) {
376        case WaitUntilStmt::ClauseNode::AND:
377            // check which subtrees have all whenState == false (disabled)
378            if (!currNode->left->whenState && !currNode->right->whenState) {
379                // this case can only occur when whole tree is disabled since otherwise
380                //    genStatusAssign( ... ) isn't called on nodes with whenState == false
381                assert( !currNode->whenState ); // paranoidWWW
382                // whole tree disabled so pass true so that select is SAT vacuously
383                genSubtreeAssign( stmt, currNode, true, idx, retStmt, clauseData );
384            } else if ( !currNode->left->whenState ) {
385                // pass true since x && true === x
386                genSubtreeAssign( stmt, currNode->left, true, idx, retStmt, clauseData );
387                genStatusAssign( stmt, currNode->right, idx, retStmt, clauseData );
388            } else if ( !currNode->right->whenState ) {
389                genStatusAssign( stmt, currNode->left, idx, retStmt, clauseData );
390                genSubtreeAssign( stmt, currNode->right, true, idx, retStmt, clauseData );
391            } else { 
392                // if no children with whenState == false recurse normally via break
393                break;
394            }
395            return;
396        case WaitUntilStmt::ClauseNode::OR:
397            if (!currNode->left->whenState && !currNode->right->whenState) {
398                assert( !currNode->whenState ); // paranoid
399                genSubtreeAssign( stmt, currNode, true, idx, retStmt, clauseData );
400            } else if ( !currNode->left->whenState ) {
401                // pass false since x || false === x
402                genSubtreeAssign( stmt, currNode->left, false, idx, retStmt, clauseData );
403                genStatusAssign( stmt, currNode->right, idx, retStmt, clauseData );
404            } else if ( !currNode->right->whenState ) {
405                genStatusAssign( stmt, currNode->left, idx, retStmt, clauseData );
406                genSubtreeAssign( stmt, currNode->right, false, idx, retStmt, clauseData );
407            } else { 
408                break;
409            }
410            return;
411        case WaitUntilStmt::ClauseNode::LEAF:
412            idx++;
413            return;
414        default:
415            assertf(false, "Unreachable waituntil clause node type. How did you get here???");
416    }
417    genStatusAssign( stmt, currNode->left, idx, retStmt, clauseData );
418    genStatusAssign( stmt, currNode->right, idx, retStmt, clauseData );
419}
420
421// generates a minimal set of assignments for status arr based on which whens are toggled on/off
422CompoundStmt * GenerateWaitUntilCore::getStatusAssignment( const WaitUntilStmt * stmt, vector<ClauseData *> & clauseData ) {
423    updateWhenState( stmt->predicateTree );
424    CompoundStmt * retval = new CompoundStmt( stmt->location );
425    int idx = 0;
426    genStatusAssign( stmt, stmt->predicateTree, idx, retval, clauseData );
427    return retval;
428}
429
430// generates nested if/elses for all possible assignments of ambiguous when_conds
431// exponential size of code gen but linear runtime O(n), where n is number of ambiguous whens()
432Stmt * GenerateWaitUntilCore::genWhenStateConditions( const WaitUntilStmt * stmt, vector<ClauseData *> & clauseData, 
433    vector<pair<int, WaitUntilStmt::ClauseNode *>> & ambigClauses, vector<pair<int, WaitUntilStmt::ClauseNode *>>::size_type ambigIdx ) {
434    // I hate C++ sometimes, using vector<pair<int, WaitUntilStmt::ClauseNode *>>::size_type for size() comparison seems silly.
435    //    Why is size_type parameterized on the type stored in the vector?????
436
437    const CodeLocation & loc = stmt->location;
438    int clauseIdx = ambigClauses.at(ambigIdx).first;
439    WaitUntilStmt::ClauseNode * currNode = ambigClauses.at(ambigIdx).second;
440    Stmt * thenStmt;
441    Stmt * elseStmt;
442   
443    if ( ambigIdx == ambigClauses.size() - 1 ) { // base case
444        currNode->whenState = true;
445        thenStmt = getStatusAssignment( stmt, clauseData );
446        currNode->whenState = false;
447        elseStmt = getStatusAssignment( stmt, clauseData );
448    } else {
449        // recurse both with when enabled and disabled to generate all possible cases
450        currNode->whenState = true;
451        thenStmt = genWhenStateConditions( stmt, clauseData, ambigClauses, ambigIdx + 1 );
452        currNode->whenState = false;
453        elseStmt = genWhenStateConditions( stmt, clauseData, ambigClauses, ambigIdx + 1 );
454    }
455
456    // insert first recursion result in if ( __when_cond_ ) { ... }
457    // insert second recursion result in else { ... }
458    return new CompoundStmt ( loc,
459        {
460            new IfStmt( loc,
461                new NameExpr( loc, clauseData.at(clauseIdx)->whenName ),
462                thenStmt,
463                elseStmt
464            )
465        }
466    );
467}
468
469// typedef a fn ptr so that we can reuse genPredExpr
470// genLeafExpr is used to refer to one of the following two routines
471typedef Expr * (*GenLeafExpr)( const CodeLocation & loc, int & index );
472
473// return Expr that represents clause_statuses[index]
474// mutates index to be index + 1
475Expr * genSatExpr( const CodeLocation & loc, int & index ) {
476    return genArrAccessExpr( loc, index++, "clause_statuses" );
477}
478
479// return Expr that represents has_run(clause_statuses[index])
480Expr * genRunExpr( const CodeLocation & loc, int & index ) {
481    return new UntypedExpr ( loc, 
482        new NameExpr( loc, "__CFA_has_clause_run" ),
483        { genSatExpr( loc, index ) }
484    );
485}
486
487// Takes in the ClauseNode tree and recursively generates
488// the predicate expr used inside the predicate functions
489Expr * genPredExpr( const CodeLocation & loc, WaitUntilStmt::ClauseNode * currNode, int & idx, GenLeafExpr genLeaf ) {
490    switch (currNode->op) {
491        case WaitUntilStmt::ClauseNode::AND:
492            return new LogicalExpr( loc, 
493                new CastExpr( loc, genPredExpr( loc, currNode->left, idx, genLeaf ), new BasicType( BasicType::Kind::Bool ), GeneratedFlag::ExplicitCast ),
494                new CastExpr( loc, genPredExpr( loc, currNode->right, idx, genLeaf ), new BasicType( BasicType::Kind::Bool ), GeneratedFlag::ExplicitCast ), 
495                LogicalFlag::AndExpr
496            );
497        case WaitUntilStmt::ClauseNode::OR:
498            return new LogicalExpr( loc,
499                new CastExpr( loc, genPredExpr( loc, currNode->left, idx, genLeaf ), new BasicType( BasicType::Kind::Bool ), GeneratedFlag::ExplicitCast ),
500                new CastExpr( loc, genPredExpr( loc, currNode->right, idx, genLeaf ), new BasicType( BasicType::Kind::Bool ), GeneratedFlag::ExplicitCast ), 
501                LogicalFlag::OrExpr );
502        case WaitUntilStmt::ClauseNode::LEAF:
503            return genLeaf( loc, idx );
504        default:
505            assertf(false, "Unreachable waituntil clause node type. How did you get here???");
506    }
507}
508
509
510// Builds the predicate functions used to check the status of the waituntil statement
511/* Ex:
512{
513    waituntil( A ){ doA(); }
514    or waituntil( B ){ doB(); }
515    and waituntil( C ) { doC(); }
516}
517generates =>
518static inline bool is_full_sat_1( int * clause_statuses ) {
519    return clause_statuses[0]
520        || clause_statuses[1]
521        && clause_statuses[2];
522}
523
524static inline bool is_done_sat_1( int * clause_statuses ) {
525    return has_run(clause_statuses[0])
526        || has_run(clause_statuses[1])
527        && has_run(clause_statuses[2]);
528}
529*/
530// Returns a predicate function decl
531// predName and genLeaf determine if this generates an is_done or an is_full predicate
532FunctionDecl * buildPredicate( const WaitUntilStmt * stmt, GenLeafExpr genLeaf, string & predName ) {
533    int arrIdx = 0;
534    const CodeLocation & loc = stmt->location;
535    CompoundStmt * body = new CompoundStmt( loc );
536    body->push_back( new ReturnStmt( loc, genPredExpr( loc,  stmt->predicateTree, arrIdx, genLeaf ) ) );
537
538    return new FunctionDecl( loc,
539        predName,
540        {},                     // forall
541        {
542            new ObjectDecl( loc,
543                "clause_statuses",
544                new PointerType( new BasicType( BasicType::Kind::LongUnsignedInt ) )
545            )
546        },
547        { 
548            new ObjectDecl( loc,
549                "sat_ret",
550                new BasicType( BasicType::Kind::Bool )
551            )
552        },
553        body,               // body
554        { Storage::Static },    // storage
555        Linkage::Cforall,       // linkage
556        {},                     // attributes
557        { Function::Inline }
558    );
559}
560
561// Creates is_done and is_full predicates
562void GenerateWaitUntilCore::addPredicates( const WaitUntilStmt * stmt, string & satName, string & runName ) {
563    if ( !stmt->else_stmt || stmt->else_cond ) // don't need SAT predicate when else variation with no else_cond
564        satFns.push_back( Concurrency::buildPredicate( stmt, genSatExpr, satName ) ); 
565    satFns.push_back( Concurrency::buildPredicate( stmt, genRunExpr, runName ) );
566}
567
568// Adds the following to body:
569// if ( when_cond ) { // this if is omitted if no when() condition
570//      setup_clause( clause1, &clause_statuses[0], &park_counter );
571//      register_select(A, clause1);
572// }
573void GenerateWaitUntilCore::setUpClause( const WhenClause * clause, ClauseData * data, string & pCountName, CompoundStmt * body ) {   
574    CompoundStmt * currBody = body;
575    const CodeLocation & loc = clause->location;
576
577    // If we have a when_cond make the initialization conditional
578    if ( clause->when_cond )
579        currBody = new CompoundStmt( loc );
580
581    // Generates: setup_clause( clause1, &clause_statuses[0], &park_counter );
582    currBody->push_back( new ExprStmt( loc,
583        new UntypedExpr ( loc,
584            new NameExpr( loc, "setup_clause" ),
585            {
586                new NameExpr( loc, data->nodeName ),
587                new AddressExpr( loc, genArrAccessExpr( loc, data->index, data->statusName ) ),
588                new AddressExpr( loc, new NameExpr( loc, pCountName ) )
589            }
590        )
591    ));
592
593    // Generates: register_select(A, clause1);
594    currBody->push_back( new ExprStmt( loc, genSelectTraitCall( clause, data, "register_select" ) ) );
595
596    // generates: if ( when_cond ) { ... currBody ... }
597    if ( clause->when_cond )
598        body->push_back( 
599            new IfStmt( loc,
600                new NameExpr( loc, data->whenName ),
601                currBody
602            )
603        );
604}
605
606// Used to generate a call to one of the select trait routines
607Expr * GenerateWaitUntilCore::genSelectTraitCall( const WhenClause * clause, const ClauseData * data, string fnName ) {
608    const CodeLocation & loc = clause->location;
609    return new UntypedExpr ( loc,
610        new NameExpr( loc, fnName ),
611        {
612            new NameExpr( loc, data->targetName ),
613            new NameExpr( loc, data->nodeName )
614        }
615    );
616}
617
618// Generates:
619/* if ( on_selected( target_1, node_1 )) ... corresponding body of target_1 ...
620*/
621CompoundStmt * GenerateWaitUntilCore::genStmtBlock( const WhenClause * clause, const ClauseData * data ) {
622    const CodeLocation & cLoc = clause->location;
623    return new CompoundStmt( cLoc,
624        {
625            new IfStmt( cLoc,
626                genSelectTraitCall( clause, data, "on_selected" ),
627                new CompoundStmt( cLoc,
628                    {
629                        ast::deepCopy( clause->stmt )
630                    }
631                )
632            )
633        }
634    );
635}
636
637// this routine generates and returns the following
638/*for ( int i = 0; i < numClauses; i++ ) {
639    if ( predName(clause_statuses) ) break;
640    if (clause_statuses[i] == __SELECT_SAT) {
641        switch (i) {
642            case 0:
643                try {
644                    if (on_selected( target1, clause1 ))
645                        dotarget1stmt();
646                }
647                finally { clause_statuses[i] = __SELECT_RUN; unregister_select(target1, clause1); }
648                break;
649            ...
650            case N:
651                ...
652                break;
653        }
654    }
655}*/
656ForStmt * GenerateWaitUntilCore::genStatusCheckFor( const WaitUntilStmt * stmt, vector<ClauseData *> & clauseData, string & predName ) {
657    CompoundStmt * ifBody = new CompoundStmt( stmt->location );
658    const CodeLocation & loc = stmt->location;
659
660    /* generates:
661    switch (i) {
662        case 0:
663            try {
664                if (on_selected( target1, clause1 ))
665                    dotarget1stmt();
666            }
667            finally { clause_statuses[i] = __SELECT_RUN; unregister_select(target1, clause1); }
668            break;
669            ...
670        case N:
671            ...
672            break;
673    }*/
674    std::vector<ptr<CaseClause>> switchCases;
675    int idx = 0;
676    for ( const auto & clause: stmt->clauses ) {
677        const CodeLocation & cLoc = clause->location;
678        switchCases.push_back(
679            new CaseClause( cLoc,
680                ConstantExpr::from_int( cLoc, idx ),
681                {
682                    new CompoundStmt( cLoc,
683                        {
684                            new ast::TryStmt( cLoc,
685                                genStmtBlock( clause, clauseData.at(idx) ),
686                                {},
687                                new ast::FinallyClause( cLoc, 
688                                    new CompoundStmt( cLoc,
689                                        {
690                                            new ExprStmt( loc,
691                                                new UntypedExpr ( loc,
692                                                    new NameExpr( loc, "?=?" ),
693                                                    {
694                                                        new UntypedExpr ( loc, 
695                                                            new NameExpr( loc, "?[?]" ),
696                                                            {
697                                                                new NameExpr( loc, clauseData.at(0)->statusName ),
698                                                                new NameExpr( loc, idxName )
699                                                            }
700                                                        ),
701                                                        new NameExpr( loc, "__SELECT_RUN" )
702                                                    }
703                                                )
704                                            ),
705                                            new ExprStmt( loc, genSelectTraitCall( clause, clauseData.at(idx), "unregister_select" ) )
706                                        }
707                                    )
708                                )
709                            ),
710                            new BranchStmt( cLoc, BranchStmt::Kind::Break, Label( cLoc ) )
711                        }
712                    )
713                }
714            )
715        );
716        idx++;
717    }
718
719    ifBody->push_back(
720        new SwitchStmt( loc,
721            new NameExpr( loc, idxName ),
722            std::move( switchCases )
723        )
724    );
725
726    // gens:
727    // if (clause_statuses[i] == __SELECT_SAT) {
728    //      ... ifBody  ...
729    // }
730    IfStmt * ifSwitch = new IfStmt( loc,
731        new UntypedExpr ( loc,
732            new NameExpr( loc, "?==?" ),
733            {
734                new UntypedExpr ( loc, 
735                    new NameExpr( loc, "?[?]" ),
736                    {
737                        new NameExpr( loc, clauseData.at(0)->statusName ),
738                        new NameExpr( loc, idxName )
739                    }
740                ),
741                new NameExpr( loc, "__SELECT_SAT" )
742            }
743        ),      // condition
744        ifBody  // body
745    );
746
747    return new ForStmt( loc,
748        {
749            new DeclStmt( loc,
750                new ObjectDecl( loc,
751                    idxName,
752                    new BasicType( BasicType::Kind::SignedInt ),
753                    new SingleInit( loc, ConstantExpr::from_int( loc, 0 ) )
754                )
755            )
756        },  // inits
757        new UntypedExpr ( loc,
758            new NameExpr( loc, "?<?" ),
759            {
760                new NameExpr( loc, idxName ),
761                ConstantExpr::from_int( loc, stmt->clauses.size() )
762            }
763        ),  // cond
764        new UntypedExpr ( loc,
765            new NameExpr( loc, "?++" ),
766            { new NameExpr( loc, idxName ) }
767        ),  // inc
768        new CompoundStmt( loc,
769            {
770                new IfStmt( loc,
771                    new UntypedExpr ( loc,
772                        new NameExpr( loc, predName ),
773                        { new NameExpr( loc, clauseData.at(0)->statusName ) }
774                    ),
775                    new BranchStmt( loc, BranchStmt::Kind::Break, Label( loc ) )
776                ),
777                ifSwitch
778            }
779        )   // body
780    );
781}
782
783// Generates: !is_full_sat_n() / !is_run_sat_n()
784Expr * genNotSatExpr( const WaitUntilStmt * stmt, string & satName, string & arrName ) {
785    const CodeLocation & loc = stmt->location;
786    return new UntypedExpr ( loc,
787        new NameExpr( loc, "!?" ),
788        {
789            new UntypedExpr ( loc,
790                new NameExpr( loc, satName ),
791                { new NameExpr( loc, arrName ) }
792            )
793        }
794    );
795}
796
797// Generates the code needed for waituntils with an else ( ... )
798// Checks clauses once after registering for completion and runs them if completes
799// If not enough have run to satisfy predicate after one pass then the else is run
800Stmt * GenerateWaitUntilCore::genElseClauseBranch( const WaitUntilStmt * stmt, string & runName, string & arrName, vector<ClauseData *> & clauseData ) {
801    return new CompoundStmt( stmt->else_stmt->location,
802        {
803            genStatusCheckFor( stmt, clauseData, runName ),
804            new IfStmt( stmt->else_stmt->location,
805                genNotSatExpr( stmt, runName, arrName ),
806                ast::deepCopy( stmt->else_stmt )
807            )
808        }
809    );
810}
811
812Stmt * GenerateWaitUntilCore::genNoElseClauseBranch( const WaitUntilStmt * stmt, string & satName, string & runName, string & arrName, string & pCountName, vector<ClauseData *> & clauseData ) {
813    CompoundStmt * whileBody = new CompoundStmt( stmt->location );
814    const CodeLocation & loc = stmt->location;
815
816    // generates: __CFA_maybe_park( &park_counter );
817    whileBody->push_back(
818        new ExprStmt( loc,
819            new UntypedExpr ( loc,
820                new NameExpr( loc, "__CFA_maybe_park" ),
821                { new AddressExpr( loc, new NameExpr( loc, pCountName ) ) }
822            )
823        )
824    );
825
826    whileBody->push_back( genStatusCheckFor( stmt, clauseData, satName ) );
827
828    return new CompoundStmt( loc,
829        {
830            new WhileDoStmt( loc,
831                genNotSatExpr( stmt, satName, arrName ),
832                whileBody,  // body
833                {}          // no inits
834            ),
835            genStatusCheckFor( stmt, clauseData, runName )
836        }
837    );
838}
839
840// generates the following decls for each clause to ensure the target expr and when_cond is only evaluated once
841// typeof(target) & __clause_target_0 = target;
842// bool __when_cond_0 = when_cond; // only generated if when_cond defined
843// select_node clause1;
844void GenerateWaitUntilCore::genClauseInits( const WaitUntilStmt * stmt, vector<ClauseData *> & clauseData, CompoundStmt * body, string & statusName, string & elseWhenName ) {
845    ClauseData * currClause;
846    for ( vector<ClauseData*>::size_type i = 0; i < stmt->clauses.size(); i++ ) {
847        currClause = new ClauseData( i, statusName );
848        currClause->nodeName = namer_node.newName();
849        currClause->targetName = namer_target.newName();
850        currClause->whenName = namer_when.newName();
851        clauseData.push_back(currClause);
852        const CodeLocation & cLoc = stmt->clauses.at(i)->location;
853
854        // typeof(target) & __clause_target_0 = target;
855        body->push_back(
856            new DeclStmt( cLoc,
857                new ObjectDecl( cLoc,
858                    currClause->targetName,
859                    new ReferenceType( new TypeofType( ast::deepCopy( stmt->clauses.at(i)->target ) ) ),
860                    new SingleInit( cLoc, ast::deepCopy( stmt->clauses.at(i)->target ) )
861                )
862            )
863        );
864
865        // bool __when_cond_0 = when_cond; // only generated if when_cond defined
866        if ( stmt->clauses.at(i)->when_cond )
867            body->push_back(
868                new DeclStmt( cLoc,
869                    new ObjectDecl( cLoc,
870                        currClause->whenName,
871                        new BasicType( BasicType::Kind::Bool ),
872                        new SingleInit( cLoc, ast::deepCopy( stmt->clauses.at(i)->when_cond ) )
873                    )
874                )
875            );
876       
877        // select_node clause1;
878        body->push_back(
879            new DeclStmt( cLoc,
880                new ObjectDecl( cLoc,
881                    currClause->nodeName,
882                    new StructInstType( selectNodeDecl )
883                )
884            )
885        );
886    }
887
888    if ( stmt->else_stmt && stmt->else_cond ) {
889        body->push_back(
890            new DeclStmt( stmt->else_cond->location,
891                new ObjectDecl( stmt->else_cond->location,
892                    elseWhenName,
893                    new BasicType( BasicType::Kind::Bool ),
894                    new SingleInit( stmt->else_cond->location, ast::deepCopy( stmt->else_cond ) )
895                )
896            )
897        );
898    }
899}
900
901/*
902if ( clause_status == &clause1 ) ... clause 1 body ...
903...
904elif ( clause_status == &clausen ) ... clause n body ...
905*/
906Stmt * GenerateWaitUntilCore::buildOrCaseSwitch( const WaitUntilStmt * stmt, string & statusName, vector<ClauseData *> & data ) {
907    const CodeLocation & loc = stmt->location;
908
909    IfStmt * outerIf = nullptr;
910        IfStmt * lastIf = nullptr;
911
912        //adds an if/elif clause for each select clause address to run the corresponding clause stmt
913        for ( long unsigned int i = 0; i < data.size(); i++ ) {
914        const CodeLocation & cLoc = stmt->clauses.at(i)->location;
915
916                IfStmt * currIf = new IfStmt( cLoc,
917                        new UntypedExpr( cLoc, 
918                new NameExpr( cLoc, "?==?" ), 
919                {
920                    new NameExpr( cLoc, statusName ),
921                    new CastExpr( cLoc, 
922                        new AddressExpr( cLoc, new NameExpr( cLoc, data.at(i)->nodeName ) ),
923                        new BasicType( BasicType::Kind::LongUnsignedInt ), GeneratedFlag::ExplicitCast
924                    )
925                }
926            ),
927            genStmtBlock( stmt->clauses.at(i), data.at(i) )
928                );
929               
930                if ( i == 0 ) {
931                        outerIf = currIf;
932                } else {
933                        // add ifstmt to else of previous stmt
934                        lastIf->else_ = currIf;
935                }
936
937                lastIf = currIf;
938        }
939
940    // C_TODO: will remove this commented code later. Currently it isn't needed but may switch to a modified version of this later if it has better performance
941    // std::vector<ptr<CaseClause>> switchCases;
942
943    // int idx = 0;
944    // for ( const auto & clause: stmt->clauses ) {
945    //     const CodeLocation & cLoc = clause->location;
946    //     switchCases.push_back(
947    //         new CaseClause( cLoc,
948    //             new CastExpr( cLoc,
949    //                 new AddressExpr( cLoc, new NameExpr( cLoc, data.at(idx)->targetName ) ),
950    //                 new BasicType( BasicType::Kind::LongUnsignedInt ), GeneratedFlag::ExplicitCast
951    //             ),
952    //             {
953    //                 new CompoundStmt( cLoc,
954    //                     {
955    //                         ast::deepCopy( clause->stmt ),
956    //                         new BranchStmt( cLoc, BranchStmt::Kind::Break, Label( cLoc ) )
957    //                     }
958    //                 )
959    //             }
960    //         )
961    //     );
962    //     idx++;
963    // }
964
965    return new CompoundStmt( loc,
966        {
967            new ExprStmt( loc, new UntypedExpr( loc, new NameExpr( loc, "park" ) ) ),
968            outerIf
969            // new SwitchStmt( loc,
970            //     new NameExpr( loc, statusName ),
971            //     std::move( switchCases )
972            // )
973        }
974    );
975}
976
977Stmt * GenerateWaitUntilCore::recursiveOrIfGen( const WaitUntilStmt * stmt, vector<ClauseData *> & data, vector<ClauseData*>::size_type idx, string & elseWhenName ) {
978    if ( idx == data.size() ) {   // base case, gen last else
979        const CodeLocation & cLoc = stmt->else_stmt->location;
980        if ( !stmt->else_stmt ) // normal non-else gen
981            return buildOrCaseSwitch( stmt, data.at(0)->statusName, data );
982
983        Expr * raceFnCall = new UntypedExpr( stmt->location,
984            new NameExpr( stmt->location, "__select_node_else_race" ),
985            { new NameExpr( stmt->location, data.at(0)->nodeName ) }
986        );
987
988        if ( stmt->else_stmt && stmt->else_cond ) { // return else conditional on both when and race
989            return new IfStmt( cLoc,
990                new LogicalExpr( cLoc,
991                    new CastExpr( cLoc,
992                        new NameExpr( cLoc, elseWhenName ),
993                        new BasicType( BasicType::Kind::Bool ), GeneratedFlag::ExplicitCast
994                    ),
995                    new CastExpr( cLoc,
996                        raceFnCall,
997                        new BasicType( BasicType::Kind::Bool ), GeneratedFlag::ExplicitCast
998                    ),
999                    LogicalFlag::AndExpr
1000                ),
1001                ast::deepCopy( stmt->else_stmt ),
1002                buildOrCaseSwitch( stmt, data.at(0)->statusName, data )
1003            );
1004        }
1005
1006        // return else conditional on race
1007        return new IfStmt( stmt->else_stmt->location,
1008            raceFnCall,
1009            ast::deepCopy( stmt->else_stmt ),
1010            buildOrCaseSwitch( stmt, data.at(0)->statusName, data )
1011        );
1012    }
1013    const CodeLocation & cLoc = stmt->clauses.at(idx)->location;
1014
1015    Expr * ifCond;
1016
1017    // If we have a when_cond make the register call conditional on it
1018    if ( stmt->clauses.at(idx)->when_cond ) {
1019        ifCond = new LogicalExpr( cLoc,
1020            new CastExpr( cLoc,
1021                new NameExpr( cLoc, data.at(idx)->whenName ), 
1022                new BasicType( BasicType::Kind::Bool ), GeneratedFlag::ExplicitCast
1023            ),
1024            new CastExpr( cLoc,
1025                genSelectTraitCall( stmt->clauses.at(idx), data.at(idx), "register_select" ),
1026                new BasicType( BasicType::Kind::Bool ), GeneratedFlag::ExplicitCast
1027            ),
1028            LogicalFlag::AndExpr
1029        );
1030    } else ifCond = genSelectTraitCall( stmt->clauses.at(idx), data.at(idx), "register_select" );
1031
1032    return new CompoundStmt( cLoc,
1033        {   // gens: setup_clause( clause1, &status, 0p );
1034            new ExprStmt( cLoc,
1035                new UntypedExpr ( cLoc,
1036                    new NameExpr( cLoc, "setup_clause" ),
1037                    {
1038                        new NameExpr( cLoc, data.at(idx)->nodeName ),
1039                        new AddressExpr( cLoc, new NameExpr( cLoc, data.at(idx)->statusName ) ),
1040                        ConstantExpr::null( cLoc, new PointerType( new BasicType( BasicType::Kind::SignedInt ) ) )
1041                    }
1042                )
1043            ),
1044            // gens: if (__when_cond && register_select()) { clause body } else { ... recursiveOrIfGen ... }
1045            new IfStmt( cLoc,
1046                ifCond,
1047                genStmtBlock( stmt->clauses.at(idx), data.at(idx) ),
1048                // ast::deepCopy( stmt->clauses.at(idx)->stmt ),
1049                recursiveOrIfGen( stmt, data, idx + 1, elseWhenName )
1050            )
1051        }
1052    );
1053}
1054
1055// This gens the special case of an all OR waituntil:
1056/*
1057int status = 0;
1058
1059typeof(target) & __clause_target_0 = target;
1060bool __when_cond_0 = when_cond; // only generated if when_cond defined
1061select_node clause1;
1062... generate above for rest of clauses ...
1063
1064try {
1065    setup_clause( clause1, &status, 0p );
1066    if ( __when_cond_0 && register_select( 1 ) ) {
1067        ... clause 1 body ...
1068    } else {
1069        ... recursively gen for each of n clauses ...
1070        setup_clause( clausen, &status, 0p );
1071        if ( __when_cond_n-1 && register_select( n ) ) {
1072            ... clause n body ...
1073        } else {
1074            if ( else_when ) ... else clause body ...
1075            else {
1076                park();
1077
1078                // after winning the race and before unpark() clause_status is set to be the winning clause index + 1
1079                if ( clause_status == &clause1) ... clause 1 body ...
1080                ...
1081                elif ( clause_status == &clausen ) ... clause n body ...
1082            }
1083        }
1084    }
1085}
1086finally {
1087    if ( __when_cond_1 && clause1.status != 0p) unregister_select( 1 ); // if registered unregister
1088    ...
1089    if ( __when_cond_n && clausen.status != 0p) unregister_select( n );
1090}
1091*/
1092Stmt * GenerateWaitUntilCore::genAllOr( const WaitUntilStmt * stmt ) {
1093    const CodeLocation & loc = stmt->location;
1094    string statusName = namer_status.newName();
1095    string elseWhenName = namer_when.newName();
1096    int numClauses = stmt->clauses.size();
1097    CompoundStmt * body = new CompoundStmt( stmt->location );
1098
1099    // Generates: unsigned long int status = 0;
1100    body->push_back( new DeclStmt( loc,
1101        new ObjectDecl( loc,
1102            statusName,
1103            new BasicType( BasicType::Kind::LongUnsignedInt ),
1104            new SingleInit( loc, ConstantExpr::from_int( loc, 0 ) )
1105        )
1106    ));
1107
1108    vector<ClauseData *> clauseData;
1109    genClauseInits( stmt, clauseData, body, statusName, elseWhenName );
1110
1111    vector<int> whenIndices; // track which clauses have whens
1112
1113    CompoundStmt * unregisters = new CompoundStmt( loc );
1114    Expr * ifCond;
1115    for ( int i = 0; i < numClauses; i++ ) {
1116        const CodeLocation & cLoc = stmt->clauses.at(i)->location;
1117        // Gens: node.status != 0p
1118        UntypedExpr * statusPtrCheck = new UntypedExpr( cLoc, 
1119            new NameExpr( cLoc, "?!=?" ), 
1120            {
1121                ConstantExpr::null( cLoc, new PointerType( new BasicType( BasicType::Kind::LongUnsignedInt ) ) ),
1122                new UntypedExpr( cLoc, 
1123                    new NameExpr( cLoc, "__get_clause_status" ), 
1124                    { new NameExpr( cLoc, clauseData.at(i)->nodeName ) } 
1125                ) 
1126            }
1127        );
1128
1129        // If we have a when_cond make the unregister call conditional on it
1130        if ( stmt->clauses.at(i)->when_cond ) {
1131            whenIndices.push_back(i);
1132            ifCond = new LogicalExpr( cLoc,
1133                new CastExpr( cLoc,
1134                    new NameExpr( cLoc, clauseData.at(i)->whenName ), 
1135                    new BasicType( BasicType::Kind::Bool ), GeneratedFlag::ExplicitCast
1136                ),
1137                new CastExpr( cLoc,
1138                    statusPtrCheck,
1139                    new BasicType( BasicType::Kind::Bool ), GeneratedFlag::ExplicitCast
1140                ),
1141                LogicalFlag::AndExpr
1142            );
1143        } else ifCond = statusPtrCheck;
1144       
1145        unregisters->push_back(
1146            new IfStmt( cLoc,
1147                ifCond,
1148                new ExprStmt( cLoc, genSelectTraitCall( stmt->clauses.at(i), clauseData.at(i), "unregister_select" ) ) 
1149            )
1150        );
1151    }
1152
1153    if ( whenIndices.empty() || whenIndices.size() != stmt->clauses.size() ) {
1154        body->push_back(
1155                new ast::TryStmt( loc,
1156                new CompoundStmt( loc, { recursiveOrIfGen( stmt, clauseData, 0, elseWhenName ) } ),
1157                {},
1158                new ast::FinallyClause( loc, unregisters )
1159            )
1160        );
1161    } else { // If all clauses have whens, we need to skip the waituntil if they are all false
1162        Expr * outerIfCond = new NameExpr( loc, clauseData.at( whenIndices.at(0) )->whenName );
1163        Expr * lastExpr = outerIfCond;
1164
1165        for ( vector<int>::size_type i = 1; i < whenIndices.size(); i++ ) {
1166            outerIfCond = new LogicalExpr( loc,
1167                new CastExpr( loc,
1168                    new NameExpr( loc, clauseData.at( whenIndices.at(i) )->whenName ), 
1169                    new BasicType( BasicType::Kind::Bool ), GeneratedFlag::ExplicitCast
1170                ),
1171                new CastExpr( loc,
1172                    lastExpr,
1173                    new BasicType( BasicType::Kind::Bool ), GeneratedFlag::ExplicitCast
1174                ),
1175                LogicalFlag::OrExpr
1176            );
1177            lastExpr = outerIfCond;
1178        }
1179
1180        body->push_back(
1181                new ast::TryStmt( loc,
1182                new CompoundStmt( loc, 
1183                    {
1184                        new IfStmt( loc,
1185                            outerIfCond,
1186                            recursiveOrIfGen( stmt, clauseData, 0, elseWhenName )
1187                        )
1188                    }
1189                ),
1190                {},
1191                new ast::FinallyClause( loc, unregisters )
1192            )
1193        );
1194    }
1195
1196    for ( ClauseData * datum : clauseData )
1197        delete datum;
1198
1199    return body;
1200}
1201
1202Stmt * GenerateWaitUntilCore::postvisit( const WaitUntilStmt * stmt ) {
1203    if ( !selectNodeDecl )
1204        SemanticError( stmt, "waituntil statement requires #include <waituntil.hfa>" );
1205
1206    // Prep clause tree to figure out how to set initial statuses
1207    // setTreeSizes( stmt->predicateTree );
1208    if ( paintWhenTree( stmt->predicateTree ) ) // if this returns true we can special case since tree is all OR's
1209        return genAllOr( stmt );
1210
1211    CompoundStmt * tryBody = new CompoundStmt( stmt->location );
1212    CompoundStmt * body = new CompoundStmt( stmt->location );
1213    string statusArrName = namer_status.newName();
1214    string pCountName = namer_park.newName();
1215    string satName = namer_sat.newName();
1216    string runName = namer_run.newName();
1217    string elseWhenName = namer_when.newName();
1218    int numClauses = stmt->clauses.size();
1219    addPredicates( stmt, satName, runName );
1220
1221    const CodeLocation & loc = stmt->location;
1222
1223    // Generates: int park_counter = 0;
1224    body->push_back( new DeclStmt( loc,
1225        new ObjectDecl( loc,
1226            pCountName,
1227            new BasicType( BasicType::Kind::SignedInt ),
1228            new SingleInit( loc, ConstantExpr::from_int( loc, 0 ) )
1229        )
1230    ));
1231
1232    // Generates: int clause_statuses[3] = { 0 };
1233    body->push_back( new DeclStmt( loc,
1234        new ObjectDecl( loc,
1235            statusArrName,
1236            new ArrayType( new BasicType( BasicType::Kind::LongUnsignedInt ), ConstantExpr::from_int( loc, numClauses ), LengthFlag::FixedLen, DimensionFlag::DynamicDim ),
1237            new ListInit( loc,
1238                {
1239                    new SingleInit( loc, ConstantExpr::from_int( loc, 0 ) )
1240                }
1241            )
1242        )
1243    ));
1244
1245    vector<ClauseData *> clauseData;
1246    genClauseInits( stmt, clauseData, body, statusArrName, elseWhenName );
1247
1248    vector<pair<int, WaitUntilStmt::ClauseNode *>> ambiguousClauses;       // list of ambiguous clauses
1249    vector<int> andWhenClauses;    // list of clauses that have an AND op as a direct parent and when_cond defined
1250
1251    collectWhens( stmt->predicateTree, ambiguousClauses, andWhenClauses );
1252
1253    // This is only needed for clauses that have AND as a parent and a when_cond defined
1254    // generates: if ( ! when_cond_0 ) clause_statuses_0 = __SELECT_RUN;
1255    for ( int idx : andWhenClauses ) {
1256        const CodeLocation & cLoc = stmt->clauses.at(idx)->location;
1257        body->push_back( 
1258            new IfStmt( cLoc,
1259                new UntypedExpr ( cLoc,
1260                    new NameExpr( cLoc, "!?" ),
1261                    { new NameExpr( cLoc, clauseData.at(idx)->whenName ) }
1262                ),  // IfStmt cond
1263                new ExprStmt( cLoc,
1264                    new UntypedExpr ( cLoc,
1265                        new NameExpr( cLoc, "?=?" ),
1266                        {
1267                            new UntypedExpr ( cLoc, 
1268                                new NameExpr( cLoc, "?[?]" ),
1269                                {
1270                                    new NameExpr( cLoc, statusArrName ),
1271                                    ConstantExpr::from_int( cLoc, idx )
1272                                }
1273                            ),
1274                            new NameExpr( cLoc, "__SELECT_RUN" )
1275                        }
1276                    )
1277                )  // IfStmt then
1278            )
1279        );
1280    }
1281
1282    // Only need to generate conditional initial state setting for ambiguous when clauses
1283    if ( !ambiguousClauses.empty() ) {
1284        body->push_back( genWhenStateConditions( stmt, clauseData, ambiguousClauses, 0 ) );
1285    }
1286
1287    // generates the following for each clause:
1288    // setup_clause( clause1, &clause_statuses[0], &park_counter );
1289    // register_select(A, clause1);
1290    for ( int i = 0; i < numClauses; i++ ) {
1291        setUpClause( stmt->clauses.at(i), clauseData.at(i), pCountName, tryBody );
1292    }
1293
1294    // generate satisfy logic based on if there is an else clause and if it is conditional
1295    if ( stmt->else_stmt && stmt->else_cond ) { // gen both else/non else branches
1296        tryBody->push_back(
1297            new IfStmt( stmt->else_cond->location,
1298                new NameExpr( stmt->else_cond->location, elseWhenName ),
1299                genElseClauseBranch( stmt, runName, statusArrName, clauseData ),
1300                genNoElseClauseBranch( stmt, satName, runName, statusArrName, pCountName, clauseData )
1301            )
1302        );
1303    } else if ( !stmt->else_stmt ) { // normal gen
1304        tryBody->push_back( genNoElseClauseBranch( stmt, satName, runName, statusArrName, pCountName, clauseData ) );
1305    } else { // generate just else
1306        tryBody->push_back( genElseClauseBranch( stmt, runName, statusArrName, clauseData ) );
1307    }
1308
1309    CompoundStmt * unregisters = new CompoundStmt( loc );
1310    // generates for each clause:
1311    // if ( !has_run( clause_statuses[i] ) )
1312    // OR if when_cond defined
1313    // if ( when_cond_i && !has_run( clause_statuses[i] ) )
1314    // body of if is:
1315    // { if (unregister_select(A, clause1) && on_selected(A, clause1)) clause1->stmt; } // this conditionally runs the block unregister_select returns true (needed by some primitives)
1316    Expr * ifCond;
1317    UntypedExpr * statusExpr; // !clause_statuses[i]
1318    for ( int i = 0; i < numClauses; i++ ) {
1319        const CodeLocation & cLoc = stmt->clauses.at(i)->location;
1320
1321        statusExpr = new UntypedExpr ( cLoc,
1322            new NameExpr( cLoc, "!?" ),
1323            {
1324                new UntypedExpr ( cLoc, 
1325                    new NameExpr( cLoc, "__CFA_has_clause_run" ),
1326                    {
1327                        genArrAccessExpr( cLoc, i, statusArrName )
1328                    }
1329                )
1330            }
1331        );
1332
1333        if ( stmt->clauses.at(i)->when_cond ) {
1334            // generates: if( when_cond_i && !has_run(clause_statuses[i]) )
1335            ifCond = new LogicalExpr( cLoc,
1336                new CastExpr( cLoc,
1337                    new NameExpr( cLoc, clauseData.at(i)->whenName ), 
1338                    new BasicType( BasicType::Kind::Bool ), GeneratedFlag::ExplicitCast
1339                ),
1340                new CastExpr( cLoc,
1341                    statusExpr,
1342                    new BasicType( BasicType::Kind::Bool ), GeneratedFlag::ExplicitCast
1343                ),
1344                LogicalFlag::AndExpr
1345            );
1346        } else // generates: if( !clause_statuses[i] )
1347            ifCond = statusExpr;
1348       
1349        unregisters->push_back( 
1350            new IfStmt( cLoc,
1351                ifCond,
1352                new CompoundStmt( cLoc,
1353                    {
1354                        new IfStmt( cLoc,
1355                            genSelectTraitCall( stmt->clauses.at(i), clauseData.at(i), "unregister_select" ),
1356                            // ast::deepCopy( stmt->clauses.at(i)->stmt )
1357                            genStmtBlock( stmt->clauses.at(i), clauseData.at(i) )
1358                        )
1359                    }
1360                )
1361               
1362            )
1363        );
1364    }
1365
1366    body->push_back( 
1367        new ast::TryStmt(
1368            loc,
1369            tryBody,
1370            {},
1371            new ast::FinallyClause( loc, unregisters )
1372        )
1373    );
1374
1375    for ( ClauseData * datum : clauseData )
1376        delete datum;
1377
1378    return body;
1379}
1380
1381// To add the predicates at global scope we need to do it in a second pass
1382// Predicates are added after "struct select_node { ... };"
1383class AddPredicateDecls final : public WithDeclsToAdd<> {
1384    vector<FunctionDecl *> & satFns;
1385    const StructDecl * selectNodeDecl = nullptr;
1386
1387  public:
1388    void previsit( const StructDecl * decl ) {
1389        if ( !decl->body ) {
1390            return;
1391        } else if ( "select_node" == decl->name ) {
1392            assert( !selectNodeDecl );
1393            selectNodeDecl = decl;
1394            for ( FunctionDecl * fn : satFns )
1395                declsToAddAfter.push_back(fn);           
1396        }
1397    }
1398    AddPredicateDecls( vector<FunctionDecl *> & satFns ): satFns(satFns) {}
1399};
1400
1401void generateWaitUntil( TranslationUnit & translationUnit ) {
1402    vector<FunctionDecl *> satFns;
1403        Pass<GenerateWaitUntilCore>::run( translationUnit, satFns );
1404    Pass<AddPredicateDecls>::run( translationUnit, satFns );
1405}
1406
1407} // namespace Concurrency
1408
1409// Local Variables: //
1410// tab-width: 4 //
1411// mode: c++ //
1412// compile-command: "make install" //
1413// End: //
Note: See TracBrowser for help on using the repository browser.