//
// Cforall Version 1.0.0 Copyright (C) 2015 University of Waterloo
//
// The contents of this file are covered under the licence agreement in the
// file "LICENCE" distributed with Cforall.
//
// ExceptVisitor.cc --
//
// Author           : Andrew Beach
// Created On       : Wed Jun 14 16:49:00 2017
// Last Modified By : Andrew Beach
// Last Modified On : Fri Jun 30 13:30:00 2017
// Update Count     : 1
//

#include "ExceptTranslate.h"
#include "Common/PassVisitor.h"
#include "SynTree/Statement.h"
#include "SynTree/Declaration.h"
#include "SynTree/Expression.h"
#include "SynTree/Type.h"
#include "SynTree/Attribute.h"

namespace ControlStruct {

	// This (large) section could probably be moved out of the class
	// and be static helpers instead.

	// Type(Qualifiers &, false, std::list<Attribute *> &)

	// void (*function)();
	static FunctionType try_func_t(Type::Qualifiers(), false);
	// void (*function)(int, exception);
	static FunctionType catch_func_t(Type::Qualifiers(), false);
	// int (*function)(exception);
	static FunctionType match_func_t(Type::Qualifiers(), false);
	// bool (*function)(exception);
	static FunctionType handle_func_t(Type::Qualifiers(), false);
	// void (*function)(__attribute__((unused)) void *);
	static FunctionType finally_func_t(Type::Qualifiers(), false);

	static void init_func_types() {
		static bool init_complete = false;
		if (init_complete) {
			return;
		}
		ObjectDecl index_obj(
			"__handler_index",
			Type::StorageClasses(),
			LinkageSpec::Cforall,
			/*bitfieldWidth*/ NULL,
			new BasicType( emptyQualifiers, BasicType::SignedInt ),
			/*init*/ NULL
			);
		ObjectDecl exception_obj(
			"__exception_inst",
			Type::StorageClasses(),
			LinkageSpec::Cforall,
			/*bitfieldWidth*/ NULL,
			new PointerType(
				emptyQualifiers,
				new BasicType( emptyQualifiers, BasicType::SignedInt )
				),
			/*init*/ NULL
			);
		ObjectDecl bool_obj(
			"__ret_bool",
			Type::StorageClasses(),
			LinkageSpec::Cforall,
			/*bitfieldWidth*/ NULL,
			new BasicType(emptyQualifiers, BasicType::Bool),
			/*init*/ NULL
			);
		ObjectDecl voidptr_obj(
			"__hook",
			Type::StorageClasses(),
			LinkageSpec::Cforall,
			NULL,
			new PointerType(
				emptyQualifiers,
				new VoidType(
					emptyQualifiers
					),
				std::list<Attribute *>{new Attribute("unused")}
				),
			NULL
			);

		catch_func_t.get_parameters().push_back( index_obj.clone() );
		catch_func_t.get_parameters().push_back( exception_obj.clone() );
		match_func_t.get_returnVals().push_back( index_obj.clone() );
		match_func_t.get_parameters().push_back( exception_obj.clone() );
		handle_func_t.get_returnVals().push_back( bool_obj.clone() );
		handle_func_t.get_parameters().push_back( exception_obj.clone() );
		finally_func_t.get_parameters().push_back( voidptr_obj.clone() );

		init_complete = true;
	}

	// Buricratic Helpers (Not having to do with the paritular operation.)

	typedef std::list<CatchStmt*> CatchList;

	void split( CatchList& allHandlers, CatchList& terHandlers,
				CatchList& resHandlers ) {
		while ( !allHandlers.empty() ) {
			CatchStmt * stmt = allHandlers.front();
			allHandlers.pop_front();
			if (CatchStmt::Terminate == stmt->get_kind()) {
				terHandlers.push_back(stmt);
			} else {
				resHandlers.push_back(stmt);
			}
		}
	}

	template<typename T>
	void free_all( std::list<T *> &list ) {
		typename std::list<T *>::iterator it;
		for ( it = list.begin() ; it != list.end() ; ++it ) {
			delete *it;
		}
		list.clear();
	}

	void appendDeclStmt( CompoundStmt * block, Declaration * item ) {
		block->push_back(new DeclStmt(noLabels, item));
	}

	Expression * nameOf( DeclarationWithType * decl ) {
		return new VariableExpr( decl );
	}

	// ThrowStmt Mutation Helpers

	Statement * create_given_throw(
			const char * throwFunc, ThrowStmt * throwStmt ) {
		// { int NAME = EXPR; throwFunc( &NAME ); }
		CompoundStmt * result = new CompoundStmt( noLabels );
		ObjectDecl * local = new ObjectDecl(
			"__local_exception_copy",
			Type::StorageClasses(),
			LinkageSpec::Cforall,
			NULL,
			new BasicType( emptyQualifiers, BasicType::SignedInt ),
			new SingleInit( throwStmt->get_expr() )
			);
		appendDeclStmt( result, local );
		UntypedExpr * call = new UntypedExpr( new NameExpr( throwFunc ) );
		call->get_args().push_back( new AddressExpr( nameOf( local ) ) );
		result->push_back( new ExprStmt( throwStmt->get_labels(), call ) );
		throwStmt->set_expr( nullptr );
		delete throwStmt;
		return result;
	}

	Statement * create_terminate_throw( ThrowStmt *throwStmt ) {
		// { int NAME = EXPR; __throw_terminate( &NAME ); }
		return create_given_throw( "__cfaehm__throw_termination", throwStmt );
	}
	Statement * create_terminate_rethrow( ThrowStmt *throwStmt ) {
		// __rethrow_terminate();
		assert( nullptr == throwStmt->get_expr() );
		Statement * result = new ExprStmt(
			throwStmt->get_labels(),
			new UntypedExpr( new NameExpr( "__cfaehm__rethrow_termination" ) )
			);
		delete throwStmt;
		return result;
	}
	Statement * create_resume_throw( ThrowStmt *throwStmt ) {
		// __throw_resume( EXPR );
		return create_given_throw( "__cfaehm__throw_resumption", throwStmt );
	}
	Statement * create_resume_rethrow( ThrowStmt *throwStmt ) {
		// return false;
		Statement * result = new ReturnStmt(
			throwStmt->get_labels(),
			new ConstantExpr( Constant::from_bool( false ) )
			);
		delete throwStmt;
		return result;
	}

	// TryStmt Mutation Helpers

	CompoundStmt * take_try_block( TryStmt *tryStmt ) {
		CompoundStmt * block = tryStmt->get_block();
		tryStmt->set_block( nullptr );
		return block;
	}
	FunctionDecl * create_try_wrapper( CompoundStmt *body ) {

		return new FunctionDecl( "try", Type::StorageClasses(),
			LinkageSpec::Cforall, try_func_t.clone(), body );
	}

	FunctionDecl * create_terminate_catch( CatchList &handlers ) {
		std::list<CaseStmt *> handler_wrappers;

		FunctionType *func_type = catch_func_t.clone();
		DeclarationWithType * index_obj = func_type->get_parameters().front();
	//	DeclarationWithType * except_obj = func_type->get_parameters().back();

		// Index 1..{number of handlers}
		int index = 0;
		CatchList::iterator it = handlers.begin();
		for ( ; it != handlers.end() ; ++it ) {
			++index;
			CatchStmt * handler = *it;

			// INTEGERconstant Version
			// case `index`:
			// {
			//     `handler.body`
			// }
			// return;
			std::list<Statement *> caseBody;
			caseBody.push_back( handler->get_body() );
			handler->set_body( nullptr );
			caseBody.push_back( new ReturnStmt( noLabels, nullptr ) );

			handler_wrappers.push_back( new CaseStmt(
				noLabels,
				new ConstantExpr( Constant::from_int( index ) ),
				caseBody
				) );
		}
		// TODO: Some sort of meaningful error on default perhaps?

		std::list<Statement*> stmt_handlers;
		while ( !handler_wrappers.empty() ) {
			stmt_handlers.push_back( handler_wrappers.front() );
			handler_wrappers.pop_front();
		}

		SwitchStmt * handler_lookup = new SwitchStmt(
			noLabels,
			nameOf( index_obj ),
			stmt_handlers
			);
		CompoundStmt * body = new CompoundStmt( noLabels );
		body->push_back( handler_lookup );

		return new FunctionDecl("catch", Type::StorageClasses(),
			LinkageSpec::Cforall, func_type, body);
	}

	// Create a single check from a moddified handler.
	// except_obj is referenced, modded_handler will be freed.
	CompoundStmt *create_single_matcher(
			DeclarationWithType * except_obj, CatchStmt * modded_handler ) {
		CompoundStmt * block = new CompoundStmt( noLabels );

		// INTEGERconstant Version
		assert( nullptr == modded_handler->get_decl() );
		ConstantExpr * number =
			dynamic_cast<ConstantExpr*>( modded_handler->get_cond() );
		assert( number );
		modded_handler->set_cond( nullptr );

		Expression * cond;
		{
			std::list<Expression *> args;
			args.push_back( number );

			std::list<Expression *> rhs_args;
			rhs_args.push_back( nameOf( except_obj ) );
			Expression * rhs = new UntypedExpr(
				new NameExpr( "*?" ), rhs_args );
			args.push_back( rhs );

			cond = new UntypedExpr( new NameExpr( "?==?" /*???*/), args );
		}

		if ( modded_handler->get_cond() ) {
			cond = new LogicalExpr( cond, modded_handler->get_cond() );
		}
		block->push_back( new IfStmt( noLabels,
			cond, modded_handler->get_body(), nullptr ) );

		modded_handler->set_decl( nullptr );
		modded_handler->set_cond( nullptr );
		modded_handler->set_body( nullptr );
		delete modded_handler;
		return block;
	}

	FunctionDecl * create_terminate_match( CatchList &handlers ) {
		CompoundStmt * body = new CompoundStmt( noLabels );

		FunctionType * func_type = match_func_t.clone();
		DeclarationWithType * except_obj = func_type->get_parameters().back();

		// Index 1..{number of handlers}
		int index = 0;
		CatchList::iterator it;
		for ( it = handlers.begin() ; it != handlers.end() ; ++it ) {
			++index;
			CatchStmt * handler = *it;

			// Body should have been taken by create_terminate_catch.
			assert( nullptr == handler->get_body() );

			// Create new body.
			handler->set_body( new ReturnStmt( noLabels,
				new ConstantExpr( Constant::from_int( index ) ) ) );

			// Create the handler.
			body->push_back( create_single_matcher( except_obj, handler ) );
			*it = nullptr;
		}

		body->push_back( new ReturnStmt( noLabels, new ConstantExpr(
			Constant::from_int( 0 ) ) ) );

		return new FunctionDecl("match", Type::StorageClasses(),
			LinkageSpec::Cforall, func_type, body);
	}

	CompoundStmt * create_terminate_caller(
			FunctionDecl * try_wrapper,
			FunctionDecl * terminate_catch,
			FunctionDecl * terminate_match) {

		UntypedExpr * caller = new UntypedExpr( new NameExpr(
			"__cfaehm__try_terminate" ) );
		std::list<Expression *>& args = caller->get_args();
		args.push_back( nameOf( try_wrapper ) );
		args.push_back( nameOf( terminate_catch ) );
		args.push_back( nameOf( terminate_match ) );

		CompoundStmt * callStmt = new CompoundStmt( noLabels );
		callStmt->push_back( new ExprStmt( noLabels, caller ) );
		return callStmt;
	}

	FunctionDecl * create_resume_handler( CatchList &handlers ) {
		CompoundStmt * body = new CompoundStmt( noLabels );

		FunctionType * func_type = match_func_t.clone();
		DeclarationWithType * except_obj = func_type->get_parameters().back();

		CatchList::iterator it;
		for ( it = handlers.begin() ; it != handlers.end() ; ++it ) {
			CatchStmt * handler = *it;

			// Modifiy body.
			CompoundStmt * handling_code =
				dynamic_cast<CompoundStmt*>( handler->get_body() );
			if ( ! handling_code ) {
				handling_code = new CompoundStmt( noLabels );
				handling_code->push_back( handler->get_body() );
			}
			handling_code->push_back( new ReturnStmt( noLabels,
				new ConstantExpr( Constant::from_bool( false ) ) ) );
			handler->set_body( handling_code );

			// Create the handler.
			body->push_back( create_single_matcher( except_obj, handler ) );
			*it = nullptr;
		}

		return new FunctionDecl("handle", Type::StorageClasses(),
			LinkageSpec::Cforall, func_type, body);
	}

	CompoundStmt * create_resume_wrapper(
			StructDecl * node_decl,
			Statement * wraps,
			FunctionDecl * resume_handler ) {
		CompoundStmt * body = new CompoundStmt( noLabels );

		// struct __try_resume_node __resume_node
		//  	__attribute__((cleanup( __cfaehm__try_resume_cleanup )));
		// ** unwinding of the stack here could cause problems **
		// ** however I don't think that can happen currently **
		// __cfaehm__try_resume_setup( &__resume_node, resume_handler );

		std::list< Attribute * > attributes;
		{
			std::list< Expression * > attr_params;
			attr_params.push_back( new NameExpr(
				"__cfaehm__try_resume_cleanup" ) );
			attributes.push_back( new Attribute( "cleanup", attr_params ) );
		}

		ObjectDecl * obj = new ObjectDecl(
			"__resume_node",
			Type::StorageClasses(),
			LinkageSpec::Cforall,
			nullptr,
			new StructInstType(
				Type::Qualifiers(),
				node_decl
				),
			nullptr,
			attributes
			);
		appendDeclStmt( body, obj );

		UntypedExpr *setup = new UntypedExpr( new NameExpr(
			"__cfaehm__try_resume_setup" ) );
		setup->get_args().push_back( new AddressExpr( nameOf( obj ) ) );
		setup->get_args().push_back( nameOf( resume_handler ) );

		body->push_back( new ExprStmt( noLabels, setup ) );

		body->push_back( wraps );
		return body;
	}

	FunctionDecl * create_finally_wrapper( TryStmt * tryStmt ) {
		FinallyStmt * finally = tryStmt->get_finally();
		CompoundStmt * body = finally->get_block();
		finally->set_block( nullptr );
		delete finally;
		tryStmt->set_finally( nullptr );

		return new FunctionDecl("finally", Type::StorageClasses(),
			LinkageSpec::Cforall, finally_func_t.clone(), body);
	}

	ObjectDecl * create_finally_hook(
			StructDecl * hook_decl, FunctionDecl * finally_wrapper ) {
		// struct __cfaehm__cleanup_hook __finally_hook
		//   	__attribute__((cleanup( finally_wrapper )));

		// Make Cleanup Attribute.
		std::list< Attribute * > attributes;
		{
			std::list< Expression * > attr_params;
			attr_params.push_back( nameOf( finally_wrapper ) );
			attributes.push_back( new Attribute( "cleanup", attr_params ) );
		}

		return new ObjectDecl(
			"__finally_hook",
			Type::StorageClasses(),
			LinkageSpec::Cforall,
			nullptr,
			new StructInstType(
				emptyQualifiers,
				hook_decl
				),
			nullptr,
			attributes
			);
	}


	class ExceptionMutatorCore : public WithGuards {
		enum Context { NoHandler, TerHandler, ResHandler };

		// Also need to handle goto, break & continue.
		// They need to be cut off in a ResHandler, until we enter another
		// loop, switch or the goto stays within the function.

		Context cur_context;

		// We might not need this, but a unique base for each try block's
		// generated functions might be nice.
		//std::string curFunctionName;
		//unsigned int try_count = 0;

		StructDecl *node_decl;
		StructDecl *hook_decl;

	public:
		ExceptionMutatorCore() :
			cur_context(NoHandler),
			node_decl(nullptr), hook_decl(nullptr)
		{}

		void premutate( CatchStmt *catchStmt );
		void premutate( StructDecl *structDecl );
		Statement * postmutate( ThrowStmt *throwStmt );
		Statement * postmutate( TryStmt *tryStmt );
	};

	Statement * ExceptionMutatorCore::postmutate( ThrowStmt *throwStmt ) {
		// Ignoring throwStmt->get_target() for now.
		if ( ThrowStmt::Terminate == throwStmt->get_kind() ) {
			if ( throwStmt->get_expr() ) {
				return create_terminate_throw( throwStmt );
			} else if ( TerHandler == cur_context ) {
				return create_terminate_rethrow( throwStmt );
			} else {
				assertf(false, "Invalid throw in %s at %i\n",
					throwStmt->location.filename.c_str(),
					throwStmt->location.linenumber);
				return nullptr;
			}
		} else {
			if ( throwStmt->get_expr() ) {
				return create_resume_throw( throwStmt );
			} else if ( ResHandler == cur_context ) {
				return create_resume_rethrow( throwStmt );
			} else {
				assertf(false, "Invalid throwResume in %s at %i\n",
					throwStmt->location.filename.c_str(),
					throwStmt->location.linenumber);
				return nullptr;
			}
		}
	}

	Statement * ExceptionMutatorCore::postmutate( TryStmt *tryStmt ) {
		assert( node_decl );
		assert( hook_decl );

		// Generate a prefix for the function names?

		CompoundStmt * block = new CompoundStmt( noLabels );
		CompoundStmt * inner = take_try_block( tryStmt );

		if ( tryStmt->get_finally() ) {
			// Define the helper function.
			FunctionDecl * finally_block =
				create_finally_wrapper( tryStmt );
			appendDeclStmt( block, finally_block );
			// Create and add the finally cleanup hook.
			appendDeclStmt( block,
				create_finally_hook( hook_decl, finally_block ) );
		}

		CatchList termination_handlers;
		CatchList resumption_handlers;
		split( tryStmt->get_catchers(),
			   termination_handlers, resumption_handlers );

		if ( resumption_handlers.size() ) {
			// Define the helper function.
			FunctionDecl * resume_handler =
				create_resume_handler( resumption_handlers );
			appendDeclStmt( block, resume_handler );
			// Prepare hooks
			inner = create_resume_wrapper( node_decl, inner, resume_handler );
		}

		if ( termination_handlers.size() ) {
			// Define the three helper functions.
			FunctionDecl * try_wrapper = create_try_wrapper( inner );
			appendDeclStmt( block, try_wrapper );
			FunctionDecl * terminate_catch =
				create_terminate_catch( termination_handlers );
			appendDeclStmt( block, terminate_catch );
			FunctionDecl * terminate_match =
				create_terminate_match( termination_handlers );
			appendDeclStmt( block, terminate_match );
			// Build the call to the try wrapper.
			inner = create_terminate_caller(
				try_wrapper, terminate_catch, terminate_match );
		}

		// Embed the try block.
		block->push_back( inner );

		//free_all( termination_handlers );
		//free_all( resumption_handlers );

		return block;
	}

	void ExceptionMutatorCore::premutate( CatchStmt *catchStmt ) {
		GuardValue( cur_context );
		if ( CatchStmt::Terminate == catchStmt->get_kind() ) {
			cur_context = TerHandler;
		} else {
			cur_context = ResHandler;
		}
	}

	void ExceptionMutatorCore::premutate( StructDecl *structDecl ) {
		if ( !structDecl->has_body() ) {
			// Skip children?
			return;
		} else if ( structDecl->get_name() == "__cfaehm__try_resume_node" ) {
			assert( nullptr == node_decl );
			node_decl = structDecl;
		} else if ( structDecl->get_name() == "__cfaehm__cleanup_hook" ) {
			assert( nullptr == hook_decl );
			hook_decl = structDecl;
		}
		// Later we might get the exception type as well.
	}

	void translateEHM( std::list< Declaration *> & translationUnit ) {
		init_func_types();

		PassVisitor<ExceptionMutatorCore> translator;
		for ( Declaration * decl : translationUnit ) {
			decl->acceptMutator( translator );
		}
	}
}
