source: src/Concurrency/KeywordsNew.cpp @ 4269d1b

Last change on this file since 4269d1b was b2ecd48, checked in by Andrew Beach <ajbeach@…>, 16 months ago

Changes related to invariant checking scoping, it is not ready by these are unlikely to change.

  • Property mode set to 100644
File size: 48.9 KB
Line 
1//
2// Cforall Version 1.0.0 Copyright (C) 2016 University of Waterloo
3//
4// The contents of this file are covered under the licence agreement in the
5// file "LICENCE" distributed with Cforall.
6//
7// KeywordsNew.cpp -- Implement concurrency constructs from their keywords.
8//
9// Author           : Andrew Beach
10// Created On       : Tue Nov 16  9:53:00 2021
11// Last Modified By : Andrew Beach
12// Last Modified On : Fri Mar 11 10:40:00 2022
13// Update Count     : 2
14//
15
16#include <iostream>
17
18#include "Concurrency/Keywords.h"
19
20#include "AST/Copy.hpp"
21#include "AST/Decl.hpp"
22#include "AST/Expr.hpp"
23#include "AST/Inspect.hpp"
24#include "AST/Pass.hpp"
25#include "AST/Stmt.hpp"
26#include "AST/DeclReplacer.hpp"
27#include "AST/TranslationUnit.hpp"
28#include "CodeGen/OperatorTable.h"
29#include "Common/Examine.h"
30#include "Common/utility.h"
31#include "Common/UniqueName.h"
32#include "ControlStruct/LabelGeneratorNew.hpp"
33#include "InitTweak/InitTweak.h"
34#include "Virtual/Tables.h"
35
36namespace Concurrency {
37
38namespace {
39
40// --------------------------------------------------------------------------
41// Loose Helper Functions:
42
43/// Detect threads constructed with the keyword thread.
44bool isThread( const ast::DeclWithType * decl ) {
45        auto baseType = decl->get_type()->stripDeclarator();
46        auto instType = dynamic_cast<const ast::StructInstType *>( baseType );
47        if ( nullptr == instType ) { return false; }
48        return instType->base->is_thread();
49}
50
51/// Get the virtual type id if given a type name.
52std::string typeIdType( std::string const & exception_name ) {
53        return exception_name.empty() ? std::string()
54                : Virtual::typeIdType( exception_name );
55}
56
57/// Get the vtable type name if given a type name.
58std::string vtableTypeName( std::string const & exception_name ) {
59        return exception_name.empty() ? std::string()
60                : Virtual::vtableTypeName( exception_name );
61}
62
63static ast::Type * mutate_under_references( ast::ptr<ast::Type>& type ) {
64        ast::Type * mutType = type.get_and_mutate();
65        for ( ast::ReferenceType * mutRef
66                ; (mutRef = dynamic_cast<ast::ReferenceType *>( mutType ))
67                ; mutType = mutRef->base.get_and_mutate() );
68        return mutType;
69}
70
71// Describe that it adds the generic parameters and the uses of the generic
72// parameters on the function and first "this" argument.
73ast::FunctionDecl * fixupGenerics(
74                const ast::FunctionDecl * func, const ast::StructDecl * decl ) {
75        const CodeLocation & location = decl->location;
76        // We have to update both the declaration
77        auto mutFunc = ast::mutate( func );
78        auto mutType = mutFunc->type.get_and_mutate();
79
80        if ( decl->params.empty() ) {
81                return mutFunc;
82        }
83
84        assert( 0 != mutFunc->params.size() );
85        assert( 0 != mutType->params.size() );
86
87        // Add the "forall" clause information.
88        for ( const ast::ptr<ast::TypeDecl> & typeParam : decl->params ) {
89                auto typeDecl = ast::deepCopy( typeParam );
90                mutFunc->type_params.push_back( typeDecl );
91                mutType->forall.push_back( new ast::TypeInstType( typeDecl ) );
92                for ( auto & assertion : typeDecl->assertions ) {
93                        mutFunc->assertions.push_back( assertion );
94                        mutType->assertions.emplace_back(
95                                new ast::VariableExpr( location, assertion ) );
96                }
97                typeDecl->assertions.clear();
98        }
99
100        // Even chain_mutate is not powerful enough for this:
101        ast::ptr<ast::Type>& paramType = strict_dynamic_cast<ast::ObjectDecl *>(
102                mutFunc->params[0].get_and_mutate() )->type;
103        auto paramTypeInst = strict_dynamic_cast<ast::StructInstType *>(
104                mutate_under_references( paramType ) );
105        auto typeParamInst = strict_dynamic_cast<ast::StructInstType *>(
106                mutate_under_references( mutType->params[0] ) );
107
108        for ( const ast::ptr<ast::TypeDecl> & typeDecl : mutFunc->type_params ) {
109                paramTypeInst->params.push_back(
110                        new ast::TypeExpr( location, new ast::TypeInstType( typeDecl ) ) );
111                typeParamInst->params.push_back(
112                        new ast::TypeExpr( location, new ast::TypeInstType( typeDecl ) ) );
113        }
114
115        return mutFunc;
116}
117
118// --------------------------------------------------------------------------
119struct ConcurrentSueKeyword : public ast::WithDeclsToAdd<> {
120        ConcurrentSueKeyword(
121                std::string&& type_name, std::string&& field_name,
122                std::string&& getter_name, std::string&& context_error,
123                std::string&& exception_name,
124                bool needs_main, ast::AggregateDecl::Aggregate cast_target
125        ) :
126                type_name( type_name ), field_name( field_name ),
127                getter_name( getter_name ), context_error( context_error ),
128                exception_name( exception_name ),
129                typeid_name( typeIdType( exception_name ) ),
130                vtable_name( vtableTypeName( exception_name ) ),
131                needs_main( needs_main ), cast_target( cast_target )
132        {}
133
134        virtual ~ConcurrentSueKeyword() {}
135
136        const ast::Decl * postvisit( const ast::StructDecl * decl );
137        const ast::DeclWithType * postvisit( const ast::FunctionDecl * decl );
138        const ast::Expr * postvisit( const ast::KeywordCastExpr * expr );
139
140        struct StructAndField {
141                const ast::StructDecl * decl;
142                const ast::ObjectDecl * field;
143        };
144
145        const ast::StructDecl * handleStruct( const ast::StructDecl * );
146        void handleMain( const ast::FunctionDecl *, const ast::StructInstType * );
147        void addTypeId( const ast::StructDecl * );
148        void addVtableForward( const ast::StructDecl * );
149        const ast::FunctionDecl * forwardDeclare( const ast::StructDecl * );
150        StructAndField addField( const ast::StructDecl * );
151        void addGetRoutines( const ast::ObjectDecl *, const ast::FunctionDecl * );
152        void addLockUnlockRoutines( const ast::StructDecl * );
153
154private:
155        const std::string type_name;
156        const std::string field_name;
157        const std::string getter_name;
158        const std::string context_error;
159        const std::string exception_name;
160        const std::string typeid_name;
161        const std::string vtable_name;
162        const bool needs_main;
163        const ast::AggregateDecl::Aggregate cast_target;
164
165        const ast::StructDecl   * type_decl = nullptr;
166        const ast::FunctionDecl * dtor_decl = nullptr;
167        const ast::StructDecl * except_decl = nullptr;
168        const ast::StructDecl * typeid_decl = nullptr;
169        const ast::StructDecl * vtable_decl = nullptr;
170
171};
172
173// Handles thread type declarations:
174//
175// thread Mythread {                         struct MyThread {
176//  int data;                                  int data;
177//  a_struct_t more_data;                      a_struct_t more_data;
178//                                =>             thread$ __thrd_d;
179// };                                        };
180//                                           static inline thread$ * get_thread( MyThread * this ) { return &this->__thrd_d; }
181//
182struct ThreadKeyword final : public ConcurrentSueKeyword {
183        ThreadKeyword() : ConcurrentSueKeyword(
184                "thread$",
185                "__thrd",
186                "get_thread",
187                "thread keyword requires threads to be in scope, add #include <thread.hfa>\n",
188                "ThreadCancelled",
189                true,
190                ast::AggregateDecl::Thread )
191        {}
192
193        virtual ~ThreadKeyword() {}
194};
195
196// Handles coroutine type declarations:
197//
198// coroutine MyCoroutine {                   struct MyCoroutine {
199//  int data;                                  int data;
200//  a_struct_t more_data;                      a_struct_t more_data;
201//                                =>             coroutine$ __cor_d;
202// };                                        };
203//                                           static inline coroutine$ * get_coroutine( MyCoroutine * this ) { return &this->__cor_d; }
204//
205struct CoroutineKeyword final : public ConcurrentSueKeyword {
206        CoroutineKeyword() : ConcurrentSueKeyword(
207                "coroutine$",
208                "__cor",
209                "get_coroutine",
210                "coroutine keyword requires coroutines to be in scope, add #include <coroutine.hfa>\n",
211                "CoroutineCancelled",
212                true,
213                ast::AggregateDecl::Coroutine )
214        {}
215
216        virtual ~CoroutineKeyword() {}
217};
218
219// Handles monitor type declarations:
220//
221// monitor MyMonitor {                       struct MyMonitor {
222//  int data;                                  int data;
223//  a_struct_t more_data;                      a_struct_t more_data;
224//                                =>             monitor$ __mon_d;
225// };                                        };
226//                                           static inline monitor$ * get_coroutine( MyMonitor * this ) {
227//                                               return &this->__cor_d;
228//                                           }
229//                                           void lock(MyMonitor & this) {
230//                                               lock(get_monitor(this));
231//                                           }
232//                                           void unlock(MyMonitor & this) {
233//                                               unlock(get_monitor(this));
234//                                           }
235//
236struct MonitorKeyword final : public ConcurrentSueKeyword {
237        MonitorKeyword() : ConcurrentSueKeyword(
238                "monitor$",
239                "__mon",
240                "get_monitor",
241                "monitor keyword requires monitors to be in scope, add #include <monitor.hfa>\n",
242                "",
243                false,
244                ast::AggregateDecl::Monitor )
245        {}
246
247        virtual ~MonitorKeyword() {}
248};
249
250// Handles generator type declarations:
251//
252// generator MyGenerator {                   struct MyGenerator {
253//  int data;                                  int data;
254//  a_struct_t more_data;                      a_struct_t more_data;
255//                                =>             int __generator_state;
256// };                                        };
257//
258struct GeneratorKeyword final : public ConcurrentSueKeyword {
259        GeneratorKeyword() : ConcurrentSueKeyword(
260                "generator$",
261                "__generator_state",
262                "get_generator",
263                "Unable to find builtin type generator$\n",
264                "",
265                true,
266                ast::AggregateDecl::Generator )
267        {}
268
269        virtual ~GeneratorKeyword() {}
270};
271
272const ast::Decl * ConcurrentSueKeyword::postvisit(
273                const ast::StructDecl * decl ) {
274        if ( !decl->body ) {
275                return decl;
276        } else if ( cast_target == decl->kind ) {
277                return handleStruct( decl );
278        } else if ( type_name == decl->name ) {
279                assert( !type_decl );
280                type_decl = decl;
281        } else if ( exception_name == decl->name ) {
282                assert( !except_decl );
283                except_decl = decl;
284        } else if ( typeid_name == decl->name ) {
285                assert( !typeid_decl );
286                typeid_decl = decl;
287        } else if ( vtable_name == decl->name ) {
288                assert( !vtable_decl );
289                vtable_decl = decl;
290        }
291        return decl;
292}
293
294// Try to get the full definition, but raise an error on conflicts.
295const ast::FunctionDecl * getDefinition(
296                const ast::FunctionDecl * old_decl,
297                const ast::FunctionDecl * new_decl ) {
298        if ( !new_decl->stmts ) {
299                return old_decl;
300        } else if ( !old_decl->stmts ) {
301                return new_decl;
302        } else {
303                assert( !old_decl->stmts || !new_decl->stmts );
304                return nullptr;
305        }
306}
307
308const ast::DeclWithType * ConcurrentSueKeyword::postvisit(
309                const ast::FunctionDecl * decl ) {
310        if ( type_decl && isDestructorFor( decl, type_decl ) ) {
311                // Check for forward declarations, try to get the full definition.
312                dtor_decl = (dtor_decl) ? getDefinition( dtor_decl, decl ) : decl;
313        } else if ( !vtable_name.empty() && decl->has_body() ) {
314                if (const ast::DeclWithType * param = isMainFor( decl, cast_target )) {
315                        if ( !vtable_decl ) {
316                                SemanticError( decl, context_error );
317                        }
318                        // Should be safe because of isMainFor.
319                        const ast::StructInstType * struct_type =
320                                static_cast<const ast::StructInstType *>(
321                                        static_cast<const ast::ReferenceType *>(
322                                                param->get_type() )->base.get() );
323
324                        handleMain( decl, struct_type );
325                }
326        }
327        return decl;
328}
329
330const ast::Expr * ConcurrentSueKeyword::postvisit(
331                const ast::KeywordCastExpr * expr ) {
332        if ( cast_target == expr->target ) {
333                // Convert `(thread &)ex` to `(thread$ &)*get_thread(ex)`, etc.
334                if ( !type_decl || !dtor_decl ) {
335                        SemanticError( expr, context_error );
336                }
337                assert( nullptr == expr->result );
338                auto cast = ast::mutate( expr );
339                cast->result = new ast::ReferenceType( new ast::StructInstType( type_decl ) );
340                cast->concrete_target.field  = field_name;
341                cast->concrete_target.getter = getter_name;
342                return cast;
343        }
344        return expr;
345}
346
347const ast::StructDecl * ConcurrentSueKeyword::handleStruct(
348                const ast::StructDecl * decl ) {
349        assert( decl->body );
350
351        if ( !type_decl || !dtor_decl ) {
352                SemanticError( decl, context_error );
353        }
354
355        if ( !exception_name.empty() ) {
356                if( !typeid_decl || !vtable_decl ) {
357                        SemanticError( decl, context_error );
358                }
359                addTypeId( decl );
360                addVtableForward( decl );
361        }
362
363        const ast::FunctionDecl * func = forwardDeclare( decl );
364        StructAndField addFieldRet = addField( decl );
365        decl = addFieldRet.decl;
366        const ast::ObjectDecl * field = addFieldRet.field;
367
368        addGetRoutines( field, func );
369        // Add routines to monitors for use by mutex stmt.
370        if ( ast::AggregateDecl::Monitor == cast_target ) {
371                addLockUnlockRoutines( decl );
372        }
373
374        return decl;
375}
376
377void ConcurrentSueKeyword::handleMain(
378                const ast::FunctionDecl * decl, const ast::StructInstType * type ) {
379        assert( vtable_decl );
380        assert( except_decl );
381
382        const CodeLocation & location = decl->location;
383
384        std::vector<ast::ptr<ast::Expr>> poly_args = {
385                new ast::TypeExpr( location, type ),
386        };
387        ast::ObjectDecl * vtable_object = Virtual::makeVtableInstance(
388                location,
389                "_default_vtable_object_declaration",
390                new ast::StructInstType( vtable_decl, copy( poly_args ) ),
391                type,
392                nullptr
393        );
394        declsToAddAfter.push_back( vtable_object );
395        declsToAddAfter.push_back(
396                new ast::ObjectDecl(
397                        location,
398                        Virtual::concurrentDefaultVTableName(),
399                        new ast::ReferenceType( vtable_object->type, ast::CV::Const ),
400                        new ast::SingleInit( location,
401                                new ast::VariableExpr( location, vtable_object ) )
402                )
403        );
404        declsToAddAfter.push_back( Virtual::makeGetExceptionFunction(
405                location,
406                vtable_object,
407                new ast::StructInstType( except_decl, copy( poly_args ) )
408        ) );
409}
410
411void ConcurrentSueKeyword::addTypeId( const ast::StructDecl * decl ) {
412        assert( typeid_decl );
413        const CodeLocation & location = decl->location;
414
415        ast::StructInstType * typeid_type =
416                new ast::StructInstType( typeid_decl, ast::CV::Const );
417        typeid_type->params.push_back(
418                new ast::TypeExpr( location, new ast::StructInstType( decl ) ) );
419        declsToAddBefore.push_back(
420                Virtual::makeTypeIdInstance( location, typeid_type ) );
421        // If the typeid_type is going to be kept, the other reference will have
422        // been made by now, but we also get to avoid extra mutates.
423        ast::ptr<ast::StructInstType> typeid_cleanup = typeid_type;
424}
425
426void ConcurrentSueKeyword::addVtableForward( const ast::StructDecl * decl ) {
427        assert( vtable_decl );
428        const CodeLocation& location = decl->location;
429
430        std::vector<ast::ptr<ast::Expr>> poly_args = {
431                new ast::TypeExpr( location, new ast::StructInstType( decl ) ),
432        };
433        declsToAddBefore.push_back( Virtual::makeGetExceptionForward(
434                location,
435                new ast::StructInstType( vtable_decl, copy( poly_args ) ),
436                new ast::StructInstType( except_decl, copy( poly_args ) )
437        ) );
438        ast::ObjectDecl * vtable_object = Virtual::makeVtableForward(
439                location,
440                "_default_vtable_object_declaration",
441                new ast::StructInstType( vtable_decl, std::move( poly_args ) )
442        );
443        declsToAddBefore.push_back( vtable_object );
444        declsToAddBefore.push_back(
445                new ast::ObjectDecl(
446                        location,
447                        Virtual::concurrentDefaultVTableName(),
448                        new ast::ReferenceType( vtable_object->type, ast::CV::Const ),
449                        nullptr,
450                        ast::Storage::Extern,
451                        ast::Linkage::Cforall
452                )
453        );
454}
455
456const ast::FunctionDecl * ConcurrentSueKeyword::forwardDeclare(
457                const ast::StructDecl * decl ) {
458        const CodeLocation & location = decl->location;
459
460        ast::StructDecl * forward = ast::deepCopy( decl );
461        {
462                // If removing members makes ref-count go to zero, do not free.
463                ast::ptr<ast::StructDecl> forward_ptr = forward;
464                forward->body = false;
465                forward->members.clear();
466                forward_ptr.release();
467        }
468
469        ast::ObjectDecl * this_decl = new ast::ObjectDecl(
470                location,
471                "this",
472                new ast::ReferenceType( new ast::StructInstType( decl ) )
473        );
474
475        ast::ObjectDecl * ret_decl = new ast::ObjectDecl(
476                location,
477                "ret",
478                new ast::PointerType( new ast::StructInstType( type_decl ) )
479        );
480
481        ast::FunctionDecl * get_decl = new ast::FunctionDecl(
482                location,
483                getter_name,
484                {}, // forall
485                { this_decl }, // params
486                { ret_decl }, // returns
487                nullptr, // stmts
488                ast::Storage::Static,
489                ast::Linkage::Cforall,
490                { new ast::Attribute( "const" ) },
491                ast::Function::Inline
492        );
493        get_decl = fixupGenerics( get_decl, decl );
494
495        ast::FunctionDecl * main_decl = nullptr;
496        if ( needs_main ) {
497                // `this_decl` is copied here because the original was used above.
498                main_decl = new ast::FunctionDecl(
499                        location,
500                        "main",
501                        {},
502                        { ast::deepCopy( this_decl ) },
503                        {},
504                        nullptr,
505                        ast::Storage::Classes(),
506                        ast::Linkage::Cforall
507                );
508                main_decl = fixupGenerics( main_decl, decl );
509        }
510
511        declsToAddBefore.push_back( forward );
512        if ( needs_main ) declsToAddBefore.push_back( main_decl );
513        declsToAddBefore.push_back( get_decl );
514
515        return get_decl;
516}
517
518ConcurrentSueKeyword::StructAndField ConcurrentSueKeyword::addField(
519                const ast::StructDecl * decl ) {
520        const CodeLocation & location = decl->location;
521
522        ast::ObjectDecl * field = new ast::ObjectDecl(
523                location,
524                field_name,
525                new ast::StructInstType( type_decl )
526        );
527
528        auto mutDecl = ast::mutate( decl );
529        mutDecl->members.push_back( field );
530
531        return {mutDecl, field};
532}
533
534void ConcurrentSueKeyword::addGetRoutines(
535                const ast::ObjectDecl * field, const ast::FunctionDecl * forward ) {
536        // Clone the signature and then build the body.
537        ast::FunctionDecl * decl = ast::deepCopy( forward );
538
539        // Say it is generated at the "same" places as the forward declaration.
540        const CodeLocation & location = decl->location;
541
542        const ast::DeclWithType * param = decl->params.front();
543        ast::Stmt * stmt = new ast::ReturnStmt( location,
544                new ast::AddressExpr( location,
545                        new ast::MemberExpr( location,
546                                field,
547                                new ast::CastExpr( location,
548                                        new ast::VariableExpr( location, param ),
549                                        ast::deepCopy( param->get_type()->stripReferences() ),
550                                        ast::ExplicitCast
551                                )
552                        )
553                )
554        );
555
556        decl->stmts = new ast::CompoundStmt( location, { stmt } );
557        declsToAddAfter.push_back( decl );
558}
559
560void ConcurrentSueKeyword::addLockUnlockRoutines(
561                const ast::StructDecl * decl ) {
562        // This should only be used on monitors.
563        assert( ast::AggregateDecl::Monitor == cast_target );
564
565        const CodeLocation & location = decl->location;
566
567        // The parameter for both routines.
568        ast::ObjectDecl * this_decl = new ast::ObjectDecl(
569                location,
570                "this",
571                new ast::ReferenceType( new ast::StructInstType( decl ) )
572        );
573
574        ast::FunctionDecl * lock_decl = new ast::FunctionDecl(
575                location,
576                "lock",
577                { /* forall */ },
578                {
579                        // Copy the declaration of this.
580                        ast::deepCopy( this_decl ),
581                },
582                { /* returns */ },
583                nullptr,
584                ast::Storage::Static,
585                ast::Linkage::Cforall,
586                { /* attributes */ },
587                ast::Function::Inline
588        );
589        lock_decl = fixupGenerics( lock_decl, decl );
590
591        lock_decl->stmts = new ast::CompoundStmt( location, {
592                new ast::ExprStmt( location,
593                        new ast::UntypedExpr( location,
594                                new ast::NameExpr( location, "lock" ),
595                                {
596                                        new ast::UntypedExpr( location,
597                                                new ast::NameExpr( location, "get_monitor" ),
598                                                { new ast::VariableExpr( location,
599                                                        InitTweak::getParamThis( lock_decl ) ) }
600                                        )
601                                }
602                        )
603                )
604        } );
605
606        ast::FunctionDecl * unlock_decl = new ast::FunctionDecl(
607                location,
608                "unlock",
609                { /* forall */ },
610                {
611                        // Last use, consume the declaration of this.
612                        this_decl,
613                },
614                { /* returns */ },
615                nullptr,
616                ast::Storage::Static,
617                ast::Linkage::Cforall,
618                { /* attributes */ },
619                ast::Function::Inline
620        );
621        unlock_decl = fixupGenerics( unlock_decl, decl );
622
623        unlock_decl->stmts = new ast::CompoundStmt( location, {
624                new ast::ExprStmt( location,
625                        new ast::UntypedExpr( location,
626                                new ast::NameExpr( location, "unlock" ),
627                                {
628                                        new ast::UntypedExpr( location,
629                                                new ast::NameExpr( location, "get_monitor" ),
630                                                { new ast::VariableExpr( location,
631                                                        InitTweak::getParamThis( unlock_decl ) ) }
632                                        )
633                                }
634                        )
635                )
636        } );
637
638        declsToAddAfter.push_back( lock_decl );
639        declsToAddAfter.push_back( unlock_decl );
640}
641
642
643// --------------------------------------------------------------------------
644struct SuspendKeyword final :
645                public ast::WithStmtsToAdd<>, public ast::WithGuards {
646        SuspendKeyword() = default;
647        virtual ~SuspendKeyword() = default;
648
649        void previsit( const ast::FunctionDecl * );
650        const ast::DeclWithType * postvisit( const ast::FunctionDecl * );
651        const ast::Stmt * postvisit( const ast::SuspendStmt * );
652
653private:
654        bool is_real_suspend( const ast::FunctionDecl * );
655
656        const ast::Stmt * make_generator_suspend( const ast::SuspendStmt * );
657        const ast::Stmt * make_coroutine_suspend( const ast::SuspendStmt * );
658
659        struct LabelPair {
660                ast::Label obj;
661                int idx;
662        };
663
664        LabelPair make_label(const ast::Stmt * stmt ) {
665                labels.push_back( ControlStruct::newLabel( "generator", stmt ) );
666                return { labels.back(), int(labels.size()) };
667        }
668
669        const ast::DeclWithType * in_generator = nullptr;
670        const ast::FunctionDecl * decl_suspend = nullptr;
671        std::vector<ast::Label> labels;
672};
673
674void SuspendKeyword::previsit( const ast::FunctionDecl * decl ) {
675        GuardValue( in_generator ); in_generator = nullptr;
676
677        // If it is the real suspend, grab it if we don't have one already.
678        if ( is_real_suspend( decl ) ) {
679                decl_suspend = decl_suspend ? decl_suspend : decl;
680                return;
681        }
682
683        // Otherwise check if this is a generator main and, if so, handle it.
684        auto param = isMainFor( decl, ast::AggregateDecl::Generator );
685        if ( !param ) return;
686
687        if ( 0 != decl->returns.size() ) {
688                SemanticError( decl->location, "Generator main must return void" );
689        }
690
691        in_generator = param;
692        GuardValue( labels ); labels.clear();
693}
694
695const ast::DeclWithType * SuspendKeyword::postvisit(
696                const ast::FunctionDecl * decl ) {
697        // Only modify a full definition of a generator with states.
698        if ( !decl->stmts || !in_generator || labels.empty() ) return decl;
699
700        const CodeLocation & location = decl->location;
701
702        // Create a new function body:
703        // static void * __generator_labels[] = {&&s0, &&s1, ...};
704        // void * __generator_label = __generator_labels[GEN.__generator_state];
705        // goto * __generator_label;
706        // s0: ;
707        // OLD_BODY
708
709        // This is the null statement inserted right before the body.
710        ast::NullStmt * noop = new ast::NullStmt( location );
711        noop->labels.push_back( ControlStruct::newLabel( "generator", noop ) );
712        const ast::Label & first_label = noop->labels.back();
713
714        // Add each label to the init, starting with the first label.
715        std::vector<ast::ptr<ast::Init>> inits = {
716                new ast::SingleInit( location,
717                        new ast::LabelAddressExpr( location, copy( first_label ) ) ) };
718        // Then go through all the stored labels, and clear the store.
719        for ( auto && label : labels ) {
720                inits.push_back( new ast::SingleInit( label.location,
721                        new ast::LabelAddressExpr( label.location, std::move( label )
722                        ) ) );
723        }
724        labels.clear();
725        // Then construct the initializer itself.
726        auto init = new ast::ListInit( location, std::move( inits ) );
727
728        ast::ObjectDecl * generatorLabels = new ast::ObjectDecl(
729                location,
730                "__generator_labels",
731                new ast::ArrayType(
732                        new ast::PointerType( new ast::VoidType() ),
733                        nullptr,
734                        ast::FixedLen,
735                        ast::DynamicDim
736                ),
737                init,
738                ast::Storage::Classes(),
739                ast::Linkage::AutoGen
740        );
741
742        ast::ObjectDecl * generatorLabel = new ast::ObjectDecl(
743                location,
744                "__generator_label",
745                new ast::PointerType( new ast::VoidType() ),
746                new ast::SingleInit( location,
747                        new ast::UntypedExpr( location,
748                                new ast::NameExpr( location, "?[?]" ),
749                                {
750                                        // TODO: Could be a variable expr.
751                                        new ast::NameExpr( location, "__generator_labels" ),
752                                        new ast::UntypedMemberExpr( location,
753                                                new ast::NameExpr( location, "__generator_state" ),
754                                                new ast::VariableExpr( location, in_generator )
755                                        )
756                                }
757                        )
758                ),
759                ast::Storage::Classes(),
760                ast::Linkage::AutoGen
761        );
762
763        ast::BranchStmt * theGoTo = new ast::BranchStmt(
764                location, new ast::VariableExpr( location, generatorLabel )
765        );
766
767        // The noop goes here in order.
768
769        ast::CompoundStmt * body = new ast::CompoundStmt( location, {
770                { new ast::DeclStmt( location, generatorLabels ) },
771                { new ast::DeclStmt( location, generatorLabel ) },
772                { theGoTo },
773                { noop },
774                { decl->stmts },
775        } );
776
777        auto mutDecl = ast::mutate( decl );
778        mutDecl->stmts = body;
779        return mutDecl;
780}
781
782const ast::Stmt * SuspendKeyword::postvisit( const ast::SuspendStmt * stmt ) {
783        switch ( stmt->kind ) {
784        case ast::SuspendStmt::None:
785                // Use the context to determain the implicit target.
786                if ( in_generator ) {
787                        return make_generator_suspend( stmt );
788                } else {
789                        return make_coroutine_suspend( stmt );
790                }
791        case ast::SuspendStmt::Coroutine:
792                return make_coroutine_suspend( stmt );
793        case ast::SuspendStmt::Generator:
794                // Generator suspends must be directly in a generator.
795                if ( !in_generator ) SemanticError( stmt->location, "'suspend generator' must be used inside main of generator type." );
796                return make_generator_suspend( stmt );
797        }
798        assert( false );
799        return stmt;
800}
801
802/// Find the real/official suspend declaration.
803bool SuspendKeyword::is_real_suspend( const ast::FunctionDecl * decl ) {
804        return ( !decl->linkage.is_mangled
805                && 0 == decl->params.size()
806                && 0 == decl->returns.size()
807                && "__cfactx_suspend" == decl->name );
808}
809
810const ast::Stmt * SuspendKeyword::make_generator_suspend(
811                const ast::SuspendStmt * stmt ) {
812        assert( in_generator );
813        // Target code is:
814        //   GEN.__generator_state = X;
815        //   THEN
816        //   return;
817        //   __gen_X:;
818
819        const CodeLocation & location = stmt->location;
820
821        LabelPair label = make_label( stmt );
822
823        // This is the context saving statement.
824        stmtsToAddBefore.push_back( new ast::ExprStmt( location,
825                new ast::UntypedExpr( location,
826                        new ast::NameExpr( location, "?=?" ),
827                        {
828                                new ast::UntypedMemberExpr( location,
829                                        new ast::NameExpr( location, "__generator_state" ),
830                                        new ast::VariableExpr( location, in_generator )
831                                ),
832                                ast::ConstantExpr::from_int( location, label.idx ),
833                        }
834                )
835        ) );
836
837        // The THEN component is conditional (return is not).
838        if ( stmt->then ) {
839                stmtsToAddBefore.push_back( stmt->then.get() );
840        }
841        stmtsToAddBefore.push_back( new ast::ReturnStmt( location, nullptr ) );
842
843        // The null statement replaces the old suspend statement.
844        return new ast::NullStmt( location, { label.obj } );
845}
846
847const ast::Stmt * SuspendKeyword::make_coroutine_suspend(
848                const ast::SuspendStmt * stmt ) {
849        // The only thing we need from the old statement is the location.
850        const CodeLocation & location = stmt->location;
851
852        if ( !decl_suspend ) {
853                SemanticError( location, "suspend keyword applied to coroutines requires coroutines to be in scope, add #include <coroutine.hfa>\n" );
854        }
855        if ( stmt->then ) {
856                SemanticError( location, "Compound statement following coroutines is not implemented." );
857        }
858
859        return new ast::ExprStmt( location,
860                new ast::UntypedExpr( location,
861                        ast::VariableExpr::functionPointer( location, decl_suspend ) )
862        );
863}
864
865// --------------------------------------------------------------------------
866struct MutexKeyword final : public ast::WithDeclsToAdd<> {
867        const ast::FunctionDecl * postvisit( const ast::FunctionDecl * decl );
868        void postvisit( const ast::StructDecl * decl );
869        const ast::Stmt * postvisit( const ast::MutexStmt * stmt );
870
871        static std::vector<const ast::DeclWithType *> findMutexArgs(
872                        const ast::FunctionDecl * decl, bool & first );
873        static void validate( const ast::DeclWithType * decl );
874
875        ast::CompoundStmt * addDtorStatements( const ast::FunctionDecl* func, const ast::CompoundStmt *, const std::vector<const ast::DeclWithType *> &);
876        ast::CompoundStmt * addStatements( const ast::FunctionDecl* func, const ast::CompoundStmt *, const std::vector<const ast::DeclWithType *> &);
877        ast::CompoundStmt * addStatements( const ast::CompoundStmt * body, const std::vector<ast::ptr<ast::Expr>> & args );
878        ast::CompoundStmt * addThreadDtorStatements( const ast::FunctionDecl* func, const ast::CompoundStmt * body, const std::vector<const ast::DeclWithType *> & args );
879        ast::ExprStmt * genVirtLockUnlockExpr( const std::string & fnName, ast::ptr<ast::Expr> expr, const CodeLocation & location, ast::Expr * param);
880        ast::IfStmt * genTypeDiscrimLockUnlock( const std::string & fnName, const std::vector<ast::ptr<ast::Expr>> & args, const CodeLocation & location, ast::UntypedExpr * thisParam );
881private:
882        const ast::StructDecl * monitor_decl = nullptr;
883        const ast::StructDecl * guard_decl = nullptr;
884        const ast::StructDecl * dtor_guard_decl = nullptr;
885        const ast::StructDecl * thread_guard_decl = nullptr;
886        const ast::StructDecl * lock_guard_decl = nullptr;
887
888        static ast::ptr<ast::Type> generic_func;
889
890        UniqueName mutex_func_namer = UniqueName("__lock_unlock_curr");
891};
892
893const ast::FunctionDecl * MutexKeyword::postvisit(
894                const ast::FunctionDecl * decl ) {
895        bool is_first_argument_mutex = false;
896        const std::vector<const ast::DeclWithType *> mutexArgs =
897                findMutexArgs( decl, is_first_argument_mutex );
898        bool const isDtor = CodeGen::isDestructor( decl->name );
899
900        // Does this function have any mutex arguments that connect to monitors?
901        if ( mutexArgs.empty() ) {
902                // If this is the destructor for a monitor it must be mutex.
903                if ( isDtor ) {
904                        // This reflects MutexKeyword::validate, but no error messages.
905                        const ast::Type * type = decl->type->params.front();
906
907                        // If it's a copy, it's not a mutex.
908                        const ast::ReferenceType * refType = dynamic_cast<const ast::ReferenceType *>( type );
909                        if ( nullptr == refType ) {
910                                return decl;
911                        }
912
913                        // If it is not pointing directly to a type, it's not a mutex.
914                        auto base = refType->base;
915                        if ( base.as<ast::ReferenceType>() ) return decl;
916                        if ( base.as<ast::PointerType>() ) return decl;
917
918                        // If it is not a struct, it's not a mutex.
919                        auto baseStruct = base.as<ast::StructInstType>();
920                        if ( nullptr == baseStruct ) return decl;
921
922                        // If it is a monitor, then it is a monitor.
923                        if( baseStruct->base->is_monitor() || baseStruct->base->is_thread() ) {
924                                SemanticError( decl, "destructors for structures declared as \"monitor\" must use mutex parameters\n" );
925                        }
926                }
927                return decl;
928        }
929
930        // Monitors can't be constructed with mutual exclusion.
931        if ( CodeGen::isConstructor( decl->name ) && is_first_argument_mutex ) {
932                SemanticError( decl, "constructors cannot have mutex parameters\n" );
933        }
934
935        // It makes no sense to have multiple mutex parameters for the destructor.
936        if ( isDtor && mutexArgs.size() != 1 ) {
937                SemanticError( decl, "destructors can only have 1 mutex argument\n" );
938        }
939
940        // Make sure all the mutex arguments are monitors.
941        for ( auto arg : mutexArgs ) {
942                validate( arg );
943        }
944
945        // Check to see if the body needs to be instrument the body.
946        const ast::CompoundStmt * body = decl->stmts;
947        if ( !body ) return decl;
948
949        // Check to if the required headers have been seen.
950        if ( !monitor_decl || !guard_decl || !dtor_guard_decl ) {
951                SemanticError( decl, "mutex keyword requires monitors to be in scope, add #include <monitor.hfa>\n" );
952        }
953
954        // Instrument the body.
955        ast::CompoundStmt * newBody = nullptr;
956        if ( isDtor && isThread( mutexArgs.front() ) ) {
957                if ( !thread_guard_decl ) {
958                        SemanticError( decl, "thread destructor requires threads to be in scope, add #include <thread.hfa>\n" );
959                }
960                newBody = addThreadDtorStatements( decl, body, mutexArgs );
961        } else if ( isDtor ) {
962                newBody = addDtorStatements( decl, body, mutexArgs );
963        } else {
964                newBody = addStatements( decl, body, mutexArgs );
965        }
966        assert( newBody );
967        return ast::mutate_field( decl, &ast::FunctionDecl::stmts, newBody );
968}
969
970void MutexKeyword::postvisit( const ast::StructDecl * decl ) {
971        if ( !decl->body ) {
972                return;
973        } else if ( decl->name == "monitor$" ) {
974                assert( !monitor_decl );
975                monitor_decl = decl;
976        } else if ( decl->name == "monitor_guard_t" ) {
977                assert( !guard_decl );
978                guard_decl = decl;
979        } else if ( decl->name == "monitor_dtor_guard_t" ) {
980                assert( !dtor_guard_decl );
981                dtor_guard_decl = decl;
982        } else if ( decl->name == "thread_dtor_guard_t" ) {
983                assert( !thread_guard_decl );
984                thread_guard_decl = decl;
985        } else if ( decl->name == "__mutex_stmt_lock_guard" ) {
986                assert( !lock_guard_decl );
987                lock_guard_decl = decl;
988        }
989}
990
991const ast::Stmt * MutexKeyword::postvisit( const ast::MutexStmt * stmt ) {
992        if ( !lock_guard_decl ) {
993                SemanticError( stmt->location, "mutex stmt requires a header, add #include <mutex_stmt.hfa>\n" );
994        }
995        ast::CompoundStmt * body =
996                        new ast::CompoundStmt( stmt->location, { stmt->stmt } );
997       
998        return addStatements( body, stmt->mutexObjs );;
999}
1000
1001std::vector<const ast::DeclWithType *> MutexKeyword::findMutexArgs(
1002                const ast::FunctionDecl * decl, bool & first ) {
1003        std::vector<const ast::DeclWithType *> mutexArgs;
1004
1005        bool once = true;
1006        for ( auto arg : decl->params ) {
1007                const ast::Type * type = arg->get_type();
1008                if ( type->is_mutex() ) {
1009                        if ( once ) first = true;
1010                        mutexArgs.push_back( arg.get() );
1011                }
1012                once = false;
1013        }
1014        return mutexArgs;
1015}
1016
1017void MutexKeyword::validate( const ast::DeclWithType * decl ) {
1018        const ast::Type * type = decl->get_type();
1019
1020        // If it's a copy, it's not a mutex.
1021        const ast::ReferenceType * refType = dynamic_cast<const ast::ReferenceType *>( type );
1022        if ( nullptr == refType ) {
1023                SemanticError( decl, "Mutex argument must be of reference type " );
1024        }
1025
1026        // If it is not pointing directly to a type, it's not a mutex.
1027        auto base = refType->base;
1028        if ( base.as<ast::ReferenceType>() || base.as<ast::PointerType>() ) {
1029                SemanticError( decl, "Mutex argument have exactly one level of indirection " );
1030        }
1031
1032        // If it is not a struct, it's not a mutex.
1033        auto baseStruct = base.as<ast::StructInstType>();
1034        if ( nullptr == baseStruct ) return;
1035
1036        // Make sure that only the outer reference is mutex.
1037        if( baseStruct->is_mutex() ) {
1038                SemanticError( decl, "mutex keyword may only appear once per argument " );
1039        }
1040}
1041
1042ast::CompoundStmt * MutexKeyword::addDtorStatements(
1043                const ast::FunctionDecl* func, const ast::CompoundStmt * body,
1044                const std::vector<const ast::DeclWithType *> & args ) {
1045        ast::Type * argType = ast::shallowCopy( args.front()->get_type() );
1046        argType->set_mutex( false );
1047
1048        ast::CompoundStmt * mutBody = ast::mutate( body );
1049
1050        // Generated code goes near the beginning of body:
1051        const CodeLocation & location = mutBody->location;
1052
1053        const ast::ObjectDecl * monitor = new ast::ObjectDecl(
1054                location,
1055                "__monitor",
1056                new ast::PointerType( new ast::StructInstType( monitor_decl ) ),
1057                new ast::SingleInit(
1058                        location,
1059                        new ast::UntypedExpr(
1060                                location,
1061                                new ast::NameExpr( location, "get_monitor" ),
1062                                { new ast::CastExpr(
1063                                        location,
1064                                        new ast::VariableExpr( location, args.front() ),
1065                                        argType, ast::ExplicitCast
1066                                ) }
1067                        )
1068                )
1069        );
1070
1071        assert( generic_func );
1072
1073        // In reverse order:
1074        // monitor_dtor_guard_t __guard = { __monitor, func, false };
1075        mutBody->push_front(
1076                new ast::DeclStmt( location, new ast::ObjectDecl(
1077                        location,
1078                        "__guard",
1079                        new ast::StructInstType( dtor_guard_decl ),
1080                        new ast::ListInit(
1081                                location,
1082                                {
1083                                        new ast::SingleInit( location,
1084                                                new ast::AddressExpr( location,
1085                                                        new ast::VariableExpr( location, monitor ) ) ),
1086                                        new ast::SingleInit( location,
1087                                                new ast::CastExpr( location,
1088                                                        new ast::VariableExpr( location, func ),
1089                                                        generic_func,
1090                                                        ast::ExplicitCast ) ),
1091                                        new ast::SingleInit( location,
1092                                                ast::ConstantExpr::from_bool( location, false ) ),
1093                                },
1094                                {},
1095                                ast::MaybeConstruct
1096                        )
1097                ))
1098        );
1099
1100        // monitor$ * __monitor = get_monitor(a);
1101        mutBody->push_front( new ast::DeclStmt( location, monitor ) );
1102
1103        return mutBody;
1104}
1105
1106ast::CompoundStmt * MutexKeyword::addStatements(
1107                const ast::FunctionDecl* func, const ast::CompoundStmt * body,
1108                const std::vector<const ast::DeclWithType * > & args ) {
1109        ast::CompoundStmt * mutBody = ast::mutate( body );
1110
1111        // Code is generated near the beginning of the compound statement.
1112        const CodeLocation & location = mutBody->location;
1113
1114        // Make pointer to the monitors.
1115        ast::ObjectDecl * monitors = new ast::ObjectDecl(
1116                location,
1117                "__monitors",
1118                new ast::ArrayType(
1119                        new ast::PointerType(
1120                                new ast::StructInstType( monitor_decl )
1121                        ),
1122                        ast::ConstantExpr::from_ulong( location, args.size() ),
1123                        ast::FixedLen,
1124                        ast::DynamicDim
1125                ),
1126                new ast::ListInit(
1127                        location,
1128                        map_range<std::vector<ast::ptr<ast::Init>>>(
1129                                args,
1130                                []( const ast::DeclWithType * decl ) {
1131                                        return new ast::SingleInit(
1132                                                decl->location,
1133                                                new ast::UntypedExpr(
1134                                                        decl->location,
1135                                                        new ast::NameExpr( decl->location, "get_monitor" ),
1136                                                        {
1137                                                                new ast::CastExpr(
1138                                                                        decl->location,
1139                                                                        new ast::VariableExpr( decl->location, decl ),
1140                                                                        decl->get_type(),
1141                                                                        ast::ExplicitCast
1142                                                                )
1143                                                        }
1144                                                )
1145                                        );
1146                                }
1147                        )
1148                )
1149        );
1150
1151        assert( generic_func );
1152
1153        // In Reverse Order:
1154        mutBody->push_front(
1155                new ast::DeclStmt( location, new ast::ObjectDecl(
1156                        location,
1157                        "__guard",
1158                        new ast::StructInstType( guard_decl ),
1159                        new ast::ListInit(
1160                                location,
1161                                {
1162                                        new ast::SingleInit( location,
1163                                                new ast::VariableExpr( location, monitors ) ),
1164                                        new ast::SingleInit( location,
1165                                                ast::ConstantExpr::from_ulong( location, args.size() ) ),
1166                                        new ast::SingleInit( location, new ast::CastExpr(
1167                                                location,
1168                                                new ast::VariableExpr( location, func ),
1169                                                generic_func,
1170                                                ast::ExplicitCast
1171                                        ) ),
1172                                },
1173                                {},
1174                                ast::MaybeConstruct
1175                        )
1176                ))
1177        );
1178
1179        // monitor$ * __monitors[] = { get_monitor(a), get_monitor(b) };
1180        mutBody->push_front( new ast::DeclStmt( location, monitors ) );
1181
1182        return mutBody;
1183}
1184
1185// generates a cast to the void ptr to the appropriate lock type and dereferences it before calling lock or unlock on it
1186// used to undo the type erasure done by storing all the lock pointers as void
1187ast::ExprStmt * MutexKeyword::genVirtLockUnlockExpr( const std::string & fnName, ast::ptr<ast::Expr> expr, const CodeLocation & location, ast::Expr * param ) {
1188        return new ast::ExprStmt( location,
1189                new ast::UntypedExpr( location,
1190                        new ast::NameExpr( location, fnName ), {
1191                                ast::UntypedExpr::createDeref(
1192                                        location,
1193                                        new ast::CastExpr( location, 
1194                                                param,
1195                                                new ast::PointerType( new ast::TypeofType( new ast::UntypedExpr(
1196                                                        expr->location,
1197                                                        new ast::NameExpr( expr->location, "__get_mutexstmt_lock_type" ),
1198                                                        { expr }
1199                                                ) ) ),
1200                                                ast::GeneratedFlag::ExplicitCast
1201                                        )
1202                                )
1203                        }
1204                )
1205        );
1206}
1207
1208ast::IfStmt * MutexKeyword::genTypeDiscrimLockUnlock( const std::string & fnName, const std::vector<ast::ptr<ast::Expr>> & args, const CodeLocation & location, ast::UntypedExpr * thisParam ) {
1209        ast::IfStmt * outerLockIf = nullptr;
1210        ast::IfStmt * lastLockIf = nullptr;
1211
1212        //adds an if/elif clause for each lock to assign type from void ptr based on ptr address
1213        for ( long unsigned int i = 0; i < args.size(); i++ ) {
1214               
1215                ast::UntypedExpr * ifCond = new ast::UntypedExpr( location,
1216                        new ast::NameExpr( location, "?==?" ), {
1217                                ast::deepCopy( thisParam ),
1218                                new ast::CastExpr( location, new ast::AddressExpr( location, args.at(i) ), new ast::PointerType( new ast::VoidType() ))
1219                        }
1220                );
1221
1222                ast::IfStmt * currLockIf = new ast::IfStmt( 
1223                        location,
1224                        ifCond,
1225                        genVirtLockUnlockExpr( fnName, args.at(i), location, ast::deepCopy( thisParam ) )
1226                );
1227               
1228                if ( i == 0 ) {
1229                        outerLockIf = currLockIf;
1230                } else {
1231                        // add ifstmt to else of previous stmt
1232                        lastLockIf->else_ = currLockIf;
1233                }
1234
1235                lastLockIf = currLockIf;
1236        }
1237        return outerLockIf;
1238}
1239
1240void flattenTuple( const ast::UntypedTupleExpr * tuple, std::vector<ast::ptr<ast::Expr>> & output ) {
1241    for ( auto & expr : tuple->exprs ) {
1242        const ast::UntypedTupleExpr * innerTuple = dynamic_cast<const ast::UntypedTupleExpr *>(expr.get());
1243        if ( innerTuple ) flattenTuple( innerTuple, output );
1244        else output.emplace_back( ast::deepCopy( expr ));
1245    }
1246}
1247
1248ast::CompoundStmt * MutexKeyword::addStatements(
1249                const ast::CompoundStmt * body,
1250                const std::vector<ast::ptr<ast::Expr>> & args ) {
1251
1252        // Code is generated near the beginning of the compound statement.
1253        const CodeLocation & location = body->location;
1254
1255                // final body to return
1256        ast::CompoundStmt * newBody = new ast::CompoundStmt( location );
1257
1258        // std::string lockFnName = mutex_func_namer.newName();
1259        // std::string unlockFnName = mutex_func_namer.newName();
1260
1261    // If any arguments to the mutex stmt are tuples, flatten them
1262    std::vector<ast::ptr<ast::Expr>> flattenedArgs;
1263    for ( auto & arg : args ) {
1264        const ast::UntypedTupleExpr * tuple = dynamic_cast<const ast::UntypedTupleExpr *>(args.at(0).get());
1265        if ( tuple ) flattenTuple( tuple, flattenedArgs );
1266        else flattenedArgs.emplace_back( ast::deepCopy( arg ));
1267    }
1268
1269        // Make pointer to the monitors.
1270        ast::ObjectDecl * monitors = new ast::ObjectDecl(
1271                location,
1272                "__monitors",
1273                new ast::ArrayType(
1274                        new ast::PointerType(
1275                                new ast::VoidType()
1276                        ),
1277                        ast::ConstantExpr::from_ulong( location, flattenedArgs.size() ),
1278                        ast::FixedLen,
1279                        ast::DynamicDim
1280                ),
1281                new ast::ListInit(
1282                        location,
1283                        map_range<std::vector<ast::ptr<ast::Init>>>(
1284                                flattenedArgs, [](const ast::Expr * expr) {
1285                                        return new ast::SingleInit(
1286                                                expr->location,
1287                                                new ast::UntypedExpr(
1288                                                        expr->location,
1289                                                        new ast::NameExpr( expr->location, "__get_mutexstmt_lock_ptr" ),
1290                                                        { expr }
1291                                                )
1292                                        );
1293                                }
1294                        )
1295                )
1296        );
1297
1298        ast::StructInstType * lock_guard_struct =
1299                        new ast::StructInstType( lock_guard_decl );
1300
1301        // use try stmts to lock and finally to unlock
1302        ast::TryStmt * outerTry = nullptr;
1303        ast::TryStmt * currentTry;
1304        ast::CompoundStmt * lastBody = nullptr;
1305
1306        // adds a nested try stmt for each lock we are locking
1307        for ( long unsigned int i = 0; i < flattenedArgs.size(); i++ ) {
1308                ast::UntypedExpr * innerAccess = new ast::UntypedExpr( 
1309                        location,
1310                        new ast::NameExpr( location,"?[?]" ), {
1311                                new ast::NameExpr( location, "__monitors" ),
1312                                ast::ConstantExpr::from_int( location, i )
1313                        }
1314                );
1315
1316                // make the try body
1317                ast::CompoundStmt * currTryBody = new ast::CompoundStmt( location );
1318                ast::IfStmt * lockCall = genTypeDiscrimLockUnlock( "lock", flattenedArgs, location, innerAccess );
1319                currTryBody->push_back( lockCall );
1320
1321                // make the finally stmt
1322                ast::CompoundStmt * currFinallyBody = new ast::CompoundStmt( location );
1323                ast::IfStmt * unlockCall = genTypeDiscrimLockUnlock( "unlock", flattenedArgs, location, innerAccess );
1324                currFinallyBody->push_back( unlockCall );
1325
1326                // construct the current try
1327                currentTry = new ast::TryStmt(
1328                        location,
1329                        currTryBody,
1330                        {},
1331                        new ast::FinallyClause( location, currFinallyBody )
1332                );
1333                if ( i == 0 ) outerTry = currentTry;
1334                else {
1335                        // pushback try into the body of the outer try
1336                        lastBody->push_back( currentTry );
1337                }
1338                lastBody = currTryBody;
1339        }
1340
1341        // push body into innermost try body
1342        if ( lastBody != nullptr ) {
1343                lastBody->push_back( body );
1344                newBody->push_front( outerTry );
1345        }
1346
1347        // monitor_guard_t __guard = { __monitors, # };
1348        newBody->push_front(
1349                new ast::DeclStmt(
1350                        location,
1351                        new ast::ObjectDecl(
1352                                location,
1353                                "__guard",
1354                                lock_guard_struct,
1355                                new ast::ListInit(
1356                                        location,
1357                                        {
1358                                                new ast::SingleInit(
1359                                                        location,
1360                                                        new ast::VariableExpr( location, monitors ) ),
1361                                                new ast::SingleInit(
1362                                                        location,
1363                                                        ast::ConstantExpr::from_ulong( location, flattenedArgs.size() ) ),
1364                                        },
1365                                        {},
1366                                        ast::MaybeConstruct
1367                                )
1368                        )
1369                )
1370        );
1371
1372        // monitor$ * __monitors[] = { get_monitor(a), get_monitor(b) };
1373        newBody->push_front( new ast::DeclStmt( location, monitors ) );
1374
1375        // // The parameter for both __lock_curr/__unlock_curr routines.
1376        // ast::ObjectDecl * this_decl = new ast::ObjectDecl(
1377        //      location,
1378        //      "this",
1379        //      new ast::PointerType( new ast::VoidType() ),
1380        //      nullptr,
1381        //      {},
1382        //      ast::Linkage::Cforall
1383        // );
1384
1385        // ast::FunctionDecl * lock_decl = new ast::FunctionDecl(
1386        //      location,
1387        //      lockFnName,
1388        //      { /* forall */ },
1389        //      {
1390        //              // Copy the declaration of this.
1391        //              this_decl,
1392        //      },
1393        //      { /* returns */ },
1394        //      nullptr,
1395        //      0,
1396        //      ast::Linkage::Cforall,
1397        //      { /* attributes */ },
1398        //      ast::Function::Inline
1399        // );
1400
1401        // ast::FunctionDecl * unlock_decl = new ast::FunctionDecl(
1402        //      location,
1403        //      unlockFnName,
1404        //      { /* forall */ },
1405        //      {
1406        //              // Copy the declaration of this.
1407        //              ast::deepCopy( this_decl ),
1408        //      },
1409        //      { /* returns */ },
1410        //      nullptr,
1411        //      0,
1412        //      ast::Linkage::Cforall,
1413        //      { /* attributes */ },
1414        //      ast::Function::Inline
1415        // );
1416
1417        // ast::IfStmt * outerLockIf = nullptr;
1418        // ast::IfStmt * outerUnlockIf = nullptr;
1419        // ast::IfStmt * lastLockIf = nullptr;
1420        // ast::IfStmt * lastUnlockIf = nullptr;
1421
1422        // //adds an if/elif clause for each lock to assign type from void ptr based on ptr address
1423        // for ( long unsigned int i = 0; i < args.size(); i++ ) {
1424        //      ast::VariableExpr * thisParam = new ast::VariableExpr( location, InitTweak::getParamThis( lock_decl ) );
1425        //      ast::UntypedExpr * ifCond = new ast::UntypedExpr( location,
1426        //              new ast::NameExpr( location, "?==?" ), {
1427        //                      thisParam,
1428        //                      new ast::CastExpr( location, new ast::AddressExpr( location, args.at(i) ), new ast::PointerType( new ast::VoidType() ))
1429        //              }
1430        //      );
1431
1432        //      ast::IfStmt * currLockIf = new ast::IfStmt(
1433        //              location,
1434        //              ast::deepCopy( ifCond ),
1435        //              genVirtLockUnlockExpr( "lock", args.at(i), location, ast::deepCopy( thisParam ) )
1436        //      );
1437
1438        //      ast::IfStmt * currUnlockIf = new ast::IfStmt(
1439        //              location,
1440        //              ifCond,
1441        //              genVirtLockUnlockExpr( "unlock", args.at(i), location, ast::deepCopy( thisParam ) )
1442        //      );
1443               
1444        //      if ( i == 0 ) {
1445        //              outerLockIf = currLockIf;
1446        //              outerUnlockIf = currUnlockIf;
1447        //      } else {
1448        //              // add ifstmt to else of previous stmt
1449        //              lastLockIf->else_ = currLockIf;
1450        //              lastUnlockIf->else_ = currUnlockIf;
1451        //      }
1452
1453        //      lastLockIf = currLockIf;
1454        //      lastUnlockIf = currUnlockIf;
1455        // }
1456       
1457        // // add pointer typing if/elifs to body of routines
1458        // lock_decl->stmts = new ast::CompoundStmt( location, { outerLockIf } );
1459        // unlock_decl->stmts = new ast::CompoundStmt( location, { outerUnlockIf } );
1460
1461        // // add routines to scope
1462        // declsToAddBefore.push_back( lock_decl );
1463        // declsToAddBefore.push_back( unlock_decl );
1464
1465        // newBody->push_front(new ast::DeclStmt( location, lock_decl ));
1466        // newBody->push_front(new ast::DeclStmt( location, unlock_decl ));
1467
1468        return newBody;
1469}
1470
1471ast::CompoundStmt * MutexKeyword::addThreadDtorStatements(
1472                const ast::FunctionDecl*, const ast::CompoundStmt * body,
1473                const std::vector<const ast::DeclWithType * > & args ) {
1474        assert( args.size() == 1 );
1475        const ast::DeclWithType * arg = args.front();
1476        const ast::Type * argType = arg->get_type();
1477        assert( argType->is_mutex() );
1478
1479        ast::CompoundStmt * mutBody = ast::mutate( body );
1480
1481        // The code is generated near the front of the body.
1482        const CodeLocation & location = mutBody->location;
1483
1484        // thread_dtor_guard_t __guard = { this, intptr( 0 ) };
1485        mutBody->push_front( new ast::DeclStmt(
1486                location,
1487                new ast::ObjectDecl(
1488                        location,
1489                        "__guard",
1490                        new ast::StructInstType( thread_guard_decl ),
1491                        new ast::ListInit(
1492                                location,
1493                                {
1494                                        new ast::SingleInit( location,
1495                                                new ast::CastExpr( location,
1496                                                        new ast::VariableExpr( location, arg ), argType ) ),
1497                                        new ast::SingleInit(
1498                                                location,
1499                                                new ast::UntypedExpr(
1500                                                        location,
1501                                                        new ast::NameExpr( location, "intptr" ), {
1502                                                                ast::ConstantExpr::from_int( location, 0 ),
1503                                                        }
1504                                                ) ),
1505                                },
1506                                {},
1507                                ast::MaybeConstruct
1508                        )
1509                )
1510        ));
1511
1512        return mutBody;
1513}
1514
1515ast::ptr<ast::Type> MutexKeyword::generic_func =
1516        new ast::FunctionType( ast::VariableArgs );
1517
1518// --------------------------------------------------------------------------
1519struct ThreadStarter final {
1520        void previsit( const ast::StructDecl * decl );
1521        const ast::FunctionDecl * postvisit( const ast::FunctionDecl * decl );
1522
1523private:
1524        bool thread_ctor_seen = false;
1525        const ast::StructDecl * thread_decl = nullptr;
1526};
1527
1528void ThreadStarter::previsit( const ast::StructDecl * decl ) {
1529        if ( decl->body && decl->name == "thread$" ) {
1530                assert( !thread_decl );
1531                thread_decl = decl;
1532        }
1533}
1534
1535const ast::FunctionDecl * ThreadStarter::postvisit( const ast::FunctionDecl * decl ) {
1536        if ( !CodeGen::isConstructor( decl->name ) ) return decl;
1537
1538        // Seach for the thread constructor.
1539        // (Are the "prefixes" of these to blocks the same?)
1540        const ast::Type * typeof_this = InitTweak::getTypeofThis( decl->type );
1541        auto ctored_type = dynamic_cast<const ast::StructInstType *>( typeof_this );
1542        if ( ctored_type && ctored_type->base == thread_decl ) {
1543                thread_ctor_seen = true;
1544        }
1545
1546        // Modify this declaration, the extra checks to see if we will are first.
1547        const ast::ptr<ast::DeclWithType> & param = decl->params.front();
1548        auto type = dynamic_cast<const ast::StructInstType *>(
1549                ast::getPointerBase( param->get_type() ) );
1550        if ( nullptr == type ) return decl;
1551        if ( !type->base->is_thread() ) return decl;
1552        if ( !thread_decl || !thread_ctor_seen ) {
1553                SemanticError( type->base->location, "thread keyword requires threads to be in scope, add #include <thread.hfa>" );
1554        }
1555        const ast::CompoundStmt * stmt = decl->stmts;
1556        if ( nullptr == stmt ) return decl;
1557
1558        // Now do the actual modification:
1559        ast::CompoundStmt * mutStmt = ast::mutate( stmt );
1560        const CodeLocation & location = mutStmt->location;
1561        mutStmt->push_back(
1562                new ast::ExprStmt(
1563                        location,
1564                        new ast::UntypedExpr(
1565                                location,
1566                                new ast::NameExpr( location, "__thrd_start" ),
1567                                {
1568                                        new ast::VariableExpr( location, param ),
1569                                        new ast::NameExpr( location, "main" ),
1570                                }
1571                        )
1572                )
1573        );
1574
1575        return ast::mutate_field( decl, &ast::FunctionDecl::stmts, mutStmt );
1576}
1577
1578} // namespace
1579
1580// --------------------------------------------------------------------------
1581// Interface Functions:
1582
1583void implementKeywords( ast::TranslationUnit & translationUnit ) {
1584        ast::Pass<ThreadKeyword>::run( translationUnit );
1585        ast::Pass<CoroutineKeyword>::run( translationUnit );
1586        ast::Pass<MonitorKeyword>::run( translationUnit );
1587        ast::Pass<GeneratorKeyword>::run( translationUnit );
1588        ast::Pass<SuspendKeyword>::run( translationUnit );
1589}
1590
1591void implementMutex( ast::TranslationUnit & translationUnit ) {
1592        ast::Pass<MutexKeyword>::run( translationUnit );
1593}
1594
1595void implementThreadStarter( ast::TranslationUnit & translationUnit ) {
1596        ast::Pass<ThreadStarter>::run( translationUnit );
1597}
1598
1599}
1600
1601// Local Variables: //
1602// tab-width: 4 //
1603// mode: c++ //
1604// compile-command: "make install" //
1605// End: //
Note: See TracBrowser for help on using the repository browser.