//
// Cforall Version 1.0.0 Copyright (C) 2015 University of Waterloo
//
// The contents of this file are covered under the licence agreement in the
// file "LICENCE" distributed with Cforall.
//
// TupleAssignment.cc -- 
//
// Author           : Rodolfo G. Esteves
// Created On       : Mon May 18 07:44:20 2015
// Last Modified By : Peter A. Buhr
// Last Modified On : Mon May 18 15:02:53 2015
// Update Count     : 2
//

#include "ResolvExpr/AlternativeFinder.h"
#include "ResolvExpr/Alternative.h"
#include "ResolvExpr/typeops.h"
#include "SynTree/Expression.h"
#include "TupleAssignment.h"
#include "Common/SemanticError.h"

#include <functional>
#include <algorithm>
#include <iterator>
#include <iostream>
#include <cassert>
#include <set>

namespace Tuples {
	TupleAssignSpotter::TupleAssignSpotter( ResolvExpr::AlternativeFinder *f = 0 )
		: currentFinder(f), matcher(0), hasMatched( false ) {}

	bool TupleAssignSpotter::pointsToTuple( Expression *expr ) {
		// also check for function returning tuple of reference types
		if (AddressExpr *addr = dynamic_cast<AddressExpr *>(expr) )
			if ( isTuple(addr->get_arg() ) )
				return true;
		return false;
	}

	bool TupleAssignSpotter::isTupleVar( DeclarationWithType *decl ) {
		if ( dynamic_cast<TupleType *>(decl->get_type()) )
			return true;
		return false;
	}

	bool TupleAssignSpotter::isTuple( Expression *expr, bool isRight ) {
		// true if `expr' is an expression returning a tuple: tuple, tuple variable or MRV function
		if ( ! expr ) return false;

		if ( dynamic_cast<TupleExpr *>(expr) )
			return true;
		else if ( VariableExpr *var = dynamic_cast<VariableExpr *>(expr) ) {
			if ( isTupleVar(var->get_var()) )
				return true;
		}

		return false;
	}

	bool TupleAssignSpotter::match() {
		assert ( matcher != 0 );

		std::list< Expression * > new_assigns;
		if ( ! matcher->match(new_assigns) )
			return false;

		if ( new_assigns.empty() ) return false;
		/*return */matcher->solve( new_assigns );
		if ( dynamic_cast<TupleAssignSpotter::MultipleAssignMatcher *>( matcher ) ) {
			// now resolve new assignments
			std::list< Expression * > solved_assigns;
			ResolvExpr::AltList solved_alts;
			assert( currentFinder != 0 );

			ResolvExpr::AltList current;
			for ( std::list< Expression * >::iterator i = new_assigns.begin(); i != new_assigns.end(); ++i ) {
				//try {
				ResolvExpr::AlternativeFinder finder( currentFinder->get_indexer(), currentFinder->get_environ() );
				finder.findWithAdjustment(*i);
				// prune expressions that don't coincide with
				ResolvExpr::AltList alts = finder.get_alternatives();
				assert( alts.size() == 1 );
				assert(alts.front().expr != 0 );
				current.push_back( finder.get_alternatives().front() );
				solved_assigns.push_back( alts.front().expr->clone() );
				//solved_assigns.back()->print(std::cerr);
				/*} catch( ... ) {
				  continue; // no reasonable alternative found
				  }*/
			}
			options.add_option( current );

			return true;
		} else { // mass assignment
			//if ( new_assigns.empty() ) return false;
			std::list< Expression * > solved_assigns;
			ResolvExpr::AltList solved_alts;
			assert( currentFinder != 0 );

			ResolvExpr::AltList current;
			if ( optMass.empty() ) {
				for ( std::list< Expression * >::size_type i = 0; i != new_assigns.size(); ++i )
					optMass.push_back( ResolvExpr::AltList() );
			}
			int cnt = 0;
			for ( std::list< Expression * >::iterator i = new_assigns.begin(); i != new_assigns.end(); ++i, cnt++ ) {

				ResolvExpr::AlternativeFinder finder( currentFinder->get_indexer(), currentFinder->get_environ() );
				finder.findWithAdjustment(*i);
				ResolvExpr::AltList alts = finder.get_alternatives();
				assert( alts.size() == 1 );
				assert(alts.front().expr != 0 );
				current.push_back( finder.get_alternatives().front() );
				optMass[cnt].push_back( finder.get_alternatives().front() );
				solved_assigns.push_back( alts.front().expr->clone() );
			}

			return true;
		}

		return false;
	}

	bool TupleAssignSpotter::isMVR( Expression *expr ) {
		if ( expr->get_results().size() > 1 ) {
			// MVR processing
			return true;
		}
		return false;
	}

	bool TupleAssignSpotter::isTupleAssignment( UntypedExpr * expr, std::list<ResolvExpr::AltList> &possibilities ) {
		if (  NameExpr *assgnop = dynamic_cast< NameExpr * >(expr->get_function()) ) {

			if ( assgnop->get_name() == std::string("?=?") ) {

				for ( std::list<ResolvExpr::AltList>::iterator ali = possibilities.begin(); ali != possibilities.end(); ++ali ) {
					assert( ali->size() == 2 );
					ResolvExpr::AltList::iterator opit = ali->begin();
					ResolvExpr::Alternative op1 = *opit, op2 = *(++opit);

					if ( pointsToTuple(op1.expr) ) { // also handles tuple vars
						if ( isTuple( op2.expr, true ) )
							matcher = new MultipleAssignMatcher(op1.expr, op2.expr);
						else if ( isMVR( op2.expr ) ) {
							// handle MVR differently
						} else
							// mass assignment
							matcher = new MassAssignMatcher(op1.expr, op2.expr);

						std::list< ResolvExpr::AltList > options;
						if ( match() )
							/*
							  if ( hasMatched ) {
							  // throw SemanticError("Ambiguous tuple assignment");
							  } else {*/
							// Matched for the first time
							hasMatched = true;
						/*} */
					} /* else if ( isTuple( op2 ) )
						 throw SemanticError("Inapplicable tuple assignment.");
					  */
				}

				if ( hasMatched ) {
					if ( dynamic_cast<TupleAssignSpotter::MultipleAssignMatcher *>( matcher ) ) {
						//options.print( std::cerr );
						std::list< ResolvExpr::AltList >best = options.get_best();
						if ( best.size() == 1 ) {
							std::list<Expression *> solved_assigns;
							for ( ResolvExpr::AltList::iterator i = best.front().begin(); i != best.front().end(); ++i ) {
								solved_assigns.push_back( i->expr );
							}
							/* assigning cost zero? */
							currentFinder->get_alternatives().push_front( ResolvExpr::Alternative(new SolvedTupleExpr(solved_assigns/*, SolvedTupleExpr::MULTIPLE*/), currentFinder->get_environ(), ResolvExpr::Cost() ) );
						}
					} else {
						assert( ! optMass.empty() );
						ResolvExpr::AltList winners;
						for ( std::vector< ResolvExpr::AltList >::iterator i = optMass.begin(); i != optMass.end(); ++i )
							findMinCostAlt( i->begin(), i->end(), back_inserter(winners) );

						std::list< Expression *> solved_assigns;
						for ( ResolvExpr::AltList::iterator i = winners.begin(); i != winners.end(); ++i )
							solved_assigns.push_back( i->expr );
						currentFinder->get_alternatives().push_front( ResolvExpr::Alternative(new SolvedTupleExpr(solved_assigns/*, SolvedTupleExpr::MASS*/), currentFinder->get_environ(), ResolvExpr::Cost() ) );
					}
				}
			}
		}
		return hasMatched;
	}

	void TupleAssignSpotter::Matcher::init( Expression *_lhs, Expression *_rhs ) {
		lhs.clear();
		if (AddressExpr *addr = dynamic_cast<AddressExpr *>(_lhs) )
			if ( TupleExpr *tuple = dynamic_cast<TupleExpr *>(addr->get_arg()) )
				std::copy( tuple->get_exprs().begin(), tuple->get_exprs().end(), back_inserter(lhs) );

		rhs.clear();
	}

	TupleAssignSpotter::Matcher::Matcher( /*TupleAssignSpotter &spot,*/ Expression *_lhs, Expression *_rhs ) /*: own_spotter(spot) */{
		init(_lhs,_rhs);
	}

	TupleAssignSpotter::MultipleAssignMatcher::MultipleAssignMatcher( Expression *_lhs, Expression *_rhs )/* : own_spotter(spot) */{
		init(_lhs,_rhs);

		if ( TupleExpr *tuple = dynamic_cast<TupleExpr *>(_rhs) )
			std::copy( tuple->get_exprs().begin(), tuple->get_exprs().end(), back_inserter(rhs) );
	}

	UntypedExpr *TupleAssignSpotter::Matcher::createAssgn( Expression *left, Expression *right ) {
		if ( left && right ) {
			std::list< Expression * > args;
			args.push_back(new AddressExpr(left->clone()));  args.push_back(right->clone());
			return new UntypedExpr(new NameExpr("?=?"), args);
		} else
			throw 0; // xxx - diagnose the problem
	}

	bool TupleAssignSpotter::MassAssignMatcher::match( std::list< Expression * > &out ) {
		if ( lhs.empty() || (rhs.size() != 1) ) return false;

		for ( std::list< Expression * >::iterator l = lhs.begin(); l != lhs.end(); l++ ) {
			std::list< Expression * > args;
			args.push_back( new AddressExpr(*l) );
			args.push_back( rhs.front() );
			out.push_back( new UntypedExpr(new NameExpr("?=?"), args) );
		}

		return true;
	}

	bool TupleAssignSpotter::MassAssignMatcher::solve( std::list< Expression * > &assigns ) {
		/*
		  std::list< Expression * > solved_assigns;
		  ResolvExpr::AltList solved_alts;
		  assert( currentFinder != 0 );

		  ResolvExpr::AltList current;
		  if ( optMass.empty() ) {
		  for ( std::list< Expression * >::size_type i = 0; i != new_assigns.size(); ++i )
		  optMass.push_back( ResolvExpr::AltList() );
		  }
		  int cnt = 0;
		  for ( std::list< Expression * >::iterator i = new_assigns.begin(); i != new_assigns.end(); ++i, cnt++ ) {

		  ResolvExpr::AlternativeFinder finder( currentFinder->get_indexer(), currentFinder->get_environ() );
		  finder.findWithAdjustment(*i);
		  ResolvExpr::AltList alts = finder.get_alternatives();
		  assert( alts.size() == 1 );
		  assert(alts.front().expr != 0 );
		  current.push_back( finder.get_alternatives().front() );
		  optMass[cnt].push_back( finder.get_alternatives().front() );
		  solved_assigns.push_back( alts.front().expr->clone() );
		  }
		*/
		return true;
	}

	bool TupleAssignSpotter::MultipleAssignMatcher::match( std::list< Expression * > &out ) {
		// need more complicated matching
		if ( lhs.size() == rhs.size() ) {
			zipWith( lhs.begin(), lhs.end(), rhs.begin(), rhs.end(), back_inserter(out), TupleAssignSpotter::Matcher::createAssgn );
			return true;
		} //else
		//std::cerr << "The length of (left, right) is: (" << lhs.size() << "," << rhs.size() << ")" << std::endl;*/
		return false;
	}

	bool TupleAssignSpotter::MultipleAssignMatcher::solve( std::list< Expression * > &assigns ) {
		/*
		  std::list< Expression * > solved_assigns;
		  ResolvExpr::AltList solved_alts;
		  assert( currentFinder != 0 );

		  ResolvExpr::AltList current;
		  for ( std::list< Expression * >::iterator i = new_assigns.begin(); i != new_assigns.end(); ++i ) {
		  //try {
		  ResolvExpr::AlternativeFinder finder( currentFinder->get_indexer(), currentFinder->get_environ() );
		  finder.findWithAdjustment(*i);
		  // prune expressions that don't coincide with
		  ResolvExpr::AltList alts = finder.get_alternatives();
		  assert( alts.size() == 1 );
		  assert(alts.front().expr != 0 );
		  current.push_back( finder.get_alternatives().front() );
		  solved_assigns.push_back( alts.front().expr->clone() );
		  //solved_assigns.back()->print(std::cerr);
		  //} catch( ... ) {
		  //continue; // no reasonable alternative found
		  //}
		  }
		  options.add_option( current );
		*/

		return true;
	}

	void TupleAssignSpotter::Options::add_option( ResolvExpr::AltList &opt ) {
		using namespace std;

		options.push_back( opt );
		/*
		  vector< Cost > costs;
		  costs.reserve( opt.size() );
		  transform( opt.begin(), opt.end(), back_inserter(costs), ptr_fun(extract_cost) );
		*/
		// transpose matrix
		if ( costMatrix.empty() )
			for ( unsigned int i = 0; i< opt.size(); ++i)
				costMatrix.push_back( vector<ResolvExpr::Cost>() );

		int cnt = 0;
		for ( ResolvExpr::AltList::iterator i = opt.begin(); i != opt.end(); ++i, cnt++ )
			costMatrix[cnt].push_back( i->cost );

		return;
	}

	std::list< ResolvExpr::AltList > TupleAssignSpotter::Options::get_best() {
		using namespace std;
		using namespace ResolvExpr;
		list< ResolvExpr::AltList > ret;
		list< multiset<int> > solns;
		for ( vector< vector<Cost> >::iterator i = costMatrix.begin(); i != costMatrix.end(); ++i ) {
			list<int> current;
			findMinCost( i->begin(), i->end(), back_inserter(current) );
			solns.push_back( multiset<int>(current.begin(), current.end()) );
		}
		// need to combine
		multiset<int> result;
		lift_intersection( solns.begin(), solns.end(), inserter( result, result.begin() ) );
		if ( result.size() != 1 )
			throw SemanticError("Ambiguous tuple expression");
		ret.push_back(get_option( *(result.begin() )));
		return ret;
	}

	void TupleAssignSpotter::Options::print( std::ostream &ostr ) {
		using namespace std;

		for ( vector< vector < ResolvExpr::Cost > >::iterator i = costMatrix.begin(); i != costMatrix.end(); ++i ) {
			for ( vector < ResolvExpr::Cost >::iterator j = i->begin(); j != i->end(); ++j )
				ostr << *j << " " ;
			ostr << std::endl;
		} // for
		return;
	}

	ResolvExpr::Cost extract_cost( ResolvExpr::Alternative &alt ) {
		return alt.cost;
	}

	template< typename InputIterator, typename OutputIterator >
	void TupleAssignSpotter::Options::findMinCost( InputIterator begin, InputIterator end, OutputIterator out ) {
		using namespace ResolvExpr;
		std::list<int> alternatives;

		// select the alternatives that have the minimum parameter cost
		Cost minCost = Cost::infinity;
		unsigned int index = 0;
		for ( InputIterator i = begin; i != end; ++i, index++ ) {
			if ( *i < minCost ) {
				minCost = *i;
				alternatives.clear();
				alternatives.push_back( index );
			} else if ( *i == minCost ) {
				alternatives.push_back( index );
			}
		}
		std::copy( alternatives.begin(), alternatives.end(), out );
	}

	template< class InputIterator, class OutputIterator >
	void TupleAssignSpotter::Options::lift_intersection( InputIterator begin, InputIterator end, OutputIterator out ) {
		if ( begin == end ) return;
		InputIterator test = begin;

		if (++test == end)
			{ copy(begin->begin(), begin->end(), out); return; }


		std::multiset<int> cur; // InputIterator::value_type::value_type
		copy( begin->begin(), begin->end(), inserter( cur, cur.begin() ) );

		while ( test != end ) {
			std::multiset<int> temp;
			set_intersection( cur.begin(), cur.end(), test->begin(), test->end(), inserter(temp,temp.begin()) );
			cur.clear();
			copy( temp.begin(), temp.end(), inserter(cur,cur.begin()));
			++test;
		}

		copy( cur.begin(), cur.end(), out );
		return;
	}

	ResolvExpr::AltList TupleAssignSpotter::Options::get_option( std::list< ResolvExpr::AltList >::size_type index ) {
		if ( index >= options.size() )
			throw 0; // XXX
		std::list< ResolvExpr::AltList >::iterator it = options.begin();
		for ( std::list< ResolvExpr::AltList >::size_type i = 0; i < index; ++i, ++it );
		return *it;
	}
} // namespace Tuples

// Local Variables: //
// tab-width: 4 //
// mode: c++ //
// compile-command: "make install" //
// End: //
