//
// Cforall Version 1.0.0 Copyright (C) 2015 University of Waterloo
//
// The contents of this file are covered under the licence agreement in the
// file "LICENCE" distributed with Cforall.
//
// InstantiateGeneric.cc --
//
// Author           : Aaron B. Moss
// Created On       : Wed Nov 11 14:55:01 2015
// Last Modified By : Aaron B. Moss
// Last Modified On : Wed Nov 11 14:55:01 2015
// Update Count     : 1
//

#include <list>
#include <map>
#include <string>
#include <utility>
#include <vector>

#include "InstantiateGeneric.h"
#include "DeclMutator.h"

#include "ResolvExpr/typeops.h"
#include "SymTab/Indexer.h"
#include "SynTree/Declaration.h"
#include "SynTree/Mutator.h"
#include "SynTree/Statement.h"
#include "SynTree/Type.h"
#include "SynTree/TypeSubstitution.h"

#include "UniqueName.h"
#include "utility.h"

namespace GenPoly {

	/// Key for a unique concrete type; generic base type paired with type parameter list
	struct ConcreteType {
		ConcreteType() : base(NULL), params() {}

		ConcreteType(AggregateDecl *_base, const std::list< Type* >& _params) : base(_base), params() { cloneAll(_params, params); }

		ConcreteType(const ConcreteType& that) : base(that.base), params() { cloneAll(that.params, params); }

		/// Extracts types from a list of TypeExpr*
		ConcreteType(AggregateDecl *_base, const std::list< TypeExpr* >& _params) : base(_base), params() {
			for ( std::list< TypeExpr* >::const_iterator param = _params.begin(); param != _params.end(); ++param ) {
				params.push_back( (*param)->get_type()->clone() );
			}
		}

		ConcreteType& operator= (const ConcreteType& that) {
			deleteAll( params );
			params.clear();

			base = that.base;
			cloneAll( that.params, params );

			return *this;
		}

		~ConcreteType() { deleteAll( params ); }

		bool operator== (const ConcreteType& that) const {
			if ( base != that.base ) return false;

			SymTab::Indexer dummy;
			if ( params.size() != that.params.size() ) return false;
			for ( std::list< Type* >::const_iterator it = params.begin(), jt = that.params.begin(); it != params.end(); ++it, ++jt ) {
				if ( ! ResolvExpr::typesCompatible( *it, *jt, dummy ) ) return false;
			}
			return true;
		}

		AggregateDecl *base;        ///< Base generic type
		std::list< Type* > params;  ///< Instantiation parameters
	};
	
	/// Maps a concrete type to the instantiated struct type, accounting for scope
	class InstantiationMap {
		/// Instantiation of a generic type, with key information to find it
		struct Instantiation {
			ConcreteType key;     ///< Instantiation parameters for this type
			AggregateDecl *decl;  ///< Declaration of the instantiated generic type

			Instantiation() : key(), decl(0) {}
			Instantiation(const ConcreteType &_key, AggregateDecl *_decl) : key(_key), decl(_decl) {}
		};
		/// Map of generic types to instantiations of them
		typedef std::map< AggregateDecl*, std::vector< Instantiation > > Scope;

		std::vector< Scope > scopes;  ///< list of scopes, from outermost to innermost

	public:
		/// Starts a new scope
		void beginScope() {
			Scope scope;
			scopes.push_back(scope);
		}

		/// Ends a scope
		void endScope() {
			scopes.pop_back();
		}

		/// Default constructor initializes with one scope
		InstantiationMap() { beginScope(); }

	private:
		/// Gets the declaration for the concrete instantiation of this type, assuming it has already been instantiated in the current scope.
		/// Returns NULL on none such.
		AggregateDecl* lookup( AggregateDecl *generic, const std::list< TypeExpr* >& params ) {
			ConcreteType key(generic, params);
			// scan scopes from innermost out
			for ( std::vector< Scope >::const_reverse_iterator scope = scopes.rbegin(); scope != scopes.rend(); ++scope ) {
				// skip scope if no instantiations of this generic type
				Scope::const_iterator insts = scope->find( generic );
				if ( insts == scope->end() ) continue;
				// look through instantiations for matches to concrete type
				for ( std::vector< Instantiation >::const_iterator inst = insts->second.begin(); inst != insts->second.end(); ++inst ) {
					if ( inst->key == key ) return inst->decl;
				}
			}
			// no matching instantiation found
			return NULL;
		}
	public:
		StructDecl* lookup( StructInstType *inst, const std::list< TypeExpr* > &typeSubs ) { return (StructDecl*)lookup( inst->get_baseStruct(), typeSubs ); }
		UnionDecl* lookup( UnionInstType *inst, const std::list< TypeExpr* > &typeSubs ) { return (UnionDecl*)lookup( inst->get_baseUnion(), typeSubs ); }

	private:
		/// Adds a declaration for a concrete type to the current scope
		void insert( AggregateDecl *generic, const std::list< TypeExpr* > &params, AggregateDecl *decl ) {
			ConcreteType key(generic, params);
			scopes.back()[generic].push_back( Instantiation( key, decl ) );
		}
	public:
		void insert( StructInstType *inst, const std::list< TypeExpr* > &typeSubs, StructDecl *decl ) { insert( inst->get_baseStruct(), typeSubs, decl ); }
		void insert( UnionInstType *inst, const std::list< TypeExpr* > &typeSubs, UnionDecl *decl ) { insert( inst->get_baseUnion(), typeSubs, decl ); }
	};

	/// Mutator pass that replaces concrete instantiations of generic types with actual struct declarations, scoped appropriately
	class Instantiate : public DeclMutator {
		/// Map of (generic type, parameter list) pairs to concrete type instantiations
		InstantiationMap instantiations;
		/// Namer for concrete types
		UniqueName typeNamer;

	public:
		Instantiate() : DeclMutator(), instantiations(), typeNamer("_conc_") {}

		virtual Type* mutate( StructInstType *inst );
		virtual Type* mutate( UnionInstType *inst );

// 		virtual Expression* mutate( MemberExpr *memberExpr );
		
		virtual void doBeginScope();
		virtual void doEndScope();
	};
	
	void instantiateGeneric( std::list< Declaration* >& translationUnit ) {
		Instantiate instantiator;
		instantiator.mutateDeclarationList( translationUnit );
	}

	/// Makes substitutions of params into baseParams; returns true if all parameters substituted for a concrete type
	bool makeSubstitutions( const std::list< TypeDecl* >& baseParams, const std::list< Expression* >& params, std::list< TypeExpr* >& out ) {
 		bool allConcrete = true;  // will finish the substitution list even if they're not all concrete

		// substitute concrete types for given parameters, and incomplete types for placeholders
		std::list< TypeDecl* >::const_iterator baseParam = baseParams.begin();
		std::list< Expression* >::const_iterator param = params.begin();
		for ( ; baseParam != baseParams.end() && param != params.end(); ++baseParam, ++param ) {
// 			switch ( (*baseParam)->get_kind() ) {
// 			case TypeDecl::Any: {   // any type is a valid substitution here; complete types can be used to instantiate generics
				TypeExpr *paramType = dynamic_cast< TypeExpr* >( *param );
				assert(paramType && "Aggregate parameters should be type expressions");
				out.push_back( paramType->clone() );
				// check that the substituted type isn't a type variable itself
				if ( dynamic_cast< TypeInstType* >( paramType->get_type() ) ) {
 					allConcrete = false;
				}
// 				break;
// 			}
// 			case TypeDecl::Dtype:  // dtype can be consistently replaced with void [only pointers, which become void*]
// 				out.push_back( new TypeExpr( new VoidType( Type::Qualifiers() ) ) );
// 				break;
// 			case TypeDecl::Ftype:  // pointer-to-ftype can be consistently replaced with void (*)(void) [similar to dtype]
// 				out.push_back( new TypeExpr( new FunctionType( Type::Qualifiers(), false ) ) );
// 				break;
// 			}
		}

		// if any parameters left over, not done
		if ( baseParam != baseParams.end() ) return false;
// 		// if not enough parameters given, substitute remaining incomplete types for placeholders
// 		for ( ; baseParam != baseParams.end(); ++baseParam ) {
// 			switch ( (*baseParam)->get_kind() ) {
// 			case TypeDecl::Any:    // no more substitutions here, fail early
// 				return false;
// 			case TypeDecl::Dtype:  // dtype can be consistently replaced with void [only pointers, which become void*]
// 				out.push_back( new TypeExpr( new VoidType( Type::Qualifiers() ) ) );
// 				break;
// 			case TypeDecl::Ftype:  // pointer-to-ftype can be consistently replaced with void (*)(void) [similar to dtype]
// 				out.push_back( new TypeExpr( new FunctionType( Type::Qualifiers(), false ) ) );
// 				break;
// 			}
// 		}

 		return allConcrete;
	}

	/// Substitutes types of members of in according to baseParams => typeSubs, appending the result to out
	void substituteMembers( const std::list< Declaration* >& in, const std::list< TypeDecl* >& baseParams, const std::list< TypeExpr* >& typeSubs, 
						    std::list< Declaration* >& out ) {
		// substitute types into new members
		TypeSubstitution subs( baseParams.begin(), baseParams.end(), typeSubs.begin() );
		for ( std::list< Declaration* >::const_iterator member = in.begin(); member != in.end(); ++member ) {
			Declaration *newMember = (*member)->clone();
			subs.apply(newMember);
			out.push_back( newMember );
		}
	}

	Type* Instantiate::mutate( StructInstType *inst ) {
		// mutate subtypes
		Type *mutated = Mutator::mutate( inst );
		inst = dynamic_cast< StructInstType* >( mutated );
		if ( ! inst ) return mutated;

		// exit early if no need for further mutation
		if ( inst->get_parameters().empty() ) return inst;
		assert( inst->get_baseParameters() && "Base struct has parameters" );

		// check if type can be concretely instantiated; put substitutions into typeSubs
		std::list< TypeExpr* > typeSubs;
		if ( ! makeSubstitutions( *inst->get_baseParameters(), inst->get_parameters(), typeSubs ) ) {
			deleteAll( typeSubs );
			return inst;
		}
		
		// make concrete instantiation of generic type
		StructDecl *concDecl = instantiations.lookup( inst, typeSubs );
		if ( ! concDecl ) {
			// set concDecl to new type, insert type declaration into statements to add
			concDecl = new StructDecl( typeNamer.newName( inst->get_name() ) );
			substituteMembers( inst->get_baseStruct()->get_members(), *inst->get_baseParameters(), typeSubs, 	concDecl->get_members() );
			DeclMutator::addDeclaration( concDecl );
			instantiations.insert( inst, typeSubs, concDecl );
		}
		StructInstType *newInst = new StructInstType( inst->get_qualifiers(), concDecl->get_name() );
		newInst->set_baseStruct( concDecl );

		deleteAll( typeSubs );
		delete inst;
		return newInst;
	}
	
	Type* Instantiate::mutate( UnionInstType *inst ) {
		// mutate subtypes
		Type *mutated = Mutator::mutate( inst );
		inst = dynamic_cast< UnionInstType* >( mutated );
		if ( ! inst ) return mutated;

		// exit early if no need for further mutation
		if ( inst->get_parameters().empty() ) return inst;
		assert( inst->get_baseParameters() && "Base union has parameters" );

		// check if type can be concretely instantiated; put substitutions into typeSubs
		std::list< TypeExpr* > typeSubs;
		if ( ! makeSubstitutions( *inst->get_baseParameters(), inst->get_parameters(), typeSubs ) ) {
			deleteAll( typeSubs );
			return inst;
		}
		
		// make concrete instantiation of generic type
		UnionDecl *concDecl = instantiations.lookup( inst, typeSubs );
		if ( ! concDecl ) {
			// set concDecl to new type, insert type declaration into statements to add
			concDecl = new UnionDecl( typeNamer.newName( inst->get_name() ) );
			substituteMembers( inst->get_baseUnion()->get_members(), *inst->get_baseParameters(), typeSubs, concDecl->get_members() );
			DeclMutator::addDeclaration( concDecl );
			instantiations.insert( inst, typeSubs, concDecl );
		}
		UnionInstType *newInst = new UnionInstType( inst->get_qualifiers(), concDecl->get_name() );
		newInst->set_baseUnion( concDecl );

		deleteAll( typeSubs );
		delete inst;
		return newInst;
	}

// 	/// Gets the base struct or union declaration for a member expression; NULL if not applicable
// 	AggregateDecl* getMemberBaseDecl( MemberExpr *memberExpr ) {
// 		// get variable for member aggregate
// 		VariableExpr *varExpr = dynamic_cast< VariableExpr* >( memberExpr->get_aggregate() );
// 		if ( ! varExpr ) return NULL;
// 
// 		// get object for variable
// 		ObjectDecl *objectDecl = dynamic_cast< ObjectDecl* >( varExpr->get_var() );
// 		if ( ! objectDecl ) return NULL;
// 
// 		// get base declaration from object type
// 		Type *objectType = objectDecl->get_type();
// 		StructInstType *structType = dynamic_cast< StructInstType* >( objectType );
// 		if ( structType ) return structType->get_baseStruct();
// 		UnionInstType *unionType = dynamic_cast< UnionInstType* >( objectType );
// 		if ( unionType ) return unionType->get_baseUnion();
// 
// 		return NULL;
// 	}
// 
// 	/// Finds the declaration with the given name, returning decls.end() if none such
// 	std::list< Declaration* >::const_iterator findDeclNamed( const std::list< Declaration* > &decls, const std::string &name ) {
// 		for( std::list< Declaration* >::const_iterator decl = decls.begin(); decl != decls.end(); ++decl ) {
// 			if ( (*decl)->get_name() == name ) return decl;
// 		}
// 		return decls.end();
// 	}
// 	
// 	Expression* Instantiate::mutate( MemberExpr *memberExpr ) {
// 		// mutate, exiting early if no longer MemberExpr
// 		Expression *expr = Mutator::mutate( memberExpr );
// 		memberExpr = dynamic_cast< MemberExpr* >( expr );
// 		if ( ! memberExpr ) return expr;
// 
// 		// get declaration of member and base declaration of member, exiting early if not found
// 		AggregateDecl *memberBase = getMemberBaseDecl( memberExpr );
// 		if ( ! memberBase ) return memberExpr;
// 		DeclarationWithType *memberDecl = memberExpr->get_member();
// 		std::list< Declaration* >::const_iterator baseIt = findDeclNamed( memberBase->get_members(), memberDecl->get_name() );
// 		if ( baseIt == memberBase->get_members().end() ) return memberExpr;
// 		DeclarationWithType *baseDecl = dynamic_cast< DeclarationWithType* >( *baseIt );
// 		if ( ! baseDecl ) return memberExpr;
// 
// 		// check if stated type of the member is not the type of the member's declaration; if so, need a cast
// 		// this *SHOULD* be safe, I don't think anything but the void-replacements I put in for dtypes would make it past the typechecker
// 		SymTab::Indexer dummy;
// 		if ( ResolvExpr::typesCompatible( memberDecl->get_type(), baseDecl->get_type(), dummy ) ) return memberExpr;
// 		else return new CastExpr( memberExpr, memberDecl->get_type() );
// 	}
	
	void Instantiate::doBeginScope() {
		DeclMutator::doBeginScope();
		instantiations.beginScope();
	}

	void Instantiate::doEndScope() {
		DeclMutator::doEndScope();
		instantiations.endScope();
	}
	
}  // namespace GenPoly

// Local Variables: //
// tab-width: 4 //
// mode: c++ //
// compile-command: "make install" //
// End: //
