source: src/Virtual/ExpandCasts.cc @ 251ce80

ast-experimental
Last change on this file since 251ce80 was bccd70a, checked in by Andrew Beach <ajbeach@…>, 13 months ago

Removed internal code from TypeSubstitution? header. It caused a chain of include problems, which have been corrected.

  • Property mode set to 100644
File size: 15.8 KB
Line 
1//
2// Cforall Version 1.0.0 Copyright (C) 2015 University of Waterloo
3//
4// The contents of this file are covered under the licence agreement in the
5// file "LICENCE" distributed with Cforall.
6//
7// ExpandCasts.cc --
8//
9// Author           : Andrew Beach
10// Created On       : Mon Jul 24 13:59:00 2017
11// Last Modified By : Andrew Beach
12// Last Modified On : Thu Aug 11 12:06:00 2022
13// Update Count     : 5
14//
15
16#include "ExpandCasts.h"
17
18#include <cassert>                 // for assert, assertf
19#include <iterator>                // for back_inserter, inserter
20#include <string>                  // for string, allocator, operator==, ope...
21
22#include "AST/Copy.hpp"
23#include "AST/Decl.hpp"
24#include "AST/Expr.hpp"
25#include "AST/Pass.hpp"
26#include "Common/PassVisitor.h"    // for PassVisitor
27#include "Common/ScopedMap.h"      // for ScopedMap
28#include "Common/SemanticError.h"  // for SemanticError
29#include "SymTab/Mangler.h"        // for mangleType
30#include "SynTree/Declaration.h"   // for ObjectDecl, StructDecl, FunctionDecl
31#include "SynTree/Expression.h"    // for VirtualCastExpr, CastExpr, Address...
32#include "SynTree/Mutator.h"       // for mutateAll
33#include "SynTree/Type.h"          // for Type, PointerType, StructInstType
34#include "SynTree/Visitor.h"       // for acceptAll
35
36namespace Virtual {
37
38namespace {
39
40bool is_prefix( const std::string & prefix, const std::string& entire ) {
41        size_t const p_size = prefix.size();
42        return (p_size < entire.size() && prefix == entire.substr(0, p_size));
43}
44
45bool is_type_id_object( const ObjectDecl * objectDecl ) {
46        const std::string & objectName = objectDecl->name;
47        return is_prefix( "__cfatid_", objectName );
48}
49
50bool is_type_id_object( const ast::ObjectDecl * decl ) {
51        return is_prefix( "__cfatid_", decl->name );
52}
53
54        // Indented until the new ast code gets added.
55
56        /// Maps virtual table types the instance for that type.
57        class VirtualTableMap final {
58                ScopedMap<std::string, ObjectDecl *> vtable_instances;
59        public:
60                void enterScope() {
61                        vtable_instances.beginScope();
62                }
63                void leaveScope() {
64                        vtable_instances.endScope();
65                }
66
67                ObjectDecl * insert( ObjectDecl * vtableDecl ) {
68                        std::string const & mangledName = SymTab::Mangler::mangleType( vtableDecl->type );
69                        ObjectDecl *& value = vtable_instances[ mangledName ];
70                        if ( value ) {
71                                if ( vtableDecl->storageClasses.is_extern ) {
72                                        return nullptr;
73                                } else if ( ! value->storageClasses.is_extern ) {
74                                        return value;
75                                }
76                        }
77                        value = vtableDecl;
78                        return nullptr;
79                }
80
81                ObjectDecl * lookup( const Type * vtableType ) {
82                        std::string const & mangledName = SymTab::Mangler::mangleType( vtableType );
83                        const auto it = vtable_instances.find( mangledName );
84                        return ( vtable_instances.end() == it ) ? nullptr : it->second;
85                }
86        };
87
88        class VirtualCastCore {
89                CastExpr * cast_to_type_id( Expression * expr, int level_of_indirection ) {
90                        Type * type = new StructInstType(
91                                Type::Qualifiers( Type::Const ), pvt_decl );
92                        for (int i = 0 ; i < level_of_indirection ; ++i) {
93                                type = new PointerType( noQualifiers, type );
94                        }
95                        return new CastExpr( expr, type );
96                }
97
98        public:
99                VirtualCastCore() :
100                        indexer(), vcast_decl( nullptr ), pvt_decl( nullptr )
101                {}
102
103                void premutate( FunctionDecl * functionDecl );
104                void premutate( StructDecl * structDecl );
105                void premutate( ObjectDecl * objectDecl );
106
107                Expression * postmutate( VirtualCastExpr * castExpr );
108
109                VirtualTableMap indexer;
110        private:
111                FunctionDecl *vcast_decl;
112                StructDecl *pvt_decl;
113        };
114
115        void VirtualCastCore::premutate( FunctionDecl * functionDecl ) {
116                if ( (! vcast_decl) &&
117                     functionDecl->get_name() == "__cfavir_virtual_cast" ) {
118                        vcast_decl = functionDecl;
119                }
120        }
121
122        void VirtualCastCore::premutate( StructDecl * structDecl ) {
123                if ( pvt_decl || ! structDecl->has_body() ) {
124                        return;
125                } else if ( structDecl->get_name() == "__cfavir_type_info" ) {
126                        pvt_decl = structDecl;
127                }
128        }
129
130        void VirtualCastCore::premutate( ObjectDecl * objectDecl ) {
131                if ( is_type_id_object( objectDecl ) ) {
132                        // Multiple definitions should be fine because of linkonce.
133                        indexer.insert( objectDecl );
134                }
135        }
136
137        /// Better error locations for generated casts.
138        CodeLocation castLocation( const VirtualCastExpr * castExpr ) {
139                if ( castExpr->location.isSet() ) {
140                        return castExpr->location;
141                } else if ( castExpr->arg->location.isSet() ) {
142                        return castExpr->arg->location;
143                } else if ( castExpr->result->location.isSet() ) {
144                        return castExpr->result->location;
145                } else {
146                        return CodeLocation();
147                }
148        }
149
150        [[noreturn]] void castError( const VirtualCastExpr * castExpr, std::string const & message ) {
151                SemanticError( castLocation( castExpr ), message );
152        }
153
154        /// Get the base type from a pointer or reference.
155        const Type * getBaseType( const Type * type ) {
156                if ( auto target = dynamic_cast<const PointerType *>( type ) ) {
157                        return target->base;
158                } else if ( auto target = dynamic_cast<const ReferenceType *>( type ) ) {
159                        return target->base;
160                } else {
161                        return nullptr;
162                }
163        }
164
165        /* Attempt to follow the "head" field of the structure to get the...
166         * Returns nullptr on error, otherwise owner must free returned node.
167         */
168        StructInstType * followHeadPointerType(
169                        const StructInstType * oldType,
170                        const std::string& fieldName,
171                        const CodeLocation& errorLocation ) {
172
173                // First section of the function is all about trying to fill this variable in.
174                StructInstType * newType = nullptr;
175                {
176                        const StructDecl * oldDecl = oldType->baseStruct;
177                        assert( oldDecl );
178
179                        // Helper function for throwing semantic errors.
180                        auto throwError = [&fieldName, &errorLocation, &oldDecl](const std::string& message) {
181                                const std::string& context = "While following head pointer of " +
182                                        oldDecl->name + " named '" + fieldName + "': ";
183                                SemanticError( errorLocation, context + message );
184                        };
185
186                        if ( oldDecl->members.empty() ) {
187                                throwError( "Type has no fields." );
188                        }
189                        const Declaration * memberDecl = oldDecl->members.front();
190                        assert( memberDecl );
191                        const ObjectDecl * fieldDecl = dynamic_cast<const ObjectDecl *>( memberDecl );
192                        assert( fieldDecl );
193                        if ( fieldName != fieldDecl->name ) {
194                                throwError( "Head field did not have expected name." );
195                        }
196
197                        const Type * fieldType = fieldDecl->type;
198                        if ( nullptr == fieldType ) {
199                                throwError( "Could not get head field." );
200                        }
201                        const PointerType * ptrType = dynamic_cast<const PointerType *>( fieldType );
202                        if ( nullptr == ptrType ) {
203                                throwError( "First field is not a pointer type." );
204                        }
205                        assert( ptrType->base );
206                        newType = dynamic_cast<StructInstType *>( ptrType->base );
207                        if ( nullptr == newType ) {
208                                throwError( "First field does not point to a structure type." );
209                        }
210                }
211
212                // Now we can look into copying it.
213                newType = newType->clone();
214                if ( ! oldType->parameters.empty() ) {
215                        deleteAll( newType->parameters );
216                        newType->parameters.clear();
217                        cloneAll( oldType->parameters, newType->parameters );
218                }
219                return newType;
220        }
221
222        /// Get the type-id type from a virtual type.
223        StructInstType * getTypeIdType( const Type * type, const CodeLocation& errorLocation ) {
224                const StructInstType * typeInst = dynamic_cast<const StructInstType *>( type );
225                if ( nullptr == typeInst ) {
226                        return nullptr;
227                }
228                StructInstType * tableInst =
229                        followHeadPointerType( typeInst, "virtual_table", errorLocation );
230                if ( nullptr == tableInst ) {
231                        return nullptr;
232                }
233                StructInstType * typeIdInst =
234                        followHeadPointerType( tableInst, "__cfavir_typeid", errorLocation );
235                delete tableInst;
236                return typeIdInst;
237        }
238
239        Expression * VirtualCastCore::postmutate( VirtualCastExpr * castExpr ) {
240                assertf( castExpr->result, "Virtual Cast target not found before expansion." );
241
242                assert( vcast_decl );
243                assert( pvt_decl );
244
245                const Type * base_type = getBaseType( castExpr->result );
246                if ( nullptr == base_type ) {
247                        castError( castExpr, "Virtual cast target must be a pointer or reference type." );
248                }
249                const Type * type_id_type = getTypeIdType( base_type, castLocation( castExpr ) );
250                if ( nullptr == type_id_type ) {
251                        castError( castExpr, "Ill formed virtual cast target type." );
252                }
253                ObjectDecl * type_id = indexer.lookup( type_id_type );
254                delete type_id_type;
255                if ( nullptr == type_id ) {
256                        castError( castExpr, "Virtual cast does not target a virtual type." );
257                }
258
259                Expression * result = new CastExpr(
260                        new ApplicationExpr( VariableExpr::functionPointer( vcast_decl ), {
261                                cast_to_type_id( new AddressExpr( new VariableExpr( type_id ) ), 1 ),
262                                cast_to_type_id( castExpr->get_arg(), 2 ),
263                        } ),
264                        castExpr->get_result()->clone()
265                );
266
267                castExpr->set_arg( nullptr );
268                castExpr->set_result( nullptr );
269                delete castExpr;
270                return result;
271        }
272
273/// Better error locations for generated casts.
274// TODO: Does the improved distribution of code locations make this unneeded?
275CodeLocation castLocation( const ast::VirtualCastExpr * castExpr ) {
276        if ( castExpr->location.isSet() ) {
277                return castExpr->location;
278        } else if ( castExpr->arg->location.isSet() ) {
279                return castExpr->arg->location;
280        } else {
281                return CodeLocation();
282        }
283}
284
285[[noreturn]] void castError( ast::VirtualCastExpr const * castExpr, std::string const & message ) {
286        SemanticError( castLocation( castExpr ), message );
287}
288
289class TypeIdTable final {
290        ScopedMap<std::string, ast::ObjectDecl const *> instances;
291public:
292        void enterScope() { instances.beginScope(); }
293        void leaveScope() { instances.endScope(); }
294
295        // Attempt to insert an instance into the map. If there is a conflict,
296        // returns the previous declaration for error messages.
297        ast::ObjectDecl const * insert( ast::ObjectDecl const * typeIdDecl ) {
298                std::string mangledName = Mangle::mangleType( typeIdDecl->type );
299                ast::ObjectDecl const *& value = instances[ mangledName ];
300                if ( value ) {
301                        if ( typeIdDecl->storage.is_extern ) {
302                                return nullptr;
303                        } else if ( !value->storage.is_extern ) {
304                                return value;
305                        }
306                }
307                value = typeIdDecl;
308                return nullptr;
309        }
310
311        ast::ObjectDecl const * lookup( ast::Type const * typeIdType ) {
312                std::string mangledName = Mangle::mangleType( typeIdType );
313                auto const it = instances.find( mangledName );
314                return ( instances.end() == it ) ? nullptr : it->second;
315        }
316};
317
318struct ExpandCastsCore final {
319        void previsit( ast::FunctionDecl const * decl );
320        void previsit( ast::StructDecl const * decl );
321        void previsit( ast::ObjectDecl const * decl );
322        ast::Expr const * postvisit( ast::VirtualCastExpr const * expr );
323
324        ast::CastExpr const * cast_to_type_id(
325                ast::Expr const * expr, unsigned int level_of_indirection );
326
327        ast::FunctionDecl const * vcast_decl = nullptr;
328        ast::StructDecl const * info_decl = nullptr;
329
330        TypeIdTable symtab;
331};
332
333void ExpandCastsCore::previsit( ast::FunctionDecl const * decl ) {
334        if ( !vcast_decl && "__cfavir_virtual_cast" == decl->name ) {
335                vcast_decl = decl;
336        }
337}
338
339void ExpandCastsCore::previsit( ast::StructDecl const * decl ) {
340        if ( !info_decl && decl->body && "__cfavir_type_info" == decl->name ) {
341                info_decl = decl;
342        }
343}
344
345void ExpandCastsCore::previsit( ast::ObjectDecl const * decl ) {
346        if ( is_type_id_object( decl ) ) {
347                // Multiple definitions should be fine because of linkonce.
348                symtab.insert( decl );
349        }
350}
351
352/// Get the base type from a pointer or reference.
353ast::Type const * getBaseType( ast::ptr<ast::Type> const & type ) {
354        if ( auto target = type.as<ast::PointerType>() ) {
355                return target->base.get();
356        } else if ( auto target = type.as<ast::ReferenceType>() ) {
357                return target->base.get();
358        } else {
359                return nullptr;
360        }
361}
362
363/// Copy newType, but give the copy the params of the oldType.
364ast::StructInstType * polyCopy(
365                ast::StructInstType const * oldType,
366                ast::StructInstType const * newType ) {
367        assert( oldType->params.size() == newType->params.size() );
368        ast::StructInstType * retType = ast::deepCopy( newType );
369        if ( ! oldType->params.empty() ) {
370                retType->params.clear();
371                for ( auto oldParams : oldType->params ) {
372                        retType->params.push_back( ast::deepCopy( oldParams ) );
373                }
374        }
375        return retType;
376}
377
378/// Follow the "head" field of the structure to get the type that is pointed
379/// to by that field.
380ast::StructInstType const * followHeadPointerType(
381                CodeLocation const & errorLocation,
382                ast::StructInstType const * oldType,
383                std::string const & fieldName ) {
384        ast::StructDecl const * oldDecl = oldType->base;
385        assert( oldDecl );
386
387        // Helper function for throwing semantic errors.
388        auto throwError = [&fieldName, &errorLocation, &oldDecl](
389                        std::string const & message ) {
390                std::string const & context = "While following head pointer of " +
391                        oldDecl->name + " named '" + fieldName + "': ";
392                SemanticError( errorLocation, context + message );
393        };
394
395        if ( oldDecl->members.empty() ) {
396                throwError( "Type has no fields." );
397        }
398        ast::ptr<ast::Decl> const & memberDecl = oldDecl->members.front();
399        assert( memberDecl );
400        ast::ObjectDecl const * fieldDecl = memberDecl.as<ast::ObjectDecl>();
401        assert( fieldDecl );
402        if ( fieldName != fieldDecl->name ) {
403                throwError( "Head field did not have expected name." );
404        }
405
406        ast::ptr<ast::Type> const & fieldType = fieldDecl->type;
407        if ( nullptr == fieldType ) {
408                throwError( "Could not get head field." );
409        }
410        auto ptrType = fieldType.as<ast::PointerType>();
411        if ( nullptr == ptrType ) {
412                throwError( "First field is not a pointer type." );
413        }
414        assert( ptrType->base );
415        auto newType = ptrType->base.as<ast::StructInstType>();
416        if ( nullptr == newType ) {
417                throwError( "First field does not point to a structure type." );
418        }
419
420        return polyCopy( oldType, newType );
421}
422
423/// Get the type-id type from a virtual type.
424ast::StructInstType const * getTypeIdType(
425                CodeLocation const & errorLocation,
426                ast::Type const * type ) {
427        auto typeInst = dynamic_cast<ast::StructInstType const *>( type );
428        if ( nullptr == typeInst ) {
429                return nullptr;
430        }
431        ast::ptr<ast::StructInstType> tableInst =
432                followHeadPointerType( errorLocation, typeInst, "virtual_table" );
433        if ( nullptr == tableInst ) {
434                return nullptr;
435        }
436        ast::StructInstType const * typeIdInst =
437                followHeadPointerType( errorLocation, tableInst, "__cfavir_typeid" );
438        return typeIdInst;
439}
440
441ast::Expr const * ExpandCastsCore::postvisit(
442                ast::VirtualCastExpr const * expr ) {
443        assertf( expr->result, "Virtual cast target not found before expansion." );
444
445        assert( vcast_decl );
446        assert( info_decl );
447
448        ast::Type const * base_type = getBaseType( expr->result );
449        if ( nullptr == base_type ) {
450                castError( expr, "Virtual cast target must be a pointer or reference type." );
451        }
452        ast::StructInstType const * type_id_type =
453                        getTypeIdType( castLocation( expr ), base_type );
454        if ( nullptr == type_id_type ) {
455                castError( expr, "Ill formed virtual cast target type." );
456        }
457        ast::ObjectDecl const * type_id = symtab.lookup( type_id_type );
458        if ( nullptr == type_id ) {
459                // I'm trying to give a different error for polymorpic types as
460                // different things can go wrong there.
461                if ( type_id_type->params.empty() ) {
462                        castError( expr, "Virtual cast does not target a virtual type." );
463                } else {
464                        castError( expr, "Virtual cast does not target a type with a "
465                                "type id (possible missing virtual table)." );
466                }
467        }
468
469        return new ast::CastExpr( expr->location,
470                new ast::ApplicationExpr( expr->location,
471                        ast::VariableExpr::functionPointer( expr->location, vcast_decl ),
472                        {
473                                cast_to_type_id(
474                                        new ast::AddressExpr( expr->location,
475                                                new ast::VariableExpr( expr->location, type_id ) ),
476                                        1 ),
477                                cast_to_type_id( expr->arg, 2 ),
478                        }
479                ),
480                ast::deepCopy( expr->result )
481        );
482}
483
484ast::CastExpr const * ExpandCastsCore::cast_to_type_id(
485                ast::Expr const * expr, unsigned int level_of_indirection ) {
486        assert( info_decl );
487        ast::Type * type = new ast::StructInstType( info_decl, ast::CV::Const );
488        for ( unsigned int i = 0 ; i < level_of_indirection ; ++i ) {
489                type = new ast::PointerType( type );
490        }
491        return new ast::CastExpr( expr->location, expr, type );
492}
493
494} // namespace
495
496void expandCasts( std::list< Declaration * > & translationUnit ) {
497        PassVisitor<VirtualCastCore> translator;
498        mutateAll( translationUnit, translator );
499}
500
501void expandCasts( ast::TranslationUnit & translationUnit ) {
502        ast::Pass<ExpandCastsCore>::run( translationUnit );
503}
504
505} // namespace Virtual
Note: See TracBrowser for help on using the repository browser.