source: src/ResolvExpr/SpecCost.cc@ a2eb21a

Last change on this file since a2eb21a was 978e5eb, checked in by Michael Brooks <mlbrooks@…>, 5 years ago

Calculation of specialization benefit (spec "cost") looks inside the body of xInstTypes, counting one benefit for each inner type constructor met. Fixes 225

  • Property mode set to 100644
File size: 7.2 KB
RevLine 
[1dd1bd2]1//
2// Cforall Version 1.0.0 Copyright (C) 2015 University of Waterloo
3//
4// The contents of this file are covered under the licence agreement in the
5// file "LICENCE" distributed with Cforall.
6//
7// SpecCost.cc --
8//
9// Author : Aaron B. Moss
10// Created On : Tue Oct 02 15:50:00 2018
[5aa4656]11// Last Modified By : Andrew Beach
[03bf5c8]12// Last Modified On : Wed Jul 3 11:07:00 2019
13// Update Count : 3
[1dd1bd2]14//
15
[03bf5c8]16#include <cassert>
[1dd1bd2]17#include <limits>
18#include <list>
[5aa4656]19#include <type_traits>
[1dd1bd2]20
[5aa4656]21#include "AST/Pass.hpp"
[9d5089e]22#include "AST/Type.hpp"
[1dd1bd2]23#include "Common/PassVisitor.h"
24#include "SynTree/Declaration.h"
25#include "SynTree/Expression.h"
26#include "SynTree/Type.h"
27
28namespace ResolvExpr {
29
30 /// Counts specializations in a type
31 class CountSpecs : public WithShortCircuiting, public WithVisitorRef<CountSpecs> {
32 int count = -1; ///< specialization count (-1 for none)
33
34 public:
35 int get_count() const { return count >= 0 ? count : 0; }
36
37 // mark specialization of base type
38 void postvisit(PointerType*) { if ( count >= 0 ) ++count; }
39
40 // mark specialization of base type
41 void postvisit(ArrayType*) { if ( count >= 0 ) ++count; }
42
43 // mark specialization of base type
44 void postvisit(ReferenceType*) { if ( count >= 0 ) ++count; }
45
[978e5eb]46 void postvisit(StructInstType*) { if ( count >= 0 ) ++count; }
47 void postvisit(UnionInstType*) { if ( count >= 0 ) ++count; }
48
[1dd1bd2]49 private:
50 // takes minimum non-negative count over parameter/return list
51 void takeminover( int& mincount, std::list<DeclarationWithType*>& dwts ) {
52 for ( DeclarationWithType* dwt : dwts ) {
53 count = -1;
54 maybeAccept( dwt->get_type(), *visitor );
55 if ( count != -1 && count < mincount ) mincount = count;
56 }
57 }
58
59 public:
60 // take minimal specialization value over ->returnVals and ->parameters
61 void previsit(FunctionType* fty) {
62 int mincount = std::numeric_limits<int>::max();
63 takeminover( mincount, fty->parameters );
64 takeminover( mincount, fty->returnVals );
65 // add another level to mincount if set
66 count = mincount < std::numeric_limits<int>::max() ? mincount + 1 : -1;
67 // already visited children
68 visit_children = false;
69 }
[5aa4656]70
[1dd1bd2]71 private:
72 // returns minimum non-negative count + 1 over type parameters (-1 if none such)
73 int minover( std::list<Expression*>& parms ) {
74 int mincount = std::numeric_limits<int>::max();
75 for ( Expression* parm : parms ) {
76 count = -1;
77 maybeAccept( parm->result, *visitor );
78 if ( count != -1 && count < mincount ) mincount = count;
79 }
80 return mincount < std::numeric_limits<int>::max() ? mincount + 1 : -1;
81 }
82
83 public:
84 // look for polymorphic parameters
85 void previsit(StructInstType* sty) {
86 count = minover( sty->parameters );
87 }
[5aa4656]88
[1dd1bd2]89 // look for polymorphic parameters
90 void previsit(UnionInstType* uty) {
91 count = minover( uty->parameters );
92 }
93
94 // note polymorphic type (which may be specialized)
95 // xxx - maybe account for open/closed type variables
96 void postvisit(TypeInstType*) { count = 0; }
97
98 // take minimal specialization over elements
99 // xxx - maybe don't increment, tuple flattening doesn't necessarily specialize
100 void previsit(TupleType* tty) {
101 int mincount = std::numeric_limits<int>::max();
102 for ( Type* ty : tty->types ) {
103 count = -1;
104 maybeAccept( ty, *visitor );
105 if ( count != -1 && count < mincount ) mincount = count;
106 }
107 count = mincount < std::numeric_limits<int>::max() ? mincount + 1 : -1;
108 visit_children = false;
109 }
110 };
111
112 /// Returns the (negated) specialization cost for a given type
113 int specCost( Type* ty ) {
114 PassVisitor<CountSpecs> counter;
115 maybeAccept( ty, *counter.pass.visitor );
116 return counter.pass.get_count();
117 }
[9d5089e]118
[5aa4656]119namespace {
120 /// The specialization counter inner class.
121 class SpecCounter : public ast::WithShortCircuiting, public ast::WithVisitorRef<SpecCounter> {
122 int count = -1; ///< specialization count (-1 for none)
123
124 // Converts the max value to -1 (none), otherwise increments the value.
125 static int toNoneOrInc( int value ) {
126 assert( 0 <= value );
127 return value < std::numeric_limits<int>::max() ? value + 1 : -1;
128 }
129
130 template<typename T> using MapperT =
131 typename std::add_pointer<ast::Type const *(typename T::value_type const &)>::type;
132
[03bf5c8]133 #warning Should use a standard maybe_accept
134 void maybe_accept( ast::Type const * type ) {
135 if ( type ) {
136 auto node = type->accept( *visitor );
137 assert( node == nullptr || node == type );
138 }
139 }
140
[5aa4656]141 // Update the minimum to the new lowest non-none value.
142 template<typename T>
143 void updateMinimumPresent( int & minimum, const T & list, MapperT<T> mapper ) {
144 for ( const auto & node : list ) {
145 count = -1;
[03bf5c8]146 maybe_accept( mapper( node ) );
[5aa4656]147 if ( count != -1 && count < minimum ) minimum = count;
148 }
149 }
150
151 // Returns minimum non-negative count + 1 over type parameters (-1 if none such).
152 template<typename T>
153 int minimumPresent( const T & list, MapperT<T> mapper ) {
154 int minCount = std::numeric_limits<int>::max();
155 updateMinimumPresent( minCount, list, mapper );
156 return toNoneOrInc( minCount );
157 }
158
159 // The three mappers:
160 static const ast::Type * decl_type( const ast::ptr< ast::DeclWithType > & decl ) {
161 return decl->get_type();
162 }
163 static const ast::Type * expr_result( const ast::ptr< ast::Expr > & expr ) {
164 return expr->result;
165 }
166 static const ast::Type * type_deref( const ast::ptr< ast::Type > & type ) {
167 return type.get();
168 }
169
170 public:
171 int get_count() const { return 0 <= count ? count : 0; }
172
173 // Mark specialization of base type.
174 void postvisit( const ast::PointerType * ) { if ( count >= 0 ) ++count; }
175 void postvisit( const ast::ArrayType * ) { if ( count >= 0 ) ++count; }
176 void postvisit( const ast::ReferenceType * ) { if ( count >= 0 ) ++count; }
177
[978e5eb]178 void postvisit( const ast::StructInstType * ) { if ( count >= 0 ) ++count; }
179 void postvisit( const ast::UnionInstType * ) { if ( count >= 0 ) ++count; }
180
[5aa4656]181 // Use the minimal specialization value over returns and params.
182 void previsit( const ast::FunctionType * fty ) {
183 int minCount = std::numeric_limits<int>::max();
[954c954]184 updateMinimumPresent( minCount, fty->params, type_deref );
185 updateMinimumPresent( minCount, fty->returns, type_deref );
[5aa4656]186 // Add another level to minCount if set.
187 count = toNoneOrInc( minCount );
188 // We have already visited children.
189 visit_children = false;
190 }
191
192 // Look for polymorphic parameters.
193 void previsit( const ast::StructInstType * sty ) {
194 count = minimumPresent( sty->params, expr_result );
195 }
196
197 // Look for polymorphic parameters.
198 void previsit( const ast::UnionInstType * uty ) {
199 count = minimumPresent( uty->params, expr_result );
200 }
201
202 // Note polymorphic type (which may be specialized).
203 // xxx - maybe account for open/closed type variables
204 void postvisit( const ast::TypeInstType * ) { count = 0; }
205
206 // Use the minimal specialization over elements.
207 // xxx - maybe don't increment, tuple flattening doesn't necessarily specialize
208 void previsit( const ast::TupleType * tty ) {
209 count = minimumPresent( tty->types, type_deref );
210 visit_children = false;
211 }
212 };
213
214} // namespace
215
216int specCost( const ast::Type * type ) {
217 if ( nullptr == type ) {
[9d5089e]218 return 0;
219 }
[5aa4656]220 ast::Pass<SpecCounter> counter;
[7ff3e522]221 type->accept( counter );
222 return counter.core.get_count();
[5aa4656]223}
224
[1dd1bd2]225} // namespace ResolvExpr
226
227// Local Variables: //
228// tab-width: 4 //
229// mode: c++ //
230// compile-command: "make install" //
231// End: //
Note: See TracBrowser for help on using the repository browser.