//
// Cforall Version 1.0.0 Copyright (C) 2015 University of Waterloo
//
// The contents of this file are covered under the licence agreement in the
// file "LICENCE" distributed with Cforall.
//
// ExpandCasts.cc --
//
// Author           : Andrew Beach
// Created On       : Mon Jul 24 13:59:00 2017
// Last Modified By : Andrew Beach
// Last Modified On : Tue Jul 22 10:04:00 2020
// Update Count     : 3
//

#include "ExpandCasts.h"

#include <cassert>                 // for assert, assertf
#include <iterator>                // for back_inserter, inserter
#include <map>                     // for map, _Rb_tree_iterator, map<>::ite...
#include <string>                  // for string, allocator, operator==, ope...
#include <utility>                 // for pair

#include "Common/PassVisitor.h"    // for PassVisitor
#include "Common/SemanticError.h"  // for SemanticError
#include "SymTab/Mangler.h"        // for mangleType
#include "SynTree/Declaration.h"   // for ObjectDecl, StructDecl, FunctionDecl
#include "SynTree/Expression.h"    // for VirtualCastExpr, CastExpr, Address...
#include "SynTree/Mutator.h"       // for mutateAll
#include "SynTree/Type.h"          // for Type, PointerType, StructInstType
#include "SynTree/Visitor.h"       // for acceptAll

namespace Virtual {

	// Indented until the new ast code gets added.

	/// Maps virtual table types the instance for that type.
	class VirtualTableMap final {
		std::unordered_map<std::string, ObjectDecl *> vtable_instances;
	public:
		ObjectDecl * insert( ObjectDecl * vtableDecl ) {
			std::string const & mangledName = SymTab::Mangler::mangleType( vtableDecl->type );
			ObjectDecl *& value = vtable_instances[ mangledName ];
			if ( value ) {
				if ( vtableDecl->storageClasses.is_extern ) {
					return nullptr;
				} else if ( ! value->storageClasses.is_extern ) {
					return value;
				}
			}
			value = vtableDecl;
			return nullptr;
		}

		ObjectDecl * lookup( const Type * vtableType ) {
			std::string const & mangledName = SymTab::Mangler::mangleType( vtableType );
			const auto it = vtable_instances.find( mangledName );
			return ( vtable_instances.end() == it ) ? nullptr : it->second;
		}
	};

	/* Currently virtual depends on the rather brittle name matching between
	 * a (strict/explicate) virtual type, its vtable type and the vtable
	 * instance.
	 * A stronger implementation, would probably keep track of those triads
	 * and use that information to create better error messages.
	 */

	namespace {

	std::string get_vtable_name( std::string const & name ) {
		return name + "_vtable";
	}

	std::string get_vtable_inst_name( std::string const & name ) {
		return std::string("_") + get_vtable_name( name ) + "_instance";
	}

	std::string get_vtable_name_root( std::string const & name ) {
		return name.substr(0, name.size() - 7 );
	}

	std::string get_vtable_inst_name_root( std::string const & name ) {
		return get_vtable_name_root( name.substr(1, name.size() - 10 ) );
	}

	bool is_vtable_inst_name( std::string const & name ) {
		return 17 < name.size() &&
			name == get_vtable_inst_name( get_vtable_inst_name_root( name ) );
	}

	} // namespace

	class VirtualCastCore {
		VirtualTableMap vtable_instances;
		FunctionDecl *vcast_decl;
		StructDecl *pvt_decl;

		Type * pointer_to_pvt(int level_of_indirection) {
			Type * type = new StructInstType(
				Type::Qualifiers( Type::Const ), pvt_decl );
			for (int i = 0 ; i < level_of_indirection ; ++i) {
				type = new PointerType( noQualifiers, type );
			}
			return type;
		}

	public:
		VirtualCastCore() :
			vtable_instances(), vcast_decl( nullptr ), pvt_decl( nullptr )
		{}

		void premutate( FunctionDecl * functionDecl );
		void premutate( StructDecl * structDecl );
		void premutate( ObjectDecl * objectDecl );

		Expression * postmutate( VirtualCastExpr * castExpr );
	};

	void VirtualCastCore::premutate( FunctionDecl * functionDecl ) {
		if ( (! vcast_decl) &&
		     functionDecl->get_name() == "__cfa__virtual_cast" ) {
			vcast_decl = functionDecl;
		}
	}

	void VirtualCastCore::premutate( StructDecl * structDecl ) {
		if ( pvt_decl || ! structDecl->has_body() ) {
			return;
		} else if ( structDecl->get_name() == "__cfa__parent_vtable" ) {
			pvt_decl = structDecl;
		}
	}

	void VirtualCastCore::premutate( ObjectDecl * objectDecl ) {
		if ( is_vtable_inst_name( objectDecl->get_name() ) ) {
			if ( ObjectDecl * existing = vtable_instances.insert( objectDecl ) ) {
				std::string msg = "Repeated instance of virtual table, original found at: ";
				msg += existing->location.filename;
				msg += ":" + toString( existing->location.first_line );
				SemanticError( objectDecl->location, msg );
			}
		}
	}

	namespace {

	/// Better error locations for generated casts.
	CodeLocation castLocation( const VirtualCastExpr * castExpr ) {
		if ( castExpr->location.isSet() ) {
			return castExpr->location;
		} else if ( castExpr->arg->location.isSet() ) {
			return castExpr->arg->location;
		} else if ( castExpr->result->location.isSet() ) {
			return castExpr->result->location;
		} else {
			return CodeLocation();
		}
	}

	[[noreturn]] void castError( const VirtualCastExpr * castExpr, std::string const & message ) {
		SemanticError( castLocation( castExpr ), message );
	}

	/// Get the virtual table type used in a virtual cast.
	Type * getVirtualTableType( const VirtualCastExpr * castExpr ) {
		const Type * objectType;
		if ( auto target = dynamic_cast<const PointerType *>( castExpr->result ) ) {
			objectType = target->base;
		} else if ( auto target = dynamic_cast<const ReferenceType *>( castExpr->result ) ) {
			objectType = target->base;
		} else {
			castError( castExpr, "Virtual cast type must be a pointer or reference type." );
		}
		assert( objectType );

		const StructInstType * structType = dynamic_cast<const StructInstType *>( objectType );
		if ( nullptr == structType ) {
			castError( castExpr, "Virtual cast type must refer to a structure type." );
		}
		const StructDecl * structDecl = structType->baseStruct;
		assert( structDecl );

		const ObjectDecl * fieldDecl = nullptr;
		if ( 0 < structDecl->members.size() ) {
			const Declaration * memberDecl = structDecl->members.front();
			assert( memberDecl );
			fieldDecl = dynamic_cast<const ObjectDecl *>( memberDecl );
			if ( fieldDecl && fieldDecl->name != "virtual_table" ) {
				fieldDecl = nullptr;
			}
		}
		if ( nullptr == fieldDecl ) {
			castError( castExpr, "Virtual cast type must have a leading virtual_table field." );
		}
		const PointerType * fieldType = dynamic_cast<const PointerType *>( fieldDecl->type );
		if ( nullptr == fieldType ) {
			castError( castExpr, "Virtual cast type virtual_table field is not a pointer." );
		}
		assert( fieldType->base );
		auto virtualStructType = dynamic_cast<const StructInstType *>( fieldType->base );
		assert( virtualStructType );

		// Here is the type, but if it is polymorphic it will have lost information.
		// (Always a clone so that it may always be deleted.)
		StructInstType * virtualType = virtualStructType->clone();
		if ( ! structType->parameters.empty() ) {
			deleteAll( virtualType->parameters );
			virtualType->parameters.clear();
			cloneAll( structType->parameters, virtualType->parameters );
		}
		return virtualType;
	}

	} // namespace

	Expression * VirtualCastCore::postmutate( VirtualCastExpr * castExpr ) {
		assertf( castExpr->result, "Virtual Cast target not found before expansion." );

		assert( vcast_decl );
		assert( pvt_decl );

		const Type * vtable_type = getVirtualTableType( castExpr );
		ObjectDecl * table = vtable_instances.lookup( vtable_type );
		if ( nullptr == table ) {
			SemanticError( castLocation( castExpr ),
				"Could not find virtual table instance." );
		}

		Expression * result = new CastExpr(
			new ApplicationExpr( VariableExpr::functionPointer( vcast_decl ), {
					new CastExpr(
						new AddressExpr( new VariableExpr( table ) ),
						pointer_to_pvt(1)
					),
					new CastExpr(
						castExpr->get_arg(),
						pointer_to_pvt(2)
					)
			} ),
			castExpr->get_result()->clone()
		);

		castExpr->set_arg( nullptr );
		castExpr->set_result( nullptr );
		delete castExpr;
		delete vtable_type;
		return result;
	}

	void expandCasts( std::list< Declaration * > & translationUnit ) {
		PassVisitor<VirtualCastCore> translator;
		mutateAll( translationUnit, translator );
	}
}
