source: src/Concurrency/WaitforNew.cpp @ ded6c2a6

ast-experimental
Last change on this file since ded6c2a6 was c86b08d, checked in by caparsons <caparson@…>, 19 months ago

added support for the waituntil statement in the compiler

  • Property mode set to 100644
File size: 16.1 KB
Line 
1//
2// Cforall Version 1.0.0 Copyright (C) 2016 University of Waterloo
3//
4// The contents of this file are covered under the licence agreement in the
5// file "LICENCE" distributed with Cforall.
6//
7// WaitforNew.cpp -- Expand waitfor clauses into code.
8//
9// Author           : Andrew Beach
10// Created On       : Fri May 27 10:31:00 2022
11// Last Modified By : Andrew Beach
12// Last Modified On : Tue Jun 13 13:30:00 2022
13// Update Count     : 0
14//
15
16#include "Waitfor.h"
17
18#include <string>
19
20#include "AST/Pass.hpp"
21#include "Common/UniqueName.h"
22#include "InitTweak/InitTweak.h"
23#include "ResolvExpr/Resolver.h"
24
25#include "AST/Print.hpp"
26
27using namespace std::string_literals;
28using ResolvExpr::ResolveContext;
29
30/* So this is what this file dones:
31
32void f(int i, float f, A & mutex b, struct foo *  );
33void f(int );
34
35...{
36        when ( a < 1 ) waitfor( f : a ) { fee(); }
37        or timeout( getWaitTime() ) { fy(); }
38        or waitfor( g : a ) { foe(); }
39        or waitfor( ^?{} : a ) { break; }
40        or waitfor( ^?{} ) { break; }
41        or when ( a < 1 ) else { fum(); }
42}...
43
44                 ||
45                 ||
46                \||/
47                 \/
48
49...{
50        {
51                __acceptable_t __acceptables_#[4 <num-clauses>];
52                bool __do_run_# = false;
53
54                monitor$ * __monitors_#[1 <num-monitors>] = { a };
55                if ( a < 1) {
56                        void (*__function_#)() = <casts> f;
57                        __acceptables_#[0].is_dtor = false;
58                        __acceptables_#[0].func = __function_#;
59                        __acceptables_#[0].data = __monitors_#;
60                        __acceptables_#[0].size = 1;
61                        __do_run_# = true;
62                }
63
64                // Remaining waitfor clauses go here.
65
66                long long unsigned int __timeout_# = -1;
67                if ( true ) {
68                        __timeout_# = getWaitTime();
69                        __do_run_# = true;
70                }
71
72                if ( a < 1 ) {
73                        __timeout_# = 0
74                        __do_run_# = true;
75                }
76
77                short int __index_# = -1;
78                __waitfor_mask_t __mask_# = {&__index_#, {__acceptables_#, ?}};
79                __waitfor_internal((__waitfor_mask_t&)__mask_#, __timeout_#);
80
81                switch (__index_#) {
82                case 0:
83                        { { fee(); } break; }
84                case 1:
85                        { { foe(); } break; }
86                case 2:
87                        { <modified-break> break; }
88                case 3:
89                        { <modified-break> break; }
90                case -2:
91                        { { fy(); } break; }
92                case -1:
93                        { { foe(); } break; }
94                }
95        }
96}...
97*/
98
99namespace Concurrency {
100
101namespace {
102
103class GenerateWaitForCore final :
104                public ast::WithSymbolTable, public ast::WithConstTranslationUnit {
105        const ast::FunctionDecl * decl_waitfor    = nullptr;
106        const ast::StructDecl   * decl_mask       = nullptr;
107        const ast::StructDecl   * decl_acceptable = nullptr;
108        const ast::StructDecl   * decl_monitor    = nullptr;
109
110        UniqueName namer_acc = "__acceptables_"s;
111        UniqueName namer_idx = "__index_"s;
112        UniqueName namer_flg = "__do_run_"s;
113        UniqueName namer_msk = "__mask_"s;
114        UniqueName namer_mon = "__monitors_"s;
115        UniqueName namer_tim = "__timeout_"s;
116        UniqueName namer_fun = "__function_"s;
117
118        ast::ObjectDecl * declareAcceptables( ast::CompoundStmt * out,
119                const CodeLocation & location, unsigned long numClauses );
120        ast::ObjectDecl * declareFlag(
121                ast::CompoundStmt * out, const CodeLocation & location );
122        ast::ExprStmt * makeSetter(
123                const CodeLocation & location, ast::ObjectDecl * flag );
124        ast::ObjectDecl * declMonitors(
125                ast::CompoundStmt * out, const ast::WaitForClause * clause );
126        void init_clause( ast::CompoundStmt * out, ast::ObjectDecl * acceptables,
127                int index, const ast::WaitForClause * clause, ast::Stmt * setter );
128        ast::Expr * init_timeout(
129                ast::CompoundStmt * out, const CodeLocation & topLocation,
130                const ast::Expr * timeout_time, const ast::Expr * timeout_cond,
131                const ast::Stmt * else_stmt, const ast::Expr * else_cond,
132                const ast::Stmt * setter );
133        ast::Expr * call(
134                ast::CompoundStmt * out, const CodeLocation & location,
135                size_t numClauses, ast::ObjectDecl * acceptables,
136                ast::Expr * timeout );
137public:
138        void previsit( const ast::FunctionDecl * decl );
139        void previsit( const ast::StructDecl * decl );
140        ast::Stmt * postvisit( const ast::WaitForStmt * stmt );
141};
142
143ast::Expr * makeOpIndex( const CodeLocation & location,
144                const ast::DeclWithType * array, unsigned long index ) {
145        return new ast::UntypedExpr( location,
146                new ast::NameExpr( location, "?[?]" ),
147                {
148                        new ast::VariableExpr( location, array ),
149                        ast::ConstantExpr::from_ulong( location, index ),
150                }
151        );
152}
153
154ast::Expr * makeOpAssign( const CodeLocation & location,
155                const ast::Expr * lhs, const ast::Expr * rhs ) {
156        return new ast::UntypedExpr( location,
157                new ast::NameExpr( location, "?=?" ),
158                { lhs, rhs }
159        );
160}
161
162ast::Expr * makeOpMember( const CodeLocation & location,
163                const std::string & mem, const ast::Expr * sue ) {
164        return new ast::UntypedMemberExpr( location,
165                new ast::NameExpr( location, mem ),
166                sue
167        );
168}
169
170ast::Stmt * makeAccStmt(
171                const CodeLocation & location, ast::DeclWithType * object,
172                unsigned long index, const std::string & member,
173                const ast::Expr * value, const ResolveContext & context
174) {
175        ast::Expr * expr = makeOpAssign( location,
176                makeOpMember( location,
177                        member,
178                        makeOpIndex( location,
179                                object,
180                                index
181                        )
182                ),
183                value
184        );
185
186        auto result = ResolvExpr::findVoidExpression( expr, context );
187        return new ast::ExprStmt( location, result.get() );
188}
189
190const ast::Stmt * maybeCond( const CodeLocation & location,
191                const ast::Expr * cond, std::list<ast::ptr<ast::Stmt>> && stmts ) {
192        ast::Stmt * block = new ast::CompoundStmt( location, std::move( stmts ) );
193        return (cond) ? new ast::IfStmt( location, cond, block ) : block;
194}
195
196const ast::VariableExpr * extractVariable( const ast::Expr * func ) {
197        if ( auto var = dynamic_cast<const ast::VariableExpr *>( func ) ) {
198                return var;
199        }
200        auto cast = strict_dynamic_cast<const ast::CastExpr *>( func );
201        return cast->arg.strict_as<ast::VariableExpr>();
202}
203
204const ast::Expr * detectIsDtor(
205                const CodeLocation & location, const ast::Expr * func ) {
206        const ast::VariableExpr * typed_func = extractVariable( func );
207        bool is_dtor = InitTweak::isDestructor(
208                typed_func->var.strict_as<ast::FunctionDecl>() );
209        return ast::ConstantExpr::from_bool( location, is_dtor );
210}
211
212ast::ObjectDecl * GenerateWaitForCore::declareAcceptables(
213                ast::CompoundStmt * out,
214                const CodeLocation & location, unsigned long numClauses ) {
215        ast::ObjectDecl * acceptables = new ast::ObjectDecl( location,
216                namer_acc.newName(),
217                new ast::ArrayType(
218                        new ast::StructInstType( decl_acceptable ),
219                        ast::ConstantExpr::from_ulong( location, numClauses ),
220                        ast::FixedLen,
221                        ast::DynamicDim
222                )
223        );
224        out->push_back( new ast::DeclStmt( location, acceptables ) );
225
226        ast::Expr * set = new ast::UntypedExpr( location,
227                new ast::NameExpr( location, "__builtin_memset" ),
228                {
229                        new ast::VariableExpr( location, acceptables ),
230                        ast::ConstantExpr::from_int( location, 0 ),
231                        new ast::SizeofExpr( location,
232                                new ast::VariableExpr( location, acceptables ) ),
233                }
234        );
235        ResolveContext context{ symtab, transUnit().global };
236        auto result = ResolvExpr::findVoidExpression( set, context );
237        out->push_back( new ast::ExprStmt( location, result.get() ) );
238
239        return acceptables;
240}
241
242ast::ObjectDecl * GenerateWaitForCore::declareFlag(
243                ast::CompoundStmt * out, const CodeLocation & location ) {
244        ast::ObjectDecl * flag = new ast::ObjectDecl( location,
245                namer_flg.newName(),
246                new ast::BasicType( ast::BasicType::Bool ),
247                new ast::SingleInit( location,
248                        ast::ConstantExpr::from_ulong( location, 0 )
249                )
250        );
251        out->push_back( new ast::DeclStmt( location, flag ) );
252        return flag;
253}
254
255ast::ExprStmt * GenerateWaitForCore::makeSetter(
256                const CodeLocation & location, ast::ObjectDecl * flag ) {
257        ast::Expr * expr = new ast::UntypedExpr( location,
258                new ast::NameExpr( location, "?=?" ),
259                {
260                        new ast::VariableExpr( location, flag ),
261                        ast::ConstantExpr::from_ulong( location, 1 ),
262                }
263        );
264        ResolveContext context{ symtab, transUnit().global };
265        auto result = ResolvExpr::findVoidExpression( expr, context );
266        return new ast::ExprStmt( location, result.get() );
267}
268
269ast::ObjectDecl * GenerateWaitForCore::declMonitors(
270                ast::CompoundStmt * out,
271                const ast::WaitForClause * clause ) {
272        const CodeLocation & location = clause->location;
273        ast::ObjectDecl * monitor = new ast::ObjectDecl( location,
274                namer_mon.newName(),
275                new ast::ArrayType(
276                        new ast::PointerType(
277                                new ast::StructInstType( decl_monitor )
278                        ),
279                        ast::ConstantExpr::from_ulong( location, clause->target_args.size() ),
280                        ast::FixedLen,
281                        ast::DynamicDim
282                ),
283                new ast::ListInit( location,
284                        map_range<std::vector<ast::ptr<ast::Init>>>(
285                                clause->target_args,
286                                []( const ast::Expr * expr ){
287                                        return new ast::SingleInit( expr->location, expr ); }
288                        )
289                )
290        );
291        out->push_back( new ast::DeclStmt( location, monitor ) );
292        return monitor;
293}
294
295void GenerateWaitForCore::init_clause(
296                ast::CompoundStmt * out,
297                ast::ObjectDecl * acceptables,
298                int index,
299                const ast::WaitForClause * clause,
300                ast::Stmt * setter ) {
301        const CodeLocation & location = clause->location;
302        const ast::ObjectDecl * monitors = declMonitors( out, clause );
303        ast::Type * fptr_t = new ast::PointerType(
304                        new ast::FunctionType( ast::VariableArgs ) );
305
306        const ast::VariableExpr * variableExpr =
307                clause->target.as<ast::VariableExpr>();
308        ast::Expr * castExpr = new ast::CastExpr(
309                location,
310                new ast::CastExpr(
311                        location,
312                        clause->target,
313                        ast::deepCopy( variableExpr->result.get() ),
314                        ast::GeneratedCast ),
315                fptr_t,
316                ast::GeneratedCast );
317
318        ast::ObjectDecl * funcDecl = new ast::ObjectDecl( location,
319                namer_fun.newName(),
320                ast::deepCopy( fptr_t ),
321                new ast::SingleInit( location, castExpr )
322                );
323        ast::Expr * funcExpr = new ast::VariableExpr( location, funcDecl );
324        out->push_back( new ast::DeclStmt( location, funcDecl ) );
325
326        ResolveContext context{ symtab, transUnit().global };
327        out->push_back( maybeCond( location, clause->when_cond.get(), {
328                makeAccStmt( location, acceptables, index, "is_dtor",
329                        detectIsDtor( location, clause->target ), context ),
330                makeAccStmt( location, acceptables, index, "func",
331                        funcExpr, context ),
332                makeAccStmt( location, acceptables, index, "data",
333                        new ast::VariableExpr( location, monitors ), context ),
334                makeAccStmt( location, acceptables, index, "size",
335                        ast::ConstantExpr::from_ulong( location,
336                                clause->target_args.size() ), context ),
337                ast::deepCopy( setter ),
338        } ) );
339}
340
341ast::Expr * GenerateWaitForCore::init_timeout(
342                ast::CompoundStmt * out,
343                const CodeLocation & topLocation,
344                const ast::Expr * timeout_time,
345                const ast::Expr * timeout_cond,
346                const ast::Stmt * else_stmt,
347                const ast::Expr * else_cond,
348                const ast::Stmt * setter ) {
349        ast::ObjectDecl * timeout = new ast::ObjectDecl( topLocation,
350                namer_tim.newName(),
351                new ast::BasicType( ast::BasicType::LongLongUnsignedInt ),
352                new ast::SingleInit( topLocation,
353                        ast::ConstantExpr::from_int( topLocation, -1 )
354                )
355        );
356        out->push_back( new ast::DeclStmt( topLocation, timeout ) );
357
358        if ( timeout_time ) {
359                const CodeLocation & location = timeout_time->location;
360                out->push_back( maybeCond( location, timeout_cond, {
361                        new ast::ExprStmt( location,
362                                makeOpAssign(
363                                        location,
364                                        new ast::VariableExpr( location, timeout ),
365                                        timeout_time
366                                )
367                        ),
368                        ast::deepCopy( setter ),
369                } ) );
370        }
371
372        // We only care about the else_stmt's presence and location.
373        if ( else_stmt ) {
374                const CodeLocation & location = else_stmt->location;
375                out->push_back( maybeCond( location, else_cond, {
376                        new ast::ExprStmt( location,
377                                makeOpAssign(
378                                        location,
379                                        new ast::VariableExpr( location, timeout ),
380                                        ast::ConstantExpr::from_ulong( location, 0 )
381                                )
382                        ),
383                        ast::deepCopy( setter ),
384                } ) );
385        }
386
387        return new ast::VariableExpr( topLocation, timeout );
388}
389
390ast::Expr * GenerateWaitForCore::call(
391        ast::CompoundStmt * out,
392        const CodeLocation & location,
393        size_t numClauses,
394        ast::ObjectDecl * acceptables,
395        ast::Expr * timeout
396) {
397        ast::ObjectDecl * index = new ast::ObjectDecl( location,
398                namer_idx.newName(),
399                new ast::BasicType( ast::BasicType::ShortSignedInt ),
400                new ast::SingleInit( location,
401                        ast::ConstantExpr::from_int( location, -1 )
402                )
403        );
404        out->push_back( new ast::DeclStmt( location, index ) );
405
406        ast::ObjectDecl * mask = new ast::ObjectDecl( location,
407                namer_msk.newName(),
408                new ast::StructInstType( decl_mask ),
409                new ast::ListInit( location, {
410                        new ast::SingleInit( location,
411                                new ast::AddressExpr( location,
412                                        new ast::VariableExpr( location, index )
413                                )
414                        ),
415                        new ast::ListInit( location, {
416                                new ast::SingleInit( location,
417                                        new ast::VariableExpr( location, acceptables )
418                                ),
419                                new ast::SingleInit( location,
420                                        ast::ConstantExpr::from_ulong( location, numClauses )
421                                ),
422                        }),
423                })
424        );
425        out->push_back( new ast::DeclStmt( location, mask ) );
426
427        ast::ApplicationExpr * waitforMask = new ast::ApplicationExpr( location,
428                ast::VariableExpr::functionPointer( location, decl_waitfor ),
429                {
430                        new ast::CastExpr(
431                                new ast::VariableExpr( location, mask ),
432                                new ast::ReferenceType(
433                                        new ast::StructInstType( decl_mask )
434                                )
435                        ),
436                        timeout
437                }
438        );
439        out->push_back( new ast::ExprStmt( location, waitforMask ) );
440
441        return new ast::VariableExpr( location, index );
442}
443
444ast::Stmt * choose( const ast::WaitForStmt * waitfor, ast::Expr * result ) {
445        const CodeLocation & location = waitfor->location;
446
447        ast::SwitchStmt * theSwitch = new ast::SwitchStmt( location,
448                result,
449                std::vector<ast::ptr<ast::CaseClause>>()
450        );
451
452        // For some reason, enumerate doesn't work here because of references.
453        for ( size_t i = 0 ; i < waitfor->clauses.size() ; ++i ) {
454                theSwitch->cases.push_back(
455                        new ast::CaseClause( location,
456                                ast::ConstantExpr::from_ulong( location, i ),
457                                {
458                                        new ast::CompoundStmt( location, {
459                                                waitfor->clauses[i]->stmt,
460                                                new ast::BranchStmt( location,
461                                                        ast::BranchStmt::Break,
462                                                        ast::Label( location )
463                                                )
464                                        })
465                                }
466                        )
467                );
468        }
469
470        if ( waitfor->timeout_stmt ) {
471                theSwitch->cases.push_back(
472                        new ast::CaseClause( location,
473                                ast::ConstantExpr::from_int( location, -2 ),
474                                {
475                                        new ast::CompoundStmt( location, {
476                                                waitfor->timeout_stmt,
477                                                new ast::BranchStmt( location,
478                                                        ast::BranchStmt::Break,
479                                                        ast::Label( location )
480                                                )
481                                        })
482                                }
483                        )
484                );
485        }
486
487        if ( waitfor->else_stmt ) {
488                theSwitch->cases.push_back(
489                        new ast::CaseClause( location,
490                                ast::ConstantExpr::from_int( location, -1 ),
491                                {
492                                        new ast::CompoundStmt( location, {
493                                                waitfor->else_stmt,
494                                                new ast::BranchStmt( location,
495                                                        ast::BranchStmt::Break,
496                                                        ast::Label( location )
497                                                )
498                                        })
499                                }
500                        )
501                );
502        }
503
504        return theSwitch;
505}
506
507void GenerateWaitForCore::previsit( const ast::FunctionDecl * decl ) {
508        if ( "__waitfor_internal" == decl->name ) {
509                decl_waitfor = decl;
510        }
511}
512
513void GenerateWaitForCore::previsit( const ast::StructDecl * decl ) {
514        if ( !decl->body ) {
515                return;
516        } else if ( "__acceptable_t" == decl->name ) {
517                assert( !decl_acceptable );
518                decl_acceptable = decl;
519        } else if ( "__waitfor_mask_t" == decl->name ) {
520                assert( !decl_mask );
521                decl_mask = decl;
522        } else if ( "monitor$" == decl->name ) {
523                assert( !decl_monitor );
524                decl_monitor = decl;
525        }
526}
527
528ast::Stmt * GenerateWaitForCore::postvisit( const ast::WaitForStmt * stmt ) {
529        if ( !decl_monitor || !decl_acceptable || !decl_mask ) {
530                SemanticError( stmt, "waitfor keyword requires monitors to be in scope, add #include <monitor.hfa>" );
531        }
532
533        const CodeLocation & location = stmt->location;
534        ast::CompoundStmt * comp = new ast::CompoundStmt( location );
535
536        ast::ObjectDecl * acceptables = declareAcceptables( comp, location, stmt->clauses.size() );
537        ast::ObjectDecl * flag        = declareFlag( comp, location );
538        ast::Stmt       * setter      = makeSetter( location, flag );
539
540        // For some reason, enumerate doesn't work here because of references.
541        for ( size_t i = 0 ; i < stmt->clauses.size() ; ++i ) {
542                init_clause( comp, acceptables, i, stmt->clauses[i], setter );
543        }
544
545        ast::Expr * timeout = init_timeout(
546                comp,
547                location,
548                stmt->timeout_time,
549                stmt->timeout_cond,
550                stmt->else_stmt,
551                stmt->else_cond,
552                setter
553        );
554
555        ast::CompoundStmt * compound = new ast::CompoundStmt( location );
556        comp->push_back( new ast::IfStmt( location,
557                new ast::VariableExpr( location, flag ),
558                compound,
559                nullptr
560        ));
561
562        ast::Expr * result = call(
563                compound, location, stmt->clauses.size(), acceptables, timeout );
564        compound->push_back( choose( stmt, result ) );
565        return comp;
566}
567
568} // namespace
569
570void generateWaitFor( ast::TranslationUnit & translationUnit ) {
571        ast::Pass<GenerateWaitForCore>::run( translationUnit );
572}
573
574} // namespace Concurrency
575
576// Local Variables: //
577// tab-width: 4 //
578// mode: c++ //
579// compile-command: "make install" //
580// End: //
Note: See TracBrowser for help on using the repository browser.