//
// Cforall Version 1.0.0 Copyright (C) 2015 University of Waterloo
//
// The contents of this file are covered under the licence agreement in the
// file "LICENCE" distributed with Cforall.
//
// TypeEnvironment.cc --
//
// Author           : Richard C. Bilson
// Created On       : Sun May 17 12:19:47 2015
// Last Modified By : Peter A. Buhr
// Last Modified On : Sun May 17 12:23:36 2015
// Update Count     : 3
//

#include <cassert>                     // for assert
#include <algorithm>                   // for copy, set_intersection
#include <iterator>                    // for ostream_iterator, insert_iterator
#include <utility>                     // for pair, move

#include "Common/InternedString.h"     // for interned_string
#include "Common/PassVisitor.h"        // for PassVisitor<GcTracer>
#include "Common/utility.h"            // for maybeClone
#include "SymTab/Indexer.h"            // for Indexer
#include "SynTree/GcTracer.h"          // for PassVisitor<GcTracer>
#include "SynTree/Type.h"              // for Type, FunctionType, Type::Fora...
#include "SynTree/TypeSubstitution.h"  // for TypeSubstitution
#include "TypeEnvironment.h"
#include "Unify.h"                     // for unifyInexact

namespace ResolvExpr {
	void printAssertionSet( const AssertionSet &assertions, std::ostream &os, int indent ) {
		for ( AssertionSet::const_iterator i = assertions.begin(); i != assertions.end(); ++i ) {
			i->first->print( os, indent );
			if ( i->second.isUsed ) {
				os << "(used)";
			} else {
				os << "(not used)";
			} // if
		} // for
	}

	void printOpenVarSet( const OpenVarSet &openVars, std::ostream &os, int indent ) {
		os << std::string( indent, ' ' );
		for ( OpenVarSet::const_iterator i = openVars.begin(); i != openVars.end(); ++i ) {
			os << i->first << "(" << i->second << ") ";
		} // for
	}

#if 0
	void EqvClass::initialize( const EqvClass &src, EqvClass &dest ) {
		dest.vars = src.vars;
		dest.type = maybeClone( src.type );
		dest.allowWidening = src.allowWidening;
		dest.data = src.data;
	}

	EqvClass::EqvClass() : vars(), type( 0 ), allowWidening( true ), data() {}

	EqvClass::EqvClass( std::vector<interned_string>&& vs, BoundType&& bound )
		: vars( vs.begin(), vs.end() ), type( maybeClone( bound.type ) ), 
		  allowWidening( bound.allowWidening ), data( bound.data ) {}

	EqvClass::EqvClass( const EqvClass &other ) {
		initialize( other, *this );
	}

	EqvClass &EqvClass::operator=( const EqvClass &other ) {
		if ( this == &other ) return *this;
		initialize( other, *this );
		return *this;
	}

	void EqvClass::print( std::ostream &os, Indenter indent ) const {
		os << "( ";
		std::copy( vars.begin(), vars.end(), std::ostream_iterator< std::string >( os, " " ) );
		os << ")";
		if ( type ) {
			os << " -> ";
			type->print( os, indent+1 );
		} // if
		if ( ! allowWidening ) {
			os << " (no widening)";
		} // if
		os << std::endl;
	}

	const EqvClass* TypeEnvironment::lookup( const std::string &var ) const {
		for ( std::list< EqvClass >::const_iterator i = env.begin(); i != env.end(); ++i ) {
			if ( i->vars.find( var ) != i->vars.end() ) {
///       std::cout << var << " is in class ";
///       i->print( std::cout );
				return &*i;
			}
///     std::cout << var << " is not in class ";
///     i->print( std::cout );
		} // for
		return nullptr;
	}
#endif

	std::pair<interned_string, interned_string> TypeEnvironment::mergeClasses( 
			interned_string root1, interned_string root2 ) {
		// merge classes
		Classes* newClasses = classes->merge( root1, root2 );

		// determine new root
		assertf(classes->get_mode() == Classes::REMFROM, "classes updated to REMFROM by merge");
		auto ret = std::make_pair( classes->get_root(), classes->get_child );
		
		// finalize classes
		classes = newClasses;
		return ret;
	}

	ClassRef TypeEnvironment::lookup( interned_string var ) const {
		interned_string root = classes->find_or_default( var, nullptr );
		if ( root ) return { this, root };
		else return { nullptr, var };
	}

	bool tyVarCompatible( const TypeDecl::Data & data, Type *type ) {
		switch ( data.kind ) {
		  case TypeDecl::Dtype:
			// to bind to an object type variable, the type must not be a function type.
			// if the type variable is specified to be a complete type then the incoming
			// type must also be complete
			// xxx - should this also check that type is not a tuple type and that it's not a ttype?
			return ! isFtype( type ) && (! data.isComplete || type->isComplete() );
		  case TypeDecl::Ftype:
			return isFtype( type );
		  case TypeDecl::Ttype:
			// ttype unifies with any tuple type
			return dynamic_cast< TupleType * >( type ) || Tuples::isTtype( type );
		} // switch
		return false;
	}

	bool TypeEnvironment::bindVar( TypeInstType* typeInst, Type* bindTo, 
			const TypeDecl::Data& data, AssertionSet& need, AssertionSet& have, 
			const OpenVarSet& openVars, WidenMode widenMode, const SymTab::Indexer& indexer ) {
		// remove references from other, so that type variables can only bind to value types
		bindTo = bindTo->stripReferences();
		
		auto tyVar = openVars.find( typeInst->get_name() );
		assert( tyVar != openVars.end() );
		if ( ! tyVarCompatible( tyVar->second, other ) ) return false;

		if ( occurs( bindTo, typeInst->get_name(), *this ) ) return false;

		if ( ClassRef curClass = lookup( typeInst->get_name() ) ) {
			BoundType curData = curClass.get_bound();
			if ( curData.type ) {
				Type* common = nullptr;
				// attempt to unify equivalence class type (which needs its qualifiers restored) 
				// with the target type
				Type* newType = curData.type->clone();
				newType->get_qualifiers() = typeInst->get_qualifiers();
				if ( unifyInexact( newType, bindTo, *this, need, have, openVars, 
						widenMode & WidenMode{ curData.allowWidening, true }, indexer, common ) ) {
					if ( common ) {
						// update bound type to common type
						common->get_qualifiers() = Type::Qualifiers{};
						curData.type = common;
						bindings = bindings->set( curClass.update_root(), curData );
					}
					return true;
				} else return false;
			} else {
				// update bound type to other type
				curData.type = bindTo->clone();
				curData.type->get_qualifiers() = Type::Qualifiers{};
				curData.allowWidening = widenMode.widenFirst && widenMode.widenSecond;
				bindings = bindings->set( curClass.get_root(), curData );
			}
		} else {
			// make new class consisting entirely of this variable
			BoundType curData{ bindTo->clone(), widenMode.first && widenMode.second, data };
			curData.type->get_qualifiers() = Type::Qualifiers{};
			classes = classes->add( curClass.get_root() );
			bindings = bindings->set( curClass.get_root(), curData );
		}
		return true;
	}
	
	bool TypeEnvironment::bindVarToVar( TypeInstType* var1, TypeInstType* var2, 
			const TypeDecl::Data& data, AssertionSet& need, AssertionSet& have, 
			const OpenVarSet& openVars, WidenMode widenMode, const SymTab::Indexer& indexer ) {
		ClassRef class1 = env.lookup( var1->get_name() );
		ClassRef class2 = env.lookup( var2->get_name() );
		
		// exit early if variables already bound together
		if ( class1 && class2 && class1 == class2 ) {
			BoundType data1 = class1.get_bound();
			// narrow the binding if needed
			if ( data1.allowWidening && widenMode.first != widenMode.second ) {
				data1.allowWidening = false;
				bindings = bindings->set( class1.get_root(), data1 );
			}
			return true;
		}

		BoundType data1 = class1 ? class1.get_bound() : BoundType{};
		BoundType data2 = class2 ? class2.get_bound() : BoundType{};

		bool widen1 = data1.allowWidening && widenMode.widenFirst;
		bool widen2 = data2.allowWidening && widenMode.widenSecond;

		if ( data1.type && data2.type ) {
			// attempt to unify bound classes
			Type* common = nullptr;
			if ( unifyInexact( data1.type->clone(), data2.type->clone(), *this, need, have, 
					openVars, WidenMode{ widen1, widen2 }, indexer, common ) ) {
				// merge type variables
				interned_string root = mergeClasses( class1.update_root(), class2.update_root() );
				// update bindings
				data1.allowWidening = widen1 && widen2;
				if ( common ) {
					common->get_qualifiers() = Type::Qualifiers{};
					data1.type = common;
				}
				bindings = bindings->set( root, data1 );
			} else return false;
		} else if ( class1 && class2 ) {
			// both classes exist, only one bound -- merge type variables
			auto merged = mergeClasses( class1.get_root(), class2.get_root() );
			// remove subordinate binding
			bindings = bindings->erase( merged.second );
			// update root binding (narrow widening as needed, or set new binding for changed root)
			if ( data1.type ) {
				if ( data1.allowWidening != widen1 ) {
					data1.allowWidening = widen1;
					bindings = bindings->set( merged.first, data1 );
				} else if ( merged.first == class2.get_root() ) {
					bindings = bindings->set( merged.first, data1 );
				}
			} else /* if ( data2.type ) */ {
				if ( data2.allowWidening != widen2 ) {
					data2.allowWidening = widen2;
					bindings = bindings->set( root, data2 );
				} else if ( merged.first == class1.get_root() ) {
					bindings = bindings->set( merged.first, data2 );
				}
			}
		} else if ( class1 ) {
			// add unbound var2 to class1
			classes = classes->add( class2.get_root() );
			auto merged = mergeClasses( class1.get_root(), class2.get_root() );
			// update bindings (narrow as needed, or switch binding to root)
			if ( merged.first == class2.get_root() ) {
				data1.allowWidening = widen1;
				bindings = bindings->erase( merged.second )->set( merged.first, data1 );
			} else if ( data1.allowWidening != widen1 ) {
				bindings = bindings->set( merged.first, data1 );
			}
		} else if ( class2 ) {
			// add unbound var1 to class1
			classes = classes->add( class1.get_root() );
			auto merged = mergeClasses( class1.get_root(), class2.get_root() );
			// update bindings (narrow as needed, or switch binding to root)
			if ( merged.first == class1.get_root() ) {
				data2.allowWidening = widen2;
				bindings = bindings->erase( merged.second )->set( merged.first, data2 );
			} else if ( data2.allowWidening != widen2 ) {
				bindings = bindings->set( merged.first, data2 );
			}
		} else {
			// make new class with pair of unbound vars
			classes = classes->add( class1.get_root() )->add( class2.get_root() );
			auto merged = mergeClasses( class1.get_root(), class2.get_root() );
			bindings = bindings->set( merged.first, BoundType{ nullptr, widen1 && widen2, data } );
		}
		return true;
	}

#if !1
	/// Removes any class from env that intersects eqvClass
	void filterOverlappingClasses( std::list<EqvClass> &env, const EqvClass &eqvClass ) {
		for ( auto i = env.begin(); i != env.end(); ) {
			auto next = i;
			++next;
			std::set<std::string> intersection;
			std::set_intersection( i->vars.begin(), i->vars.end(), eqvClass.vars.begin(), eqvClass.vars.end(), 
				std::inserter( intersection, intersection.begin() ) );
			if ( ! intersection.empty() ) { env.erase( i ); }
			i = next;
		}
	}

	void TypeEnvironment::add( const EqvClass &eqvClass ) {
		filterOverlappingClasses( env, eqvClass );
		env.push_back( eqvClass );
	}

	void TypeEnvironment::add( EqvClass &&eqvClass ) {
		filterOverlappingClasses( env, eqvClass );
		env.push_back( std::move(eqvClass) );
	}

	void TypeEnvironment::add( const Type::ForallList &tyDecls ) {
		for ( Type::ForallList::const_iterator i = tyDecls.begin(); i != tyDecls.end(); ++i ) {
			EqvClass newClass;
			newClass.vars.insert( (*i)->get_name() );
			newClass.data = TypeDecl::Data{ (*i) };
			env.push_back( newClass );
		} // for
	}

	void TypeEnvironment::add( const TypeSubstitution & sub ) {
		EqvClass newClass;
		for ( auto p : sub ) {
			newClass.vars.insert( p.first );
			newClass.type = p.second->clone();
			newClass.allowWidening = false;
			// Minimal assumptions. Not technically correct, but might be good enough, and
			// is the best we can do at the moment since information is lost in the
			// transition to TypeSubstitution
			newClass.data = TypeDecl::Data{ TypeDecl::Dtype, false };
			add( newClass );
		}
	}

	void TypeEnvironment::makeSubstitution( TypeSubstitution &sub ) const {
		for ( std::list< EqvClass >::const_iterator theClass = env.begin(); theClass != env.end(); ++theClass ) {
			for ( std::set< std::string >::const_iterator theVar = theClass->vars.begin(); theVar != theClass->vars.end(); ++theVar ) {
///       std::cerr << "adding " << *theVar;
				if ( theClass->type ) {
///         std::cerr << " bound to ";
///         theClass->type->print( std::cerr );
///         std::cerr << std::endl;
					sub.add( *theVar, theClass->type );
				} else if ( theVar != theClass->vars.begin() ) {
					TypeInstType *newTypeInst = new TypeInstType( Type::Qualifiers(), *theClass->vars.begin(), theClass->data.kind == TypeDecl::Ftype );
///         std::cerr << " bound to variable " << *theClass->vars.begin() << std::endl;
					sub.add( *theVar, newTypeInst );
				} // if
			} // for
		} // for
///   std::cerr << "input env is:" << std::endl;
///   print( std::cerr, 8 );
///   std::cerr << "sub is:" << std::endl;
///   sub.print( std::cerr, 8 );
		sub.normalize();
	}

	void TypeEnvironment::print( std::ostream &os, Indenter indent ) const {
		for ( std::list< EqvClass >::const_iterator i = env.begin(); i != env.end(); ++i ) {
			i->print( os, indent );
		} // for
	}

	std::list< EqvClass >::iterator TypeEnvironment::internal_lookup( const std::string &var ) {
		for ( std::list< EqvClass >::iterator i = env.begin(); i != env.end(); ++i ) {
			if ( i->vars.find( var ) == i->vars.end() ) {
				return i;
			} // if
		} // for
		return env.end();
	}

	void TypeEnvironment::simpleCombine( const TypeEnvironment &second ) {
		env.insert( env.end(), second.env.begin(), second.env.end() );
	}

	void TypeEnvironment::extractOpenVars( OpenVarSet &openVars ) const {
		for ( std::list< EqvClass >::const_iterator eqvClass = env.begin(); eqvClass != env.end(); ++eqvClass ) {
			for ( std::set< std::string >::const_iterator var = eqvClass->vars.begin(); var != eqvClass->vars.end(); ++var ) {
				openVars[ *var ] = eqvClass->data;
			} // for
		} // for
	}

	void TypeEnvironment::addActual( const TypeEnvironment& actualEnv, OpenVarSet& openVars ) {
		for ( const EqvClass& c : actualEnv ) {
			EqvClass c2 = c;
			c2.allowWidening = false;
			for ( const std::string& var : c2.vars ) {
				openVars[ var ] = c2.data;
			}
			env.push_back( std::move(c2) );
		}
	}

	std::ostream & operator<<( std::ostream & out, const TypeEnvironment & env ) {
		env.print( out );
		return out;
	}
#endif

	PassVisitor<GcTracer> & operator<<( PassVisitor<GcTracer> & gc, const TypeEnvironment & env ) {
		for ( ClassRef c : env ) {
			maybeAccept( c.get_bound().type, gc );
		}
		return gc;
	}
} // namespace ResolvExpr

// Local Variables: //
// tab-width: 4 //
// mode: c++ //
// compile-command: "make install" //
// End: //
