//
// Cforall Version 1.0.0 Copyright (C) 2015 University of Waterloo
//
// The contents of this file are covered under the licence agreement in the
// file "LICENCE" distributed with Cforall.
//
// LabelFixer.cc -- 
//
// Author           : Rodolfo G. Esteves
// Created On       : Mon May 18 07:44:20 2015
// Last Modified By : Rob Schluntz
// Last Modified On : Wed Jun 24 16:24:34 2015
// Update Count     : 141
//

#include <list>
#include <cassert>

#include "LabelFixer.h"
#include "MLEMutator.h"
#include "SynTree/Expression.h"
#include "SynTree/Statement.h"
#include "SynTree/Declaration.h"
#include "utility.h"

#include <iostream>

namespace ControlStruct {
	LabelFixer::Entry::Entry( Statement *to, Statement *from ) : definition ( to ) {
		if ( from != 0 ) {
			UsageLoc loc; loc.stmt = from;
			usage.push_back( loc );
		}
	}

	LabelFixer::Entry::Entry( Statement *to, Expression *from ) : definition ( to ) {
		if ( from != 0 ) {
			UsageLoc loc; loc.expr = from;
			usage.push_back( loc );
		}
	}


	bool LabelFixer::Entry::insideLoop() {
		return ( dynamic_cast< ForStmt * > ( definition ) ||
			dynamic_cast< WhileStmt * > ( definition )  );
	}

	void LabelFixer::Entry::UsageLoc::accept( Visitor & visitor ) {
		if ( dynamic_cast< Statement * >( stmt ) ) {
			stmt->accept( visitor );
		} else {
			expr->accept( visitor );
		}
	}

	LabelFixer::LabelFixer( LabelGenerator *gen ) : generator ( gen ) {
		if ( generator == 0 )
			generator = LabelGenerator::getGenerator();
	}

	void LabelFixer::visit( FunctionDecl *functionDecl ) {
		maybeAccept( functionDecl->get_statements(), *this );

		MLEMutator mlemut( resolveJumps(), generator );
		functionDecl->acceptMutator( mlemut );
	}

	// prune to at most one label definition for each statement
	void LabelFixer::visit( Statement *stmt ) {
		currentStatement = stmt;
		std::list< Label > &labels = stmt->get_labels();

		if ( ! labels.empty() ) {
			// only remember one label for each statement
			Label current = setLabelsDef( labels, stmt );
			labels.clear();
			labels.push_front( current );
		} // if
	}

	void LabelFixer::visit( BranchStmt *branchStmt ) {
		visit ( ( Statement * )branchStmt );

		// for labeled branches, add an entry to the label table
		Label target = branchStmt->get_target();
		if ( target != "" ) {
			setLabelsUsg( target, branchStmt );
		}
	}

	void LabelFixer::visit( UntypedExpr *untyped ) {
		if ( NameExpr * func = dynamic_cast< NameExpr * >( untyped->get_function() ) ) {
			if ( func->get_name() == "&&" ) {
				NameExpr * arg = dynamic_cast< NameExpr * >( untyped->get_args().front() );
				Label target = arg->get_name();
				assert( target != "" );
				setLabelsUsg( target, untyped );
			} else {
				Visitor::visit( untyped );
			}
		}
	}


	// sets the definition of the labelTable entry to be the provided 
	// statement for every label in the list parameter. Happens for every kind of statement
	Label LabelFixer::setLabelsDef( std::list< Label > &llabel, Statement *definition ) {
		assert( definition != 0 );
		assert( llabel.size() > 0 );

		Entry * e = new Entry( definition );

		for ( std::list< Label >::iterator i = llabel.begin(); i != llabel.end(); i++ ) {
			if ( labelTable.find( *i ) == labelTable.end() ) {
				// all labels on this statement need to use the same entry, so this should only be created once
				// undefined and unused until now, add an entry
				labelTable[ *i ] =  e;
			} else if ( labelTable[ *i ]->defined() ) {
				// defined twice, error
				throw SemanticError( "Duplicate definition of label: " + *i );
			}	else {
				// used previously, but undefined until now -> link with this entry
				Entry * oldEntry = labelTable[ *i ];
				e->add_uses( *oldEntry );
				labelTable[ *i ] = e;
			} // if
		} // for

		// produce one of the labels attached to this statement to be 
		// temporarily used as the canonical label
		return labelTable[ llabel.front() ]->get_label();
	}

	// Remember all uses of a label.
	template< typename UsageNode >
	void LabelFixer::setLabelsUsg( Label orgValue, UsageNode *use ) {
		assert( use != 0 );

		if ( labelTable.find( orgValue ) != labelTable.end() ) {
			// the label has been defined or used before
			labelTable[ orgValue ]->add_use( use );
		} else {
			labelTable[ orgValue ] = new Entry( 0, use );
		}
	}

	class LabelGetter : public Visitor {
		public:
		LabelGetter( Label &label ) : label( label ) {}

		virtual void visit( BranchStmt * branchStmt ) {
			label = branchStmt->get_target();
		}

		virtual void visit( UntypedExpr * untyped ) {
			NameExpr * name = dynamic_cast< NameExpr * >( untyped->get_function() );
			assert( name );
			assert( name->get_name() == "&&" );
			NameExpr * arg = dynamic_cast< NameExpr * >( untyped->get_args().front() );
			assert( arg );
			label = arg->get_name();
		}		

		private:
			Label &label;
	};

	class LabelSetter : public Visitor {
		public:
		LabelSetter( Label label ) : label( label ) {}

		virtual void visit( BranchStmt * branchStmt ) {
			branchStmt->set_target( label );
		}

		virtual void visit( UntypedExpr * untyped ) {
			NameExpr * name = dynamic_cast< NameExpr * >( untyped->get_function() );
			assert( name );
			assert( name->get_name() == "&&" );
			NameExpr * arg = dynamic_cast< NameExpr * >( untyped->get_args().front() );
			assert( arg );
			arg->set_name( label );
		}

	private:
		Label label;
	};

	// Ultimately builds a table that maps a label to its defining statement.
	// In the process, 
	std::map<Label, Statement * > *LabelFixer::resolveJumps() throw ( SemanticError ) {
		std::map< Statement *, Entry * > def_us;

		// combine the entries for all labels that target the same location
		for ( std::map< Label, Entry *>::iterator i = labelTable.begin(); i != labelTable.end(); ++i ) {
			Entry *e = i->second;

			if ( def_us.find ( e->get_definition() ) == def_us.end() ) {
				def_us[ e->get_definition() ] = e;
			} else if ( e->used() ) {
				def_us[ e->get_definition() ]->add_uses( *e );
			}
		}

		// create a unique label for each target location. 
		for ( std::map< Statement *, Entry * >::iterator i = def_us.begin(); i != def_us.end(); ++i ) {
			Statement *to = (*i).first;
			Entry * entry = (*i).second;
			std::list< Entry::UsageLoc > &from = entry->get_uses();

			// no label definition found
			if ( to == 0 ) {
				Label undef;
				LabelGetter getLabel( undef );
				from.back().accept( getLabel );
				// Label undef = getLabel( from.back()->get_target() );
				throw SemanticError ( "'" + undef + "' label not defined");
			} // if

			// generate a new label, and attach it to its defining statement as the only label on that statement
			Label finalLabel = generator->newLabel( to->get_labels().back() );
			entry->set_label( finalLabel );

			to->get_labels().clear();
			to->get_labels().push_back( finalLabel );

			// redirect each of the source branch statements to the new target label
			for ( std::list< Entry::UsageLoc >::iterator j = from.begin(); j != from.end(); ++j ) {
				LabelSetter setLabel( finalLabel );
				(*j).accept( setLabel );
				// setLabel( *j, finalLabel );

				// BranchStmt *jump = *j;
				// assert( jump != 0 );
				// jump->set_target( finalLabel );
			} // for
		} // for

		// create a table where each label maps to its defining statement
		std::map< Label, Statement * > *ret = new std::map< Label, Statement * >();
		for ( std::map< Statement *, Entry * >::iterator i = def_us.begin(); i != def_us.end(); ++i ) {
			(*ret)[ (*i).second->get_label() ] = (*i).first;
		}

		return ret;
	}
}  // namespace ControlStruct

// Local Variables: //
// tab-width: 4 //
// mode: c++ //
// compile-command: "make install" //
// End: //
