source: src/Concurrency/Waitfor.cc @ d67cdb7

ADTaaron-thesisarm-ehast-experimentalcleanup-dtorsdeferred_resndemanglerenumforall-pointer-decayjacob/cs343-translationjenkins-sandboxnew-astnew-ast-unique-exprnew-envno_listpersistent-indexerpthread-emulationqualifiedEnumresolv-newwith_gc
Last change on this file since d67cdb7 was d67cdb7, checked in by Peter A. Buhr <pabuhr@…>, 7 years ago

merge

  • Property mode set to 100644
File size: 15.7 KB
Line 
1//
2// Cforall Version 1.0.0 Copyright (C) 2016 University of Waterloo
3//
4// The contents of this file are covered under the licence agreement in the
5// file "LICENCE" distributed with Cforall.
6//
7// Waitfor.cc --
8//
9// Author           : Thierry Delisle
10// Created On       : Mon Aug 28 11:06:52 2017
11// Last Modified By :
12// Last Modified On :
13// Update Count     : 5
14//
15
16#include "Concurrency/Keywords.h"
17
18#include <cassert>                 // for assert
19#include <string>                  // for string, operator==
20
21using namespace std::string_literals;
22
23#include "Common/PassVisitor.h"    // for PassVisitor
24#include "Common/SemanticError.h"  // for SemanticError
25#include "Common/utility.h"        // for deleteAll, map_range
26#include "CodeGen/OperatorTable.h" // for isConstructor
27#include "InitTweak/InitTweak.h"   // for getPointerBase
28#include "Parser/LinkageSpec.h"    // for Cforall
29#include "ResolvExpr/Resolver.h"   // for findVoidExpression
30#include "SynTree/Constant.h"      // for Constant
31#include "SynTree/Declaration.h"   // for StructDecl, FunctionDecl, ObjectDecl
32#include "SynTree/Expression.h"    // for VariableExpr, ConstantExpr, Untype...
33#include "SynTree/Initializer.h"   // for SingleInit, ListInit, Initializer ...
34#include "SynTree/Label.h"         // for Label
35#include "SynTree/Statement.h"     // for CompoundStmt, DeclStmt, ExprStmt
36#include "SynTree/Type.h"          // for StructInstType, Type, PointerType
37#include "SynTree/Visitor.h"       // for Visitor, acceptAll
38
39class Attribute;
40/*
41void foo() {
42        while( true ) {
43                when( a < 1 ) waitfor( f, a ) { bar(); }
44                or timeout( swagl() );
45                or waitfor( g, a ) { baz(); }
46                or waitfor( ^?{}, a ) { break; }
47                or waitfor( ^?{} ) { break; }
48        }
49}
50
51void f(int i, float f, A & mutex b, struct foo *  );
52void f(int );
53
54
55                      |  |
56                      |  |
57                            |  |
58                      |  |
59                      |  |
60                    \ |  | /
61                     \    /
62                      \  /
63                       \/
64
65
66void foo() {
67        while( true ) {
68
69                acceptable_t acceptables[3];
70                if( a < 1 ) {
71                        acceptables[0].func = f;
72                        acceptables[0].mon = a;
73                }
74                acceptables[1].func = g;
75                acceptables[1].mon = a;
76
77                acceptables[2].func = f;
78                acceptables[2].mon = a;
79                acceptables[2].is_dtor = true;
80
81                int ret = waitfor_internal( acceptables, swagl() );
82
83                switch( ret ) {
84                        case 0:
85                        {
86                                bar();
87                        }
88                        case 1:
89                        {
90                                baz();
91                        }
92                        case 2:
93                                signal(a);
94                                {
95                                        break;
96                                }
97                }
98        }
99}*/
100
101namespace Concurrency {
102
103        namespace {
104                const std::list<Label> noLabels;
105                const std::list< Attribute * > noAttributes;
106                Type::StorageClasses noStorage;
107                Type::Qualifiers noQualifiers;
108        }
109
110        //=============================================================================================
111        // Pass declarations
112        //=============================================================================================
113
114        class GenerateWaitForPass final : public WithIndexer {
115          public:
116
117                void premutate( FunctionDecl * decl );
118                void premutate( StructDecl   * decl );
119
120                Statement * postmutate( WaitForStmt * stmt );
121
122                static void generate( std::list< Declaration * > & translationUnit ) {
123                        PassVisitor< GenerateWaitForPass > impl;
124                        acceptAll( translationUnit, impl );
125                }
126
127                ObjectDecl * declare( unsigned long count, CompoundStmt * stmt );
128                ObjectDecl * declareFlag( CompoundStmt * stmt );
129                Statement  * makeSetter( ObjectDecl * flag );
130                ObjectDecl * declMon( WaitForStmt::Clause & clause, CompoundStmt * stmt );
131                void         init( ObjectDecl * acceptables, int index, WaitForStmt::Clause & clause, Statement * settter, CompoundStmt * stmt );
132                Expression * init_timeout( Expression *& time, Expression *& time_cond, bool has_else, Expression *& else_cond, Statement * settter, CompoundStmt * stmt );
133                Expression * call(size_t count, ObjectDecl * acceptables, Expression * timeout, CompoundStmt * stmt);
134                void         choose( WaitForStmt * waitfor, Expression  * result, CompoundStmt * stmt );
135
136                static void implement( std::list< Declaration * > & translationUnit ) {
137                        PassVisitor< GenerateWaitForPass > impl;
138                        mutateAll( translationUnit, impl );
139                }
140
141
142          private:
143                FunctionDecl        * decl_waitfor    = nullptr;
144                StructDecl          * decl_mask       = nullptr;
145                StructDecl          * decl_acceptable = nullptr;
146                StructDecl          * decl_monitor    = nullptr;
147
148                static std::unique_ptr< Type > generic_func;
149
150                UniqueName namer_acc = "__acceptables_"s;
151                UniqueName namer_idx = "__index_"s;
152                UniqueName namer_flg = "__do_run_"s;
153                UniqueName namer_msk = "__mask_"s;
154                UniqueName namer_mon = "__monitors_"s;
155                UniqueName namer_tim = "__timeout_"s;
156        };
157
158        //=============================================================================================
159        // General entry routine
160        //=============================================================================================
161        void generateWaitFor( std::list< Declaration * > & translationUnit ) {
162                GenerateWaitForPass     ::implement( translationUnit );
163        }
164
165        //=============================================================================================
166        // Generic helper routine
167        //=============================================================================================
168
169        namespace {
170                Expression * makeOpIndex( DeclarationWithType * array, unsigned long index ) {
171                        return new UntypedExpr(
172                                new NameExpr( "?[?]" ),
173                                {
174                                        new VariableExpr( array ),
175                                        new ConstantExpr( Constant::from_ulong( index ) )
176                                }
177                        );
178                }
179
180                Expression * makeOpAssign( Expression * lhs, Expression * rhs ) {
181                        return new UntypedExpr(
182                                        new NameExpr( "?=?" ),
183                                        { lhs, rhs }
184                        );
185                }
186
187                Expression * makeOpMember( Expression * sue, const std::string & mem ) {
188                        return new UntypedMemberExpr( new NameExpr( mem ), sue );
189                }
190
191                Statement * makeAccStatement( DeclarationWithType * object, unsigned long index, const std::string & member, Expression * value, const SymTab::Indexer & indexer ) {
192                        std::unique_ptr< Expression > expr( makeOpAssign(
193                                makeOpMember(
194                                        makeOpIndex(
195                                                object,
196                                                index
197                                        ),
198                                        member
199                                ),
200                                value
201                        ) );
202
203                        return new ExprStmt( noLabels, ResolvExpr::findVoidExpression( expr.get(), indexer ) );
204                }
205
206                Expression * safeCond( Expression * expr, bool ifnull = true ) {
207                        if( expr ) return expr;
208
209                        return new ConstantExpr( Constant::from_bool( ifnull ) );
210                }
211
212                VariableExpr * extractVariable( Expression * func ) {
213                        if( VariableExpr * var = dynamic_cast< VariableExpr * >( func ) ) {
214                                return var;
215                        }
216
217                        CastExpr * cast = strict_dynamic_cast< CastExpr * >( func );
218                        return strict_dynamic_cast< VariableExpr * >( cast->arg );
219                }
220
221                Expression * detectIsDtor( Expression * func ) {
222                        VariableExpr * typed_func = extractVariable( func );
223                        bool is_dtor = InitTweak::isDestructor( typed_func->var );
224                        return new ConstantExpr( Constant::from_bool( is_dtor ) );
225                }
226        };
227
228
229        //=============================================================================================
230        // Generate waitfor implementation
231        //=============================================================================================
232
233        void GenerateWaitForPass::premutate( FunctionDecl * decl) {
234                if( decl->name != "__waitfor_internal" ) return;
235
236                decl_waitfor = decl;
237        }
238
239        void GenerateWaitForPass::premutate( StructDecl   * decl ) {
240                if( ! decl->body ) return;
241
242                if( decl->name == "__acceptable_t" ) {
243                        assert( !decl_acceptable );
244                        decl_acceptable = decl;
245                }
246                else if( decl->name == "__waitfor_mask_t" ) {
247                        assert( !decl_mask );
248                        decl_mask = decl;
249                }
250                else if( decl->name == "monitor_desc" ) {
251                        assert( !decl_monitor );
252                        decl_monitor = decl;
253                }
254        }
255
256        Statement * GenerateWaitForPass::postmutate( WaitForStmt * waitfor ) {
257                if( !decl_monitor || !decl_acceptable || !decl_mask ) throw SemanticError( "waitfor keyword requires monitors to be in scope, add #include <monitor>", waitfor );
258
259                CompoundStmt * stmt = new CompoundStmt( noLabels );
260
261                ObjectDecl * acceptables = declare( waitfor->clauses.size(), stmt );
262                ObjectDecl * flag        = declareFlag( stmt );
263                Statement  * setter      = makeSetter( flag );
264
265                int index = 0;
266                for( auto & clause : waitfor->clauses ) {
267                        init( acceptables, index, clause, setter, stmt );
268
269                        index++;
270                }
271
272                Expression * timeout = init_timeout(
273                        waitfor->timeout.time,
274                        waitfor->timeout.condition,
275                        waitfor->orelse .statement,
276                        waitfor->orelse .condition,
277                        setter,
278                        stmt
279                );
280
281                CompoundStmt * compound = new CompoundStmt( noLabels );
282                stmt->push_back( new IfStmt(
283                        noLabels,
284                        safeCond( new VariableExpr( flag ) ),
285                        compound,
286                        nullptr
287                ));
288
289                Expression * result = call( waitfor->clauses.size(), acceptables, timeout, compound );
290
291                choose( waitfor, result, compound );
292
293                return stmt;
294        }
295
296        ObjectDecl * GenerateWaitForPass::declare( unsigned long count, CompoundStmt * stmt )
297        {
298                ObjectDecl * acceptables = ObjectDecl::newObject(
299                        namer_acc.newName(),
300                        new ArrayType(
301                                noQualifiers,
302                                new StructInstType(
303                                        noQualifiers,
304                                        decl_acceptable
305                                ),
306                                new ConstantExpr( Constant::from_ulong( count ) ),
307                                false,
308                                false
309                        ),
310                        nullptr
311                );
312
313                stmt->push_back( new DeclStmt( noLabels, acceptables) );
314
315                UntypedExpr * set = new UntypedExpr(
316                        new NameExpr( "__builtin_memset" ),
317                        {
318                                new VariableExpr( acceptables ),
319                                new ConstantExpr( Constant::from_int( 0 ) ),
320                                new SizeofExpr( new VariableExpr( acceptables ) )
321                        }
322                );
323
324                Expression * resolved_set = ResolvExpr::findVoidExpression( set, indexer );
325                delete set;
326
327                stmt->push_back( new ExprStmt( noLabels, resolved_set ) );
328
329                return acceptables;
330        }
331
332        ObjectDecl * GenerateWaitForPass::declareFlag( CompoundStmt * stmt ) {
333                ObjectDecl * flag = ObjectDecl::newObject(
334                        namer_flg.newName(),
335                        new BasicType(
336                                noQualifiers,
337                                BasicType::Bool
338                        ),
339                        new SingleInit( new ConstantExpr( Constant::from_ulong( 0 ) ) )
340                );
341
342                stmt->push_back( new DeclStmt( noLabels, flag) );
343
344                return flag;
345        }
346
347        Statement * GenerateWaitForPass::makeSetter( ObjectDecl * flag ) {
348                Expression * untyped = new UntypedExpr(
349                        new NameExpr( "?=?" ),
350                        {
351                                new VariableExpr( flag ),
352                                new ConstantExpr( Constant::from_ulong( 1 ) )
353                        }
354                );
355
356                Expression * expr = ResolvExpr::findVoidExpression( untyped, indexer );
357                delete untyped;
358
359                return new ExprStmt( noLabels, expr );
360        }
361
362        ObjectDecl * GenerateWaitForPass::declMon( WaitForStmt::Clause & clause, CompoundStmt * stmt ) {
363
364                ObjectDecl * mon = ObjectDecl::newObject(
365                        namer_mon.newName(),
366                        new ArrayType(
367                                noQualifiers,
368                                new PointerType(
369                                        noQualifiers,
370                                        new StructInstType(
371                                                noQualifiers,
372                                                decl_monitor
373                                        )
374                                ),
375                                new ConstantExpr( Constant::from_ulong( clause.target.arguments.size() ) ),
376                                false,
377                                false
378                        ),
379                        new ListInit(
380                                map_range < std::list<Initializer*> > ( clause.target.arguments, [this](Expression * expr ){
381                                        Expression * untyped = new CastExpr(
382                                                new UntypedExpr(
383                                                        new NameExpr( "get_monitor" ),
384                                                        { expr }
385                                                ),
386                                                new PointerType(
387                                                        noQualifiers,
388                                                        new StructInstType(
389                                                                noQualifiers,
390                                                                decl_monitor
391                                                        )
392                                                )
393                                        );
394
395                                        Expression * init = ResolvExpr::findSingleExpression( untyped, indexer );
396                                        delete untyped;
397                                        return new SingleInit( init );
398                                })
399                        )
400                );
401
402                stmt->push_back( new DeclStmt( noLabels, mon) );
403
404                return mon;
405        }
406
407        void GenerateWaitForPass::init( ObjectDecl * acceptables, int index, WaitForStmt::Clause & clause, Statement * setter, CompoundStmt * stmt ) {
408
409                ObjectDecl * monitors = declMon( clause, stmt );
410
411                Type * fptr_t = new PointerType( noQualifiers, new FunctionType( noQualifiers, true ) );
412
413                stmt->push_back( new IfStmt(
414                        noLabels,
415                        safeCond( clause.condition ),
416                        new CompoundStmt({
417                                makeAccStatement( acceptables, index, "is_dtor", detectIsDtor( clause.target.function )                                    , indexer ),
418                                makeAccStatement( acceptables, index, "func"   , new CastExpr( clause.target.function, fptr_t )                            , indexer ),
419                                makeAccStatement( acceptables, index, "list"   , new VariableExpr( monitors )                                              , indexer ),
420                                makeAccStatement( acceptables, index, "size"   , new ConstantExpr( Constant::from_ulong( clause.target.arguments.size() ) ), indexer ),
421                                setter->clone()
422                        }),
423                        nullptr
424                ));
425
426                clause.target.function = nullptr;
427                clause.target.arguments.empty();
428                clause.condition = nullptr;
429        }
430
431        Expression * GenerateWaitForPass::init_timeout(
432                Expression *& time,
433                Expression *& time_cond,
434                bool has_else,
435                Expression *& else_cond,
436                Statement * setter,
437                CompoundStmt * stmt
438        ) {
439                ObjectDecl * timeout = ObjectDecl::newObject(
440                        namer_tim.newName(),
441                        new BasicType(
442                                noQualifiers,
443                                BasicType::LongLongUnsignedInt
444                        ),
445                        new SingleInit(
446                                new ConstantExpr( Constant::from_int( -1 ) )
447                        )
448                );
449
450                stmt->push_back( new DeclStmt( noLabels, timeout ) );
451
452                if( time ) {
453                        stmt->push_back( new IfStmt(
454                                noLabels,
455                                safeCond( time_cond ),
456                                new CompoundStmt({
457                                        new ExprStmt(
458                                                noLabels,
459                                                makeOpAssign(
460                                                        new VariableExpr( timeout ),
461                                                        time
462                                                )
463                                        ),
464                                        setter->clone()
465                                }),
466                                nullptr
467                        ));
468
469                        time = time_cond = nullptr;
470                }
471
472                if( has_else ) {
473                        stmt->push_back( new IfStmt(
474                                noLabels,
475                                safeCond( else_cond ),
476                                new CompoundStmt({
477                                        new ExprStmt(
478                                                noLabels,
479                                                makeOpAssign(
480                                                        new VariableExpr( timeout ),
481                                                        new ConstantExpr( Constant::from_ulong( 0 ) )
482                                                )
483                                        ),
484                                        setter->clone()
485                                }),
486                                nullptr
487                        ));
488
489                        else_cond = nullptr;
490                }
491
492                delete setter;
493
494                return new VariableExpr( timeout );
495        }
496
497        Expression * GenerateWaitForPass::call(
498                size_t count,
499                ObjectDecl * acceptables,
500                Expression * timeout,
501                CompoundStmt * stmt
502        ) {
503                ObjectDecl * index = ObjectDecl::newObject(
504                        namer_idx.newName(),
505                        new BasicType(
506                                noQualifiers,
507                                BasicType::ShortSignedInt
508                        ),
509                        new SingleInit(
510                                new ConstantExpr( Constant::from_int( -1 ) )
511                        )
512                );
513
514                stmt->push_back( new DeclStmt( noLabels, index ) );
515
516                ObjectDecl * mask = ObjectDecl::newObject(
517                        namer_msk.newName(),
518                        new StructInstType(
519                                noQualifiers,
520                                decl_mask
521                        ),
522                        new ListInit({
523                                new SingleInit( new AddressExpr( new VariableExpr( index ) ) ),
524                                new SingleInit( new VariableExpr( acceptables ) ),
525                                new SingleInit( new ConstantExpr( Constant::from_ulong( count ) ) )
526                        })
527                );
528
529                stmt->push_back( new DeclStmt( noLabels, mask ) );
530
531                stmt->push_back( new ExprStmt(
532                        noLabels,
533                        new ApplicationExpr(
534                                VariableExpr::functionPointer( decl_waitfor ),
535                                {
536                                        new CastExpr(
537                                                new VariableExpr( mask ),
538                                                new ReferenceType(
539                                                        noQualifiers,
540                                                        new StructInstType(
541                                                                noQualifiers,
542                                                                decl_mask
543                                                        )
544                                                )
545                                        ),
546                                        timeout
547                                }
548                        )
549                ));
550
551                return new VariableExpr( index );
552        }
553
554        void GenerateWaitForPass::choose(
555                WaitForStmt * waitfor,
556                Expression  * result,
557                CompoundStmt * stmt
558        ) {
559                SwitchStmt * swtch = new SwitchStmt(
560                        noLabels,
561                        result,
562                        std::list<Statement *>()
563                );
564
565                unsigned long i = 0;
566                for( auto & clause : waitfor->clauses ) {
567                        swtch->statements.push_back(
568                                new CaseStmt(
569                                        noLabels,
570                                        new ConstantExpr( Constant::from_ulong( i++ ) ),
571                                        {
572                                                clause.statement,
573                                                new BranchStmt(
574                                                        noLabels,
575                                                        "",
576                                                        BranchStmt::Break
577                                                )
578                                        }
579                                )
580                        );
581                }
582
583                if(waitfor->timeout.statement) {
584                        swtch->statements.push_back(
585                                new CaseStmt(
586                                        noLabels,
587                                        new ConstantExpr( Constant::from_int( -2 ) ),
588                                        {
589                                                waitfor->timeout.statement,
590                                                new BranchStmt(
591                                                        noLabels,
592                                                        "",
593                                                        BranchStmt::Break
594                                                )
595                                        }
596                                )
597                        );
598                }
599
600                if(waitfor->orelse.statement) {
601                        swtch->statements.push_back(
602                                new CaseStmt(
603                                        noLabels,
604                                        new ConstantExpr( Constant::from_int( -1 ) ),
605                                        {
606                                                waitfor->orelse.statement,
607                                                new BranchStmt(
608                                                        noLabels,
609                                                        "",
610                                                        BranchStmt::Break
611                                                )
612                                        }
613                                )
614                        );
615                }
616
617                stmt->push_back( swtch );
618        }
619};
620
621// Local Variables: //
622// mode: c //
623// tab-width: 4 //
624// End: //
Note: See TracBrowser for help on using the repository browser.