//                              -*- Mode: CPP -*-
//
// Cforall Version 1.0.0 Copyright (C) 2016 University of Waterloo
//
// The contents of this file are covered under the licence agreement in the
// file "LICENCE" distributed with Cforall.
//
// Keywords.cc --
//
// Author           : Thierry Delisle
// Created On       : Mon Mar 13 12:41:22 2017
// Last Modified By :
// Last Modified On :
// Update Count     : 3
//

#include "Concurrency/Keywords.h"

#include "InitTweak/InitTweak.h"
#include "SymTab/AddVisit.h"
#include "SynTree/Declaration.h"
#include "SynTree/Expression.h"
#include "SynTree/Initializer.h"
#include "SynTree/Statement.h"
#include "SynTree/Type.h"
#include "SynTree/Visitor.h"

namespace Concurrency {

	namespace {
		const std::list<Label> noLabels;
		const std::list< Attribute * > noAttributes;
		Type::StorageClasses noStorage;
		Type::Qualifiers noQualifiers;
	}

	//=============================================================================================
	// Visitors declaration
	//=============================================================================================

	//-----------------------------------------------------------------------------
	//Handles sue type declarations :
	// sue MyType {                             struct MyType {
	// 	int data;                                  int data;
	// 	a_struct_t more_data;                      a_struct_t more_data;
	//                                =>             NewField_t newField;
	// };                                        };
	//                                           static inline NewField_t * getter_name( MyType * this ) { return &this->newField; }
	//
	class ConcurrentSueKeyword : public Visitor {
	  protected:
	    template< typename Visitor >
	    friend void SymTab::acceptAndAdd( std::list< Declaration * > &translationUnit, Visitor &visitor );
	  public:

	  	ConcurrentSueKeyword( std::string&& type_name, std::string&& field_name, std::string&& getter_name, std::string&& context_error, bool needs_main ) :
		  type_name( type_name ), field_name( field_name ), getter_name( getter_name ), context_error( context_error ), needs_main( needs_main ) {}

		virtual ~ConcurrentSueKeyword() {}

		using Visitor::visit;
		virtual void visit( StructDecl * decl ) override final;

		void handle( StructDecl * );
		FunctionDecl * forwardDeclare( StructDecl * );
		ObjectDecl * addField( StructDecl * );
		void addRoutines( StructDecl *, ObjectDecl *, FunctionDecl * );

		virtual bool is_target( StructDecl * decl ) = 0;

	  private:
		const std::string type_name;
		const std::string field_name;
		const std::string getter_name;
		const std::string context_error;
		bool needs_main;

		std::list< Declaration * > declsToAdd, declsToAddAfter;
		StructDecl* type_decl = nullptr;
	};


	//-----------------------------------------------------------------------------
	//Handles thread type declarations :
	// thread Mythread {                         struct MyThread {
	// 	int data;                                  int data;
	// 	a_struct_t more_data;                      a_struct_t more_data;
	//                                =>             thread_desc __thrd_d;
	// };                                        };
	//                                           static inline thread_desc * get_thread( MyThread * this ) { return &this->__thrd_d; }
	//
	class ThreadKeyword final : public ConcurrentSueKeyword {
	  public:

	  	ThreadKeyword() : ConcurrentSueKeyword(
			"thread_desc",
			"__thrd",
			"get_thread",
			"thread keyword requires threads to be in scope, add #include <thread>",
			true
		)
		{}

		virtual ~ThreadKeyword() {}

		virtual bool is_target( StructDecl * decl ) override final { return decl->is_thread(); }

		static void implement( std::list< Declaration * > & translationUnit ) {
			ThreadKeyword impl;
			SymTab::acceptAndAdd( translationUnit, impl );
		}
	};

	//-----------------------------------------------------------------------------
	//Handles coroutine type declarations :
	// coroutine MyCoroutine {                   struct MyCoroutine {
	// 	int data;                                  int data;
	// 	a_struct_t more_data;                      a_struct_t more_data;
	//                                =>             coroutine_desc __cor_d;
	// };                                        };
	//                                           static inline coroutine_desc * get_coroutine( MyCoroutine * this ) { return &this->__cor_d; }
	//
	class CoroutineKeyword final : public ConcurrentSueKeyword {
	  public:

	  	CoroutineKeyword() : ConcurrentSueKeyword(
			"coroutine_desc",
			"__cor",
			"get_coroutine",
			"coroutine keyword requires coroutines to be in scope, add #include <coroutine>",
			true
		)
		{}

		virtual ~CoroutineKeyword() {}

		virtual bool is_target( StructDecl * decl ) override final { return decl->is_coroutine(); }

		static void implement( std::list< Declaration * > & translationUnit ) {
			CoroutineKeyword impl;
			SymTab::acceptAndAdd( translationUnit, impl );
		}
	};

	//-----------------------------------------------------------------------------
	//Handles monitor type declarations :
	// monitor MyMonitor {                       struct MyMonitor {
	// 	int data;                                  int data;
	// 	a_struct_t more_data;                      a_struct_t more_data;
	//                                =>             monitor_desc __mon_d;
	// };                                        };
	//                                           static inline monitor_desc * get_coroutine( MyMonitor * this ) { return &this->__cor_d; }
	//
	class MonitorKeyword final : public ConcurrentSueKeyword {
	  public:

	  	MonitorKeyword() : ConcurrentSueKeyword(
			"monitor_desc",
			"__mon",
			"get_monitor",
			"monitor keyword requires monitors to be in scope, add #include <monitor>",
			false
		)
		{}

		virtual ~MonitorKeyword() {}

		virtual bool is_target( StructDecl * decl ) override final { return decl->is_monitor(); }

		static void implement( std::list< Declaration * > & translationUnit ) {
			MonitorKeyword impl;
			SymTab::acceptAndAdd( translationUnit, impl );
		}
	};

	//-----------------------------------------------------------------------------
	//Handles mutex routines definitions :
	// void foo( A * mutex a, B * mutex b,  int i ) {                  void foo( A * a, B * b,  int i ) {
	// 	                                                                 monitor_desc * __monitors[] = { get_monitor(a), get_monitor(b) };
	// 	                                                                 monitor_guard_t __guard = { __monitors, 2 };
	//    /*Some code*/                                       =>           /*Some code*/
	// }                                                               }
	//
	class MutexKeyword final : public Visitor {
	  public:

		using Visitor::visit;
		virtual void visit( FunctionDecl * decl ) override final;
		virtual void visit(   StructDecl * decl ) override final;

		std::list<DeclarationWithType*> findMutexArgs( FunctionDecl* );
		void validate( DeclarationWithType * );
		void addStatments( CompoundStmt *, const std::list<DeclarationWithType * > &);

		static void implement( std::list< Declaration * > & translationUnit ) {
			MutexKeyword impl;
			acceptAll( translationUnit, impl );
		}

	  private:
	  	StructDecl* monitor_decl = nullptr;
		StructDecl* guard_decl = nullptr;
	};

	//-----------------------------------------------------------------------------
	//Handles mutex routines definitions :
	// void foo( A * mutex a, B * mutex b,  int i ) {                  void foo( A * a, B * b,  int i ) {
	// 	                                                                 monitor_desc * __monitors[] = { get_monitor(a), get_monitor(b) };
	// 	                                                                 monitor_guard_t __guard = { __monitors, 2 };
	//    /*Some code*/                                       =>           /*Some code*/
	// }                                                               }
	//
	class ThreadStarter final : public Visitor {
	  public:

		using Visitor::visit;
		virtual void visit( FunctionDecl * decl ) override final;

		void addStartStatement( FunctionDecl * decl, DeclarationWithType * param );

		static void implement( std::list< Declaration * > & translationUnit ) {
			ThreadStarter impl;
			acceptAll( translationUnit, impl );
		}
	};

	//=============================================================================================
	// General entry routine
	//=============================================================================================
	void applyKeywords( std::list< Declaration * > & translationUnit ) {
		ThreadKeyword	::implement( translationUnit );
		CoroutineKeyword	::implement( translationUnit );
		MonitorKeyword	::implement( translationUnit );
	}

	void implementMutexFuncs( std::list< Declaration * > & translationUnit ) {
		MutexKeyword	::implement( translationUnit );
	}

	void implementThreadStarter( std::list< Declaration * > & translationUnit ) {
		ThreadStarter	::implement( translationUnit );
	}

	//=============================================================================================
	// Generic keyword implementation
	//=============================================================================================
	void ConcurrentSueKeyword::visit(StructDecl * decl) {
		Visitor::visit(decl);
		if( decl->get_name() == type_name ) {
			assert( !type_decl );
			type_decl = decl;
		}
		else if ( is_target(decl) ) {
			handle( decl );
		}

	}

	void ConcurrentSueKeyword::handle( StructDecl * decl ) {
		if( ! decl->has_body() ) return;

		if( !type_decl ) throw SemanticError( context_error, decl );

		FunctionDecl * func = forwardDeclare( decl );
		ObjectDecl * field = addField( decl );
		addRoutines( decl, field, func );
	}

	FunctionDecl * ConcurrentSueKeyword::forwardDeclare( StructDecl * decl ) {

		StructDecl * forward = decl->clone();
		forward->set_body( false );
		deleteAll( forward->get_members() );
		forward->get_members().clear();

		FunctionType * get_type = new FunctionType( noQualifiers, false );
		ObjectDecl * this_decl = new ObjectDecl(
			"this",
			noStorage,
			LinkageSpec::Cforall,
			nullptr,
			new PointerType(
				noQualifiers,
				new StructInstType(
					noQualifiers,
					decl
				)
			),
			nullptr
		);

		get_type->get_parameters().push_back( this_decl );
		get_type->get_returnVals().push_back(
			new ObjectDecl(
				"ret",
				noStorage,
				LinkageSpec::Cforall,
				nullptr,
				new PointerType(
					noQualifiers,
					new StructInstType(
						noQualifiers,
						type_decl
					)
				),
				nullptr
			)
		);

		FunctionDecl * get_decl = new FunctionDecl(
			getter_name,
			Type::Static,
			LinkageSpec::Cforall,
			get_type,
			nullptr,
			noAttributes,
			Type::Inline
		);

		FunctionDecl * main_decl = nullptr;

		if( needs_main ) {
			FunctionType * main_type = new FunctionType( noQualifiers, false );
			
			main_type->get_parameters().push_back( this_decl->clone() );

			main_decl = new FunctionDecl(
				"main",
				noStorage,
				LinkageSpec::Cforall,
				main_type,
				nullptr
			);
		}

		declsToAdd.push_back( forward );
		if( needs_main ) declsToAdd.push_back( main_decl );
		declsToAdd.push_back( get_decl );

		return get_decl;
	}

	ObjectDecl * ConcurrentSueKeyword::addField( StructDecl * decl ) {
		ObjectDecl * field = new ObjectDecl(
			field_name,
			noStorage,
			LinkageSpec::Cforall,
			nullptr,
			new StructInstType(
				noQualifiers,
				type_decl
			),
			nullptr
		);

		decl->get_members().push_back( field );

		return field;
	}

	void ConcurrentSueKeyword::addRoutines( StructDecl * decl, ObjectDecl * field, FunctionDecl * func ) {
		CompoundStmt * statement = new CompoundStmt( noLabels );
		statement->push_back( 
			new ReturnStmt(
				noLabels,
				new AddressExpr(
					new MemberExpr(
						field,
						UntypedExpr::createDeref( new VariableExpr( func->get_functionType()->get_parameters().front() ) )
					)
				)
			)
		);

		FunctionDecl * get_decl = func->clone();

		get_decl->set_statements( statement );

		declsToAddAfter.push_back( get_decl );

		// get_decl->fixUniqueId();
	}

	//=============================================================================================
	// Mutex keyword implementation
	//=============================================================================================
	void MutexKeyword::visit(FunctionDecl* decl) {
		Visitor::visit(decl);		

		std::list<DeclarationWithType*> mutexArgs = findMutexArgs( decl );
		if( mutexArgs.empty() ) return;

		for(auto arg : mutexArgs) {
			validate( arg );
		}

		CompoundStmt* body = decl->get_statements();
		if( ! body ) return;

		if( !monitor_decl ) throw SemanticError( "mutex keyword requires monitors to be in scope, add #include <monitor>", decl );
		if( !guard_decl ) throw SemanticError( "mutex keyword requires monitors to be in scope, add #include <monitor>", decl );

		addStatments( body, mutexArgs );
	}

	void MutexKeyword::visit(StructDecl* decl) {
		Visitor::visit(decl);

		if( decl->get_name() == "monitor_desc" ) {
			assert( !monitor_decl );
			monitor_decl = decl;
		}
		else if( decl->get_name() == "monitor_guard_t" ) {
			assert( !guard_decl );
			guard_decl = decl;
		}
	}

	std::list<DeclarationWithType*> MutexKeyword::findMutexArgs( FunctionDecl* decl ) {
		std::list<DeclarationWithType*> mutexArgs;

		for( auto arg : decl->get_functionType()->get_parameters()) {
			//Find mutex arguments
			Type* ty = arg->get_type();
			if( ! ty->get_mutex() ) continue;

			//Append it to the list
			mutexArgs.push_back( arg );
		}

		return mutexArgs;
	}

	void MutexKeyword::validate( DeclarationWithType * arg ) {
		Type* ty = arg->get_type();

		//Makes sure it's not a copy
		PointerType* pty = dynamic_cast< PointerType * >( ty );
		if( ! pty ) throw SemanticError( "Mutex argument must be of pointer/reference type ", arg );

		//Make sure the we are pointing directly to a type
		Type* base = pty->get_base();
		if(  dynamic_cast< PointerType * >( base ) ) throw SemanticError( "Mutex argument have exactly one level of indirection ", arg );

		//Make sure that typed isn't mutex
		if( base->get_mutex() ) throw SemanticError( "mutex keyword may only appear once per argument ", arg );
	}

	void MutexKeyword::addStatments( CompoundStmt * body, const std::list<DeclarationWithType * > & args ) {
		ObjectDecl * monitors = new ObjectDecl(
			"__monitors",
			noStorage,
			LinkageSpec::Cforall,
			nullptr,
			new ArrayType(
				noQualifiers,
				new PointerType(
					noQualifiers,
					new StructInstType(
						noQualifiers,
						monitor_decl
					)
				),
				new ConstantExpr( Constant::from_ulong( args.size() ) ),
				false,
				false
			),
			new ListInit(
				map_range < std::list<Initializer*> > ( args, [this](DeclarationWithType * var ){
					Type * type = var->get_type()->clone();
					type->set_mutex( false );
					return new SingleInit( new UntypedExpr(
						new NameExpr( "get_monitor" ),
						{  new CastExpr( new VariableExpr( var ), type ) }
					) );
				})
			)
		);

		//in reverse order :
		// monitor_guard_t __guard = { __monitors, # };
		body->push_front(
			new DeclStmt( noLabels, new ObjectDecl(
				"__guard",
				noStorage,
				LinkageSpec::Cforall,
				nullptr,
				new StructInstType(
					noQualifiers,
					guard_decl
				),
				new ListInit(
					{
						new SingleInit( new VariableExpr( monitors ) ),
						new SingleInit( new ConstantExpr( Constant::from_ulong( args.size() ) ) )
					},
					noDesignators,
					true
				)
			))
		);

		//monitor_desc * __monitors[] = { get_monitor(a), get_monitor(b) };
		body->push_front( new DeclStmt( noLabels, monitors) );
	}

	//=============================================================================================
	// General entry routine
	//=============================================================================================
	void ThreadStarter::visit(FunctionDecl * decl) {
		Visitor::visit(decl);
		
		if( ! InitTweak::isConstructor(decl->get_name()) ) return;

		DeclarationWithType * param = decl->get_functionType()->get_parameters().front();
		auto ptr = dynamic_cast< PointerType * >( param->get_type() );
		// if( ptr ) std::cerr << "FRED1" << std::endl;
		auto type  = dynamic_cast< StructInstType * >( ptr->get_base() );
		// if( type ) std::cerr << "FRED2" << std::endl;
		if( type && type->get_baseStruct()->is_thread() ) {
			addStartStatement( decl, param );
		}
	}

	void ThreadStarter::addStartStatement( FunctionDecl * decl, DeclarationWithType * param ) {
		CompoundStmt * stmt = decl->get_statements();

		if( ! stmt ) return;

		stmt->push_back( 
			new ExprStmt(
				noLabels,
				new UntypedExpr(
					new NameExpr( "__thrd_start" ),
					{ new VariableExpr( param ) }
				)
			)
		);
	}
};
