source: src/GenPoly/GenPoly.cc @ 5408b59

ADTast-experimental
Last change on this file since 5408b59 was c8837e5, checked in by Andrew Beach <ajbeach@…>, 21 months ago

Rewrite in GenPoly? to avoid mixing new AST and TyVarMap? (which internally has old AST code). Some nearby functions got writen out even though they are not used, and hence not tested.

  • Property mode set to 100644
File size: 31.5 KB
Line 
1//
2// Cforall Version 1.0.0 Copyright (C) 2015 University of Waterloo
3//
4// The contents of this file are covered under the licence agreement in the
5// file "LICENCE" distributed with Cforall.
6//
7// GenPoly.cc --
8//
9// Author           : Richard C. Bilson
10// Created On       : Mon May 18 07:44:20 2015
11// Last Modified By : Andrew Beach
12// Last Modified On : Fri Oct  7 15:25:00 2022
13// Update Count     : 16
14//
15
16#include "GenPoly.h"
17
18#include <cassert>                      // for assertf, assert
19#include <iostream>                     // for operator<<, ostream, basic_os...
20#include <iterator>                     // for back_insert_iterator, back_in...
21#include <list>                         // for list, _List_iterator, list<>:...
22#include <typeindex>                    // for type_index
23#include <utility>                      // for pair
24#include <vector>                       // for vector
25
26#include "AST/Type.hpp"
27#include "GenPoly/ErasableScopedMap.h"  // for ErasableScopedMap<>::const_it...
28#include "ResolvExpr/typeops.h"         // for flatten
29#include "SynTree/Constant.h"           // for Constant
30#include "SynTree/Expression.h"         // for Expression, TypeExpr, Constan...
31#include "SynTree/Type.h"               // for Type, StructInstType, UnionIn...
32#include "SynTree/TypeSubstitution.h"   // for TypeSubstitution
33
34using namespace std;
35
36namespace GenPoly {
37        namespace {
38                /// Checks a parameter list for polymorphic parameters; will substitute according to env if present
39                bool hasPolyParams( std::list< Expression* >& params, const TypeSubstitution *env ) {
40                        for ( std::list< Expression* >::iterator param = params.begin(); param != params.end(); ++param ) {
41                                TypeExpr *paramType = dynamic_cast< TypeExpr* >( *param );
42                                assertf(paramType, "Aggregate parameters should be type expressions");
43                                if ( isPolyType( paramType->get_type(), env ) ) return true;
44                        }
45                        return false;
46                }
47
48                bool hasPolyParams( const std::vector<ast::ptr<ast::Expr>> & params, const ast::TypeSubstitution * env) {
49                        for (auto &param : params) {
50                                auto paramType = param.strict_as<ast::TypeExpr>();
51                                if (isPolyType(paramType->type, env)) return true;
52                        }
53                        return false;
54                }
55
56                /// Checks a parameter list for polymorphic parameters from tyVars; will substitute according to env if present
57                bool hasPolyParams( std::list< Expression* >& params, const TyVarMap &tyVars, const TypeSubstitution *env ) {
58                        for ( std::list< Expression* >::iterator param = params.begin(); param != params.end(); ++param ) {
59                                TypeExpr *paramType = dynamic_cast< TypeExpr* >( *param );
60                                assertf(paramType, "Aggregate parameters should be type expressions");
61                                if ( isPolyType( paramType->get_type(), tyVars, env ) ) return true;
62                        }
63                        return false;
64                }
65
66                /// Checks a parameter list for dynamic-layout parameters from tyVars; will substitute according to env if present
67                bool hasDynParams( std::list< Expression* >& params, const TyVarMap &tyVars, const TypeSubstitution *env ) {
68                        for ( std::list< Expression* >::iterator param = params.begin(); param != params.end(); ++param ) {
69                                TypeExpr *paramType = dynamic_cast< TypeExpr* >( *param );
70                                assertf(paramType, "Aggregate parameters should be type expressions");
71                                if ( isDynType( paramType->get_type(), tyVars, env ) ) return true;
72                        }
73                        return false;
74                }
75
76                bool hasDynParams(
77                                const std::vector<ast::ptr<ast::Expr>> & params,
78                                const TypeVarMap & typeVars,
79                                const ast::TypeSubstitution * subst ) {
80                        for ( ast::ptr<ast::Expr> const & paramExpr : params ) {
81                                auto param = paramExpr.as<ast::TypeExpr>();
82                                assertf( param, "Aggregate parameters should be type expressions." );
83                                if ( isDynType( param->type.get(), typeVars, subst ) ) {
84                                        return true;
85                                }
86                        }
87                        return false;
88                }
89
90                /// Checks a parameter list for inclusion of polymorphic parameters; will substitute according to env if present
91                bool includesPolyParams( std::list< Expression* >& params, const TypeSubstitution *env ) {
92                        for ( std::list< Expression* >::iterator param = params.begin(); param != params.end(); ++param ) {
93                                TypeExpr *paramType = dynamic_cast< TypeExpr* >( *param );
94                                assertf(paramType, "Aggregate parameters should be type expressions");
95                                if ( includesPolyType( paramType->get_type(), env ) ) return true;
96                        }
97                        return false;
98                }
99
100                /// Checks a parameter list for inclusion of polymorphic parameters from tyVars; will substitute according to env if present
101                bool includesPolyParams( std::list< Expression* >& params, const TyVarMap &tyVars, const TypeSubstitution *env ) {
102                        for ( std::list< Expression* >::iterator param = params.begin(); param != params.end(); ++param ) {
103                                TypeExpr *paramType = dynamic_cast< TypeExpr* >( *param );
104                                assertf(paramType, "Aggregate parameters should be type expressions");
105                                if ( includesPolyType( paramType->get_type(), tyVars, env ) ) return true;
106                        }
107                        return false;
108                }
109        }
110
111        Type* replaceTypeInst( Type* type, const TypeSubstitution* env ) {
112                if ( ! env ) return type;
113                if ( TypeInstType *typeInst = dynamic_cast< TypeInstType* >( type ) ) {
114                        Type *newType = env->lookup( typeInst->get_name() );
115                        if ( newType ) return newType;
116                }
117                return type;
118        }
119
120        const ast::Type * replaceTypeInst(const ast::Type * type, const ast::TypeSubstitution * env) {
121                if (!env) return type;
122                if (auto typeInst = dynamic_cast<const ast::TypeInstType*> (type)) {
123                        auto newType = env->lookup(typeInst);
124                        if (newType) return newType;
125                }
126                return type;
127        }
128
129        Type *isPolyType( Type *type, const TypeSubstitution *env ) {
130                type = replaceTypeInst( type, env );
131
132                if ( dynamic_cast< TypeInstType * >( type ) ) {
133                        return type;
134                } else if ( ArrayType * arrayType = dynamic_cast< ArrayType * >( type ) ) {
135                        return isPolyType( arrayType->base, env );
136                } else if ( StructInstType *structType = dynamic_cast< StructInstType* >( type ) ) {
137                        if ( hasPolyParams( structType->get_parameters(), env ) ) return type;
138                } else if ( UnionInstType *unionType = dynamic_cast< UnionInstType* >( type ) ) {
139                        if ( hasPolyParams( unionType->get_parameters(), env ) ) return type;
140                }
141                return 0;
142        }
143
144        const ast::Type * isPolyType(const ast::Type * type, const ast::TypeSubstitution * env) {
145                type = replaceTypeInst( type, env );
146
147                if ( dynamic_cast< const ast::TypeInstType * >( type ) ) {
148                        return type;
149                } else if ( auto arrayType = dynamic_cast< const ast::ArrayType * >( type ) ) {
150                        return isPolyType( arrayType->base, env );
151                } else if ( auto structType = dynamic_cast< const ast::StructInstType* >( type ) ) {
152                        if ( hasPolyParams( structType->params, env ) ) return type;
153                } else if ( auto unionType = dynamic_cast< const ast::UnionInstType* >( type ) ) {
154                        if ( hasPolyParams( unionType->params, env ) ) return type;
155                }
156                return 0;
157        }
158
159        Type *isPolyType( Type *type, const TyVarMap &tyVars, const TypeSubstitution *env ) {
160                type = replaceTypeInst( type, env );
161
162                if ( TypeInstType *typeInst = dynamic_cast< TypeInstType * >( type ) ) {
163                        if ( tyVars.find( typeInst->get_name() ) != tyVars.end() ) {
164                                return type;
165                        }
166                } else if ( ArrayType * arrayType = dynamic_cast< ArrayType * >( type ) ) {
167                        return isPolyType( arrayType->base, tyVars, env );
168                } else if ( StructInstType *structType = dynamic_cast< StructInstType* >( type ) ) {
169                        if ( hasPolyParams( structType->get_parameters(), tyVars, env ) ) return type;
170                } else if ( UnionInstType *unionType = dynamic_cast< UnionInstType* >( type ) ) {
171                        if ( hasPolyParams( unionType->get_parameters(), tyVars, env ) ) return type;
172                }
173                return 0;
174        }
175
176        const ast::Type * isPolyType(const ast::Type * type, const TyVarMap & tyVars, const ast::TypeSubstitution * env) {
177                type = replaceTypeInst( type, env );
178
179                if ( auto typeInst = dynamic_cast< const ast::TypeInstType * >( type ) ) {
180                        return tyVars.find(typeInst->typeString()) != tyVars.end() ? type : nullptr;
181                } else if ( auto arrayType = dynamic_cast< const ast::ArrayType * >( type ) ) {
182                        return isPolyType( arrayType->base, env );
183                } else if ( auto structType = dynamic_cast< const ast::StructInstType* >( type ) ) {
184                        if ( hasPolyParams( structType->params, env ) ) return type;
185                } else if ( auto unionType = dynamic_cast< const ast::UnionInstType* >( type ) ) {
186                        if ( hasPolyParams( unionType->params, env ) ) return type;
187                }
188                return nullptr;
189        }
190
191const ast::Type * isPolyType( const ast::Type * type,
192                const TypeVarMap & typeVars, const ast::TypeSubstitution * subst ) {
193        type = replaceTypeInst( type, subst );
194
195        if ( auto inst = dynamic_cast< const ast::TypeInstType * >( type ) ) {
196                if ( typeVars.find( inst->typeString() ) != typeVars.end() ) return type;
197        } else if ( auto array = dynamic_cast< const ast::ArrayType * >( type ) ) {
198                return isPolyType( array->base, subst );
199        } else if ( auto sue = dynamic_cast< const ast::StructInstType * >( type ) ) {
200                if ( hasPolyParams( sue->params, subst ) ) return type;
201        } else if ( auto sue = dynamic_cast< const ast::UnionInstType * >( type ) ) {
202                if ( hasPolyParams( sue->params, subst ) ) return type;
203        }
204        return nullptr;
205}
206
207        ReferenceToType *isDynType( Type *type, const TyVarMap &tyVars, const TypeSubstitution *env ) {
208                type = replaceTypeInst( type, env );
209
210                if ( TypeInstType *typeInst = dynamic_cast< TypeInstType * >( type ) ) {
211                        auto var = tyVars.find( typeInst->get_name() );
212                        if ( var != tyVars.end() && var->second.isComplete ) {
213                                return typeInst;
214                        }
215                } else if ( StructInstType *structType = dynamic_cast< StructInstType* >( type ) ) {
216                        if ( hasDynParams( structType->get_parameters(), tyVars, env ) ) return structType;
217                } else if ( UnionInstType *unionType = dynamic_cast< UnionInstType* >( type ) ) {
218                        if ( hasDynParams( unionType->get_parameters(), tyVars, env ) ) return unionType;
219                }
220                return 0;
221        }
222
223const ast::BaseInstType * isDynType(
224                const ast::Type * type, const TypeVarMap & typeVars,
225                const ast::TypeSubstitution * subst ) {
226        type = replaceTypeInst( type, subst );
227
228        if ( auto inst = dynamic_cast<ast::TypeInstType const *>( type ) ) {
229                auto var = typeVars.find( inst->name );
230                if ( var != typeVars.end() && var->second.isComplete ) {
231
232                }
233        } else if ( auto inst = dynamic_cast<ast::StructInstType const *>( type ) ) {
234                if ( hasDynParams( inst->params, typeVars, subst ) ) {
235                        return inst;
236                }
237        } else if ( auto inst = dynamic_cast<ast::UnionInstType const *>( type ) ) {
238                if ( hasDynParams( inst->params, typeVars, subst ) ) {
239                        return inst;
240                }
241        }
242        return nullptr;
243}
244
245        ReferenceToType *isDynRet( FunctionType *function, const TyVarMap &forallTypes ) {
246                if ( function->get_returnVals().empty() ) return 0;
247
248                return (ReferenceToType*)isDynType( function->get_returnVals().front()->get_type(), forallTypes );
249        }
250
251const ast::BaseInstType *isDynRet(
252                const ast::FunctionType * type, const TypeVarMap & typeVars ) {
253        if ( type->returns.empty() ) return nullptr;
254
255        return isDynType( type->returns.front(), typeVars );
256}
257
258        ReferenceToType *isDynRet( FunctionType *function ) {
259                if ( function->get_returnVals().empty() ) return 0;
260
261                TyVarMap forallTypes( TypeDecl::Data{} );
262                makeTyVarMap( function, forallTypes );
263                return (ReferenceToType*)isDynType( function->get_returnVals().front()->get_type(), forallTypes );
264        }
265
266        bool needsAdapter( FunctionType *adaptee, const TyVarMap &tyVars ) {
267//              if ( ! adaptee->get_returnVals().empty() && isPolyType( adaptee->get_returnVals().front()->get_type(), tyVars ) ) {
268//                      return true;
269//              } // if
270                if ( isDynRet( adaptee, tyVars ) ) return true;
271
272                for ( std::list< DeclarationWithType* >::const_iterator innerArg = adaptee->get_parameters().begin(); innerArg != adaptee->get_parameters().end(); ++innerArg ) {
273//                      if ( isPolyType( (*innerArg)->get_type(), tyVars ) ) {
274                        if ( isDynType( (*innerArg)->get_type(), tyVars ) ) {
275                                return true;
276                        } // if
277                } // for
278                return false;
279        }
280
281bool needsAdapter(
282                ast::FunctionType const * adaptee, const TypeVarMap & typeVars ) {
283        if ( isDynRet( adaptee, typeVars ) ) return true;
284
285        for ( auto param : adaptee->params ) {
286                if ( isDynType( param, typeVars ) ) {
287                        return true;
288                }
289        }
290        return false;
291}
292
293        Type *isPolyPtr( Type *type, const TypeSubstitution *env ) {
294                type = replaceTypeInst( type, env );
295
296                if ( PointerType *ptr = dynamic_cast< PointerType *>( type ) ) {
297                        return isPolyType( ptr->get_base(), env );
298                }
299                return 0;
300        }
301
302        Type *isPolyPtr( Type *type, const TyVarMap &tyVars, const TypeSubstitution *env ) {
303                type = replaceTypeInst( type, env );
304
305                if ( PointerType *ptr = dynamic_cast< PointerType *>( type ) ) {
306                        return isPolyType( ptr->get_base(), tyVars, env );
307                }
308                return 0;
309        }
310
311        Type * hasPolyBase( Type *type, int *levels, const TypeSubstitution *env ) {
312                int dummy;
313                if ( ! levels ) { levels = &dummy; }
314                *levels = 0;
315
316                while ( true ) {
317                        type = replaceTypeInst( type, env );
318
319                        if ( PointerType *ptr = dynamic_cast< PointerType *>( type ) ) {
320                                type = ptr->get_base();
321                                ++(*levels);
322                        } else break;
323                }
324
325                return isPolyType( type, env );
326        }
327
328        Type * hasPolyBase( Type *type, const TyVarMap &tyVars, int *levels, const TypeSubstitution *env ) {
329                int dummy;
330                if ( ! levels ) { levels = &dummy; }
331                *levels = 0;
332
333                while ( true ) {
334                        type = replaceTypeInst( type, env );
335
336                        if ( PointerType *ptr = dynamic_cast< PointerType *>( type ) ) {
337                                type = ptr->get_base();
338                                ++(*levels);
339                        } else break;
340                }
341
342                return isPolyType( type, tyVars, env );
343        }
344
345ast::Type const * hasPolyBase(
346                ast::Type const * type, const TypeVarMap & typeVars,
347                int * levels, const ast::TypeSubstitution * subst ) {
348        int level_count = 0;
349
350        while ( true ) {
351                type = replaceTypeInst( type, subst );
352
353                if ( auto ptr = dynamic_cast<ast::PointerType const *>( type ) ) {
354                        type = ptr->base;
355                        ++level_count;
356                } else {
357                        break;
358                }
359        }
360
361        if ( nullptr != levels ) { *levels = level_count; }
362        return isPolyType( type, typeVars, subst );
363}
364
365        bool includesPolyType( Type *type, const TypeSubstitution *env ) {
366                type = replaceTypeInst( type, env );
367
368                if ( dynamic_cast< TypeInstType * >( type ) ) {
369                        return true;
370                } else if ( PointerType *pointerType = dynamic_cast< PointerType* >( type ) ) {
371                        if ( includesPolyType( pointerType->get_base(), env ) ) return true;
372                } else if ( StructInstType *structType = dynamic_cast< StructInstType* >( type ) ) {
373                        if ( includesPolyParams( structType->get_parameters(), env ) ) return true;
374                } else if ( UnionInstType *unionType = dynamic_cast< UnionInstType* >( type ) ) {
375                        if ( includesPolyParams( unionType->get_parameters(), env ) ) return true;
376                }
377                return false;
378        }
379
380        bool includesPolyType( Type *type, const TyVarMap &tyVars, const TypeSubstitution *env ) {
381                type = replaceTypeInst( type, env );
382
383                if ( TypeInstType *typeInstType = dynamic_cast< TypeInstType * >( type ) ) {
384                        if ( tyVars.find( typeInstType->get_name() ) != tyVars.end() ) {
385                                return true;
386                        }
387                } else if ( PointerType *pointerType = dynamic_cast< PointerType* >( type ) ) {
388                        if ( includesPolyType( pointerType->get_base(), tyVars, env ) ) return true;
389                } else if ( StructInstType *structType = dynamic_cast< StructInstType* >( type ) ) {
390                        if ( includesPolyParams( structType->get_parameters(), tyVars, env ) ) return true;
391                } else if ( UnionInstType *unionType = dynamic_cast< UnionInstType* >( type ) ) {
392                        if ( includesPolyParams( unionType->get_parameters(), tyVars, env ) ) return true;
393                }
394                return false;
395        }
396
397        FunctionType * getFunctionType( Type *ty ) {
398                PointerType *ptrType;
399                if ( ( ptrType = dynamic_cast< PointerType* >( ty ) ) ) {
400                        return dynamic_cast< FunctionType* >( ptrType->get_base() ); // pointer if FunctionType, NULL otherwise
401                } else {
402                        return dynamic_cast< FunctionType* >( ty ); // pointer if FunctionType, NULL otherwise
403                }
404        }
405
406        const ast::FunctionType * getFunctionType( const ast::Type * ty ) {
407                if ( auto pty = dynamic_cast< const ast::PointerType * >( ty ) ) {
408                        return pty->base.as< ast::FunctionType >();
409                } else {
410                        return dynamic_cast< const ast::FunctionType * >( ty );
411                }
412        }
413
414        VariableExpr * getBaseVar( Expression *expr, int *levels ) {
415                int dummy;
416                if ( ! levels ) { levels = &dummy; }
417                *levels = 0;
418
419                while ( true ) {
420                        if ( VariableExpr *varExpr = dynamic_cast< VariableExpr* >( expr ) ) {
421                                return varExpr;
422                        } else if ( MemberExpr *memberExpr = dynamic_cast< MemberExpr* >( expr ) ) {
423                                expr = memberExpr->get_aggregate();
424                        } else if ( AddressExpr *addressExpr = dynamic_cast< AddressExpr* >( expr ) ) {
425                                expr = addressExpr->get_arg();
426                        } else if ( UntypedExpr *untypedExpr = dynamic_cast< UntypedExpr* >( expr ) ) {
427                                // look for compiler-inserted dereference operator
428                                NameExpr *fn = dynamic_cast< NameExpr* >( untypedExpr->get_function() );
429                                if ( ! fn || fn->get_name() != std::string("*?") ) return 0;
430                                expr = *untypedExpr->begin_args();
431                        } else if ( CommaExpr *commaExpr = dynamic_cast< CommaExpr* >( expr ) ) {
432                                // copy constructors insert comma exprs, look at second argument which contains the variable
433                                expr = commaExpr->get_arg2();
434                                continue;
435                        } else if ( ConditionalExpr * condExpr = dynamic_cast< ConditionalExpr * >( expr ) ) {
436                                int lvl1;
437                                int lvl2;
438                                VariableExpr * var1 = getBaseVar( condExpr->get_arg2(), &lvl1 );
439                                VariableExpr * var2 = getBaseVar( condExpr->get_arg3(), &lvl2 );
440                                if ( lvl1 == lvl2 && var1 && var2 && var1->get_var() == var2->get_var() ) {
441                                        *levels = lvl1;
442                                        return var1;
443                                }
444                                break;
445                        } else break;
446
447                        ++(*levels);
448                }
449
450                return 0;
451        }
452
453        namespace {
454                /// Checks if is a pointer to D
455                template<typename D, typename B>
456                bool is( const B* p ) { return type_index{typeid(D)} == type_index{typeid(*p)}; }
457
458                /// Converts to a pointer to D without checking for safety
459                template<typename D, typename B>
460                inline D* as( B* p ) { return reinterpret_cast<D*>(p); }
461
462                template<typename D, typename B>
463                inline D const * as( B const * p ) {
464                        return reinterpret_cast<D const *>( p );
465                }
466
467                /// Flattens a declaration list
468                template<typename Output>
469                void flattenList( list< DeclarationWithType* > src, Output out ) {
470                        for ( DeclarationWithType* decl : src ) {
471                                ResolvExpr::flatten( decl->get_type(), out );
472                        }
473                }
474
475                /// Flattens a list of types
476                template<typename Output>
477                void flattenList( list< Type* > src, Output out ) {
478                        for ( Type* ty : src ) {
479                                ResolvExpr::flatten( ty, out );
480                        }
481                }
482
483                void flattenList( vector<ast::ptr<ast::Type>> const & src,
484                                vector<ast::ptr<ast::Type>> & out ) {
485                        for ( auto const & type : src ) {
486                                ResolvExpr::flatten( type, out );
487                        }
488                }
489
490                /// Checks if two lists of parameters are equal up to polymorphic substitution.
491                bool paramListsPolyCompatible( const list< Expression* >& aparams, const list< Expression* >& bparams ) {
492                        if ( aparams.size() != bparams.size() ) return false;
493
494                        for ( list< Expression* >::const_iterator at = aparams.begin(), bt = bparams.begin();
495                                        at != aparams.end(); ++at, ++bt ) {
496                                TypeExpr *aparam = dynamic_cast< TypeExpr* >(*at);
497                                assertf(aparam, "Aggregate parameters should be type expressions");
498                                TypeExpr *bparam = dynamic_cast< TypeExpr* >(*bt);
499                                assertf(bparam, "Aggregate parameters should be type expressions");
500
501                                // xxx - might need to let VoidType be a wildcard here too; could have some voids
502                                // stuffed in for dtype-statics.
503                                // if ( is<VoidType>( aparam->get_type() ) || is<VoidType>( bparam->get_type() ) ) continue;
504                                if ( ! typesPolyCompatible( aparam->get_type(), bparam->get_type() ) ) return false;
505                        }
506
507                        return true;
508                }
509
510                bool paramListsPolyCompatible(
511                                std::vector<ast::ptr<ast::Expr>> const & lparams,
512                                std::vector<ast::ptr<ast::Expr>> const & rparams ) {
513                        if ( lparams.size() != rparams.size() ) {
514                                return false;
515                        }
516
517                        for ( auto lparam = lparams.begin(), rparam = rparams.begin() ;
518                                        lparam != lparams.end() ; ++lparam, ++rparam ) {
519                                ast::TypeExpr const * lexpr = lparam->as<ast::TypeExpr>();
520                                assertf( lexpr, "Aggregate parameters should be type expressions" );
521                                ast::TypeExpr const * rexpr = rparam->as<ast::TypeExpr>();
522                                assertf( rexpr, "Aggregate parameters should be type expressions" );
523
524                                // xxx - might need to let VoidType be a wildcard here too; could have some voids
525                                // stuffed in for dtype-statics.
526                                // if ( is<VoidType>( lexpr->type() ) || is<VoidType>( bparam->get_type() ) ) continue;
527                                if ( !typesPolyCompatible( lexpr->type, rexpr->type ) ) {
528                                        return false;
529                                }
530                        }
531
532                        return true;
533                }
534        }
535
536        bool typesPolyCompatible( Type *a, Type *b ) {
537                type_index aid{ typeid(*a) };
538                // polymorphic types always match
539                if ( aid == type_index{typeid(TypeInstType)} ) return true;
540
541                type_index bid{ typeid(*b) };
542                // polymorphic types always match
543                if ( bid == type_index{typeid(TypeInstType)} ) return true;
544
545                // can't match otherwise if different types
546                if ( aid != bid ) return false;
547
548                // recurse through type structure (conditions borrowed from Unify.cc)
549                if ( aid == type_index{typeid(BasicType)} ) {
550                        return as<BasicType>(a)->get_kind() == as<BasicType>(b)->get_kind();
551                } else if ( aid == type_index{typeid(PointerType)} ) {
552                        PointerType *ap = as<PointerType>(a), *bp = as<PointerType>(b);
553
554                        // void pointers should match any other pointer type
555                        return is<VoidType>( ap->get_base() ) || is<VoidType>( bp->get_base() )
556                                || typesPolyCompatible( ap->get_base(), bp->get_base() );
557                } else if ( aid == type_index{typeid(ReferenceType)} ) {
558                        ReferenceType *ap = as<ReferenceType>(a), *bp = as<ReferenceType>(b);
559                        return is<VoidType>( ap->get_base() ) || is<VoidType>( bp->get_base() )
560                                || typesPolyCompatible( ap->get_base(), bp->get_base() );
561                } else if ( aid == type_index{typeid(ArrayType)} ) {
562                        ArrayType *aa = as<ArrayType>(a), *ba = as<ArrayType>(b);
563
564                        if ( aa->get_isVarLen() ) {
565                                if ( ! ba->get_isVarLen() ) return false;
566                        } else {
567                                if ( ba->get_isVarLen() ) return false;
568
569                                ConstantExpr *ad = dynamic_cast<ConstantExpr*>( aa->get_dimension() );
570                                ConstantExpr *bd = dynamic_cast<ConstantExpr*>( ba->get_dimension() );
571                                if ( ad && bd
572                                                && ad->get_constant()->get_value() != bd->get_constant()->get_value() )
573                                        return false;
574                        }
575
576                        return typesPolyCompatible( aa->get_base(), ba->get_base() );
577                } else if ( aid == type_index{typeid(FunctionType)} ) {
578                        FunctionType *af = as<FunctionType>(a), *bf = as<FunctionType>(b);
579
580                        vector<Type*> aparams, bparams;
581                        flattenList( af->get_parameters(), back_inserter( aparams ) );
582                        flattenList( bf->get_parameters(), back_inserter( bparams ) );
583                        if ( aparams.size() != bparams.size() ) return false;
584
585                        vector<Type*> areturns, breturns;
586                        flattenList( af->get_returnVals(), back_inserter( areturns ) );
587                        flattenList( bf->get_returnVals(), back_inserter( breturns ) );
588                        if ( areturns.size() != breturns.size() ) return false;
589
590                        for ( unsigned i = 0; i < aparams.size(); ++i ) {
591                                if ( ! typesPolyCompatible( aparams[i], bparams[i] ) ) return false;
592                        }
593                        for ( unsigned i = 0; i < areturns.size(); ++i ) {
594                                if ( ! typesPolyCompatible( areturns[i], breturns[i] ) ) return false;
595                        }
596                        return true;
597                } else if ( aid == type_index{typeid(StructInstType)} ) {
598                        StructInstType *aa = as<StructInstType>(a), *ba = as<StructInstType>(b);
599
600                        if ( aa->get_name() != ba->get_name() ) return false;
601                        return paramListsPolyCompatible( aa->get_parameters(), ba->get_parameters() );
602                } else if ( aid == type_index{typeid(UnionInstType)} ) {
603                        UnionInstType *aa = as<UnionInstType>(a), *ba = as<UnionInstType>(b);
604
605                        if ( aa->get_name() != ba->get_name() ) return false;
606                        return paramListsPolyCompatible( aa->get_parameters(), ba->get_parameters() );
607                } else if ( aid == type_index{typeid(EnumInstType)} ) {
608                        return as<EnumInstType>(a)->get_name() == as<EnumInstType>(b)->get_name();
609                } else if ( aid == type_index{typeid(TraitInstType)} ) {
610                        return as<TraitInstType>(a)->get_name() == as<TraitInstType>(b)->get_name();
611                } else if ( aid == type_index{typeid(TupleType)} ) {
612                        TupleType *at = as<TupleType>(a), *bt = as<TupleType>(b);
613
614                        vector<Type*> atypes, btypes;
615                        flattenList( at->get_types(), back_inserter( atypes ) );
616                        flattenList( bt->get_types(), back_inserter( btypes ) );
617                        if ( atypes.size() != btypes.size() ) return false;
618
619                        for ( unsigned i = 0; i < atypes.size(); ++i ) {
620                                if ( ! typesPolyCompatible( atypes[i], btypes[i] ) ) return false;
621                        }
622                        return true;
623                } else return true; // VoidType, VarArgsType, ZeroType & OneType just need the same type
624        }
625
626bool typesPolyCompatible( ast::Type const * lhs, ast::Type const * rhs ) {
627        type_index const lid = typeid(*lhs);
628
629        // Polymorphic types always match:
630        if ( type_index(typeid(ast::TypeInstType)) == lid ) return true;
631
632        type_index const rid = typeid(*rhs);
633        if ( type_index(typeid(ast::TypeInstType)) == rid ) return true;
634
635        // All other types only match if they are the same type:
636        if ( lid != rid ) return false;
637
638        // So remaining types can be examined case by case.
639        // Recurse through type structure (conditions borrowed from Unify.cc).
640
641        if ( type_index(typeid(ast::BasicType)) == lid ) {
642                return as<ast::BasicType>(lhs)->kind == as<ast::BasicType>(rhs)->kind;
643        } else if ( type_index(typeid(ast::PointerType)) == lid ) {
644                ast::PointerType const * l = as<ast::PointerType>(lhs);
645                ast::PointerType const * r = as<ast::PointerType>(rhs);
646
647                // void pointers should match any other pointer type.
648                return is<ast::VoidType>( l->base.get() )
649                        || is<ast::VoidType>( r->base.get() )
650                        || typesPolyCompatible( l->base.get(), r->base.get() );
651        } else if ( type_index(typeid(ast::ReferenceType)) == lid ) {
652                ast::ReferenceType const * l = as<ast::ReferenceType>(lhs);
653                ast::ReferenceType const * r = as<ast::ReferenceType>(rhs);
654
655                // void references should match any other reference type.
656                return is<ast::VoidType>( l->base.get() )
657                        || is<ast::VoidType>( r->base.get() )
658                        || typesPolyCompatible( l->base.get(), r->base.get() );
659        } else if ( type_index(typeid(ast::ArrayType)) == lid ) {
660                ast::ArrayType const * l = as<ast::ArrayType>(lhs);
661                ast::ArrayType const * r = as<ast::ArrayType>(rhs);
662
663                if ( l->isVarLen ) {
664                        if ( !r->isVarLen ) return false;
665                } else {
666                        if ( r->isVarLen ) return false;
667
668                        auto lc = l->dimension.as<ast::ConstantExpr>();
669                        auto rc = r->dimension.as<ast::ConstantExpr>();
670                        if ( lc && rc && lc->intValue() != rc->intValue() ) {
671                                return false;
672                        }
673                }
674
675                return typesPolyCompatible( l->base.get(), r->base.get() );
676        } else if ( type_index(typeid(ast::FunctionType)) == lid ) {
677                ast::FunctionType const * l = as<ast::FunctionType>(lhs);
678                ast::FunctionType const * r = as<ast::FunctionType>(rhs);
679
680                std::vector<ast::ptr<ast::Type>> lparams, rparams;
681                flattenList( l->params, lparams );
682                flattenList( r->params, rparams );
683                if ( lparams.size() != rparams.size() ) return false;
684                for ( unsigned i = 0; i < lparams.size(); ++i ) {
685                        if ( !typesPolyCompatible( lparams[i], rparams[i] ) ) return false;
686                }
687
688                std::vector<ast::ptr<ast::Type>> lrets, rrets;
689                flattenList( l->returns, lrets );
690                flattenList( r->returns, rrets );
691                if ( lrets.size() != rrets.size() ) return false;
692                for ( unsigned i = 0; i < lrets.size(); ++i ) {
693                        if ( !typesPolyCompatible( lrets[i], rrets[i] ) ) return false;
694                }
695                return true;
696        } else if ( type_index(typeid(ast::StructInstType)) == lid ) {
697                ast::StructInstType const * l = as<ast::StructInstType>(lhs);
698                ast::StructInstType const * r = as<ast::StructInstType>(rhs);
699
700                if ( l->name != r->name ) return false;
701                return paramListsPolyCompatible( l->params, r->params );
702        } else if ( type_index(typeid(ast::UnionInstType)) == lid ) {
703                ast::UnionInstType const * l = as<ast::UnionInstType>(lhs);
704                ast::UnionInstType const * r = as<ast::UnionInstType>(rhs);
705
706                if ( l->name != r->name ) return false;
707                return paramListsPolyCompatible( l->params, r->params );
708        } else if ( type_index(typeid(ast::EnumInstType)) == lid ) {
709                ast::EnumInstType const * l = as<ast::EnumInstType>(lhs);
710                ast::EnumInstType const * r = as<ast::EnumInstType>(rhs);
711
712                return l->name == r->name;
713        } else if ( type_index(typeid(ast::TraitInstType)) == lid ) {
714                ast::TraitInstType const * l = as<ast::TraitInstType>(lhs);
715                ast::TraitInstType const * r = as<ast::TraitInstType>(rhs);
716
717                return l->name == r->name;
718        } else if ( type_index(typeid(ast::TupleType)) == lid ) {
719                ast::TupleType const * l = as<ast::TupleType>(lhs);
720                ast::TupleType const * r = as<ast::TupleType>(rhs);
721
722                std::vector<ast::ptr<ast::Type>> ltypes, rtypes;
723                flattenList( l->types, ( ltypes ) );
724                flattenList( r->types, ( rtypes ) );
725                if ( ltypes.size() != rtypes.size() ) return false;
726
727                for ( unsigned i = 0 ; i < ltypes.size() ; ++i ) {
728                        if ( !typesPolyCompatible( ltypes[i], rtypes[i] ) ) return false;
729                }
730                return true;
731        // The remaining types (VoidType, VarArgsType, ZeroType & OneType)
732        // have no variation so will always be equal.
733        } else {
734                return true;
735        }
736}
737
738        bool needsBoxing( Type * param, Type * arg, const TyVarMap &exprTyVars, const TypeSubstitution * env ) {
739                // is parameter is not polymorphic, don't need to box
740                if ( ! isPolyType( param, exprTyVars ) ) return false;
741                Type * newType = arg->clone();
742                if ( env ) env->apply( newType );
743                std::unique_ptr<Type> manager( newType );
744                // if the argument's type is polymorphic, we don't need to box again!
745                return ! isPolyType( newType );
746        }
747
748bool needsBoxing( const ast::Type * param, const ast::Type * arg,
749                const TypeVarMap & typeVars, const ast::TypeSubstitution * subst ) {
750        // Don't need to box if the parameter is not polymorphic.
751        if ( !isPolyType( param, typeVars ) ) return false;
752
753        ast::ptr<ast::Type> newType = arg;
754        if ( subst ) {
755                int count = subst->apply( newType );
756                (void)count;
757        }
758        // Only need to box if the argument is not also polymorphic.
759        return !isPolyType( newType );
760}
761
762        bool needsBoxing( Type * param, Type * arg, ApplicationExpr * appExpr, const TypeSubstitution * env ) {
763                FunctionType * function = getFunctionType( appExpr->function->result );
764                assertf( function, "ApplicationExpr has non-function type: %s", toString( appExpr->function->result ).c_str() );
765                TyVarMap exprTyVars( TypeDecl::Data{} );
766                makeTyVarMap( function, exprTyVars );
767                return needsBoxing( param, arg, exprTyVars, env );
768        }
769
770bool needsBoxing(
771                const ast::Type * param, const ast::Type * arg,
772                const ast::ApplicationExpr * expr,
773                const ast::TypeSubstitution * subst ) {
774        const ast::FunctionType * function = getFunctionType( expr->func->result );
775        assertf( function, "ApplicationExpr has non-function type: %s", toString( expr->func->result ).c_str() );
776        TypeVarMap exprTyVars = { ast::TypeDecl::Data() };
777        makeTypeVarMap( function, exprTyVars );
778        return needsBoxing( param, arg, exprTyVars, subst );
779}
780
781        void addToTyVarMap( TypeDecl * tyVar, TyVarMap &tyVarMap ) {
782                tyVarMap.insert( tyVar->name, TypeDecl::Data{ tyVar } );
783        }
784
785void addToTypeVarMap( const ast::TypeInstType * type, TypeVarMap & typeVars ) {
786        typeVars.insert( type->typeString(), ast::TypeDecl::Data( type->base ) );
787}
788
789        void makeTyVarMap( Type *type, TyVarMap &tyVarMap ) {
790                for ( Type::ForallList::const_iterator tyVar = type->get_forall().begin(); tyVar != type->get_forall().end(); ++tyVar ) {
791                        assert( *tyVar );
792                        addToTyVarMap( *tyVar, tyVarMap );
793                }
794                if ( PointerType *pointer = dynamic_cast< PointerType* >( type ) ) {
795                        makeTyVarMap( pointer->get_base(), tyVarMap );
796                }
797        }
798
799void makeTypeVarMap( const ast::Type * type, TypeVarMap & typeVars ) {
800        if ( auto func = dynamic_cast<ast::FunctionType const *>( type ) ) {
801                for ( auto & typeVar : func->forall ) {
802                        assert( typeVar );
803                        addToTypeVarMap( typeVar, typeVars );
804                }
805        }
806        if ( auto pointer = dynamic_cast<ast::PointerType const *>( type ) ) {
807                makeTypeVarMap( pointer->base, typeVars );
808        }
809}
810
811        void printTyVarMap( std::ostream &os, const TyVarMap &tyVarMap ) {
812                for ( TyVarMap::const_iterator i = tyVarMap.begin(); i != tyVarMap.end(); ++i ) {
813                        os << i->first << " (" << i->second << ") ";
814                } // for
815                os << std::endl;
816        }
817
818void printTypeVarMap( std::ostream &os, const TypeVarMap & typeVars ) {
819        for ( auto const & pair : typeVars ) {
820                os << pair.first << " (" << pair.second << ") ";
821        } // for
822        os << std::endl;
823}
824
825} // namespace GenPoly
826
827// Local Variables: //
828// tab-width: 4 //
829// mode: c++ //
830// compile-command: "make install" //
831// End: //
Note: See TracBrowser for help on using the repository browser.