//
// Cforall Version 1.0.0 Copyright (C) 2015 University of Waterloo
//
// The contents of this file are covered under the licence agreement in the
// file "LICENCE" distributed with Cforall.
//
// ConversionCost.cc --
//
// Author           : Richard C. Bilson
// Created On       : Sun May 17 07:06:19 2015
// Last Modified By : Peter A. Buhr
// Last Modified On : Mon Sep 25 15:43:34 2017
// Update Count     : 10
//

#include "ConversionCost.h"

#include <cassert>                       // for assert
#include <list>                          // for list, list<>::const_iterator
#include <string>                        // for operator==, string

#include "ResolvExpr/Cost.h"             // for Cost
#include "ResolvExpr/TypeEnvironment.h"  // for EqvClass, TypeEnvironment
#include "SymTab/Indexer.h"              // for Indexer
#include "SynTree/Declaration.h"         // for TypeDecl, NamedTypeDecl
#include "SynTree/Type.h"                // for Type, BasicType, TypeInstType
#include "typeops.h"                     // for typesCompatibleIgnoreQualifiers

namespace ResolvExpr {
	const Cost Cost::zero =      Cost(  0,  0,  0,  0 );
	const Cost Cost::infinity =  Cost( -1, -1, -1, -1 );
	const Cost Cost::unsafe =    Cost(  1,  0,  0,  0 );
	const Cost Cost::poly =      Cost(  0,  1,  0,  0 );
	const Cost Cost::safe =      Cost(  0,  0,  1,  0 );
	const Cost Cost::reference = Cost(  0,  0,  0,  1 );

#if 0
#define PRINT(x) x
#else
#define PRINT(x)
#endif
	Cost conversionCost( Type *src, Type *dest, const SymTab::Indexer &indexer, const TypeEnvironment &env ) {
		if ( TypeInstType *destAsTypeInst = dynamic_cast< TypeInstType* >( dest ) ) {
			EqvClass eqvClass;
			NamedTypeDecl *namedType;
			PRINT( std::cerr << "type inst " << destAsTypeInst->get_name(); )
			if ( env.lookup( destAsTypeInst->get_name(), eqvClass ) ) {
				if ( eqvClass.type ) {
					return conversionCost( src, eqvClass.type, indexer, env );
				} else {
					return Cost::infinity;
				}
			} else if ( ( namedType = indexer.lookupType( destAsTypeInst->get_name() ) ) ) {
				PRINT( std::cerr << " found" << std::endl; )
				TypeDecl *type = dynamic_cast< TypeDecl* >( namedType );
				// all typedefs should be gone by this point
				assert( type );
				if ( type->get_base() ) {
					return conversionCost( src, type->get_base(), indexer, env ) + Cost::safe;
				} // if
			} // if
			PRINT( std::cerr << " not found" << std::endl; )
		} // if
		PRINT(
			std::cerr << "src is ";
			src->print( std::cerr );
			std::cerr << std::endl << "dest is ";
			dest->print( std::cerr );
			std::cerr << std::endl << "env is" << std::endl;
			env.print( std::cerr, 8 );
		)
		if ( typesCompatibleIgnoreQualifiers( src, dest, indexer, env ) ) {
			PRINT( std::cerr << "compatible!" << std::endl; )
			return Cost::zero;
		} else if ( dynamic_cast< VoidType* >( dest ) ) {
			return Cost::safe;
		} else if ( ReferenceType * refType = dynamic_cast< ReferenceType * > ( dest ) ) {
			PRINT( std::cerr << "conversionCost: dest is reference" << std::endl; )
			return convertToReferenceCost( src, refType, indexer, env, [](Type * t1, Type * t2, const TypeEnvironment & env, const SymTab::Indexer &){
				return ptrsAssignable( t1, t2, env );
			});
		} else {
			ConversionCost converter( dest, indexer, env );
			src->accept( converter );
			if ( converter.get_cost() == Cost::infinity ) {
				return Cost::infinity;
			} else {
				return converter.get_cost() + Cost::zero;
			} // if
		} // if
	}

	Cost convertToReferenceCost( Type * src, Type * dest, int diff, const SymTab::Indexer & indexer, const TypeEnvironment & env, PtrsFunction func ) {
		PRINT( std::cerr << "convert to reference cost... diff " << diff << std::endl; )
		if ( diff > 0 ) {
			// TODO: document this
			Cost cost = convertToReferenceCost( strict_dynamic_cast< ReferenceType * >( src )->get_base(), dest, diff-1, indexer, env, func );
			cost.incReference();
			return cost;
		} else if ( diff < -1 ) {
			// TODO: document this
			Cost cost = convertToReferenceCost( src, strict_dynamic_cast< ReferenceType * >( dest )->get_base(), diff+1, indexer, env, func );
			cost.incReference();
			return cost;
		} else if ( diff == 0 ) {
			ReferenceType * srcAsRef = dynamic_cast< ReferenceType * >( src );
			ReferenceType * destAsRef = dynamic_cast< ReferenceType * >( dest );
			if ( srcAsRef && destAsRef ) { // pointer-like conversions between references
				PRINT( std::cerr << "converting between references" << std::endl; )
				if ( srcAsRef->get_base()->get_qualifiers() <= destAsRef->get_base()->get_qualifiers() && typesCompatibleIgnoreQualifiers( srcAsRef->get_base(), destAsRef->get_base(), indexer, env ) ) {
					return Cost::safe;
				} else {  // xxx - this discards reference qualifiers from consideration -- reducing qualifiers is a safe conversion; is this right?
					int assignResult = func( srcAsRef->get_base(), destAsRef->get_base(), env, indexer );
					PRINT( std::cerr << "comparing references: " << assignResult << " " << srcAsRef << " " << destAsRef << std::endl; )
					if ( assignResult > 0 ) {
						return Cost::safe;
					} else if ( assignResult < 0 ) {
						return Cost::unsafe;
					} // if
				} // if
			} else {
				PRINT( std::cerr << "reference to rvalue conversion" << std::endl; )
				ConversionCost converter( dest, indexer, env );
				src->accept( converter );
				return converter.get_cost();
			} // if
		} else {
			ReferenceType * destAsRef = dynamic_cast< ReferenceType * >( dest );
			assert( diff == -1 && destAsRef );
			PRINT( std::cerr << "dest is: " << dest << " / src is: " << src << std::endl; )
			if ( typesCompatibleIgnoreQualifiers( src, destAsRef->get_base(), indexer, env ) ) {
				PRINT( std::cerr << "converting compatible base type" << std::endl; )
				if ( src->get_lvalue() ) {
					PRINT(
						std::cerr << "lvalue to reference conversion" << std::endl;
						std::cerr << src << " => " << destAsRef << std::endl;
					)
					// lvalue-to-reference conversion:  cv lvalue T => cv T &
					if ( src->get_qualifiers() == destAsRef->get_base()->get_qualifiers() ) {
						return Cost::reference; // cost needs to be non-zero to add cast
					} if ( src->get_qualifiers() < destAsRef->get_base()->get_qualifiers() ) {
						return Cost::safe; // cost needs to be higher than previous cast to differentiate adding qualifiers vs. keeping same
					} else {
						return Cost::unsafe;
					} // if
				} else if ( destAsRef->get_base()->get_const() ) {
					PRINT( std::cerr << "rvalue to const ref conversion" << std::endl; )
					// rvalue-to-const-reference conversion: T => const T &
					return Cost::safe;
				} else {
					PRINT( std::cerr << "rvalue to non-const reference conversion" << std::endl; )
					// rvalue-to-reference conversion: T => T &
					return Cost::unsafe;
				} // if
			} // if
			PRINT( std::cerr << "attempting to convert from incompatible base type -- fail" << std::endl; )
		}
		return Cost::infinity;
	}

	Cost convertToReferenceCost( Type * src, ReferenceType * dest, const SymTab::Indexer & indexer, const TypeEnvironment & env, PtrsFunction func ) {
		int sdepth = src->referenceDepth(), ddepth = dest->referenceDepth();
		return convertToReferenceCost( src, dest, sdepth-ddepth, indexer, env, func );
	}

	ConversionCost::ConversionCost( Type *dest, const SymTab::Indexer &indexer, const TypeEnvironment &env )
		: dest( dest ), indexer( indexer ), cost( Cost::infinity ), env( env ) {
	}

/*
            Old
            ===
           Double
             |
           Float
             |
           ULong
           /   \
        UInt    Long
           \   /
            Int
             |
           Ushort
             |
           Short
             |
           Uchar
           /   \
        Schar   Char

                                New
                                ===
                       +-----LongDoubleComplex--+
           LongDouble--+          |             +-LongDoubleImag
             |         +---DoubleComplex---+         |
           Double------+        |          +----DoubleImag
             |           +-FloatComplex-+            |
           Float---------+              +-------FloatImag
             |
          ULongLong
             |
          LongLong
             |
           ULong
           /   \
        UInt    Long
           \   /
            Int
             |
           Ushort
             |
           Short
             |
           Uchar
           /   \
        Schar   Char
           \   /
            Bool
*/

	static const int costMatrix[ BasicType::NUMBER_OF_BASIC_TYPES ][ BasicType::NUMBER_OF_BASIC_TYPES ] = {
	/* Src \ Dest:	Bool	Char	SChar	UChar	Short	UShort	Int 	UInt	Long	ULong	LLong	ULLong	Float	Double	LDbl	FCplex	DCplex	LDCplex	FImag	DImag	LDImag	I128,	U128 */
		/* Bool */ 	{ 0,	1,		1,		2,		3,		4,		5,		6,		6,		7,		8,		9,		12,		13,		14,		12,		13,		14,		-1,		-1,		-1,		10,		11,	},
		/* Char */ 	{ -1,	0,		-1,		1,		2,		3,		4,		5,		5,		6,		7,		8,		11,		12,		13,		11,		12,		13,		-1,		-1,		-1,		9,		10,	},
		/* SChar */ { -1,	-1,		0,		1,		2,		3,		4,		5,		5,		6,		7,		8,		11,		12,		13,		11,		12,		13,		-1,		-1,		-1,		9,		10,	},
		/* UChar */ { -1,	-1,		-1,		0,		1,		2,		3,		4,		4,		5,		6,		7,		10,		11,		12,		10,		11,		12,		-1,		-1,		-1,		8,		9,	},
		/* Short */ { -1,	-1,		-1,		-1,		0,		1,		2,		3,		3,		4,		5,		6,		9,		10,		11,		9,		10,		11,		-1,		-1,		-1,		7,		8,	},
		/* UShort */{ -1,	-1,		-1,		-1,		-1,		0,		1,		2,		2,		3,		4,		5,		8,		9,		10,		8,		9,		10,		-1,		-1,		-1,		6,		7,	},
		/* Int */ 	{ -1,	-1,		-1,		-1,		-1,		-1,		0,		1,		1,		2,		3,		4,		7,		8,		9,		7,		8,		9,		-1,		-1,		-1,		5,		6,	},
		/* UInt */ 	{ -1,	-1,		-1,		-1,		-1,		-1,		-1,		0,		-1,		1,		2,		3,		6,		7,		8,		6,		7,		8,		-1,		-1,		-1,		4,		5,	},
		/* Long */ 	{ -1,	-1,		-1,		-1,		-1,		-1,		-1,		-1,		0,		1,		2,		3,		6,		7,		8,		6,		7,		8,		-1,		-1,		-1,		4,		5,	},
		/* ULong */ { -1,	-1,		-1,		-1,		-1,		-1,		-1,		-1,		-1,		0,		1,		2,		5,		6,		7,		5,		6,		7,		-1,		-1,		-1,		3,		4,	},
		/* LLong */ { -1,	-1,		-1,		-1,		-1,		-1,		-1,		-1,		-1,		-1,		0,		1,		4,		5,		6,		4,		5,		6,		-1,		-1,		-1,		2,		3,	},
		/* ULLong */{ -1,	-1,		-1,		-1,		-1,		-1,		-1,		-1,		-1,		-1,		-1,		0,		3,		4,		5,		3,		4,		5,		-1,		-1,		-1,		1,		2,	},

		/* Float */ { -1,	-1,		-1,		-1,		-1,		-1,		-1,		-1,		-1,		-1,		-1,		-1,		0,		1,		2,		1,		2,		3,		-1,		-1,		-1,		-1,		-1,	},
		/* Double */{ -1,	-1,		-1,		-1,		-1,		-1,		-1,		-1,		-1,		-1,		-1,		-1,		-1,		0,		1,		-1,		1,		2,		-1,		-1,		-1,		-1,		-1,	},
		/* LDbl */ 	{ -1,	-1,		-1,		-1,		-1,		-1,		-1,		-1,		-1,		-1,		-1,		-1,		-1,		-1,		0,		-1,		-1,		1,		-1,		-1,		-1,		-1,		-1,	},
		/* FCplex */{ -1,	-1,		-1,		-1,		-1,		-1,		-1,		-1,		-1,		-1,		-1,		-1,		-1,		-1,		-1,		0,		1,		2,		-1,		-1,		-1,		-1,		-1,	},
		/* DCplex */{ -1,	-1,		-1,		-1,		-1,		-1,		-1,		-1,		-1,		-1,		-1,		-1,		-1,		-1,		-1,		-1,		0,		1,		-1,		-1,		-1,		-1,		-1,	},
		/* LDCplex */{ -1,	-1,		-1,		-1,		-1,		-1,		-1,		-1,		-1,		-1,		-1,		-1,		-1,		-1,		-1,		-1,		-1,		0,		-1,		-1,		-1,		-1,		-1,	},
		/* FImag */ { -1,	-1,		-1,		-1,		-1,		-1,		-1,		-1,		-1,		-1,		-1,		-1,		-1,		-1,		-1,		1,		2,		3,		0,		1,		2,		-1,		-1,	},
		/* DImag */ { -1,	-1,		-1,		-1,		-1,		-1,		-1,		-1,		-1,		-1,		-1,		-1,		-1,		-1,		-1,		-1,		1,		2,		-1,		0,		1,		-1,		-1,	},
		/* LDImag */{ -1,	-1,		-1,		-1,		-1,		-1,		-1,		-1,		-1,		-1,		-1,		-1,		-1,		-1,		-1,		-1,		-1,		1,		-1,		-1,		0,		-1,		-1,	},

		/* I128 */  { -1,	-1,		-1,		-1,		-1,		-1,		-1,		-1,		-1,		-1,		-1,		-1,		2,		3,		4,		3,		4,		5,		-1,		-1,		-1,		0,		1,	},
		/* U128 */  { -1,	-1,		-1,		-1,		-1,		-1,		-1,		-1,		-1,		-1,		-1,		-1,		1,		2,		3,		2,		3,		4,		-1,		-1,		-1,		-1,		0,	},
	};

	void ConversionCost::visit( __attribute((unused)) VoidType *voidType ) {
		cost = Cost::infinity;
	}

	void ConversionCost::visit(BasicType *basicType) {
		if ( BasicType *destAsBasic = dynamic_cast< BasicType* >( dest ) ) {
			int tableResult = costMatrix[ basicType->get_kind() ][ destAsBasic->get_kind() ];
			if ( tableResult == -1 ) {
				cost = Cost::unsafe;
			} else {
				cost = Cost::zero;
				cost.incSafe( tableResult );
			} // if
		} else if ( dynamic_cast< EnumInstType *>( dest ) ) {
			// xxx - not positive this is correct, but appears to allow casting int => enum
			cost = Cost::unsafe;
		} else if ( dynamic_cast< ZeroType* >( dest ) != nullptr || dynamic_cast< OneType* >( dest ) != nullptr ) {
			cost = Cost::unsafe;
		} // if
	}

	void ConversionCost::visit( PointerType * pointerType ) {
		if ( PointerType *destAsPtr = dynamic_cast< PointerType* >( dest ) ) {
			PRINT( std::cerr << pointerType << " ===> " << destAsPtr; )
			Type::Qualifiers tq1 = pointerType->get_base()->get_qualifiers();
			Type::Qualifiers tq2 = destAsPtr->get_base()->get_qualifiers();
			if ( tq1 <= tq2 && typesCompatibleIgnoreQualifiers( pointerType->get_base(), destAsPtr->get_base(), indexer, env ) ) {
				if ( tq1 == tq2 ) {
					// types are the same
					cost = Cost::zero;
				} else {
					// types are the same, except otherPointer has more qualifiers
					PRINT( std::cerr << " :: compatible and good qualifiers" << std::endl; )
					cost = Cost::safe;
				}
			} else {  // xxx - this discards qualifiers from consideration -- reducing qualifiers is a safe conversion; is this right?
				int assignResult = ptrsAssignable( pointerType->base, destAsPtr->base, env );
				PRINT( std::cerr << " :: " << assignResult << std::endl; )
				if ( assignResult > 0 && pointerType->get_base()->get_qualifiers() <= destAsPtr->get_qualifiers() ) {
					cost = Cost::safe;
				} else if ( assignResult < 0 ) {
					cost = Cost::unsafe;
				} // if
				// assignResult == 0 means Cost::Infinity
			} // if
		} else if ( dynamic_cast< ZeroType * >( dest ) ) {
			cost = Cost::unsafe;
		} // if
	}

	void ConversionCost::visit( ArrayType * ) {}

	void ConversionCost::visit( ReferenceType * refType ) {
		// Note: dest can never be a reference, since it would have been caught in an earlier check
		assert( ! dynamic_cast< ReferenceType * >( dest ) );
		// convert reference to rvalue: cv T1 & => T2
		// recursively compute conversion cost from T1 to T2.
		// cv can be safely dropped because of 'implicit dereference' behavior.
		refType->base->accept( *this );
		if ( refType->base->get_qualifiers() == dest->get_qualifiers() ) {
			cost.incReference();  // prefer exact qualifiers
		} else if ( refType->base->get_qualifiers() < dest->get_qualifiers() ) {
			cost.incSafe(); // then gaining qualifiers
		} else {
			cost.incUnsafe(); // lose qualifiers as last resort
		}
		PRINT( std::cerr << refType << " ==> " << dest << " " << cost << std::endl; )
	}

	void ConversionCost::visit( FunctionType * ) {}

	void ConversionCost::visit( StructInstType * inst ) {
		if ( StructInstType *destAsInst = dynamic_cast< StructInstType* >( dest ) ) {
			if ( inst->name == destAsInst->name ) {
				cost = Cost::zero;
			} // if
		} // if
	}

	void ConversionCost::visit( UnionInstType * inst ) {
		if ( UnionInstType *destAsInst = dynamic_cast< UnionInstType* >( dest ) ) {
			if ( inst->name == destAsInst->name ) {
				cost = Cost::zero;
			} // if
		} // if
	}

	void ConversionCost::visit( EnumInstType * ) {
		static Type::Qualifiers q;
		static BasicType integer( q, BasicType::SignedInt );
		integer.accept( *this );  // safe if dest >= int
		if ( cost < Cost::unsafe ) {
			cost.incSafe();
		} // if
	}

	void ConversionCost::visit( TraitInstType * ) {}

	void ConversionCost::visit( TypeInstType *inst ) {
		EqvClass eqvClass;
		NamedTypeDecl *namedType;
		if ( env.lookup( inst->get_name(), eqvClass ) ) {
			cost = conversionCost( eqvClass.type, dest, indexer, env );
		} else if ( TypeInstType *destAsInst = dynamic_cast< TypeInstType* >( dest ) ) {
			if ( inst->get_name() == destAsInst->get_name() ) {
				cost = Cost::zero;
			}
		} else if ( ( namedType = indexer.lookupType( inst->get_name() ) ) ) {
			TypeDecl *type = dynamic_cast< TypeDecl* >( namedType );
			// all typedefs should be gone by this point
			assert( type );
			if ( type->get_base() ) {
				cost = conversionCost( type->get_base(), dest, indexer, env ) + Cost::safe;
			} // if
		} // if
	}

	void ConversionCost::visit( TupleType * tupleType ) {
		Cost c = Cost::zero;
		if ( TupleType * destAsTuple = dynamic_cast< TupleType * >( dest ) ) {
			std::list< Type * >::const_iterator srcIt = tupleType->get_types().begin();
			std::list< Type * >::const_iterator destIt = destAsTuple->get_types().begin();
			while ( srcIt != tupleType->get_types().end() && destIt != destAsTuple->get_types().end() ) {
				Cost newCost = conversionCost( *srcIt++, *destIt++, indexer, env );
				if ( newCost == Cost::infinity ) {
					return;
				} // if
				c += newCost;
			} // while
			if ( destIt != destAsTuple->get_types().end() ) {
				cost = Cost::infinity;
			} else {
				cost = c;
			} // if
		} // if
	}

	void ConversionCost::visit( VarArgsType * ) {
		if ( dynamic_cast< VarArgsType* >( dest ) ) {
			cost = Cost::zero;
		}
	}

	void ConversionCost::visit( ZeroType * ) {
		if ( dynamic_cast< ZeroType * >( dest ) ) {
			cost = Cost::zero;
		} else if ( BasicType *destAsBasic = dynamic_cast< BasicType* >( dest ) ) {
			// copied from visit(BasicType*) for signed int, but +1 for safe conversions
			int tableResult = costMatrix[ BasicType::SignedInt ][ destAsBasic->get_kind() ];
			if ( tableResult == -1 ) {
				cost = Cost::unsafe;
			} else {
				cost = Cost::zero;
				cost.incSafe( tableResult + 1 );
			}
		} else if ( dynamic_cast< PointerType* >( dest ) ) {
			cost = Cost::safe;
		}
	}

	void ConversionCost::visit( OneType * ) {
		if ( dynamic_cast< OneType * >( dest ) ) {
			cost = Cost::zero;
		} else if ( BasicType *destAsBasic = dynamic_cast< BasicType* >( dest ) ) {
			// copied from visit(BasicType*) for signed int, but +1 for safe conversions
			int tableResult = costMatrix[ BasicType::SignedInt ][ destAsBasic->get_kind() ];
			if ( tableResult == -1 ) {
				cost = Cost::unsafe;
			} else {
				cost = Cost::zero;
				cost.incSafe( tableResult + 1 );
			}
		}
	}
} // namespace ResolvExpr

// Local Variables: //
// tab-width: 4 //
// mode: c++ //
// compile-command: "make install" //
// End: //
