source: src/Virtual/ExpandCasts.cc @ 9317419

ADTast-experimental
Last change on this file since 9317419 was 0026d67, checked in by Andrew Beach <ajbeach@…>, 2 years ago

Replaced Mangle::typeMode() with Mangle::mangleType(...), as it is how typeMode() was always used and it is shorter. Various other clean-up in the Mangler files.

  • Property mode set to 100644
File size: 15.8 KB
Line 
1//
2// Cforall Version 1.0.0 Copyright (C) 2015 University of Waterloo
3//
4// The contents of this file are covered under the licence agreement in the
5// file "LICENCE" distributed with Cforall.
6//
7// ExpandCasts.cc --
8//
9// Author           : Andrew Beach
10// Created On       : Mon Jul 24 13:59:00 2017
11// Last Modified By : Andrew Beach
12// Last Modified On : Thu Aug 11 12:06:00 2022
13// Update Count     : 5
14//
15
16#include "ExpandCasts.h"
17
18#include <cassert>                 // for assert, assertf
19#include <iterator>                // for back_inserter, inserter
20#include <string>                  // for string, allocator, operator==, ope...
21
22#include "AST/Decl.hpp"
23#include "AST/Expr.hpp"
24#include "AST/Pass.hpp"
25#include "Common/PassVisitor.h"    // for PassVisitor
26#include "Common/ScopedMap.h"      // for ScopedMap
27#include "Common/SemanticError.h"  // for SemanticError
28#include "SymTab/Mangler.h"        // for mangleType
29#include "SynTree/Declaration.h"   // for ObjectDecl, StructDecl, FunctionDecl
30#include "SynTree/Expression.h"    // for VirtualCastExpr, CastExpr, Address...
31#include "SynTree/Mutator.h"       // for mutateAll
32#include "SynTree/Type.h"          // for Type, PointerType, StructInstType
33#include "SynTree/Visitor.h"       // for acceptAll
34
35namespace Virtual {
36
37namespace {
38
39bool is_prefix( const std::string & prefix, const std::string& entire ) {
40        size_t const p_size = prefix.size();
41        return (p_size < entire.size() && prefix == entire.substr(0, p_size));
42}
43
44bool is_type_id_object( const ObjectDecl * objectDecl ) {
45        const std::string & objectName = objectDecl->name;
46        return is_prefix( "__cfatid_", objectName );
47}
48
49bool is_type_id_object( const ast::ObjectDecl * decl ) {
50        return is_prefix( "__cfatid_", decl->name );
51}
52
53        // Indented until the new ast code gets added.
54
55        /// Maps virtual table types the instance for that type.
56        class VirtualTableMap final {
57                ScopedMap<std::string, ObjectDecl *> vtable_instances;
58        public:
59                void enterScope() {
60                        vtable_instances.beginScope();
61                }
62                void leaveScope() {
63                        vtable_instances.endScope();
64                }
65
66                ObjectDecl * insert( ObjectDecl * vtableDecl ) {
67                        std::string const & mangledName = SymTab::Mangler::mangleType( vtableDecl->type );
68                        ObjectDecl *& value = vtable_instances[ mangledName ];
69                        if ( value ) {
70                                if ( vtableDecl->storageClasses.is_extern ) {
71                                        return nullptr;
72                                } else if ( ! value->storageClasses.is_extern ) {
73                                        return value;
74                                }
75                        }
76                        value = vtableDecl;
77                        return nullptr;
78                }
79
80                ObjectDecl * lookup( const Type * vtableType ) {
81                        std::string const & mangledName = SymTab::Mangler::mangleType( vtableType );
82                        const auto it = vtable_instances.find( mangledName );
83                        return ( vtable_instances.end() == it ) ? nullptr : it->second;
84                }
85        };
86
87        class VirtualCastCore {
88                CastExpr * cast_to_type_id( Expression * expr, int level_of_indirection ) {
89                        Type * type = new StructInstType(
90                                Type::Qualifiers( Type::Const ), pvt_decl );
91                        for (int i = 0 ; i < level_of_indirection ; ++i) {
92                                type = new PointerType( noQualifiers, type );
93                        }
94                        return new CastExpr( expr, type );
95                }
96
97        public:
98                VirtualCastCore() :
99                        indexer(), vcast_decl( nullptr ), pvt_decl( nullptr )
100                {}
101
102                void premutate( FunctionDecl * functionDecl );
103                void premutate( StructDecl * structDecl );
104                void premutate( ObjectDecl * objectDecl );
105
106                Expression * postmutate( VirtualCastExpr * castExpr );
107
108                VirtualTableMap indexer;
109        private:
110                FunctionDecl *vcast_decl;
111                StructDecl *pvt_decl;
112        };
113
114        void VirtualCastCore::premutate( FunctionDecl * functionDecl ) {
115                if ( (! vcast_decl) &&
116                     functionDecl->get_name() == "__cfavir_virtual_cast" ) {
117                        vcast_decl = functionDecl;
118                }
119        }
120
121        void VirtualCastCore::premutate( StructDecl * structDecl ) {
122                if ( pvt_decl || ! structDecl->has_body() ) {
123                        return;
124                } else if ( structDecl->get_name() == "__cfavir_type_info" ) {
125                        pvt_decl = structDecl;
126                }
127        }
128
129        void VirtualCastCore::premutate( ObjectDecl * objectDecl ) {
130                if ( is_type_id_object( objectDecl ) ) {
131                        // Multiple definitions should be fine because of linkonce.
132                        indexer.insert( objectDecl );
133                }
134        }
135
136        /// Better error locations for generated casts.
137        CodeLocation castLocation( const VirtualCastExpr * castExpr ) {
138                if ( castExpr->location.isSet() ) {
139                        return castExpr->location;
140                } else if ( castExpr->arg->location.isSet() ) {
141                        return castExpr->arg->location;
142                } else if ( castExpr->result->location.isSet() ) {
143                        return castExpr->result->location;
144                } else {
145                        return CodeLocation();
146                }
147        }
148
149        [[noreturn]] void castError( const VirtualCastExpr * castExpr, std::string const & message ) {
150                SemanticError( castLocation( castExpr ), message );
151        }
152
153        /// Get the base type from a pointer or reference.
154        const Type * getBaseType( const Type * type ) {
155                if ( auto target = dynamic_cast<const PointerType *>( type ) ) {
156                        return target->base;
157                } else if ( auto target = dynamic_cast<const ReferenceType *>( type ) ) {
158                        return target->base;
159                } else {
160                        return nullptr;
161                }
162        }
163
164        /* Attempt to follow the "head" field of the structure to get the...
165         * Returns nullptr on error, otherwise owner must free returned node.
166         */
167        StructInstType * followHeadPointerType(
168                        const StructInstType * oldType,
169                        const std::string& fieldName,
170                        const CodeLocation& errorLocation ) {
171
172                // First section of the function is all about trying to fill this variable in.
173                StructInstType * newType = nullptr;
174                {
175                        const StructDecl * oldDecl = oldType->baseStruct;
176                        assert( oldDecl );
177
178                        // Helper function for throwing semantic errors.
179                        auto throwError = [&fieldName, &errorLocation, &oldDecl](const std::string& message) {
180                                const std::string& context = "While following head pointer of " +
181                                        oldDecl->name + " named '" + fieldName + "': ";
182                                SemanticError( errorLocation, context + message );
183                        };
184
185                        if ( oldDecl->members.empty() ) {
186                                throwError( "Type has no fields." );
187                        }
188                        const Declaration * memberDecl = oldDecl->members.front();
189                        assert( memberDecl );
190                        const ObjectDecl * fieldDecl = dynamic_cast<const ObjectDecl *>( memberDecl );
191                        assert( fieldDecl );
192                        if ( fieldName != fieldDecl->name ) {
193                                throwError( "Head field did not have expected name." );
194                        }
195
196                        const Type * fieldType = fieldDecl->type;
197                        if ( nullptr == fieldType ) {
198                                throwError( "Could not get head field." );
199                        }
200                        const PointerType * ptrType = dynamic_cast<const PointerType *>( fieldType );
201                        if ( nullptr == ptrType ) {
202                                throwError( "First field is not a pointer type." );
203                        }
204                        assert( ptrType->base );
205                        newType = dynamic_cast<StructInstType *>( ptrType->base );
206                        if ( nullptr == newType ) {
207                                throwError( "First field does not point to a structure type." );
208                        }
209                }
210
211                // Now we can look into copying it.
212                newType = newType->clone();
213                if ( ! oldType->parameters.empty() ) {
214                        deleteAll( newType->parameters );
215                        newType->parameters.clear();
216                        cloneAll( oldType->parameters, newType->parameters );
217                }
218                return newType;
219        }
220
221        /// Get the type-id type from a virtual type.
222        StructInstType * getTypeIdType( const Type * type, const CodeLocation& errorLocation ) {
223                const StructInstType * typeInst = dynamic_cast<const StructInstType *>( type );
224                if ( nullptr == typeInst ) {
225                        return nullptr;
226                }
227                StructInstType * tableInst =
228                        followHeadPointerType( typeInst, "virtual_table", errorLocation );
229                if ( nullptr == tableInst ) {
230                        return nullptr;
231                }
232                StructInstType * typeIdInst =
233                        followHeadPointerType( tableInst, "__cfavir_typeid", errorLocation );
234                delete tableInst;
235                return typeIdInst;
236        }
237
238        Expression * VirtualCastCore::postmutate( VirtualCastExpr * castExpr ) {
239                assertf( castExpr->result, "Virtual Cast target not found before expansion." );
240
241                assert( vcast_decl );
242                assert( pvt_decl );
243
244                const Type * base_type = getBaseType( castExpr->result );
245                if ( nullptr == base_type ) {
246                        castError( castExpr, "Virtual cast target must be a pointer or reference type." );
247                }
248                const Type * type_id_type = getTypeIdType( base_type, castLocation( castExpr ) );
249                if ( nullptr == type_id_type ) {
250                        castError( castExpr, "Ill formed virtual cast target type." );
251                }
252                ObjectDecl * type_id = indexer.lookup( type_id_type );
253                delete type_id_type;
254                if ( nullptr == type_id ) {
255                        castError( castExpr, "Virtual cast does not target a virtual type." );
256                }
257
258                Expression * result = new CastExpr(
259                        new ApplicationExpr( VariableExpr::functionPointer( vcast_decl ), {
260                                cast_to_type_id( new AddressExpr( new VariableExpr( type_id ) ), 1 ),
261                                cast_to_type_id( castExpr->get_arg(), 2 ),
262                        } ),
263                        castExpr->get_result()->clone()
264                );
265
266                castExpr->set_arg( nullptr );
267                castExpr->set_result( nullptr );
268                delete castExpr;
269                return result;
270        }
271
272/// Better error locations for generated casts.
273// TODO: Does the improved distribution of code locations make this unneeded?
274CodeLocation castLocation( const ast::VirtualCastExpr * castExpr ) {
275        if ( castExpr->location.isSet() ) {
276                return castExpr->location;
277        } else if ( castExpr->arg->location.isSet() ) {
278                return castExpr->arg->location;
279        } else {
280                return CodeLocation();
281        }
282}
283
284[[noreturn]] void castError( ast::VirtualCastExpr const * castExpr, std::string const & message ) {
285        SemanticError( castLocation( castExpr ), message );
286}
287
288class TypeIdTable final {
289        ScopedMap<std::string, ast::ObjectDecl const *> instances;
290public:
291        void enterScope() { instances.beginScope(); }
292        void leaveScope() { instances.endScope(); }
293
294        // Attempt to insert an instance into the map. If there is a conflict,
295        // returns the previous declaration for error messages.
296        ast::ObjectDecl const * insert( ast::ObjectDecl const * typeIdDecl ) {
297                std::string mangledName = Mangle::mangleType( typeIdDecl->type );
298                ast::ObjectDecl const *& value = instances[ mangledName ];
299                if ( value ) {
300                        if ( typeIdDecl->storage.is_extern ) {
301                                return nullptr;
302                        } else if ( !value->storage.is_extern ) {
303                                return value;
304                        }
305                }
306                value = typeIdDecl;
307                return nullptr;
308        }
309
310        ast::ObjectDecl const * lookup( ast::Type const * typeIdType ) {
311                std::string mangledName = Mangle::mangleType( typeIdType );
312                auto const it = instances.find( mangledName );
313                return ( instances.end() == it ) ? nullptr : it->second;
314        }
315};
316
317struct ExpandCastsCore final {
318        void previsit( ast::FunctionDecl const * decl );
319        void previsit( ast::StructDecl const * decl );
320        void previsit( ast::ObjectDecl const * decl );
321        ast::Expr const * postvisit( ast::VirtualCastExpr const * expr );
322
323        ast::CastExpr const * cast_to_type_id(
324                ast::Expr const * expr, unsigned int level_of_indirection );
325
326        ast::FunctionDecl const * vcast_decl = nullptr;
327        ast::StructDecl const * info_decl = nullptr;
328
329        TypeIdTable symtab;
330};
331
332void ExpandCastsCore::previsit( ast::FunctionDecl const * decl ) {
333        if ( !vcast_decl && "__cfavir_virtual_cast" == decl->name ) {
334                vcast_decl = decl;
335        }
336}
337
338void ExpandCastsCore::previsit( ast::StructDecl const * decl ) {
339        if ( !info_decl && decl->body && "__cfavir_type_info" == decl->name ) {
340                info_decl = decl;
341        }
342}
343
344void ExpandCastsCore::previsit( ast::ObjectDecl const * decl ) {
345        if ( is_type_id_object( decl ) ) {
346                // Multiple definitions should be fine because of linkonce.
347                symtab.insert( decl );
348        }
349}
350
351/// Get the base type from a pointer or reference.
352ast::Type const * getBaseType( ast::ptr<ast::Type> const & type ) {
353        if ( auto target = type.as<ast::PointerType>() ) {
354                return target->base.get();
355        } else if ( auto target = type.as<ast::ReferenceType>() ) {
356                return target->base.get();
357        } else {
358                return nullptr;
359        }
360}
361
362/// Copy newType, but give the copy the params of the oldType.
363ast::StructInstType * polyCopy(
364                ast::StructInstType const * oldType,
365                ast::StructInstType const * newType ) {
366        assert( oldType->params.size() == newType->params.size() );
367        ast::StructInstType * retType = ast::deepCopy( newType );
368        if ( ! oldType->params.empty() ) {
369                retType->params.clear();
370                for ( auto oldParams : oldType->params ) {
371                        retType->params.push_back( ast::deepCopy( oldParams ) );
372                }
373        }
374        return retType;
375}
376
377/// Follow the "head" field of the structure to get the type that is pointed
378/// to by that field.
379ast::StructInstType const * followHeadPointerType(
380                CodeLocation const & errorLocation,
381                ast::StructInstType const * oldType,
382                std::string const & fieldName ) {
383        ast::StructDecl const * oldDecl = oldType->base;
384        assert( oldDecl );
385
386        // Helper function for throwing semantic errors.
387        auto throwError = [&fieldName, &errorLocation, &oldDecl](
388                        std::string const & message ) {
389                std::string const & context = "While following head pointer of " +
390                        oldDecl->name + " named '" + fieldName + "': ";
391                SemanticError( errorLocation, context + message );
392        };
393
394        if ( oldDecl->members.empty() ) {
395                throwError( "Type has no fields." );
396        }
397        ast::ptr<ast::Decl> const & memberDecl = oldDecl->members.front();
398        assert( memberDecl );
399        ast::ObjectDecl const * fieldDecl = memberDecl.as<ast::ObjectDecl>();
400        assert( fieldDecl );
401        if ( fieldName != fieldDecl->name ) {
402                throwError( "Head field did not have expected name." );
403        }
404
405        ast::ptr<ast::Type> const & fieldType = fieldDecl->type;
406        if ( nullptr == fieldType ) {
407                throwError( "Could not get head field." );
408        }
409        auto ptrType = fieldType.as<ast::PointerType>();
410        if ( nullptr == ptrType ) {
411                throwError( "First field is not a pointer type." );
412        }
413        assert( ptrType->base );
414        auto newType = ptrType->base.as<ast::StructInstType>();
415        if ( nullptr == newType ) {
416                throwError( "First field does not point to a structure type." );
417        }
418
419        return polyCopy( oldType, newType );
420}
421
422/// Get the type-id type from a virtual type.
423ast::StructInstType const * getTypeIdType(
424                CodeLocation const & errorLocation,
425                ast::Type const * type ) {
426        auto typeInst = dynamic_cast<ast::StructInstType const *>( type );
427        if ( nullptr == typeInst ) {
428                return nullptr;
429        }
430        ast::ptr<ast::StructInstType> tableInst =
431                followHeadPointerType( errorLocation, typeInst, "virtual_table" );
432        if ( nullptr == tableInst ) {
433                return nullptr;
434        }
435        ast::StructInstType const * typeIdInst =
436                followHeadPointerType( errorLocation, tableInst, "__cfavir_typeid" );
437        return typeIdInst;
438}
439
440ast::Expr const * ExpandCastsCore::postvisit(
441                ast::VirtualCastExpr const * expr ) {
442        assertf( expr->result, "Virtual cast target not found before expansion." );
443
444        assert( vcast_decl );
445        assert( info_decl );
446
447        ast::Type const * base_type = getBaseType( expr->result );
448        if ( nullptr == base_type ) {
449                castError( expr, "Virtual cast target must be a pointer or reference type." );
450        }
451        ast::StructInstType const * type_id_type =
452                        getTypeIdType( castLocation( expr ), base_type );
453        if ( nullptr == type_id_type ) {
454                castError( expr, "Ill formed virtual cast target type." );
455        }
456        ast::ObjectDecl const * type_id = symtab.lookup( type_id_type );
457        if ( nullptr == type_id ) {
458                // I'm trying to give a different error for polymorpic types as
459                // different things can go wrong there.
460                if ( type_id_type->params.empty() ) {
461                        castError( expr, "Virtual cast does not target a virtual type." );
462                } else {
463                        castError( expr, "Virtual cast does not target a type with a "
464                                "type id (possible missing virtual table)." );
465                }
466        }
467
468        return new ast::CastExpr( expr->location,
469                new ast::ApplicationExpr( expr->location,
470                        ast::VariableExpr::functionPointer( expr->location, vcast_decl ),
471                        {
472                                cast_to_type_id(
473                                        new ast::AddressExpr( expr->location,
474                                                new ast::VariableExpr( expr->location, type_id ) ),
475                                        1 ),
476                                cast_to_type_id( expr->arg, 2 ),
477                        }
478                ),
479                ast::deepCopy( expr->result )
480        );
481}
482
483ast::CastExpr const * ExpandCastsCore::cast_to_type_id(
484                ast::Expr const * expr, unsigned int level_of_indirection ) {
485        assert( info_decl );
486        ast::Type * type = new ast::StructInstType( info_decl, ast::CV::Const );
487        for ( unsigned int i = 0 ; i < level_of_indirection ; ++i ) {
488                type = new ast::PointerType( type );
489        }
490        return new ast::CastExpr( expr->location, expr, type );
491}
492
493} // namespace
494
495void expandCasts( std::list< Declaration * > & translationUnit ) {
496        PassVisitor<VirtualCastCore> translator;
497        mutateAll( translationUnit, translator );
498}
499
500void expandCasts( ast::TranslationUnit & translationUnit ) {
501        ast::Pass<ExpandCastsCore>::run( translationUnit );
502}
503
504} // namespace Virtual
Note: See TracBrowser for help on using the repository browser.