source: src/Concurrency/Waituntil.cpp @ d96f7c4

Last change on this file since d96f7c4 was ed96731, checked in by Andrew Beach <ajbeach@…>, 2 months ago

With{Stmts,Decls}ToAdd? how has an -X version like WithSymbolTableX. Although these -X versions might be useful can could possibly be removed in the future. (This is a therapy commit.)

  • Property mode set to 100644
File size: 49.9 KB
Line 
1//
2// Cforall Version 1.0.0 Copyright (C) 2016 University of Waterloo
3//
4// The contents of this file are covered under the licence agreement in the
5// file "LICENCE" distributed with Cforall.
6//
7// WaitforNew.cpp -- Expand waitfor clauses into code.
8//
9// Author           : Andrew Beach
10// Created On       : Fri May 27 10:31:00 2022
11// Last Modified By : Andrew Beach
12// Last Modified On : Tue Jun 13 13:30:00 2022
13// Update Count     : 0
14//
15
16#include "Waituntil.hpp"
17
18#include <string>
19
20#include "AST/Copy.hpp"
21#include "AST/Expr.hpp"
22#include "AST/Pass.hpp"
23#include "AST/Print.hpp"
24#include "AST/Stmt.hpp"
25#include "AST/Type.hpp"
26#include "Common/UniqueName.hpp"
27
28using namespace ast;
29using namespace std;
30
31/* So this is what this pass dones:
32{
33        when ( condA ) waituntil( A ){ doA(); }
34        or when ( condB ) waituntil( B ){ doB(); }
35        and when ( condC ) waituntil( C ) { doC(); }
36}
37                 ||
38                 ||
39                \||/
40                 \/
41
42Generates these two routines:
43static inline bool is_full_sat_1( int * clause_statuses ) {
44        return clause_statuses[0]
45                || clause_statuses[1]
46                && clause_statuses[2];
47}
48
49static inline bool is_done_sat_1( int * clause_statuses ) {
50        return has_run(clause_statuses[0])
51                || has_run(clause_statuses[1])
52                && has_run(clause_statuses[2]);
53}
54
55Replaces the waituntil statement above with the following code:
56{
57        // used with atomic_dec/inc to get binary semaphore behaviour
58        int park_counter = 0;
59
60        // status (one for each clause)
61        int clause_statuses[3] = { 0 };
62
63        bool whenA = condA;
64        bool whenB = condB;
65        bool whenC = condC;
66
67        if ( !whenB ) clause_statuses[1] = __SELECT_RUN;
68        if ( !whenC ) clause_statuses[2] = __SELECT_RUN;
69
70        // some other conditional settors for clause_statuses are set here, see genSubtreeAssign and related routines
71
72        // three blocks
73        // for each block, create, setup, then register select_node
74        select_node clause1;
75        select_node clause2;
76        select_node clause3;
77
78        try {
79                if ( whenA ) { register_select(A, clause1); setup_clause( clause1, &clause_statuses[0], &park_counter ); }
80                ... repeat ^ for B and C ...
81
82                // if else clause is defined a separate branch can occur here to set initial values, see genWhenStateConditions
83
84                // loop & park until done
85                while( !is_full_sat_1( clause_statuses ) ) {
86
87                        // binary sem P();
88                        if ( __atomic_sub_fetch( &park_counter, 1, __ATOMIC_SEQ_CST) < 0 )
89                                park();
90
91                        // execute any blocks available with status set to 0
92                        for ( int i = 0; i < 3; i++ ) {
93                                if (clause_statuses[i] == __SELECT_SAT) {
94                                    switch (i) {
95                                        case 0:
96                                            try {
97                                                    on_selected( A, clause1 );
98                                                    doA();
99                                            }
100                                            finally { clause_statuses[i] = __SELECT_RUN; unregister_select(A, clause1); }
101                                            break;
102                                        case 1:
103                                            ... same gen as A but for B and clause2 ...
104                                            break;
105                                        case 2:
106                                            ... same gen as A but for C and clause3 ...
107                                            break;
108                                    }
109                                }
110                        }
111                }
112
113                // ensure that the blocks that triggered is_full_sat_1 are run
114                // by running every un-run block that is SAT from the start until
115                // the predicate is SAT when considering RUN status = true
116                for ( int i = 0; i < 3; i++ ) {
117                        if (is_done_sat_1( clause_statuses )) break;
118                        if (clause_statuses[i] == __SELECT_SAT)
119                                ... Same if body here as in loop above ...
120                }
121        } finally {
122                // the unregister and on_selected calls are needed to support primitives where the acquire has side effects
123                // so the corresponding block MUST be run for those primitives to not lose state (example is channels)
124                if ( !has_run(clause_statuses[0]) && whenA && unregister_select(A, clause1) )
125                        on_selected( A, clause1 )
126                        doA();
127                ... repeat if above for B and C ...
128        }
129}
130
131*/
132
133namespace Concurrency {
134
135class GenerateWaitUntilCore final {
136        vector<FunctionDecl *> & satFns;
137        UniqueName namer_sat = "__is_full_sat_"s;
138        UniqueName namer_run = "__is_run_sat_"s;
139        UniqueName namer_park = "__park_counter_"s;
140        UniqueName namer_status = "__clause_statuses_"s;
141        UniqueName namer_node = "__clause_"s;
142        UniqueName namer_target = "__clause_target_"s;
143        UniqueName namer_when = "__when_cond_"s;
144        UniqueName namer_label = "__waituntil_label_"s;
145
146        string idxName = "__CFA_clause_idx_";
147
148        struct ClauseData {
149                string nodeName;
150                string targetName;
151                string whenName;
152                int index;
153                string & statusName;
154                ClauseData( int index, string & statusName ) : index(index), statusName(statusName) {}
155        };
156
157        const StructDecl * selectNodeDecl = nullptr;
158
159        // This first set of routines are all used to do the complicated job of
160        //    dealing with how to set predicate statuses with certain when_conds T/F
161        //    so that the when_cond == F effectively makes that clause "disappear"
162        void updateAmbiguousWhen( WaitUntilStmt::ClauseNode * currNode, bool andAbove, bool orAbove, bool andBelow, bool orBelow );
163        void paintWhenTree( WaitUntilStmt::ClauseNode * currNode, bool andAbove, bool orAbove, bool & andBelow, bool & orBelow );
164        bool paintWhenTree( WaitUntilStmt::ClauseNode * currNode );
165        void collectWhens( WaitUntilStmt::ClauseNode * currNode, vector<pair<int, WaitUntilStmt::ClauseNode *>> & ambigIdxs, vector<int> & andIdxs, int & index, bool parentAmbig, bool parentAnd );
166        void collectWhens( WaitUntilStmt::ClauseNode * currNode, vector<pair<int, WaitUntilStmt::ClauseNode *>> & ambigIdxs, vector<int> & andIdxs );
167        void updateWhenState( WaitUntilStmt::ClauseNode * currNode );
168        void genSubtreeAssign( const WaitUntilStmt * stmt, WaitUntilStmt::ClauseNode * currNode, bool status, int & idx, CompoundStmt * retStmt, vector<ClauseData *> & clauseData );
169        void genStatusAssign( const WaitUntilStmt * stmt, WaitUntilStmt::ClauseNode * currNode, int & idx, CompoundStmt * retStmt, vector<ClauseData *> & clauseData );
170        CompoundStmt * getStatusAssignment( const WaitUntilStmt * stmt, vector<ClauseData *> & clauseData );
171        Stmt * genWhenStateConditions( const WaitUntilStmt * stmt, vector<ClauseData *> & clauseData, vector<pair<int, WaitUntilStmt::ClauseNode *>> & ambigClauses, vector<pair<int, WaitUntilStmt::ClauseNode *>>::size_type ambigIdx );
172
173        // These routines are just code-gen helpers
174        void addPredicates( const WaitUntilStmt * stmt, string & satName, string & runName );
175        void setUpClause( const WhenClause * clause, ClauseData * data, string & pCountName, CompoundStmt * body );
176        CompoundStmt * genStatusCheckFor( const WaitUntilStmt * stmt, vector<ClauseData *> & clauseData, string & predName );
177        Expr * genSelectTraitCall( const WhenClause * clause, const ClauseData * data, string fnName );
178        CompoundStmt * genStmtBlock( const WhenClause * clause, const ClauseData * data );
179        Stmt * genElseClauseBranch( const WaitUntilStmt * stmt, string & runName, string & arrName, vector<ClauseData *> & clauseData );
180        Stmt * genNoElseClauseBranch( const WaitUntilStmt * stmt, string & runName, string & arrName, string & pCountName, vector<ClauseData *> & clauseData );
181        void genClauseInits( const WaitUntilStmt * stmt, vector<ClauseData *> & clauseData, CompoundStmt * body, string & statusName, string & elseWhenName );
182        Stmt * recursiveOrIfGen( const WaitUntilStmt * stmt, vector<ClauseData *> & data, vector<ClauseData*>::size_type idx, string & elseWhenName );
183        Stmt * buildOrCaseSwitch( const WaitUntilStmt * stmt, string & statusName, vector<ClauseData *> & data );
184        Stmt * genAllOr( const WaitUntilStmt * stmt );
185
186  public:
187        void previsit( const StructDecl * decl );
188        Stmt * postvisit( const WaitUntilStmt * stmt );
189        GenerateWaitUntilCore( vector<FunctionDecl *> & satFns ): satFns(satFns) {}
190};
191
192// Finds select_node decl
193void GenerateWaitUntilCore::previsit( const StructDecl * decl ) {
194        if ( !decl->body ) {
195                return;
196        } else if ( "select_node" == decl->name ) {
197                assert( !selectNodeDecl );
198                selectNodeDecl = decl;
199        }
200}
201
202void GenerateWaitUntilCore::updateAmbiguousWhen( WaitUntilStmt::ClauseNode * currNode, bool andAbove, bool orAbove, bool andBelow, bool orBelow ) {
203        // all children when-ambiguous
204        if ( currNode->left->ambiguousWhen && currNode->right->ambiguousWhen )
205                // true iff an ancestor/descendant has a different operation
206                currNode->ambiguousWhen = (orAbove || orBelow) && (andBelow || andAbove);
207        // ambiguousWhen is initially false so theres no need to set it here
208}
209
210// Traverses ClauseNode tree and paints each AND/OR node as when-ambiguous true or false
211// This tree painting is needed to generate the if statements that set the initial state
212//    of the clause statuses when some clauses are turned off via when_cond
213// An internal AND/OR node is when-ambiguous if it satisfies all of the following:
214// - It has an ancestor or descendant that is a different operation, i.e. (AND has an OR ancestor or vice versa)
215// - All of its descendent clauses are optional, i.e. they have a when_cond defined on the WhenClause
216void GenerateWaitUntilCore::paintWhenTree( WaitUntilStmt::ClauseNode * currNode, bool andAbove, bool orAbove, bool & andBelow, bool & orBelow ) {
217        bool aBelow = false; // updated by child nodes
218        bool oBelow = false; // updated by child nodes
219        switch (currNode->op) {
220                case WaitUntilStmt::ClauseNode::AND:
221                        paintWhenTree( currNode->left, true, orAbove, aBelow, oBelow );
222                        paintWhenTree( currNode->right, true, orAbove, aBelow, oBelow );
223
224                        // update currNode's when flag based on conditions listed in fn signature comment above
225                        updateAmbiguousWhen(currNode, true, orAbove, aBelow, oBelow );
226
227                        // set return flags to tell parents which decendant ops have been seen
228                        andBelow = true;
229                        orBelow = oBelow;
230                        return;
231                case WaitUntilStmt::ClauseNode::OR:
232                        paintWhenTree( currNode->left, andAbove, true, aBelow, oBelow );
233                        paintWhenTree( currNode->right, andAbove, true, aBelow, oBelow );
234
235                        // update currNode's when flag based on conditions listed in fn signature comment above
236                        updateAmbiguousWhen(currNode, andAbove, true, aBelow, oBelow );
237
238                        // set return flags to tell parents which decendant ops have been seen
239                        andBelow = aBelow;
240                        orBelow = true;
241                        return;
242                case WaitUntilStmt::ClauseNode::LEAF:
243                        if ( currNode->leaf->when_cond )
244                                currNode->ambiguousWhen = true;
245                        return;
246                default:
247                        assertf(false, "Unreachable waituntil clause node type. How did you get here???");
248        }
249}
250
251// overloaded wrapper for paintWhenTree that sets initial values
252// returns true if entire tree is OR's (special case)
253bool GenerateWaitUntilCore::paintWhenTree( WaitUntilStmt::ClauseNode * currNode ) {
254        bool aBelow = false, oBelow = false; // unused by initial call
255        paintWhenTree( currNode, false, false, aBelow, oBelow );
256        return !aBelow;
257}
258
259// Helper: returns Expr that represents arrName[index]
260Expr * genArrAccessExpr( const CodeLocation & loc, int index, string arrName ) {
261        return new UntypedExpr ( loc,
262                new NameExpr( loc, "?[?]" ),
263                {
264                        new NameExpr( loc, arrName ),
265                        ConstantExpr::from_int( loc, index )
266                }
267        );
268}
269
270// After the ClauseNode AND/OR nodes are painted this routine is called to traverses the tree and does the following:
271// - collects a set of indices in the clause arr that refer whenclauses that can have ambiguous status assignments (ambigIdxs)
272// - collects a set of indices in the clause arr that refer whenclauses that have a when() defined and an AND node as a parent (andIdxs)
273// - updates LEAF nodes to be when-ambiguous if their direct parent is when-ambiguous.
274void GenerateWaitUntilCore::collectWhens( WaitUntilStmt::ClauseNode * currNode, vector<pair<int, WaitUntilStmt::ClauseNode *>> & ambigIdxs, vector<int> & andIdxs, int & index, bool parentAmbig, bool parentAnd ) {
275        switch (currNode->op) {
276                case WaitUntilStmt::ClauseNode::AND:
277                        collectWhens( currNode->left, ambigIdxs, andIdxs, index, currNode->ambiguousWhen, true );
278                        collectWhens( currNode->right,  ambigIdxs, andIdxs, index, currNode->ambiguousWhen, true );
279                        return;
280                case WaitUntilStmt::ClauseNode::OR:
281                        collectWhens( currNode->left,  ambigIdxs, andIdxs, index, currNode->ambiguousWhen, false );
282                        collectWhens( currNode->right,  ambigIdxs, andIdxs, index, currNode->ambiguousWhen, false );
283                        return;
284                case WaitUntilStmt::ClauseNode::LEAF:
285                        if ( parentAmbig ) {
286                                ambigIdxs.push_back(make_pair(index, currNode));
287                        }
288                        if ( parentAnd && currNode->leaf->when_cond ) {
289                                currNode->childOfAnd = true;
290                                andIdxs.push_back(index);
291                        }
292                        index++;
293                        return;
294                default:
295                        assertf(false, "Unreachable waituntil clause node type. How did you get here???");
296        }
297}
298
299// overloaded wrapper for collectWhens that sets initial values
300void GenerateWaitUntilCore::collectWhens( WaitUntilStmt::ClauseNode * currNode, vector<pair<int, WaitUntilStmt::ClauseNode *>> & ambigIdxs, vector<int> & andIdxs ) {
301        int idx = 0;
302        collectWhens( currNode, ambigIdxs, andIdxs, idx, false, false );
303}
304
305// recursively updates ClauseNode whenState on internal nodes so that next pass can see which
306//    subtrees are "turned off"
307// sets whenState = false iff both children have whenState == false.
308// similar to paintWhenTree except since paintWhenTree also filtered out clauses we don't need to consider based on op
309// since the ambiguous clauses were filtered in paintWhenTree we don't need to worry about that here
310void GenerateWaitUntilCore::updateWhenState( WaitUntilStmt::ClauseNode * currNode ) {
311        if ( currNode->op == WaitUntilStmt::ClauseNode::LEAF ) return;
312        updateWhenState( currNode->left );
313        updateWhenState( currNode->right );
314        if ( !currNode->left->whenState && !currNode->right->whenState )
315                currNode->whenState = false;
316        else
317                currNode->whenState = true;
318}
319
320// generates the minimal set of status assignments to ensure predicate subtree passed as currNode evaluates to status
321// assumes that this will only be called on subtrees that are entirely whenState == false
322void GenerateWaitUntilCore::genSubtreeAssign( const WaitUntilStmt * stmt, WaitUntilStmt::ClauseNode * currNode, bool status, int & idx, CompoundStmt * retStmt, vector<ClauseData *> & clauseData ) {
323        if ( ( currNode->op == WaitUntilStmt::ClauseNode::AND && status )
324                || ( currNode->op == WaitUntilStmt::ClauseNode::OR && !status ) ) {
325                // need to recurse on both subtrees if && subtree needs to be true or || subtree needs to be false
326                genSubtreeAssign( stmt, currNode->left, status, idx, retStmt, clauseData );
327                genSubtreeAssign( stmt, currNode->right, status, idx, retStmt, clauseData );
328        } else if ( ( currNode->op == WaitUntilStmt::ClauseNode::OR && status )
329                || ( currNode->op == WaitUntilStmt::ClauseNode::AND && !status ) ) {
330                // only one subtree needs to evaluate to status if && subtree needs to be true or || subtree needs to be false
331                CompoundStmt * leftStmt = new CompoundStmt( stmt->location );
332                CompoundStmt * rightStmt = new CompoundStmt( stmt->location );
333
334                // only one side needs to evaluate to status so we recurse on both subtrees
335                //    but only keep the statements from the subtree with minimal statements
336                genSubtreeAssign( stmt, currNode->left, status, idx, leftStmt, clauseData );
337                genSubtreeAssign( stmt, currNode->right, status, idx, rightStmt, clauseData );
338
339                // append minimal statements to retStmt
340                if ( leftStmt->kids.size() < rightStmt->kids.size() ) {
341                        retStmt->kids.splice( retStmt->kids.end(), leftStmt->kids );
342                } else {
343                        retStmt->kids.splice( retStmt->kids.end(), rightStmt->kids );
344                }
345
346                delete leftStmt;
347                delete rightStmt;
348        } else if ( currNode->op == WaitUntilStmt::ClauseNode::LEAF ) {
349                const CodeLocation & loc = stmt->location;
350                if ( status && !currNode->childOfAnd ) {
351                        retStmt->push_back(
352                                new ExprStmt( loc,
353                                    UntypedExpr::createAssign( loc,
354                                        genArrAccessExpr( loc, idx, clauseData.at(idx)->statusName ),
355                                        new NameExpr( loc, "__SELECT_RUN" )
356                                    )
357                                )
358                        );
359                } else if ( !status && currNode->childOfAnd ) {
360                        retStmt->push_back(
361                                new ExprStmt( loc,
362                                    UntypedExpr::createAssign( loc,
363                                        genArrAccessExpr( loc, idx, clauseData.at(idx)->statusName ),
364                                        new NameExpr( loc, "__SELECT_UNSAT" )
365                                    )
366                                )
367                        );
368                }
369
370                // No need to generate statements for the following cases since childOfAnd are always set to true
371                //    and !childOfAnd are always false
372                // - status && currNode->childOfAnd
373                // - !status && !currNode->childOfAnd
374                idx++;
375        }
376}
377
378void GenerateWaitUntilCore::genStatusAssign( const WaitUntilStmt * stmt, WaitUntilStmt::ClauseNode * currNode, int & idx, CompoundStmt * retStmt, vector<ClauseData *> & clauseData ) {
379        switch (currNode->op) {
380                case WaitUntilStmt::ClauseNode::AND:
381                        // check which subtrees have all whenState == false (disabled)
382                        if (!currNode->left->whenState && !currNode->right->whenState) {
383                                // this case can only occur when whole tree is disabled since otherwise
384                                //    genStatusAssign( ... ) isn't called on nodes with whenState == false
385                                assert( !currNode->whenState ); // paranoidWWW
386                                // whole tree disabled so pass true so that select is SAT vacuously
387                                genSubtreeAssign( stmt, currNode, true, idx, retStmt, clauseData );
388                        } else if ( !currNode->left->whenState ) {
389                                // pass true since x && true === x
390                                genSubtreeAssign( stmt, currNode->left, true, idx, retStmt, clauseData );
391                                genStatusAssign( stmt, currNode->right, idx, retStmt, clauseData );
392                        } else if ( !currNode->right->whenState ) {
393                                genStatusAssign( stmt, currNode->left, idx, retStmt, clauseData );
394                                genSubtreeAssign( stmt, currNode->right, true, idx, retStmt, clauseData );
395                        } else {
396                                // if no children with whenState == false recurse normally via break
397                                break;
398                        }
399                        return;
400                case WaitUntilStmt::ClauseNode::OR:
401                        if (!currNode->left->whenState && !currNode->right->whenState) {
402                                assert( !currNode->whenState ); // paranoid
403                                genSubtreeAssign( stmt, currNode, true, idx, retStmt, clauseData );
404                        } else if ( !currNode->left->whenState ) {
405                                // pass false since x || false === x
406                                genSubtreeAssign( stmt, currNode->left, false, idx, retStmt, clauseData );
407                                genStatusAssign( stmt, currNode->right, idx, retStmt, clauseData );
408                        } else if ( !currNode->right->whenState ) {
409                                genStatusAssign( stmt, currNode->left, idx, retStmt, clauseData );
410                                genSubtreeAssign( stmt, currNode->right, false, idx, retStmt, clauseData );
411                        } else {
412                                break;
413                        }
414                        return;
415                case WaitUntilStmt::ClauseNode::LEAF:
416                        idx++;
417                        return;
418                default:
419                        assertf(false, "Unreachable waituntil clause node type. How did you get here???");
420        }
421        genStatusAssign( stmt, currNode->left, idx, retStmt, clauseData );
422        genStatusAssign( stmt, currNode->right, idx, retStmt, clauseData );
423}
424
425// generates a minimal set of assignments for status arr based on which whens are toggled on/off
426CompoundStmt * GenerateWaitUntilCore::getStatusAssignment( const WaitUntilStmt * stmt, vector<ClauseData *> & clauseData ) {
427        updateWhenState( stmt->predicateTree );
428        CompoundStmt * retval = new CompoundStmt( stmt->location );
429        int idx = 0;
430        genStatusAssign( stmt, stmt->predicateTree, idx, retval, clauseData );
431        return retval;
432}
433
434// generates nested if/elses for all possible assignments of ambiguous when_conds
435// exponential size of code gen but linear runtime O(n), where n is number of ambiguous whens()
436Stmt * GenerateWaitUntilCore::genWhenStateConditions( const WaitUntilStmt * stmt, vector<ClauseData *> & clauseData,
437        vector<pair<int, WaitUntilStmt::ClauseNode *>> & ambigClauses, vector<pair<int, WaitUntilStmt::ClauseNode *>>::size_type ambigIdx ) {
438        // I hate C++ sometimes, using vector<pair<int, WaitUntilStmt::ClauseNode *>>::size_type for size() comparison seems silly.
439        //    Why is size_type parameterized on the type stored in the vector?????
440
441        const CodeLocation & loc = stmt->location;
442        int clauseIdx = ambigClauses.at(ambigIdx).first;
443        WaitUntilStmt::ClauseNode * currNode = ambigClauses.at(ambigIdx).second;
444        Stmt * thenStmt;
445        Stmt * elseStmt;
446
447        if ( ambigIdx == ambigClauses.size() - 1 ) { // base case
448                currNode->whenState = true;
449                thenStmt = getStatusAssignment( stmt, clauseData );
450                currNode->whenState = false;
451                elseStmt = getStatusAssignment( stmt, clauseData );
452        } else {
453                // recurse both with when enabled and disabled to generate all possible cases
454                currNode->whenState = true;
455                thenStmt = genWhenStateConditions( stmt, clauseData, ambigClauses, ambigIdx + 1 );
456                currNode->whenState = false;
457                elseStmt = genWhenStateConditions( stmt, clauseData, ambigClauses, ambigIdx + 1 );
458        }
459
460        // insert first recursion result in if ( __when_cond_ ) { ... }
461        // insert second recursion result in else { ... }
462        return new CompoundStmt ( loc,
463                {
464                        new IfStmt( loc,
465                                new NameExpr( loc, clauseData.at(clauseIdx)->whenName ),
466                                thenStmt,
467                                elseStmt
468                        )
469                }
470        );
471}
472
473// typedef a fn ptr so that we can reuse genPredExpr
474// genLeafExpr is used to refer to one of the following two routines
475typedef Expr * (*GenLeafExpr)( const CodeLocation & loc, int & index );
476
477// return Expr that represents clause_statuses[index]
478// mutates index to be index + 1
479Expr * genSatExpr( const CodeLocation & loc, int & index ) {
480        return genArrAccessExpr( loc, index++, "clause_statuses" );
481}
482
483// return Expr that represents has_run(clause_statuses[index])
484Expr * genRunExpr( const CodeLocation & loc, int & index ) {
485        return new UntypedExpr ( loc,
486                new NameExpr( loc, "__CFA_has_clause_run" ),
487                { genSatExpr( loc, index ) }
488        );
489}
490
491// Takes in the ClauseNode tree and recursively generates
492// the predicate expr used inside the predicate functions
493Expr * genPredExpr( const CodeLocation & loc, WaitUntilStmt::ClauseNode * currNode, int & idx, GenLeafExpr genLeaf ) {
494        Expr * leftExpr, * rightExpr;
495        switch (currNode->op) {
496                case WaitUntilStmt::ClauseNode::AND:
497                        leftExpr = genPredExpr( loc, currNode->left, idx, genLeaf );
498                        rightExpr = genPredExpr( loc, currNode->right, idx, genLeaf );
499                        return new LogicalExpr( loc,
500                                new CastExpr( loc, leftExpr, new BasicType( BasicKind::Bool ), GeneratedFlag::ExplicitCast ),
501                                new CastExpr( loc, rightExpr, new BasicType( BasicKind::Bool ), GeneratedFlag::ExplicitCast ),
502                                LogicalFlag::AndExpr
503                        );
504                        break;
505                case WaitUntilStmt::ClauseNode::OR:
506                        leftExpr = genPredExpr( loc, currNode->left, idx, genLeaf );
507                        rightExpr = genPredExpr( loc, currNode->right, idx, genLeaf );
508                        return new LogicalExpr( loc,
509                                new CastExpr( loc, leftExpr, new BasicType( BasicKind::Bool ), GeneratedFlag::ExplicitCast ),
510                                new CastExpr( loc, rightExpr, new BasicType( BasicKind::Bool ), GeneratedFlag::ExplicitCast ),
511                                LogicalFlag::OrExpr );
512                        break;
513                case WaitUntilStmt::ClauseNode::LEAF:
514                        return genLeaf( loc, idx );
515                        break;
516                default:
517                        assertf(false, "Unreachable waituntil clause node type. How did you get here???");\
518                        return nullptr;
519                        break;
520        }
521        return nullptr;
522}
523
524
525// Builds the predicate functions used to check the status of the waituntil statement
526/* Ex:
527{
528        waituntil( A ){ doA(); }
529        or waituntil( B ){ doB(); }
530        and waituntil( C ) { doC(); }
531}
532generates =>
533static inline bool is_full_sat_1( int * clause_statuses ) {
534        return clause_statuses[0]
535                || clause_statuses[1]
536                && clause_statuses[2];
537}
538
539static inline bool is_done_sat_1( int * clause_statuses ) {
540        return has_run(clause_statuses[0])
541                || has_run(clause_statuses[1])
542                && has_run(clause_statuses[2]);
543}
544*/
545// Returns a predicate function decl
546// predName and genLeaf determine if this generates an is_done or an is_full predicate
547FunctionDecl * buildPredicate( const WaitUntilStmt * stmt, GenLeafExpr genLeaf, string & predName ) {
548        int arrIdx = 0;
549        const CodeLocation & loc = stmt->location;
550        CompoundStmt * body = new CompoundStmt( loc );
551        body->push_back( new ReturnStmt( loc, genPredExpr( loc,  stmt->predicateTree, arrIdx, genLeaf ) ) );
552
553        return new FunctionDecl( loc,
554                predName,
555                {
556                        new ObjectDecl( loc,
557                                "clause_statuses",
558                                new PointerType( new BasicType( BasicKind::LongUnsignedInt ) )
559                        )
560                },
561                {
562                        new ObjectDecl( loc,
563                                "sat_ret",
564                                new BasicType( BasicKind::Bool )
565                        )
566                },
567                body,               // body
568                { Storage::Static },    // storage
569                Linkage::Cforall,       // linkage
570                {},                     // attributes
571                { Function::Inline }
572        );
573}
574
575// Creates is_done and is_full predicates
576void GenerateWaitUntilCore::addPredicates( const WaitUntilStmt * stmt, string & satName, string & runName ) {
577        if ( !stmt->else_stmt || stmt->else_cond ) // don't need SAT predicate when else variation with no else_cond
578                satFns.push_back( Concurrency::buildPredicate( stmt, genSatExpr, satName ) );
579        satFns.push_back( Concurrency::buildPredicate( stmt, genRunExpr, runName ) );
580}
581
582// Adds the following to body:
583// if ( when_cond ) { // this if is omitted if no when() condition
584//      setup_clause( clause1, &clause_statuses[0], &park_counter );
585//      register_select(A, clause1);
586// }
587void GenerateWaitUntilCore::setUpClause( const WhenClause * clause, ClauseData * data, string & pCountName, CompoundStmt * body ) {
588        CompoundStmt * currBody = body;
589        const CodeLocation & loc = clause->location;
590
591        // If we have a when_cond make the initialization conditional
592        if ( clause->when_cond )
593                currBody = new CompoundStmt( loc );
594
595        // Generates: setup_clause( clause1, &clause_statuses[0], &park_counter );
596        currBody->push_back( new ExprStmt( loc,
597                new UntypedExpr ( loc,
598                        new NameExpr( loc, "setup_clause" ),
599                        {
600                                new NameExpr( loc, data->nodeName ),
601                                new AddressExpr( loc, genArrAccessExpr( loc, data->index, data->statusName ) ),
602                                new AddressExpr( loc, new NameExpr( loc, pCountName ) )
603                        }
604                )
605        ));
606
607        // Generates: register_select(A, clause1);
608        currBody->push_back( new ExprStmt( loc, genSelectTraitCall( clause, data, "register_select" ) ) );
609
610        // generates: if ( when_cond ) { ... currBody ... }
611        if ( clause->when_cond )
612                body->push_back(
613                        new IfStmt( loc,
614                                new NameExpr( loc, data->whenName ),
615                                currBody
616                        )
617                );
618}
619
620// Used to generate a call to one of the select trait routines
621Expr * GenerateWaitUntilCore::genSelectTraitCall( const WhenClause * clause, const ClauseData * data, string fnName ) {
622        const CodeLocation & loc = clause->location;
623        return new UntypedExpr ( loc,
624                new NameExpr( loc, fnName ),
625                {
626                        new NameExpr( loc, data->targetName ),
627                        new NameExpr( loc, data->nodeName )
628                }
629        );
630}
631
632// Generates:
633/* on_selected( target_1, node_1 ); ... corresponding body of target_1 ...
634*/
635CompoundStmt * GenerateWaitUntilCore::genStmtBlock( const WhenClause * clause, const ClauseData * data ) {
636        const CodeLocation & cLoc = clause->location;
637        return new CompoundStmt( cLoc,
638                {
639                        new IfStmt( cLoc,
640                                genSelectTraitCall( clause, data, "on_selected" ),
641                                ast::deepCopy( clause->stmt )
642                        )
643                }
644        );
645}
646
647// this routine generates and returns the following
648/*for ( int i = 0; i < numClauses; i++ ) {
649        if ( predName(clause_statuses) ) break;
650        if (clause_statuses[i] == __SELECT_SAT) {
651                switch (i) {
652                        case 0:
653                                try {
654                                    on_selected( target1, clause1 );
655                                    dotarget1stmt();
656                                }
657                                finally { clause_statuses[i] = __SELECT_RUN; unregister_select(target1, clause1); }
658                                break;
659                        ...
660                        case N:
661                                ...
662                                break;
663                }
664        }
665}*/
666CompoundStmt * GenerateWaitUntilCore::genStatusCheckFor( const WaitUntilStmt * stmt, vector<ClauseData *> & clauseData, string & predName ) {
667        CompoundStmt * ifBody = new CompoundStmt( stmt->location );
668        const CodeLocation & loc = stmt->location;
669
670        string switchLabel = namer_label.newName();
671
672        /* generates:
673        switch (i) {
674                case 0:
675                        try {
676                                on_selected( target1, clause1 );
677                                dotarget1stmt();
678                        }
679                        finally { clause_statuses[i] = __SELECT_RUN; unregister_select(target1, clause1); }
680                        break;
681                        ...
682                case N:
683                        ...
684                        break;
685        }*/
686        std::vector<ptr<CaseClause>> switchCases;
687        int idx = 0;
688        for ( const auto & clause: stmt->clauses ) {
689                const CodeLocation & cLoc = clause->location;
690                switchCases.push_back(
691                        new CaseClause( cLoc,
692                                ConstantExpr::from_int( cLoc, idx ),
693                                {
694                                    new CompoundStmt( cLoc,
695                                        {
696                                            new ast::TryStmt( cLoc,
697                                                genStmtBlock( clause, clauseData.at(idx) ),
698                                                {},
699                                                new ast::FinallyClause( cLoc,
700                                                    new CompoundStmt( cLoc,
701                                                        {
702                                                            new ExprStmt( loc,
703                                                                new UntypedExpr ( loc,
704                                                                    new NameExpr( loc, "?=?" ),
705                                                                    {
706                                                                        new UntypedExpr ( loc,
707                                                                            new NameExpr( loc, "?[?]" ),
708                                                                            {
709                                                                                new NameExpr( loc, clauseData.at(0)->statusName ),
710                                                                                new NameExpr( loc, idxName )
711                                                                            }
712                                                                        ),
713                                                                        new NameExpr( loc, "__SELECT_RUN" )
714                                                                    }
715                                                                )
716                                                            ),
717                                                            new ExprStmt( loc, genSelectTraitCall( clause, clauseData.at(idx), "unregister_select" ) )
718                                                        }
719                                                    )
720                                                )
721                                            ),
722                                            new BranchStmt( cLoc, BranchStmt::Kind::Break, Label( cLoc, switchLabel ) )
723                                        }
724                                    )
725                                }
726                        )
727                );
728                idx++;
729        }
730
731        ifBody->push_back(
732                new SwitchStmt( loc,
733                        new NameExpr( loc, idxName ),
734                        std::move( switchCases ),
735                        { Label( loc, switchLabel ) }
736                )
737        );
738
739        // gens:
740        // if (clause_statuses[i] == __SELECT_SAT) {
741        //      ... ifBody  ...
742        // }
743        IfStmt * ifSwitch = new IfStmt( loc,
744                new UntypedExpr ( loc,
745                        new NameExpr( loc, "?==?" ),
746                        {
747                                new UntypedExpr ( loc,
748                                    new NameExpr( loc, "?[?]" ),
749                                    {
750                                        new NameExpr( loc, clauseData.at(0)->statusName ),
751                                        new NameExpr( loc, idxName )
752                                    }
753                                ),
754                                new NameExpr( loc, "__SELECT_SAT" )
755                        }
756                ),      // condition
757                ifBody  // body
758        );
759
760        string forLabel = namer_label.newName();
761
762        // we hoist init here so that this pass can happen after hoistdecls pass
763        return new CompoundStmt( loc,
764                {
765                        new DeclStmt( loc,
766                                new ObjectDecl( loc,
767                                    idxName,
768                                    new BasicType( BasicKind::SignedInt ),
769                                    new SingleInit( loc, ConstantExpr::from_int( loc, 0 ) )
770                                )
771                        ),
772                        new ForStmt( loc,
773                                {},  // inits
774                                new UntypedExpr ( loc,
775                                    new NameExpr( loc, "?<?" ),
776                                    {
777                                        new NameExpr( loc, idxName ),
778                                        ConstantExpr::from_int( loc, stmt->clauses.size() )
779                                    }
780                                ),  // cond
781                                new UntypedExpr ( loc,
782                                    new NameExpr( loc, "?++" ),
783                                    { new NameExpr( loc, idxName ) }
784                                ),  // inc
785                                new CompoundStmt( loc,
786                                    {
787                                        new IfStmt( loc,
788                                            new UntypedExpr ( loc,
789                                                new NameExpr( loc, predName ),
790                                                { new NameExpr( loc, clauseData.at(0)->statusName ) }
791                                            ),
792                                            new BranchStmt( loc, BranchStmt::Kind::Break, Label( loc, forLabel ) )
793                                        ),
794                                        ifSwitch
795                                    }
796                                ),   // body
797                                { Label( loc, forLabel ) }
798                        )
799                }
800        );
801}
802
803// Generates: !is_full_sat_n() / !is_run_sat_n()
804Expr * genNotSatExpr( const WaitUntilStmt * stmt, string & satName, string & arrName ) {
805        const CodeLocation & loc = stmt->location;
806        return new UntypedExpr ( loc,
807                new NameExpr( loc, "!?" ),
808                {
809                        new UntypedExpr ( loc,
810                                new NameExpr( loc, satName ),
811                                { new NameExpr( loc, arrName ) }
812                        )
813                }
814        );
815}
816
817// Generates the code needed for waituntils with an else ( ... )
818// Checks clauses once after registering for completion and runs them if completes
819// If not enough have run to satisfy predicate after one pass then the else is run
820Stmt * GenerateWaitUntilCore::genElseClauseBranch( const WaitUntilStmt * stmt, string & runName, string & arrName, vector<ClauseData *> & clauseData ) {
821        return new CompoundStmt( stmt->else_stmt->location,
822                {
823                        genStatusCheckFor( stmt, clauseData, runName ),
824                        new IfStmt( stmt->else_stmt->location,
825                                genNotSatExpr( stmt, runName, arrName ),
826                                ast::deepCopy( stmt->else_stmt )
827                        )
828                }
829        );
830}
831
832Stmt * GenerateWaitUntilCore::genNoElseClauseBranch( const WaitUntilStmt * stmt, string & runName, string & arrName, string & pCountName, vector<ClauseData *> & clauseData ) {
833        CompoundStmt * whileBody = new CompoundStmt( stmt->location );
834        const CodeLocation & loc = stmt->location;
835
836        // generates: __CFA_maybe_park( &park_counter );
837        whileBody->push_back(
838                new ExprStmt( loc,
839                        new UntypedExpr ( loc,
840                                new NameExpr( loc, "__CFA_maybe_park" ),
841                                { new AddressExpr( loc, new NameExpr( loc, pCountName ) ) }
842                        )
843                )
844        );
845
846        whileBody->push_back( genStatusCheckFor( stmt, clauseData, runName ) );
847
848        return new CompoundStmt( loc,
849                {
850                        new WhileDoStmt( loc,
851                                genNotSatExpr( stmt, runName, arrName ),
852                                whileBody,  // body
853                                {}          // no inits
854                        )
855                }
856        );
857}
858
859// generates the following decls for each clause to ensure the target expr and when_cond is only evaluated once
860// typeof(target) & __clause_target_0 = target;
861// bool __when_cond_0 = when_cond; // only generated if when_cond defined
862// select_node clause1;
863void GenerateWaitUntilCore::genClauseInits( const WaitUntilStmt * stmt, vector<ClauseData *> & clauseData, CompoundStmt * body, string & statusName, string & elseWhenName ) {
864        ClauseData * currClause;
865        for ( vector<ClauseData*>::size_type i = 0; i < stmt->clauses.size(); i++ ) {
866                currClause = new ClauseData( i, statusName );
867                currClause->nodeName = namer_node.newName();
868                currClause->targetName = namer_target.newName();
869                currClause->whenName = namer_when.newName();
870                clauseData.push_back(currClause);
871                const CodeLocation & cLoc = stmt->clauses.at(i)->location;
872
873                // typeof(target) & __clause_target_0 = target;
874                body->push_back(
875                        new DeclStmt( cLoc,
876                                new ObjectDecl( cLoc,
877                                    currClause->targetName,
878                                    new ReferenceType(
879                                        new TypeofType( new UntypedExpr( cLoc,
880                                            new NameExpr( cLoc, "__CFA_select_get_type" ),
881                                            { ast::deepCopy( stmt->clauses.at(i)->target ) }
882                                        ))
883                                    ),
884                                    new SingleInit( cLoc, ast::deepCopy( stmt->clauses.at(i)->target ) )
885                                )
886                        )
887                );
888
889                // bool __when_cond_0 = when_cond; // only generated if when_cond defined
890                if ( stmt->clauses.at(i)->when_cond )
891                        body->push_back(
892                                new DeclStmt( cLoc,
893                                    new ObjectDecl( cLoc,
894                                        currClause->whenName,
895                                        new BasicType( BasicKind::Bool ),
896                                        new SingleInit( cLoc, ast::deepCopy( stmt->clauses.at(i)->when_cond ) )
897                                    )
898                                )
899                        );
900
901                // select_node clause1;
902                body->push_back(
903                        new DeclStmt( cLoc,
904                                new ObjectDecl( cLoc,
905                                    currClause->nodeName,
906                                    new StructInstType( selectNodeDecl )
907                                )
908                        )
909                );
910        }
911
912        if ( stmt->else_stmt && stmt->else_cond ) {
913                body->push_back(
914                        new DeclStmt( stmt->else_cond->location,
915                                new ObjectDecl( stmt->else_cond->location,
916                                    elseWhenName,
917                                    new BasicType( BasicKind::Bool ),
918                                    new SingleInit( stmt->else_cond->location, ast::deepCopy( stmt->else_cond ) )
919                                )
920                        )
921                );
922        }
923}
924
925/*
926if ( clause_status == &clause1 ) ... clause 1 body ...
927...
928elif ( clause_status == &clausen ) ... clause n body ...
929*/
930Stmt * GenerateWaitUntilCore::buildOrCaseSwitch( const WaitUntilStmt * stmt, string & statusName, vector<ClauseData *> & data ) {
931        const CodeLocation & loc = stmt->location;
932
933        IfStmt * outerIf = nullptr;
934        IfStmt * lastIf = nullptr;
935
936        //adds an if/elif clause for each select clause address to run the corresponding clause stmt
937        for ( long unsigned int i = 0; i < data.size(); i++ ) {
938                const CodeLocation & cLoc = stmt->clauses.at(i)->location;
939
940                IfStmt * currIf = new IfStmt( cLoc,
941                        new UntypedExpr( cLoc,
942                                new NameExpr( cLoc, "?==?" ),
943                                {
944                                    new NameExpr( cLoc, statusName ),
945                                    new CastExpr( cLoc,
946                                        new AddressExpr( cLoc, new NameExpr( cLoc, data.at(i)->nodeName ) ),
947                                        new BasicType( BasicKind::LongUnsignedInt ), GeneratedFlag::ExplicitCast
948                                    )
949                                }
950                        ),
951                        genStmtBlock( stmt->clauses.at(i), data.at(i) )
952                );
953
954                if ( i == 0 ) {
955                        outerIf = currIf;
956                } else {
957                        // add ifstmt to else of previous stmt
958                        lastIf->else_ = currIf;
959                }
960
961                lastIf = currIf;
962        }
963
964        return new CompoundStmt( loc,
965                {
966                        new ExprStmt( loc, new UntypedExpr( loc, new NameExpr( loc, "park" ) ) ),
967                        outerIf
968                }
969        );
970}
971
972Stmt * GenerateWaitUntilCore::recursiveOrIfGen( const WaitUntilStmt * stmt, vector<ClauseData *> & data, vector<ClauseData*>::size_type idx, string & elseWhenName ) {
973        if ( idx == data.size() ) {   // base case, gen last else
974                const CodeLocation & cLoc = stmt->else_stmt->location;
975                if ( !stmt->else_stmt ) // normal non-else gen
976                        return buildOrCaseSwitch( stmt, data.at(0)->statusName, data );
977
978                Expr * raceFnCall = new UntypedExpr( stmt->location,
979                        new NameExpr( stmt->location, "__select_node_else_race" ),
980                        { new NameExpr( stmt->location, data.at(0)->nodeName ) }
981                );
982
983                if ( stmt->else_stmt && stmt->else_cond ) { // return else conditional on both when and race
984                        return new IfStmt( cLoc,
985                                new LogicalExpr( cLoc,
986                                    new CastExpr( cLoc,
987                                        new NameExpr( cLoc, elseWhenName ),
988                                        new BasicType( BasicKind::Bool ), GeneratedFlag::ExplicitCast
989                                    ),
990                                    new CastExpr( cLoc,
991                                        raceFnCall,
992                                        new BasicType( BasicKind::Bool ), GeneratedFlag::ExplicitCast
993                                    ),
994                                    LogicalFlag::AndExpr
995                                ),
996                                ast::deepCopy( stmt->else_stmt ),
997                                buildOrCaseSwitch( stmt, data.at(0)->statusName, data )
998                        );
999                }
1000
1001                // return else conditional on race
1002                return new IfStmt( stmt->else_stmt->location,
1003                        raceFnCall,
1004                        ast::deepCopy( stmt->else_stmt ),
1005                        buildOrCaseSwitch( stmt, data.at(0)->statusName, data )
1006                );
1007        }
1008        const CodeLocation & cLoc = stmt->clauses.at(idx)->location;
1009
1010        Expr * baseCond = genSelectTraitCall( stmt->clauses.at(idx), data.at(idx), "register_select" );
1011        Expr * ifCond;
1012
1013        // If we have a when_cond make the register call conditional on it
1014        if ( stmt->clauses.at(idx)->when_cond ) {
1015                ifCond = new LogicalExpr( cLoc,
1016                        new CastExpr( cLoc,
1017                                new NameExpr( cLoc, data.at(idx)->whenName ),
1018                                new BasicType( BasicKind::Bool ), GeneratedFlag::ExplicitCast
1019                        ),
1020                        new CastExpr( cLoc,
1021                                baseCond,
1022                                new BasicType( BasicKind::Bool ), GeneratedFlag::ExplicitCast
1023                        ),
1024                        LogicalFlag::AndExpr
1025                );
1026        } else ifCond = baseCond;
1027
1028        return new CompoundStmt( cLoc,
1029                {   // gens: setup_clause( clause1, &status, 0p );
1030                        new ExprStmt( cLoc,
1031                                new UntypedExpr ( cLoc,
1032                                    new NameExpr( cLoc, "setup_clause" ),
1033                                    {
1034                                        new NameExpr( cLoc, data.at(idx)->nodeName ),
1035                                        new AddressExpr( cLoc, new NameExpr( cLoc, data.at(idx)->statusName ) ),
1036                                        ConstantExpr::null( cLoc, new PointerType( new BasicType( BasicKind::SignedInt ) ) )
1037                                    }
1038                                )
1039                        ),
1040                        // gens: if (__when_cond && register_select()) { clause body } else { ... recursiveOrIfGen ... }
1041                        new IfStmt( cLoc,
1042                                ifCond,
1043                                genStmtBlock( stmt->clauses.at(idx), data.at(idx) ),
1044                                recursiveOrIfGen( stmt, data, idx + 1, elseWhenName )
1045                        )
1046                }
1047        );
1048}
1049
1050// This gens the special case of an all OR waituntil:
1051/*
1052int status = 0;
1053
1054typeof(target) & __clause_target_0 = target;
1055bool __when_cond_0 = when_cond; // only generated if when_cond defined
1056select_node clause1;
1057... generate above for rest of clauses ...
1058
1059try {
1060        setup_clause( clause1, &status, 0p );
1061        if ( __when_cond_0 && register_select( 1 ) ) {
1062                ... clause 1 body ...
1063        } else {
1064                ... recursively gen for each of n clauses ...
1065                setup_clause( clausen, &status, 0p );
1066                if ( __when_cond_n-1 && register_select( n ) ) {
1067                        ... clause n body ...
1068                } else {
1069                        if ( else_when ) ... else clause body ...
1070                        else {
1071                                park();
1072
1073                                // after winning the race and before unpark() clause_status is set to be the winning clause index + 1
1074                                if ( clause_status == &clause1) ... clause 1 body ...
1075                                ...
1076                                elif ( clause_status == &clausen ) ... clause n body ...
1077                        }
1078                }
1079        }
1080}
1081finally {
1082        if ( __when_cond_1 && clause1.status != 0p) unregister_select( 1 ); // if registered unregister
1083        ...
1084        if ( __when_cond_n && clausen.status != 0p) unregister_select( n );
1085}
1086*/
1087Stmt * GenerateWaitUntilCore::genAllOr( const WaitUntilStmt * stmt ) {
1088        const CodeLocation & loc = stmt->location;
1089        string statusName = namer_status.newName();
1090        string elseWhenName = namer_when.newName();
1091        int numClauses = stmt->clauses.size();
1092        CompoundStmt * body = new CompoundStmt( stmt->location );
1093
1094        // Generates: unsigned long int status = 0;
1095        body->push_back( new DeclStmt( loc,
1096                new ObjectDecl( loc,
1097                        statusName,
1098                        new BasicType( BasicKind::LongUnsignedInt ),
1099                        new SingleInit( loc, ConstantExpr::from_int( loc, 0 ) )
1100                )
1101        ));
1102
1103        vector<ClauseData *> clauseData;
1104        genClauseInits( stmt, clauseData, body, statusName, elseWhenName );
1105
1106        vector<int> whenIndices; // track which clauses have whens
1107
1108        CompoundStmt * unregisters = new CompoundStmt( loc );
1109        Expr * ifCond;
1110        for ( int i = 0; i < numClauses; i++ ) {
1111                const CodeLocation & cLoc = stmt->clauses.at(i)->location;
1112                // Gens: node.status != 0p
1113                UntypedExpr * statusPtrCheck = new UntypedExpr( cLoc,
1114                        new NameExpr( cLoc, "?!=?" ),
1115                        {
1116                                ConstantExpr::null( cLoc, new PointerType( new BasicType( BasicKind::LongUnsignedInt ) ) ),
1117                                new UntypedExpr( cLoc,
1118                                    new NameExpr( cLoc, "__get_clause_status" ),
1119                                    { new NameExpr( cLoc, clauseData.at(i)->nodeName ) }
1120                                )
1121                        }
1122                );
1123
1124                // If we have a when_cond make the unregister call conditional on it
1125                if ( stmt->clauses.at(i)->when_cond ) {
1126                        whenIndices.push_back(i);
1127                        ifCond = new LogicalExpr( cLoc,
1128                                new CastExpr( cLoc,
1129                                    new NameExpr( cLoc, clauseData.at(i)->whenName ),
1130                                    new BasicType( BasicKind::Bool ), GeneratedFlag::ExplicitCast
1131                                ),
1132                                new CastExpr( cLoc,
1133                                    statusPtrCheck,
1134                                    new BasicType( BasicKind::Bool ), GeneratedFlag::ExplicitCast
1135                                ),
1136                                LogicalFlag::AndExpr
1137                        );
1138                } else ifCond = statusPtrCheck;
1139
1140                unregisters->push_back(
1141                        new IfStmt( cLoc,
1142                                ifCond,
1143                                new ExprStmt( cLoc, genSelectTraitCall( stmt->clauses.at(i), clauseData.at(i), "unregister_select" ) )
1144                        )
1145                );
1146        }
1147
1148        if ( whenIndices.empty() || whenIndices.size() != stmt->clauses.size() ) {
1149                body->push_back(
1150                                new ast::TryStmt( loc,
1151                                new CompoundStmt( loc, { recursiveOrIfGen( stmt, clauseData, 0, elseWhenName ) } ),
1152                                {},
1153                                new ast::FinallyClause( loc, unregisters )
1154                        )
1155                );
1156        } else { // If all clauses have whens, we need to skip the waituntil if they are all false
1157                Expr * outerIfCond = new NameExpr( loc, clauseData.at( whenIndices.at(0) )->whenName );
1158                Expr * lastExpr = outerIfCond;
1159
1160                for ( vector<int>::size_type i = 1; i < whenIndices.size(); i++ ) {
1161                        outerIfCond = new LogicalExpr( loc,
1162                                new CastExpr( loc,
1163                                    new NameExpr( loc, clauseData.at( whenIndices.at(i) )->whenName ),
1164                                    new BasicType( BasicKind::Bool ), GeneratedFlag::ExplicitCast
1165                                ),
1166                                new CastExpr( loc,
1167                                    lastExpr,
1168                                    new BasicType( BasicKind::Bool ), GeneratedFlag::ExplicitCast
1169                                ),
1170                                LogicalFlag::OrExpr
1171                        );
1172                        lastExpr = outerIfCond;
1173                }
1174
1175                body->push_back(
1176                                new ast::TryStmt( loc,
1177                                new CompoundStmt( loc,
1178                                    {
1179                                        new IfStmt( loc,
1180                                            outerIfCond,
1181                                            recursiveOrIfGen( stmt, clauseData, 0, elseWhenName )
1182                                        )
1183                                    }
1184                                ),
1185                                {},
1186                                new ast::FinallyClause( loc, unregisters )
1187                        )
1188                );
1189        }
1190
1191        for ( ClauseData * datum : clauseData )
1192                delete datum;
1193
1194        return body;
1195}
1196
1197Stmt * GenerateWaitUntilCore::postvisit( const WaitUntilStmt * stmt ) {
1198        if ( !selectNodeDecl )
1199                SemanticError( stmt, "waituntil statement requires #include <waituntil.hfa>" );
1200
1201        // Prep clause tree to figure out how to set initial statuses
1202        // setTreeSizes( stmt->predicateTree );
1203        if ( paintWhenTree( stmt->predicateTree ) ) // if this returns true we can special case since tree is all OR's
1204                return genAllOr( stmt );
1205
1206        CompoundStmt * tryBody = new CompoundStmt( stmt->location );
1207        CompoundStmt * body = new CompoundStmt( stmt->location );
1208        string statusArrName = namer_status.newName();
1209        string pCountName = namer_park.newName();
1210        string satName = namer_sat.newName();
1211        string runName = namer_run.newName();
1212        string elseWhenName = namer_when.newName();
1213        int numClauses = stmt->clauses.size();
1214        addPredicates( stmt, satName, runName );
1215
1216        const CodeLocation & loc = stmt->location;
1217
1218        // Generates: int park_counter = 0;
1219        body->push_back( new DeclStmt( loc,
1220                new ObjectDecl( loc,
1221                        pCountName,
1222                        new BasicType( BasicKind::SignedInt ),
1223                        new SingleInit( loc, ConstantExpr::from_int( loc, 0 ) )
1224                )
1225        ));
1226
1227        // Generates: int clause_statuses[3] = { 0 };
1228        body->push_back( new DeclStmt( loc,
1229                new ObjectDecl( loc,
1230                        statusArrName,
1231                        new ArrayType( new BasicType( BasicKind::LongUnsignedInt ), ConstantExpr::from_int( loc, numClauses ), LengthFlag::FixedLen, DimensionFlag::DynamicDim ),
1232                        new ListInit( loc,
1233                                {
1234                                    new SingleInit( loc, ConstantExpr::from_int( loc, 0 ) )
1235                                }
1236                        )
1237                )
1238        ));
1239
1240        vector<ClauseData *> clauseData;
1241        genClauseInits( stmt, clauseData, body, statusArrName, elseWhenName );
1242
1243        vector<pair<int, WaitUntilStmt::ClauseNode *>> ambiguousClauses;       // list of ambiguous clauses
1244        vector<int> andWhenClauses;    // list of clauses that have an AND op as a direct parent and when_cond defined
1245
1246        collectWhens( stmt->predicateTree, ambiguousClauses, andWhenClauses );
1247
1248        // This is only needed for clauses that have AND as a parent and a when_cond defined
1249        // generates: if ( ! when_cond_0 ) clause_statuses_0 = __SELECT_RUN;
1250        for ( int idx : andWhenClauses ) {
1251                const CodeLocation & cLoc = stmt->clauses.at(idx)->location;
1252                body->push_back(
1253                        new IfStmt( cLoc,
1254                                new UntypedExpr ( cLoc,
1255                                    new NameExpr( cLoc, "!?" ),
1256                                    { new NameExpr( cLoc, clauseData.at(idx)->whenName ) }
1257                                ),  // IfStmt cond
1258                                new ExprStmt( cLoc,
1259                                    new UntypedExpr ( cLoc,
1260                                        new NameExpr( cLoc, "?=?" ),
1261                                        {
1262                                            new UntypedExpr ( cLoc,
1263                                                new NameExpr( cLoc, "?[?]" ),
1264                                                {
1265                                                    new NameExpr( cLoc, statusArrName ),
1266                                                    ConstantExpr::from_int( cLoc, idx )
1267                                                }
1268                                            ),
1269                                            new NameExpr( cLoc, "__SELECT_RUN" )
1270                                        }
1271                                    )
1272                                )  // IfStmt then
1273                        )
1274                );
1275        }
1276
1277        // Only need to generate conditional initial state setting for ambiguous when clauses
1278        if ( !ambiguousClauses.empty() ) {
1279                body->push_back( genWhenStateConditions( stmt, clauseData, ambiguousClauses, 0 ) );
1280        }
1281
1282        // generates the following for each clause:
1283        // setup_clause( clause1, &clause_statuses[0], &park_counter );
1284        // register_select(A, clause1);
1285        for ( int i = 0; i < numClauses; i++ ) {
1286                setUpClause( stmt->clauses.at(i), clauseData.at(i), pCountName, tryBody );
1287        }
1288
1289        // generate satisfy logic based on if there is an else clause and if it is conditional
1290        if ( stmt->else_stmt && stmt->else_cond ) { // gen both else/non else branches
1291                tryBody->push_back(
1292                        new IfStmt( stmt->else_cond->location,
1293                                new NameExpr( stmt->else_cond->location, elseWhenName ),
1294                                genElseClauseBranch( stmt, runName, statusArrName, clauseData ),
1295                                genNoElseClauseBranch( stmt, runName, statusArrName, pCountName, clauseData )
1296                        )
1297                );
1298        } else if ( !stmt->else_stmt ) { // normal gen
1299                tryBody->push_back( genNoElseClauseBranch( stmt, runName, statusArrName, pCountName, clauseData ) );
1300        } else { // generate just else
1301                tryBody->push_back( genElseClauseBranch( stmt, runName, statusArrName, clauseData ) );
1302        }
1303
1304        // Collection of unregister calls on resources to be put in finally clause
1305        // for each clause:
1306        // if ( !__CFA_has_clause_run( clause_statuses[i] )) && unregister_select( ... , clausei ) ) { ... clausei stmt ... }
1307        // OR if when( ... ) defined on resource
1308        // if ( when_cond_i && (!__CFA_has_clause_run( clause_statuses[i] )) && unregister_select( ... , clausei ) ) { ... clausei stmt ... }
1309        CompoundStmt * unregisters = new CompoundStmt( loc );
1310
1311        Expr * statusExpr; // !__CFA_has_clause_run( clause_statuses[i] )
1312        for ( int i = 0; i < numClauses; i++ ) {
1313                const CodeLocation & cLoc = stmt->clauses.at(i)->location;
1314
1315                // Generates: !__CFA_has_clause_run( clause_statuses[i] )
1316                statusExpr = new UntypedExpr ( cLoc,
1317                        new NameExpr( cLoc, "!?" ),
1318                        {
1319                                new UntypedExpr ( cLoc,
1320                                    new NameExpr( cLoc, "__CFA_has_clause_run" ),
1321                                    {
1322                                        genArrAccessExpr( cLoc, i, statusArrName )
1323                                    }
1324                                )
1325                        }
1326                );
1327
1328                // Generates:
1329                // (!__CFA_has_clause_run( clause_statuses[i] )) && unregister_select( ... , clausei );
1330                statusExpr = new LogicalExpr( cLoc,
1331                        new CastExpr( cLoc,
1332                                statusExpr,
1333                                new BasicType( BasicKind::Bool ), GeneratedFlag::ExplicitCast
1334                        ),
1335                        new CastExpr( cLoc,
1336                                genSelectTraitCall( stmt->clauses.at(i), clauseData.at(i), "unregister_select" ),
1337                                new BasicType( BasicKind::Bool ), GeneratedFlag::ExplicitCast
1338                        ),
1339                        LogicalFlag::AndExpr
1340                );
1341
1342                // if when cond defined generates:
1343                // when_cond_i && (!__CFA_has_clause_run( clause_statuses[i] )) && unregister_select( ... , clausei );
1344                if ( stmt->clauses.at(i)->when_cond )
1345                        statusExpr = new LogicalExpr( cLoc,
1346                                new CastExpr( cLoc,
1347                                    new NameExpr( cLoc, clauseData.at(i)->whenName ),
1348                                    new BasicType( BasicKind::Bool ), GeneratedFlag::ExplicitCast
1349                                ),
1350                                new CastExpr( cLoc,
1351                                    statusExpr,
1352                                    new BasicType( BasicKind::Bool ), GeneratedFlag::ExplicitCast
1353                                ),
1354                                LogicalFlag::AndExpr
1355                        );
1356
1357                // generates:
1358                // if ( statusExpr ) { ... clausei stmt ... }
1359                unregisters->push_back(
1360                        new IfStmt( cLoc,
1361                                statusExpr,
1362                                new CompoundStmt( cLoc,
1363                                    {
1364                                        new IfStmt( cLoc,
1365                                            genSelectTraitCall( stmt->clauses.at(i), clauseData.at(i), "on_selected" ),
1366                                            ast::deepCopy( stmt->clauses.at(i)->stmt )
1367                                        )
1368                                    }
1369                                )
1370                        )
1371                );
1372
1373                // // generates:
1374                // // if ( statusExpr ) { ... clausei stmt ... }
1375                // unregisters->push_back(
1376                //     new IfStmt( cLoc,
1377                //         statusExpr,
1378                //         genStmtBlock( stmt->clauses.at(i), clauseData.at(i) )
1379                //     )
1380                // );
1381        }
1382
1383        body->push_back(
1384                new ast::TryStmt(
1385                        loc,
1386                        tryBody,
1387                        {},
1388                        new ast::FinallyClause( loc, unregisters )
1389                )
1390        );
1391
1392        for ( ClauseData * datum : clauseData )
1393                delete datum;
1394
1395        return body;
1396}
1397
1398// To add the predicates at global scope we need to do it in a second pass
1399// Predicates are added after "struct select_node { ... };"
1400class AddPredicateDecls final : public WithDeclsToAdd {
1401        vector<FunctionDecl *> & satFns;
1402        const StructDecl * selectNodeDecl = nullptr;
1403
1404  public:
1405        void previsit( const StructDecl * decl ) {
1406                if ( !decl->body ) {
1407                        return;
1408                } else if ( "select_node" == decl->name ) {
1409                        assert( !selectNodeDecl );
1410                        selectNodeDecl = decl;
1411                        for ( FunctionDecl * fn : satFns )
1412                                declsToAddAfter.push_back(fn);
1413                }
1414        }
1415        AddPredicateDecls( vector<FunctionDecl *> & satFns ): satFns(satFns) {}
1416};
1417
1418void generateWaitUntil( TranslationUnit & translationUnit ) {
1419        vector<FunctionDecl *> satFns;
1420        Pass<GenerateWaitUntilCore>::run( translationUnit, satFns );
1421        Pass<AddPredicateDecls>::run( translationUnit, satFns );
1422}
1423
1424} // namespace Concurrency
1425
1426// Local Variables: //
1427// tab-width: 4 //
1428// mode: c++ //
1429// compile-command: "make install" //
1430// End: //
Note: See TracBrowser for help on using the repository browser.