source: src/Concurrency/Keywords.cc @ 9a05b81

ADTarm-ehast-experimentalenumforall-pointer-decayjacob/cs343-translationnew-ast-unique-exprpthread-emulationqualifiedEnum
Last change on this file since 9a05b81 was 1c01c58, checked in by Andrew Beach <ajbeach@…>, 4 years ago

Rather large commit to get coroutine cancellation working.

This includes what you would expect, like new code in exceptions and a new
test, but it also includes a bunch of other things.

New coroutine state, currently just marks that the stack was cancelled. New
helpers for checking code structure and generating vtables. Changes to the
coroutine interface so resume may throw exceptions on cancellation, plus the
exception type that is thrown. Changes to the coroutine keyword generation to
generate exception code for each type of coroutine.

  • Property mode set to 100644
File size: 32.4 KB
Line 
1//
2// Cforall Version 1.0.0 Copyright (C) 2016 University of Waterloo
3//
4// The contents of this file are covered under the licence agreement in the
5// file "LICENCE" distributed with Cforall.
6//
7// Keywords.cc --
8//
9// Author           : Thierry Delisle
10// Created On       : Mon Mar 13 12:41:22 2017
11// Last Modified By :
12// Last Modified On :
13// Update Count     : 10
14//
15
16#include "Concurrency/Keywords.h"
17
18#include <cassert>                        // for assert
19#include <string>                         // for string, operator==
20
21#include <iostream>
22
23#include "Common/Examine.h"               // for isMainFor
24#include "Common/PassVisitor.h"           // for PassVisitor
25#include "Common/SemanticError.h"         // for SemanticError
26#include "Common/utility.h"               // for deleteAll, map_range
27#include "CodeGen/OperatorTable.h"        // for isConstructor
28#include "ControlStruct/LabelGenerator.h" // for LebelGenerator
29#include "InitTweak/InitTweak.h"          // for getPointerBase
30#include "SynTree/LinkageSpec.h"          // for Cforall
31#include "SynTree/Constant.h"             // for Constant
32#include "SynTree/Declaration.h"          // for StructDecl, FunctionDecl, ObjectDecl
33#include "SynTree/Expression.h"           // for VariableExpr, ConstantExpr, Untype...
34#include "SynTree/Initializer.h"          // for SingleInit, ListInit, Initializer ...
35#include "SynTree/Label.h"                // for Label
36#include "SynTree/Statement.h"            // for CompoundStmt, DeclStmt, ExprStmt
37#include "SynTree/Type.h"                 // for StructInstType, Type, PointerType
38#include "SynTree/Visitor.h"              // for Visitor, acceptAll
39#include "Virtual/Tables.h"
40
41class Attribute;
42
43namespace Concurrency {
44        inline static std::string getVTableName( std::string const & exception_name ) {
45                return exception_name.empty() ? std::string() : Virtual::vtableTypeName(exception_name);
46        }
47
48        //=============================================================================================
49        // Pass declarations
50        //=============================================================================================
51
52        //-----------------------------------------------------------------------------
53        //Handles sue type declarations :
54        // sue MyType {                             struct MyType {
55        //      int data;                                  int data;
56        //      a_struct_t more_data;                      a_struct_t more_data;
57        //                                =>             NewField_t newField;
58        // };                                        };
59        //                                           static inline NewField_t * getter_name( MyType * this ) { return &this->newField; }
60        //
61        class ConcurrentSueKeyword : public WithDeclsToAdd {
62          public:
63
64                ConcurrentSueKeyword( std::string&& type_name, std::string&& field_name,
65                        std::string&& getter_name, std::string&& context_error, std::string&& exception_name,
66                        bool needs_main, AggregateDecl::Aggregate cast_target ) :
67                  type_name( type_name ), field_name( field_name ), getter_name( getter_name ),
68                  context_error( context_error ), vtable_name( getVTableName( exception_name ) ),
69                  needs_main( needs_main ), cast_target( cast_target ) {}
70
71                virtual ~ConcurrentSueKeyword() {}
72
73                Declaration * postmutate( StructDecl * decl );
74                DeclarationWithType * postmutate( FunctionDecl * decl );
75
76                void handle( StructDecl * );
77                void addVtableForward( StructDecl * );
78                FunctionDecl * forwardDeclare( StructDecl * );
79                ObjectDecl * addField( StructDecl * );
80                void addRoutines( ObjectDecl *, FunctionDecl * );
81
82                virtual bool is_target( StructDecl * decl ) = 0;
83
84                Expression * postmutate( KeywordCastExpr * cast );
85
86          private:
87                const std::string type_name;
88                const std::string field_name;
89                const std::string getter_name;
90                const std::string context_error;
91                const std::string vtable_name;
92                bool needs_main;
93                AggregateDecl::Aggregate cast_target;
94
95                StructDecl   * type_decl = nullptr;
96                FunctionDecl * dtor_decl = nullptr;
97                StructDecl * vtable_decl = nullptr;
98        };
99
100
101        //-----------------------------------------------------------------------------
102        //Handles thread type declarations :
103        // thread Mythread {                         struct MyThread {
104        //      int data;                                  int data;
105        //      a_struct_t more_data;                      a_struct_t more_data;
106        //                                =>             $thread __thrd_d;
107        // };                                        };
108        //                                           static inline $thread * get_thread( MyThread * this ) { return &this->__thrd_d; }
109        //
110        class ThreadKeyword final : public ConcurrentSueKeyword {
111          public:
112
113                ThreadKeyword() : ConcurrentSueKeyword(
114                        "$thread",
115                        "__thrd",
116                        "get_thread",
117                        "thread keyword requires threads to be in scope, add #include <thread.hfa>\n",
118                        "",
119                        true,
120                        AggregateDecl::Thread
121                )
122                {}
123
124                virtual ~ThreadKeyword() {}
125
126                virtual bool is_target( StructDecl * decl ) override final { return decl->is_thread(); }
127
128                static void implement( std::list< Declaration * > & translationUnit ) {
129                        PassVisitor< ThreadKeyword > impl;
130                        mutateAll( translationUnit, impl );
131                }
132        };
133
134        //-----------------------------------------------------------------------------
135        //Handles coroutine type declarations :
136        // coroutine MyCoroutine {                   struct MyCoroutine {
137        //      int data;                                  int data;
138        //      a_struct_t more_data;                      a_struct_t more_data;
139        //                                =>             $coroutine __cor_d;
140        // };                                        };
141        //                                           static inline $coroutine * get_coroutine( MyCoroutine * this ) { return &this->__cor_d; }
142        //
143        class CoroutineKeyword final : public ConcurrentSueKeyword {
144          public:
145
146                CoroutineKeyword() : ConcurrentSueKeyword(
147                        "$coroutine",
148                        "__cor",
149                        "get_coroutine",
150                        "coroutine keyword requires coroutines to be in scope, add #include <coroutine.hfa>\n",
151                        "CoroutineCancelled",
152                        true,
153                        AggregateDecl::Coroutine
154                )
155                {}
156
157                virtual ~CoroutineKeyword() {}
158
159                virtual bool is_target( StructDecl * decl ) override final { return decl->is_coroutine(); }
160
161                static void implement( std::list< Declaration * > & translationUnit ) {
162                        PassVisitor< CoroutineKeyword > impl;
163                        mutateAll( translationUnit, impl );
164                }
165        };
166
167
168
169        //-----------------------------------------------------------------------------
170        //Handles monitor type declarations :
171        // monitor MyMonitor {                       struct MyMonitor {
172        //      int data;                                  int data;
173        //      a_struct_t more_data;                      a_struct_t more_data;
174        //                                =>             $monitor __mon_d;
175        // };                                        };
176        //                                           static inline $monitor * get_coroutine( MyMonitor * this ) { return &this->__cor_d; }
177        //
178        class MonitorKeyword final : public ConcurrentSueKeyword {
179          public:
180
181                MonitorKeyword() : ConcurrentSueKeyword(
182                        "$monitor",
183                        "__mon",
184                        "get_monitor",
185                        "monitor keyword requires monitors to be in scope, add #include <monitor.hfa>\n",
186                        "",
187                        false,
188                        AggregateDecl::Monitor
189                )
190                {}
191
192                virtual ~MonitorKeyword() {}
193
194                virtual bool is_target( StructDecl * decl ) override final { return decl->is_monitor(); }
195
196                static void implement( std::list< Declaration * > & translationUnit ) {
197                        PassVisitor< MonitorKeyword > impl;
198                        mutateAll( translationUnit, impl );
199                }
200        };
201
202        //-----------------------------------------------------------------------------
203        //Handles generator type declarations :
204        // generator MyGenerator {                   struct MyGenerator {
205        //      int data;                                  int data;
206        //      a_struct_t more_data;                      a_struct_t more_data;
207        //                                =>             int __gen_next;
208        // };                                        };
209        //
210        class GeneratorKeyword final : public ConcurrentSueKeyword {
211          public:
212
213                GeneratorKeyword() : ConcurrentSueKeyword(
214                        "$generator",
215                        "__generator_state",
216                        "get_generator",
217                        "Unable to find builtin type $generator\n",
218                        "",
219                        true,
220                        AggregateDecl::Generator
221                )
222                {}
223
224                virtual ~GeneratorKeyword() {}
225
226                virtual bool is_target( StructDecl * decl ) override final { return decl->is_generator(); }
227
228                static void implement( std::list< Declaration * > & translationUnit ) {
229                        PassVisitor< GeneratorKeyword > impl;
230                        mutateAll( translationUnit, impl );
231                }
232        };
233
234
235        //-----------------------------------------------------------------------------
236        class SuspendKeyword final : public WithStmtsToAdd, public WithGuards {
237        public:
238                SuspendKeyword() = default;
239                virtual ~SuspendKeyword() = default;
240
241                void  premutate( FunctionDecl * );
242                DeclarationWithType * postmutate( FunctionDecl * );
243
244                Statement * postmutate( SuspendStmt * );
245
246                static void implement( std::list< Declaration * > & translationUnit ) {
247                        PassVisitor< SuspendKeyword > impl;
248                        mutateAll( translationUnit, impl );
249                }
250
251        private:
252                bool is_real_suspend( FunctionDecl * );
253
254                Statement * make_generator_suspend( SuspendStmt * );
255                Statement * make_coroutine_suspend( SuspendStmt * );
256
257                struct LabelPair {
258                        Label obj;
259                        int   idx;
260                };
261
262                LabelPair make_label() {
263                        labels.push_back( gen.newLabel("generator") );
264                        return { labels.back(), int(labels.size()) };
265                }
266
267                DeclarationWithType * in_generator = nullptr;
268                FunctionDecl * decl_suspend = nullptr;
269                std::vector<Label> labels;
270                ControlStruct::LabelGenerator & gen = *ControlStruct::LabelGenerator::getGenerator();
271        };
272
273        //-----------------------------------------------------------------------------
274        //Handles mutex routines definitions :
275        // void foo( A * mutex a, B * mutex b,  int i ) {                  void foo( A * a, B * b,  int i ) {
276        //                                                                       $monitor * __monitors[] = { get_monitor(a), get_monitor(b) };
277        //                                                                       monitor_guard_t __guard = { __monitors, 2 };
278        //    /*Some code*/                                       =>           /*Some code*/
279        // }                                                               }
280        //
281        class MutexKeyword final {
282          public:
283
284                void postvisit( FunctionDecl * decl );
285                void postvisit(   StructDecl * decl );
286
287                std::list<DeclarationWithType*> findMutexArgs( FunctionDecl*, bool & first );
288                void validate( DeclarationWithType * );
289                void addDtorStatments( FunctionDecl* func, CompoundStmt *, const std::list<DeclarationWithType * > &);
290                void addStatments( FunctionDecl* func, CompoundStmt *, const std::list<DeclarationWithType * > &);
291
292                static void implement( std::list< Declaration * > & translationUnit ) {
293                        PassVisitor< MutexKeyword > impl;
294                        acceptAll( translationUnit, impl );
295                }
296
297          private:
298                StructDecl* monitor_decl = nullptr;
299                StructDecl* guard_decl = nullptr;
300                StructDecl* dtor_guard_decl = nullptr;
301
302                static std::unique_ptr< Type > generic_func;
303        };
304
305        std::unique_ptr< Type > MutexKeyword::generic_func = std::unique_ptr< Type >(
306                new FunctionType(
307                        noQualifiers,
308                        true
309                )
310        );
311
312        //-----------------------------------------------------------------------------
313        //Handles mutex routines definitions :
314        // void foo( A * mutex a, B * mutex b,  int i ) {                  void foo( A * a, B * b,  int i ) {
315        //                                                                       $monitor * __monitors[] = { get_monitor(a), get_monitor(b) };
316        //                                                                       monitor_guard_t __guard = { __monitors, 2 };
317        //    /*Some code*/                                       =>           /*Some code*/
318        // }                                                               }
319        //
320        class ThreadStarter final {
321          public:
322
323                void postvisit( FunctionDecl * decl );
324                void previsit ( StructDecl   * decl );
325
326                void addStartStatement( FunctionDecl * decl, DeclarationWithType * param );
327
328                static void implement( std::list< Declaration * > & translationUnit ) {
329                        PassVisitor< ThreadStarter > impl;
330                        acceptAll( translationUnit, impl );
331                }
332
333          private :
334                bool thread_ctor_seen = false;
335                StructDecl * thread_decl = nullptr;
336        };
337
338        //=============================================================================================
339        // General entry routine
340        //=============================================================================================
341        void applyKeywords( std::list< Declaration * > & translationUnit ) {
342                ThreadKeyword   ::implement( translationUnit );
343                CoroutineKeyword        ::implement( translationUnit );
344                MonitorKeyword  ::implement( translationUnit );
345                GeneratorKeyword  ::implement( translationUnit );
346                SuspendKeyword    ::implement( translationUnit );
347        }
348
349        void implementMutexFuncs( std::list< Declaration * > & translationUnit ) {
350                MutexKeyword    ::implement( translationUnit );
351        }
352
353        void implementThreadStarter( std::list< Declaration * > & translationUnit ) {
354                ThreadStarter   ::implement( translationUnit );
355        }
356
357        //=============================================================================================
358        // Generic keyword implementation
359        //=============================================================================================
360        void fixupGenerics(FunctionType * func, StructDecl * decl) {
361                cloneAll(decl->parameters, func->forall);
362                for ( TypeDecl * td : func->forall ) {
363                        strict_dynamic_cast<StructInstType*>(
364                                func->parameters.front()->get_type()->stripReferences()
365                        )->parameters.push_back(
366                                new TypeExpr( new TypeInstType( noQualifiers, td->name, td ) )
367                        );
368                }
369        }
370
371        Declaration * ConcurrentSueKeyword::postmutate(StructDecl * decl) {
372                if( decl->name == type_name && decl->body ) {
373                        assert( !type_decl );
374                        type_decl = decl;
375                }
376                else if ( is_target(decl) ) {
377                        handle( decl );
378                }
379                else if ( !vtable_decl && vtable_name == decl->name && decl->body ) {
380                        vtable_decl = decl;
381                }
382                // Might be able to get ride of is target.
383                assert( is_target(decl) == (cast_target == decl->kind) );
384                return decl;
385        }
386
387        DeclarationWithType * ConcurrentSueKeyword::postmutate( FunctionDecl * decl ) {
388                if ( type_decl && isDestructorFor( decl, type_decl ) )
389                        dtor_decl = decl;
390                else if ( vtable_name.empty() )
391                        ;
392                else if ( auto param = isMainFor( decl, cast_target ) ) {
393                        // This should never trigger.
394                        assert( vtable_decl );
395                        // Should be safe because of isMainFor.
396                        StructInstType * struct_type = static_cast<StructInstType *>(
397                                static_cast<ReferenceType *>( param->get_type() )->base );
398                        assert( struct_type );
399
400                        declsToAddAfter.push_back( Virtual::makeVtableInstance( vtable_decl, {
401                                new TypeExpr( struct_type->clone() ),
402                        }, struct_type, nullptr ) );
403                }
404
405                return decl;
406        }
407
408        Expression * ConcurrentSueKeyword::postmutate( KeywordCastExpr * cast ) {
409                if ( cast_target == cast->target ) {
410                        // convert (thread &)t to ($thread &)*get_thread(t), etc.
411                        if( !type_decl ) SemanticError( cast, context_error );
412                        if( !dtor_decl ) SemanticError( cast, context_error );
413                        assert( cast->result == nullptr );
414                        cast->set_result( new ReferenceType( noQualifiers, new StructInstType( noQualifiers, type_decl ) ) );
415                        cast->concrete_target.field  = field_name;
416                        cast->concrete_target.getter = getter_name;
417                }
418                return cast;
419        }
420
421
422        void ConcurrentSueKeyword::handle( StructDecl * decl ) {
423                if( ! decl->body ) return;
424
425                if( !type_decl ) SemanticError( decl, context_error );
426                if( !dtor_decl ) SemanticError( decl, context_error );
427
428                addVtableForward( decl );
429                FunctionDecl * func = forwardDeclare( decl );
430                ObjectDecl * field = addField( decl );
431                addRoutines( field, func );
432        }
433
434        void ConcurrentSueKeyword::addVtableForward( StructDecl * decl ) {
435                if ( vtable_decl ) {
436                        declsToAddBefore.push_back( Virtual::makeVtableForward( vtable_decl, {
437                                new TypeExpr( new StructInstType( noQualifiers, decl ) ),
438                        } ) );
439                // Its only an error if we want a vtable and don't have one.
440                } else if ( ! vtable_name.empty() ) {
441                        SemanticError( decl, context_error );
442                }
443        }
444
445        FunctionDecl * ConcurrentSueKeyword::forwardDeclare( StructDecl * decl ) {
446
447                StructDecl * forward = decl->clone();
448                forward->set_body( false );
449                deleteAll( forward->get_members() );
450                forward->get_members().clear();
451
452                FunctionType * get_type = new FunctionType( noQualifiers, false );
453                ObjectDecl * this_decl = new ObjectDecl(
454                        "this",
455                        noStorageClasses,
456                        LinkageSpec::Cforall,
457                        nullptr,
458                        new ReferenceType(
459                                noQualifiers,
460                                new StructInstType(
461                                        noQualifiers,
462                                        decl
463                                )
464                        ),
465                        nullptr
466                );
467
468                get_type->get_parameters().push_back( this_decl->clone() );
469                get_type->get_returnVals().push_back(
470                        new ObjectDecl(
471                                "ret",
472                                noStorageClasses,
473                                LinkageSpec::Cforall,
474                                nullptr,
475                                new PointerType(
476                                        noQualifiers,
477                                        new StructInstType(
478                                                noQualifiers,
479                                                type_decl
480                                        )
481                                ),
482                                nullptr
483                        )
484                );
485                fixupGenerics(get_type, decl);
486
487                FunctionDecl * get_decl = new FunctionDecl(
488                        getter_name,
489                        Type::Static,
490                        LinkageSpec::Cforall,
491                        get_type,
492                        nullptr,
493                        { new Attribute("const") },
494                        Type::Inline
495                );
496
497                FunctionDecl * main_decl = nullptr;
498
499                if( needs_main ) {
500                        FunctionType * main_type = new FunctionType( noQualifiers, false );
501
502                        main_type->get_parameters().push_back( this_decl->clone() );
503
504                        main_decl = new FunctionDecl(
505                                "main",
506                                noStorageClasses,
507                                LinkageSpec::Cforall,
508                                main_type,
509                                nullptr
510                        );
511                        fixupGenerics(main_type, decl);
512                }
513
514                delete this_decl;
515
516                declsToAddBefore.push_back( forward );
517                if( needs_main ) declsToAddBefore.push_back( main_decl );
518                declsToAddBefore.push_back( get_decl );
519
520                return get_decl;
521        }
522
523        ObjectDecl * ConcurrentSueKeyword::addField( StructDecl * decl ) {
524                ObjectDecl * field = new ObjectDecl(
525                        field_name,
526                        noStorageClasses,
527                        LinkageSpec::Cforall,
528                        nullptr,
529                        new StructInstType(
530                                noQualifiers,
531                                type_decl
532                        ),
533                        nullptr
534                );
535
536                decl->get_members().push_back( field );
537
538                return field;
539        }
540
541        void ConcurrentSueKeyword::addRoutines( ObjectDecl * field, FunctionDecl * func ) {
542                CompoundStmt * statement = new CompoundStmt();
543                statement->push_back(
544                        new ReturnStmt(
545                                new AddressExpr(
546                                        new MemberExpr(
547                                                field,
548                                                new CastExpr(
549                                                        new VariableExpr( func->get_functionType()->get_parameters().front() ),
550                                                        func->get_functionType()->get_parameters().front()->get_type()->stripReferences()->clone(),
551                                                        false
552                                                )
553                                        )
554                                )
555                        )
556                );
557
558                FunctionDecl * get_decl = func->clone();
559
560                get_decl->set_statements( statement );
561
562                declsToAddAfter.push_back( get_decl );
563        }
564
565        //=============================================================================================
566        // Suspend keyword implementation
567        //=============================================================================================
568        bool SuspendKeyword::is_real_suspend( FunctionDecl * func ) {
569                if(isMangled(func->linkage)) return false; // the real suspend isn't mangled
570                if(func->name != "__cfactx_suspend") return false; // the real suspend has a specific name
571                if(func->type->parameters.size() != 0) return false; // Too many parameters
572                if(func->type->returnVals.size() != 0) return false; // Too many return values
573
574                return true;
575        }
576
577        void SuspendKeyword::premutate( FunctionDecl * func ) {
578                GuardValue(in_generator);
579                in_generator = nullptr;
580
581                // Is this the real suspend?
582                if(is_real_suspend(func)) {
583                        decl_suspend = decl_suspend ? decl_suspend : func;
584                        return;
585                }
586
587                // Is this the main of a generator?
588                auto param = isMainFor( func, AggregateDecl::Aggregate::Generator );
589                if(!param) return;
590
591                if(func->type->returnVals.size() != 0) SemanticError(func->location, "Generator main must return void");
592
593                in_generator = param;
594                GuardValue(labels);
595                labels.clear();
596        }
597
598        DeclarationWithType * SuspendKeyword::postmutate( FunctionDecl * func ) {
599                if( !func->statements ) return func; // Not the actual definition, don't do anything
600                if( !in_generator     ) return func; // Not in a generator, don't do anything
601                if( labels.empty()    ) return func; // Generator has no states, nothing to do, could throw a warning
602
603                // This is a generator main, we need to add the following code to the top
604                // static void * __generator_labels[] = {&&s0, &&s1, ...};
605                // goto * __generator_labels[gen.__generator_state];
606                const auto & loc = func->location;
607
608                const auto first_label = gen.newLabel("generator");
609
610                // for each label add to declaration
611                std::list<Initializer*> inits = { new SingleInit( new LabelAddressExpr( first_label ) ) };
612                for(const auto & label : labels) {
613                        inits.push_back(
614                                new SingleInit(
615                                        new LabelAddressExpr( label )
616                                )
617                        );
618                }
619                auto init = new ListInit(std::move(inits), noDesignators, true);
620                labels.clear();
621
622                // create decl
623                auto decl = new ObjectDecl(
624                        "__generator_labels",
625                        Type::StorageClasses( Type::Static ),
626                        LinkageSpec::AutoGen,
627                        nullptr,
628                        new ArrayType(
629                                Type::Qualifiers(),
630                                new PointerType(
631                                        Type::Qualifiers(),
632                                        new VoidType( Type::Qualifiers() )
633                                ),
634                                nullptr,
635                                false, false
636                        ),
637                        init
638                );
639
640                // create the goto
641                assert(in_generator);
642
643                auto go_decl = new ObjectDecl(
644                        "__generator_label",
645                        noStorageClasses,
646                        LinkageSpec::AutoGen,
647                        nullptr,
648                        new PointerType(
649                                Type::Qualifiers(),
650                                new VoidType( Type::Qualifiers() )
651                        ),
652                        new SingleInit(
653                                new UntypedExpr(
654                                        new NameExpr("?[?]"),
655                                        {
656                                                new NameExpr("__generator_labels"),
657                                                new UntypedMemberExpr(
658                                                        new NameExpr("__generator_state"),
659                                                        new VariableExpr( in_generator )
660                                                )
661                                        }
662                                )
663                        )
664                );
665                go_decl->location = loc;
666
667                auto go = new BranchStmt(
668                        new VariableExpr( go_decl ),
669                        BranchStmt::Goto
670                );
671                go->location = loc;
672                go->computedTarget->location = loc;
673
674                auto noop = new NullStmt({ first_label });
675                noop->location = loc;
676
677                // wrap everything in a nice compound
678                auto body = new CompoundStmt({
679                        new DeclStmt( decl ),
680                        new DeclStmt( go_decl ),
681                        go,
682                        noop,
683                        func->statements
684                });
685                body->location   = loc;
686                func->statements = body;
687
688                return func;
689        }
690
691        Statement * SuspendKeyword::postmutate( SuspendStmt * stmt ) {
692                SuspendStmt::Type type = stmt->type;
693                if(type == SuspendStmt::None) {
694                        // This suspend has a implicit target, find it
695                        type = in_generator ? SuspendStmt::Generator : SuspendStmt::Coroutine;
696                }
697
698                // Check that the target makes sense
699                if(!in_generator && type == SuspendStmt::Generator) SemanticError( stmt->location, "'suspend generator' must be used inside main of generator type.");
700
701                // Act appropriately
702                switch(type) {
703                        case SuspendStmt::Generator: return make_generator_suspend(stmt);
704                        case SuspendStmt::Coroutine: return make_coroutine_suspend(stmt);
705                        default: abort();
706                }
707        }
708
709        Statement * SuspendKeyword::make_generator_suspend( SuspendStmt * stmt ) {
710                assert(in_generator);
711                // Target code is :
712                //   gen.__generator_state = X;
713                //   { THEN }
714                //   return;
715                //   __gen_X:;
716
717                // Save the location and delete the old statement, we only need the location from this point on
718                auto loc = stmt->location;
719
720                // Build the label and get its index
721                auto label = make_label();
722
723                // Create the context saving statement
724                auto save = new ExprStmt( new UntypedExpr(
725                        new NameExpr( "?=?" ),
726                        {
727                                new UntypedMemberExpr(
728                                        new NameExpr("__generator_state"),
729                                        new VariableExpr( in_generator )
730                                ),
731                                new ConstantExpr(
732                                        Constant::from_int( label.idx )
733                                )
734                        }
735                ));
736                assert(save->expr);
737                save->location = loc;
738                stmtsToAddBefore.push_back( save );
739
740                // if we have a then add it here
741                auto then = stmt->then;
742                stmt->then = nullptr;
743                delete stmt;
744                if(then) stmtsToAddBefore.push_back( then );
745
746                // Create the return statement
747                auto ret = new ReturnStmt( nullptr );
748                ret->location = loc;
749                stmtsToAddBefore.push_back( ret );
750
751                // Create the null statement with the created label
752                auto noop = new NullStmt({ label.obj });
753                noop->location = loc;
754
755                // Return the null statement to take the place of the previous statement
756                return noop;
757        }
758
759        Statement * SuspendKeyword::make_coroutine_suspend( SuspendStmt * stmt ) {
760                if(stmt->then) SemanticError( stmt->location, "Compound statement following coroutines is not implemented.");
761
762                // Save the location and delete the old statement, we only need the location from this point on
763                auto loc = stmt->location;
764                delete stmt;
765
766                // Create the call expression
767                if(!decl_suspend) SemanticError( loc, "suspend keyword applied to coroutines requires coroutines to be in scope, add #include <coroutine.hfa>\n");
768                auto expr = new UntypedExpr( VariableExpr::functionPointer( decl_suspend ) );
769                expr->location = loc;
770
771                // Change this statement into a regular expr
772                assert(expr);
773                auto nstmt = new ExprStmt( expr );
774                nstmt->location = loc;
775                return nstmt;
776        }
777
778
779        //=============================================================================================
780        // Mutex keyword implementation
781        //=============================================================================================
782
783        void MutexKeyword::postvisit(FunctionDecl* decl) {
784
785                bool first = false;
786                std::list<DeclarationWithType*> mutexArgs = findMutexArgs( decl, first );
787                bool isDtor = CodeGen::isDestructor( decl->name );
788
789                // Is this function relevant to monitors
790                if( mutexArgs.empty() ) {
791                        // If this is the destructor for a monitor it must be mutex
792                        if(isDtor) {
793                                Type* ty = decl->get_functionType()->get_parameters().front()->get_type();
794
795                                // If it's a copy, it's not a mutex
796                                ReferenceType* rty = dynamic_cast< ReferenceType * >( ty );
797                                if( ! rty ) return;
798
799                                // If we are not pointing directly to a type, it's not a mutex
800                                Type* base = rty->get_base();
801                                if( dynamic_cast< ReferenceType * >( base ) ) return;
802                                if( dynamic_cast< PointerType * >( base ) ) return;
803
804                                // Check if its a struct
805                                StructInstType * baseStruct = dynamic_cast< StructInstType * >( base );
806                                if( !baseStruct ) return;
807
808                                // Check if its a monitor
809                                if(baseStruct->baseStruct->is_monitor() || baseStruct->baseStruct->is_thread())
810                                        SemanticError( decl, "destructors for structures declared as \"monitor\" must use mutex parameters\n" );
811                        }
812                        return;
813                }
814
815                // Monitors can't be constructed with mutual exclusion
816                if( CodeGen::isConstructor(decl->name) && !first ) SemanticError( decl, "constructors cannot have mutex parameters" );
817
818                // It makes no sense to have multiple mutex parameters for the destructor
819                if( isDtor && mutexArgs.size() != 1 ) SemanticError( decl, "destructors can only have 1 mutex argument" );
820
821                // Make sure all the mutex arguments are monitors
822                for(auto arg : mutexArgs) {
823                        validate( arg );
824                }
825
826                // Check if we need to instrument the body
827                CompoundStmt* body = decl->get_statements();
828                if( ! body ) return;
829
830                // Do we have the required headers
831                if( !monitor_decl || !guard_decl || !dtor_guard_decl )
832                        SemanticError( decl, "mutex keyword requires monitors to be in scope, add #include <monitor.hfa>\n" );
833
834                // Instrument the body
835                if( isDtor ) {
836                        addDtorStatments( decl, body, mutexArgs );
837                }
838                else {
839                        addStatments( decl, body, mutexArgs );
840                }
841        }
842
843        void MutexKeyword::postvisit(StructDecl* decl) {
844
845                if( decl->name == "$monitor" && decl->body ) {
846                        assert( !monitor_decl );
847                        monitor_decl = decl;
848                }
849                else if( decl->name == "monitor_guard_t" && decl->body ) {
850                        assert( !guard_decl );
851                        guard_decl = decl;
852                }
853                else if( decl->name == "monitor_dtor_guard_t" && decl->body ) {
854                        assert( !dtor_guard_decl );
855                        dtor_guard_decl = decl;
856                }
857        }
858
859        std::list<DeclarationWithType*> MutexKeyword::findMutexArgs( FunctionDecl* decl, bool & first ) {
860                std::list<DeclarationWithType*> mutexArgs;
861
862                bool once = true;
863                for( auto arg : decl->get_functionType()->get_parameters()) {
864                        //Find mutex arguments
865                        Type* ty = arg->get_type();
866                        if( ! ty->get_mutex() ) continue;
867
868                        if(once) {first = true;}
869                        once = false;
870
871                        //Append it to the list
872                        mutexArgs.push_back( arg );
873                }
874
875                return mutexArgs;
876        }
877
878        void MutexKeyword::validate( DeclarationWithType * arg ) {
879                Type* ty = arg->get_type();
880
881                //Makes sure it's not a copy
882                ReferenceType* rty = dynamic_cast< ReferenceType * >( ty );
883                if( ! rty ) SemanticError( arg, "Mutex argument must be of reference type " );
884
885                //Make sure the we are pointing directly to a type
886                Type* base = rty->get_base();
887                if( dynamic_cast< ReferenceType * >( base ) ) SemanticError( arg, "Mutex argument have exactly one level of indirection " );
888                if( dynamic_cast< PointerType * >( base ) ) SemanticError( arg, "Mutex argument have exactly one level of indirection " );
889
890                //Make sure that typed isn't mutex
891                if( base->get_mutex() ) SemanticError( arg, "mutex keyword may only appear once per argument " );
892        }
893
894        void MutexKeyword::addDtorStatments( FunctionDecl* func, CompoundStmt * body, const std::list<DeclarationWithType * > & args ) {
895                Type * arg_type = args.front()->get_type()->clone();
896                arg_type->set_mutex( false );
897
898                ObjectDecl * monitors = new ObjectDecl(
899                        "__monitor",
900                        noStorageClasses,
901                        LinkageSpec::Cforall,
902                        nullptr,
903                        new PointerType(
904                                noQualifiers,
905                                new StructInstType(
906                                        noQualifiers,
907                                        monitor_decl
908                                )
909                        ),
910                        new SingleInit( new UntypedExpr(
911                                new NameExpr( "get_monitor" ),
912                                {  new CastExpr( new VariableExpr( args.front() ), arg_type, false ) }
913                        ))
914                );
915
916                assert(generic_func);
917
918                //in reverse order :
919                // monitor_dtor_guard_t __guard = { __monitors, func };
920                body->push_front(
921                        new DeclStmt( new ObjectDecl(
922                                "__guard",
923                                noStorageClasses,
924                                LinkageSpec::Cforall,
925                                nullptr,
926                                new StructInstType(
927                                        noQualifiers,
928                                        dtor_guard_decl
929                                ),
930                                new ListInit(
931                                        {
932                                                new SingleInit( new AddressExpr( new VariableExpr( monitors ) ) ),
933                                                new SingleInit( new CastExpr( new VariableExpr( func ), generic_func->clone(), false ) )
934                                        },
935                                        noDesignators,
936                                        true
937                                )
938                        ))
939                );
940
941                //$monitor * __monitors[] = { get_monitor(a), get_monitor(b) };
942                body->push_front( new DeclStmt( monitors) );
943        }
944
945        void MutexKeyword::addStatments( FunctionDecl* func, CompoundStmt * body, const std::list<DeclarationWithType * > & args ) {
946                ObjectDecl * monitors = new ObjectDecl(
947                        "__monitors",
948                        noStorageClasses,
949                        LinkageSpec::Cforall,
950                        nullptr,
951                        new ArrayType(
952                                noQualifiers,
953                                new PointerType(
954                                        noQualifiers,
955                                        new StructInstType(
956                                                noQualifiers,
957                                                monitor_decl
958                                        )
959                                ),
960                                new ConstantExpr( Constant::from_ulong( args.size() ) ),
961                                false,
962                                false
963                        ),
964                        new ListInit(
965                                map_range < std::list<Initializer*> > ( args, [](DeclarationWithType * var ){
966                                        Type * type = var->get_type()->clone();
967                                        type->set_mutex( false );
968                                        return new SingleInit( new UntypedExpr(
969                                                new NameExpr( "get_monitor" ),
970                                                {  new CastExpr( new VariableExpr( var ), type, false ) }
971                                        ) );
972                                })
973                        )
974                );
975
976                assert(generic_func);
977
978                // in reverse order :
979                // monitor_guard_t __guard = { __monitors, #, func };
980                body->push_front(
981                        new DeclStmt( new ObjectDecl(
982                                "__guard",
983                                noStorageClasses,
984                                LinkageSpec::Cforall,
985                                nullptr,
986                                new StructInstType(
987                                        noQualifiers,
988                                        guard_decl
989                                ),
990                                new ListInit(
991                                        {
992                                                new SingleInit( new VariableExpr( monitors ) ),
993                                                new SingleInit( new ConstantExpr( Constant::from_ulong( args.size() ) ) ),
994                                                new SingleInit( new CastExpr( new VariableExpr( func ), generic_func->clone(), false ) )
995                                        },
996                                        noDesignators,
997                                        true
998                                )
999                        ))
1000                );
1001
1002                //$monitor * __monitors[] = { get_monitor(a), get_monitor(b) };
1003                body->push_front( new DeclStmt( monitors) );
1004        }
1005
1006        //=============================================================================================
1007        // General entry routine
1008        //=============================================================================================
1009        void ThreadStarter::previsit( StructDecl * decl ) {
1010                if( decl->name == "$thread" && decl->body ) {
1011                        assert( !thread_decl );
1012                        thread_decl = decl;
1013                }
1014        }
1015
1016        void ThreadStarter::postvisit(FunctionDecl * decl) {
1017                if( ! CodeGen::isConstructor(decl->name) ) return;
1018
1019                Type * typeof_this = InitTweak::getTypeofThis(decl->type);
1020                StructInstType * ctored_type = dynamic_cast< StructInstType * >( typeof_this );
1021                if( ctored_type && ctored_type->baseStruct == thread_decl ) {
1022                        thread_ctor_seen = true;
1023                }
1024
1025                DeclarationWithType * param = decl->get_functionType()->get_parameters().front();
1026                auto type  = dynamic_cast< StructInstType * >( InitTweak::getPointerBase( param->get_type() ) );
1027                if( type && type->get_baseStruct()->is_thread() ) {
1028                        if( !thread_decl || !thread_ctor_seen ) {
1029                                SemanticError( type->get_baseStruct()->location, "thread keyword requires threads to be in scope, add #include <thread.hfa>");
1030                        }
1031
1032                        addStartStatement( decl, param );
1033                }
1034        }
1035
1036        void ThreadStarter::addStartStatement( FunctionDecl * decl, DeclarationWithType * param ) {
1037                CompoundStmt * stmt = decl->get_statements();
1038
1039                if( ! stmt ) return;
1040
1041                stmt->push_back(
1042                        new ExprStmt(
1043                                new UntypedExpr(
1044                                        new NameExpr( "__thrd_start" ),
1045                                        { new VariableExpr( param ), new NameExpr("main") }
1046                                )
1047                        )
1048                );
1049        }
1050};
1051
1052// Local Variables: //
1053// mode: c //
1054// tab-width: 4 //
1055// End: //
1056
Note: See TracBrowser for help on using the repository browser.