source: src/Concurrency/Waitfor.cc @ 72f85de

ADTarm-ehast-experimentalcleanup-dtorsenumforall-pointer-decayjacob/cs343-translationjenkins-sandboxnew-astnew-ast-unique-exprpthread-emulationqualifiedEnum
Last change on this file since 72f85de was 08da53d, checked in by Rob Schluntz <rschlunt@…>, 7 years ago

Refactor findSingleExpr and remove unnecessary resolver-generated casts

  • Property mode set to 100644
File size: 15.6 KB
Line 
1//
2// Cforall Version 1.0.0 Copyright (C) 2016 University of Waterloo
3//
4// The contents of this file are covered under the licence agreement in the
5// file "LICENCE" distributed with Cforall.
6//
7// Waitfor.cc --
8//
9// Author           : Thierry Delisle
10// Created On       : Mon Aug 28 11:06:52 2017
11// Last Modified By :
12// Last Modified On :
13// Update Count     : 5
14//
15
16#include "Concurrency/Keywords.h"
17
18#include <cassert>                 // for assert
19#include <string>                  // for string, operator==
20
21using namespace std::string_literals;
22
23#include "Common/PassVisitor.h"    // for PassVisitor
24#include "Common/SemanticError.h"  // for SemanticError
25#include "Common/utility.h"        // for deleteAll, map_range
26#include "CodeGen/OperatorTable.h" // for isConstructor
27#include "InitTweak/InitTweak.h"   // for getPointerBase
28#include "Parser/LinkageSpec.h"    // for Cforall
29#include "ResolvExpr/Resolver.h"   // for findVoidExpression
30#include "SynTree/Constant.h"      // for Constant
31#include "SynTree/Declaration.h"   // for StructDecl, FunctionDecl, ObjectDecl
32#include "SynTree/Expression.h"    // for VariableExpr, ConstantExpr, Untype...
33#include "SynTree/Initializer.h"   // for SingleInit, ListInit, Initializer ...
34#include "SynTree/Label.h"         // for Label
35#include "SynTree/Statement.h"     // for CompoundStmt, DeclStmt, ExprStmt
36#include "SynTree/Type.h"          // for StructInstType, Type, PointerType
37#include "SynTree/Visitor.h"       // for Visitor, acceptAll
38
39class Attribute;
40/*
41void foo() {
42        while( true ) {
43                when( a < 1 ) waitfor( f, a ) { bar(); }
44                or timeout( swagl() );
45                or waitfor( g, a ) { baz(); }
46                or waitfor( ^?{}, a ) { break; }
47                or waitfor( ^?{} ) { break; }
48        }
49}
50
51void f(int i, float f, A & mutex b, struct foo *  );
52void f(int );
53
54
55                      |  |
56                      |  |
57                            |  |
58                      |  |
59                      |  |
60                    \ |  | /
61                     \    /
62                      \  /
63                       \/
64
65
66void foo() {
67        while( true ) {
68
69                acceptable_t acceptables[3];
70                if( a < 1 ) {
71                        acceptables[0].func = f;
72                        acceptables[0].mon = a;
73                }
74                acceptables[1].func = g;
75                acceptables[1].mon = a;
76
77                acceptables[2].func = f;
78                acceptables[2].mon = a;
79                acceptables[2].is_dtor = true;
80
81                int ret = waitfor_internal( acceptables, swagl() );
82
83                switch( ret ) {
84                        case 0:
85                        {
86                                bar();
87                        }
88                        case 1:
89                        {
90                                baz();
91                        }
92                        case 2:
93                                signal(a);
94                                {
95                                        break;
96                                }
97                }
98        }
99}*/
100
101namespace Concurrency {
102
103        namespace {
104                const std::list<Label> noLabels;
105                const std::list< Attribute * > noAttributes;
106                Type::StorageClasses noStorage;
107                Type::Qualifiers noQualifiers;
108        }
109
110        //=============================================================================================
111        // Pass declarations
112        //=============================================================================================
113
114        class GenerateWaitForPass final : public WithIndexer {
115          public:
116
117                void premutate( FunctionDecl * decl );
118                void premutate( StructDecl   * decl );
119
120                Statement * postmutate( WaitForStmt * stmt );
121
122                static void generate( std::list< Declaration * > & translationUnit ) {
123                        PassVisitor< GenerateWaitForPass > impl;
124                        acceptAll( translationUnit, impl );
125                }
126
127                ObjectDecl * declare( unsigned long count, CompoundStmt * stmt );
128                ObjectDecl * declareFlag( CompoundStmt * stmt );
129                Statement  * makeSetter( ObjectDecl * flag );
130                ObjectDecl * declMon( WaitForStmt::Clause & clause, CompoundStmt * stmt );
131                void         init( ObjectDecl * acceptables, int index, WaitForStmt::Clause & clause, Statement * settter, CompoundStmt * stmt );
132                Expression * init_timeout( Expression *& time, Expression *& time_cond, bool has_else, Expression *& else_cond, Statement * settter, CompoundStmt * stmt );
133                Expression * call(size_t count, ObjectDecl * acceptables, Expression * timeout, CompoundStmt * stmt);
134                void         choose( WaitForStmt * waitfor, Expression  * result, CompoundStmt * stmt );
135
136                static void implement( std::list< Declaration * > & translationUnit ) {
137                        PassVisitor< GenerateWaitForPass > impl;
138                        mutateAll( translationUnit, impl );
139                }
140
141
142          private:
143                FunctionDecl        * decl_waitfor    = nullptr;
144                StructDecl          * decl_mask       = nullptr;
145                StructDecl          * decl_acceptable = nullptr;
146                StructDecl          * decl_monitor    = nullptr;
147
148                static std::unique_ptr< Type > generic_func;
149
150                UniqueName namer_acc = "__acceptables_"s;
151                UniqueName namer_idx = "__index_"s;
152                UniqueName namer_flg = "__do_run_"s;
153                UniqueName namer_msk = "__mask_"s;
154                UniqueName namer_mon = "__monitors_"s;
155                UniqueName namer_tim = "__timeout_"s;
156        };
157
158        //=============================================================================================
159        // General entry routine
160        //=============================================================================================
161        void generateWaitFor( std::list< Declaration * > & translationUnit ) {
162                GenerateWaitForPass     ::implement( translationUnit );
163        }
164
165        //=============================================================================================
166        // Generic helper routine
167        //=============================================================================================
168
169        namespace {
170                Expression * makeOpIndex( DeclarationWithType * array, unsigned long index ) {
171                        return new UntypedExpr(
172                                new NameExpr( "?[?]" ),
173                                {
174                                        new VariableExpr( array ),
175                                        new ConstantExpr( Constant::from_ulong( index ) )
176                                }
177                        );
178                }
179
180                Expression * makeOpAssign( Expression * lhs, Expression * rhs ) {
181                        return new UntypedExpr(
182                                        new NameExpr( "?=?" ),
183                                        { lhs, rhs }
184                        );
185                }
186
187                Expression * makeOpMember( Expression * sue, const std::string & mem ) {
188                        return new UntypedMemberExpr( new NameExpr( mem ), sue );
189                }
190
191                Statement * makeAccStatement( DeclarationWithType * object, unsigned long index, const std::string & member, Expression * value, const SymTab::Indexer & indexer ) {
192                        Expression * expr = makeOpAssign(
193                                makeOpMember(
194                                        makeOpIndex(
195                                                object,
196                                                index
197                                        ),
198                                        member
199                                ),
200                                value
201                        );
202
203                        ResolvExpr::findVoidExpression( expr, indexer );
204
205                        return new ExprStmt( noLabels, expr );
206                }
207
208                Expression * safeCond( Expression * expr, bool ifnull = true ) {
209                        if( expr ) return expr;
210
211                        return new ConstantExpr( Constant::from_bool( ifnull ) );
212                }
213
214                VariableExpr * extractVariable( Expression * func ) {
215                        if( VariableExpr * var = dynamic_cast< VariableExpr * >( func ) ) {
216                                return var;
217                        }
218
219                        CastExpr * cast = strict_dynamic_cast< CastExpr * >( func );
220                        return strict_dynamic_cast< VariableExpr * >( cast->arg );
221                }
222
223                Expression * detectIsDtor( Expression * func ) {
224                        VariableExpr * typed_func = extractVariable( func );
225                        bool is_dtor = InitTweak::isDestructor( typed_func->var );
226                        return new ConstantExpr( Constant::from_bool( is_dtor ) );
227                }
228        };
229
230
231        //=============================================================================================
232        // Generate waitfor implementation
233        //=============================================================================================
234
235        void GenerateWaitForPass::premutate( FunctionDecl * decl) {
236                if( decl->name != "__waitfor_internal" ) return;
237
238                decl_waitfor = decl;
239        }
240
241        void GenerateWaitForPass::premutate( StructDecl   * decl ) {
242                if( ! decl->body ) return;
243
244                if( decl->name == "__acceptable_t" ) {
245                        assert( !decl_acceptable );
246                        decl_acceptable = decl;
247                }
248                else if( decl->name == "__waitfor_mask_t" ) {
249                        assert( !decl_mask );
250                        decl_mask = decl;
251                }
252                else if( decl->name == "monitor_desc" ) {
253                        assert( !decl_monitor );
254                        decl_monitor = decl;
255                }
256        }
257
258        Statement * GenerateWaitForPass::postmutate( WaitForStmt * waitfor ) {
259                if( !decl_monitor || !decl_acceptable || !decl_mask ) throw SemanticError( "waitfor keyword requires monitors to be in scope, add #include <monitor>", waitfor );
260
261                CompoundStmt * stmt = new CompoundStmt( noLabels );
262
263                ObjectDecl * acceptables = declare( waitfor->clauses.size(), stmt );
264                ObjectDecl * flag        = declareFlag( stmt );
265                Statement  * setter      = makeSetter( flag );
266
267                int index = 0;
268                for( auto & clause : waitfor->clauses ) {
269                        init( acceptables, index, clause, setter, stmt );
270
271                        index++;
272                }
273
274                Expression * timeout = init_timeout(
275                        waitfor->timeout.time,
276                        waitfor->timeout.condition,
277                        waitfor->orelse .statement,
278                        waitfor->orelse .condition,
279                        setter,
280                        stmt
281                );
282
283                CompoundStmt * compound = new CompoundStmt( noLabels );
284                stmt->push_back( new IfStmt(
285                        noLabels,
286                        safeCond( new VariableExpr( flag ) ),
287                        compound,
288                        nullptr
289                ));
290
291                Expression * result = call( waitfor->clauses.size(), acceptables, timeout, compound );
292
293                choose( waitfor, result, compound );
294
295                return stmt;
296        }
297
298        ObjectDecl * GenerateWaitForPass::declare( unsigned long count, CompoundStmt * stmt )
299        {
300                ObjectDecl * acceptables = ObjectDecl::newObject(
301                        namer_acc.newName(),
302                        new ArrayType(
303                                noQualifiers,
304                                new StructInstType(
305                                        noQualifiers,
306                                        decl_acceptable
307                                ),
308                                new ConstantExpr( Constant::from_ulong( count ) ),
309                                false,
310                                false
311                        ),
312                        nullptr
313                );
314
315                stmt->push_back( new DeclStmt( noLabels, acceptables) );
316
317                Expression * set = new UntypedExpr(
318                        new NameExpr( "__builtin_memset" ),
319                        {
320                                new VariableExpr( acceptables ),
321                                new ConstantExpr( Constant::from_int( 0 ) ),
322                                new SizeofExpr( new VariableExpr( acceptables ) )
323                        }
324                );
325
326                ResolvExpr::findVoidExpression( set, indexer );
327
328                stmt->push_back( new ExprStmt( noLabels, set ) );
329
330                return acceptables;
331        }
332
333        ObjectDecl * GenerateWaitForPass::declareFlag( CompoundStmt * stmt ) {
334                ObjectDecl * flag = ObjectDecl::newObject(
335                        namer_flg.newName(),
336                        new BasicType(
337                                noQualifiers,
338                                BasicType::Bool
339                        ),
340                        new SingleInit( new ConstantExpr( Constant::from_ulong( 0 ) ) )
341                );
342
343                stmt->push_back( new DeclStmt( noLabels, flag) );
344
345                return flag;
346        }
347
348        Statement * GenerateWaitForPass::makeSetter( ObjectDecl * flag ) {
349                Expression * expr = new UntypedExpr(
350                        new NameExpr( "?=?" ),
351                        {
352                                new VariableExpr( flag ),
353                                new ConstantExpr( Constant::from_ulong( 1 ) )
354                        }
355                );
356
357                ResolvExpr::findVoidExpression( expr, indexer );
358
359                return new ExprStmt( noLabels, expr );
360        }
361
362        ObjectDecl * GenerateWaitForPass::declMon( WaitForStmt::Clause & clause, CompoundStmt * stmt ) {
363
364                ObjectDecl * mon = ObjectDecl::newObject(
365                        namer_mon.newName(),
366                        new ArrayType(
367                                noQualifiers,
368                                new PointerType(
369                                        noQualifiers,
370                                        new StructInstType(
371                                                noQualifiers,
372                                                decl_monitor
373                                        )
374                                ),
375                                new ConstantExpr( Constant::from_ulong( clause.target.arguments.size() ) ),
376                                false,
377                                false
378                        ),
379                        new ListInit(
380                                map_range < std::list<Initializer*> > ( clause.target.arguments, [this](Expression * expr ){
381                                        Expression * init = new CastExpr(
382                                                new UntypedExpr(
383                                                        new NameExpr( "get_monitor" ),
384                                                        { expr }
385                                                ),
386                                                new PointerType(
387                                                        noQualifiers,
388                                                        new StructInstType(
389                                                                noQualifiers,
390                                                                decl_monitor
391                                                        )
392                                                )
393                                        );
394
395                                        ResolvExpr::findSingleExpression( init, indexer );
396                                        return new SingleInit( init );
397                                })
398                        )
399                );
400
401                stmt->push_back( new DeclStmt( noLabels, mon) );
402
403                return mon;
404        }
405
406        void GenerateWaitForPass::init( ObjectDecl * acceptables, int index, WaitForStmt::Clause & clause, Statement * setter, CompoundStmt * stmt ) {
407
408                ObjectDecl * monitors = declMon( clause, stmt );
409
410                Type * fptr_t = new PointerType( noQualifiers, new FunctionType( noQualifiers, true ) );
411
412                stmt->push_back( new IfStmt(
413                        noLabels,
414                        safeCond( clause.condition ),
415                        new CompoundStmt({
416                                makeAccStatement( acceptables, index, "is_dtor", detectIsDtor( clause.target.function )                                    , indexer ),
417                                makeAccStatement( acceptables, index, "func"   , new CastExpr( clause.target.function, fptr_t )                            , indexer ),
418                                makeAccStatement( acceptables, index, "list"   , new VariableExpr( monitors )                                              , indexer ),
419                                makeAccStatement( acceptables, index, "size"   , new ConstantExpr( Constant::from_ulong( clause.target.arguments.size() ) ), indexer ),
420                                setter->clone()
421                        }),
422                        nullptr
423                ));
424
425                clause.target.function = nullptr;
426                clause.target.arguments.empty();
427                clause.condition = nullptr;
428        }
429
430        Expression * GenerateWaitForPass::init_timeout(
431                Expression *& time,
432                Expression *& time_cond,
433                bool has_else,
434                Expression *& else_cond,
435                Statement * setter,
436                CompoundStmt * stmt
437        ) {
438                ObjectDecl * timeout = ObjectDecl::newObject(
439                        namer_tim.newName(),
440                        new BasicType(
441                                noQualifiers,
442                                BasicType::LongLongUnsignedInt
443                        ),
444                        new SingleInit(
445                                new ConstantExpr( Constant::from_int( -1 ) )
446                        )
447                );
448
449                stmt->push_back( new DeclStmt( noLabels, timeout ) );
450
451                if( time ) {
452                        stmt->push_back( new IfStmt(
453                                noLabels,
454                                safeCond( time_cond ),
455                                new CompoundStmt({
456                                        new ExprStmt(
457                                                noLabels,
458                                                makeOpAssign(
459                                                        new VariableExpr( timeout ),
460                                                        time
461                                                )
462                                        ),
463                                        setter->clone()
464                                }),
465                                nullptr
466                        ));
467
468                        time = time_cond = nullptr;
469                }
470
471                if( has_else ) {
472                        stmt->push_back( new IfStmt(
473                                noLabels,
474                                safeCond( else_cond ),
475                                new CompoundStmt({
476                                        new ExprStmt(
477                                                noLabels,
478                                                makeOpAssign(
479                                                        new VariableExpr( timeout ),
480                                                        new ConstantExpr( Constant::from_ulong( 0 ) )
481                                                )
482                                        ),
483                                        setter->clone()
484                                }),
485                                nullptr
486                        ));
487
488                        else_cond = nullptr;
489                }
490
491                delete setter;
492
493                return new VariableExpr( timeout );
494        }
495
496        Expression * GenerateWaitForPass::call(
497                size_t count,
498                ObjectDecl * acceptables,
499                Expression * timeout,
500                CompoundStmt * stmt
501        ) {
502                ObjectDecl * index = ObjectDecl::newObject(
503                        namer_idx.newName(),
504                        new BasicType(
505                                noQualifiers,
506                                BasicType::ShortSignedInt
507                        ),
508                        new SingleInit(
509                                new ConstantExpr( Constant::from_int( -1 ) )
510                        )
511                );
512
513                stmt->push_back( new DeclStmt( noLabels, index ) );
514
515                ObjectDecl * mask = ObjectDecl::newObject(
516                        namer_msk.newName(),
517                        new StructInstType(
518                                noQualifiers,
519                                decl_mask
520                        ),
521                        new ListInit({
522                                new SingleInit( new AddressExpr( new VariableExpr( index ) ) ),
523                                new SingleInit( new VariableExpr( acceptables ) ),
524                                new SingleInit( new ConstantExpr( Constant::from_ulong( count ) ) )
525                        })
526                );
527
528                stmt->push_back( new DeclStmt( noLabels, mask ) );
529
530                stmt->push_back( new ExprStmt(
531                        noLabels,
532                        new ApplicationExpr(
533                                VariableExpr::functionPointer( decl_waitfor ),
534                                {
535                                        new CastExpr(
536                                                new VariableExpr( mask ),
537                                                new ReferenceType(
538                                                        noQualifiers,
539                                                        new StructInstType(
540                                                                noQualifiers,
541                                                                decl_mask
542                                                        )
543                                                )
544                                        ),
545                                        timeout
546                                }
547                        )
548                ));
549
550                return new VariableExpr( index );
551        }
552
553        void GenerateWaitForPass::choose(
554                WaitForStmt * waitfor,
555                Expression  * result,
556                CompoundStmt * stmt
557        ) {
558                SwitchStmt * swtch = new SwitchStmt(
559                        noLabels,
560                        result,
561                        std::list<Statement *>()
562                );
563
564                unsigned long i = 0;
565                for( auto & clause : waitfor->clauses ) {
566                        swtch->statements.push_back(
567                                new CaseStmt(
568                                        noLabels,
569                                        new ConstantExpr( Constant::from_ulong( i++ ) ),
570                                        {
571                                                clause.statement,
572                                                new BranchStmt(
573                                                        noLabels,
574                                                        "",
575                                                        BranchStmt::Break
576                                                )
577                                        }
578                                )
579                        );
580                }
581
582                if(waitfor->timeout.statement) {
583                        swtch->statements.push_back(
584                                new CaseStmt(
585                                        noLabels,
586                                        new ConstantExpr( Constant::from_int( -2 ) ),
587                                        {
588                                                waitfor->timeout.statement,
589                                                new BranchStmt(
590                                                        noLabels,
591                                                        "",
592                                                        BranchStmt::Break
593                                                )
594                                        }
595                                )
596                        );
597                }
598
599                if(waitfor->orelse.statement) {
600                        swtch->statements.push_back(
601                                new CaseStmt(
602                                        noLabels,
603                                        new ConstantExpr( Constant::from_int( -1 ) ),
604                                        {
605                                                waitfor->orelse.statement,
606                                                new BranchStmt(
607                                                        noLabels,
608                                                        "",
609                                                        BranchStmt::Break
610                                                )
611                                        }
612                                )
613                        );
614                }
615
616                stmt->push_back( swtch );
617        }
618};
619
620// Local Variables: //
621// mode: c //
622// tab-width: 4 //
623// End: //
Note: See TracBrowser for help on using the repository browser.