source: src/Virtual/ExpandCasts.cc@ 11ab0b4a

Last change on this file since 11ab0b4a was bccd70a, checked in by Andrew Beach <ajbeach@…>, 2 years ago

Removed internal code from TypeSubstitution header. It caused a chain of include problems, which have been corrected.

  • Property mode set to 100644
File size: 15.8 KB
Line 
1//
2// Cforall Version 1.0.0 Copyright (C) 2015 University of Waterloo
3//
4// The contents of this file are covered under the licence agreement in the
5// file "LICENCE" distributed with Cforall.
6//
7// ExpandCasts.cc --
8//
9// Author : Andrew Beach
10// Created On : Mon Jul 24 13:59:00 2017
11// Last Modified By : Andrew Beach
12// Last Modified On : Thu Aug 11 12:06:00 2022
13// Update Count : 5
14//
15
16#include "ExpandCasts.h"
17
18#include <cassert> // for assert, assertf
19#include <iterator> // for back_inserter, inserter
20#include <string> // for string, allocator, operator==, ope...
21
22#include "AST/Copy.hpp"
23#include "AST/Decl.hpp"
24#include "AST/Expr.hpp"
25#include "AST/Pass.hpp"
26#include "Common/PassVisitor.h" // for PassVisitor
27#include "Common/ScopedMap.h" // for ScopedMap
28#include "Common/SemanticError.h" // for SemanticError
29#include "SymTab/Mangler.h" // for mangleType
30#include "SynTree/Declaration.h" // for ObjectDecl, StructDecl, FunctionDecl
31#include "SynTree/Expression.h" // for VirtualCastExpr, CastExpr, Address...
32#include "SynTree/Mutator.h" // for mutateAll
33#include "SynTree/Type.h" // for Type, PointerType, StructInstType
34#include "SynTree/Visitor.h" // for acceptAll
35
36namespace Virtual {
37
38namespace {
39
40bool is_prefix( const std::string & prefix, const std::string& entire ) {
41 size_t const p_size = prefix.size();
42 return (p_size < entire.size() && prefix == entire.substr(0, p_size));
43}
44
45bool is_type_id_object( const ObjectDecl * objectDecl ) {
46 const std::string & objectName = objectDecl->name;
47 return is_prefix( "__cfatid_", objectName );
48}
49
50bool is_type_id_object( const ast::ObjectDecl * decl ) {
51 return is_prefix( "__cfatid_", decl->name );
52}
53
54 // Indented until the new ast code gets added.
55
56 /// Maps virtual table types the instance for that type.
57 class VirtualTableMap final {
58 ScopedMap<std::string, ObjectDecl *> vtable_instances;
59 public:
60 void enterScope() {
61 vtable_instances.beginScope();
62 }
63 void leaveScope() {
64 vtable_instances.endScope();
65 }
66
67 ObjectDecl * insert( ObjectDecl * vtableDecl ) {
68 std::string const & mangledName = SymTab::Mangler::mangleType( vtableDecl->type );
69 ObjectDecl *& value = vtable_instances[ mangledName ];
70 if ( value ) {
71 if ( vtableDecl->storageClasses.is_extern ) {
72 return nullptr;
73 } else if ( ! value->storageClasses.is_extern ) {
74 return value;
75 }
76 }
77 value = vtableDecl;
78 return nullptr;
79 }
80
81 ObjectDecl * lookup( const Type * vtableType ) {
82 std::string const & mangledName = SymTab::Mangler::mangleType( vtableType );
83 const auto it = vtable_instances.find( mangledName );
84 return ( vtable_instances.end() == it ) ? nullptr : it->second;
85 }
86 };
87
88 class VirtualCastCore {
89 CastExpr * cast_to_type_id( Expression * expr, int level_of_indirection ) {
90 Type * type = new StructInstType(
91 Type::Qualifiers( Type::Const ), pvt_decl );
92 for (int i = 0 ; i < level_of_indirection ; ++i) {
93 type = new PointerType( noQualifiers, type );
94 }
95 return new CastExpr( expr, type );
96 }
97
98 public:
99 VirtualCastCore() :
100 indexer(), vcast_decl( nullptr ), pvt_decl( nullptr )
101 {}
102
103 void premutate( FunctionDecl * functionDecl );
104 void premutate( StructDecl * structDecl );
105 void premutate( ObjectDecl * objectDecl );
106
107 Expression * postmutate( VirtualCastExpr * castExpr );
108
109 VirtualTableMap indexer;
110 private:
111 FunctionDecl *vcast_decl;
112 StructDecl *pvt_decl;
113 };
114
115 void VirtualCastCore::premutate( FunctionDecl * functionDecl ) {
116 if ( (! vcast_decl) &&
117 functionDecl->get_name() == "__cfavir_virtual_cast" ) {
118 vcast_decl = functionDecl;
119 }
120 }
121
122 void VirtualCastCore::premutate( StructDecl * structDecl ) {
123 if ( pvt_decl || ! structDecl->has_body() ) {
124 return;
125 } else if ( structDecl->get_name() == "__cfavir_type_info" ) {
126 pvt_decl = structDecl;
127 }
128 }
129
130 void VirtualCastCore::premutate( ObjectDecl * objectDecl ) {
131 if ( is_type_id_object( objectDecl ) ) {
132 // Multiple definitions should be fine because of linkonce.
133 indexer.insert( objectDecl );
134 }
135 }
136
137 /// Better error locations for generated casts.
138 CodeLocation castLocation( const VirtualCastExpr * castExpr ) {
139 if ( castExpr->location.isSet() ) {
140 return castExpr->location;
141 } else if ( castExpr->arg->location.isSet() ) {
142 return castExpr->arg->location;
143 } else if ( castExpr->result->location.isSet() ) {
144 return castExpr->result->location;
145 } else {
146 return CodeLocation();
147 }
148 }
149
150 [[noreturn]] void castError( const VirtualCastExpr * castExpr, std::string const & message ) {
151 SemanticError( castLocation( castExpr ), message );
152 }
153
154 /// Get the base type from a pointer or reference.
155 const Type * getBaseType( const Type * type ) {
156 if ( auto target = dynamic_cast<const PointerType *>( type ) ) {
157 return target->base;
158 } else if ( auto target = dynamic_cast<const ReferenceType *>( type ) ) {
159 return target->base;
160 } else {
161 return nullptr;
162 }
163 }
164
165 /* Attempt to follow the "head" field of the structure to get the...
166 * Returns nullptr on error, otherwise owner must free returned node.
167 */
168 StructInstType * followHeadPointerType(
169 const StructInstType * oldType,
170 const std::string& fieldName,
171 const CodeLocation& errorLocation ) {
172
173 // First section of the function is all about trying to fill this variable in.
174 StructInstType * newType = nullptr;
175 {
176 const StructDecl * oldDecl = oldType->baseStruct;
177 assert( oldDecl );
178
179 // Helper function for throwing semantic errors.
180 auto throwError = [&fieldName, &errorLocation, &oldDecl](const std::string& message) {
181 const std::string& context = "While following head pointer of " +
182 oldDecl->name + " named '" + fieldName + "': ";
183 SemanticError( errorLocation, context + message );
184 };
185
186 if ( oldDecl->members.empty() ) {
187 throwError( "Type has no fields." );
188 }
189 const Declaration * memberDecl = oldDecl->members.front();
190 assert( memberDecl );
191 const ObjectDecl * fieldDecl = dynamic_cast<const ObjectDecl *>( memberDecl );
192 assert( fieldDecl );
193 if ( fieldName != fieldDecl->name ) {
194 throwError( "Head field did not have expected name." );
195 }
196
197 const Type * fieldType = fieldDecl->type;
198 if ( nullptr == fieldType ) {
199 throwError( "Could not get head field." );
200 }
201 const PointerType * ptrType = dynamic_cast<const PointerType *>( fieldType );
202 if ( nullptr == ptrType ) {
203 throwError( "First field is not a pointer type." );
204 }
205 assert( ptrType->base );
206 newType = dynamic_cast<StructInstType *>( ptrType->base );
207 if ( nullptr == newType ) {
208 throwError( "First field does not point to a structure type." );
209 }
210 }
211
212 // Now we can look into copying it.
213 newType = newType->clone();
214 if ( ! oldType->parameters.empty() ) {
215 deleteAll( newType->parameters );
216 newType->parameters.clear();
217 cloneAll( oldType->parameters, newType->parameters );
218 }
219 return newType;
220 }
221
222 /// Get the type-id type from a virtual type.
223 StructInstType * getTypeIdType( const Type * type, const CodeLocation& errorLocation ) {
224 const StructInstType * typeInst = dynamic_cast<const StructInstType *>( type );
225 if ( nullptr == typeInst ) {
226 return nullptr;
227 }
228 StructInstType * tableInst =
229 followHeadPointerType( typeInst, "virtual_table", errorLocation );
230 if ( nullptr == tableInst ) {
231 return nullptr;
232 }
233 StructInstType * typeIdInst =
234 followHeadPointerType( tableInst, "__cfavir_typeid", errorLocation );
235 delete tableInst;
236 return typeIdInst;
237 }
238
239 Expression * VirtualCastCore::postmutate( VirtualCastExpr * castExpr ) {
240 assertf( castExpr->result, "Virtual Cast target not found before expansion." );
241
242 assert( vcast_decl );
243 assert( pvt_decl );
244
245 const Type * base_type = getBaseType( castExpr->result );
246 if ( nullptr == base_type ) {
247 castError( castExpr, "Virtual cast target must be a pointer or reference type." );
248 }
249 const Type * type_id_type = getTypeIdType( base_type, castLocation( castExpr ) );
250 if ( nullptr == type_id_type ) {
251 castError( castExpr, "Ill formed virtual cast target type." );
252 }
253 ObjectDecl * type_id = indexer.lookup( type_id_type );
254 delete type_id_type;
255 if ( nullptr == type_id ) {
256 castError( castExpr, "Virtual cast does not target a virtual type." );
257 }
258
259 Expression * result = new CastExpr(
260 new ApplicationExpr( VariableExpr::functionPointer( vcast_decl ), {
261 cast_to_type_id( new AddressExpr( new VariableExpr( type_id ) ), 1 ),
262 cast_to_type_id( castExpr->get_arg(), 2 ),
263 } ),
264 castExpr->get_result()->clone()
265 );
266
267 castExpr->set_arg( nullptr );
268 castExpr->set_result( nullptr );
269 delete castExpr;
270 return result;
271 }
272
273/// Better error locations for generated casts.
274// TODO: Does the improved distribution of code locations make this unneeded?
275CodeLocation castLocation( const ast::VirtualCastExpr * castExpr ) {
276 if ( castExpr->location.isSet() ) {
277 return castExpr->location;
278 } else if ( castExpr->arg->location.isSet() ) {
279 return castExpr->arg->location;
280 } else {
281 return CodeLocation();
282 }
283}
284
285[[noreturn]] void castError( ast::VirtualCastExpr const * castExpr, std::string const & message ) {
286 SemanticError( castLocation( castExpr ), message );
287}
288
289class TypeIdTable final {
290 ScopedMap<std::string, ast::ObjectDecl const *> instances;
291public:
292 void enterScope() { instances.beginScope(); }
293 void leaveScope() { instances.endScope(); }
294
295 // Attempt to insert an instance into the map. If there is a conflict,
296 // returns the previous declaration for error messages.
297 ast::ObjectDecl const * insert( ast::ObjectDecl const * typeIdDecl ) {
298 std::string mangledName = Mangle::mangleType( typeIdDecl->type );
299 ast::ObjectDecl const *& value = instances[ mangledName ];
300 if ( value ) {
301 if ( typeIdDecl->storage.is_extern ) {
302 return nullptr;
303 } else if ( !value->storage.is_extern ) {
304 return value;
305 }
306 }
307 value = typeIdDecl;
308 return nullptr;
309 }
310
311 ast::ObjectDecl const * lookup( ast::Type const * typeIdType ) {
312 std::string mangledName = Mangle::mangleType( typeIdType );
313 auto const it = instances.find( mangledName );
314 return ( instances.end() == it ) ? nullptr : it->second;
315 }
316};
317
318struct ExpandCastsCore final {
319 void previsit( ast::FunctionDecl const * decl );
320 void previsit( ast::StructDecl const * decl );
321 void previsit( ast::ObjectDecl const * decl );
322 ast::Expr const * postvisit( ast::VirtualCastExpr const * expr );
323
324 ast::CastExpr const * cast_to_type_id(
325 ast::Expr const * expr, unsigned int level_of_indirection );
326
327 ast::FunctionDecl const * vcast_decl = nullptr;
328 ast::StructDecl const * info_decl = nullptr;
329
330 TypeIdTable symtab;
331};
332
333void ExpandCastsCore::previsit( ast::FunctionDecl const * decl ) {
334 if ( !vcast_decl && "__cfavir_virtual_cast" == decl->name ) {
335 vcast_decl = decl;
336 }
337}
338
339void ExpandCastsCore::previsit( ast::StructDecl const * decl ) {
340 if ( !info_decl && decl->body && "__cfavir_type_info" == decl->name ) {
341 info_decl = decl;
342 }
343}
344
345void ExpandCastsCore::previsit( ast::ObjectDecl const * decl ) {
346 if ( is_type_id_object( decl ) ) {
347 // Multiple definitions should be fine because of linkonce.
348 symtab.insert( decl );
349 }
350}
351
352/// Get the base type from a pointer or reference.
353ast::Type const * getBaseType( ast::ptr<ast::Type> const & type ) {
354 if ( auto target = type.as<ast::PointerType>() ) {
355 return target->base.get();
356 } else if ( auto target = type.as<ast::ReferenceType>() ) {
357 return target->base.get();
358 } else {
359 return nullptr;
360 }
361}
362
363/// Copy newType, but give the copy the params of the oldType.
364ast::StructInstType * polyCopy(
365 ast::StructInstType const * oldType,
366 ast::StructInstType const * newType ) {
367 assert( oldType->params.size() == newType->params.size() );
368 ast::StructInstType * retType = ast::deepCopy( newType );
369 if ( ! oldType->params.empty() ) {
370 retType->params.clear();
371 for ( auto oldParams : oldType->params ) {
372 retType->params.push_back( ast::deepCopy( oldParams ) );
373 }
374 }
375 return retType;
376}
377
378/// Follow the "head" field of the structure to get the type that is pointed
379/// to by that field.
380ast::StructInstType const * followHeadPointerType(
381 CodeLocation const & errorLocation,
382 ast::StructInstType const * oldType,
383 std::string const & fieldName ) {
384 ast::StructDecl const * oldDecl = oldType->base;
385 assert( oldDecl );
386
387 // Helper function for throwing semantic errors.
388 auto throwError = [&fieldName, &errorLocation, &oldDecl](
389 std::string const & message ) {
390 std::string const & context = "While following head pointer of " +
391 oldDecl->name + " named '" + fieldName + "': ";
392 SemanticError( errorLocation, context + message );
393 };
394
395 if ( oldDecl->members.empty() ) {
396 throwError( "Type has no fields." );
397 }
398 ast::ptr<ast::Decl> const & memberDecl = oldDecl->members.front();
399 assert( memberDecl );
400 ast::ObjectDecl const * fieldDecl = memberDecl.as<ast::ObjectDecl>();
401 assert( fieldDecl );
402 if ( fieldName != fieldDecl->name ) {
403 throwError( "Head field did not have expected name." );
404 }
405
406 ast::ptr<ast::Type> const & fieldType = fieldDecl->type;
407 if ( nullptr == fieldType ) {
408 throwError( "Could not get head field." );
409 }
410 auto ptrType = fieldType.as<ast::PointerType>();
411 if ( nullptr == ptrType ) {
412 throwError( "First field is not a pointer type." );
413 }
414 assert( ptrType->base );
415 auto newType = ptrType->base.as<ast::StructInstType>();
416 if ( nullptr == newType ) {
417 throwError( "First field does not point to a structure type." );
418 }
419
420 return polyCopy( oldType, newType );
421}
422
423/// Get the type-id type from a virtual type.
424ast::StructInstType const * getTypeIdType(
425 CodeLocation const & errorLocation,
426 ast::Type const * type ) {
427 auto typeInst = dynamic_cast<ast::StructInstType const *>( type );
428 if ( nullptr == typeInst ) {
429 return nullptr;
430 }
431 ast::ptr<ast::StructInstType> tableInst =
432 followHeadPointerType( errorLocation, typeInst, "virtual_table" );
433 if ( nullptr == tableInst ) {
434 return nullptr;
435 }
436 ast::StructInstType const * typeIdInst =
437 followHeadPointerType( errorLocation, tableInst, "__cfavir_typeid" );
438 return typeIdInst;
439}
440
441ast::Expr const * ExpandCastsCore::postvisit(
442 ast::VirtualCastExpr const * expr ) {
443 assertf( expr->result, "Virtual cast target not found before expansion." );
444
445 assert( vcast_decl );
446 assert( info_decl );
447
448 ast::Type const * base_type = getBaseType( expr->result );
449 if ( nullptr == base_type ) {
450 castError( expr, "Virtual cast target must be a pointer or reference type." );
451 }
452 ast::StructInstType const * type_id_type =
453 getTypeIdType( castLocation( expr ), base_type );
454 if ( nullptr == type_id_type ) {
455 castError( expr, "Ill formed virtual cast target type." );
456 }
457 ast::ObjectDecl const * type_id = symtab.lookup( type_id_type );
458 if ( nullptr == type_id ) {
459 // I'm trying to give a different error for polymorpic types as
460 // different things can go wrong there.
461 if ( type_id_type->params.empty() ) {
462 castError( expr, "Virtual cast does not target a virtual type." );
463 } else {
464 castError( expr, "Virtual cast does not target a type with a "
465 "type id (possible missing virtual table)." );
466 }
467 }
468
469 return new ast::CastExpr( expr->location,
470 new ast::ApplicationExpr( expr->location,
471 ast::VariableExpr::functionPointer( expr->location, vcast_decl ),
472 {
473 cast_to_type_id(
474 new ast::AddressExpr( expr->location,
475 new ast::VariableExpr( expr->location, type_id ) ),
476 1 ),
477 cast_to_type_id( expr->arg, 2 ),
478 }
479 ),
480 ast::deepCopy( expr->result )
481 );
482}
483
484ast::CastExpr const * ExpandCastsCore::cast_to_type_id(
485 ast::Expr const * expr, unsigned int level_of_indirection ) {
486 assert( info_decl );
487 ast::Type * type = new ast::StructInstType( info_decl, ast::CV::Const );
488 for ( unsigned int i = 0 ; i < level_of_indirection ; ++i ) {
489 type = new ast::PointerType( type );
490 }
491 return new ast::CastExpr( expr->location, expr, type );
492}
493
494} // namespace
495
496void expandCasts( std::list< Declaration * > & translationUnit ) {
497 PassVisitor<VirtualCastCore> translator;
498 mutateAll( translationUnit, translator );
499}
500
501void expandCasts( ast::TranslationUnit & translationUnit ) {
502 ast::Pass<ExpandCastsCore>::run( translationUnit );
503}
504
505} // namespace Virtual
Note: See TracBrowser for help on using the repository browser.