//
// Cforall Version 1.0.0 Copyright (C) 2015 University of Waterloo
//
// The contents of this file are covered under the licence agreement in the
// file "LICENCE" distributed with Cforall.
//
// CandidateFinder.cpp --
//
// Author           : Aaron B. Moss
// Created On       : Wed Jun 5 14:30:00 2019
// Last Modified By : Aaron B. Moss
// Last Modified On : Wed Jun 5 14:30:00 2019
// Update Count     : 1
//

#include "CandidateFinder.hpp"

#include <iterator>               // for back_inserter
#include <sstream>
#include <string>
#include <unordered_map>

#include "Candidate.hpp"
#include "CompilationState.h"
#include "Cost.h"
#include "Resolver.h"
#include "SatisfyAssertions.hpp"
#include "typeops.h"              // for adjustExprType
#include "Unify.h"
#include "AST/Expr.hpp"
#include "AST/Node.hpp"
#include "AST/Pass.hpp"
#include "AST/Print.hpp"
#include "AST/SymbolTable.hpp"
#include "SymTab/Mangler.h"

#define PRINT( text ) if ( resolvep ) { text }

namespace ResolvExpr {

namespace {

	/// Actually visits expressions to find their candidate interpretations
	struct Finder final : public ast::WithShortCircuiting {
		CandidateFinder & selfFinder;
		const ast::SymbolTable & symtab;
		CandidateList & candidates;
		const ast::TypeEnvironment & tenv;
		ast::ptr< ast::Type > & targetType;

		Finder( CandidateFinder & f )
		: selfFinder( f ), symtab( f.symtab ), candidates( f.candidates ), tenv( f.env ), 
		  targetType( f.targetType ) {}
		
		void previsit( const ast::Node * ) { visit_children = false; }

		/// Convenience to add candidate to list
		template<typename... Args>
		void addCandidate( Args &&... args ) {
			candidates.emplace_back( new Candidate{ std::forward<Args>( args )... } );
		}

		void postvisit( const ast::ApplicationExpr * applicationExpr ) {
			addCandidate( applicationExpr, tenv );
		}

		void postvisit( const ast::UntypedExpr * untypedExpr ) {
			#warning unimplemented
			(void)untypedExpr;
			assert(false);
		}

		/// true if expression is an lvalue
		static bool isLvalue( const ast::Expr * x ) {
			return x->result && ( x->result->is_lvalue() || x->result.as< ast::ReferenceType >() );
		}

		void postvisit( const ast::AddressExpr * addressExpr ) {
			CandidateFinder finder{ symtab, tenv };
			finder.find( addressExpr->arg );
			for ( CandidateRef & r : finder.candidates ) {
				if ( ! isLvalue( r->expr ) ) continue;
				addCandidate( *r, new ast::AddressExpr{ addressExpr->location, r->expr } );
			}
		}

		void postvisit( const ast::LabelAddressExpr * labelExpr ) {
			addCandidate( labelExpr, tenv );
		}

		void postvisit( const ast::CastExpr * castExpr ) {
			#warning unimplemented
			(void)castExpr;
			assert(false);
		}

		void postvisit( const ast::VirtualCastExpr * castExpr ) {
			assertf( castExpr->result, "Implicit virtual cast targets not yet supported." );
			CandidateFinder finder{ symtab, tenv };
			// don't prune here, all alternatives guaranteed to have same type
			finder.find( castExpr->arg, ResolvMode::withoutPrune() );
			for ( CandidateRef & r : finder.candidates ) {
				addCandidate( 
					*r, new ast::VirtualCastExpr{ castExpr->location, r->expr, castExpr->result } );
			}
		}

		void postvisit( const ast::UntypedMemberExpr * memberExpr ) {
			#warning unimplemented
			(void)memberExpr;
			assert(false);
		}

		void postvisit( const ast::MemberExpr * memberExpr ) {
			addCandidate( memberExpr, tenv );
		}

		void postvisit( const ast::NameExpr * variableExpr ) {
			#warning unimplemented
			(void)variableExpr;
			assert(false);
		}

		void postvisit( const ast::VariableExpr * variableExpr ) {
			// not sufficient to just pass `variableExpr` here, type might have changed since
			// creation
			addCandidate( 
				new ast::VariableExpr{ variableExpr->location, variableExpr->var }, tenv );
		}

		void postvisit( const ast::ConstantExpr * constantExpr ) {
			addCandidate( constantExpr, tenv );
		}

		void postvisit( const ast::SizeofExpr * sizeofExpr ) {
			#warning unimplemented
			(void)sizeofExpr;
			assert(false);
		}

		void postvisit( const ast::AlignofExpr * alignofExpr ) {
			#warning unimplemented
			(void)alignofExpr;
			assert(false);
		}

		void postvisit( const ast::UntypedOffsetofExpr * offsetofExpr ) {
			#warning unimplemented
			(void)offsetofExpr;
			assert(false);
		}

		void postvisit( const ast::OffsetofExpr * offsetofExpr ) {
			addCandidate( offsetofExpr, tenv );
		}

		void postvisit( const ast::OffsetPackExpr * offsetPackExpr ) {
			addCandidate( offsetPackExpr, tenv );
		}

		void postvisit( const ast::LogicalExpr * logicalExpr ) {
			CandidateFinder finder1{ symtab, tenv };
			finder1.find( logicalExpr->arg1, ResolvMode::withAdjustment() );
			if ( finder1.candidates.empty() ) return;

			CandidateFinder finder2{ symtab, tenv };
			finder2.find( logicalExpr->arg2, ResolvMode::withAdjustment() );
			if ( finder2.candidates.empty() ) return;

			for ( const CandidateRef & r1 : finder1.candidates ) {
				for ( const CandidateRef & r2 : finder2.candidates ) {
					ast::TypeEnvironment env{ r1->env };
					env.simpleCombine( r2->env );
					ast::OpenVarSet open{ r1->open };
					mergeOpenVars( open, r2->open );
					ast::AssertionSet need;
					mergeAssertionSet( need, r1->need );
					mergeAssertionSet( need, r2->need );

					addCandidate(
						new ast::LogicalExpr{ 
							logicalExpr->location, r1->expr, r2->expr, logicalExpr->isAnd },
						std::move( env ), std::move( open ), std::move( need ), 
						r1->cost + r2->cost );
				}
			}
		}

		void postvisit( const ast::ConditionalExpr * conditionalExpr ) {
			// candidates for condition
			CandidateFinder finder1{ symtab, tenv };
			finder1.find( conditionalExpr->arg1, ResolvMode::withAdjustment() );
			if ( finder1.candidates.empty() ) return;

			// candidates for true result
			CandidateFinder finder2{ symtab, tenv };
			finder2.find( conditionalExpr->arg2, ResolvMode::withAdjustment() );
			if ( finder2.candidates.empty() ) return;

			// candidates for false result
			CandidateFinder finder3{ symtab, tenv };
			finder3.find( conditionalExpr->arg3, ResolvMode::withAdjustment() );
			if ( finder3.candidates.empty() ) return;

			for ( const CandidateRef & r1 : finder1.candidates ) {
				for ( const CandidateRef & r2 : finder2.candidates ) {
					for ( const CandidateRef & r3 : finder3.candidates ) {
						ast::TypeEnvironment env{ r1->env };
						env.simpleCombine( r2->env );
						env.simpleCombine( r3->env );
						ast::OpenVarSet open{ r1->open };
						mergeOpenVars( open, r2->open );
						mergeOpenVars( open, r3->open );
						ast::AssertionSet need;
						mergeAssertionSet( need, r1->need );
						mergeAssertionSet( need, r2->need );
						mergeAssertionSet( need, r3->need );
						ast::AssertionSet have;

						// unify true and false results, then infer parameters to produce new 
						// candidates
						ast::ptr< ast::Type > common;
						if ( 
							unify( 
								r2->expr->result, r3->expr->result, env, need, have, open, symtab, 
								common ) 
						) {
							#warning unimplemented
							assert(false);
						}
					}
				}
			}
		}

		void postvisit( const ast::CommaExpr * commaExpr ) {
			ast::TypeEnvironment env{ tenv };
			ast::ptr< ast::Expr > arg1 = resolveInVoidContext( commaExpr->arg1, symtab, env );
			
			CandidateFinder finder2{ symtab, env };
			finder2.find( commaExpr->arg2, ResolvMode::withAdjustment() );

			for ( const CandidateRef & r2 : finder2.candidates ) {
				addCandidate( *r2, new ast::CommaExpr{ commaExpr->location, arg1, r2->expr } );
			}
		}

		void postvisit( const ast::ImplicitCopyCtorExpr * ctorExpr ) {
			addCandidate( ctorExpr, tenv );
		}

		void postvisit( const ast::ConstructorExpr * ctorExpr ) {
			CandidateFinder finder{ symtab, tenv };
			finder.find( ctorExpr->callExpr, ResolvMode::withoutPrune() );
			for ( CandidateRef & r : finder.candidates ) {
				addCandidate( *r, new ast::ConstructorExpr{ ctorExpr->location, r->expr } );
			}
		}

		void postvisit( const ast::RangeExpr * rangeExpr ) {
			// resolve low and high, accept candidates where low and high types unify
			CandidateFinder finder1{ symtab, tenv };
			finder1.find( rangeExpr->low, ResolvMode::withAdjustment() );
			if ( finder1.candidates.empty() ) return;

			CandidateFinder finder2{ symtab, tenv };
			finder2.find( rangeExpr->high, ResolvMode::withAdjustment() );
			if ( finder2.candidates.empty() ) return;

			for ( const CandidateRef & r1 : finder1.candidates ) {
				for ( const CandidateRef & r2 : finder2.candidates ) {
					ast::TypeEnvironment env{ r1->env };
					env.simpleCombine( r2->env );
					ast::OpenVarSet open{ r1->open };
					mergeOpenVars( open, r2->open );
					ast::AssertionSet need;
					mergeAssertionSet( need, r1->need );
					mergeAssertionSet( need, r2->need );
					ast::AssertionSet have;

					ast::ptr< ast::Type > common;
					if ( 
						unify( 
							r1->expr->result, r2->expr->result, env, need, have, open, symtab, 
							common ) 
					) {
						ast::RangeExpr * newExpr = 
							new ast::RangeExpr{ rangeExpr->location, r1->expr, r2->expr };
						newExpr->result = common ? common : r1->expr->result;
						
						#warning unimplemented
						assert(false);
					}
				}
			}
		}

		void postvisit( const ast::UntypedTupleExpr * tupleExpr ) {
			std::vector< CandidateFinder > subCandidates = 
				selfFinder.findSubExprs( tupleExpr->exprs );
			std::vector< CandidateList > possibilities;
			combos( subCandidates.begin(), subCandidates.end(), back_inserter( possibilities ) );

			for ( const CandidateList & subs : possibilities ) {
				std::vector< ast::ptr< ast::Expr > > exprs;
				exprs.reserve( subs.size() );
				for ( const CandidateRef & sub : subs ) { exprs.emplace_back( sub->expr ); }

				ast::TypeEnvironment env;
				ast::OpenVarSet open;
				ast::AssertionSet need;
				for ( const CandidateRef & sub : subs ) {
					env.simpleCombine( sub->env );
					mergeOpenVars( open, sub->open );
					mergeAssertionSet( need, sub->need );
				}

				addCandidate(
					new ast::TupleExpr{ tupleExpr->location, std::move( exprs ) }, 
					std::move( env ), std::move( open ), std::move( need ), sumCost( subs ) );
			}
		}

		void postvisit( const ast::TupleExpr * tupleExpr ) {
			addCandidate( tupleExpr, tenv );
		}

		void postvisit( const ast::TupleIndexExpr * tupleExpr ) {
			addCandidate( tupleExpr, tenv );
		}

		void postvisit( const ast::TupleAssignExpr * tupleExpr ) {
			addCandidate( tupleExpr, tenv );
		}

		void postvisit( const ast::UniqueExpr * unqExpr ) {
			CandidateFinder finder{ symtab, tenv };
			finder.find( unqExpr->expr, ResolvMode::withAdjustment() );
			for ( CandidateRef & r : finder.candidates ) {
				// ensure that the the id is passed on so that the expressions are "linked"
				addCandidate( *r, new ast::UniqueExpr{ unqExpr->location, r->expr, unqExpr->id } );
			}
		}

		void postvisit( const ast::StmtExpr * stmtExpr ) {
			#warning unimplemented
			(void)stmtExpr;
			assert(false);
		}

		void postvisit( const ast::UntypedInitExpr * initExpr ) {
			#warning unimplemented
			(void)initExpr;
			assert(false);
		}

		void postvisit( const ast::InitExpr * ) {
			assertf( false, "CandidateFinder should never see a resolved InitExpr." );
		}

		void postvisit( const ast::DeletedExpr * ) {
			assertf( false, "CandidateFinder should never see a DeletedExpr." );
		}

		void postvisit( const ast::GenericExpr * ) {
			assertf( false, "_Generic is not yet supported." );
		}
	};

	/// Prunes a list of candidates down to those that have the minimum conversion cost for a given 
	/// return type. Skips ambiguous candidates.
	CandidateList pruneCandidates( CandidateList & candidates ) {
		struct PruneStruct {
			CandidateRef candidate;
			bool ambiguous;

			PruneStruct() = default;
			PruneStruct( const CandidateRef & c ) : candidate( c ), ambiguous( false ) {}
		};

		// find lowest-cost candidate for each type
		std::unordered_map< std::string, PruneStruct > selected;
		for ( CandidateRef & candidate : candidates ) {
			std::string mangleName;
			{
				ast::ptr< ast::Type > newType = candidate->expr->result;
				candidate->env.apply( newType );
				mangleName = Mangle::mangle( newType );
			}

			auto found = selected.find( mangleName );
			if ( found != selected.end() ) {
				if ( candidate->cost < found->second.candidate->cost ) {
					PRINT(
						std::cerr << "cost " << candidate->cost << " beats " 
							<< found->second.candidate->cost << std::endl;
					)

					found->second = PruneStruct{ candidate };
				} else if ( candidate->cost == found->second.candidate->cost ) {
					// if one of the candidates contains a deleted identifier, can pick the other, 
					// since deleted expressions should not be ambiguous if there is another option 
					// that is at least as good
					if ( findDeletedExpr( candidate->expr ) ) {
						// do nothing
						PRINT( std::cerr << "candidate is deleted" << std::endl; )
					} else if ( findDeletedExpr( found->second.candidate->expr ) ) {
						PRINT( std::cerr << "current is deleted" << std::endl; )
						found->second = PruneStruct{ candidate };
					} else {
						PRINT( std::cerr << "marking ambiguous" << std::endl; )
						found->second.ambiguous = true;
					}
				} else {
					PRINT(
						std::cerr << "cost " << candidate->cost << " loses to " 
							<< found->second.candidate->cost << std::endl;
					)
				}
			} else {
				selected.emplace_hint( found, mangleName, candidate );
			}
		}

		// report unambiguous min-cost candidates
		CandidateList out;
		for ( auto & target : selected ) {
			if ( target.second.ambiguous ) continue;

			CandidateRef cand = target.second.candidate;
			
			ast::ptr< ast::Type > newResult = cand->expr->result;
			cand->env.applyFree( newResult );
			cand->expr = ast::mutate_field(
				cand->expr.get(), &ast::Expr::result, std::move( newResult ) );
			
			out.emplace_back( cand );
		}
		return out;
	}

	/// Returns a list of alternatives with the minimum cost in the given list
	CandidateList findMinCost( const CandidateList & candidates ) {
		CandidateList out;
		Cost minCost = Cost::infinity;
		for ( const CandidateRef & r : candidates ) {
			if ( r->cost < minCost ) {
				minCost = r->cost;
				out.clear();
				out.emplace_back( r );
			} else if ( r->cost == minCost ) {
				out.emplace_back( r );
			}
		}
		return out;
	}

} // anonymous namespace

void CandidateFinder::find( const ast::Expr * expr, ResolvMode mode ) {
	// Find alternatives for expression
	ast::Pass<Finder> finder{ *this };
	expr->accept( finder );

	if ( mode.failFast && candidates.empty() ) {
		SemanticError( expr, "No reasonable alternatives for expression " );
	}

	if ( mode.satisfyAssns || mode.prune ) {
		// trim candidates to just those where the assertions are satisfiable
		// - necessary pre-requisite to pruning
		CandidateList satisfied;
		std::vector< std::string > errors;
		for ( auto & candidate : candidates ) {
			satisfyAssertions( *candidate, symtab, satisfied, errors );
		}

		// fail early if none such
		if ( mode.failFast && satisfied.empty() ) {
			std::ostringstream stream;
			stream << "No alternatives with satisfiable assertions for " << expr << "\n";
			for ( const auto& err : errors ) {
				stream << err;
			}
			SemanticError( expr->location, stream.str() );
		}

		// reset candidates
		candidates = std::move( satisfied );
	}

	if ( mode.prune ) {
		// trim candidates to single best one
		PRINT(
			std::cerr << "alternatives before prune:" << std::endl;
			print( std::cerr, candidates );
		)

		CandidateList pruned = pruneCandidates( candidates );
		
		if ( mode.failFast && pruned.empty() ) {
			std::ostringstream stream;
			CandidateList winners = findMinCost( candidates );
			stream << "Cannot choose between " << winners.size() << " alternatives for "
				"expression\n";
			ast::print( stream, expr );
			stream << " Alternatives are:\n";
			print( stream, winners, 1 );
			SemanticError( expr->location, stream.str() );
		}

		auto oldsize = candidates.size();
		candidates = std::move( pruned );

		PRINT(
			std::cerr << "there are " << oldsize << " alternatives before elimination" << std::endl;
		)
		PRINT(
			std::cerr << "there are " << candidates.size() << " alternatives after elimination" 
				<< std::endl;
		)
	}

	// adjust types after pruning so that types substituted by pruneAlternatives are correctly 
	// adjusted
	if ( mode.adjust ) {
		for ( CandidateRef & r : candidates ) {
			r->expr = ast::mutate_field( 
				r->expr.get(), &ast::Expr::result, 
				adjustExprType( r->expr->result, r->env, symtab ) );
		}
	}

	// Central location to handle gcc extension keyword, etc. for all expressions
	for ( CandidateRef & r : candidates ) {
		if ( r->expr->extension != expr->extension ) {
			r->expr.get_and_mutate()->extension = expr->extension;
		}
	}
}

std::vector< CandidateFinder > CandidateFinder::findSubExprs( 
	const std::vector< ast::ptr< ast::Expr > > & xs 
) {
	std::vector< CandidateFinder > out;

	for ( const auto & x : xs ) {
		out.emplace_back( symtab, env );
		out.back().find( x, ResolvMode::withAdjustment() );
		
		PRINT(
			std::cerr << "findSubExprs" << std::endl;
			print( std::cerr, out.back().candidates );
		)
	}

	return out;
}

} // namespace ResolvExpr

// Local Variables: //
// tab-width: 4 //
// mode: c++ //
// compile-command: "make install" //
// End: //
