/*
 * This file is part of the Cforall project
 *
 * $Id: TypeSubstitution.cc,v 1.9 2005/08/29 20:59:26 rcbilson Exp $
 *
 */

#include "Type.h"
#include "TypeSubstitution.h"


TypeSubstitution::TypeSubstitution()
{
}

TypeSubstitution::TypeSubstitution( const TypeSubstitution &other )
{
    initialize( other, *this );
}

TypeSubstitution::~TypeSubstitution()
{
    for( TypeEnvType::iterator i = typeEnv.begin(); i != typeEnv.end(); ++i ) {
	delete( i->second );
    }
    for( VarEnvType::iterator i = varEnv.begin(); i != varEnv.end(); ++i ) {
	delete( i->second );
    }
}

TypeSubstitution &
TypeSubstitution::operator=( const TypeSubstitution &other )
{
    if( this == &other ) return *this;
    initialize( other, *this );
    return *this;
}

void 
TypeSubstitution::initialize( const TypeSubstitution &src, TypeSubstitution &dest )
{
    dest.typeEnv.clear();
    dest.varEnv.clear();
    dest.add( src );
}

void 
TypeSubstitution::add( const TypeSubstitution &other )
{
    for( TypeEnvType::const_iterator i = other.typeEnv.begin(); i != other.typeEnv.end(); ++i ) {
	typeEnv[ i->first ] = i->second->clone();
    }
    for( VarEnvType::const_iterator i = other.varEnv.begin(); i != other.varEnv.end(); ++i ) {
	varEnv[ i->first ] = i->second->clone();
    }
}

void 
TypeSubstitution::add( std::string formalType, Type *actualType )
{
    TypeEnvType::iterator i = typeEnv.find( formalType );
    if( i != typeEnv.end() ) {
	delete i->second;
    }
    typeEnv[ formalType ] = actualType->clone();
}

void 
TypeSubstitution::remove( std::string formalType )
{
    TypeEnvType::iterator i = typeEnv.find( formalType );
    if( i != typeEnv.end() ) {
	delete i->second;
	typeEnv.erase( formalType );
    }
}

Type *
TypeSubstitution::lookup( std::string formalType ) const
{
    TypeEnvType::const_iterator i = typeEnv.find( formalType );
    if( i == typeEnv.end() ) {
	return 0;
    } else {
	return i->second;
    }
}

bool 
TypeSubstitution::empty() const
{
    return typeEnv.empty() && varEnv.empty();
}

void
TypeSubstitution::normalize()
{
    do {
	subCount = 0;
	freeOnly = true;
	for( TypeEnvType::iterator i = typeEnv.begin(); i != typeEnv.end(); ++i ) {
	    i->second = i->second->acceptMutator( *this );
	}
    } while( subCount );
}

Type* 
TypeSubstitution::mutate(TypeInstType *inst)
{
    BoundVarsType::const_iterator bound = boundVars.find( inst->get_name() );
    if( bound != boundVars.end() ) return inst;
    
    TypeEnvType::const_iterator i = typeEnv.find( inst->get_name() );
    if( i == typeEnv.end() ) {
	return inst;
    } else {
///	    std::cout << "found " << inst->get_name() << ", replacing with ";
///	    i->second->print( std::cout );
///	    std::cout << std::endl;
	subCount++;
	Type *newtype = i->second->clone();
	newtype->get_qualifiers() += inst->get_qualifiers();
	delete inst;
	return newtype;
    }
}

Expression* 
TypeSubstitution::mutate(NameExpr *nameExpr)
{
    VarEnvType::const_iterator i = varEnv.find( nameExpr->get_name() );
    if( i == varEnv.end() ) {
	return nameExpr;
    } else {
	subCount++;
	delete nameExpr;
	return i->second->clone();
    }
}

template< typename TypeClass >
Type *
TypeSubstitution::handleType( TypeClass *type )
{
    BoundVarsType oldBoundVars( boundVars );
    if( freeOnly ) {
	for( std::list< TypeDecl* >::const_iterator tyvar = type->get_forall().begin(); tyvar != type->get_forall().end(); ++tyvar ) {
	    boundVars.insert( (*tyvar)->get_name() );
	}
    }
    Type *ret = Mutator::mutate( type );
    boundVars = oldBoundVars;
    return ret;
}

Type* 
TypeSubstitution::mutate(VoidType *basicType)
{
    return handleType( basicType );
}

Type* 
TypeSubstitution::mutate(BasicType *basicType)
{
    return handleType( basicType );
}

Type* 
TypeSubstitution::mutate(PointerType *pointerType)
{
    return handleType( pointerType );
}

Type* 
TypeSubstitution::mutate(ArrayType *arrayType)
{
    return handleType( arrayType );
}

Type* 
TypeSubstitution::mutate(FunctionType *functionType)
{
    return handleType( functionType );
}

Type* 
TypeSubstitution::mutate(StructInstType *aggregateUseType)
{
    return handleType( aggregateUseType );
}

Type* 
TypeSubstitution::mutate(UnionInstType *aggregateUseType)
{
    return handleType( aggregateUseType );
}

Type* 
TypeSubstitution::mutate(EnumInstType *aggregateUseType)
{
    return handleType( aggregateUseType );
}

Type* 
TypeSubstitution::mutate(ContextInstType *aggregateUseType)
{
    return handleType( aggregateUseType );
}

Type* 
TypeSubstitution::mutate(TupleType *tupleType)
{
    return handleType( tupleType );
}

void 
TypeSubstitution::print( std::ostream &os, int indent ) const
{
    os << std::string( indent, ' ' ) << "Types:" << std::endl;
    for( TypeEnvType::const_iterator i = typeEnv.begin(); i != typeEnv.end(); ++i ) {
	os << std::string( indent+2, ' ' ) << i->first << " -> ";
	i->second->print( os, indent+4 );
	os << std::endl;
    }
    os << std::string( indent, ' ' ) << "Non-types:" << std::endl;
    for( VarEnvType::const_iterator i = varEnv.begin(); i != varEnv.end(); ++i ) {
	os << std::string( indent+2, ' ' ) << i->first << " -> ";
	i->second->print( os, indent+4 );
	os << std::endl;
    }
}

