//
// Cforall Version 1.0.0 Copyright (C) 2015 University of Waterloo
//
// The contents of this file are covered under the licence agreement in the
// file "LICENCE" distributed with Cforall.
//
// TypeEnvironment.h --
//
// Author           : Aaron B. Moss
// Created On       : Sun May 17 12:24:58 2015
// Last Modified By : Aaron B. Moss
// Last Modified On : Fri Jun 29 16:00:00 2018
// Update Count     : 5
//

#pragma once

#include <iostream>                        // for ostream
#include <iterator>
#include <list>                            // for list, list<>::iterator, list<>...
#include <map>                             // for map, map<>::value_compare
#include <set>                             // for set
#include <string>                          // for string
#include <utility>                         // for pair
#include <vector>                          // for vector

#include "WidenMode.h"                 // for WidenMode

#include "Common/InternedString.h"         // for interned_string
#include "Common/PersistentDisjointSet.h"  // for PersistentDisjointSet
#include "Common/PersistentMap.h"          // for PersistentMap
#include "SynTree/Declaration.h"           // for TypeDecl::Data, DeclarationWit...
#include "SynTree/SynTree.h"               // for UniqueId
#include "SynTree/Type.h"                  // for Type, TypeInstType, Type::ForallList
#include "SynTree/TypeSubstitution.h"      // for TypeSubstitution

template< typename Pass > class PassVisitor;
class GcTracer;
namespace SymTab { class Indexer; }

namespace ResolvExpr {
	// adding this comparison operator significantly improves assertion resolution run time for
	// some cases. The current resolution algorithm's speed partially depends on the order of
	// assertions. Assertions which have fewer possible matches should appear before
	// assertions which have more possible matches. This seems to imply that this could
	// be further improved by providing an indexer as an additional argument and ordering based
	// on the number of matches of the same kind (object, function) for the names of the
	// declarations.
	//
	// I've seen a TU go from 54 minutes to 1 minute 34 seconds with the addition of this 
	// comparator.
	//
	// Note: since this compares pointers for position, minor changes in the source file that affect
	// memory layout can alter compilation time in unpredictable ways. For example, the placement
	// of a line directive can reorder type pointers with respect to each other so that assertions
	// are seen in different orders, causing a potentially different number of unification calls 
	// when resolving assertions. I've seen a TU go from 36 seconds to 27 seconds by reordering 
	// line directives alone, so it would be nice to fix this comparison so that assertions compare 
	// more consistently. I've tried to modify this to compare on mangle name instead of type as 
	// the second comparator, but this causes some assertions to never be recorded. More 
	// investigation is needed.
	struct AssertCompare {
		bool operator()( DeclarationWithType * d1, DeclarationWithType * d2 ) const {
			int cmp = d1->get_name().compare( d2->get_name() );
			return cmp < 0 ||
				( cmp == 0 && d1->get_type() < d2->get_type() );
		}
	};
	struct AssertionSetValue {
		bool isUsed;
		// chain of Unique IDs of the assertion declarations. The first ID in the chain is the ID 
		// of an assertion on the current type, with each successive ID being the ID of an 
		// assertion pulled in by the previous ID. The last ID in the chain is the ID of the 
		// assertion that pulled in the current assertion.
		std::list< UniqueId > idChain;
	};
	typedef std::map< DeclarationWithType*, AssertionSetValue, AssertCompare > AssertionSet;
	typedef std::map< std::string, TypeDecl::Data > OpenVarSet;

	void printAssertionSet( const AssertionSet &, std::ostream &, int indent = 0 );
	void printOpenVarSet( const OpenVarSet &, std::ostream &, int indent = 0 );

	/// A data structure for holding all the necessary information for a type binding
	struct BoundType {
		Type* type;
		bool allowWidening;
		TypeDecl::Data data;

		BoundType() = default;
		BoundType( TypeDecl* td ) : type{nullptr}, allowWidening{true}, data{td} {}
		BoundType( Type* ty, bool aw, const TypeDecl::Data& td )
			: type{ty}, allowWidening{aw}, data{td} {}
		BoundType( const BoundType& o )
			: type{maybeClone(o.type)}, allowWidening{o.allowWidening}, data{o.data} {}
		BoundType( BoundType&& o ) = default;
		BoundType& operator= (const BoundType& o) {
			if ( this == &o ) return *this;
			type = maybeClone( o.type );
			allowWidening = o.allowWidening;
			data = o.data;
			return *this;
		}
		BoundType& operator= (BoundType&& o) = default;
	};

	class TypeEnvironment;

	/// A reference to an equivalence class that may be used to constitute one from its environment
	class ClassRef {
		friend TypeEnvironment;

		const TypeEnvironment* env;  ///< Containing environment
		interned_string root;        ///< Name of root type

	public:
		ClassRef() : env(nullptr), root(nullptr) {}
		ClassRef( const TypeEnvironment* env, interned_string root ) : env(env), root(root) {}

		/// Gets the root of the reference equivalence class;
		interned_string get_root() const { return root; }

		/// Ensures that root is still the representative element of this typeclass;
		/// undefined behaviour if called without referenced typeclass; returns new root
		inline interned_string update_root();

		/// Gets the type variables of the referenced equivalence class, empty list for none
		template<typename T = std::vector<interned_string>>
		inline T get_vars() const;

		/// Gets the bound type information of the referenced equivalence class, default if none
		inline BoundType get_bound() const;

		// Check that there is a referenced typeclass
		explicit operator bool() const { return env != nullptr; }

		bool operator== (const ClassRef& o) const { return env == o.env && root == o.root; }
		bool operator!= (const ClassRef& o) const { return !(*this == o); }
	};

	class ValidateGuard;

	class TypeEnvironment {
		friend ClassRef;
		friend GcTracer;
		
		/// Backing storage for equivalence classes
		using Classes = PersistentDisjointSet<interned_string>;
		/// Type bindings included in this environment (from class root)
		using Bindings = PersistentMap<interned_string, BoundType>;

		/// Sets of equivalent type variables, stored by name
		Classes* classes;
		/// Bindings from roots of equivalence classes to type binding information. 
		/// All roots have a binding so that the list of classes can be reconstituted, though these 
		/// may be null.
		Bindings* bindings;

		// for debugging
		friend ValidateGuard;
		const char* last_fn = "<none>";

		/// Merges the classes rooted at root1 and root2, returning a pair containing the root and 
		/// child of the bound class. Does not check for validity of merge.
		std::pair<interned_string, interned_string> mergeClasses( 
			interned_string root1, interned_string root2 );

	public:
		class iterator : public std::iterator<
				std::forward_iterator_tag, 
				ClassRef, 
				std::iterator_traits<Bindings::iterator>::difference_type,
				ClassRef,
				ClassRef> {
			friend TypeEnvironment;
			
			const TypeEnvironment* env;
			Bindings::iterator it;

			iterator(const TypeEnvironment* e, Bindings::iterator&& i) : env(e), it(std::move(i)) {}

			ClassRef ref() const { return { env, it->first }; }
		public:
			iterator() = default;

			reference operator* () { return ref(); }
			pointer operator-> () { return ref(); }

			iterator& operator++ () { ++it; return *this; }
			iterator operator++ (int) { iterator tmp = *this; ++(*this); return tmp; }

			bool operator== (const iterator& o) const { return env == o.env && it == o.it; }
			bool operator!= (const iterator& o) const { return !(*this == o); }
		};

		/// Finds a reference to the class containing `var`, invalid if none such.
		/// returned root variable will be valid regardless
		ClassRef lookup( interned_string var ) const;

		/// Binds a type variable to a type; returns false if fails
		bool bindVar( TypeInstType* typeInst, Type* bindTo, const TypeDecl::Data& data, 
			AssertionSet& need, AssertionSet& have, const OpenVarSet& openVars, 
			WidenMode widenMode, const SymTab::Indexer& indexer );
		
		/// Binds two type variables together; returns false if fails
		bool bindVarToVar( TypeInstType* var1, TypeInstType* var2, const TypeDecl::Data& data, 
			AssertionSet& need, AssertionSet& have, const OpenVarSet& openVars, 
			WidenMode widenMode, const SymTab::Indexer& indexer );

	public:
		TypeEnvironment() : classes{ new Classes{} }, bindings{ new Bindings{} } {}

		void add( const Type::ForallList &tyDecls );
		void add( const TypeSubstitution & sub );
		template< typename SynTreeClass > int apply( SynTreeClass *&type ) const;
		template< typename SynTreeClass > int applyFree( SynTreeClass *&type ) const;
		void makeSubstitution( TypeSubstitution &result ) const;
		bool isEmpty() const { return classes->empty(); }
		void print( std::ostream &os, Indenter indent = {} ) const;
	
		/// Combines two environments without checking invariants.
		/// Caller should ensure environments do not share type variables.
		void simpleCombine( const TypeEnvironment &second );

		/// Combines two environments, checking compatibility. Both environments must be versioned 
		/// from the same initial environment.
		/// Returns false if unsuccessful, but does NOT roll back partial changes
		bool combine( const TypeEnvironment& o, const SymTab::Indexer& indexer );
	
		void extractOpenVars( OpenVarSet &openVars ) const;
		TypeEnvironment *clone() const { return new TypeEnvironment( *this ); }

		/// Iteratively adds the environment of a new actual (with allowWidening = false),
		/// and extracts open variables.
		void addActual( const TypeEnvironment& actualEnv, OpenVarSet& openVars );

		/// Disallows widening for all bindings in the environment
		void forbidWidening();

		iterator begin() const { return { this, bindings->begin() }; }
		iterator end() const { return { this, bindings->end() }; }
	};

	interned_string ClassRef::update_root() { return root = env->classes->find( root ); }

	template<typename T>
	T ClassRef::get_vars() const {
		T vars;
		env->classes->for_class( root, [&vars]( interned_string var ) {
			vars.insert( vars.end(), var );
		} );
		return vars;
	}

	BoundType ClassRef::get_bound() const {
		return env->bindings->get_or_default( root, BoundType{} );
	}

	template< typename SynTreeClass >
	int TypeEnvironment::apply( SynTreeClass *&type ) const {
		TypeSubstitution sub;
		makeSubstitution( sub );
		return sub.apply( type );
	}

	template< typename SynTreeClass >
	int TypeEnvironment::applyFree( SynTreeClass *&type ) const {
		TypeSubstitution sub;
		makeSubstitution( sub );
		return sub.applyFree( type );
	}

	std::ostream & operator<<( std::ostream & out, const TypeEnvironment & env );

	PassVisitor<GcTracer> & operator<<( PassVisitor<GcTracer> & gc, const TypeEnvironment & env );
} // namespace ResolvExpr

// Local Variables: //
// tab-width: 4 //
// mode: c++ //
// compile-command: "make install" //
// End: //
