source: src/GenPoly/GenPoly.cc @ c6b4432

Last change on this file since c6b4432 was c6b4432, checked in by Andrew Beach <ajbeach@…>, 6 months ago

Remove BaseSyntaxNode? and clean-up.

  • Property mode set to 100644
File size: 14.5 KB
Line 
1//
2// Cforall Version 1.0.0 Copyright (C) 2015 University of Waterloo
3//
4// The contents of this file are covered under the licence agreement in the
5// file "LICENCE" distributed with Cforall.
6//
7// GenPoly.cc --
8//
9// Author           : Richard C. Bilson
10// Created On       : Mon May 18 07:44:20 2015
11// Last Modified By : Andrew Beach
12// Last Modified On : Mon Oct 24 15:19:00 2022
13// Update Count     : 17
14//
15
16#include "GenPoly.h"
17
18#include <cassert>                      // for assertf, assert
19#include <iostream>                     // for operator<<, ostream, basic_os...
20#include <iterator>                     // for back_insert_iterator, back_in...
21#include <list>                         // for list, _List_iterator, list<>:...
22#include <typeindex>                    // for type_index
23#include <utility>                      // for pair
24#include <vector>                       // for vector
25
26#include "AST/Expr.hpp"
27#include "AST/Type.hpp"
28#include "AST/TypeSubstitution.hpp"
29#include "GenPoly/ErasableScopedMap.h"  // for ErasableScopedMap<>::const_it...
30#include "ResolvExpr/typeops.h"         // for flatten
31
32using namespace std;
33
34namespace GenPoly {
35        namespace {
36                /// Checks a parameter list for polymorphic parameters; will substitute according to env if present
37                bool hasPolyParams( const std::vector<ast::ptr<ast::Expr>> & params, const ast::TypeSubstitution * env ) {
38                        for ( auto &param : params ) {
39                                auto paramType = param.as<ast::TypeExpr>();
40                                assertf( paramType, "Aggregate parameters should be type expressions" );
41                                if ( isPolyType( paramType->type, env ) ) return true;
42                        }
43                        return false;
44                }
45
46                /// Checks a parameter list for polymorphic parameters from tyVars; will substitute according to env if present
47                bool hasPolyParams( const std::vector<ast::ptr<ast::Expr>> & params, const TypeVarMap & typeVars, const ast::TypeSubstitution * env ) {
48                        for ( auto & param : params ) {
49                                auto paramType = param.as<ast::TypeExpr>();
50                                assertf( paramType, "Aggregate parameters should be type expressions" );
51                                if ( isPolyType( paramType->type, typeVars, env ) ) return true;
52                        }
53                        return false;
54                }
55
56                /// Checks a parameter list for dynamic-layout parameters from tyVars; will substitute according to env if present
57                bool hasDynParams(
58                                const std::vector<ast::ptr<ast::Expr>> & params,
59                                const TypeVarMap & typeVars,
60                                const ast::TypeSubstitution * subst ) {
61                        for ( ast::ptr<ast::Expr> const & paramExpr : params ) {
62                                auto param = paramExpr.as<ast::TypeExpr>();
63                                assertf( param, "Aggregate parameters should be type expressions." );
64                                if ( isDynType( param->type.get(), typeVars, subst ) ) {
65                                        return true;
66                                }
67                        }
68                        return false;
69                }
70        }
71
72        const ast::Type * replaceTypeInst(const ast::Type * type, const ast::TypeSubstitution * env) {
73                if (!env) return type;
74                if ( auto typeInst = dynamic_cast<const ast::TypeInstType*>(type) ) {
75                        auto newType = env->lookup(typeInst);
76                        if (newType) return newType;
77                }
78                return type;
79        }
80
81        const ast::Type * isPolyType(const ast::Type * type, const ast::TypeSubstitution * env) {
82                type = replaceTypeInst( type, env );
83
84                if ( dynamic_cast< const ast::TypeInstType * >( type ) ) {
85                        return type;
86                } else if ( auto arrayType = dynamic_cast< const ast::ArrayType * >( type ) ) {
87                        return isPolyType( arrayType->base, env );
88                } else if ( auto structType = dynamic_cast< const ast::StructInstType* >( type ) ) {
89                        if ( hasPolyParams( structType->params, env ) ) return type;
90                } else if ( auto unionType = dynamic_cast< const ast::UnionInstType* >( type ) ) {
91                        if ( hasPolyParams( unionType->params, env ) ) return type;
92                }
93                return 0;
94        }
95
96const ast::Type * isPolyType( const ast::Type * type,
97                const TypeVarMap & typeVars, const ast::TypeSubstitution * subst ) {
98        type = replaceTypeInst( type, subst );
99
100        if ( auto inst = dynamic_cast< const ast::TypeInstType * >( type ) ) {
101                if ( typeVars.contains( *inst ) ) return type;
102        } else if ( auto array = dynamic_cast< const ast::ArrayType * >( type ) ) {
103                return isPolyType( array->base, typeVars, subst );
104        } else if ( auto sue = dynamic_cast< const ast::StructInstType * >( type ) ) {
105                if ( hasPolyParams( sue->params, typeVars, subst ) ) return type;
106        } else if ( auto sue = dynamic_cast< const ast::UnionInstType * >( type ) ) {
107                if ( hasPolyParams( sue->params, typeVars, subst ) ) return type;
108        }
109        return nullptr;
110}
111
112const ast::BaseInstType * isDynType(
113                const ast::Type * type, const TypeVarMap & typeVars,
114                const ast::TypeSubstitution * subst ) {
115        type = replaceTypeInst( type, subst );
116
117        if ( auto inst = dynamic_cast<ast::TypeInstType const *>( type ) ) {
118                auto var = typeVars.find( *inst );
119                if ( var != typeVars.end() && var->second.isComplete ) {
120                        return inst;
121                }
122        } else if ( auto inst = dynamic_cast<ast::StructInstType const *>( type ) ) {
123                if ( hasDynParams( inst->params, typeVars, subst ) ) {
124                        return inst;
125                }
126        } else if ( auto inst = dynamic_cast<ast::UnionInstType const *>( type ) ) {
127                if ( hasDynParams( inst->params, typeVars, subst ) ) {
128                        return inst;
129                }
130        }
131        return nullptr;
132}
133
134const ast::BaseInstType *isDynRet(
135                const ast::FunctionType * type, const TypeVarMap & typeVars ) {
136        if ( type->returns.empty() ) return nullptr;
137
138        return isDynType( type->returns.front(), typeVars );
139}
140
141const ast::BaseInstType *isDynRet( const ast::FunctionType * func ) {
142        if ( func->returns.empty() ) return nullptr;
143
144        TypeVarMap forallTypes;
145        makeTypeVarMap( func, forallTypes );
146        return isDynType( func->returns.front(), forallTypes );
147}
148
149bool needsAdapter(
150                ast::FunctionType const * adaptee, const TypeVarMap & typeVars ) {
151        if ( isDynRet( adaptee, typeVars ) ) return true;
152
153        for ( auto param : adaptee->params ) {
154                if ( isDynType( param, typeVars ) ) {
155                        return true;
156                }
157        }
158        return false;
159}
160
161const ast::Type * isPolyPtr(
162                const ast::Type * type, const TypeVarMap & typeVars,
163                const ast::TypeSubstitution * typeSubs ) {
164        type = replaceTypeInst( type, typeSubs );
165
166        if ( auto * ptr = dynamic_cast<ast::PointerType const *>( type ) ) {
167                return isPolyType( ptr->base, typeVars, typeSubs );
168        }
169        return nullptr;
170}
171
172ast::Type const * hasPolyBase(
173                ast::Type const * type, const TypeVarMap & typeVars,
174                int * levels, const ast::TypeSubstitution * subst ) {
175        int level_count = 0;
176
177        while ( true ) {
178                type = replaceTypeInst( type, subst );
179
180                if ( auto ptr = dynamic_cast<ast::PointerType const *>( type ) ) {
181                        type = ptr->base;
182                        ++level_count;
183                } else {
184                        break;
185                }
186        }
187
188        if ( nullptr != levels ) { *levels = level_count; }
189        return isPolyType( type, typeVars, subst );
190}
191
192        const ast::FunctionType * getFunctionType( const ast::Type * ty ) {
193                if ( auto pty = dynamic_cast< const ast::PointerType * >( ty ) ) {
194                        return pty->base.as< ast::FunctionType >();
195                } else {
196                        return dynamic_cast< const ast::FunctionType * >( ty );
197                }
198        }
199
200        namespace {
201                /// Checks if is a pointer to D
202                template<typename D, typename B>
203                bool is( const B* p ) { return type_index{typeid(D)} == type_index{typeid(*p)}; }
204
205                /// Converts to a pointer to D without checking for safety
206                template<typename D, typename B>
207                inline D* as( B* p ) { return reinterpret_cast<D*>(p); }
208
209                template<typename D, typename B>
210                inline D const * as( B const * p ) {
211                        return reinterpret_cast<D const *>( p );
212                }
213
214                /// Flattens a list of types.
215                // There is another flattenList in Unify.
216                void flattenList( vector<ast::ptr<ast::Type>> const & src,
217                                vector<ast::ptr<ast::Type>> & out ) {
218                        for ( auto const & type : src ) {
219                                ResolvExpr::flatten( type, out );
220                        }
221                }
222
223                bool paramListsPolyCompatible(
224                                std::vector<ast::ptr<ast::Expr>> const & lparams,
225                                std::vector<ast::ptr<ast::Expr>> const & rparams ) {
226                        if ( lparams.size() != rparams.size() ) {
227                                return false;
228                        }
229
230                        for ( auto lparam = lparams.begin(), rparam = rparams.begin() ;
231                                        lparam != lparams.end() ; ++lparam, ++rparam ) {
232                                ast::TypeExpr const * lexpr = lparam->as<ast::TypeExpr>();
233                                assertf( lexpr, "Aggregate parameters should be type expressions" );
234                                ast::TypeExpr const * rexpr = rparam->as<ast::TypeExpr>();
235                                assertf( rexpr, "Aggregate parameters should be type expressions" );
236
237                                // xxx - might need to let VoidType be a wildcard here too; could have some voids
238                                // stuffed in for dtype-statics.
239                                // if ( is<VoidType>( lexpr->type() ) || is<VoidType>( bparam->get_type() ) ) continue;
240                                if ( !typesPolyCompatible( lexpr->type, rexpr->type ) ) {
241                                        return false;
242                                }
243                        }
244
245                        return true;
246                }
247        }
248
249bool typesPolyCompatible( ast::Type const * lhs, ast::Type const * rhs ) {
250        type_index const lid = typeid(*lhs);
251
252        // Polymorphic types always match:
253        if ( type_index(typeid(ast::TypeInstType)) == lid ) return true;
254
255        type_index const rid = typeid(*rhs);
256        if ( type_index(typeid(ast::TypeInstType)) == rid ) return true;
257
258        // All other types only match if they are the same type:
259        if ( lid != rid ) return false;
260
261        // So remaining types can be examined case by case.
262        // Recurse through type structure (conditions borrowed from Unify.cc).
263
264        if ( type_index(typeid(ast::BasicType)) == lid ) {
265                return as<ast::BasicType>(lhs)->kind == as<ast::BasicType>(rhs)->kind;
266        } else if ( type_index(typeid(ast::PointerType)) == lid ) {
267                ast::PointerType const * l = as<ast::PointerType>(lhs);
268                ast::PointerType const * r = as<ast::PointerType>(rhs);
269
270                // void pointers should match any other pointer type.
271                return is<ast::VoidType>( l->base.get() )
272                        || is<ast::VoidType>( r->base.get() )
273                        || typesPolyCompatible( l->base.get(), r->base.get() );
274        } else if ( type_index(typeid(ast::ReferenceType)) == lid ) {
275                ast::ReferenceType const * l = as<ast::ReferenceType>(lhs);
276                ast::ReferenceType const * r = as<ast::ReferenceType>(rhs);
277
278                // void references should match any other reference type.
279                return is<ast::VoidType>( l->base.get() )
280                        || is<ast::VoidType>( r->base.get() )
281                        || typesPolyCompatible( l->base.get(), r->base.get() );
282        } else if ( type_index(typeid(ast::ArrayType)) == lid ) {
283                ast::ArrayType const * l = as<ast::ArrayType>(lhs);
284                ast::ArrayType const * r = as<ast::ArrayType>(rhs);
285
286                if ( l->isVarLen ) {
287                        if ( !r->isVarLen ) return false;
288                } else {
289                        if ( r->isVarLen ) return false;
290
291                        auto lc = l->dimension.as<ast::ConstantExpr>();
292                        auto rc = r->dimension.as<ast::ConstantExpr>();
293                        if ( lc && rc && lc->intValue() != rc->intValue() ) {
294                                return false;
295                        }
296                }
297
298                return typesPolyCompatible( l->base.get(), r->base.get() );
299        } else if ( type_index(typeid(ast::FunctionType)) == lid ) {
300                ast::FunctionType const * l = as<ast::FunctionType>(lhs);
301                ast::FunctionType const * r = as<ast::FunctionType>(rhs);
302
303                std::vector<ast::ptr<ast::Type>> lparams, rparams;
304                flattenList( l->params, lparams );
305                flattenList( r->params, rparams );
306                if ( lparams.size() != rparams.size() ) return false;
307                for ( unsigned i = 0; i < lparams.size(); ++i ) {
308                        if ( !typesPolyCompatible( lparams[i], rparams[i] ) ) return false;
309                }
310
311                std::vector<ast::ptr<ast::Type>> lrets, rrets;
312                flattenList( l->returns, lrets );
313                flattenList( r->returns, rrets );
314                if ( lrets.size() != rrets.size() ) return false;
315                for ( unsigned i = 0; i < lrets.size(); ++i ) {
316                        if ( !typesPolyCompatible( lrets[i], rrets[i] ) ) return false;
317                }
318                return true;
319        } else if ( type_index(typeid(ast::StructInstType)) == lid ) {
320                ast::StructInstType const * l = as<ast::StructInstType>(lhs);
321                ast::StructInstType const * r = as<ast::StructInstType>(rhs);
322
323                if ( l->name != r->name ) return false;
324                return paramListsPolyCompatible( l->params, r->params );
325        } else if ( type_index(typeid(ast::UnionInstType)) == lid ) {
326                ast::UnionInstType const * l = as<ast::UnionInstType>(lhs);
327                ast::UnionInstType const * r = as<ast::UnionInstType>(rhs);
328
329                if ( l->name != r->name ) return false;
330                return paramListsPolyCompatible( l->params, r->params );
331        } else if ( type_index(typeid(ast::EnumInstType)) == lid ) {
332                ast::EnumInstType const * l = as<ast::EnumInstType>(lhs);
333                ast::EnumInstType const * r = as<ast::EnumInstType>(rhs);
334
335                return l->name == r->name;
336        } else if ( type_index(typeid(ast::TraitInstType)) == lid ) {
337                ast::TraitInstType const * l = as<ast::TraitInstType>(lhs);
338                ast::TraitInstType const * r = as<ast::TraitInstType>(rhs);
339
340                return l->name == r->name;
341        } else if ( type_index(typeid(ast::TupleType)) == lid ) {
342                ast::TupleType const * l = as<ast::TupleType>(lhs);
343                ast::TupleType const * r = as<ast::TupleType>(rhs);
344
345                std::vector<ast::ptr<ast::Type>> ltypes, rtypes;
346                flattenList( l->types, ( ltypes ) );
347                flattenList( r->types, ( rtypes ) );
348                if ( ltypes.size() != rtypes.size() ) return false;
349
350                for ( unsigned i = 0 ; i < ltypes.size() ; ++i ) {
351                        if ( !typesPolyCompatible( ltypes[i], rtypes[i] ) ) return false;
352                }
353                return true;
354        // The remaining types (VoidType, VarArgsType, ZeroType & OneType)
355        // have no variation so will always be equal.
356        } else {
357                return true;
358        }
359}
360
361bool needsBoxing( const ast::Type * param, const ast::Type * arg,
362                const TypeVarMap & typeVars, const ast::TypeSubstitution * subst ) {
363        // Don't need to box if the parameter is not polymorphic.
364        if ( !isPolyType( param, typeVars ) ) return false;
365
366        ast::ptr<ast::Type> newType = arg;
367        if ( subst ) {
368                int count = subst->apply( newType );
369                (void)count;
370        }
371        // Only need to box if the argument is not also polymorphic.
372        return !isPolyType( newType );
373}
374
375bool needsBoxing(
376                const ast::Type * param, const ast::Type * arg,
377                const ast::ApplicationExpr * expr,
378                const ast::TypeSubstitution * subst ) {
379        const ast::FunctionType * function = getFunctionType( expr->func->result );
380        assertf( function, "ApplicationExpr has non-function type: %s", toString( expr->func->result ).c_str() );
381        TypeVarMap exprTyVars;
382        makeTypeVarMap( function, exprTyVars );
383        return needsBoxing( param, arg, exprTyVars, subst );
384}
385
386void addToTypeVarMap( const ast::TypeDecl * decl, TypeVarMap & typeVars ) {
387        typeVars.insert( ast::TypeEnvKey( decl, 0, 0 ), ast::TypeData( decl ) );
388}
389
390void addToTypeVarMap( const ast::TypeInstType * type, TypeVarMap & typeVars ) {
391        typeVars.insert( ast::TypeEnvKey( *type ), ast::TypeData( type->base ) );
392}
393
394void makeTypeVarMap( const ast::Type * type, TypeVarMap & typeVars ) {
395        if ( auto func = dynamic_cast<ast::FunctionType const *>( type ) ) {
396                for ( auto & typeVar : func->forall ) {
397                        assert( typeVar );
398                        addToTypeVarMap( typeVar, typeVars );
399                }
400        }
401        if ( auto pointer = dynamic_cast<ast::PointerType const *>( type ) ) {
402                makeTypeVarMap( pointer->base, typeVars );
403        }
404}
405
406void makeTypeVarMap( const ast::FunctionDecl * decl, TypeVarMap & typeVars ) {
407        for ( auto & typeDecl : decl->type_params ) {
408                addToTypeVarMap( typeDecl, typeVars );
409        }
410}
411
412} // namespace GenPoly
413
414// Local Variables: //
415// tab-width: 4 //
416// mode: c++ //
417// compile-command: "make install" //
418// End: //
Note: See TracBrowser for help on using the repository browser.