source: src/GenPoly/GenPoly.cc @ 8c91088

ADTast-experimental
Last change on this file since 8c91088 was 3606fe4, checked in by Andrew Beach <ajbeach@…>, 20 months ago

Translated Instantiate Generic to the new AST. This includes various utilities and some assorted clean-up.

  • Property mode set to 100644
File size: 30.2 KB
Line 
1//
2// Cforall Version 1.0.0 Copyright (C) 2015 University of Waterloo
3//
4// The contents of this file are covered under the licence agreement in the
5// file "LICENCE" distributed with Cforall.
6//
7// GenPoly.cc --
8//
9// Author           : Richard C. Bilson
10// Created On       : Mon May 18 07:44:20 2015
11// Last Modified By : Andrew Beach
12// Last Modified On : Wed Sep 14  9:24:00 2022
13// Update Count     : 15
14//
15
16#include "GenPoly.h"
17
18#include <cassert>                      // for assertf, assert
19#include <iostream>                     // for operator<<, ostream, basic_os...
20#include <iterator>                     // for back_insert_iterator, back_in...
21#include <list>                         // for list, _List_iterator, list<>:...
22#include <typeindex>                    // for type_index
23#include <utility>                      // for pair
24#include <vector>                       // for vector
25
26#include "AST/Type.hpp"
27#include "GenPoly/ErasableScopedMap.h"  // for ErasableScopedMap<>::const_it...
28#include "ResolvExpr/typeops.h"         // for flatten
29#include "SynTree/Constant.h"           // for Constant
30#include "SynTree/Expression.h"         // for Expression, TypeExpr, Constan...
31#include "SynTree/Type.h"               // for Type, StructInstType, UnionIn...
32#include "SynTree/TypeSubstitution.h"   // for TypeSubstitution
33
34using namespace std;
35
36namespace GenPoly {
37        namespace {
38                /// Checks a parameter list for polymorphic parameters; will substitute according to env if present
39                bool hasPolyParams( std::list< Expression* >& params, const TypeSubstitution *env ) {
40                        for ( std::list< Expression* >::iterator param = params.begin(); param != params.end(); ++param ) {
41                                TypeExpr *paramType = dynamic_cast< TypeExpr* >( *param );
42                                assertf(paramType, "Aggregate parameters should be type expressions");
43                                if ( isPolyType( paramType->get_type(), env ) ) return true;
44                        }
45                        return false;
46                }
47
48                bool hasPolyParams( const std::vector<ast::ptr<ast::Expr>> & params, const ast::TypeSubstitution * env) {
49                        for (auto &param : params) {
50                                auto paramType = param.strict_as<ast::TypeExpr>();
51                                if (isPolyType(paramType->type, env)) return true;
52                        }
53                        return false;
54                }
55
56                /// Checks a parameter list for polymorphic parameters from tyVars; will substitute according to env if present
57                bool hasPolyParams( std::list< Expression* >& params, const TyVarMap &tyVars, const TypeSubstitution *env ) {
58                        for ( std::list< Expression* >::iterator param = params.begin(); param != params.end(); ++param ) {
59                                TypeExpr *paramType = dynamic_cast< TypeExpr* >( *param );
60                                assertf(paramType, "Aggregate parameters should be type expressions");
61                                if ( isPolyType( paramType->get_type(), tyVars, env ) ) return true;
62                        }
63                        return false;
64                }
65
66                __attribute__((unused))
67                bool hasPolyParams( const std::vector<ast::ptr<ast::Expr>> & params, const TyVarMap & tyVars, const ast::TypeSubstitution * env) {
68                        for (auto &param : params) {
69                                auto paramType = param.strict_as<ast::TypeExpr>();
70                                if (isPolyType(paramType->type, tyVars, env)) return true;
71                        }
72                        return false;
73                }
74
75                /// Checks a parameter list for dynamic-layout parameters from tyVars; will substitute according to env if present
76                bool hasDynParams( std::list< Expression* >& params, const TyVarMap &tyVars, const TypeSubstitution *env ) {
77                        for ( std::list< Expression* >::iterator param = params.begin(); param != params.end(); ++param ) {
78                                TypeExpr *paramType = dynamic_cast< TypeExpr* >( *param );
79                                assertf(paramType, "Aggregate parameters should be type expressions");
80                                if ( isDynType( paramType->get_type(), tyVars, env ) ) return true;
81                        }
82                        return false;
83                }
84
85                bool hasDynParams( const std::vector<ast::ptr<ast::Expr>> & params, const TyVarMap &tyVars, const ast::TypeSubstitution *typeSubs ) {
86                        for ( ast::ptr<ast::Expr> const & param : params ) {
87                                auto paramType = param.as<ast::TypeExpr>();
88                                assertf( paramType, "Aggregate parameters should be type expressions." );
89                                if ( isDynType( paramType->type, tyVars, typeSubs ) ) {
90                                        return true;
91                                }
92                        }
93                        return false;
94                }
95
96                /// Checks a parameter list for inclusion of polymorphic parameters; will substitute according to env if present
97                bool includesPolyParams( std::list< Expression* >& params, const TypeSubstitution *env ) {
98                        for ( std::list< Expression* >::iterator param = params.begin(); param != params.end(); ++param ) {
99                                TypeExpr *paramType = dynamic_cast< TypeExpr* >( *param );
100                                assertf(paramType, "Aggregate parameters should be type expressions");
101                                if ( includesPolyType( paramType->get_type(), env ) ) return true;
102                        }
103                        return false;
104                }
105
106                /// Checks a parameter list for inclusion of polymorphic parameters from tyVars; will substitute according to env if present
107                bool includesPolyParams( std::list< Expression* >& params, const TyVarMap &tyVars, const TypeSubstitution *env ) {
108                        for ( std::list< Expression* >::iterator param = params.begin(); param != params.end(); ++param ) {
109                                TypeExpr *paramType = dynamic_cast< TypeExpr* >( *param );
110                                assertf(paramType, "Aggregate parameters should be type expressions");
111                                if ( includesPolyType( paramType->get_type(), tyVars, env ) ) return true;
112                        }
113                        return false;
114                }
115        }
116
117        Type* replaceTypeInst( Type* type, const TypeSubstitution* env ) {
118                if ( ! env ) return type;
119                if ( TypeInstType *typeInst = dynamic_cast< TypeInstType* >( type ) ) {
120                        Type *newType = env->lookup( typeInst->get_name() );
121                        if ( newType ) return newType;
122                }
123                return type;
124        }
125
126        const ast::Type * replaceTypeInst(const ast::Type * type, const ast::TypeSubstitution * env) {
127                if (!env) return type;
128                if (auto typeInst = dynamic_cast<const ast::TypeInstType*> (type)) {
129                        auto newType = env->lookup(typeInst);
130                        if (newType) return newType;
131                }
132                return type;
133        }
134
135        Type *isPolyType( Type *type, const TypeSubstitution *env ) {
136                type = replaceTypeInst( type, env );
137
138                if ( dynamic_cast< TypeInstType * >( type ) ) {
139                        return type;
140                } else if ( ArrayType * arrayType = dynamic_cast< ArrayType * >( type ) ) {
141                        return isPolyType( arrayType->base, env );
142                } else if ( StructInstType *structType = dynamic_cast< StructInstType* >( type ) ) {
143                        if ( hasPolyParams( structType->get_parameters(), env ) ) return type;
144                } else if ( UnionInstType *unionType = dynamic_cast< UnionInstType* >( type ) ) {
145                        if ( hasPolyParams( unionType->get_parameters(), env ) ) return type;
146                }
147                return 0;
148        }
149
150        const ast::Type * isPolyType(const ast::Type * type, const ast::TypeSubstitution * env) {
151                type = replaceTypeInst( type, env );
152
153                if ( dynamic_cast< const ast::TypeInstType * >( type ) ) {
154                        return type;
155                } else if ( auto arrayType = dynamic_cast< const ast::ArrayType * >( type ) ) {
156                        return isPolyType( arrayType->base, env );
157                } else if ( auto structType = dynamic_cast< const ast::StructInstType* >( type ) ) {
158                        if ( hasPolyParams( structType->params, env ) ) return type;
159                } else if ( auto unionType = dynamic_cast< const ast::UnionInstType* >( type ) ) {
160                        if ( hasPolyParams( unionType->params, env ) ) return type;
161                }
162                return 0;
163        }
164
165        Type *isPolyType( Type *type, const TyVarMap &tyVars, const TypeSubstitution *env ) {
166                type = replaceTypeInst( type, env );
167
168                if ( TypeInstType *typeInst = dynamic_cast< TypeInstType * >( type ) ) {
169                        if ( tyVars.find( typeInst->get_name() ) != tyVars.end() ) {
170                                return type;
171                        }
172                } else if ( ArrayType * arrayType = dynamic_cast< ArrayType * >( type ) ) {
173                        return isPolyType( arrayType->base, tyVars, env );
174                } else if ( StructInstType *structType = dynamic_cast< StructInstType* >( type ) ) {
175                        if ( hasPolyParams( structType->get_parameters(), tyVars, env ) ) return type;
176                } else if ( UnionInstType *unionType = dynamic_cast< UnionInstType* >( type ) ) {
177                        if ( hasPolyParams( unionType->get_parameters(), tyVars, env ) ) return type;
178                }
179                return 0;
180        }
181
182        const ast::Type * isPolyType(const ast::Type * type, const TyVarMap & tyVars, const ast::TypeSubstitution * env) {
183                type = replaceTypeInst( type, env );
184
185                if ( auto typeInst = dynamic_cast< const ast::TypeInstType * >( type ) ) {
186                        return tyVars.find(typeInst->typeString()) != tyVars.end() ? type : nullptr;
187                } else if ( auto arrayType = dynamic_cast< const ast::ArrayType * >( type ) ) {
188                        return isPolyType( arrayType->base, env );
189                } else if ( auto structType = dynamic_cast< const ast::StructInstType* >( type ) ) {
190                        if ( hasPolyParams( structType->params, env ) ) return type;
191                } else if ( auto unionType = dynamic_cast< const ast::UnionInstType* >( type ) ) {
192                        if ( hasPolyParams( unionType->params, env ) ) return type;
193                }
194                return nullptr;
195        }
196
197        ReferenceToType *isDynType( Type *type, const TyVarMap &tyVars, const TypeSubstitution *env ) {
198                type = replaceTypeInst( type, env );
199
200                if ( TypeInstType *typeInst = dynamic_cast< TypeInstType * >( type ) ) {
201                        auto var = tyVars.find( typeInst->get_name() );
202                        if ( var != tyVars.end() && var->second.isComplete ) {
203                                return typeInst;
204                        }
205                } else if ( StructInstType *structType = dynamic_cast< StructInstType* >( type ) ) {
206                        if ( hasDynParams( structType->get_parameters(), tyVars, env ) ) return structType;
207                } else if ( UnionInstType *unionType = dynamic_cast< UnionInstType* >( type ) ) {
208                        if ( hasDynParams( unionType->get_parameters(), tyVars, env ) ) return unionType;
209                }
210                return 0;
211        }
212
213        const ast::BaseInstType *isDynType( const ast::Type *type, const TyVarMap &tyVars, const ast::TypeSubstitution *typeSubs ) {
214                type = replaceTypeInst( type, typeSubs );
215
216                if ( auto inst = dynamic_cast<ast::TypeInstType const *>( type ) ) {
217                        auto var = tyVars.find( inst->name );
218                        if ( var != tyVars.end() && var->second.isComplete ) {
219                                return inst;
220                        }
221                } else if ( auto inst = dynamic_cast<ast::StructInstType const *>( type ) ) {
222                        if ( hasDynParams( inst->params, tyVars, typeSubs ) ) {
223                                return inst;
224                        }
225                } else if ( auto inst = dynamic_cast<ast::UnionInstType const *>( type ) ) {
226                        if ( hasDynParams( inst->params, tyVars, typeSubs ) ) {
227                                return inst;
228                        }
229                }
230                return nullptr;
231        }
232
233        ReferenceToType *isDynRet( FunctionType *function, const TyVarMap &forallTypes ) {
234                if ( function->get_returnVals().empty() ) return 0;
235
236                return (ReferenceToType*)isDynType( function->get_returnVals().front()->get_type(), forallTypes );
237        }
238
239        ReferenceToType *isDynRet( FunctionType *function ) {
240                if ( function->get_returnVals().empty() ) return 0;
241
242                TyVarMap forallTypes( TypeDecl::Data{} );
243                makeTyVarMap( function, forallTypes );
244                return (ReferenceToType*)isDynType( function->get_returnVals().front()->get_type(), forallTypes );
245        }
246
247        bool needsAdapter( FunctionType *adaptee, const TyVarMap &tyVars ) {
248//              if ( ! adaptee->get_returnVals().empty() && isPolyType( adaptee->get_returnVals().front()->get_type(), tyVars ) ) {
249//                      return true;
250//              } // if
251                if ( isDynRet( adaptee, tyVars ) ) return true;
252
253                for ( std::list< DeclarationWithType* >::const_iterator innerArg = adaptee->get_parameters().begin(); innerArg != adaptee->get_parameters().end(); ++innerArg ) {
254//                      if ( isPolyType( (*innerArg)->get_type(), tyVars ) ) {
255                        if ( isDynType( (*innerArg)->get_type(), tyVars ) ) {
256                                return true;
257                        } // if
258                } // for
259                return false;
260        }
261
262        Type *isPolyPtr( Type *type, const TypeSubstitution *env ) {
263                type = replaceTypeInst( type, env );
264
265                if ( PointerType *ptr = dynamic_cast< PointerType *>( type ) ) {
266                        return isPolyType( ptr->get_base(), env );
267                }
268                return 0;
269        }
270
271        Type *isPolyPtr( Type *type, const TyVarMap &tyVars, const TypeSubstitution *env ) {
272                type = replaceTypeInst( type, env );
273
274                if ( PointerType *ptr = dynamic_cast< PointerType *>( type ) ) {
275                        return isPolyType( ptr->get_base(), tyVars, env );
276                }
277                return 0;
278        }
279
280        Type * hasPolyBase( Type *type, int *levels, const TypeSubstitution *env ) {
281                int dummy;
282                if ( ! levels ) { levels = &dummy; }
283                *levels = 0;
284
285                while ( true ) {
286                        type = replaceTypeInst( type, env );
287
288                        if ( PointerType *ptr = dynamic_cast< PointerType *>( type ) ) {
289                                type = ptr->get_base();
290                                ++(*levels);
291                        } else break;
292                }
293
294                return isPolyType( type, env );
295        }
296
297        Type * hasPolyBase( Type *type, const TyVarMap &tyVars, int *levels, const TypeSubstitution *env ) {
298                int dummy;
299                if ( ! levels ) { levels = &dummy; }
300                *levels = 0;
301
302                while ( true ) {
303                        type = replaceTypeInst( type, env );
304
305                        if ( PointerType *ptr = dynamic_cast< PointerType *>( type ) ) {
306                                type = ptr->get_base();
307                                ++(*levels);
308                        } else break;
309                }
310
311                return isPolyType( type, tyVars, env );
312        }
313
314        bool includesPolyType( Type *type, const TypeSubstitution *env ) {
315                type = replaceTypeInst( type, env );
316
317                if ( dynamic_cast< TypeInstType * >( type ) ) {
318                        return true;
319                } else if ( PointerType *pointerType = dynamic_cast< PointerType* >( type ) ) {
320                        if ( includesPolyType( pointerType->get_base(), env ) ) return true;
321                } else if ( StructInstType *structType = dynamic_cast< StructInstType* >( type ) ) {
322                        if ( includesPolyParams( structType->get_parameters(), env ) ) return true;
323                } else if ( UnionInstType *unionType = dynamic_cast< UnionInstType* >( type ) ) {
324                        if ( includesPolyParams( unionType->get_parameters(), env ) ) return true;
325                }
326                return false;
327        }
328
329        bool includesPolyType( Type *type, const TyVarMap &tyVars, const TypeSubstitution *env ) {
330                type = replaceTypeInst( type, env );
331
332                if ( TypeInstType *typeInstType = dynamic_cast< TypeInstType * >( type ) ) {
333                        if ( tyVars.find( typeInstType->get_name() ) != tyVars.end() ) {
334                                return true;
335                        }
336                } else if ( PointerType *pointerType = dynamic_cast< PointerType* >( type ) ) {
337                        if ( includesPolyType( pointerType->get_base(), tyVars, env ) ) return true;
338                } else if ( StructInstType *structType = dynamic_cast< StructInstType* >( type ) ) {
339                        if ( includesPolyParams( structType->get_parameters(), tyVars, env ) ) return true;
340                } else if ( UnionInstType *unionType = dynamic_cast< UnionInstType* >( type ) ) {
341                        if ( includesPolyParams( unionType->get_parameters(), tyVars, env ) ) return true;
342                }
343                return false;
344        }
345
346        FunctionType * getFunctionType( Type *ty ) {
347                PointerType *ptrType;
348                if ( ( ptrType = dynamic_cast< PointerType* >( ty ) ) ) {
349                        return dynamic_cast< FunctionType* >( ptrType->get_base() ); // pointer if FunctionType, NULL otherwise
350                } else {
351                        return dynamic_cast< FunctionType* >( ty ); // pointer if FunctionType, NULL otherwise
352                }
353        }
354
355        const ast::FunctionType * getFunctionType( const ast::Type * ty ) {
356                if ( auto pty = dynamic_cast< const ast::PointerType * >( ty ) ) {
357                        return pty->base.as< ast::FunctionType >();
358                } else {
359                        return dynamic_cast< const ast::FunctionType * >( ty );
360                }
361        }
362
363        VariableExpr * getBaseVar( Expression *expr, int *levels ) {
364                int dummy;
365                if ( ! levels ) { levels = &dummy; }
366                *levels = 0;
367
368                while ( true ) {
369                        if ( VariableExpr *varExpr = dynamic_cast< VariableExpr* >( expr ) ) {
370                                return varExpr;
371                        } else if ( MemberExpr *memberExpr = dynamic_cast< MemberExpr* >( expr ) ) {
372                                expr = memberExpr->get_aggregate();
373                        } else if ( AddressExpr *addressExpr = dynamic_cast< AddressExpr* >( expr ) ) {
374                                expr = addressExpr->get_arg();
375                        } else if ( UntypedExpr *untypedExpr = dynamic_cast< UntypedExpr* >( expr ) ) {
376                                // look for compiler-inserted dereference operator
377                                NameExpr *fn = dynamic_cast< NameExpr* >( untypedExpr->get_function() );
378                                if ( ! fn || fn->get_name() != std::string("*?") ) return 0;
379                                expr = *untypedExpr->begin_args();
380                        } else if ( CommaExpr *commaExpr = dynamic_cast< CommaExpr* >( expr ) ) {
381                                // copy constructors insert comma exprs, look at second argument which contains the variable
382                                expr = commaExpr->get_arg2();
383                                continue;
384                        } else if ( ConditionalExpr * condExpr = dynamic_cast< ConditionalExpr * >( expr ) ) {
385                                int lvl1;
386                                int lvl2;
387                                VariableExpr * var1 = getBaseVar( condExpr->get_arg2(), &lvl1 );
388                                VariableExpr * var2 = getBaseVar( condExpr->get_arg3(), &lvl2 );
389                                if ( lvl1 == lvl2 && var1 && var2 && var1->get_var() == var2->get_var() ) {
390                                        *levels = lvl1;
391                                        return var1;
392                                }
393                                break;
394                        } else break;
395
396                        ++(*levels);
397                }
398
399                return 0;
400        }
401
402        namespace {
403                /// Checks if is a pointer to D
404                template<typename D, typename B>
405                bool is( const B* p ) { return type_index{typeid(D)} == type_index{typeid(*p)}; }
406
407                /// Converts to a pointer to D without checking for safety
408                template<typename D, typename B>
409                inline D* as( B* p ) { return reinterpret_cast<D*>(p); }
410
411                template<typename D, typename B>
412                inline D const * as( B const * p ) {
413                        return reinterpret_cast<D const *>( p );
414                }
415
416                /// Flattens a declaration list
417                template<typename Output>
418                void flattenList( list< DeclarationWithType* > src, Output out ) {
419                        for ( DeclarationWithType* decl : src ) {
420                                ResolvExpr::flatten( decl->get_type(), out );
421                        }
422                }
423
424                /// Flattens a list of types
425                template<typename Output>
426                void flattenList( list< Type* > src, Output out ) {
427                        for ( Type* ty : src ) {
428                                ResolvExpr::flatten( ty, out );
429                        }
430                }
431
432                void flattenList( vector<ast::ptr<ast::Type>> const & src,
433                                vector<ast::ptr<ast::Type>> & out ) {
434                        for ( auto const & type : src ) {
435                                ResolvExpr::flatten( type, out );
436                        }
437                }
438
439                /// Checks if two lists of parameters are equal up to polymorphic substitution.
440                bool paramListsPolyCompatible( const list< Expression* >& aparams, const list< Expression* >& bparams ) {
441                        if ( aparams.size() != bparams.size() ) return false;
442
443                        for ( list< Expression* >::const_iterator at = aparams.begin(), bt = bparams.begin();
444                                        at != aparams.end(); ++at, ++bt ) {
445                                TypeExpr *aparam = dynamic_cast< TypeExpr* >(*at);
446                                assertf(aparam, "Aggregate parameters should be type expressions");
447                                TypeExpr *bparam = dynamic_cast< TypeExpr* >(*bt);
448                                assertf(bparam, "Aggregate parameters should be type expressions");
449
450                                // xxx - might need to let VoidType be a wildcard here too; could have some voids
451                                // stuffed in for dtype-statics.
452                                // if ( is<VoidType>( aparam->get_type() ) || is<VoidType>( bparam->get_type() ) ) continue;
453                                if ( ! typesPolyCompatible( aparam->get_type(), bparam->get_type() ) ) return false;
454                        }
455
456                        return true;
457                }
458
459                bool paramListsPolyCompatible(
460                                std::vector<ast::ptr<ast::Expr>> const & lparams,
461                                std::vector<ast::ptr<ast::Expr>> const & rparams ) {
462                        if ( lparams.size() != rparams.size() ) {
463                                return false;
464                        }
465
466                        for ( auto lparam = lparams.begin(), rparam = rparams.begin() ;
467                                        lparam != lparams.end() ; ++lparam, ++rparam ) {
468                                ast::TypeExpr const * lexpr = lparam->as<ast::TypeExpr>();
469                                assertf( lexpr, "Aggregate parameters should be type expressions" );
470                                ast::TypeExpr const * rexpr = rparam->as<ast::TypeExpr>();
471                                assertf( rexpr, "Aggregate parameters should be type expressions" );
472
473                                // xxx - might need to let VoidType be a wildcard here too; could have some voids
474                                // stuffed in for dtype-statics.
475                                // if ( is<VoidType>( lexpr->type() ) || is<VoidType>( bparam->get_type() ) ) continue;
476                                if ( !typesPolyCompatible( lexpr->type, rexpr->type ) ) {
477                                        return false;
478                                }
479                        }
480
481                        return true;
482                }
483        }
484
485        bool typesPolyCompatible( Type *a, Type *b ) {
486                type_index aid{ typeid(*a) };
487                // polymorphic types always match
488                if ( aid == type_index{typeid(TypeInstType)} ) return true;
489
490                type_index bid{ typeid(*b) };
491                // polymorphic types always match
492                if ( bid == type_index{typeid(TypeInstType)} ) return true;
493
494                // can't match otherwise if different types
495                if ( aid != bid ) return false;
496
497                // recurse through type structure (conditions borrowed from Unify.cc)
498                if ( aid == type_index{typeid(BasicType)} ) {
499                        return as<BasicType>(a)->get_kind() == as<BasicType>(b)->get_kind();
500                } else if ( aid == type_index{typeid(PointerType)} ) {
501                        PointerType *ap = as<PointerType>(a), *bp = as<PointerType>(b);
502
503                        // void pointers should match any other pointer type
504                        return is<VoidType>( ap->get_base() ) || is<VoidType>( bp->get_base() )
505                                || typesPolyCompatible( ap->get_base(), bp->get_base() );
506                } else if ( aid == type_index{typeid(ReferenceType)} ) {
507                        ReferenceType *ap = as<ReferenceType>(a), *bp = as<ReferenceType>(b);
508                        return is<VoidType>( ap->get_base() ) || is<VoidType>( bp->get_base() )
509                                || typesPolyCompatible( ap->get_base(), bp->get_base() );
510                } else if ( aid == type_index{typeid(ArrayType)} ) {
511                        ArrayType *aa = as<ArrayType>(a), *ba = as<ArrayType>(b);
512
513                        if ( aa->get_isVarLen() ) {
514                                if ( ! ba->get_isVarLen() ) return false;
515                        } else {
516                                if ( ba->get_isVarLen() ) return false;
517
518                                ConstantExpr *ad = dynamic_cast<ConstantExpr*>( aa->get_dimension() );
519                                ConstantExpr *bd = dynamic_cast<ConstantExpr*>( ba->get_dimension() );
520                                if ( ad && bd
521                                                && ad->get_constant()->get_value() != bd->get_constant()->get_value() )
522                                        return false;
523                        }
524
525                        return typesPolyCompatible( aa->get_base(), ba->get_base() );
526                } else if ( aid == type_index{typeid(FunctionType)} ) {
527                        FunctionType *af = as<FunctionType>(a), *bf = as<FunctionType>(b);
528
529                        vector<Type*> aparams, bparams;
530                        flattenList( af->get_parameters(), back_inserter( aparams ) );
531                        flattenList( bf->get_parameters(), back_inserter( bparams ) );
532                        if ( aparams.size() != bparams.size() ) return false;
533
534                        vector<Type*> areturns, breturns;
535                        flattenList( af->get_returnVals(), back_inserter( areturns ) );
536                        flattenList( bf->get_returnVals(), back_inserter( breturns ) );
537                        if ( areturns.size() != breturns.size() ) return false;
538
539                        for ( unsigned i = 0; i < aparams.size(); ++i ) {
540                                if ( ! typesPolyCompatible( aparams[i], bparams[i] ) ) return false;
541                        }
542                        for ( unsigned i = 0; i < areturns.size(); ++i ) {
543                                if ( ! typesPolyCompatible( areturns[i], breturns[i] ) ) return false;
544                        }
545                        return true;
546                } else if ( aid == type_index{typeid(StructInstType)} ) {
547                        StructInstType *aa = as<StructInstType>(a), *ba = as<StructInstType>(b);
548
549                        if ( aa->get_name() != ba->get_name() ) return false;
550                        return paramListsPolyCompatible( aa->get_parameters(), ba->get_parameters() );
551                } else if ( aid == type_index{typeid(UnionInstType)} ) {
552                        UnionInstType *aa = as<UnionInstType>(a), *ba = as<UnionInstType>(b);
553
554                        if ( aa->get_name() != ba->get_name() ) return false;
555                        return paramListsPolyCompatible( aa->get_parameters(), ba->get_parameters() );
556                } else if ( aid == type_index{typeid(EnumInstType)} ) {
557                        return as<EnumInstType>(a)->get_name() == as<EnumInstType>(b)->get_name();
558                } else if ( aid == type_index{typeid(TraitInstType)} ) {
559                        return as<TraitInstType>(a)->get_name() == as<TraitInstType>(b)->get_name();
560                } else if ( aid == type_index{typeid(TupleType)} ) {
561                        TupleType *at = as<TupleType>(a), *bt = as<TupleType>(b);
562
563                        vector<Type*> atypes, btypes;
564                        flattenList( at->get_types(), back_inserter( atypes ) );
565                        flattenList( bt->get_types(), back_inserter( btypes ) );
566                        if ( atypes.size() != btypes.size() ) return false;
567
568                        for ( unsigned i = 0; i < atypes.size(); ++i ) {
569                                if ( ! typesPolyCompatible( atypes[i], btypes[i] ) ) return false;
570                        }
571                        return true;
572                } else return true; // VoidType, VarArgsType, ZeroType & OneType just need the same type
573        }
574
575bool typesPolyCompatible( ast::Type const * lhs, ast::Type const * rhs ) {
576        type_index const lid = typeid(*lhs);
577
578        // Polymorphic types always match:
579        if ( type_index(typeid(ast::TypeInstType)) == lid ) return true;
580
581        type_index const rid = typeid(*rhs);
582        if ( type_index(typeid(ast::TypeInstType)) == rid ) return true;
583
584        // All other types only match if they are the same type:
585        if ( lid != rid ) return false;
586
587        // So remaining types can be examined case by case.
588        // Recurse through type structure (conditions borrowed from Unify.cc).
589
590        if ( type_index(typeid(ast::BasicType)) == lid ) {
591                return as<ast::BasicType>(lhs)->kind == as<ast::BasicType>(rhs)->kind;
592        } else if ( type_index(typeid(ast::PointerType)) == lid ) {
593                ast::PointerType const * l = as<ast::PointerType>(lhs);
594                ast::PointerType const * r = as<ast::PointerType>(rhs);
595
596                // void pointers should match any other pointer type.
597                return is<ast::VoidType>( l->base.get() )
598                        || is<ast::VoidType>( r->base.get() )
599                        || typesPolyCompatible( l->base.get(), r->base.get() );
600        } else if ( type_index(typeid(ast::ReferenceType)) == lid ) {
601                ast::ReferenceType const * l = as<ast::ReferenceType>(lhs);
602                ast::ReferenceType const * r = as<ast::ReferenceType>(rhs);
603
604                // void references should match any other reference type.
605                return is<ast::VoidType>( l->base.get() )
606                        || is<ast::VoidType>( r->base.get() )
607                        || typesPolyCompatible( l->base.get(), r->base.get() );
608        } else if ( type_index(typeid(ast::ArrayType)) == lid ) {
609                ast::ArrayType const * l = as<ast::ArrayType>(lhs);
610                ast::ArrayType const * r = as<ast::ArrayType>(rhs);
611
612                if ( l->isVarLen ) {
613                        if ( !r->isVarLen ) return false;
614                } else {
615                        if ( r->isVarLen ) return false;
616
617                        auto lc = l->dimension.as<ast::ConstantExpr>();
618                        auto rc = r->dimension.as<ast::ConstantExpr>();
619                        if ( lc && rc && lc->intValue() != rc->intValue() ) {
620                                return false;
621                        }
622                }
623
624                return typesPolyCompatible( l->base.get(), r->base.get() );
625        } else if ( type_index(typeid(ast::FunctionType)) == lid ) {
626                ast::FunctionType const * l = as<ast::FunctionType>(lhs);
627                ast::FunctionType const * r = as<ast::FunctionType>(rhs);
628
629                std::vector<ast::ptr<ast::Type>> lparams, rparams;
630                flattenList( l->params, lparams );
631                flattenList( r->params, rparams );
632                if ( lparams.size() != rparams.size() ) return false;
633                for ( unsigned i = 0; i < lparams.size(); ++i ) {
634                        if ( !typesPolyCompatible( lparams[i], rparams[i] ) ) return false;
635                }
636
637                std::vector<ast::ptr<ast::Type>> lrets, rrets;
638                flattenList( l->returns, lrets );
639                flattenList( r->returns, rrets );
640                if ( lrets.size() != rrets.size() ) return false;
641                for ( unsigned i = 0; i < lrets.size(); ++i ) {
642                        if ( !typesPolyCompatible( lrets[i], rrets[i] ) ) return false;
643                }
644                return true;
645        } else if ( type_index(typeid(ast::StructInstType)) == lid ) {
646                ast::StructInstType const * l = as<ast::StructInstType>(lhs);
647                ast::StructInstType const * r = as<ast::StructInstType>(rhs);
648
649                if ( l->name != r->name ) return false;
650                return paramListsPolyCompatible( l->params, r->params );
651        } else if ( type_index(typeid(ast::UnionInstType)) == lid ) {
652                ast::UnionInstType const * l = as<ast::UnionInstType>(lhs);
653                ast::UnionInstType const * r = as<ast::UnionInstType>(rhs);
654
655                if ( l->name != r->name ) return false;
656                return paramListsPolyCompatible( l->params, r->params );
657        } else if ( type_index(typeid(ast::EnumInstType)) == lid ) {
658                ast::EnumInstType const * l = as<ast::EnumInstType>(lhs);
659                ast::EnumInstType const * r = as<ast::EnumInstType>(rhs);
660
661                return l->name == r->name;
662        } else if ( type_index(typeid(ast::TraitInstType)) == lid ) {
663                ast::TraitInstType const * l = as<ast::TraitInstType>(lhs);
664                ast::TraitInstType const * r = as<ast::TraitInstType>(rhs);
665
666                return l->name == r->name;
667        } else if ( type_index(typeid(ast::TupleType)) == lid ) {
668                ast::TupleType const * l = as<ast::TupleType>(lhs);
669                ast::TupleType const * r = as<ast::TupleType>(rhs);
670
671                std::vector<ast::ptr<ast::Type>> ltypes, rtypes;
672                flattenList( l->types, ( ltypes ) );
673                flattenList( r->types, ( rtypes ) );
674                if ( ltypes.size() != rtypes.size() ) return false;
675
676                for ( unsigned i = 0 ; i < ltypes.size() ; ++i ) {
677                        if ( !typesPolyCompatible( ltypes[i], rtypes[i] ) ) return false;
678                }
679                return true;
680        // The remaining types (VoidType, VarArgsType, ZeroType & OneType)
681        // have no variation so will always be equal.
682        } else {
683                return true;
684        }
685}
686
687        namespace {
688                // temporary hack to avoid re-implementing anything related to TyVarMap
689                // does this work? these two structs have identical definitions.
690                inline TypeDecl::Data convData(const ast::TypeDecl::Data & data) {
691                        return *reinterpret_cast<const TypeDecl::Data *>(&data);
692                }
693        }
694
695        bool needsBoxing( Type * param, Type * arg, const TyVarMap &exprTyVars, const TypeSubstitution * env ) {
696                // is parameter is not polymorphic, don't need to box
697                if ( ! isPolyType( param, exprTyVars ) ) return false;
698                Type * newType = arg->clone();
699                if ( env ) env->apply( newType );
700                std::unique_ptr<Type> manager( newType );
701                // if the argument's type is polymorphic, we don't need to box again!
702                return ! isPolyType( newType );
703        }
704
705        bool needsBoxing( const ast::Type * param, const ast::Type * arg, const TyVarMap &exprTyVars, const ast::TypeSubstitution * env) {
706                // is parameter is not polymorphic, don't need to box
707                if ( ! isPolyType( param, exprTyVars ) ) return false;
708                ast::ptr<ast::Type> newType = arg;
709                if ( env ) env->apply( newType );
710                // if the argument's type is polymorphic, we don't need to box again!
711                return ! isPolyType( newType );
712        }
713
714        bool needsBoxing( Type * param, Type * arg, ApplicationExpr * appExpr, const TypeSubstitution * env ) {
715                FunctionType * function = getFunctionType( appExpr->function->result );
716                assertf( function, "ApplicationExpr has non-function type: %s", toString( appExpr->function->result ).c_str() );
717                TyVarMap exprTyVars( TypeDecl::Data{} );
718                makeTyVarMap( function, exprTyVars );
719                return needsBoxing( param, arg, exprTyVars, env );
720        }
721
722        bool needsBoxing( const ast::Type * param, const ast::Type * arg, const ast::ApplicationExpr * appExpr, const ast::TypeSubstitution * env) {
723                const ast::FunctionType * function = getFunctionType(appExpr->func->result);
724                assertf( function, "ApplicationExpr has non-function type: %s", toString( appExpr->func->result ).c_str() );
725                TyVarMap exprTyVars(TypeDecl::Data{});
726                makeTyVarMap(function, exprTyVars);
727                return needsBoxing(param, arg, exprTyVars, env);
728
729        }
730
731        void addToTyVarMap( TypeDecl * tyVar, TyVarMap &tyVarMap ) {
732                tyVarMap.insert( tyVar->name, TypeDecl::Data{ tyVar } );
733        }
734
735        void addToTyVarMap( const ast::TypeInstType * tyVar, TyVarMap & tyVarMap) {
736                tyVarMap.insert(tyVar->typeString(), convData(ast::TypeDecl::Data{tyVar->base}));
737        }
738
739        void makeTyVarMap( Type *type, TyVarMap &tyVarMap ) {
740                for ( Type::ForallList::const_iterator tyVar = type->get_forall().begin(); tyVar != type->get_forall().end(); ++tyVar ) {
741                        assert( *tyVar );
742                        addToTyVarMap( *tyVar, tyVarMap );
743                }
744                if ( PointerType *pointer = dynamic_cast< PointerType* >( type ) ) {
745                        makeTyVarMap( pointer->get_base(), tyVarMap );
746                }
747        }
748
749        void makeTyVarMap(const ast::Type * type, TyVarMap & tyVarMap) {
750                if (auto ptype = dynamic_cast<const ast::FunctionType *>(type)) {
751                        for (auto & tyVar : ptype->forall) {
752                                assert (tyVar);
753                                addToTyVarMap(tyVar, tyVarMap);
754                        }
755                }
756                if (auto pointer = dynamic_cast<const ast::PointerType *>(type)) {
757                        makeTyVarMap(pointer->base, tyVarMap);
758                }
759        }
760
761        void printTyVarMap( std::ostream &os, const TyVarMap &tyVarMap ) {
762                for ( TyVarMap::const_iterator i = tyVarMap.begin(); i != tyVarMap.end(); ++i ) {
763                        os << i->first << " (" << i->second << ") ";
764                } // for
765                os << std::endl;
766        }
767
768} // namespace GenPoly
769
770// Local Variables: //
771// tab-width: 4 //
772// mode: c++ //
773// compile-command: "make install" //
774// End: //
Note: See TracBrowser for help on using the repository browser.