//
// Cforall Version 1.0.0 Copyright (C) 2015 University of Waterloo
//
// The contents of this file are covered under the licence agreement in the
// file "LICENCE" distributed with Cforall.
//
// SpecCost.cc --
//
// Author           : Aaron B. Moss
// Created On       : Tue Oct 02 15:50:00 2018
// Last Modified By : Aaron B. Moss
// Last Modified On : Tue Oct 02 15:50:00 2018
// Update Count     : 1
//

#include <limits>
#include <list>

#include "Common/PassVisitor.h"
#include "SynTree/Declaration.h"
#include "SynTree/Expression.h"
#include "SynTree/Type.h"

namespace ResolvExpr {

	/// Counts specializations in a type
	class CountSpecs : public WithShortCircuiting, public WithVisitorRef<CountSpecs> {
		int count = -1;  ///< specialization count (-1 for none)

	public:
		int get_count() const { return count >= 0 ? count : 0; }

		// mark specialization of base type
		void postvisit(PointerType*) { if ( count >= 0 ) ++count; }

		// mark specialization of base type
		void postvisit(ArrayType*) { if ( count >= 0 ) ++count; }

		// mark specialization of base type
		void postvisit(ReferenceType*) { if ( count >= 0 ) ++count; }

	private:
		// takes minimum non-negative count over parameter/return list
		void takeminover( int& mincount, std::list<DeclarationWithType*>& dwts ) {
			for ( DeclarationWithType* dwt : dwts ) {
				count = -1;
				maybeAccept( dwt->get_type(), *visitor );
				if ( count != -1 && count < mincount ) mincount = count;
			}
		}

	public:
		// take minimal specialization value over ->returnVals and ->parameters
		void previsit(FunctionType* fty) {
			int mincount = std::numeric_limits<int>::max();
			takeminover( mincount, fty->parameters );
			takeminover( mincount, fty->returnVals );
			// add another level to mincount if set
			count = mincount < std::numeric_limits<int>::max() ? mincount + 1 : -1;
			// already visited children
			visit_children = false;
		}
	
	private:
		// returns minimum non-negative count + 1 over type parameters (-1 if none such)
		int minover( std::list<Expression*>& parms ) {
			int mincount = std::numeric_limits<int>::max();
			for ( Expression* parm : parms ) {
				count = -1;
				maybeAccept( parm->result, *visitor );
				if ( count != -1 && count < mincount ) mincount = count;
			}
			return mincount < std::numeric_limits<int>::max() ? mincount + 1 : -1;
		}

	public:
		// look for polymorphic parameters
		void previsit(StructInstType* sty) {
			count = minover( sty->parameters );
			visit_children = false;
		}
		
		// look for polymorphic parameters
		void previsit(UnionInstType* uty) {
			count = minover( uty->parameters );
			visit_children = false;
		}

		// note polymorphic type (which may be specialized)
		// xxx - maybe account for open/closed type variables
		void postvisit(TypeInstType*) { count = 0; }

		// take minimal specialization over elements
		// xxx - maybe don't increment, tuple flattening doesn't necessarily specialize
		void previsit(TupleType* tty) {
			int mincount = std::numeric_limits<int>::max();
			for ( Type* ty : tty->types ) {
				count = -1;
				maybeAccept( ty, *visitor );
				if ( count != -1 && count < mincount ) mincount = count;
			}
			count = mincount < std::numeric_limits<int>::max() ? mincount + 1 : -1;
			visit_children = false;
		}
	};

	/// Returns the (negated) specialization cost for a given type
	int specCost( Type* ty ) {
		PassVisitor<CountSpecs> counter;
		maybeAccept( ty, *counter.pass.visitor );
		return counter.pass.get_count();
	}
} // namespace ResolvExpr

// Local Variables: //
// tab-width: 4 //
// mode: c++ //
// compile-command: "make install" //
// End: //
