source: src/GenPoly/GenPoly.cc@ 747d0fa

ADT ast-experimental pthread-emulation
Last change on this file since 747d0fa was 3606fe4, checked in by Andrew Beach <ajbeach@…>, 3 years ago

Translated Instantiate Generic to the new AST. This includes various utilities and some assorted clean-up.

  • Property mode set to 100644
File size: 30.2 KB
Line 
1//
2// Cforall Version 1.0.0 Copyright (C) 2015 University of Waterloo
3//
4// The contents of this file are covered under the licence agreement in the
5// file "LICENCE" distributed with Cforall.
6//
7// GenPoly.cc --
8//
9// Author : Richard C. Bilson
10// Created On : Mon May 18 07:44:20 2015
11// Last Modified By : Andrew Beach
12// Last Modified On : Wed Sep 14 9:24:00 2022
13// Update Count : 15
14//
15
16#include "GenPoly.h"
17
18#include <cassert> // for assertf, assert
19#include <iostream> // for operator<<, ostream, basic_os...
20#include <iterator> // for back_insert_iterator, back_in...
21#include <list> // for list, _List_iterator, list<>:...
22#include <typeindex> // for type_index
23#include <utility> // for pair
24#include <vector> // for vector
25
26#include "AST/Type.hpp"
27#include "GenPoly/ErasableScopedMap.h" // for ErasableScopedMap<>::const_it...
28#include "ResolvExpr/typeops.h" // for flatten
29#include "SynTree/Constant.h" // for Constant
30#include "SynTree/Expression.h" // for Expression, TypeExpr, Constan...
31#include "SynTree/Type.h" // for Type, StructInstType, UnionIn...
32#include "SynTree/TypeSubstitution.h" // for TypeSubstitution
33
34using namespace std;
35
36namespace GenPoly {
37 namespace {
38 /// Checks a parameter list for polymorphic parameters; will substitute according to env if present
39 bool hasPolyParams( std::list< Expression* >& params, const TypeSubstitution *env ) {
40 for ( std::list< Expression* >::iterator param = params.begin(); param != params.end(); ++param ) {
41 TypeExpr *paramType = dynamic_cast< TypeExpr* >( *param );
42 assertf(paramType, "Aggregate parameters should be type expressions");
43 if ( isPolyType( paramType->get_type(), env ) ) return true;
44 }
45 return false;
46 }
47
48 bool hasPolyParams( const std::vector<ast::ptr<ast::Expr>> & params, const ast::TypeSubstitution * env) {
49 for (auto &param : params) {
50 auto paramType = param.strict_as<ast::TypeExpr>();
51 if (isPolyType(paramType->type, env)) return true;
52 }
53 return false;
54 }
55
56 /// Checks a parameter list for polymorphic parameters from tyVars; will substitute according to env if present
57 bool hasPolyParams( std::list< Expression* >& params, const TyVarMap &tyVars, const TypeSubstitution *env ) {
58 for ( std::list< Expression* >::iterator param = params.begin(); param != params.end(); ++param ) {
59 TypeExpr *paramType = dynamic_cast< TypeExpr* >( *param );
60 assertf(paramType, "Aggregate parameters should be type expressions");
61 if ( isPolyType( paramType->get_type(), tyVars, env ) ) return true;
62 }
63 return false;
64 }
65
66 __attribute__((unused))
67 bool hasPolyParams( const std::vector<ast::ptr<ast::Expr>> & params, const TyVarMap & tyVars, const ast::TypeSubstitution * env) {
68 for (auto &param : params) {
69 auto paramType = param.strict_as<ast::TypeExpr>();
70 if (isPolyType(paramType->type, tyVars, env)) return true;
71 }
72 return false;
73 }
74
75 /// Checks a parameter list for dynamic-layout parameters from tyVars; will substitute according to env if present
76 bool hasDynParams( std::list< Expression* >& params, const TyVarMap &tyVars, const TypeSubstitution *env ) {
77 for ( std::list< Expression* >::iterator param = params.begin(); param != params.end(); ++param ) {
78 TypeExpr *paramType = dynamic_cast< TypeExpr* >( *param );
79 assertf(paramType, "Aggregate parameters should be type expressions");
80 if ( isDynType( paramType->get_type(), tyVars, env ) ) return true;
81 }
82 return false;
83 }
84
85 bool hasDynParams( const std::vector<ast::ptr<ast::Expr>> & params, const TyVarMap &tyVars, const ast::TypeSubstitution *typeSubs ) {
86 for ( ast::ptr<ast::Expr> const & param : params ) {
87 auto paramType = param.as<ast::TypeExpr>();
88 assertf( paramType, "Aggregate parameters should be type expressions." );
89 if ( isDynType( paramType->type, tyVars, typeSubs ) ) {
90 return true;
91 }
92 }
93 return false;
94 }
95
96 /// Checks a parameter list for inclusion of polymorphic parameters; will substitute according to env if present
97 bool includesPolyParams( std::list< Expression* >& params, const TypeSubstitution *env ) {
98 for ( std::list< Expression* >::iterator param = params.begin(); param != params.end(); ++param ) {
99 TypeExpr *paramType = dynamic_cast< TypeExpr* >( *param );
100 assertf(paramType, "Aggregate parameters should be type expressions");
101 if ( includesPolyType( paramType->get_type(), env ) ) return true;
102 }
103 return false;
104 }
105
106 /// Checks a parameter list for inclusion of polymorphic parameters from tyVars; will substitute according to env if present
107 bool includesPolyParams( std::list< Expression* >& params, const TyVarMap &tyVars, const TypeSubstitution *env ) {
108 for ( std::list< Expression* >::iterator param = params.begin(); param != params.end(); ++param ) {
109 TypeExpr *paramType = dynamic_cast< TypeExpr* >( *param );
110 assertf(paramType, "Aggregate parameters should be type expressions");
111 if ( includesPolyType( paramType->get_type(), tyVars, env ) ) return true;
112 }
113 return false;
114 }
115 }
116
117 Type* replaceTypeInst( Type* type, const TypeSubstitution* env ) {
118 if ( ! env ) return type;
119 if ( TypeInstType *typeInst = dynamic_cast< TypeInstType* >( type ) ) {
120 Type *newType = env->lookup( typeInst->get_name() );
121 if ( newType ) return newType;
122 }
123 return type;
124 }
125
126 const ast::Type * replaceTypeInst(const ast::Type * type, const ast::TypeSubstitution * env) {
127 if (!env) return type;
128 if (auto typeInst = dynamic_cast<const ast::TypeInstType*> (type)) {
129 auto newType = env->lookup(typeInst);
130 if (newType) return newType;
131 }
132 return type;
133 }
134
135 Type *isPolyType( Type *type, const TypeSubstitution *env ) {
136 type = replaceTypeInst( type, env );
137
138 if ( dynamic_cast< TypeInstType * >( type ) ) {
139 return type;
140 } else if ( ArrayType * arrayType = dynamic_cast< ArrayType * >( type ) ) {
141 return isPolyType( arrayType->base, env );
142 } else if ( StructInstType *structType = dynamic_cast< StructInstType* >( type ) ) {
143 if ( hasPolyParams( structType->get_parameters(), env ) ) return type;
144 } else if ( UnionInstType *unionType = dynamic_cast< UnionInstType* >( type ) ) {
145 if ( hasPolyParams( unionType->get_parameters(), env ) ) return type;
146 }
147 return 0;
148 }
149
150 const ast::Type * isPolyType(const ast::Type * type, const ast::TypeSubstitution * env) {
151 type = replaceTypeInst( type, env );
152
153 if ( dynamic_cast< const ast::TypeInstType * >( type ) ) {
154 return type;
155 } else if ( auto arrayType = dynamic_cast< const ast::ArrayType * >( type ) ) {
156 return isPolyType( arrayType->base, env );
157 } else if ( auto structType = dynamic_cast< const ast::StructInstType* >( type ) ) {
158 if ( hasPolyParams( structType->params, env ) ) return type;
159 } else if ( auto unionType = dynamic_cast< const ast::UnionInstType* >( type ) ) {
160 if ( hasPolyParams( unionType->params, env ) ) return type;
161 }
162 return 0;
163 }
164
165 Type *isPolyType( Type *type, const TyVarMap &tyVars, const TypeSubstitution *env ) {
166 type = replaceTypeInst( type, env );
167
168 if ( TypeInstType *typeInst = dynamic_cast< TypeInstType * >( type ) ) {
169 if ( tyVars.find( typeInst->get_name() ) != tyVars.end() ) {
170 return type;
171 }
172 } else if ( ArrayType * arrayType = dynamic_cast< ArrayType * >( type ) ) {
173 return isPolyType( arrayType->base, tyVars, env );
174 } else if ( StructInstType *structType = dynamic_cast< StructInstType* >( type ) ) {
175 if ( hasPolyParams( structType->get_parameters(), tyVars, env ) ) return type;
176 } else if ( UnionInstType *unionType = dynamic_cast< UnionInstType* >( type ) ) {
177 if ( hasPolyParams( unionType->get_parameters(), tyVars, env ) ) return type;
178 }
179 return 0;
180 }
181
182 const ast::Type * isPolyType(const ast::Type * type, const TyVarMap & tyVars, const ast::TypeSubstitution * env) {
183 type = replaceTypeInst( type, env );
184
185 if ( auto typeInst = dynamic_cast< const ast::TypeInstType * >( type ) ) {
186 return tyVars.find(typeInst->typeString()) != tyVars.end() ? type : nullptr;
187 } else if ( auto arrayType = dynamic_cast< const ast::ArrayType * >( type ) ) {
188 return isPolyType( arrayType->base, env );
189 } else if ( auto structType = dynamic_cast< const ast::StructInstType* >( type ) ) {
190 if ( hasPolyParams( structType->params, env ) ) return type;
191 } else if ( auto unionType = dynamic_cast< const ast::UnionInstType* >( type ) ) {
192 if ( hasPolyParams( unionType->params, env ) ) return type;
193 }
194 return nullptr;
195 }
196
197 ReferenceToType *isDynType( Type *type, const TyVarMap &tyVars, const TypeSubstitution *env ) {
198 type = replaceTypeInst( type, env );
199
200 if ( TypeInstType *typeInst = dynamic_cast< TypeInstType * >( type ) ) {
201 auto var = tyVars.find( typeInst->get_name() );
202 if ( var != tyVars.end() && var->second.isComplete ) {
203 return typeInst;
204 }
205 } else if ( StructInstType *structType = dynamic_cast< StructInstType* >( type ) ) {
206 if ( hasDynParams( structType->get_parameters(), tyVars, env ) ) return structType;
207 } else if ( UnionInstType *unionType = dynamic_cast< UnionInstType* >( type ) ) {
208 if ( hasDynParams( unionType->get_parameters(), tyVars, env ) ) return unionType;
209 }
210 return 0;
211 }
212
213 const ast::BaseInstType *isDynType( const ast::Type *type, const TyVarMap &tyVars, const ast::TypeSubstitution *typeSubs ) {
214 type = replaceTypeInst( type, typeSubs );
215
216 if ( auto inst = dynamic_cast<ast::TypeInstType const *>( type ) ) {
217 auto var = tyVars.find( inst->name );
218 if ( var != tyVars.end() && var->second.isComplete ) {
219 return inst;
220 }
221 } else if ( auto inst = dynamic_cast<ast::StructInstType const *>( type ) ) {
222 if ( hasDynParams( inst->params, tyVars, typeSubs ) ) {
223 return inst;
224 }
225 } else if ( auto inst = dynamic_cast<ast::UnionInstType const *>( type ) ) {
226 if ( hasDynParams( inst->params, tyVars, typeSubs ) ) {
227 return inst;
228 }
229 }
230 return nullptr;
231 }
232
233 ReferenceToType *isDynRet( FunctionType *function, const TyVarMap &forallTypes ) {
234 if ( function->get_returnVals().empty() ) return 0;
235
236 return (ReferenceToType*)isDynType( function->get_returnVals().front()->get_type(), forallTypes );
237 }
238
239 ReferenceToType *isDynRet( FunctionType *function ) {
240 if ( function->get_returnVals().empty() ) return 0;
241
242 TyVarMap forallTypes( TypeDecl::Data{} );
243 makeTyVarMap( function, forallTypes );
244 return (ReferenceToType*)isDynType( function->get_returnVals().front()->get_type(), forallTypes );
245 }
246
247 bool needsAdapter( FunctionType *adaptee, const TyVarMap &tyVars ) {
248// if ( ! adaptee->get_returnVals().empty() && isPolyType( adaptee->get_returnVals().front()->get_type(), tyVars ) ) {
249// return true;
250// } // if
251 if ( isDynRet( adaptee, tyVars ) ) return true;
252
253 for ( std::list< DeclarationWithType* >::const_iterator innerArg = adaptee->get_parameters().begin(); innerArg != adaptee->get_parameters().end(); ++innerArg ) {
254// if ( isPolyType( (*innerArg)->get_type(), tyVars ) ) {
255 if ( isDynType( (*innerArg)->get_type(), tyVars ) ) {
256 return true;
257 } // if
258 } // for
259 return false;
260 }
261
262 Type *isPolyPtr( Type *type, const TypeSubstitution *env ) {
263 type = replaceTypeInst( type, env );
264
265 if ( PointerType *ptr = dynamic_cast< PointerType *>( type ) ) {
266 return isPolyType( ptr->get_base(), env );
267 }
268 return 0;
269 }
270
271 Type *isPolyPtr( Type *type, const TyVarMap &tyVars, const TypeSubstitution *env ) {
272 type = replaceTypeInst( type, env );
273
274 if ( PointerType *ptr = dynamic_cast< PointerType *>( type ) ) {
275 return isPolyType( ptr->get_base(), tyVars, env );
276 }
277 return 0;
278 }
279
280 Type * hasPolyBase( Type *type, int *levels, const TypeSubstitution *env ) {
281 int dummy;
282 if ( ! levels ) { levels = &dummy; }
283 *levels = 0;
284
285 while ( true ) {
286 type = replaceTypeInst( type, env );
287
288 if ( PointerType *ptr = dynamic_cast< PointerType *>( type ) ) {
289 type = ptr->get_base();
290 ++(*levels);
291 } else break;
292 }
293
294 return isPolyType( type, env );
295 }
296
297 Type * hasPolyBase( Type *type, const TyVarMap &tyVars, int *levels, const TypeSubstitution *env ) {
298 int dummy;
299 if ( ! levels ) { levels = &dummy; }
300 *levels = 0;
301
302 while ( true ) {
303 type = replaceTypeInst( type, env );
304
305 if ( PointerType *ptr = dynamic_cast< PointerType *>( type ) ) {
306 type = ptr->get_base();
307 ++(*levels);
308 } else break;
309 }
310
311 return isPolyType( type, tyVars, env );
312 }
313
314 bool includesPolyType( Type *type, const TypeSubstitution *env ) {
315 type = replaceTypeInst( type, env );
316
317 if ( dynamic_cast< TypeInstType * >( type ) ) {
318 return true;
319 } else if ( PointerType *pointerType = dynamic_cast< PointerType* >( type ) ) {
320 if ( includesPolyType( pointerType->get_base(), env ) ) return true;
321 } else if ( StructInstType *structType = dynamic_cast< StructInstType* >( type ) ) {
322 if ( includesPolyParams( structType->get_parameters(), env ) ) return true;
323 } else if ( UnionInstType *unionType = dynamic_cast< UnionInstType* >( type ) ) {
324 if ( includesPolyParams( unionType->get_parameters(), env ) ) return true;
325 }
326 return false;
327 }
328
329 bool includesPolyType( Type *type, const TyVarMap &tyVars, const TypeSubstitution *env ) {
330 type = replaceTypeInst( type, env );
331
332 if ( TypeInstType *typeInstType = dynamic_cast< TypeInstType * >( type ) ) {
333 if ( tyVars.find( typeInstType->get_name() ) != tyVars.end() ) {
334 return true;
335 }
336 } else if ( PointerType *pointerType = dynamic_cast< PointerType* >( type ) ) {
337 if ( includesPolyType( pointerType->get_base(), tyVars, env ) ) return true;
338 } else if ( StructInstType *structType = dynamic_cast< StructInstType* >( type ) ) {
339 if ( includesPolyParams( structType->get_parameters(), tyVars, env ) ) return true;
340 } else if ( UnionInstType *unionType = dynamic_cast< UnionInstType* >( type ) ) {
341 if ( includesPolyParams( unionType->get_parameters(), tyVars, env ) ) return true;
342 }
343 return false;
344 }
345
346 FunctionType * getFunctionType( Type *ty ) {
347 PointerType *ptrType;
348 if ( ( ptrType = dynamic_cast< PointerType* >( ty ) ) ) {
349 return dynamic_cast< FunctionType* >( ptrType->get_base() ); // pointer if FunctionType, NULL otherwise
350 } else {
351 return dynamic_cast< FunctionType* >( ty ); // pointer if FunctionType, NULL otherwise
352 }
353 }
354
355 const ast::FunctionType * getFunctionType( const ast::Type * ty ) {
356 if ( auto pty = dynamic_cast< const ast::PointerType * >( ty ) ) {
357 return pty->base.as< ast::FunctionType >();
358 } else {
359 return dynamic_cast< const ast::FunctionType * >( ty );
360 }
361 }
362
363 VariableExpr * getBaseVar( Expression *expr, int *levels ) {
364 int dummy;
365 if ( ! levels ) { levels = &dummy; }
366 *levels = 0;
367
368 while ( true ) {
369 if ( VariableExpr *varExpr = dynamic_cast< VariableExpr* >( expr ) ) {
370 return varExpr;
371 } else if ( MemberExpr *memberExpr = dynamic_cast< MemberExpr* >( expr ) ) {
372 expr = memberExpr->get_aggregate();
373 } else if ( AddressExpr *addressExpr = dynamic_cast< AddressExpr* >( expr ) ) {
374 expr = addressExpr->get_arg();
375 } else if ( UntypedExpr *untypedExpr = dynamic_cast< UntypedExpr* >( expr ) ) {
376 // look for compiler-inserted dereference operator
377 NameExpr *fn = dynamic_cast< NameExpr* >( untypedExpr->get_function() );
378 if ( ! fn || fn->get_name() != std::string("*?") ) return 0;
379 expr = *untypedExpr->begin_args();
380 } else if ( CommaExpr *commaExpr = dynamic_cast< CommaExpr* >( expr ) ) {
381 // copy constructors insert comma exprs, look at second argument which contains the variable
382 expr = commaExpr->get_arg2();
383 continue;
384 } else if ( ConditionalExpr * condExpr = dynamic_cast< ConditionalExpr * >( expr ) ) {
385 int lvl1;
386 int lvl2;
387 VariableExpr * var1 = getBaseVar( condExpr->get_arg2(), &lvl1 );
388 VariableExpr * var2 = getBaseVar( condExpr->get_arg3(), &lvl2 );
389 if ( lvl1 == lvl2 && var1 && var2 && var1->get_var() == var2->get_var() ) {
390 *levels = lvl1;
391 return var1;
392 }
393 break;
394 } else break;
395
396 ++(*levels);
397 }
398
399 return 0;
400 }
401
402 namespace {
403 /// Checks if is a pointer to D
404 template<typename D, typename B>
405 bool is( const B* p ) { return type_index{typeid(D)} == type_index{typeid(*p)}; }
406
407 /// Converts to a pointer to D without checking for safety
408 template<typename D, typename B>
409 inline D* as( B* p ) { return reinterpret_cast<D*>(p); }
410
411 template<typename D, typename B>
412 inline D const * as( B const * p ) {
413 return reinterpret_cast<D const *>( p );
414 }
415
416 /// Flattens a declaration list
417 template<typename Output>
418 void flattenList( list< DeclarationWithType* > src, Output out ) {
419 for ( DeclarationWithType* decl : src ) {
420 ResolvExpr::flatten( decl->get_type(), out );
421 }
422 }
423
424 /// Flattens a list of types
425 template<typename Output>
426 void flattenList( list< Type* > src, Output out ) {
427 for ( Type* ty : src ) {
428 ResolvExpr::flatten( ty, out );
429 }
430 }
431
432 void flattenList( vector<ast::ptr<ast::Type>> const & src,
433 vector<ast::ptr<ast::Type>> & out ) {
434 for ( auto const & type : src ) {
435 ResolvExpr::flatten( type, out );
436 }
437 }
438
439 /// Checks if two lists of parameters are equal up to polymorphic substitution.
440 bool paramListsPolyCompatible( const list< Expression* >& aparams, const list< Expression* >& bparams ) {
441 if ( aparams.size() != bparams.size() ) return false;
442
443 for ( list< Expression* >::const_iterator at = aparams.begin(), bt = bparams.begin();
444 at != aparams.end(); ++at, ++bt ) {
445 TypeExpr *aparam = dynamic_cast< TypeExpr* >(*at);
446 assertf(aparam, "Aggregate parameters should be type expressions");
447 TypeExpr *bparam = dynamic_cast< TypeExpr* >(*bt);
448 assertf(bparam, "Aggregate parameters should be type expressions");
449
450 // xxx - might need to let VoidType be a wildcard here too; could have some voids
451 // stuffed in for dtype-statics.
452 // if ( is<VoidType>( aparam->get_type() ) || is<VoidType>( bparam->get_type() ) ) continue;
453 if ( ! typesPolyCompatible( aparam->get_type(), bparam->get_type() ) ) return false;
454 }
455
456 return true;
457 }
458
459 bool paramListsPolyCompatible(
460 std::vector<ast::ptr<ast::Expr>> const & lparams,
461 std::vector<ast::ptr<ast::Expr>> const & rparams ) {
462 if ( lparams.size() != rparams.size() ) {
463 return false;
464 }
465
466 for ( auto lparam = lparams.begin(), rparam = rparams.begin() ;
467 lparam != lparams.end() ; ++lparam, ++rparam ) {
468 ast::TypeExpr const * lexpr = lparam->as<ast::TypeExpr>();
469 assertf( lexpr, "Aggregate parameters should be type expressions" );
470 ast::TypeExpr const * rexpr = rparam->as<ast::TypeExpr>();
471 assertf( rexpr, "Aggregate parameters should be type expressions" );
472
473 // xxx - might need to let VoidType be a wildcard here too; could have some voids
474 // stuffed in for dtype-statics.
475 // if ( is<VoidType>( lexpr->type() ) || is<VoidType>( bparam->get_type() ) ) continue;
476 if ( !typesPolyCompatible( lexpr->type, rexpr->type ) ) {
477 return false;
478 }
479 }
480
481 return true;
482 }
483 }
484
485 bool typesPolyCompatible( Type *a, Type *b ) {
486 type_index aid{ typeid(*a) };
487 // polymorphic types always match
488 if ( aid == type_index{typeid(TypeInstType)} ) return true;
489
490 type_index bid{ typeid(*b) };
491 // polymorphic types always match
492 if ( bid == type_index{typeid(TypeInstType)} ) return true;
493
494 // can't match otherwise if different types
495 if ( aid != bid ) return false;
496
497 // recurse through type structure (conditions borrowed from Unify.cc)
498 if ( aid == type_index{typeid(BasicType)} ) {
499 return as<BasicType>(a)->get_kind() == as<BasicType>(b)->get_kind();
500 } else if ( aid == type_index{typeid(PointerType)} ) {
501 PointerType *ap = as<PointerType>(a), *bp = as<PointerType>(b);
502
503 // void pointers should match any other pointer type
504 return is<VoidType>( ap->get_base() ) || is<VoidType>( bp->get_base() )
505 || typesPolyCompatible( ap->get_base(), bp->get_base() );
506 } else if ( aid == type_index{typeid(ReferenceType)} ) {
507 ReferenceType *ap = as<ReferenceType>(a), *bp = as<ReferenceType>(b);
508 return is<VoidType>( ap->get_base() ) || is<VoidType>( bp->get_base() )
509 || typesPolyCompatible( ap->get_base(), bp->get_base() );
510 } else if ( aid == type_index{typeid(ArrayType)} ) {
511 ArrayType *aa = as<ArrayType>(a), *ba = as<ArrayType>(b);
512
513 if ( aa->get_isVarLen() ) {
514 if ( ! ba->get_isVarLen() ) return false;
515 } else {
516 if ( ba->get_isVarLen() ) return false;
517
518 ConstantExpr *ad = dynamic_cast<ConstantExpr*>( aa->get_dimension() );
519 ConstantExpr *bd = dynamic_cast<ConstantExpr*>( ba->get_dimension() );
520 if ( ad && bd
521 && ad->get_constant()->get_value() != bd->get_constant()->get_value() )
522 return false;
523 }
524
525 return typesPolyCompatible( aa->get_base(), ba->get_base() );
526 } else if ( aid == type_index{typeid(FunctionType)} ) {
527 FunctionType *af = as<FunctionType>(a), *bf = as<FunctionType>(b);
528
529 vector<Type*> aparams, bparams;
530 flattenList( af->get_parameters(), back_inserter( aparams ) );
531 flattenList( bf->get_parameters(), back_inserter( bparams ) );
532 if ( aparams.size() != bparams.size() ) return false;
533
534 vector<Type*> areturns, breturns;
535 flattenList( af->get_returnVals(), back_inserter( areturns ) );
536 flattenList( bf->get_returnVals(), back_inserter( breturns ) );
537 if ( areturns.size() != breturns.size() ) return false;
538
539 for ( unsigned i = 0; i < aparams.size(); ++i ) {
540 if ( ! typesPolyCompatible( aparams[i], bparams[i] ) ) return false;
541 }
542 for ( unsigned i = 0; i < areturns.size(); ++i ) {
543 if ( ! typesPolyCompatible( areturns[i], breturns[i] ) ) return false;
544 }
545 return true;
546 } else if ( aid == type_index{typeid(StructInstType)} ) {
547 StructInstType *aa = as<StructInstType>(a), *ba = as<StructInstType>(b);
548
549 if ( aa->get_name() != ba->get_name() ) return false;
550 return paramListsPolyCompatible( aa->get_parameters(), ba->get_parameters() );
551 } else if ( aid == type_index{typeid(UnionInstType)} ) {
552 UnionInstType *aa = as<UnionInstType>(a), *ba = as<UnionInstType>(b);
553
554 if ( aa->get_name() != ba->get_name() ) return false;
555 return paramListsPolyCompatible( aa->get_parameters(), ba->get_parameters() );
556 } else if ( aid == type_index{typeid(EnumInstType)} ) {
557 return as<EnumInstType>(a)->get_name() == as<EnumInstType>(b)->get_name();
558 } else if ( aid == type_index{typeid(TraitInstType)} ) {
559 return as<TraitInstType>(a)->get_name() == as<TraitInstType>(b)->get_name();
560 } else if ( aid == type_index{typeid(TupleType)} ) {
561 TupleType *at = as<TupleType>(a), *bt = as<TupleType>(b);
562
563 vector<Type*> atypes, btypes;
564 flattenList( at->get_types(), back_inserter( atypes ) );
565 flattenList( bt->get_types(), back_inserter( btypes ) );
566 if ( atypes.size() != btypes.size() ) return false;
567
568 for ( unsigned i = 0; i < atypes.size(); ++i ) {
569 if ( ! typesPolyCompatible( atypes[i], btypes[i] ) ) return false;
570 }
571 return true;
572 } else return true; // VoidType, VarArgsType, ZeroType & OneType just need the same type
573 }
574
575bool typesPolyCompatible( ast::Type const * lhs, ast::Type const * rhs ) {
576 type_index const lid = typeid(*lhs);
577
578 // Polymorphic types always match:
579 if ( type_index(typeid(ast::TypeInstType)) == lid ) return true;
580
581 type_index const rid = typeid(*rhs);
582 if ( type_index(typeid(ast::TypeInstType)) == rid ) return true;
583
584 // All other types only match if they are the same type:
585 if ( lid != rid ) return false;
586
587 // So remaining types can be examined case by case.
588 // Recurse through type structure (conditions borrowed from Unify.cc).
589
590 if ( type_index(typeid(ast::BasicType)) == lid ) {
591 return as<ast::BasicType>(lhs)->kind == as<ast::BasicType>(rhs)->kind;
592 } else if ( type_index(typeid(ast::PointerType)) == lid ) {
593 ast::PointerType const * l = as<ast::PointerType>(lhs);
594 ast::PointerType const * r = as<ast::PointerType>(rhs);
595
596 // void pointers should match any other pointer type.
597 return is<ast::VoidType>( l->base.get() )
598 || is<ast::VoidType>( r->base.get() )
599 || typesPolyCompatible( l->base.get(), r->base.get() );
600 } else if ( type_index(typeid(ast::ReferenceType)) == lid ) {
601 ast::ReferenceType const * l = as<ast::ReferenceType>(lhs);
602 ast::ReferenceType const * r = as<ast::ReferenceType>(rhs);
603
604 // void references should match any other reference type.
605 return is<ast::VoidType>( l->base.get() )
606 || is<ast::VoidType>( r->base.get() )
607 || typesPolyCompatible( l->base.get(), r->base.get() );
608 } else if ( type_index(typeid(ast::ArrayType)) == lid ) {
609 ast::ArrayType const * l = as<ast::ArrayType>(lhs);
610 ast::ArrayType const * r = as<ast::ArrayType>(rhs);
611
612 if ( l->isVarLen ) {
613 if ( !r->isVarLen ) return false;
614 } else {
615 if ( r->isVarLen ) return false;
616
617 auto lc = l->dimension.as<ast::ConstantExpr>();
618 auto rc = r->dimension.as<ast::ConstantExpr>();
619 if ( lc && rc && lc->intValue() != rc->intValue() ) {
620 return false;
621 }
622 }
623
624 return typesPolyCompatible( l->base.get(), r->base.get() );
625 } else if ( type_index(typeid(ast::FunctionType)) == lid ) {
626 ast::FunctionType const * l = as<ast::FunctionType>(lhs);
627 ast::FunctionType const * r = as<ast::FunctionType>(rhs);
628
629 std::vector<ast::ptr<ast::Type>> lparams, rparams;
630 flattenList( l->params, lparams );
631 flattenList( r->params, rparams );
632 if ( lparams.size() != rparams.size() ) return false;
633 for ( unsigned i = 0; i < lparams.size(); ++i ) {
634 if ( !typesPolyCompatible( lparams[i], rparams[i] ) ) return false;
635 }
636
637 std::vector<ast::ptr<ast::Type>> lrets, rrets;
638 flattenList( l->returns, lrets );
639 flattenList( r->returns, rrets );
640 if ( lrets.size() != rrets.size() ) return false;
641 for ( unsigned i = 0; i < lrets.size(); ++i ) {
642 if ( !typesPolyCompatible( lrets[i], rrets[i] ) ) return false;
643 }
644 return true;
645 } else if ( type_index(typeid(ast::StructInstType)) == lid ) {
646 ast::StructInstType const * l = as<ast::StructInstType>(lhs);
647 ast::StructInstType const * r = as<ast::StructInstType>(rhs);
648
649 if ( l->name != r->name ) return false;
650 return paramListsPolyCompatible( l->params, r->params );
651 } else if ( type_index(typeid(ast::UnionInstType)) == lid ) {
652 ast::UnionInstType const * l = as<ast::UnionInstType>(lhs);
653 ast::UnionInstType const * r = as<ast::UnionInstType>(rhs);
654
655 if ( l->name != r->name ) return false;
656 return paramListsPolyCompatible( l->params, r->params );
657 } else if ( type_index(typeid(ast::EnumInstType)) == lid ) {
658 ast::EnumInstType const * l = as<ast::EnumInstType>(lhs);
659 ast::EnumInstType const * r = as<ast::EnumInstType>(rhs);
660
661 return l->name == r->name;
662 } else if ( type_index(typeid(ast::TraitInstType)) == lid ) {
663 ast::TraitInstType const * l = as<ast::TraitInstType>(lhs);
664 ast::TraitInstType const * r = as<ast::TraitInstType>(rhs);
665
666 return l->name == r->name;
667 } else if ( type_index(typeid(ast::TupleType)) == lid ) {
668 ast::TupleType const * l = as<ast::TupleType>(lhs);
669 ast::TupleType const * r = as<ast::TupleType>(rhs);
670
671 std::vector<ast::ptr<ast::Type>> ltypes, rtypes;
672 flattenList( l->types, ( ltypes ) );
673 flattenList( r->types, ( rtypes ) );
674 if ( ltypes.size() != rtypes.size() ) return false;
675
676 for ( unsigned i = 0 ; i < ltypes.size() ; ++i ) {
677 if ( !typesPolyCompatible( ltypes[i], rtypes[i] ) ) return false;
678 }
679 return true;
680 // The remaining types (VoidType, VarArgsType, ZeroType & OneType)
681 // have no variation so will always be equal.
682 } else {
683 return true;
684 }
685}
686
687 namespace {
688 // temporary hack to avoid re-implementing anything related to TyVarMap
689 // does this work? these two structs have identical definitions.
690 inline TypeDecl::Data convData(const ast::TypeDecl::Data & data) {
691 return *reinterpret_cast<const TypeDecl::Data *>(&data);
692 }
693 }
694
695 bool needsBoxing( Type * param, Type * arg, const TyVarMap &exprTyVars, const TypeSubstitution * env ) {
696 // is parameter is not polymorphic, don't need to box
697 if ( ! isPolyType( param, exprTyVars ) ) return false;
698 Type * newType = arg->clone();
699 if ( env ) env->apply( newType );
700 std::unique_ptr<Type> manager( newType );
701 // if the argument's type is polymorphic, we don't need to box again!
702 return ! isPolyType( newType );
703 }
704
705 bool needsBoxing( const ast::Type * param, const ast::Type * arg, const TyVarMap &exprTyVars, const ast::TypeSubstitution * env) {
706 // is parameter is not polymorphic, don't need to box
707 if ( ! isPolyType( param, exprTyVars ) ) return false;
708 ast::ptr<ast::Type> newType = arg;
709 if ( env ) env->apply( newType );
710 // if the argument's type is polymorphic, we don't need to box again!
711 return ! isPolyType( newType );
712 }
713
714 bool needsBoxing( Type * param, Type * arg, ApplicationExpr * appExpr, const TypeSubstitution * env ) {
715 FunctionType * function = getFunctionType( appExpr->function->result );
716 assertf( function, "ApplicationExpr has non-function type: %s", toString( appExpr->function->result ).c_str() );
717 TyVarMap exprTyVars( TypeDecl::Data{} );
718 makeTyVarMap( function, exprTyVars );
719 return needsBoxing( param, arg, exprTyVars, env );
720 }
721
722 bool needsBoxing( const ast::Type * param, const ast::Type * arg, const ast::ApplicationExpr * appExpr, const ast::TypeSubstitution * env) {
723 const ast::FunctionType * function = getFunctionType(appExpr->func->result);
724 assertf( function, "ApplicationExpr has non-function type: %s", toString( appExpr->func->result ).c_str() );
725 TyVarMap exprTyVars(TypeDecl::Data{});
726 makeTyVarMap(function, exprTyVars);
727 return needsBoxing(param, arg, exprTyVars, env);
728
729 }
730
731 void addToTyVarMap( TypeDecl * tyVar, TyVarMap &tyVarMap ) {
732 tyVarMap.insert( tyVar->name, TypeDecl::Data{ tyVar } );
733 }
734
735 void addToTyVarMap( const ast::TypeInstType * tyVar, TyVarMap & tyVarMap) {
736 tyVarMap.insert(tyVar->typeString(), convData(ast::TypeDecl::Data{tyVar->base}));
737 }
738
739 void makeTyVarMap( Type *type, TyVarMap &tyVarMap ) {
740 for ( Type::ForallList::const_iterator tyVar = type->get_forall().begin(); tyVar != type->get_forall().end(); ++tyVar ) {
741 assert( *tyVar );
742 addToTyVarMap( *tyVar, tyVarMap );
743 }
744 if ( PointerType *pointer = dynamic_cast< PointerType* >( type ) ) {
745 makeTyVarMap( pointer->get_base(), tyVarMap );
746 }
747 }
748
749 void makeTyVarMap(const ast::Type * type, TyVarMap & tyVarMap) {
750 if (auto ptype = dynamic_cast<const ast::FunctionType *>(type)) {
751 for (auto & tyVar : ptype->forall) {
752 assert (tyVar);
753 addToTyVarMap(tyVar, tyVarMap);
754 }
755 }
756 if (auto pointer = dynamic_cast<const ast::PointerType *>(type)) {
757 makeTyVarMap(pointer->base, tyVarMap);
758 }
759 }
760
761 void printTyVarMap( std::ostream &os, const TyVarMap &tyVarMap ) {
762 for ( TyVarMap::const_iterator i = tyVarMap.begin(); i != tyVarMap.end(); ++i ) {
763 os << i->first << " (" << i->second << ") ";
764 } // for
765 os << std::endl;
766 }
767
768} // namespace GenPoly
769
770// Local Variables: //
771// tab-width: 4 //
772// mode: c++ //
773// compile-command: "make install" //
774// End: //
Note: See TracBrowser for help on using the repository browser.