source: src/Concurrency/Keywords.cpp @ cc0aa8c

Last change on this file since cc0aa8c was c92bdcc, checked in by Andrew Beach <ajbeach@…>, 6 months ago

Updated the rest of the names in src/ (except for the generated files).

  • Property mode set to 100644
File size: 48.7 KB
Line 
1//
2// Cforall Version 1.0.0 Copyright (C) 2016 University of Waterloo
3//
4// The contents of this file are covered under the licence agreement in the
5// file "LICENCE" distributed with Cforall.
6//
7// Keywords.cpp -- Implement concurrency constructs from their keywords.
8//
9// Author           : Andrew Beach
10// Created On       : Tue Nov 16  9:53:00 2021
11// Last Modified By : Peter A. Buhr
12// Last Modified On : Thu Dec 14 18:02:25 2023
13// Update Count     : 6
14//
15
16#include "Concurrency/Keywords.hpp"
17
18#include <iostream>
19
20#include "AST/Copy.hpp"
21#include "AST/Decl.hpp"
22#include "AST/Expr.hpp"
23#include "AST/Inspect.hpp"
24#include "AST/Pass.hpp"
25#include "AST/Stmt.hpp"
26#include "AST/DeclReplacer.hpp"
27#include "AST/TranslationUnit.hpp"
28#include "CodeGen/OperatorTable.hpp"
29#include "Common/Examine.hpp"
30#include "Common/Utility.hpp"
31#include "Common/UniqueName.hpp"
32#include "ControlStruct/LabelGenerator.hpp"
33#include "InitTweak/InitTweak.hpp"
34#include "Virtual/Tables.hpp"
35
36namespace Concurrency {
37
38namespace {
39
40// --------------------------------------------------------------------------
41// Loose Helper Functions:
42
43/// Detect threads constructed with the keyword thread.
44bool isThread( const ast::DeclWithType * decl ) {
45        auto baseType = decl->get_type()->stripDeclarator();
46        auto instType = dynamic_cast<const ast::StructInstType *>( baseType );
47        if ( nullptr == instType ) { return false; }
48        return instType->base->is_thread();
49}
50
51/// Get the virtual type id if given a type name.
52std::string typeIdType( std::string const & exception_name ) {
53        return exception_name.empty() ? std::string()
54                : Virtual::typeIdType( exception_name );
55}
56
57/// Get the vtable type name if given a type name.
58std::string vtableTypeName( std::string const & exception_name ) {
59        return exception_name.empty() ? std::string()
60                : Virtual::vtableTypeName( exception_name );
61}
62
63static ast::Type * mutate_under_references( ast::ptr<ast::Type>& type ) {
64        ast::Type * mutType = type.get_and_mutate();
65        for ( ast::ReferenceType * mutRef
66                ; (mutRef = dynamic_cast<ast::ReferenceType *>( mutType ))
67                ; mutType = mutRef->base.get_and_mutate() );
68        return mutType;
69}
70
71// Describe that it adds the generic parameters and the uses of the generic
72// parameters on the function and first "this" argument.
73ast::FunctionDecl * fixupGenerics(
74                const ast::FunctionDecl * func, const ast::StructDecl * decl ) {
75        const CodeLocation & location = decl->location;
76        // We have to update both the declaration
77        auto mutFunc = ast::mutate( func );
78        auto mutType = mutFunc->type.get_and_mutate();
79
80        if ( decl->params.empty() ) {
81                return mutFunc;
82        }
83
84        assert( 0 != mutFunc->params.size() );
85        assert( 0 != mutType->params.size() );
86
87        // Add the "forall" clause information.
88        for ( const ast::ptr<ast::TypeDecl> & typeParam : decl->params ) {
89                auto typeDecl = ast::deepCopy( typeParam );
90                mutFunc->type_params.push_back( typeDecl );
91                mutType->forall.push_back( new ast::TypeInstType( typeDecl ) );
92                for ( auto & assertion : typeDecl->assertions ) {
93                        mutFunc->assertions.push_back( assertion );
94                        mutType->assertions.emplace_back(
95                                new ast::VariableExpr( location, assertion ) );
96                }
97                typeDecl->assertions.clear();
98        }
99
100        // Even chain_mutate is not powerful enough for this:
101        ast::ptr<ast::Type>& paramType = strict_dynamic_cast<ast::ObjectDecl *>(
102                mutFunc->params[0].get_and_mutate() )->type;
103        auto paramTypeInst = strict_dynamic_cast<ast::StructInstType *>(
104                mutate_under_references( paramType ) );
105        auto typeParamInst = strict_dynamic_cast<ast::StructInstType *>(
106                mutate_under_references( mutType->params[0] ) );
107
108        for ( const ast::ptr<ast::TypeDecl> & typeDecl : mutFunc->type_params ) {
109                paramTypeInst->params.push_back(
110                        new ast::TypeExpr( location, new ast::TypeInstType( typeDecl ) ) );
111                typeParamInst->params.push_back(
112                        new ast::TypeExpr( location, new ast::TypeInstType( typeDecl ) ) );
113        }
114
115        return mutFunc;
116}
117
118// --------------------------------------------------------------------------
119struct ConcurrentSueKeyword : public ast::WithDeclsToAdd<> {
120        ConcurrentSueKeyword(
121                std::string&& type_name, std::string&& field_name,
122                std::string&& getter_name, std::string&& context_error,
123                std::string&& exception_name,
124                bool needs_main, ast::AggregateDecl::Aggregate cast_target
125        ) :
126                type_name( type_name ), field_name( field_name ),
127                getter_name( getter_name ), context_error( context_error ),
128                exception_name( exception_name ),
129                typeid_name( typeIdType( exception_name ) ),
130                vtable_name( vtableTypeName( exception_name ) ),
131                needs_main( needs_main ), cast_target( cast_target )
132        {}
133
134        virtual ~ConcurrentSueKeyword() {}
135
136        const ast::Decl * postvisit( const ast::StructDecl * decl );
137        const ast::DeclWithType * postvisit( const ast::FunctionDecl * decl );
138        const ast::Expr * postvisit( const ast::KeywordCastExpr * expr );
139
140        struct StructAndField {
141                const ast::StructDecl * decl;
142                const ast::ObjectDecl * field;
143        };
144
145        const ast::StructDecl * handleStruct( const ast::StructDecl * );
146        void handleMain( const ast::FunctionDecl *, const ast::StructInstType * );
147        void addTypeId( const ast::StructDecl * );
148        void addVtableForward( const ast::StructDecl * );
149        const ast::FunctionDecl * forwardDeclare( const ast::StructDecl * );
150        StructAndField addField( const ast::StructDecl * );
151        void addGetRoutines( const ast::ObjectDecl *, const ast::FunctionDecl * );
152        void addLockUnlockRoutines( const ast::StructDecl * );
153
154private:
155        const std::string type_name;
156        const std::string field_name;
157        const std::string getter_name;
158        const std::string context_error;
159        const std::string exception_name;
160        const std::string typeid_name;
161        const std::string vtable_name;
162        const bool needs_main;
163        const ast::AggregateDecl::Aggregate cast_target;
164
165        const ast::StructDecl   * type_decl = nullptr;
166        const ast::FunctionDecl * dtor_decl = nullptr;
167        const ast::StructDecl * except_decl = nullptr;
168        const ast::StructDecl * typeid_decl = nullptr;
169        const ast::StructDecl * vtable_decl = nullptr;
170
171};
172
173// Handles thread type declarations:
174//
175// thread Mythread {                         struct MyThread {
176//  int data;                                  int data;
177//  a_struct_t more_data;                      a_struct_t more_data;
178//                                =>             thread$ __thrd_d;
179// };                                        };
180//                                           static inline thread$ * get_thread( MyThread * this ) { return &this->__thrd_d; }
181//
182struct ThreadKeyword final : public ConcurrentSueKeyword {
183        ThreadKeyword() : ConcurrentSueKeyword(
184                "thread$",
185                "__thrd",
186                "get_thread",
187                "thread keyword requires threads to be in scope, add #include <thread.hfa>\n",
188                "ThreadCancelled",
189                true,
190                ast::AggregateDecl::Thread )
191        {}
192
193        virtual ~ThreadKeyword() {}
194};
195
196// Handles coroutine type declarations:
197//
198// coroutine MyCoroutine {                   struct MyCoroutine {
199//  int data;                                  int data;
200//  a_struct_t more_data;                      a_struct_t more_data;
201//                                =>             coroutine$ __cor_d;
202// };                                        };
203//                                           static inline coroutine$ * get_coroutine( MyCoroutine * this ) { return &this->__cor_d; }
204//
205struct CoroutineKeyword final : public ConcurrentSueKeyword {
206        CoroutineKeyword() : ConcurrentSueKeyword(
207                "coroutine$",
208                "__cor",
209                "get_coroutine",
210                "coroutine keyword requires coroutines to be in scope, add #include <coroutine.hfa>\n",
211                "CoroutineCancelled",
212                true,
213                ast::AggregateDecl::Coroutine )
214        {}
215
216        virtual ~CoroutineKeyword() {}
217};
218
219// Handles monitor type declarations:
220//
221// monitor MyMonitor {                       struct MyMonitor {
222//  int data;                                  int data;
223//  a_struct_t more_data;                      a_struct_t more_data;
224//                                =>             monitor$ __mon_d;
225// };                                        };
226//                                           static inline monitor$ * get_coroutine( MyMonitor * this ) {
227//                                               return &this->__cor_d;
228//                                           }
229//                                           void lock(MyMonitor & this) {
230//                                               lock(get_monitor(this));
231//                                           }
232//                                           void unlock(MyMonitor & this) {
233//                                               unlock(get_monitor(this));
234//                                           }
235//
236struct MonitorKeyword final : public ConcurrentSueKeyword {
237        MonitorKeyword() : ConcurrentSueKeyword(
238                "monitor$",
239                "__mon",
240                "get_monitor",
241                "monitor keyword requires monitors to be in scope, add #include <monitor.hfa>\n",
242                "",
243                false,
244                ast::AggregateDecl::Monitor )
245        {}
246
247        virtual ~MonitorKeyword() {}
248};
249
250// Handles generator type declarations:
251//
252// generator MyGenerator {                   struct MyGenerator {
253//  int data;                                  int data;
254//  a_struct_t more_data;                      a_struct_t more_data;
255//                                =>             int __generator_state;
256// };                                        };
257//
258struct GeneratorKeyword final : public ConcurrentSueKeyword {
259        GeneratorKeyword() : ConcurrentSueKeyword(
260                "generator$",
261                "__generator_state",
262                "get_generator",
263                "Unable to find builtin type generator$\n",
264                "",
265                true,
266                ast::AggregateDecl::Generator )
267        {}
268
269        virtual ~GeneratorKeyword() {}
270};
271
272const ast::Decl * ConcurrentSueKeyword::postvisit(
273                const ast::StructDecl * decl ) {
274        if ( !decl->body ) {
275                return decl;
276        } else if ( cast_target == decl->kind ) {
277                return handleStruct( decl );
278        } else if ( type_name == decl->name ) {
279                assert( !type_decl );
280                type_decl = decl;
281        } else if ( exception_name == decl->name ) {
282                assert( !except_decl );
283                except_decl = decl;
284        } else if ( typeid_name == decl->name ) {
285                assert( !typeid_decl );
286                typeid_decl = decl;
287        } else if ( vtable_name == decl->name ) {
288                assert( !vtable_decl );
289                vtable_decl = decl;
290        }
291        return decl;
292}
293
294// Try to get the full definition, but raise an error on conflicts.
295const ast::FunctionDecl * getDefinition(
296                const ast::FunctionDecl * old_decl,
297                const ast::FunctionDecl * new_decl ) {
298        if ( !new_decl->stmts ) {
299                return old_decl;
300        } else if ( !old_decl->stmts ) {
301                return new_decl;
302        } else {
303                assert( !old_decl->stmts || !new_decl->stmts );
304                return nullptr;
305        }
306}
307
308const ast::DeclWithType * ConcurrentSueKeyword::postvisit(
309                const ast::FunctionDecl * decl ) {
310        if ( type_decl && isDestructorFor( decl, type_decl ) ) {
311                // Check for forward declarations, try to get the full definition.
312                dtor_decl = (dtor_decl) ? getDefinition( dtor_decl, decl ) : decl;
313        } else if ( !vtable_name.empty() && decl->has_body() ) {
314                if (const ast::DeclWithType * param = isMainFor( decl, cast_target )) {
315                        if ( !vtable_decl ) {
316                                SemanticError( decl, context_error );
317                        }
318                        // Should be safe because of isMainFor.
319                        const ast::StructInstType * struct_type =
320                                static_cast<const ast::StructInstType *>(
321                                        static_cast<const ast::ReferenceType *>(
322                                                param->get_type() )->base.get() );
323
324                        handleMain( decl, struct_type );
325                }
326        }
327        return decl;
328}
329
330const ast::Expr * ConcurrentSueKeyword::postvisit(
331                const ast::KeywordCastExpr * expr ) {
332        if ( cast_target == expr->target ) {
333                // Convert `(thread &)ex` to `(thread$ &)*get_thread(ex)`, etc.
334                if ( !type_decl || !dtor_decl ) {
335                        SemanticError( expr, context_error );
336                }
337                assert( nullptr == expr->result );
338                auto cast = ast::mutate( expr );
339                cast->result = new ast::ReferenceType( new ast::StructInstType( type_decl ) );
340                cast->concrete_target.field  = field_name;
341                cast->concrete_target.getter = getter_name;
342                return cast;
343        }
344        return expr;
345}
346
347const ast::StructDecl * ConcurrentSueKeyword::handleStruct(
348                const ast::StructDecl * decl ) {
349        assert( decl->body );
350
351        if ( !type_decl || !dtor_decl ) {
352                SemanticError( decl, context_error );
353        }
354
355        if ( !exception_name.empty() ) {
356                if( !typeid_decl || !vtable_decl ) {
357                        SemanticError( decl, context_error );
358                }
359                addTypeId( decl );
360                addVtableForward( decl );
361        }
362
363        const ast::FunctionDecl * func = forwardDeclare( decl );
364        StructAndField addFieldRet = addField( decl );
365        decl = addFieldRet.decl;
366        const ast::ObjectDecl * field = addFieldRet.field;
367
368        addGetRoutines( field, func );
369        // Add routines to monitors for use by mutex stmt.
370        if ( ast::AggregateDecl::Monitor == cast_target ) {
371                addLockUnlockRoutines( decl );
372        }
373
374        return decl;
375}
376
377void ConcurrentSueKeyword::handleMain(
378                const ast::FunctionDecl * decl, const ast::StructInstType * type ) {
379        assert( vtable_decl );
380        assert( except_decl );
381
382        const CodeLocation & location = decl->location;
383
384        std::vector<ast::ptr<ast::Expr>> poly_args = {
385                new ast::TypeExpr( location, type ),
386        };
387        ast::ObjectDecl * vtable_object = Virtual::makeVtableInstance(
388                location,
389                "_default_vtable_object_declaration",
390                new ast::StructInstType( vtable_decl, copy( poly_args ) ),
391                type,
392                nullptr
393        );
394        declsToAddAfter.push_back( vtable_object );
395        declsToAddAfter.push_back(
396                new ast::ObjectDecl(
397                        location,
398                        Virtual::concurrentDefaultVTableName(),
399                        new ast::ReferenceType( vtable_object->type, ast::CV::Const ),
400                        new ast::SingleInit( location,
401                                new ast::VariableExpr( location, vtable_object ) )
402                )
403        );
404        declsToAddAfter.push_back( Virtual::makeGetExceptionFunction(
405                location,
406                vtable_object,
407                new ast::StructInstType( except_decl, copy( poly_args ) )
408        ) );
409}
410
411void ConcurrentSueKeyword::addTypeId( const ast::StructDecl * decl ) {
412        assert( typeid_decl );
413        const CodeLocation & location = decl->location;
414
415        ast::StructInstType * typeid_type =
416                new ast::StructInstType( typeid_decl, ast::CV::Const );
417        typeid_type->params.push_back(
418                new ast::TypeExpr( location, new ast::StructInstType( decl ) ) );
419        declsToAddBefore.push_back(
420                Virtual::makeTypeIdInstance( location, typeid_type ) );
421        // If the typeid_type is going to be kept, the other reference will have
422        // been made by now, but we also get to avoid extra mutates.
423        ast::ptr<ast::StructInstType> typeid_cleanup = typeid_type;
424}
425
426void ConcurrentSueKeyword::addVtableForward( const ast::StructDecl * decl ) {
427        assert( vtable_decl );
428        const CodeLocation& location = decl->location;
429
430        std::vector<ast::ptr<ast::Expr>> poly_args = {
431                new ast::TypeExpr( location, new ast::StructInstType( decl ) ),
432        };
433        declsToAddBefore.push_back( Virtual::makeGetExceptionForward(
434                location,
435                new ast::StructInstType( vtable_decl, copy( poly_args ) ),
436                new ast::StructInstType( except_decl, copy( poly_args ) )
437        ) );
438        ast::ObjectDecl * vtable_object = Virtual::makeVtableForward(
439                location,
440                "_default_vtable_object_declaration",
441                new ast::StructInstType( vtable_decl, std::move( poly_args ) )
442        );
443        declsToAddBefore.push_back( vtable_object );
444        declsToAddBefore.push_back(
445                new ast::ObjectDecl(
446                        location,
447                        Virtual::concurrentDefaultVTableName(),
448                        new ast::ReferenceType( vtable_object->type, ast::CV::Const ),
449                        nullptr,
450                        ast::Storage::Extern,
451                        ast::Linkage::Cforall
452                )
453        );
454}
455
456const ast::FunctionDecl * ConcurrentSueKeyword::forwardDeclare(
457                const ast::StructDecl * decl ) {
458        const CodeLocation & location = decl->location;
459
460        ast::StructDecl * forward = ast::deepCopy( decl );
461        {
462                // If removing members makes ref-count go to zero, do not free.
463                ast::ptr<ast::StructDecl> forward_ptr = forward;
464                forward->body = false;
465                forward->members.clear();
466                forward_ptr.release();
467        }
468
469        ast::ObjectDecl * this_decl = new ast::ObjectDecl(
470                location,
471                "this",
472                new ast::ReferenceType( new ast::StructInstType( decl ) )
473        );
474
475        ast::ObjectDecl * ret_decl = new ast::ObjectDecl(
476                location,
477                "ret",
478                new ast::PointerType( new ast::StructInstType( type_decl ) )
479        );
480
481        ast::FunctionDecl * get_decl = new ast::FunctionDecl(
482                location,
483                getter_name,
484                { this_decl }, // params
485                { ret_decl }, // returns
486                nullptr, // stmts
487                ast::Storage::Static,
488                ast::Linkage::Cforall,
489                { new ast::Attribute( "const" ) },
490                ast::Function::Inline
491        );
492        get_decl = fixupGenerics( get_decl, decl );
493
494        ast::FunctionDecl * main_decl = nullptr;
495        if ( needs_main ) {
496                // `this_decl` is copied here because the original was used above.
497                main_decl = new ast::FunctionDecl(
498                        location,
499                        "main",
500                        { ast::deepCopy( this_decl ) },
501                        {},
502                        nullptr,
503                        ast::Storage::Classes(),
504                        ast::Linkage::Cforall
505                );
506                main_decl = fixupGenerics( main_decl, decl );
507        }
508
509        declsToAddBefore.push_back( forward );
510        if ( needs_main ) declsToAddBefore.push_back( main_decl );
511        declsToAddBefore.push_back( get_decl );
512
513        return get_decl;
514}
515
516ConcurrentSueKeyword::StructAndField ConcurrentSueKeyword::addField(
517                const ast::StructDecl * decl ) {
518        const CodeLocation & location = decl->location;
519
520        ast::ObjectDecl * field = new ast::ObjectDecl(
521                location,
522                field_name,
523                new ast::StructInstType( type_decl )
524        );
525
526        auto mutDecl = ast::mutate( decl );
527        mutDecl->members.push_back( field );
528
529        return {mutDecl, field};
530}
531
532void ConcurrentSueKeyword::addGetRoutines(
533                const ast::ObjectDecl * field, const ast::FunctionDecl * forward ) {
534        // Clone the signature and then build the body.
535        ast::FunctionDecl * decl = ast::deepCopy( forward );
536
537        // Say it is generated at the "same" places as the forward declaration.
538        const CodeLocation & location = decl->location;
539
540        const ast::DeclWithType * param = decl->params.front();
541        ast::Stmt * stmt = new ast::ReturnStmt( location,
542                new ast::AddressExpr( location,
543                        new ast::MemberExpr( location,
544                                field,
545                                new ast::CastExpr( location,
546                                        new ast::VariableExpr( location, param ),
547                                        ast::deepCopy( param->get_type()->stripReferences() ),
548                                        ast::ExplicitCast
549                                )
550                        )
551                )
552        );
553
554        decl->stmts = new ast::CompoundStmt( location, { stmt } );
555        declsToAddAfter.push_back( decl );
556}
557
558void ConcurrentSueKeyword::addLockUnlockRoutines(
559                const ast::StructDecl * decl ) {
560        // This should only be used on monitors.
561        assert( ast::AggregateDecl::Monitor == cast_target );
562
563        const CodeLocation & location = decl->location;
564
565        // The parameter for both routines.
566        ast::ObjectDecl * this_decl = new ast::ObjectDecl(
567                location,
568                "this",
569                new ast::ReferenceType( new ast::StructInstType( decl ) )
570        );
571
572        ast::FunctionDecl * lock_decl = new ast::FunctionDecl(
573                location,
574                "lock",
575                {
576                        // Copy the declaration of this.
577                        ast::deepCopy( this_decl ),
578                },
579                { /* returns */ },
580                nullptr,
581                ast::Storage::Static,
582                ast::Linkage::Cforall,
583                { /* attributes */ },
584                ast::Function::Inline
585        );
586        lock_decl = fixupGenerics( lock_decl, decl );
587
588        lock_decl->stmts = new ast::CompoundStmt( location, {
589                new ast::ExprStmt( location,
590                        new ast::UntypedExpr( location,
591                                new ast::NameExpr( location, "lock" ),
592                                {
593                                        new ast::UntypedExpr( location,
594                                                new ast::NameExpr( location, "get_monitor" ),
595                                                { new ast::VariableExpr( location,
596                                                        InitTweak::getParamThis( lock_decl ) ) }
597                                        )
598                                }
599                        )
600                )
601        } );
602
603        ast::FunctionDecl * unlock_decl = new ast::FunctionDecl(
604                location,
605                "unlock",
606                {
607                        // Last use, consume the declaration of this.
608                        this_decl,
609                },
610                { /* returns */ },
611                nullptr,
612                ast::Storage::Static,
613                ast::Linkage::Cforall,
614                { /* attributes */ },
615                ast::Function::Inline
616        );
617        unlock_decl = fixupGenerics( unlock_decl, decl );
618
619        unlock_decl->stmts = new ast::CompoundStmt( location, {
620                new ast::ExprStmt( location,
621                        new ast::UntypedExpr( location,
622                                new ast::NameExpr( location, "unlock" ),
623                                {
624                                        new ast::UntypedExpr( location,
625                                                new ast::NameExpr( location, "get_monitor" ),
626                                                { new ast::VariableExpr( location,
627                                                        InitTweak::getParamThis( unlock_decl ) ) }
628                                        )
629                                }
630                        )
631                )
632        } );
633
634        declsToAddAfter.push_back( lock_decl );
635        declsToAddAfter.push_back( unlock_decl );
636}
637
638
639// --------------------------------------------------------------------------
640struct SuspendKeyword final :
641                public ast::WithStmtsToAdd<>, public ast::WithGuards {
642        SuspendKeyword() = default;
643        virtual ~SuspendKeyword() = default;
644
645        void previsit( const ast::FunctionDecl * );
646        const ast::DeclWithType * postvisit( const ast::FunctionDecl * );
647        const ast::Stmt * postvisit( const ast::SuspendStmt * );
648
649private:
650        bool is_real_suspend( const ast::FunctionDecl * );
651
652        const ast::Stmt * make_generator_suspend( const ast::SuspendStmt * );
653        const ast::Stmt * make_coroutine_suspend( const ast::SuspendStmt * );
654
655        struct LabelPair {
656                ast::Label obj;
657                int idx;
658        };
659
660        LabelPair make_label(const ast::Stmt * stmt ) {
661                labels.push_back( ControlStruct::newLabel( "generator", stmt ) );
662                return { labels.back(), int(labels.size()) };
663        }
664
665        const ast::DeclWithType * in_generator = nullptr;
666        const ast::FunctionDecl * decl_suspend = nullptr;
667        std::vector<ast::Label> labels;
668};
669
670void SuspendKeyword::previsit( const ast::FunctionDecl * decl ) {
671        GuardValue( in_generator ); in_generator = nullptr;
672
673        // If it is the real suspend, grab it if we don't have one already.
674        if ( is_real_suspend( decl ) ) {
675                decl_suspend = decl_suspend ? decl_suspend : decl;
676                return;
677        }
678
679        // Otherwise check if this is a generator main and, if so, handle it.
680        auto param = isMainFor( decl, ast::AggregateDecl::Generator );
681        if ( !param ) return;
682
683        if ( 0 != decl->returns.size() ) {
684                SemanticError( decl->location, "Generator main must return void." );
685        }
686
687        in_generator = param;
688        GuardValue( labels ); labels.clear();
689}
690
691const ast::DeclWithType * SuspendKeyword::postvisit(
692                const ast::FunctionDecl * decl ) {
693        // Only modify a full definition of a generator with states.
694        if ( !decl->stmts || !in_generator || labels.empty() ) return decl;
695
696        const CodeLocation & location = decl->location;
697
698        // Create a new function body:
699        // static void * __generator_labels[] = {&&s0, &&s1, ...};
700        // void * __generator_label = __generator_labels[GEN.__generator_state];
701        // goto * __generator_label;
702        // s0: ;
703        // OLD_BODY
704
705        // This is the null statement inserted right before the body.
706        ast::NullStmt * noop = new ast::NullStmt( location );
707        noop->labels.push_back( ControlStruct::newLabel( "generator", noop ) );
708        const ast::Label & first_label = noop->labels.back();
709
710        // Add each label to the init, starting with the first label.
711        std::vector<ast::ptr<ast::Init>> inits = {
712                new ast::SingleInit( location,
713                        new ast::LabelAddressExpr( location, copy( first_label ) ) ) };
714        // Then go through all the stored labels, and clear the store.
715        for ( auto && label : labels ) {
716                inits.push_back( new ast::SingleInit( label.location,
717                        new ast::LabelAddressExpr( label.location, std::move( label )
718                        ) ) );
719        }
720        labels.clear();
721        // Then construct the initializer itself.
722        auto init = new ast::ListInit( location, std::move( inits ) );
723
724        ast::ObjectDecl * generatorLabels = new ast::ObjectDecl(
725                location,
726                "__generator_labels",
727                new ast::ArrayType(
728                        new ast::PointerType( new ast::VoidType() ),
729                        nullptr,
730                        ast::FixedLen,
731                        ast::DynamicDim
732                ),
733                init,
734                ast::Storage::Classes(),
735                ast::Linkage::AutoGen
736        );
737
738        ast::ObjectDecl * generatorLabel = new ast::ObjectDecl(
739                location,
740                "__generator_label",
741                new ast::PointerType( new ast::VoidType() ),
742                new ast::SingleInit( location,
743                        new ast::UntypedExpr( location,
744                                new ast::NameExpr( location, "?[?]" ),
745                                {
746                                        // TODO: Could be a variable expr.
747                                        new ast::NameExpr( location, "__generator_labels" ),
748                                        new ast::UntypedMemberExpr( location,
749                                                new ast::NameExpr( location, "__generator_state" ),
750                                                new ast::VariableExpr( location, in_generator )
751                                        )
752                                }
753                        )
754                ),
755                ast::Storage::Classes(),
756                ast::Linkage::AutoGen
757        );
758
759        ast::BranchStmt * theGoTo = new ast::BranchStmt(
760                location, new ast::VariableExpr( location, generatorLabel )
761        );
762
763        // The noop goes here in order.
764
765        ast::CompoundStmt * body = new ast::CompoundStmt( location, {
766                { new ast::DeclStmt( location, generatorLabels ) },
767                { new ast::DeclStmt( location, generatorLabel ) },
768                { theGoTo },
769                { noop },
770                { decl->stmts },
771        } );
772
773        auto mutDecl = ast::mutate( decl );
774        mutDecl->stmts = body;
775        return mutDecl;
776}
777
778const ast::Stmt * SuspendKeyword::postvisit( const ast::SuspendStmt * stmt ) {
779        switch ( stmt->kind ) {
780        case ast::SuspendStmt::None:
781                // Use the context to determain the implicit target.
782                if ( in_generator ) {
783                        return make_generator_suspend( stmt );
784                } else {
785                        return make_coroutine_suspend( stmt );
786                }
787        case ast::SuspendStmt::Coroutine:
788                return make_coroutine_suspend( stmt );
789        case ast::SuspendStmt::Generator:
790                // Generator suspends must be directly in a generator.
791                if ( !in_generator ) SemanticError( stmt->location, "\"suspend generator\" must be used inside main of generator type." );
792                return make_generator_suspend( stmt );
793        }
794        assert( false );
795        return stmt;
796}
797
798/// Find the real/official suspend declaration.
799bool SuspendKeyword::is_real_suspend( const ast::FunctionDecl * decl ) {
800        return ( !decl->linkage.is_mangled
801                && 0 == decl->params.size()
802                && 0 == decl->returns.size()
803                && "__cfactx_suspend" == decl->name );
804}
805
806const ast::Stmt * SuspendKeyword::make_generator_suspend(
807                const ast::SuspendStmt * stmt ) {
808        assert( in_generator );
809        // Target code is:
810        //   GEN.__generator_state = X;
811        //   THEN
812        //   return;
813        //   __gen_X:;
814
815        const CodeLocation & location = stmt->location;
816
817        LabelPair label = make_label( stmt );
818
819        // This is the context saving statement.
820        stmtsToAddBefore.push_back( new ast::ExprStmt( location,
821                new ast::UntypedExpr( location,
822                        new ast::NameExpr( location, "?=?" ),
823                        {
824                                new ast::UntypedMemberExpr( location,
825                                        new ast::NameExpr( location, "__generator_state" ),
826                                        new ast::VariableExpr( location, in_generator )
827                                ),
828                                ast::ConstantExpr::from_int( location, label.idx ),
829                        }
830                )
831        ) );
832
833        // The THEN component is conditional (return is not).
834        if ( stmt->then ) {
835                stmtsToAddBefore.push_back( stmt->then.get() );
836        }
837        stmtsToAddBefore.push_back( new ast::ReturnStmt( location, nullptr ) );
838
839        // The null statement replaces the old suspend statement.
840        return new ast::NullStmt( location, { label.obj } );
841}
842
843const ast::Stmt * SuspendKeyword::make_coroutine_suspend(
844                const ast::SuspendStmt * stmt ) {
845        // The only thing we need from the old statement is the location.
846        const CodeLocation & location = stmt->location;
847
848        if ( !decl_suspend ) {
849                SemanticError( location, "suspend keyword applied to coroutines requires coroutines to be in scope, add #include <coroutine.hfa>." );
850        }
851        if ( stmt->then ) {
852                SemanticError( location, "Compound statement following coroutines is not implemented." );
853        }
854
855        return new ast::ExprStmt( location,
856                new ast::UntypedExpr( location,
857                        ast::VariableExpr::functionPointer( location, decl_suspend ) )
858        );
859}
860
861// --------------------------------------------------------------------------
862struct MutexKeyword final : public ast::WithDeclsToAdd<> {
863        const ast::FunctionDecl * postvisit( const ast::FunctionDecl * decl );
864        void postvisit( const ast::StructDecl * decl );
865        const ast::Stmt * postvisit( const ast::MutexStmt * stmt );
866
867        static std::vector<const ast::DeclWithType *> findMutexArgs(
868                        const ast::FunctionDecl * decl, bool & first );
869        static void validate( const ast::DeclWithType * decl );
870
871        ast::CompoundStmt * addDtorStatements( const ast::FunctionDecl* func, const ast::CompoundStmt *, const std::vector<const ast::DeclWithType *> &);
872        ast::CompoundStmt * addStatements( const ast::FunctionDecl* func, const ast::CompoundStmt *, const std::vector<const ast::DeclWithType *> &);
873        ast::CompoundStmt * addStatements( const ast::CompoundStmt * body, const std::vector<ast::ptr<ast::Expr>> & args );
874        ast::CompoundStmt * addThreadDtorStatements( const ast::FunctionDecl* func, const ast::CompoundStmt * body, const std::vector<const ast::DeclWithType *> & args );
875        ast::ExprStmt * genVirtLockUnlockExpr( const std::string & fnName, ast::ptr<ast::Expr> expr, const CodeLocation & location, ast::Expr * param);
876        ast::IfStmt * genTypeDiscrimLockUnlock( const std::string & fnName, const std::vector<ast::ptr<ast::Expr>> & args, const CodeLocation & location, ast::UntypedExpr * thisParam );
877private:
878        const ast::StructDecl * monitor_decl = nullptr;
879        const ast::StructDecl * guard_decl = nullptr;
880        const ast::StructDecl * dtor_guard_decl = nullptr;
881        const ast::StructDecl * thread_guard_decl = nullptr;
882        const ast::StructDecl * lock_guard_decl = nullptr;
883
884        static ast::ptr<ast::Type> generic_func;
885
886        UniqueName mutex_func_namer = UniqueName("__lock_unlock_curr");
887};
888
889const ast::FunctionDecl * MutexKeyword::postvisit(
890                const ast::FunctionDecl * decl ) {
891        bool is_first_argument_mutex = false;
892        const std::vector<const ast::DeclWithType *> mutexArgs =
893                findMutexArgs( decl, is_first_argument_mutex );
894        bool const isDtor = CodeGen::isDestructor( decl->name );
895
896        // Does this function have any mutex arguments that connect to monitors?
897        if ( mutexArgs.empty() ) {
898                // If this is the destructor for a monitor it must be mutex.
899                if ( isDtor ) {
900                        // This reflects MutexKeyword::validate, but no error messages.
901                        const ast::Type * type = decl->type->params.front();
902
903                        // If it's a copy, it's not a mutex.
904                        const ast::ReferenceType * refType = dynamic_cast<const ast::ReferenceType *>( type );
905                        if ( nullptr == refType ) {
906                                return decl;
907                        }
908
909                        // If it is not pointing directly to a type, it's not a mutex.
910                        auto base = refType->base;
911                        if ( base.as<ast::ReferenceType>() ) return decl;
912                        if ( base.as<ast::PointerType>() ) return decl;
913
914                        // If it is not a struct, it's not a mutex.
915                        auto baseStruct = base.as<ast::StructInstType>();
916                        if ( nullptr == baseStruct ) return decl;
917
918                        // If it is a monitor, then it is a monitor.
919                        if( baseStruct->base->is_monitor() || baseStruct->base->is_thread() ) {
920                                SemanticError( decl, "destructors for structures declared as \"monitor\" must use mutex parameters " );
921                        }
922                }
923                return decl;
924        }
925
926        // Monitors can't be constructed with mutual exclusion.
927        if ( CodeGen::isConstructor( decl->name ) && is_first_argument_mutex ) {
928                SemanticError( decl, "constructors cannot have mutex parameters " );
929        }
930
931        // It makes no sense to have multiple mutex parameters for the destructor.
932        if ( isDtor && mutexArgs.size() != 1 ) {
933                SemanticError( decl, "destructors can only have 1 mutex argument " );
934        }
935
936        // Make sure all the mutex arguments are monitors.
937        for ( auto arg : mutexArgs ) {
938                validate( arg );
939        }
940
941        // Check to see if the body needs to be instrument the body.
942        const ast::CompoundStmt * body = decl->stmts;
943        if ( !body ) return decl;
944
945        // Check to if the required headers have been seen.
946        if ( !monitor_decl || !guard_decl || !dtor_guard_decl ) {
947                SemanticError( decl, "mutex keyword requires monitors to be in scope, add #include <monitor.hfa>." );
948        }
949
950        // Instrument the body.
951        ast::CompoundStmt * newBody = nullptr;
952        if ( isDtor && isThread( mutexArgs.front() ) ) {
953                if ( !thread_guard_decl ) {
954                        SemanticError( decl, "thread destructor requires threads to be in scope, add #include <thread.hfa>." );
955                }
956                newBody = addThreadDtorStatements( decl, body, mutexArgs );
957        } else if ( isDtor ) {
958                newBody = addDtorStatements( decl, body, mutexArgs );
959        } else {
960                newBody = addStatements( decl, body, mutexArgs );
961        }
962        assert( newBody );
963        return ast::mutate_field( decl, &ast::FunctionDecl::stmts, newBody );
964}
965
966void MutexKeyword::postvisit( const ast::StructDecl * decl ) {
967        if ( !decl->body ) {
968                return;
969        } else if ( decl->name == "monitor$" ) {
970                assert( !monitor_decl );
971                monitor_decl = decl;
972        } else if ( decl->name == "monitor_guard_t" ) {
973                assert( !guard_decl );
974                guard_decl = decl;
975        } else if ( decl->name == "monitor_dtor_guard_t" ) {
976                assert( !dtor_guard_decl );
977                dtor_guard_decl = decl;
978        } else if ( decl->name == "thread_dtor_guard_t" ) {
979                assert( !thread_guard_decl );
980                thread_guard_decl = decl;
981        } else if ( decl->name == "__mutex_stmt_lock_guard" ) {
982                assert( !lock_guard_decl );
983                lock_guard_decl = decl;
984        }
985}
986
987const ast::Stmt * MutexKeyword::postvisit( const ast::MutexStmt * stmt ) {
988        if ( !lock_guard_decl ) {
989                SemanticError( stmt->location, "mutex stmt requires a header, add #include <mutex_stmt.hfa>." );
990        }
991        ast::CompoundStmt * body =
992                        new ast::CompoundStmt( stmt->location, { stmt->stmt } );
993
994        return addStatements( body, stmt->mutexObjs );;
995}
996
997std::vector<const ast::DeclWithType *> MutexKeyword::findMutexArgs(
998                const ast::FunctionDecl * decl, bool & first ) {
999        std::vector<const ast::DeclWithType *> mutexArgs;
1000
1001        bool once = true;
1002        for ( auto arg : decl->params ) {
1003                const ast::Type * type = arg->get_type();
1004                if ( type->is_mutex() ) {
1005                        if ( once ) first = true;
1006                        mutexArgs.push_back( arg.get() );
1007                }
1008                once = false;
1009        }
1010        return mutexArgs;
1011}
1012
1013void MutexKeyword::validate( const ast::DeclWithType * decl ) {
1014        const ast::Type * type = decl->get_type();
1015
1016        // If it's a copy, it's not a mutex.
1017        const ast::ReferenceType * refType = dynamic_cast<const ast::ReferenceType *>( type );
1018        if ( nullptr == refType ) {
1019                SemanticError( decl, "Mutex argument must be of reference type " );
1020        }
1021
1022        // If it is not pointing directly to a type, it's not a mutex.
1023        auto base = refType->base;
1024        if ( base.as<ast::ReferenceType>() || base.as<ast::PointerType>() ) {
1025                SemanticError( decl, "Mutex argument have exactly one level of indirection " );
1026        }
1027
1028        // If it is not a struct, it's not a mutex.
1029        auto baseStruct = base.as<ast::StructInstType>();
1030        if ( nullptr == baseStruct ) return;
1031
1032        // Make sure that only the outer reference is mutex.
1033        if( baseStruct->is_mutex() ) {
1034                SemanticError( decl, "mutex keyword may only appear once per argument " );
1035        }
1036}
1037
1038ast::CompoundStmt * MutexKeyword::addDtorStatements(
1039                const ast::FunctionDecl* func, const ast::CompoundStmt * body,
1040                const std::vector<const ast::DeclWithType *> & args ) {
1041        ast::Type * argType = ast::shallowCopy( args.front()->get_type() );
1042        argType->set_mutex( false );
1043
1044        ast::CompoundStmt * mutBody = ast::mutate( body );
1045
1046        // Generated code goes near the beginning of body:
1047        const CodeLocation & location = mutBody->location;
1048
1049        const ast::ObjectDecl * monitor = new ast::ObjectDecl(
1050                location,
1051                "__monitor",
1052                new ast::PointerType( new ast::StructInstType( monitor_decl ) ),
1053                new ast::SingleInit(
1054                        location,
1055                        new ast::UntypedExpr(
1056                                location,
1057                                new ast::NameExpr( location, "get_monitor" ),
1058                                { new ast::CastExpr(
1059                                        location,
1060                                        new ast::VariableExpr( location, args.front() ),
1061                                        argType, ast::ExplicitCast
1062                                ) }
1063                        )
1064                )
1065        );
1066
1067        assert( generic_func );
1068
1069        // In reverse order:
1070        // monitor_dtor_guard_t __guard = { __monitor, func, false };
1071        mutBody->push_front(
1072                new ast::DeclStmt( location, new ast::ObjectDecl(
1073                        location,
1074                        "__guard",
1075                        new ast::StructInstType( dtor_guard_decl ),
1076                        new ast::ListInit(
1077                                location,
1078                                {
1079                                        new ast::SingleInit( location,
1080                                                new ast::AddressExpr( location,
1081                                                        new ast::VariableExpr( location, monitor ) ) ),
1082                                        new ast::SingleInit( location,
1083                                                new ast::CastExpr( location,
1084                                                        new ast::VariableExpr( location, func ),
1085                                                        generic_func,
1086                                                        ast::ExplicitCast ) ),
1087                                        new ast::SingleInit( location,
1088                                                ast::ConstantExpr::from_bool( location, false ) ),
1089                                },
1090                                {},
1091                                ast::MaybeConstruct
1092                        )
1093                ))
1094        );
1095
1096        // monitor$ * __monitor = get_monitor(a);
1097        mutBody->push_front( new ast::DeclStmt( location, monitor ) );
1098
1099        return mutBody;
1100}
1101
1102ast::CompoundStmt * MutexKeyword::addStatements(
1103                const ast::FunctionDecl* func, const ast::CompoundStmt * body,
1104                const std::vector<const ast::DeclWithType * > & args ) {
1105        ast::CompoundStmt * mutBody = ast::mutate( body );
1106
1107        // Code is generated near the beginning of the compound statement.
1108        const CodeLocation & location = mutBody->location;
1109
1110        // Make pointer to the monitors.
1111        ast::ObjectDecl * monitors = new ast::ObjectDecl(
1112                location,
1113                "__monitors",
1114                new ast::ArrayType(
1115                        new ast::PointerType(
1116                                new ast::StructInstType( monitor_decl )
1117                        ),
1118                        ast::ConstantExpr::from_ulong( location, args.size() ),
1119                        ast::FixedLen,
1120                        ast::DynamicDim
1121                ),
1122                new ast::ListInit(
1123                        location,
1124                        map_range<std::vector<ast::ptr<ast::Init>>>(
1125                                args,
1126                                []( const ast::DeclWithType * decl ) {
1127                                        return new ast::SingleInit(
1128                                                decl->location,
1129                                                new ast::UntypedExpr(
1130                                                        decl->location,
1131                                                        new ast::NameExpr( decl->location, "get_monitor" ),
1132                                                        {
1133                                                                new ast::CastExpr(
1134                                                                        decl->location,
1135                                                                        new ast::VariableExpr( decl->location, decl ),
1136                                                                        decl->get_type(),
1137                                                                        ast::ExplicitCast
1138                                                                )
1139                                                        }
1140                                                )
1141                                        );
1142                                }
1143                        )
1144                )
1145        );
1146
1147        assert( generic_func );
1148
1149        // In Reverse Order:
1150        mutBody->push_front(
1151                new ast::DeclStmt( location, new ast::ObjectDecl(
1152                        location,
1153                        "__guard",
1154                        new ast::StructInstType( guard_decl ),
1155                        new ast::ListInit(
1156                                location,
1157                                {
1158                                        new ast::SingleInit( location,
1159                                                new ast::VariableExpr( location, monitors ) ),
1160                                        new ast::SingleInit( location,
1161                                                ast::ConstantExpr::from_ulong( location, args.size() ) ),
1162                                        new ast::SingleInit( location, new ast::CastExpr(
1163                                                location,
1164                                                new ast::VariableExpr( location, func ),
1165                                                generic_func,
1166                                                ast::ExplicitCast
1167                                        ) ),
1168                                },
1169                                {},
1170                                ast::MaybeConstruct
1171                        )
1172                ))
1173        );
1174
1175        // monitor$ * __monitors[] = { get_monitor(a), get_monitor(b) };
1176        mutBody->push_front( new ast::DeclStmt( location, monitors ) );
1177
1178        return mutBody;
1179}
1180
1181// generates a cast to the void ptr to the appropriate lock type and dereferences it before calling lock or unlock on it
1182// used to undo the type erasure done by storing all the lock pointers as void
1183ast::ExprStmt * MutexKeyword::genVirtLockUnlockExpr( const std::string & fnName, ast::ptr<ast::Expr> expr, const CodeLocation & location, ast::Expr * param ) {
1184        return new ast::ExprStmt( location,
1185                new ast::UntypedExpr( location,
1186                        new ast::NameExpr( location, fnName ), {
1187                                ast::UntypedExpr::createDeref(
1188                                        location,
1189                                        new ast::CastExpr( location,
1190                                                param,
1191                                                new ast::PointerType( new ast::TypeofType( new ast::UntypedExpr(
1192                                                        expr->location,
1193                                                        new ast::NameExpr( expr->location, "__get_mutexstmt_lock_type" ),
1194                                                        { expr }
1195                                                ) ) ),
1196                                                ast::GeneratedFlag::ExplicitCast
1197                                        )
1198                                )
1199                        }
1200                )
1201        );
1202}
1203
1204ast::IfStmt * MutexKeyword::genTypeDiscrimLockUnlock( const std::string & fnName, const std::vector<ast::ptr<ast::Expr>> & args, const CodeLocation & location, ast::UntypedExpr * thisParam ) {
1205        ast::IfStmt * outerLockIf = nullptr;
1206        ast::IfStmt * lastLockIf = nullptr;
1207
1208        //adds an if/elif clause for each lock to assign type from void ptr based on ptr address
1209        for ( long unsigned int i = 0; i < args.size(); i++ ) {
1210
1211                ast::UntypedExpr * ifCond = new ast::UntypedExpr( location,
1212                        new ast::NameExpr( location, "?==?" ), {
1213                                ast::deepCopy( thisParam ),
1214                                new ast::CastExpr( location, new ast::AddressExpr( location, args.at(i) ), new ast::PointerType( new ast::VoidType() ))
1215                        }
1216                );
1217
1218                ast::IfStmt * currLockIf = new ast::IfStmt(
1219                        location,
1220                        ifCond,
1221                        genVirtLockUnlockExpr( fnName, args.at(i), location, ast::deepCopy( thisParam ) )
1222                );
1223
1224                if ( i == 0 ) {
1225                        outerLockIf = currLockIf;
1226                } else {
1227                        // add ifstmt to else of previous stmt
1228                        lastLockIf->else_ = currLockIf;
1229                }
1230
1231                lastLockIf = currLockIf;
1232        }
1233        return outerLockIf;
1234}
1235
1236void flattenTuple( const ast::UntypedTupleExpr * tuple, std::vector<ast::ptr<ast::Expr>> & output ) {
1237        for ( auto & expr : tuple->exprs ) {
1238                const ast::UntypedTupleExpr * innerTuple = dynamic_cast<const ast::UntypedTupleExpr *>(expr.get());
1239                if ( innerTuple ) flattenTuple( innerTuple, output );
1240                else output.emplace_back( ast::deepCopy( expr ));
1241        }
1242}
1243
1244ast::CompoundStmt * MutexKeyword::addStatements(
1245                const ast::CompoundStmt * body,
1246                const std::vector<ast::ptr<ast::Expr>> & args ) {
1247
1248        // Code is generated near the beginning of the compound statement.
1249        const CodeLocation & location = body->location;
1250
1251                // final body to return
1252        ast::CompoundStmt * newBody = new ast::CompoundStmt( location );
1253
1254        // std::string lockFnName = mutex_func_namer.newName();
1255        // std::string unlockFnName = mutex_func_namer.newName();
1256
1257        // If any arguments to the mutex stmt are tuples, flatten them
1258        std::vector<ast::ptr<ast::Expr>> flattenedArgs;
1259        for ( auto & arg : args ) {
1260                const ast::UntypedTupleExpr * tuple = dynamic_cast<const ast::UntypedTupleExpr *>(args.at(0).get());
1261                if ( tuple ) flattenTuple( tuple, flattenedArgs );
1262                else flattenedArgs.emplace_back( ast::deepCopy( arg ));
1263        }
1264
1265        // Make pointer to the monitors.
1266        ast::ObjectDecl * monitors = new ast::ObjectDecl(
1267                location,
1268                "__monitors",
1269                new ast::ArrayType(
1270                        new ast::PointerType(
1271                                new ast::VoidType()
1272                        ),
1273                        ast::ConstantExpr::from_ulong( location, flattenedArgs.size() ),
1274                        ast::FixedLen,
1275                        ast::DynamicDim
1276                ),
1277                new ast::ListInit(
1278                        location,
1279                        map_range<std::vector<ast::ptr<ast::Init>>>(
1280                                flattenedArgs, [](const ast::Expr * expr) {
1281                                        return new ast::SingleInit(
1282                                                expr->location,
1283                                                new ast::UntypedExpr(
1284                                                        expr->location,
1285                                                        new ast::NameExpr( expr->location, "__get_mutexstmt_lock_ptr" ),
1286                                                        { expr }
1287                                                )
1288                                        );
1289                                }
1290                        )
1291                )
1292        );
1293
1294        ast::StructInstType * lock_guard_struct =
1295                        new ast::StructInstType( lock_guard_decl );
1296
1297        // use try stmts to lock and finally to unlock
1298        ast::TryStmt * outerTry = nullptr;
1299        ast::TryStmt * currentTry;
1300        ast::CompoundStmt * lastBody = nullptr;
1301
1302        // adds a nested try stmt for each lock we are locking
1303        for ( long unsigned int i = 0; i < flattenedArgs.size(); i++ ) {
1304                ast::UntypedExpr * innerAccess = new ast::UntypedExpr(
1305                        location,
1306                        new ast::NameExpr( location,"?[?]" ), {
1307                                new ast::NameExpr( location, "__monitors" ),
1308                                ast::ConstantExpr::from_int( location, i )
1309                        }
1310                );
1311
1312                // make the try body
1313                ast::CompoundStmt * currTryBody = new ast::CompoundStmt( location );
1314                ast::IfStmt * lockCall = genTypeDiscrimLockUnlock( "lock", flattenedArgs, location, innerAccess );
1315                currTryBody->push_back( lockCall );
1316
1317                // make the finally stmt
1318                ast::CompoundStmt * currFinallyBody = new ast::CompoundStmt( location );
1319                ast::IfStmt * unlockCall = genTypeDiscrimLockUnlock( "unlock", flattenedArgs, location, innerAccess );
1320                currFinallyBody->push_back( unlockCall );
1321
1322                // construct the current try
1323                currentTry = new ast::TryStmt(
1324                        location,
1325                        currTryBody,
1326                        {},
1327                        new ast::FinallyClause( location, currFinallyBody )
1328                );
1329                if ( i == 0 ) outerTry = currentTry;
1330                else {
1331                        // pushback try into the body of the outer try
1332                        lastBody->push_back( currentTry );
1333                }
1334                lastBody = currTryBody;
1335        }
1336
1337        // push body into innermost try body
1338        if ( lastBody != nullptr ) {
1339                lastBody->push_back( body );
1340                newBody->push_front( outerTry );
1341        }
1342
1343        // monitor_guard_t __guard = { __monitors, # };
1344        newBody->push_front(
1345                new ast::DeclStmt(
1346                        location,
1347                        new ast::ObjectDecl(
1348                                location,
1349                                "__guard",
1350                                lock_guard_struct,
1351                                new ast::ListInit(
1352                                        location,
1353                                        {
1354                                                new ast::SingleInit(
1355                                                        location,
1356                                                        new ast::VariableExpr( location, monitors ) ),
1357                                                new ast::SingleInit(
1358                                                        location,
1359                                                        ast::ConstantExpr::from_ulong( location, flattenedArgs.size() ) ),
1360                                        },
1361                                        {},
1362                                        ast::MaybeConstruct
1363                                )
1364                        )
1365                )
1366        );
1367
1368        // monitor$ * __monitors[] = { get_monitor(a), get_monitor(b) };
1369        newBody->push_front( new ast::DeclStmt( location, monitors ) );
1370
1371        // // The parameter for both __lock_curr/__unlock_curr routines.
1372        // ast::ObjectDecl * this_decl = new ast::ObjectDecl(
1373        //      location,
1374        //      "this",
1375        //      new ast::PointerType( new ast::VoidType() ),
1376        //      nullptr,
1377        //      {},
1378        //      ast::Linkage::Cforall
1379        // );
1380
1381        // ast::FunctionDecl * lock_decl = new ast::FunctionDecl(
1382        //      location,
1383        //      lockFnName,
1384        //      { /* forall */ },
1385        //      {
1386        //              // Copy the declaration of this.
1387        //              this_decl,
1388        //      },
1389        //      { /* returns */ },
1390        //      nullptr,
1391        //      0,
1392        //      ast::Linkage::Cforall,
1393        //      { /* attributes */ },
1394        //      ast::Function::Inline
1395        // );
1396
1397        // ast::FunctionDecl * unlock_decl = new ast::FunctionDecl(
1398        //      location,
1399        //      unlockFnName,
1400        //      { /* forall */ },
1401        //      {
1402        //              // Copy the declaration of this.
1403        //              ast::deepCopy( this_decl ),
1404        //      },
1405        //      { /* returns */ },
1406        //      nullptr,
1407        //      0,
1408        //      ast::Linkage::Cforall,
1409        //      { /* attributes */ },
1410        //      ast::Function::Inline
1411        // );
1412
1413        // ast::IfStmt * outerLockIf = nullptr;
1414        // ast::IfStmt * outerUnlockIf = nullptr;
1415        // ast::IfStmt * lastLockIf = nullptr;
1416        // ast::IfStmt * lastUnlockIf = nullptr;
1417
1418        // //adds an if/elif clause for each lock to assign type from void ptr based on ptr address
1419        // for ( long unsigned int i = 0; i < args.size(); i++ ) {
1420        //      ast::VariableExpr * thisParam = new ast::VariableExpr( location, InitTweak::getParamThis( lock_decl ) );
1421        //      ast::UntypedExpr * ifCond = new ast::UntypedExpr( location,
1422        //              new ast::NameExpr( location, "?==?" ), {
1423        //                      thisParam,
1424        //                      new ast::CastExpr( location, new ast::AddressExpr( location, args.at(i) ), new ast::PointerType( new ast::VoidType() ))
1425        //              }
1426        //      );
1427
1428        //      ast::IfStmt * currLockIf = new ast::IfStmt(
1429        //              location,
1430        //              ast::deepCopy( ifCond ),
1431        //              genVirtLockUnlockExpr( "lock", args.at(i), location, ast::deepCopy( thisParam ) )
1432        //      );
1433
1434        //      ast::IfStmt * currUnlockIf = new ast::IfStmt(
1435        //              location,
1436        //              ifCond,
1437        //              genVirtLockUnlockExpr( "unlock", args.at(i), location, ast::deepCopy( thisParam ) )
1438        //      );
1439
1440        //      if ( i == 0 ) {
1441        //              outerLockIf = currLockIf;
1442        //              outerUnlockIf = currUnlockIf;
1443        //      } else {
1444        //              // add ifstmt to else of previous stmt
1445        //              lastLockIf->else_ = currLockIf;
1446        //              lastUnlockIf->else_ = currUnlockIf;
1447        //      }
1448
1449        //      lastLockIf = currLockIf;
1450        //      lastUnlockIf = currUnlockIf;
1451        // }
1452
1453        // // add pointer typing if/elifs to body of routines
1454        // lock_decl->stmts = new ast::CompoundStmt( location, { outerLockIf } );
1455        // unlock_decl->stmts = new ast::CompoundStmt( location, { outerUnlockIf } );
1456
1457        // // add routines to scope
1458        // declsToAddBefore.push_back( lock_decl );
1459        // declsToAddBefore.push_back( unlock_decl );
1460
1461        // newBody->push_front(new ast::DeclStmt( location, lock_decl ));
1462        // newBody->push_front(new ast::DeclStmt( location, unlock_decl ));
1463
1464        return newBody;
1465}
1466
1467ast::CompoundStmt * MutexKeyword::addThreadDtorStatements(
1468                const ast::FunctionDecl*, const ast::CompoundStmt * body,
1469                const std::vector<const ast::DeclWithType * > & args ) {
1470        assert( args.size() == 1 );
1471        const ast::DeclWithType * arg = args.front();
1472        const ast::Type * argType = arg->get_type();
1473        assert( argType->is_mutex() );
1474
1475        ast::CompoundStmt * mutBody = ast::mutate( body );
1476
1477        // The code is generated near the front of the body.
1478        const CodeLocation & location = mutBody->location;
1479
1480        // thread_dtor_guard_t __guard = { this, intptr( 0 ) };
1481        mutBody->push_front( new ast::DeclStmt(
1482                location,
1483                new ast::ObjectDecl(
1484                        location,
1485                        "__guard",
1486                        new ast::StructInstType( thread_guard_decl ),
1487                        new ast::ListInit(
1488                                location,
1489                                {
1490                                        new ast::SingleInit( location,
1491                                                new ast::CastExpr( location,
1492                                                        new ast::VariableExpr( location, arg ), argType ) ),
1493                                        new ast::SingleInit(
1494                                                location,
1495                                                new ast::UntypedExpr(
1496                                                        location,
1497                                                        new ast::NameExpr( location, "intptr" ), {
1498                                                                ast::ConstantExpr::from_int( location, 0 ),
1499                                                        }
1500                                                ) ),
1501                                },
1502                                {},
1503                                ast::MaybeConstruct
1504                        )
1505                )
1506        ));
1507
1508        return mutBody;
1509}
1510
1511ast::ptr<ast::Type> MutexKeyword::generic_func =
1512        new ast::FunctionType( ast::VariableArgs );
1513
1514// --------------------------------------------------------------------------
1515struct ThreadStarter final {
1516        void previsit( const ast::StructDecl * decl );
1517        const ast::FunctionDecl * postvisit( const ast::FunctionDecl * decl );
1518
1519private:
1520        bool thread_ctor_seen = false;
1521        const ast::StructDecl * thread_decl = nullptr;
1522};
1523
1524void ThreadStarter::previsit( const ast::StructDecl * decl ) {
1525        if ( decl->body && decl->name == "thread$" ) {
1526                assert( !thread_decl );
1527                thread_decl = decl;
1528        }
1529}
1530
1531const ast::FunctionDecl * ThreadStarter::postvisit( const ast::FunctionDecl * decl ) {
1532        if ( !CodeGen::isConstructor( decl->name ) ) return decl;
1533
1534        // Seach for the thread constructor.
1535        // (Are the "prefixes" of these to blocks the same?)
1536        const ast::Type * typeof_this = InitTweak::getTypeofThis( decl->type );
1537        auto ctored_type = dynamic_cast<const ast::StructInstType *>( typeof_this );
1538        if ( ctored_type && ctored_type->base == thread_decl ) {
1539                thread_ctor_seen = true;
1540        }
1541
1542        // Modify this declaration, the extra checks to see if we will are first.
1543        const ast::ptr<ast::DeclWithType> & param = decl->params.front();
1544        auto type = dynamic_cast<const ast::StructInstType *>(
1545                ast::getPointerBase( param->get_type() ) );
1546        if ( nullptr == type ) return decl;
1547        if ( !type->base->is_thread() ) return decl;
1548        if ( !thread_decl || !thread_ctor_seen ) {
1549                SemanticError( type->base->location, "thread keyword requires threads to be in scope, add #include <thread.hfa>." );
1550        }
1551        const ast::CompoundStmt * stmt = decl->stmts;
1552        if ( nullptr == stmt ) return decl;
1553
1554        // Now do the actual modification:
1555        ast::CompoundStmt * mutStmt = ast::mutate( stmt );
1556        const CodeLocation & location = mutStmt->location;
1557        mutStmt->push_back(
1558                new ast::ExprStmt(
1559                        location,
1560                        new ast::UntypedExpr(
1561                                location,
1562                                new ast::NameExpr( location, "__thrd_start" ),
1563                                {
1564                                        new ast::VariableExpr( location, param ),
1565                                        new ast::NameExpr( location, "main" ),
1566                                }
1567                        )
1568                )
1569        );
1570
1571        return ast::mutate_field( decl, &ast::FunctionDecl::stmts, mutStmt );
1572}
1573
1574} // namespace
1575
1576// --------------------------------------------------------------------------
1577// Interface Functions:
1578
1579void implementKeywords( ast::TranslationUnit & translationUnit ) {
1580        ast::Pass<ThreadKeyword>::run( translationUnit );
1581        ast::Pass<CoroutineKeyword>::run( translationUnit );
1582        ast::Pass<MonitorKeyword>::run( translationUnit );
1583        ast::Pass<GeneratorKeyword>::run( translationUnit );
1584        ast::Pass<SuspendKeyword>::run( translationUnit );
1585}
1586
1587void implementMutex( ast::TranslationUnit & translationUnit ) {
1588        ast::Pass<MutexKeyword>::run( translationUnit );
1589}
1590
1591void implementThreadStarter( ast::TranslationUnit & translationUnit ) {
1592        ast::Pass<ThreadStarter>::run( translationUnit );
1593}
1594
1595}
1596
1597// Local Variables: //
1598// tab-width: 4 //
1599// mode: c++ //
1600// compile-command: "make install" //
1601// End: //
Note: See TracBrowser for help on using the repository browser.