source: src/ResolvExpr/SpecCost.cc @ 809e058

ADTast-experimental
Last change on this file since 809e058 was 978e5eb, checked in by Michael Brooks <mlbrooks@…>, 4 years ago

Calculation of specialization benefit (spec "cost") looks inside the body of xInstTypes, counting one benefit for each inner type constructor met. Fixes 225?

  • Property mode set to 100644
File size: 7.2 KB
Line 
1//
2// Cforall Version 1.0.0 Copyright (C) 2015 University of Waterloo
3//
4// The contents of this file are covered under the licence agreement in the
5// file "LICENCE" distributed with Cforall.
6//
7// SpecCost.cc --
8//
9// Author           : Aaron B. Moss
10// Created On       : Tue Oct 02 15:50:00 2018
11// Last Modified By : Andrew Beach
12// Last Modified On : Wed Jul  3 11:07:00 2019
13// Update Count     : 3
14//
15
16#include <cassert>
17#include <limits>
18#include <list>
19#include <type_traits>
20
21#include "AST/Pass.hpp"
22#include "AST/Type.hpp"
23#include "Common/PassVisitor.h"
24#include "SynTree/Declaration.h"
25#include "SynTree/Expression.h"
26#include "SynTree/Type.h"
27
28namespace ResolvExpr {
29
30        /// Counts specializations in a type
31        class CountSpecs : public WithShortCircuiting, public WithVisitorRef<CountSpecs> {
32                int count = -1;  ///< specialization count (-1 for none)
33
34        public:
35                int get_count() const { return count >= 0 ? count : 0; }
36
37                // mark specialization of base type
38                void postvisit(PointerType*) { if ( count >= 0 ) ++count; }
39
40                // mark specialization of base type
41                void postvisit(ArrayType*) { if ( count >= 0 ) ++count; }
42
43                // mark specialization of base type
44                void postvisit(ReferenceType*) { if ( count >= 0 ) ++count; }
45
46                void postvisit(StructInstType*) { if ( count >= 0 ) ++count; }
47                void postvisit(UnionInstType*) { if ( count >= 0 ) ++count; }
48
49        private:
50                // takes minimum non-negative count over parameter/return list
51                void takeminover( int& mincount, std::list<DeclarationWithType*>& dwts ) {
52                        for ( DeclarationWithType* dwt : dwts ) {
53                                count = -1;
54                                maybeAccept( dwt->get_type(), *visitor );
55                                if ( count != -1 && count < mincount ) mincount = count;
56                        }
57                }
58
59        public:
60                // take minimal specialization value over ->returnVals and ->parameters
61                void previsit(FunctionType* fty) {
62                        int mincount = std::numeric_limits<int>::max();
63                        takeminover( mincount, fty->parameters );
64                        takeminover( mincount, fty->returnVals );
65                        // add another level to mincount if set
66                        count = mincount < std::numeric_limits<int>::max() ? mincount + 1 : -1;
67                        // already visited children
68                        visit_children = false;
69                }
70
71        private:
72                // returns minimum non-negative count + 1 over type parameters (-1 if none such)
73                int minover( std::list<Expression*>& parms ) {
74                        int mincount = std::numeric_limits<int>::max();
75                        for ( Expression* parm : parms ) {
76                                count = -1;
77                                maybeAccept( parm->result, *visitor );
78                                if ( count != -1 && count < mincount ) mincount = count;
79                        }
80                        return mincount < std::numeric_limits<int>::max() ? mincount + 1 : -1;
81                }
82
83        public:
84                // look for polymorphic parameters
85                void previsit(StructInstType* sty) {
86                        count = minover( sty->parameters );
87                }
88
89                // look for polymorphic parameters
90                void previsit(UnionInstType* uty) {
91                        count = minover( uty->parameters );
92                }
93
94                // note polymorphic type (which may be specialized)
95                // xxx - maybe account for open/closed type variables
96                void postvisit(TypeInstType*) { count = 0; }
97
98                // take minimal specialization over elements
99                // xxx - maybe don't increment, tuple flattening doesn't necessarily specialize
100                void previsit(TupleType* tty) {
101                        int mincount = std::numeric_limits<int>::max();
102                        for ( Type* ty : tty->types ) {
103                                count = -1;
104                                maybeAccept( ty, *visitor );
105                                if ( count != -1 && count < mincount ) mincount = count;
106                        }
107                        count = mincount < std::numeric_limits<int>::max() ? mincount + 1 : -1;
108                        visit_children = false;
109                }
110        };
111
112        /// Returns the (negated) specialization cost for a given type
113        int specCost( Type* ty ) {
114                PassVisitor<CountSpecs> counter;
115                maybeAccept( ty, *counter.pass.visitor );
116                return counter.pass.get_count();
117        }
118
119namespace {
120        /// The specialization counter inner class.
121        class SpecCounter : public ast::WithShortCircuiting, public ast::WithVisitorRef<SpecCounter> {
122                int count = -1;  ///< specialization count (-1 for none)
123
124                // Converts the max value to -1 (none), otherwise increments the value.
125                static int toNoneOrInc( int value ) {
126                        assert( 0 <= value );
127                        return value < std::numeric_limits<int>::max() ? value + 1 : -1;
128                }
129
130                template<typename T> using MapperT =
131                        typename std::add_pointer<ast::Type const *(typename T::value_type const &)>::type;
132
133                #warning Should use a standard maybe_accept
134                void maybe_accept( ast::Type const * type ) {
135                        if ( type ) {
136                                auto node = type->accept( *visitor );
137                                assert( node == nullptr || node == type );
138                        }
139                }
140
141                // Update the minimum to the new lowest non-none value.
142                template<typename T>
143                void updateMinimumPresent( int & minimum, const T & list, MapperT<T> mapper ) {
144                        for ( const auto & node : list ) {
145                                count = -1;
146                                maybe_accept( mapper( node ) );
147                                if ( count != -1 && count < minimum ) minimum = count;
148                        }
149                }
150
151                // Returns minimum non-negative count + 1 over type parameters (-1 if none such).
152                template<typename T>
153                int minimumPresent( const T & list, MapperT<T> mapper ) {
154                        int minCount = std::numeric_limits<int>::max();
155                        updateMinimumPresent( minCount, list, mapper );
156                        return toNoneOrInc( minCount );
157                }
158
159                // The three mappers:
160                static const ast::Type * decl_type( const ast::ptr< ast::DeclWithType > & decl ) {
161                        return decl->get_type();
162                }
163                static const ast::Type * expr_result( const ast::ptr< ast::Expr > & expr ) {
164                        return expr->result;
165                }
166                static const ast::Type * type_deref( const ast::ptr< ast::Type > & type ) {
167                        return type.get();
168                }
169
170        public:
171                int get_count() const { return 0 <= count ? count : 0; }
172
173                // Mark specialization of base type.
174                void postvisit( const ast::PointerType * ) { if ( count >= 0 ) ++count; }
175                void postvisit( const ast::ArrayType * ) { if ( count >= 0 ) ++count; }
176                void postvisit( const ast::ReferenceType * ) { if ( count >= 0 ) ++count; }
177
178                void postvisit( const ast::StructInstType * ) { if ( count >= 0 ) ++count; }
179                void postvisit( const ast::UnionInstType * ) { if ( count >= 0 ) ++count; }
180
181                // Use the minimal specialization value over returns and params.
182                void previsit( const ast::FunctionType * fty ) {
183                        int minCount = std::numeric_limits<int>::max();
184                        updateMinimumPresent( minCount, fty->params, type_deref );
185                        updateMinimumPresent( minCount, fty->returns, type_deref );
186                        // Add another level to minCount if set.
187                        count = toNoneOrInc( minCount );
188                        // We have already visited children.
189                        visit_children = false;
190                }
191
192                // Look for polymorphic parameters.
193                void previsit( const ast::StructInstType * sty ) {
194                        count = minimumPresent( sty->params, expr_result );
195                }
196
197                // Look for polymorphic parameters.
198                void previsit( const ast::UnionInstType * uty ) {
199                        count = minimumPresent( uty->params, expr_result );
200                }
201
202                // Note polymorphic type (which may be specialized).
203                // xxx - maybe account for open/closed type variables
204                void postvisit( const ast::TypeInstType * ) { count = 0; }
205
206                // Use the minimal specialization over elements.
207                // xxx - maybe don't increment, tuple flattening doesn't necessarily specialize
208                void previsit( const ast::TupleType * tty ) {
209                        count = minimumPresent( tty->types, type_deref );
210                        visit_children = false;
211                }
212        };
213
214} // namespace
215
216int specCost( const ast::Type * type ) {
217        if ( nullptr == type ) {
218                return 0;
219        }
220        ast::Pass<SpecCounter> counter;
221        type->accept( counter );
222        return counter.core.get_count();
223}
224
225} // namespace ResolvExpr
226
227// Local Variables: //
228// tab-width: 4 //
229// mode: c++ //
230// compile-command: "make install" //
231// End: //
Note: See TracBrowser for help on using the repository browser.