source: src/GenPoly/FindFunction.cc @ 3bf9d10

Last change on this file since 3bf9d10 was c97b448, checked in by Andrew Beach <ajbeach@…>, 21 months ago

Added some box pass utilities that I believe are working and I don't want to look at all the time.

  • Property mode set to 100644
File size: 6.1 KB
Line 
1//
2// Cforall Version 1.0.0 Copyright (C) 2015 University of Waterloo
3//
4// The contents of this file are covered under the licence agreement in the
5// file "LICENCE" distributed with Cforall.
6//
7// FindFunction.cc --
8//
9// Author           : Richard C. Bilson
10// Created On       : Mon May 18 07:44:20 2015
11// Last Modified By : Andrew Beach
12// Last Modified On : Fri Oct  7 17:05:20 2022
13// Update Count     : 7
14//
15
16#include "FindFunction.h"
17
18#include <utility>                      // for pair
19
20#include "AST/Pass.hpp"                 // for Pass
21#include "AST/Type.hpp"
22#include "Common/PassVisitor.h"         // for PassVisitor
23#include "GenPoly/ErasableScopedMap.h"  // for ErasableScopedMap<>::iterator
24#include "GenPoly/GenPoly.h"            // for TyVarMap
25#include "ScrubTyVars.h"                // for ScrubTyVars
26#include "SynTree/Declaration.h"        // for DeclarationWithType, TypeDecl
27#include "SynTree/Mutator.h"            // for Mutator, mutateAll
28#include "SynTree/Type.h"               // for FunctionType, Type, Type::For...
29
30namespace GenPoly {
31        class FindFunction : public WithGuards, public WithVisitorRef<FindFunction>, public WithShortCircuiting {
32          public:
33                FindFunction( std::list< FunctionType const* > &functions, const TyVarMap &tyVars, bool replaceMode, FindFunctionPredicate predicate );
34
35                void premutate( FunctionType * functionType );
36                Type * postmutate( FunctionType * functionType );
37                void premutate( PointerType * pointerType );
38          private:
39                void handleForall( const Type::ForallList &forall );
40
41                std::list< FunctionType const * > & functions;
42                TyVarMap tyVars;
43                bool replaceMode;
44                FindFunctionPredicate predicate;
45        };
46
47        void findFunction( Type *type, std::list< FunctionType const * > &functions, const TyVarMap &tyVars, FindFunctionPredicate predicate ) {
48                PassVisitor<FindFunction> finder( functions, tyVars, false, predicate );
49                type->acceptMutator( finder );
50        }
51
52        void findAndReplaceFunction( Type *&type, std::list< FunctionType const * > &functions, const TyVarMap &tyVars, FindFunctionPredicate predicate ) {
53                PassVisitor<FindFunction> finder( functions, tyVars, true, predicate );
54                type = type->acceptMutator( finder );
55        }
56
57        FindFunction::FindFunction( std::list< FunctionType const * > &functions, const TyVarMap &tyVars, bool replaceMode, FindFunctionPredicate predicate )
58                : functions( functions ), tyVars( tyVars ), replaceMode( replaceMode ), predicate( predicate ) {
59        }
60
61        void FindFunction::handleForall( const Type::ForallList &forall ) {
62                for ( const Declaration * td : forall ) {
63                        TyVarMap::iterator var = tyVars.find( td->name );
64                        if ( var != tyVars.end() ) {
65                                tyVars.erase( var->first );
66                        } // if
67                } // for
68        }
69
70        void FindFunction::premutate( FunctionType * functionType ) {
71                visit_children = false;
72                GuardScope( tyVars );
73                handleForall( functionType->get_forall() );
74                mutateAll( functionType->get_returnVals(), *visitor );
75        }
76
77        Type * FindFunction::postmutate( FunctionType * functionType ) {
78                Type *ret = functionType;
79                if ( predicate( functionType, tyVars ) ) {
80                        functions.push_back( functionType );
81                        if ( replaceMode ) {
82                                // replace type parameters in function type with void*
83                                ret = ScrubTyVars::scrub( functionType->clone(), tyVars );
84                        } // if
85                } // if
86                return ret;
87        }
88
89        void FindFunction::premutate( PointerType * pointerType ) {
90                GuardScope( tyVars );
91                handleForall( pointerType->get_forall() );
92        }
93
94namespace {
95
96struct FindFunctionCore :
97                public ast::WithGuards,
98                public ast::WithShortCircuiting,
99                public ast::WithVisitorRef<FindFunctionCore> {
100        FindFunctionCore(
101                std::vector<ast::ptr<ast::FunctionType>> & functions,
102                const TypeVarMap & typeVars, FindFunctionPred predicate,
103                bool replaceMode );
104
105        void previsit( ast::FunctionType const * type );
106        ast::Type const * postvisit( ast::FunctionType const * type );
107        void previsit( ast::PointerType const * type );
108private:
109        void handleForall( const ast::FunctionType::ForallList & forall );
110
111        std::vector<ast::ptr<ast::FunctionType>> &functions;
112        TypeVarMap typeVars;
113        FindFunctionPred predicate;
114        bool replaceMode;
115};
116
117FindFunctionCore::FindFunctionCore(
118                std::vector<ast::ptr<ast::FunctionType>> & functions,
119                const TypeVarMap &typeVars, FindFunctionPred predicate,
120                bool replaceMode ) :
121        functions( functions ), typeVars( typeVars ),
122        predicate( predicate ), replaceMode( replaceMode ) {}
123
124void FindFunctionCore::handleForall( const ast::FunctionType::ForallList & forall ) {
125        for ( const ast::ptr<ast::TypeInstType> & td : forall ) {
126                TypeVarMap::iterator var = typeVars.find( *td );
127                if ( var != typeVars.end() ) {
128                        typeVars.erase( var->first );
129                } // if
130        } // for
131}
132
133void FindFunctionCore::previsit( ast::FunctionType const * type ) {
134        visit_children = false;
135        GuardScope( typeVars );
136        handleForall( type->forall );
137        //ast::accept_all( type->returns, *visitor );
138        // This might have to become ast::mutate_each with return.
139        ast::accept_each( type->returns, *visitor );
140}
141
142ast::Type const * FindFunctionCore::postvisit( ast::FunctionType const * type ) {
143        ast::Type const * ret = type;
144        if ( predicate( type, typeVars ) ) {
145                functions.push_back( type );
146                if ( replaceMode ) {
147                        // replace type parameters in function type with void*
148                        ret = scrubTypeVars( ast::deepCopy( type ), typeVars );
149                } // if
150        } // if
151        return ret;
152}
153
154void FindFunctionCore::previsit( ast::PointerType const * /*type*/ ) {
155        GuardScope( typeVars );
156        //handleForall( type->forall );
157}
158
159} // namespace
160
161void findFunction( const ast::Type * type,
162                std::vector<ast::ptr<ast::FunctionType>> & functions,
163                const TypeVarMap & typeVars, FindFunctionPred predicate ) {
164        ast::Pass<FindFunctionCore> pass( functions, typeVars, predicate, false );
165        type->accept( pass );
166        //(void)type;
167        //(void)functions;
168        //(void)typeVars;
169        //(void)predicate;
170}
171
172const ast::Type * findAndReplaceFunction( const ast::Type * type,
173                std::vector<ast::ptr<ast::FunctionType>> & functions,
174                const TypeVarMap & typeVars, FindFunctionPred predicate ) {
175        ast::Pass<FindFunctionCore> pass( functions, typeVars, predicate, true );
176        return type->accept( pass );
177        //(void)functions;
178        //(void)typeVars;
179        //(void)predicate;
180        //return type;
181}
182
183} // namespace GenPoly
184
185// Local Variables: //
186// tab-width: 4 //
187// mode: c++ //
188// compile-command: "make install" //
189// End: //
Note: See TracBrowser for help on using the repository browser.