source: src/Concurrency/KeywordsNew.cpp @ 7ad47df

ADTast-experimentalpthread-emulationqualifiedEnum
Last change on this file since 7ad47df was b230091, checked in by Andrew Beach <ajbeach@…>, 3 years ago

Added a 'missing' TypeInstType? constructor and rewrote some calls to use it.

  • Property mode set to 100644
File size: 48.7 KB
Line 
1//
2// Cforall Version 1.0.0 Copyright (C) 2016 University of Waterloo
3//
4// The contents of this file are covered under the licence agreement in the
5// file "LICENCE" distributed with Cforall.
6//
7// KeywordsNew.cpp -- Implement concurrency constructs from their keywords.
8//
9// Author           : Andrew Beach
10// Created On       : Tue Nov 16  9:53:00 2021
11// Last Modified By : Andrew Beach
12// Last Modified On : Fri Mar 11 10:40:00 2022
13// Update Count     : 2
14//
15
16#include <iostream>
17
18#include "Concurrency/Keywords.h"
19
20#include "AST/Copy.hpp"
21#include "AST/Decl.hpp"
22#include "AST/Expr.hpp"
23#include "AST/Pass.hpp"
24#include "AST/Stmt.hpp"
25#include "AST/DeclReplacer.hpp"
26#include "AST/TranslationUnit.hpp"
27#include "CodeGen/OperatorTable.h"
28#include "Common/Examine.h"
29#include "Common/utility.h"
30#include "Common/UniqueName.h"
31#include "ControlStruct/LabelGeneratorNew.hpp"
32#include "InitTweak/InitTweak.h"
33#include "Virtual/Tables.h"
34
35namespace Concurrency {
36
37namespace {
38
39// --------------------------------------------------------------------------
40// Loose Helper Functions:
41
42/// Detect threads constructed with the keyword thread.
43bool isThread( const ast::DeclWithType * decl ) {
44        auto baseType = decl->get_type()->stripDeclarator();
45        auto instType = dynamic_cast<const ast::StructInstType *>( baseType );
46        if ( nullptr == instType ) { return false; }
47        return instType->base->is_thread();
48}
49
50/// Get the virtual type id if given a type name.
51std::string typeIdType( std::string const & exception_name ) {
52        return exception_name.empty() ? std::string()
53                : Virtual::typeIdType( exception_name );
54}
55
56/// Get the vtable type name if given a type name.
57std::string vtableTypeName( std::string const & exception_name ) {
58        return exception_name.empty() ? std::string()
59                : Virtual::vtableTypeName( exception_name );
60}
61
62static ast::Type * mutate_under_references( ast::ptr<ast::Type>& type ) {
63        ast::Type * mutType = type.get_and_mutate();
64        for ( ast::ReferenceType * mutRef
65                ; (mutRef = dynamic_cast<ast::ReferenceType *>( mutType ))
66                ; mutType = mutRef->base.get_and_mutate() );
67        return mutType;
68}
69
70// Describe that it adds the generic parameters and the uses of the generic
71// parameters on the function and first "this" argument.
72ast::FunctionDecl * fixupGenerics(
73                const ast::FunctionDecl * func, const ast::StructDecl * decl ) {
74        const CodeLocation & location = decl->location;
75        // We have to update both the declaration
76        auto mutFunc = ast::mutate( func );
77        auto mutType = mutFunc->type.get_and_mutate();
78
79        if ( decl->params.empty() ) {
80                return mutFunc;
81        }
82
83        assert( 0 != mutFunc->params.size() );
84        assert( 0 != mutType->params.size() );
85
86        // Add the "forall" clause information.
87        for ( const ast::ptr<ast::TypeDecl> & typeParam : decl->params ) {
88                auto typeDecl = ast::deepCopy( typeParam );
89                mutFunc->type_params.push_back( typeDecl );
90                mutType->forall.push_back( new ast::TypeInstType( typeDecl ) );
91                for ( auto & assertion : typeDecl->assertions ) {
92                        mutFunc->assertions.push_back( assertion );
93                        mutType->assertions.emplace_back(
94                                new ast::VariableExpr( location, assertion ) );
95                }
96                typeDecl->assertions.clear();
97        }
98
99        // Even chain_mutate is not powerful enough for this:
100        ast::ptr<ast::Type>& paramType = strict_dynamic_cast<ast::ObjectDecl *>(
101                mutFunc->params[0].get_and_mutate() )->type;
102        auto paramTypeInst = strict_dynamic_cast<ast::StructInstType *>(
103                mutate_under_references( paramType ) );
104        auto typeParamInst = strict_dynamic_cast<ast::StructInstType *>(
105                mutate_under_references( mutType->params[0] ) );
106
107        for ( const ast::ptr<ast::TypeDecl> & typeDecl : mutFunc->type_params ) {
108                paramTypeInst->params.push_back(
109                        new ast::TypeExpr( location, new ast::TypeInstType( typeDecl ) ) );
110                typeParamInst->params.push_back(
111                        new ast::TypeExpr( location, new ast::TypeInstType( typeDecl ) ) );
112        }
113
114        return mutFunc;
115}
116
117// --------------------------------------------------------------------------
118struct ConcurrentSueKeyword : public ast::WithDeclsToAdd<> {
119        ConcurrentSueKeyword(
120                std::string&& type_name, std::string&& field_name,
121                std::string&& getter_name, std::string&& context_error,
122                std::string&& exception_name,
123                bool needs_main, ast::AggregateDecl::Aggregate cast_target
124        ) :
125                type_name( type_name ), field_name( field_name ),
126                getter_name( getter_name ), context_error( context_error ),
127                exception_name( exception_name ),
128                typeid_name( typeIdType( exception_name ) ),
129                vtable_name( vtableTypeName( exception_name ) ),
130                needs_main( needs_main ), cast_target( cast_target )
131        {}
132
133        virtual ~ConcurrentSueKeyword() {}
134
135        const ast::Decl * postvisit( const ast::StructDecl * decl );
136        const ast::DeclWithType * postvisit( const ast::FunctionDecl * decl );
137        const ast::Expr * postvisit( const ast::KeywordCastExpr * expr );
138
139        struct StructAndField {
140                const ast::StructDecl * decl;
141                const ast::ObjectDecl * field;
142        };
143
144        const ast::StructDecl * handleStruct( const ast::StructDecl * );
145        void handleMain( const ast::FunctionDecl *, const ast::StructInstType * );
146        void addTypeId( const ast::StructDecl * );
147        void addVtableForward( const ast::StructDecl * );
148        const ast::FunctionDecl * forwardDeclare( const ast::StructDecl * );
149        StructAndField addField( const ast::StructDecl * );
150        void addGetRoutines( const ast::ObjectDecl *, const ast::FunctionDecl * );
151        void addLockUnlockRoutines( const ast::StructDecl * );
152
153private:
154        const std::string type_name;
155        const std::string field_name;
156        const std::string getter_name;
157        const std::string context_error;
158        const std::string exception_name;
159        const std::string typeid_name;
160        const std::string vtable_name;
161        const bool needs_main;
162        const ast::AggregateDecl::Aggregate cast_target;
163
164        const ast::StructDecl   * type_decl = nullptr;
165        const ast::FunctionDecl * dtor_decl = nullptr;
166        const ast::StructDecl * except_decl = nullptr;
167        const ast::StructDecl * typeid_decl = nullptr;
168        const ast::StructDecl * vtable_decl = nullptr;
169
170};
171
172// Handles thread type declarations:
173//
174// thread Mythread {                         struct MyThread {
175//  int data;                                  int data;
176//  a_struct_t more_data;                      a_struct_t more_data;
177//                                =>             thread$ __thrd_d;
178// };                                        };
179//                                           static inline thread$ * get_thread( MyThread * this ) { return &this->__thrd_d; }
180//
181struct ThreadKeyword final : public ConcurrentSueKeyword {
182        ThreadKeyword() : ConcurrentSueKeyword(
183                "thread$",
184                "__thrd",
185                "get_thread",
186                "thread keyword requires threads to be in scope, add #include <thread.hfa>\n",
187                "ThreadCancelled",
188                true,
189                ast::AggregateDecl::Thread )
190        {}
191
192        virtual ~ThreadKeyword() {}
193};
194
195// Handles coroutine type declarations:
196//
197// coroutine MyCoroutine {                   struct MyCoroutine {
198//  int data;                                  int data;
199//  a_struct_t more_data;                      a_struct_t more_data;
200//                                =>             coroutine$ __cor_d;
201// };                                        };
202//                                           static inline coroutine$ * get_coroutine( MyCoroutine * this ) { return &this->__cor_d; }
203//
204struct CoroutineKeyword final : public ConcurrentSueKeyword {
205        CoroutineKeyword() : ConcurrentSueKeyword(
206                "coroutine$",
207                "__cor",
208                "get_coroutine",
209                "coroutine keyword requires coroutines to be in scope, add #include <coroutine.hfa>\n",
210                "CoroutineCancelled",
211                true,
212                ast::AggregateDecl::Coroutine )
213        {}
214
215        virtual ~CoroutineKeyword() {}
216};
217
218// Handles monitor type declarations:
219//
220// monitor MyMonitor {                       struct MyMonitor {
221//  int data;                                  int data;
222//  a_struct_t more_data;                      a_struct_t more_data;
223//                                =>             monitor$ __mon_d;
224// };                                        };
225//                                           static inline monitor$ * get_coroutine( MyMonitor * this ) {
226//                                               return &this->__cor_d;
227//                                           }
228//                                           void lock(MyMonitor & this) {
229//                                               lock(get_monitor(this));
230//                                           }
231//                                           void unlock(MyMonitor & this) {
232//                                               unlock(get_monitor(this));
233//                                           }
234//
235struct MonitorKeyword final : public ConcurrentSueKeyword {
236        MonitorKeyword() : ConcurrentSueKeyword(
237                "monitor$",
238                "__mon",
239                "get_monitor",
240                "monitor keyword requires monitors to be in scope, add #include <monitor.hfa>\n",
241                "",
242                false,
243                ast::AggregateDecl::Monitor )
244        {}
245
246        virtual ~MonitorKeyword() {}
247};
248
249// Handles generator type declarations:
250//
251// generator MyGenerator {                   struct MyGenerator {
252//  int data;                                  int data;
253//  a_struct_t more_data;                      a_struct_t more_data;
254//                                =>             int __generator_state;
255// };                                        };
256//
257struct GeneratorKeyword final : public ConcurrentSueKeyword {
258        GeneratorKeyword() : ConcurrentSueKeyword(
259                "generator$",
260                "__generator_state",
261                "get_generator",
262                "Unable to find builtin type generator$\n",
263                "",
264                true,
265                ast::AggregateDecl::Generator )
266        {}
267
268        virtual ~GeneratorKeyword() {}
269};
270
271const ast::Decl * ConcurrentSueKeyword::postvisit(
272                const ast::StructDecl * decl ) {
273        if ( !decl->body ) {
274                return decl;
275        } else if ( cast_target == decl->kind ) {
276                return handleStruct( decl );
277        } else if ( type_name == decl->name ) {
278                assert( !type_decl );
279                type_decl = decl;
280        } else if ( exception_name == decl->name ) {
281                assert( !except_decl );
282                except_decl = decl;
283        } else if ( typeid_name == decl->name ) {
284                assert( !typeid_decl );
285                typeid_decl = decl;
286        } else if ( vtable_name == decl->name ) {
287                assert( !vtable_decl );
288                vtable_decl = decl;
289        }
290        return decl;
291}
292
293// Try to get the full definition, but raise an error on conflicts.
294const ast::FunctionDecl * getDefinition(
295                const ast::FunctionDecl * old_decl,
296                const ast::FunctionDecl * new_decl ) {
297        if ( !new_decl->stmts ) {
298                return old_decl;
299        } else if ( !old_decl->stmts ) {
300                return new_decl;
301        } else {
302                assert( !old_decl->stmts || !new_decl->stmts );
303                return nullptr;
304        }
305}
306
307const ast::DeclWithType * ConcurrentSueKeyword::postvisit(
308                const ast::FunctionDecl * decl ) {
309        if ( type_decl && isDestructorFor( decl, type_decl ) ) {
310                // Check for forward declarations, try to get the full definition.
311                dtor_decl = (dtor_decl) ? getDefinition( dtor_decl, decl ) : decl;
312        } else if ( !vtable_name.empty() && decl->has_body() ) {
313                if (const ast::DeclWithType * param = isMainFor( decl, cast_target )) {
314                        if ( !vtable_decl ) {
315                                SemanticError( decl, context_error );
316                        }
317                        // Should be safe because of isMainFor.
318                        const ast::StructInstType * struct_type =
319                                static_cast<const ast::StructInstType *>(
320                                        static_cast<const ast::ReferenceType *>(
321                                                param->get_type() )->base.get() );
322
323                        handleMain( decl, struct_type );
324                }
325        }
326        return decl;
327}
328
329const ast::Expr * ConcurrentSueKeyword::postvisit(
330                const ast::KeywordCastExpr * expr ) {
331        if ( cast_target == expr->target ) {
332                // Convert `(thread &)ex` to `(thread$ &)*get_thread(ex)`, etc.
333                if ( !type_decl || !dtor_decl ) {
334                        SemanticError( expr, context_error );
335                }
336                assert( nullptr == expr->result );
337                auto cast = ast::mutate( expr );
338                cast->result = new ast::ReferenceType( new ast::StructInstType( type_decl ) );
339                cast->concrete_target.field  = field_name;
340                cast->concrete_target.getter = getter_name;
341                return cast;
342        }
343        return expr;
344}
345
346const ast::StructDecl * ConcurrentSueKeyword::handleStruct(
347                const ast::StructDecl * decl ) {
348        assert( decl->body );
349
350        if ( !type_decl || !dtor_decl ) {
351                SemanticError( decl, context_error );
352        }
353
354        if ( !exception_name.empty() ) {
355                if( !typeid_decl || !vtable_decl ) {
356                        SemanticError( decl, context_error );
357                }
358                addTypeId( decl );
359                addVtableForward( decl );
360        }
361
362        const ast::FunctionDecl * func = forwardDeclare( decl );
363        StructAndField addFieldRet = addField( decl );
364        decl = addFieldRet.decl;
365        const ast::ObjectDecl * field = addFieldRet.field;
366
367        addGetRoutines( field, func );
368        // Add routines to monitors for use by mutex stmt.
369        if ( ast::AggregateDecl::Monitor == cast_target ) {
370                addLockUnlockRoutines( decl );
371        }
372
373        return decl;
374}
375
376void ConcurrentSueKeyword::handleMain(
377                const ast::FunctionDecl * decl, const ast::StructInstType * type ) {
378        assert( vtable_decl );
379        assert( except_decl );
380
381        const CodeLocation & location = decl->location;
382
383        std::vector<ast::ptr<ast::Expr>> poly_args = {
384                new ast::TypeExpr( location, type ),
385        };
386        ast::ObjectDecl * vtable_object = Virtual::makeVtableInstance(
387                location,
388                "_default_vtable_object_declaration",
389                new ast::StructInstType( vtable_decl, copy( poly_args ) ),
390                type,
391                nullptr
392        );
393        declsToAddAfter.push_back( vtable_object );
394        declsToAddAfter.push_back(
395                new ast::ObjectDecl(
396                        location,
397                        Virtual::concurrentDefaultVTableName(),
398                        new ast::ReferenceType( vtable_object->type, ast::CV::Const ),
399                        new ast::SingleInit( location,
400                                new ast::VariableExpr( location, vtable_object ) ),
401                        ast::Storage::Classes(),
402                        ast::Linkage::Cforall
403                )
404        );
405        declsToAddAfter.push_back( Virtual::makeGetExceptionFunction(
406                location,
407                vtable_object,
408                new ast::StructInstType( except_decl, copy( poly_args ) )
409        ) );
410}
411
412void ConcurrentSueKeyword::addTypeId( const ast::StructDecl * decl ) {
413        assert( typeid_decl );
414        const CodeLocation & location = decl->location;
415
416        ast::StructInstType * typeid_type =
417                new ast::StructInstType( typeid_decl, ast::CV::Const );
418        typeid_type->params.push_back(
419                new ast::TypeExpr( location, new ast::StructInstType( decl ) ) );
420        declsToAddBefore.push_back(
421                Virtual::makeTypeIdInstance( location, typeid_type ) );
422        // If the typeid_type is going to be kept, the other reference will have
423        // been made by now, but we also get to avoid extra mutates.
424        ast::ptr<ast::StructInstType> typeid_cleanup = typeid_type;
425}
426
427void ConcurrentSueKeyword::addVtableForward( const ast::StructDecl * decl ) {
428        assert( vtable_decl );
429        const CodeLocation& location = decl->location;
430
431        std::vector<ast::ptr<ast::Expr>> poly_args = {
432                new ast::TypeExpr( location, new ast::StructInstType( decl ) ),
433        };
434        declsToAddBefore.push_back( Virtual::makeGetExceptionForward(
435                location,
436                new ast::StructInstType( vtable_decl, copy( poly_args ) ),
437                new ast::StructInstType( except_decl, copy( poly_args ) )
438        ) );
439        ast::ObjectDecl * vtable_object = Virtual::makeVtableForward(
440                location,
441                "_default_vtable_object_declaration",
442                new ast::StructInstType( vtable_decl, std::move( poly_args ) )
443        );
444        declsToAddBefore.push_back( vtable_object );
445        declsToAddBefore.push_back(
446                new ast::ObjectDecl(
447                        location,
448                        Virtual::concurrentDefaultVTableName(),
449                        new ast::ReferenceType( vtable_object->type, ast::CV::Const ),
450                        nullptr,
451                        ast::Storage::Extern,
452                        ast::Linkage::Cforall
453                )
454        );
455}
456
457const ast::FunctionDecl * ConcurrentSueKeyword::forwardDeclare(
458                const ast::StructDecl * decl ) {
459        const CodeLocation & location = decl->location;
460
461        ast::StructDecl * forward = ast::deepCopy( decl );
462        {
463                // If removing members makes ref-count go to zero, do not free.
464                ast::ptr<ast::StructDecl> forward_ptr = forward;
465                forward->body = false;
466                forward->members.clear();
467                forward_ptr.release();
468        }
469
470        ast::ObjectDecl * this_decl = new ast::ObjectDecl(
471                location,
472                "this",
473                new ast::ReferenceType( new ast::StructInstType( decl ) ),
474                nullptr,
475                ast::Storage::Classes(),
476                ast::Linkage::Cforall
477        );
478
479        ast::ObjectDecl * ret_decl = new ast::ObjectDecl(
480                location,
481                "ret",
482                new ast::PointerType( new ast::StructInstType( type_decl ) ),
483                nullptr,
484                ast::Storage::Classes(),
485                ast::Linkage::Cforall
486        );
487
488        ast::FunctionDecl * get_decl = new ast::FunctionDecl(
489                location,
490                getter_name,
491                {}, // forall
492                { this_decl }, // params
493                { ret_decl }, // returns
494                nullptr, // stmts
495                ast::Storage::Static,
496                ast::Linkage::Cforall,
497                { new ast::Attribute( "const" ) },
498                ast::Function::Inline
499        );
500        get_decl = fixupGenerics( get_decl, decl );
501
502        ast::FunctionDecl * main_decl = nullptr;
503        if ( needs_main ) {
504                // `this_decl` is copied here because the original was used above.
505                main_decl = new ast::FunctionDecl(
506                        location,
507                        "main",
508                        {},
509                        { ast::deepCopy( this_decl ) },
510                        {},
511                        nullptr,
512                        ast::Storage::Classes(),
513                        ast::Linkage::Cforall
514                );
515                main_decl = fixupGenerics( main_decl, decl );
516        }
517
518        declsToAddBefore.push_back( forward );
519        if ( needs_main ) declsToAddBefore.push_back( main_decl );
520        declsToAddBefore.push_back( get_decl );
521
522        return get_decl;
523}
524
525ConcurrentSueKeyword::StructAndField ConcurrentSueKeyword::addField(
526                const ast::StructDecl * decl ) {
527        const CodeLocation & location = decl->location;
528
529        ast::ObjectDecl * field = new ast::ObjectDecl(
530                location,
531                field_name,
532                new ast::StructInstType( type_decl ),
533                nullptr,
534                ast::Storage::Classes(),
535                ast::Linkage::Cforall
536        );
537
538        auto mutDecl = ast::mutate( decl );
539        mutDecl->members.push_back( field );
540
541        return {mutDecl, field};
542}
543
544void ConcurrentSueKeyword::addGetRoutines(
545                const ast::ObjectDecl * field, const ast::FunctionDecl * forward ) {
546        // Say it is generated at the "same" places as the forward declaration.
547        const CodeLocation & location = forward->location;
548
549        const ast::DeclWithType * param = forward->params.front();
550        ast::Stmt * stmt = new ast::ReturnStmt( location,
551                new ast::AddressExpr( location,
552                        new ast::MemberExpr( location,
553                                field,
554                                new ast::CastExpr( location,
555                                        new ast::VariableExpr( location, param ),
556                                        ast::deepCopy( param->get_type()->stripReferences() ),
557                                        ast::ExplicitCast
558                                )
559                        )
560                )
561        );
562
563        ast::FunctionDecl * decl = ast::deepCopy( forward );
564        decl->stmts = new ast::CompoundStmt( location, { stmt } );
565        declsToAddAfter.push_back( decl );
566}
567
568void ConcurrentSueKeyword::addLockUnlockRoutines(
569                const ast::StructDecl * decl ) {
570        // This should only be used on monitors.
571        assert( ast::AggregateDecl::Monitor == cast_target );
572
573        const CodeLocation & location = decl->location;
574
575        // The parameter for both routines.
576        ast::ObjectDecl * this_decl = new ast::ObjectDecl(
577                location,
578                "this",
579                new ast::ReferenceType( new ast::StructInstType( decl ) ),
580                nullptr,
581                ast::Storage::Classes(),
582                ast::Linkage::Cforall
583        );
584
585        ast::FunctionDecl * lock_decl = new ast::FunctionDecl(
586                location,
587                "lock",
588                { /* forall */ },
589                {
590                        // Copy the declaration of this.
591                        ast::deepCopy( this_decl ),
592                },
593                { /* returns */ },
594                nullptr,
595                ast::Storage::Static,
596                ast::Linkage::Cforall,
597                { /* attributes */ },
598                ast::Function::Inline
599        );
600        lock_decl = fixupGenerics( lock_decl, decl );
601
602        lock_decl->stmts = new ast::CompoundStmt( location, {
603                new ast::ExprStmt( location,
604                        new ast::UntypedExpr( location,
605                                new ast::NameExpr( location, "lock" ),
606                                {
607                                        new ast::UntypedExpr( location,
608                                                new ast::NameExpr( location, "get_monitor" ),
609                                                { new ast::VariableExpr( location,
610                                                        InitTweak::getParamThis( lock_decl ) ) }
611                                        )
612                                }
613                        )
614                )
615        } );
616
617        ast::FunctionDecl * unlock_decl = new ast::FunctionDecl(
618                location,
619                "unlock",
620                { /* forall */ },
621                {
622                        // Last use, consume the declaration of this.
623                        this_decl,
624                },
625                { /* returns */ },
626                nullptr,
627                ast::Storage::Static,
628                ast::Linkage::Cforall,
629                { /* attributes */ },
630                ast::Function::Inline
631        );
632        unlock_decl = fixupGenerics( unlock_decl, decl );
633
634        unlock_decl->stmts = new ast::CompoundStmt( location, {
635                new ast::ExprStmt( location,
636                        new ast::UntypedExpr( location,
637                                new ast::NameExpr( location, "unlock" ),
638                                {
639                                        new ast::UntypedExpr( location,
640                                                new ast::NameExpr( location, "get_monitor" ),
641                                                { new ast::VariableExpr( location,
642                                                        InitTweak::getParamThis( unlock_decl ) ) }
643                                        )
644                                }
645                        )
646                )
647        } );
648
649        declsToAddAfter.push_back( lock_decl );
650        declsToAddAfter.push_back( unlock_decl );
651}
652
653
654// --------------------------------------------------------------------------
655struct SuspendKeyword final :
656                public ast::WithStmtsToAdd<>, public ast::WithGuards {
657        SuspendKeyword() = default;
658        virtual ~SuspendKeyword() = default;
659
660        void previsit( const ast::FunctionDecl * );
661        const ast::DeclWithType * postvisit( const ast::FunctionDecl * );
662        const ast::Stmt * postvisit( const ast::SuspendStmt * );
663
664private:
665        bool is_real_suspend( const ast::FunctionDecl * );
666
667        const ast::Stmt * make_generator_suspend( const ast::SuspendStmt * );
668        const ast::Stmt * make_coroutine_suspend( const ast::SuspendStmt * );
669
670        struct LabelPair {
671                ast::Label obj;
672                int idx;
673        };
674
675        LabelPair make_label(const ast::Stmt * stmt ) {
676                labels.push_back( ControlStruct::newLabel( "generator", stmt ) );
677                return { labels.back(), int(labels.size()) };
678        }
679
680        const ast::DeclWithType * in_generator = nullptr;
681        const ast::FunctionDecl * decl_suspend = nullptr;
682        std::vector<ast::Label> labels;
683};
684
685void SuspendKeyword::previsit( const ast::FunctionDecl * decl ) {
686        GuardValue( in_generator ); in_generator = nullptr;
687
688        // If it is the real suspend, grab it if we don't have one already.
689        if ( is_real_suspend( decl ) ) {
690                decl_suspend = decl_suspend ? decl_suspend : decl;
691                return;
692        }
693
694        // Otherwise check if this is a generator main and, if so, handle it.
695        auto param = isMainFor( decl, ast::AggregateDecl::Generator );
696        if ( !param ) return;
697
698        if ( 0 != decl->returns.size() ) {
699                SemanticError( decl->location, "Generator main must return void" );
700        }
701
702        in_generator = param;
703        GuardValue( labels ); labels.clear();
704}
705
706const ast::DeclWithType * SuspendKeyword::postvisit(
707                const ast::FunctionDecl * decl ) {
708        // Only modify a full definition of a generator with states.
709        if ( !decl->stmts || !in_generator || labels.empty() ) return decl;
710
711        const CodeLocation & location = decl->location;
712
713        // Create a new function body:
714        // static void * __generator_labels[] = {&&s0, &&s1, ...};
715        // void * __generator_label = __generator_labels[GEN.__generator_state];
716        // goto * __generator_label;
717        // s0: ;
718        // OLD_BODY
719
720        // This is the null statement inserted right before the body.
721        ast::NullStmt * noop = new ast::NullStmt( location );
722        noop->labels.push_back( ControlStruct::newLabel( "generator", noop ) );
723        const ast::Label & first_label = noop->labels.back();
724
725        // Add each label to the init, starting with the first label.
726        std::vector<ast::ptr<ast::Init>> inits = {
727                new ast::SingleInit( location,
728                        new ast::LabelAddressExpr( location, copy( first_label ) ) ) };
729        // Then go through all the stored labels, and clear the store.
730        for ( auto && label : labels ) {
731                inits.push_back( new ast::SingleInit( label.location,
732                        new ast::LabelAddressExpr( label.location, std::move( label )
733                        ) ) );
734        }
735        labels.clear();
736        // Then construct the initializer itself.
737        auto init = new ast::ListInit( location, std::move( inits ) );
738
739        ast::ObjectDecl * generatorLabels = new ast::ObjectDecl(
740                location,
741                "__generator_labels",
742                new ast::ArrayType(
743                        new ast::PointerType( new ast::VoidType() ),
744                        nullptr,
745                        ast::FixedLen,
746                        ast::DynamicDim
747                ),
748                init,
749                ast::Storage::Classes(),
750                ast::Linkage::AutoGen
751        );
752
753        ast::ObjectDecl * generatorLabel = new ast::ObjectDecl(
754                location,
755                "__generator_label",
756                new ast::PointerType( new ast::VoidType() ),
757                new ast::SingleInit( location,
758                        new ast::UntypedExpr( location,
759                                new ast::NameExpr( location, "?[?]" ),
760                                {
761                                        // TODO: Could be a variable expr.
762                                        new ast::NameExpr( location, "__generator_labels" ),
763                                        new ast::UntypedMemberExpr( location,
764                                                new ast::NameExpr( location, "__generator_state" ),
765                                                new ast::VariableExpr( location, in_generator )
766                                        )
767                                }
768                        )
769                ),
770                ast::Storage::Classes(),
771                ast::Linkage::AutoGen
772        );
773
774        ast::BranchStmt * theGoTo = new ast::BranchStmt(
775                location, new ast::VariableExpr( location, generatorLabel )
776        );
777
778        // The noop goes here in order.
779
780        ast::CompoundStmt * body = new ast::CompoundStmt( location, {
781                { new ast::DeclStmt( location, generatorLabels ) },
782                { new ast::DeclStmt( location, generatorLabel ) },
783                { theGoTo },
784                { noop },
785                { decl->stmts },
786        } );
787
788        auto mutDecl = ast::mutate( decl );
789        mutDecl->stmts = body;
790        return mutDecl;
791}
792
793const ast::Stmt * SuspendKeyword::postvisit( const ast::SuspendStmt * stmt ) {
794        switch ( stmt->type ) {
795        case ast::SuspendStmt::None:
796                // Use the context to determain the implicit target.
797                if ( in_generator ) {
798                        return make_generator_suspend( stmt );
799                } else {
800                        return make_coroutine_suspend( stmt );
801                }
802        case ast::SuspendStmt::Coroutine:
803                return make_coroutine_suspend( stmt );
804        case ast::SuspendStmt::Generator:
805                // Generator suspends must be directly in a generator.
806                if ( !in_generator ) SemanticError( stmt->location, "'suspend generator' must be used inside main of generator type." );
807                return make_generator_suspend( stmt );
808        }
809        assert( false );
810        return stmt;
811}
812
813/// Find the real/official suspend declaration.
814bool SuspendKeyword::is_real_suspend( const ast::FunctionDecl * decl ) {
815        return ( !decl->linkage.is_mangled
816                && 0 == decl->params.size()
817                && 0 == decl->returns.size()
818                && "__cfactx_suspend" == decl->name );
819}
820
821const ast::Stmt * SuspendKeyword::make_generator_suspend(
822                const ast::SuspendStmt * stmt ) {
823        assert( in_generator );
824        // Target code is:
825        //   GEN.__generator_state = X;
826        //   THEN
827        //   return;
828        //   __gen_X:;
829
830        const CodeLocation & location = stmt->location;
831
832        LabelPair label = make_label( stmt );
833
834        // This is the context saving statement.
835        stmtsToAddBefore.push_back( new ast::ExprStmt( location,
836                new ast::UntypedExpr( location,
837                        new ast::NameExpr( location, "?=?" ),
838                        {
839                                new ast::UntypedMemberExpr( location,
840                                        new ast::NameExpr( location, "__generator_state" ),
841                                        new ast::VariableExpr( location, in_generator )
842                                ),
843                                ast::ConstantExpr::from_int( location, label.idx ),
844                        }
845                )
846        ) );
847
848        // The THEN component is conditional (return is not).
849        if ( stmt->then ) {
850                stmtsToAddBefore.push_back( stmt->then.get() );
851        }
852        stmtsToAddBefore.push_back( new ast::ReturnStmt( location, nullptr ) );
853
854        // The null statement replaces the old suspend statement.
855        return new ast::NullStmt( location, { label.obj } );
856}
857
858const ast::Stmt * SuspendKeyword::make_coroutine_suspend(
859                const ast::SuspendStmt * stmt ) {
860        // The only thing we need from the old statement is the location.
861        const CodeLocation & location = stmt->location;
862
863        if ( !decl_suspend ) {
864                SemanticError( location, "suspend keyword applied to coroutines requires coroutines to be in scope, add #include <coroutine.hfa>\n" );
865        }
866        if ( stmt->then ) {
867                SemanticError( location, "Compound statement following coroutines is not implemented." );
868        }
869
870        return new ast::ExprStmt( location,
871                new ast::UntypedExpr( location,
872                        ast::VariableExpr::functionPointer( location, decl_suspend ) )
873        );
874}
875
876// --------------------------------------------------------------------------
877struct MutexKeyword final : public ast::WithDeclsToAdd<> {
878        const ast::FunctionDecl * postvisit( const ast::FunctionDecl * decl );
879        void postvisit( const ast::StructDecl * decl );
880        const ast::Stmt * postvisit( const ast::MutexStmt * stmt );
881
882        static std::vector<const ast::DeclWithType *> findMutexArgs(
883                        const ast::FunctionDecl * decl, bool & first );
884        static void validate( const ast::DeclWithType * decl );
885
886        ast::CompoundStmt * addDtorStatements( const ast::FunctionDecl* func, const ast::CompoundStmt *, const std::vector<const ast::DeclWithType *> &);
887        ast::CompoundStmt * addStatements( const ast::FunctionDecl* func, const ast::CompoundStmt *, const std::vector<const ast::DeclWithType *> &);
888        ast::CompoundStmt * addStatements( const ast::CompoundStmt * body, const std::vector<ast::ptr<ast::Expr>> & args );
889        ast::CompoundStmt * addThreadDtorStatements( const ast::FunctionDecl* func, const ast::CompoundStmt * body, const std::vector<const ast::DeclWithType *> & args );
890        ast::ExprStmt * genVirtLockUnlockExpr( const std::string & fnName, ast::ptr<ast::Expr> expr, const CodeLocation & location, ast::Expr * param);
891        ast::IfStmt * genTypeDiscrimLockUnlock( const std::string & fnName, const std::vector<ast::ptr<ast::Expr>> & args, const CodeLocation & location, ast::UntypedExpr * thisParam );
892private:
893        const ast::StructDecl * monitor_decl = nullptr;
894        const ast::StructDecl * guard_decl = nullptr;
895        const ast::StructDecl * dtor_guard_decl = nullptr;
896        const ast::StructDecl * thread_guard_decl = nullptr;
897        const ast::StructDecl * lock_guard_decl = nullptr;
898
899        static ast::ptr<ast::Type> generic_func;
900
901        UniqueName mutex_func_namer = UniqueName("__lock_unlock_curr");
902};
903
904const ast::FunctionDecl * MutexKeyword::postvisit(
905                const ast::FunctionDecl * decl ) {
906        bool is_first_argument_mutex = false;
907        const std::vector<const ast::DeclWithType *> mutexArgs =
908                findMutexArgs( decl, is_first_argument_mutex );
909        bool const isDtor = CodeGen::isDestructor( decl->name );
910
911        // Does this function have any mutex arguments that connect to monitors?
912        if ( mutexArgs.empty() ) {
913                // If this is the destructor for a monitor it must be mutex.
914                if ( isDtor ) {
915                        // This reflects MutexKeyword::validate, but no error messages.
916                        const ast::Type * type = decl->type->params.front();
917
918                        // If it's a copy, it's not a mutex.
919                        const ast::ReferenceType * refType = dynamic_cast<const ast::ReferenceType *>( type );
920                        if ( nullptr == refType ) {
921                                return decl;
922                        }
923
924                        // If it is not pointing directly to a type, it's not a mutex.
925                        auto base = refType->base;
926                        if ( base.as<ast::ReferenceType>() ) return decl;
927                        if ( base.as<ast::PointerType>() ) return decl;
928
929                        // If it is not a struct, it's not a mutex.
930                        auto baseStruct = base.as<ast::StructInstType>();
931                        if ( nullptr == baseStruct ) return decl;
932
933                        // If it is a monitor, then it is a monitor.
934                        if( baseStruct->base->is_monitor() || baseStruct->base->is_thread() ) {
935                                SemanticError( decl, "destructors for structures declared as \"monitor\" must use mutex parameters\n" );
936                        }
937                }
938                return decl;
939        }
940
941        // Monitors can't be constructed with mutual exclusion.
942        if ( CodeGen::isConstructor( decl->name ) && is_first_argument_mutex ) {
943                SemanticError( decl, "constructors cannot have mutex parameters\n" );
944        }
945
946        // It makes no sense to have multiple mutex parameters for the destructor.
947        if ( isDtor && mutexArgs.size() != 1 ) {
948                SemanticError( decl, "destructors can only have 1 mutex argument\n" );
949        }
950
951        // Make sure all the mutex arguments are monitors.
952        for ( auto arg : mutexArgs ) {
953                validate( arg );
954        }
955
956        // Check to see if the body needs to be instrument the body.
957        const ast::CompoundStmt * body = decl->stmts;
958        if ( !body ) return decl;
959
960        // Check to if the required headers have been seen.
961        if ( !monitor_decl || !guard_decl || !dtor_guard_decl ) {
962                SemanticError( decl, "mutex keyword requires monitors to be in scope, add #include <monitor.hfa>\n" );
963        }
964
965        // Instrument the body.
966        ast::CompoundStmt * newBody = nullptr;
967        if ( isDtor && isThread( mutexArgs.front() ) ) {
968                if ( !thread_guard_decl ) {
969                        SemanticError( decl, "thread destructor requires threads to be in scope, add #include <thread.hfa>\n" );
970                }
971                newBody = addThreadDtorStatements( decl, body, mutexArgs );
972        } else if ( isDtor ) {
973                newBody = addDtorStatements( decl, body, mutexArgs );
974        } else {
975                newBody = addStatements( decl, body, mutexArgs );
976        }
977        assert( newBody );
978        return ast::mutate_field( decl, &ast::FunctionDecl::stmts, newBody );
979}
980
981void MutexKeyword::postvisit( const ast::StructDecl * decl ) {
982        if ( !decl->body ) {
983                return;
984        } else if ( decl->name == "monitor$" ) {
985                assert( !monitor_decl );
986                monitor_decl = decl;
987        } else if ( decl->name == "monitor_guard_t" ) {
988                assert( !guard_decl );
989                guard_decl = decl;
990        } else if ( decl->name == "monitor_dtor_guard_t" ) {
991                assert( !dtor_guard_decl );
992                dtor_guard_decl = decl;
993        } else if ( decl->name == "thread_dtor_guard_t" ) {
994                assert( !thread_guard_decl );
995                thread_guard_decl = decl;
996        } else if ( decl->name == "__mutex_stmt_lock_guard" ) {
997                assert( !lock_guard_decl );
998                lock_guard_decl = decl;
999        }
1000}
1001
1002const ast::Stmt * MutexKeyword::postvisit( const ast::MutexStmt * stmt ) {
1003        if ( !lock_guard_decl ) {
1004                SemanticError( stmt->location, "mutex stmt requires a header, add #include <mutex_stmt.hfa>\n" );
1005        }
1006        ast::CompoundStmt * body =
1007                        new ast::CompoundStmt( stmt->location, { stmt->stmt } );
1008       
1009        return addStatements( body, stmt->mutexObjs );;
1010}
1011
1012std::vector<const ast::DeclWithType *> MutexKeyword::findMutexArgs(
1013                const ast::FunctionDecl * decl, bool & first ) {
1014        std::vector<const ast::DeclWithType *> mutexArgs;
1015
1016        bool once = true;
1017        for ( auto arg : decl->params ) {
1018                const ast::Type * type = arg->get_type();
1019                if ( type->is_mutex() ) {
1020                        if ( once ) first = true;
1021                        mutexArgs.push_back( arg.get() );
1022                }
1023                once = false;
1024        }
1025        return mutexArgs;
1026}
1027
1028void MutexKeyword::validate( const ast::DeclWithType * decl ) {
1029        const ast::Type * type = decl->get_type();
1030
1031        // If it's a copy, it's not a mutex.
1032        const ast::ReferenceType * refType = dynamic_cast<const ast::ReferenceType *>( type );
1033        if ( nullptr == refType ) {
1034                SemanticError( decl, "Mutex argument must be of reference type " );
1035        }
1036
1037        // If it is not pointing directly to a type, it's not a mutex.
1038        auto base = refType->base;
1039        if ( base.as<ast::ReferenceType>() || base.as<ast::PointerType>() ) {
1040                SemanticError( decl, "Mutex argument have exactly one level of indirection " );
1041        }
1042
1043        // If it is not a struct, it's not a mutex.
1044        auto baseStruct = base.as<ast::StructInstType>();
1045        if ( nullptr == baseStruct ) return;
1046
1047        // Make sure that only the outer reference is mutex.
1048        if( baseStruct->is_mutex() ) {
1049                SemanticError( decl, "mutex keyword may only appear once per argument " );
1050        }
1051}
1052
1053ast::CompoundStmt * MutexKeyword::addDtorStatements(
1054                const ast::FunctionDecl* func, const ast::CompoundStmt * body,
1055                const std::vector<const ast::DeclWithType *> & args ) {
1056        ast::Type * argType = ast::shallowCopy( args.front()->get_type() );
1057        argType->set_mutex( false );
1058
1059        ast::CompoundStmt * mutBody = ast::mutate( body );
1060
1061        // Generated code goes near the beginning of body:
1062        const CodeLocation & location = mutBody->location;
1063
1064        const ast::ObjectDecl * monitor = new ast::ObjectDecl(
1065                location,
1066                "__monitor",
1067                new ast::PointerType( new ast::StructInstType( monitor_decl ) ),
1068                new ast::SingleInit(
1069                        location,
1070                        new ast::UntypedExpr(
1071                                location,
1072                                new ast::NameExpr( location, "get_monitor" ),
1073                                { new ast::CastExpr(
1074                                        location,
1075                                        new ast::VariableExpr( location, args.front() ),
1076                                        argType, ast::ExplicitCast
1077                                ) }
1078                        )
1079                ),
1080                ast::Storage::Classes(),
1081                ast::Linkage::Cforall
1082        );
1083
1084        assert( generic_func );
1085
1086        // In reverse order:
1087        // monitor_dtor_guard_t __guard = { __monitor, func, false };
1088        mutBody->push_front(
1089                new ast::DeclStmt( location, new ast::ObjectDecl(
1090                        location,
1091                        "__guard",
1092                        new ast::StructInstType( dtor_guard_decl ),
1093                        new ast::ListInit(
1094                                location,
1095                                {
1096                                        new ast::SingleInit( location,
1097                                                new ast::AddressExpr( location,
1098                                                        new ast::VariableExpr( location, monitor ) ) ),
1099                                        new ast::SingleInit( location,
1100                                                new ast::CastExpr( location,
1101                                                        new ast::VariableExpr( location, func ),
1102                                                        generic_func,
1103                                                        ast::ExplicitCast ) ),
1104                                        new ast::SingleInit( location,
1105                                                ast::ConstantExpr::from_bool( location, false ) ),
1106                                },
1107                                {},
1108                                ast::MaybeConstruct
1109                        ),
1110                        ast::Storage::Classes(),
1111                        ast::Linkage::Cforall
1112                ))
1113        );
1114
1115        // monitor$ * __monitor = get_monitor(a);
1116        mutBody->push_front( new ast::DeclStmt( location, monitor ) );
1117
1118        return mutBody;
1119}
1120
1121ast::CompoundStmt * MutexKeyword::addStatements(
1122                const ast::FunctionDecl* func, const ast::CompoundStmt * body,
1123                const std::vector<const ast::DeclWithType * > & args ) {
1124        ast::CompoundStmt * mutBody = ast::mutate( body );
1125
1126        // Code is generated near the beginning of the compound statement.
1127        const CodeLocation & location = mutBody->location;
1128
1129        // Make pointer to the monitors.
1130        ast::ObjectDecl * monitors = new ast::ObjectDecl(
1131                location,
1132                "__monitors",
1133                new ast::ArrayType(
1134                        new ast::PointerType(
1135                                new ast::StructInstType( monitor_decl )
1136                        ),
1137                        ast::ConstantExpr::from_ulong( location, args.size() ),
1138                        ast::FixedLen,
1139                        ast::DynamicDim
1140                ),
1141                new ast::ListInit(
1142                        location,
1143                        map_range<std::vector<ast::ptr<ast::Init>>>(
1144                                args,
1145                                []( const ast::DeclWithType * decl ) {
1146                                        return new ast::SingleInit(
1147                                                decl->location,
1148                                                new ast::UntypedExpr(
1149                                                        decl->location,
1150                                                        new ast::NameExpr( decl->location, "get_monitor" ),
1151                                                        {
1152                                                                new ast::CastExpr(
1153                                                                        decl->location,
1154                                                                        new ast::VariableExpr( decl->location, decl ),
1155                                                                        decl->get_type(),
1156                                                                        ast::ExplicitCast
1157                                                                )
1158                                                        }
1159                                                )
1160                                        );
1161                                }
1162                        )
1163                ),
1164                ast::Storage::Classes(),
1165                ast::Linkage::Cforall
1166        );
1167
1168        assert( generic_func );
1169
1170        // In Reverse Order:
1171        mutBody->push_front(
1172                new ast::DeclStmt( location, new ast::ObjectDecl(
1173                        location,
1174                        "__guard",
1175                        new ast::StructInstType( guard_decl ),
1176                        new ast::ListInit(
1177                                location,
1178                                {
1179                                        new ast::SingleInit( location,
1180                                                new ast::VariableExpr( location, monitors ) ),
1181                                        new ast::SingleInit( location,
1182                                                ast::ConstantExpr::from_ulong( location, args.size() ) ),
1183                                        new ast::SingleInit( location, new ast::CastExpr(
1184                                                location,
1185                                                new ast::VariableExpr( location, func ),
1186                                                generic_func,
1187                                                ast::ExplicitCast
1188                                        ) ),
1189                                },
1190                                {},
1191                                ast::MaybeConstruct
1192                        ),
1193                        ast::Storage::Classes(),
1194                        ast::Linkage::Cforall
1195                ))
1196        );
1197
1198        // monitor$ * __monitors[] = { get_monitor(a), get_monitor(b) };
1199        mutBody->push_front( new ast::DeclStmt( location, monitors ) );
1200
1201        return mutBody;
1202}
1203
1204// generates a cast to the void ptr to the appropriate lock type and dereferences it before calling lock or unlock on it
1205// used to undo the type erasure done by storing all the lock pointers as void
1206ast::ExprStmt * MutexKeyword::genVirtLockUnlockExpr( const std::string & fnName, ast::ptr<ast::Expr> expr, const CodeLocation & location, ast::Expr * param ) {
1207        return new ast::ExprStmt( location,
1208                new ast::UntypedExpr( location,
1209                        new ast::NameExpr( location, fnName ), {
1210                                ast::UntypedExpr::createDeref(
1211                                        location,
1212                                        new ast::CastExpr( location, 
1213                                                param,
1214                                                new ast::PointerType( new ast::TypeofType( new ast::UntypedExpr(
1215                                                        expr->location,
1216                                                        new ast::NameExpr( expr->location, "__get_mutexstmt_lock_type" ),
1217                                                        { expr }
1218                                                ) ) ),
1219                                                ast::GeneratedFlag::ExplicitCast
1220                                        )
1221                                )
1222                        }
1223                )
1224        );
1225}
1226
1227ast::IfStmt * MutexKeyword::genTypeDiscrimLockUnlock( const std::string & fnName, const std::vector<ast::ptr<ast::Expr>> & args, const CodeLocation & location, ast::UntypedExpr * thisParam ) {
1228        ast::IfStmt * outerLockIf = nullptr;
1229        ast::IfStmt * lastLockIf = nullptr;
1230
1231        //adds an if/elif clause for each lock to assign type from void ptr based on ptr address
1232        for ( long unsigned int i = 0; i < args.size(); i++ ) {
1233               
1234                ast::UntypedExpr * ifCond = new ast::UntypedExpr( location,
1235                        new ast::NameExpr( location, "?==?" ), {
1236                                ast::deepCopy( thisParam ),
1237                                new ast::CastExpr( location, new ast::AddressExpr( location, args.at(i) ), new ast::PointerType( new ast::VoidType() ))
1238                        }
1239                );
1240
1241                ast::IfStmt * currLockIf = new ast::IfStmt( 
1242                        location,
1243                        ifCond,
1244                        genVirtLockUnlockExpr( fnName, args.at(i), location, ast::deepCopy( thisParam ) )
1245                );
1246               
1247                if ( i == 0 ) {
1248                        outerLockIf = currLockIf;
1249                } else {
1250                        // add ifstmt to else of previous stmt
1251                        lastLockIf->else_ = currLockIf;
1252                }
1253
1254                lastLockIf = currLockIf;
1255        }
1256        return outerLockIf;
1257}
1258
1259ast::CompoundStmt * MutexKeyword::addStatements(
1260                const ast::CompoundStmt * body,
1261                const std::vector<ast::ptr<ast::Expr>> & args ) {
1262
1263        // Code is generated near the beginning of the compound statement.
1264        const CodeLocation & location = body->location;
1265
1266                // final body to return
1267        ast::CompoundStmt * newBody = new ast::CompoundStmt( location );
1268
1269        // std::string lockFnName = mutex_func_namer.newName();
1270        // std::string unlockFnName = mutex_func_namer.newName();
1271
1272        // Make pointer to the monitors.
1273        ast::ObjectDecl * monitors = new ast::ObjectDecl(
1274                location,
1275                "__monitors",
1276                new ast::ArrayType(
1277                        new ast::PointerType(
1278                                new ast::VoidType()
1279                        ),
1280                        ast::ConstantExpr::from_ulong( location, args.size() ),
1281                        ast::FixedLen,
1282                        ast::DynamicDim
1283                ),
1284                new ast::ListInit(
1285                        location,
1286                        map_range<std::vector<ast::ptr<ast::Init>>>(
1287                                args, [](const ast::Expr * expr) {
1288                                        return new ast::SingleInit(
1289                                                expr->location,
1290                                                new ast::UntypedExpr(
1291                                                        expr->location,
1292                                                        new ast::NameExpr( expr->location, "__get_mutexstmt_lock_ptr" ),
1293                                                        { expr }
1294                                                )
1295                                        );
1296                                }
1297                        )
1298                ),
1299                ast::Storage::Classes(),
1300                ast::Linkage::Cforall
1301        );
1302
1303        ast::StructInstType * lock_guard_struct =
1304                        new ast::StructInstType( lock_guard_decl );
1305
1306        // use try stmts to lock and finally to unlock
1307        ast::TryStmt * outerTry = nullptr;
1308        ast::TryStmt * currentTry;
1309        ast::CompoundStmt * lastBody = nullptr;
1310
1311        // adds a nested try stmt for each lock we are locking
1312        for ( long unsigned int i = 0; i < args.size(); i++ ) {
1313                ast::UntypedExpr * innerAccess = new ast::UntypedExpr( 
1314                        location,
1315                        new ast::NameExpr( location,"?[?]" ), {
1316                                new ast::NameExpr( location, "__monitors" ),
1317                                ast::ConstantExpr::from_int( location, i )
1318                        }
1319                );
1320
1321                // make the try body
1322                ast::CompoundStmt * currTryBody = new ast::CompoundStmt( location );
1323                ast::IfStmt * lockCall = genTypeDiscrimLockUnlock( "lock", args, location, innerAccess );
1324                currTryBody->push_back( lockCall );
1325
1326                // make the finally stmt
1327                ast::CompoundStmt * currFinallyBody = new ast::CompoundStmt( location );
1328                ast::IfStmt * unlockCall = genTypeDiscrimLockUnlock( "unlock", args, location, innerAccess );
1329                currFinallyBody->push_back( unlockCall );
1330
1331                // construct the current try
1332                currentTry = new ast::TryStmt(
1333                        location,
1334                        currTryBody,
1335                        {},
1336                        new ast::FinallyClause( location, currFinallyBody )
1337                );
1338                if ( i == 0 ) outerTry = currentTry;
1339                else {
1340                        // pushback try into the body of the outer try
1341                        lastBody->push_back( currentTry );
1342                }
1343                lastBody = currTryBody;
1344        }
1345
1346        // push body into innermost try body
1347        if ( lastBody != nullptr ) {
1348                lastBody->push_back( body );
1349                newBody->push_front( outerTry );
1350        }
1351
1352        // monitor_guard_t __guard = { __monitors, # };
1353        newBody->push_front(
1354                new ast::DeclStmt(
1355                        location,
1356                        new ast::ObjectDecl(
1357                                location,
1358                                "__guard",
1359                                lock_guard_struct,
1360                                new ast::ListInit(
1361                                        location,
1362                                        {
1363                                                new ast::SingleInit(
1364                                                        location,
1365                                                        new ast::VariableExpr( location, monitors ) ),
1366                                                new ast::SingleInit(
1367                                                        location,
1368                                                        ast::ConstantExpr::from_ulong( location, args.size() ) ),
1369                                        },
1370                                        {},
1371                                        ast::MaybeConstruct
1372                                ),
1373                                ast::Storage::Classes(),
1374                                ast::Linkage::Cforall
1375                        )
1376                )
1377        );
1378
1379        // monitor$ * __monitors[] = { get_monitor(a), get_monitor(b) };
1380        newBody->push_front( new ast::DeclStmt( location, monitors ) );
1381
1382        // // The parameter for both __lock_curr/__unlock_curr routines.
1383        // ast::ObjectDecl * this_decl = new ast::ObjectDecl(
1384        //      location,
1385        //      "this",
1386        //      new ast::PointerType( new ast::VoidType() ),
1387        //      nullptr,
1388        //      {},
1389        //      ast::Linkage::Cforall
1390        // );
1391
1392        // ast::FunctionDecl * lock_decl = new ast::FunctionDecl(
1393        //      location,
1394        //      lockFnName,
1395        //      { /* forall */ },
1396        //      {
1397        //              // Copy the declaration of this.
1398        //              this_decl,
1399        //      },
1400        //      { /* returns */ },
1401        //      nullptr,
1402        //      0,
1403        //      ast::Linkage::Cforall,
1404        //      { /* attributes */ },
1405        //      ast::Function::Inline
1406        // );
1407
1408        // ast::FunctionDecl * unlock_decl = new ast::FunctionDecl(
1409        //      location,
1410        //      unlockFnName,
1411        //      { /* forall */ },
1412        //      {
1413        //              // Copy the declaration of this.
1414        //              ast::deepCopy( this_decl ),
1415        //      },
1416        //      { /* returns */ },
1417        //      nullptr,
1418        //      0,
1419        //      ast::Linkage::Cforall,
1420        //      { /* attributes */ },
1421        //      ast::Function::Inline
1422        // );
1423
1424        // ast::IfStmt * outerLockIf = nullptr;
1425        // ast::IfStmt * outerUnlockIf = nullptr;
1426        // ast::IfStmt * lastLockIf = nullptr;
1427        // ast::IfStmt * lastUnlockIf = nullptr;
1428
1429        // //adds an if/elif clause for each lock to assign type from void ptr based on ptr address
1430        // for ( long unsigned int i = 0; i < args.size(); i++ ) {
1431        //      ast::VariableExpr * thisParam = new ast::VariableExpr( location, InitTweak::getParamThis( lock_decl ) );
1432        //      ast::UntypedExpr * ifCond = new ast::UntypedExpr( location,
1433        //              new ast::NameExpr( location, "?==?" ), {
1434        //                      thisParam,
1435        //                      new ast::CastExpr( location, new ast::AddressExpr( location, args.at(i) ), new ast::PointerType( new ast::VoidType() ))
1436        //              }
1437        //      );
1438
1439        //      ast::IfStmt * currLockIf = new ast::IfStmt(
1440        //              location,
1441        //              ast::deepCopy( ifCond ),
1442        //              genVirtLockUnlockExpr( "lock", args.at(i), location, ast::deepCopy( thisParam ) )
1443        //      );
1444
1445        //      ast::IfStmt * currUnlockIf = new ast::IfStmt(
1446        //              location,
1447        //              ifCond,
1448        //              genVirtLockUnlockExpr( "unlock", args.at(i), location, ast::deepCopy( thisParam ) )
1449        //      );
1450               
1451        //      if ( i == 0 ) {
1452        //              outerLockIf = currLockIf;
1453        //              outerUnlockIf = currUnlockIf;
1454        //      } else {
1455        //              // add ifstmt to else of previous stmt
1456        //              lastLockIf->else_ = currLockIf;
1457        //              lastUnlockIf->else_ = currUnlockIf;
1458        //      }
1459
1460        //      lastLockIf = currLockIf;
1461        //      lastUnlockIf = currUnlockIf;
1462        // }
1463       
1464        // // add pointer typing if/elifs to body of routines
1465        // lock_decl->stmts = new ast::CompoundStmt( location, { outerLockIf } );
1466        // unlock_decl->stmts = new ast::CompoundStmt( location, { outerUnlockIf } );
1467
1468        // // add routines to scope
1469        // declsToAddBefore.push_back( lock_decl );
1470        // declsToAddBefore.push_back( unlock_decl );
1471
1472        // newBody->push_front(new ast::DeclStmt( location, lock_decl ));
1473        // newBody->push_front(new ast::DeclStmt( location, unlock_decl ));
1474
1475        return newBody;
1476}
1477
1478ast::CompoundStmt * MutexKeyword::addThreadDtorStatements(
1479                const ast::FunctionDecl*, const ast::CompoundStmt * body,
1480                const std::vector<const ast::DeclWithType * > & args ) {
1481        assert( args.size() == 1 );
1482        const ast::DeclWithType * arg = args.front();
1483        const ast::Type * argType = arg->get_type();
1484        assert( argType->is_mutex() );
1485
1486        ast::CompoundStmt * mutBody = ast::mutate( body );
1487
1488        // The code is generated near the front of the body.
1489        const CodeLocation & location = mutBody->location;
1490
1491        // thread_dtor_guard_t __guard = { this, intptr( 0 ) };
1492        mutBody->push_front( new ast::DeclStmt(
1493                location,
1494                new ast::ObjectDecl(
1495                        location,
1496                        "__guard",
1497                        new ast::StructInstType( thread_guard_decl ),
1498                        new ast::ListInit(
1499                                location,
1500                                {
1501                                        new ast::SingleInit( location,
1502                                                new ast::CastExpr( location,
1503                                                        new ast::VariableExpr( location, arg ), argType ) ),
1504                                        new ast::SingleInit(
1505                                                location,
1506                                                new ast::UntypedExpr(
1507                                                        location,
1508                                                        new ast::NameExpr( location, "intptr" ), {
1509                                                                ast::ConstantExpr::from_int( location, 0 ),
1510                                                        }
1511                                                ) ),
1512                                },
1513                                {},
1514                                ast::MaybeConstruct
1515                        ),
1516                        ast::Storage::Classes(),
1517                        ast::Linkage::Cforall
1518                )
1519        ));
1520
1521        return mutBody;
1522}
1523
1524ast::ptr<ast::Type> MutexKeyword::generic_func =
1525        new ast::FunctionType( ast::VariableArgs );
1526
1527// --------------------------------------------------------------------------
1528struct ThreadStarter final {
1529        void previsit( const ast::StructDecl * decl );
1530        const ast::FunctionDecl * postvisit( const ast::FunctionDecl * decl );
1531
1532private:
1533        bool thread_ctor_seen = false;
1534        const ast::StructDecl * thread_decl = nullptr;
1535};
1536
1537void ThreadStarter::previsit( const ast::StructDecl * decl ) {
1538        if ( decl->body && decl->name == "thread$" ) {
1539                assert( !thread_decl );
1540                thread_decl = decl;
1541        }
1542}
1543
1544const ast::FunctionDecl * ThreadStarter::postvisit( const ast::FunctionDecl * decl ) {
1545        if ( !CodeGen::isConstructor( decl->name ) ) return decl;
1546
1547        // Seach for the thread constructor.
1548        // (Are the "prefixes" of these to blocks the same?)
1549        const ast::Type * typeof_this = InitTweak::getTypeofThis( decl->type );
1550        auto ctored_type = dynamic_cast<const ast::StructInstType *>( typeof_this );
1551        if ( ctored_type && ctored_type->base == thread_decl ) {
1552                thread_ctor_seen = true;
1553        }
1554
1555        // Modify this declaration, the extra checks to see if we will are first.
1556        const ast::ptr<ast::DeclWithType> & param = decl->params.front();
1557        auto type = dynamic_cast<const ast::StructInstType *>(
1558                InitTweak::getPointerBase( param->get_type() ) );
1559        if ( nullptr == type ) return decl;
1560        if ( !type->base->is_thread() ) return decl;
1561        if ( !thread_decl || !thread_ctor_seen ) {
1562                SemanticError( type->base->location, "thread keyword requires threads to be in scope, add #include <thread.hfa>" );
1563        }
1564        const ast::CompoundStmt * stmt = decl->stmts;
1565        if ( nullptr == stmt ) return decl;
1566
1567        // Now do the actual modification:
1568        ast::CompoundStmt * mutStmt = ast::mutate( stmt );
1569        const CodeLocation & location = mutStmt->location;
1570        mutStmt->push_back(
1571                new ast::ExprStmt(
1572                        location,
1573                        new ast::UntypedExpr(
1574                                location,
1575                                new ast::NameExpr( location, "__thrd_start" ),
1576                                {
1577                                        new ast::VariableExpr( location, param ),
1578                                        new ast::NameExpr( location, "main" ),
1579                                }
1580                        )
1581                )
1582        );
1583
1584        return ast::mutate_field( decl, &ast::FunctionDecl::stmts, mutStmt );
1585}
1586
1587} // namespace
1588
1589// --------------------------------------------------------------------------
1590// Interface Functions:
1591
1592void implementKeywords( ast::TranslationUnit & translationUnit ) {
1593        ast::Pass<ThreadKeyword>::run( translationUnit );
1594        ast::Pass<CoroutineKeyword>::run( translationUnit );
1595        ast::Pass<MonitorKeyword>::run( translationUnit );
1596        ast::Pass<GeneratorKeyword>::run( translationUnit );
1597        ast::Pass<SuspendKeyword>::run( translationUnit );
1598}
1599
1600void implementMutex( ast::TranslationUnit & translationUnit ) {
1601        ast::Pass<MutexKeyword>::run( translationUnit );
1602}
1603
1604void implementThreadStarter( ast::TranslationUnit & translationUnit ) {
1605        ast::Pass<ThreadStarter>::run( translationUnit );
1606}
1607
1608}
1609
1610// Local Variables: //
1611// tab-width: 4 //
1612// mode: c++ //
1613// compile-command: "make install" //
1614// End: //
Note: See TracBrowser for help on using the repository browser.