//
// Cforall Version 1.0.0 Copyright (C) 2016 University of Waterloo
//
// The contents of this file are covered under the licence agreement in the
// file "LICENCE" distributed with Cforall.
//
// Keywords.cc --
//
// Author           : Thierry Delisle
// Created On       : Mon Mar 13 12:41:22 2017
// Last Modified By :
// Last Modified On :
// Update Count     : 5
//

#include "Concurrency/Keywords.h"

#include <cassert>                 // for assert
#include <string>                  // for string, operator==

#include "Common/PassVisitor.h"    // for PassVisitor
#include "Common/SemanticError.h"  // for SemanticError
#include "Common/utility.h"        // for deleteAll, map_range
#include "CodeGen/OperatorTable.h" // for isConstructor
#include "InitTweak/InitTweak.h"   // for getPointerBase
#include "Parser/LinkageSpec.h"    // for Cforall
#include "SynTree/Constant.h"      // for Constant
#include "SynTree/Declaration.h"   // for StructDecl, FunctionDecl, ObjectDecl
#include "SynTree/Expression.h"    // for VariableExpr, ConstantExpr, Untype...
#include "SynTree/Initializer.h"   // for SingleInit, ListInit, Initializer ...
#include "SynTree/Label.h"         // for Label
#include "SynTree/Statement.h"     // for CompoundStmt, DeclStmt, ExprStmt
#include "SynTree/Type.h"          // for StructInstType, Type, PointerType
#include "SynTree/Visitor.h"       // for Visitor, acceptAll

class Attribute;

namespace Concurrency {
	//=============================================================================================
	// Pass declarations
	//=============================================================================================

	//-----------------------------------------------------------------------------
	//Handles sue type declarations :
	// sue MyType {                             struct MyType {
	// 	int data;                                  int data;
	// 	a_struct_t more_data;                      a_struct_t more_data;
	//                                =>             NewField_t newField;
	// };                                        };
	//                                           static inline NewField_t * getter_name( MyType * this ) { return &this->newField; }
	//
	class ConcurrentSueKeyword : public WithDeclsToAdd {
	  public:

	  	ConcurrentSueKeyword( std::string&& type_name, std::string&& field_name, std::string&& getter_name, std::string&& context_error, bool needs_main, KeywordCastExpr::Target cast_target ) :
		  type_name( type_name ), field_name( field_name ), getter_name( getter_name ), context_error( context_error ), needs_main( needs_main ), cast_target( cast_target ) {}

		virtual ~ConcurrentSueKeyword() {}

		Declaration * postmutate( StructDecl * decl );

		void handle( StructDecl * );
		FunctionDecl * forwardDeclare( StructDecl * );
		ObjectDecl * addField( StructDecl * );
		void addRoutines( ObjectDecl *, FunctionDecl * );

		virtual bool is_target( StructDecl * decl ) = 0;

		Expression * postmutate( KeywordCastExpr * cast );

	  private:
		const std::string type_name;
		const std::string field_name;
		const std::string getter_name;
		const std::string context_error;
		bool needs_main;
		KeywordCastExpr::Target cast_target;

		StructDecl* type_decl = nullptr;
	};


	//-----------------------------------------------------------------------------
	//Handles thread type declarations :
	// thread Mythread {                         struct MyThread {
	// 	int data;                                  int data;
	// 	a_struct_t more_data;                      a_struct_t more_data;
	//                                =>             thread_desc __thrd_d;
	// };                                        };
	//                                           static inline thread_desc * get_thread( MyThread * this ) { return &this->__thrd_d; }
	//
	class ThreadKeyword final : public ConcurrentSueKeyword {
	  public:

	  	ThreadKeyword() : ConcurrentSueKeyword(
			"thread_desc",
			"__thrd",
			"get_thread",
			"thread keyword requires threads to be in scope, add #include <thread>",
			true,
			KeywordCastExpr::Thread
		)
		{}

		virtual ~ThreadKeyword() {}

		virtual bool is_target( StructDecl * decl ) override final { return decl->is_thread(); }

		static void implement( std::list< Declaration * > & translationUnit ) {
			PassVisitor< ThreadKeyword > impl;
			mutateAll( translationUnit, impl );
		}
	};

	//-----------------------------------------------------------------------------
	//Handles coroutine type declarations :
	// coroutine MyCoroutine {                   struct MyCoroutine {
	// 	int data;                                  int data;
	// 	a_struct_t more_data;                      a_struct_t more_data;
	//                                =>             coroutine_desc __cor_d;
	// };                                        };
	//                                           static inline coroutine_desc * get_coroutine( MyCoroutine * this ) { return &this->__cor_d; }
	//
	class CoroutineKeyword final : public ConcurrentSueKeyword {
	  public:

	  	CoroutineKeyword() : ConcurrentSueKeyword(
			"coroutine_desc",
			"__cor",
			"get_coroutine",
			"coroutine keyword requires coroutines to be in scope, add #include <coroutine>",
			true,
			KeywordCastExpr::Coroutine
		)
		{}

		virtual ~CoroutineKeyword() {}

		virtual bool is_target( StructDecl * decl ) override final { return decl->is_coroutine(); }

		static void implement( std::list< Declaration * > & translationUnit ) {
			PassVisitor< CoroutineKeyword > impl;
			mutateAll( translationUnit, impl );
		}
	};

	//-----------------------------------------------------------------------------
	//Handles monitor type declarations :
	// monitor MyMonitor {                       struct MyMonitor {
	// 	int data;                                  int data;
	// 	a_struct_t more_data;                      a_struct_t more_data;
	//                                =>             monitor_desc __mon_d;
	// };                                        };
	//                                           static inline monitor_desc * get_coroutine( MyMonitor * this ) { return &this->__cor_d; }
	//
	class MonitorKeyword final : public ConcurrentSueKeyword {
	  public:

	  	MonitorKeyword() : ConcurrentSueKeyword(
			"monitor_desc",
			"__mon",
			"get_monitor",
			"monitor keyword requires monitors to be in scope, add #include <monitor>",
			false,
			KeywordCastExpr::Monitor
		)
		{}

		virtual ~MonitorKeyword() {}

		virtual bool is_target( StructDecl * decl ) override final { return decl->is_monitor(); }

		static void implement( std::list< Declaration * > & translationUnit ) {
			PassVisitor< MonitorKeyword > impl;
			mutateAll( translationUnit, impl );
		}
	};

	//-----------------------------------------------------------------------------
	//Handles mutex routines definitions :
	// void foo( A * mutex a, B * mutex b,  int i ) {                  void foo( A * a, B * b,  int i ) {
	// 	                                                                 monitor_desc * __monitors[] = { get_monitor(a), get_monitor(b) };
	// 	                                                                 monitor_guard_t __guard = { __monitors, 2 };
	//    /*Some code*/                                       =>           /*Some code*/
	// }                                                               }
	//
	class MutexKeyword final {
	  public:

		void postvisit( FunctionDecl * decl );
		void postvisit(   StructDecl * decl );

		std::list<DeclarationWithType*> findMutexArgs( FunctionDecl*, bool & first );
		void validate( DeclarationWithType * );
		void addDtorStatments( FunctionDecl* func, CompoundStmt *, const std::list<DeclarationWithType * > &);
		void addStatments( FunctionDecl* func, CompoundStmt *, const std::list<DeclarationWithType * > &);

		static void implement( std::list< Declaration * > & translationUnit ) {
			PassVisitor< MutexKeyword > impl;
			acceptAll( translationUnit, impl );
		}

	  private:
	  	StructDecl* monitor_decl = nullptr;
		StructDecl* guard_decl = nullptr;
		StructDecl* dtor_guard_decl = nullptr;

		static std::unique_ptr< Type > generic_func;
	};

	std::unique_ptr< Type > MutexKeyword::generic_func = std::unique_ptr< Type >(
		new FunctionType(
			noQualifiers,
			true
		)
	);

	//-----------------------------------------------------------------------------
	//Handles mutex routines definitions :
	// void foo( A * mutex a, B * mutex b,  int i ) {                  void foo( A * a, B * b,  int i ) {
	// 	                                                                 monitor_desc * __monitors[] = { get_monitor(a), get_monitor(b) };
	// 	                                                                 monitor_guard_t __guard = { __monitors, 2 };
	//    /*Some code*/                                       =>           /*Some code*/
	// }                                                               }
	//
	class ThreadStarter final {
	  public:

		void postvisit( FunctionDecl * decl );
		void previsit ( StructDecl   * decl );

		void addStartStatement( FunctionDecl * decl, DeclarationWithType * param );

		static void implement( std::list< Declaration * > & translationUnit ) {
			PassVisitor< ThreadStarter > impl;
			acceptAll( translationUnit, impl );
		}

	  private :
		bool thread_ctor_seen = false;
		StructDecl * thread_decl = nullptr;
	};

	//=============================================================================================
	// General entry routine
	//=============================================================================================
	void applyKeywords( std::list< Declaration * > & translationUnit ) {
		ThreadKeyword	::implement( translationUnit );
		CoroutineKeyword	::implement( translationUnit );
		MonitorKeyword	::implement( translationUnit );
	}

	void implementMutexFuncs( std::list< Declaration * > & translationUnit ) {
		MutexKeyword	::implement( translationUnit );
	}

	void implementThreadStarter( std::list< Declaration * > & translationUnit ) {
		ThreadStarter	::implement( translationUnit );
	}

	//=============================================================================================
	// Generic keyword implementation
	//=============================================================================================
	void fixupGenerics(FunctionType * func, StructDecl * decl) {
		cloneAll(decl->parameters, func->forall);
		for ( TypeDecl * td : func->forall ) {
			strict_dynamic_cast<StructInstType*>(
				func->parameters.front()->get_type()->stripReferences()
			)->parameters.push_back(
				new TypeExpr( new TypeInstType( noQualifiers, td->name, td ) )
			);
		}
	}

	Declaration * ConcurrentSueKeyword::postmutate(StructDecl * decl) {
		if( decl->name == type_name && decl->body ) {
			assert( !type_decl );
			type_decl = decl;
		}
		else if ( is_target(decl) ) {
			handle( decl );
		}
		return decl;
	}

	Expression * ConcurrentSueKeyword::postmutate( KeywordCastExpr * cast ) {
		if ( cast_target == cast->target ) {
			// convert (thread &)t to (thread_desc &)*get_thread(t), etc.
			if( !type_decl ) SemanticError( cast, context_error );
			Expression * arg = cast->arg;
			cast->arg = nullptr;
			delete cast;
			return new CastExpr(
				UntypedExpr::createDeref(
					new UntypedExpr( new NameExpr( getter_name ), { arg } )
				),
				new ReferenceType(
					noQualifiers,
					new StructInstType( noQualifiers, type_decl ) )
				);
		}
		return cast;
	}


	void ConcurrentSueKeyword::handle( StructDecl * decl ) {
		if( ! decl->body ) return;

		if( !type_decl ) SemanticError( decl, context_error );

		FunctionDecl * func = forwardDeclare( decl );
		ObjectDecl * field = addField( decl );
		addRoutines( field, func );
	}

	FunctionDecl * ConcurrentSueKeyword::forwardDeclare( StructDecl * decl ) {

		StructDecl * forward = decl->clone();
		forward->set_body( false );
		deleteAll( forward->get_members() );
		forward->get_members().clear();

		FunctionType * get_type = new FunctionType( noQualifiers, false );
		ObjectDecl * this_decl = new ObjectDecl(
			"this",
			noStorageClasses,
			LinkageSpec::Cforall,
			nullptr,
			new ReferenceType(
				noQualifiers,
				new StructInstType(
					noQualifiers,
					decl
				)
			),
			nullptr
		);

		get_type->get_parameters().push_back( this_decl->clone() );
		get_type->get_returnVals().push_back(
			new ObjectDecl(
				"ret",
				noStorageClasses,
				LinkageSpec::Cforall,
				nullptr,
				new PointerType(
					noQualifiers,
					new StructInstType(
						noQualifiers,
						type_decl
					)
				),
				nullptr
			)
		);
		fixupGenerics(get_type, decl);

		FunctionDecl * get_decl = new FunctionDecl(
			getter_name,
			Type::Static,
			LinkageSpec::Cforall,
			get_type,
			nullptr,
			noAttributes,
			Type::Inline
		);

		FunctionDecl * main_decl = nullptr;

		if( needs_main ) {
			FunctionType * main_type = new FunctionType( noQualifiers, false );

			main_type->get_parameters().push_back( this_decl->clone() );

			main_decl = new FunctionDecl(
				"main",
				noStorageClasses,
				LinkageSpec::Cforall,
				main_type,
				nullptr
			);
			fixupGenerics(main_type, decl);
		}

		delete this_decl;

		declsToAddBefore.push_back( forward );
		if( needs_main ) declsToAddBefore.push_back( main_decl );
		declsToAddBefore.push_back( get_decl );

		return get_decl;
	}

	ObjectDecl * ConcurrentSueKeyword::addField( StructDecl * decl ) {
		ObjectDecl * field = new ObjectDecl(
			field_name,
			noStorageClasses,
			LinkageSpec::Cforall,
			nullptr,
			new StructInstType(
				noQualifiers,
				type_decl
			),
			nullptr
		);

		decl->get_members().push_back( field );

		return field;
	}

	void ConcurrentSueKeyword::addRoutines( ObjectDecl * field, FunctionDecl * func ) {
		CompoundStmt * statement = new CompoundStmt();
		statement->push_back(
			new ReturnStmt(
				new AddressExpr(
					new MemberExpr(
						field,
						new CastExpr(
							new VariableExpr( func->get_functionType()->get_parameters().front() ),
							func->get_functionType()->get_parameters().front()->get_type()->stripReferences()->clone()
						)
					)
				)
			)
		);

		FunctionDecl * get_decl = func->clone();

		get_decl->set_statements( statement );

		declsToAddAfter.push_back( get_decl );

		// get_decl->fixUniqueId();
	}

	//=============================================================================================
	// Mutex keyword implementation
	//=============================================================================================

	void MutexKeyword::postvisit(FunctionDecl* decl) {

		bool first = false;
		std::list<DeclarationWithType*> mutexArgs = findMutexArgs( decl, first );
		bool isDtor = CodeGen::isDestructor( decl->name );

		// Is this function relevant to monitors
		if( mutexArgs.empty() ) {
			// If this is the destructor for a monitor it must be mutex
			if(isDtor) {
				Type* ty = decl->get_functionType()->get_parameters().front()->get_type();

				// If it's a copy, it's not a mutex
				ReferenceType* rty = dynamic_cast< ReferenceType * >( ty );
				if( ! rty ) return;

				// If we are not pointing directly to a type, it's not a mutex
				Type* base = rty->get_base();
				if( dynamic_cast< ReferenceType * >( base ) ) return;
				if( dynamic_cast< PointerType * >( base ) ) return;

				// Check if its a struct
				StructInstType * baseStruct = dynamic_cast< StructInstType * >( base );
				if( !baseStruct ) return;

				// Check if its a monitor
				if(baseStruct->baseStruct->is_monitor() || baseStruct->baseStruct->is_thread())
					SemanticError( decl, "destructors for structures declared as \"monitor\" must use mutex parameters\n" );
			}
			return;
		}

		// Monitors can't be constructed with mutual exclusion
		if( CodeGen::isConstructor(decl->name) && !first ) SemanticError( decl, "constructors cannot have mutex parameters" );

		// It makes no sense to have multiple mutex parameters for the destructor
		if( isDtor && mutexArgs.size() != 1 ) SemanticError( decl, "destructors can only have 1 mutex argument" );

		// Make sure all the mutex arguments are monitors
		for(auto arg : mutexArgs) {
			validate( arg );
		}

		// Check if we need to instrument the body
		CompoundStmt* body = decl->get_statements();
		if( ! body ) return;

		// Do we have the required headers
		if( !monitor_decl || !guard_decl || !dtor_guard_decl )
			SemanticError( decl, "mutex keyword requires monitors to be in scope, add #include <monitor>\n" );

		// Instrument the body
		if( isDtor ) {
			addDtorStatments( decl, body, mutexArgs );
		}
		else {
			addStatments( decl, body, mutexArgs );
		}
	}

	void MutexKeyword::postvisit(StructDecl* decl) {

		if( decl->name == "monitor_desc" ) {
			assert( !monitor_decl );
			monitor_decl = decl;
		}
		else if( decl->name == "monitor_guard_t" ) {
			assert( !guard_decl );
			guard_decl = decl;
		}
		else if( decl->name == "monitor_dtor_guard_t" ) {
			assert( !dtor_guard_decl );
			dtor_guard_decl = decl;
		}
	}

	std::list<DeclarationWithType*> MutexKeyword::findMutexArgs( FunctionDecl* decl, bool & first ) {
		std::list<DeclarationWithType*> mutexArgs;

		bool once = true;
		for( auto arg : decl->get_functionType()->get_parameters()) {
			//Find mutex arguments
			Type* ty = arg->get_type();
			if( ! ty->get_mutex() ) continue;

			if(once) {first = true;}
			once = false;

			//Append it to the list
			mutexArgs.push_back( arg );
		}

		return mutexArgs;
	}

	void MutexKeyword::validate( DeclarationWithType * arg ) {
		Type* ty = arg->get_type();

		//Makes sure it's not a copy
		ReferenceType* rty = dynamic_cast< ReferenceType * >( ty );
		if( ! rty ) SemanticError( arg, "Mutex argument must be of reference type " );

		//Make sure the we are pointing directly to a type
		Type* base = rty->get_base();
		if( dynamic_cast< ReferenceType * >( base ) ) SemanticError( arg, "Mutex argument have exactly one level of indirection " );
		if( dynamic_cast< PointerType * >( base ) ) SemanticError( arg, "Mutex argument have exactly one level of indirection " );

		//Make sure that typed isn't mutex
		if( base->get_mutex() ) SemanticError( arg, "mutex keyword may only appear once per argument " );
	}

	void MutexKeyword::addDtorStatments( FunctionDecl* func, CompoundStmt * body, const std::list<DeclarationWithType * > & args ) {
		Type * arg_type = args.front()->get_type()->clone();
		arg_type->set_mutex( false );

		ObjectDecl * monitors = new ObjectDecl(
			"__monitor",
			noStorageClasses,
			LinkageSpec::Cforall,
			nullptr,
			new PointerType(
				noQualifiers,
				new StructInstType(
					noQualifiers,
					monitor_decl
				)
			),
			new SingleInit( new UntypedExpr(
				new NameExpr( "get_monitor" ),
				{  new CastExpr( new VariableExpr( args.front() ), arg_type ) }
			))
		);

		assert(generic_func);

		//in reverse order :
		// monitor_guard_t __guard = { __monitors, #, func };
		body->push_front(
			new DeclStmt( new ObjectDecl(
				"__guard",
				noStorageClasses,
				LinkageSpec::Cforall,
				nullptr,
				new StructInstType(
					noQualifiers,
					dtor_guard_decl
				),
				new ListInit(
					{
						new SingleInit( new AddressExpr( new VariableExpr( monitors ) ) ),
						new SingleInit( new CastExpr( new VariableExpr( func ), generic_func->clone() ) )
					},
					noDesignators,
					true
				)
			))
		);

		//monitor_desc * __monitors[] = { get_monitor(a), get_monitor(b) };
		body->push_front( new DeclStmt( monitors) );
	}

	void MutexKeyword::addStatments( FunctionDecl* func, CompoundStmt * body, const std::list<DeclarationWithType * > & args ) {
		ObjectDecl * monitors = new ObjectDecl(
			"__monitors",
			noStorageClasses,
			LinkageSpec::Cforall,
			nullptr,
			new ArrayType(
				noQualifiers,
				new PointerType(
					noQualifiers,
					new StructInstType(
						noQualifiers,
						monitor_decl
					)
				),
				new ConstantExpr( Constant::from_ulong( args.size() ) ),
				false,
				false
			),
			new ListInit(
				map_range < std::list<Initializer*> > ( args, [](DeclarationWithType * var ){
					Type * type = var->get_type()->clone();
					type->set_mutex( false );
					return new SingleInit( new UntypedExpr(
						new NameExpr( "get_monitor" ),
						{  new CastExpr( new VariableExpr( var ), type ) }
					) );
				})
			)
		);

		assert(generic_func);

		//in reverse order :
		// monitor_guard_t __guard = { __monitors, #, func };
		body->push_front(
			new DeclStmt( new ObjectDecl(
				"__guard",
				noStorageClasses,
				LinkageSpec::Cforall,
				nullptr,
				new StructInstType(
					noQualifiers,
					guard_decl
				),
				new ListInit(
					{
						new SingleInit( new VariableExpr( monitors ) ),
						new SingleInit( new ConstantExpr( Constant::from_ulong( args.size() ) ) ),
						new SingleInit( new CastExpr( new VariableExpr( func ), generic_func->clone() ) )
					},
					noDesignators,
					true
				)
			))
		);

		//monitor_desc * __monitors[] = { get_monitor(a), get_monitor(b) };
		body->push_front( new DeclStmt( monitors) );
	}

	//=============================================================================================
	// General entry routine
	//=============================================================================================
	void ThreadStarter::previsit( StructDecl * decl ) {
		if( decl->name == "thread_desc" && decl->body ) {
			assert( !thread_decl );
			thread_decl = decl;
		}
	}

	void ThreadStarter::postvisit(FunctionDecl * decl) {
		if( ! CodeGen::isConstructor(decl->name) ) return;

		Type * typeof_this = InitTweak::getTypeofThis(decl->type);
		StructInstType * ctored_type = dynamic_cast< StructInstType * >( typeof_this );
		if( ctored_type && ctored_type->baseStruct == thread_decl ) {
			thread_ctor_seen = true;
		}

		DeclarationWithType * param = decl->get_functionType()->get_parameters().front();
		auto type  = dynamic_cast< StructInstType * >( InitTweak::getPointerBase( param->get_type() ) );
		if( type && type->get_baseStruct()->is_thread() ) {
			if( !thread_decl || !thread_ctor_seen ) {
				SemanticError( type->get_baseStruct()->location, "thread keyword requires threads to be in scope, add #include <thread>");
			}

			addStartStatement( decl, param );
		}
	}

	void ThreadStarter::addStartStatement( FunctionDecl * decl, DeclarationWithType * param ) {
		CompoundStmt * stmt = decl->get_statements();

		if( ! stmt ) return;

		stmt->push_back(
			new ExprStmt(
				new UntypedExpr(
					new NameExpr( "__thrd_start" ),
					{ new VariableExpr( param ) }
				)
			)
		);
	}
};

// Local Variables: //
// mode: c //
// tab-width: 4 //
// End: //
