/*
 * This file is part of the Cforall project
 *
 * $Id: ConversionCost.cc,v 1.11 2005/08/29 20:14:15 rcbilson Exp $
 *
 */

#include "ConversionCost.h"
#include "typeops.h"
#include "SynTree/Type.h"
#include "SynTree/Visitor.h"
#include "SymTab/Indexer.h"


namespace ResolvExpr {

const Cost Cost::zero = Cost( 0, 0, 0 );
const Cost Cost::infinity = Cost( -1, -1, -1 );

Cost
conversionCost( Type *src, Type *dest, const SymTab::Indexer &indexer, const TypeEnvironment &env )
{
  if( TypeInstType *destAsTypeInst = dynamic_cast< TypeInstType* >( dest ) ) {
    EqvClass eqvClass;
    NamedTypeDecl *namedType;
///     std::cout << "type inst " << destAsTypeInst->get_name();
    if( env.lookup( destAsTypeInst->get_name(), eqvClass ) ) {
      return conversionCost( src, eqvClass.type, indexer, env );
    } else if( ( namedType = indexer.lookupType( destAsTypeInst->get_name() ) ) ) {
///       std::cout << " found" << std::endl;
      TypeDecl *type = dynamic_cast< TypeDecl* >( namedType );
      // all typedefs should be gone by this point
      assert( type );
      if( type->get_base() ) {
        return conversionCost( src, type->get_base(), indexer, env ) + Cost( 0, 0, 1 );
      }
    }
///     std::cout << " not found" << std::endl;
  }
///   std::cout << "src is ";
///   src->print( std::cout );
///   std::cout << std::endl << "dest is ";
///   dest->print( std::cout );
///   std::cout << std::endl << "env is" << std::endl;
///   env.print( std::cout, 8 );
  if( typesCompatibleIgnoreQualifiers( src, dest, indexer, env ) ) {
///     std::cout << "compatible!" << std::endl;
    return Cost( 0, 0, 0 );
  } else if( dynamic_cast< VoidType* >( dest ) ) {
    return Cost( 0, 0, 1 );
  } else {
    ConversionCost converter( dest, indexer, env );
    src->accept( converter );
    if( converter.get_cost() == Cost::infinity ) {
      return Cost::infinity;
    } else {
      return converter.get_cost() + Cost( 0, 0, 0 );
    }
  }
}

ConversionCost::ConversionCost( Type *dest, const SymTab::Indexer &indexer, const TypeEnvironment &env )
  : dest( dest ), indexer( indexer ), cost( Cost::infinity ), env( env )
{
}

/*
Old
===
                     Double
                       |
                     Float
                       |
                     ULong
                     /   \
                  UInt    Long
                     \   /
                      Int
                       |
                     Ushort
                       |
                     Short
                       |
                     Uchar
                     /   \
                 Schar   Char

New
===
                               +-----LongDoubleComplex--+
                   LongDouble--+          |             +-LongDoubleImag
                       |         +---DoubleComplex---+         |
                     Double------+        |          +----DoubleImag
                       |           +-FloatComplex-+            |
                     Float---------+              +-------FloatImag
                       |
                   ULongLong
                       |
                   LongLong
                       |
                     ULong
                     /   \
                  UInt    Long
                     \   /
                      Int
                       |
                     Ushort
                       |
                     Short
                       |
                     Uchar
                     /   \
                 Schar   Char
                     \   /
                      Bool
*/

static const int costMatrix[ BasicType::NUMBER_OF_BASIC_TYPES ][ BasicType::NUMBER_OF_BASIC_TYPES ] =
{
/* Src \ Dest:	Bool	Char	SChar	UChar	Short	UShort	Int	UInt	Long	ULong	LLong	ULLong	Float	Double	LDbl	FCplex	DCplex	LDCplex	FImag	DImag	LDImag */
/* Bool */ 	{ 0,	1,	1,	2,	3,	4,	5,	6,	6,	7,	8,	9,	10,	11,	12,	11,	12,	13,	-1,	-1,	-1 },
/* Char */ 	{ -1,	0,	-1,	1,	2,	3,	4,	5,	5,	6,	7,	8,	9,	10,	11,	10,	11,	12,	-1,	-1,	-1 },
/* SChar */ 	{ -1,	-1,	0,	1,	2,	3,	4,	5,	5,	6,	7,	8,	9,	10,	11,	10,	11,	12,	-1,	-1,	-1 },
/* UChar */ 	{ -1,	-1,	-1,	0,	1,	2,	3,	4,	4,	5,	6,	7,	8,	9,	10,	9,	10,	11,	-1,	-1,	-1 },
/* Short */ 	{ -1,	-1,	-1,	-1,	0,	1,	2,	3,	3,	4,	5,	6,	7,	8,	9,	8,	9,	10,	-1,	-1,	-1 },
/* UShort */ 	{ -1,	-1,	-1,	-1,	-1,	0,	1,	2,	2,	3,	4,	5,	6,	7,	8,	7,	8,	9,	-1,	-1,	-1 },
/* Int */ 	{ -1,	-1,	-1,	-1,	-1,	-1,	0,	1,	1,	2,	3,	4,	5,	6,	7,	6,	7,	8,	-1,	-1,	-1 },
/* UInt */ 	{ -1,	-1,	-1,	-1,	-1,	-1,	-1,	0,	-1,	1,	2,	3,	4,	5,	6,	5,	6,	7,	-1,	-1,	-1 },
/* Long */ 	{ -1,	-1,	-1,	-1,	-1,	-1,	-1,	-1,	0,	1,	2,	3,	4,	5,	6,	5,	6,	7,	-1,	-1,	-1 },
/* ULong */ 	{ -1,	-1,	-1,	-1,	-1,	-1,	-1,	-1,	-1,	0,	1,	2,	3,	4,	5,	4,	5,	6,	-1,	-1,	-1 },
/* LLong */ 	{ -1,	-1,	-1,	-1,	-1,	-1,	-1,	-1,	-1,	-1,	0,	1,	2,	3,	4,	3,	4,	5,	-1,	-1,	-1 },
/* ULLong */ 	{ -1,	-1,	-1,	-1,	-1,	-1,	-1,	-1,	-1,	-1,	-1,	0,	1,	2,	3,	2,	3,	4,	-1,	-1,	-1 },
/* Float */ 	{ -1,	-1,	-1,	-1,	-1,	-1,	-1,	-1,	-1,	-1,	-1,	-1,	0,	1,	2,	1,	2,	3,	-1,	-1,	-1 },
/* Double */ 	{ -1,	-1,	-1,	-1,	-1,	-1,	-1,	-1,	-1,	-1,	-1,	-1,	-1,	0,	1,	-1,	1,	2,	-1,	-1,	-1 },
/* LDbl */ 	{ -1,	-1,	-1,	-1,	-1,	-1,	-1,	-1,	-1,	-1,	-1,	-1,	-1,	-1,	0,	-1,	-1,	1,	-1,	-1,	-1 },
/* FCplex */ 	{ -1,	-1,	-1,	-1,	-1,	-1,	-1,	-1,	-1,	-1,	-1,	-1,	-1,	-1,	-1,	0,	1,	2,	-1,	-1,	-1 },
/* DCplex */ 	{ -1,	-1,	-1,	-1,	-1,	-1,	-1,	-1,	-1,	-1,	-1,	-1,	-1,	-1,	-1,	-1,	0,	1,	-1,	-1,	-1 },
/* LDCplex */ 	{ -1,	-1,	-1,	-1,	-1,	-1,	-1,	-1,	-1,	-1,	-1,	-1,	-1,	-1,	-1,	-1,	-1,	0,	-1,	-1,	-1 },
/* FImag */ 	{ -1,	-1,	-1,	-1,	-1,	-1,	-1,	-1,	-1,	-1,	-1,	-1,	-1,	-1,	-1,	1,	2,	3,	0,	1,	2 },
/* DImag */ 	{ -1,	-1,	-1,	-1,	-1,	-1,	-1,	-1,	-1,	-1,	-1,	-1,	-1,	-1,	-1,	-1,	1,	2,	-1,	0,	1 },
/* LDImag */ 	{ -1,	-1,	-1,	-1,	-1,	-1,	-1,	-1,	-1,	-1,	-1,	-1,	-1,	-1,	-1,	-1,	-1,	1,	-1,	-1,	0 }
};

void 
ConversionCost::visit(VoidType *voidType)
{
  cost = Cost::infinity;
}

void 
ConversionCost::visit(BasicType *basicType)
{
  if( BasicType *destAsBasic = dynamic_cast< BasicType* >( dest ) ) {
    int tableResult = costMatrix[ basicType->get_kind() ][ destAsBasic->get_kind() ];
    if( tableResult == -1 ) {
      cost = Cost( 1, 0, 0 );
    } else {
      cost = Cost( 0, 0, tableResult );
    }
  }
}

void 
ConversionCost::visit(PointerType *pointerType)
{
  if( PointerType *destAsPtr = dynamic_cast< PointerType* >( dest ) ) {
    if( pointerType->get_base()->get_qualifiers() <= destAsPtr->get_base()->get_qualifiers() && typesCompatibleIgnoreQualifiers( pointerType->get_base(), destAsPtr->get_base(), indexer, env ) ) {
      cost = Cost( 0, 0, 1 );
    } else {
      int assignResult = ptrsAssignable( pointerType->get_base(), destAsPtr->get_base(), env );
      if( assignResult < 0 ) {
        cost = Cost( 0, 0, 1 );
      } else if( assignResult > 0 ) {
        cost = Cost( 1, 0, 0 );
      }
    }
  }
}

void 
ConversionCost::visit(ArrayType *arrayType)
{
}

void 
ConversionCost::visit(FunctionType *functionType)
{
}

void 
ConversionCost::visit(StructInstType *inst)
{
  if( StructInstType *destAsInst = dynamic_cast< StructInstType* >( dest ) ) {
    if( inst->get_name() == destAsInst->get_name() ) {
      cost = Cost::zero;
    }
  }
}

void 
ConversionCost::visit(UnionInstType *inst)
{
  if( StructInstType *destAsInst = dynamic_cast< StructInstType* >( dest ) ) {
    if( inst->get_name() == destAsInst->get_name() ) {
      cost = Cost::zero;
    }
  }
}

void 
ConversionCost::visit(EnumInstType *inst)
{
  static Type::Qualifiers q;
  static BasicType integer( q, BasicType::SignedInt );
  integer.accept( *this );
  if( cost < Cost( 1, 0, 0 ) ) {
    cost.incSafe();
  }
}

void 
ConversionCost::visit(ContextInstType *inst)
{
}

void 
ConversionCost::visit(TypeInstType *inst)
{
  EqvClass eqvClass;
  NamedTypeDecl *namedType;
  if( env.lookup( inst->get_name(), eqvClass ) ) {
    cost = conversionCost( eqvClass.type, dest, indexer, env );
  } else if( TypeInstType *destAsInst = dynamic_cast< TypeInstType* >( dest ) ) {
    if( inst->get_name() == destAsInst->get_name() ) {
      cost = Cost::zero;
    }
  } else if( ( namedType = indexer.lookupType( inst->get_name() ) ) ) {
    TypeDecl *type = dynamic_cast< TypeDecl* >( namedType );
    // all typedefs should be gone by this point
    assert( type );
    if( type->get_base() ) {
      cost = conversionCost( type->get_base(), dest, indexer, env ) + Cost( 0, 0, 1 );
    }
  }
}

void 
ConversionCost::visit(TupleType *tupleType)
{
  Cost c;
  if( TupleType *destAsTuple = dynamic_cast< TupleType* >( dest ) ) {
    std::list< Type* >::const_iterator srcIt = tupleType->get_types().begin();
    std::list< Type* >::const_iterator destIt = destAsTuple->get_types().begin();
    while( srcIt != tupleType->get_types().end() ) {
      Cost newCost = conversionCost( *srcIt++, *destIt++, indexer, env );
      if( newCost == Cost::infinity ) {
        return;
      }
      c += newCost;
    }
    if( destIt != destAsTuple->get_types().end() ) {
      cost = Cost::infinity;
    } else {
      cost = c;
    }
  }
}

} // namespace ResolvExpr
