//                              -*- Mode: CPP -*-
//
// Cforall Version 1.0.0 Copyright (C) 2016 University of Waterloo
//
// The contents of this file are covered under the licence agreement in the
// file "LICENCE" distributed with Cforall.
//
// Keywords.cc --
//
// Author           : Thierry Delisle
// Created On       : Mon Mar 13 12:41:22 2017
// Last Modified By :
// Last Modified On :
// Update Count     : 3
//

#include "Concurrency/Keywords.h"

#include "SymTab/AddVisit.h"
#include "SynTree/Declaration.h"
#include "SynTree/Expression.h"
#include "SynTree/Initializer.h"
#include "SynTree/Mutator.h"
#include "SynTree/Statement.h"
#include "SynTree/Type.h"
#include "SynTree/Visitor.h"

namespace Concurrency {

	namespace {
		const std::list<Label> noLabels;
		const std::list< Attribute * > noAttributes;
		Type::StorageClasses noStorage;
		Type::Qualifiers noQualifiers;
	}

	//=============================================================================================
	// Visitors declaration
	//=============================================================================================

	//-----------------------------------------------------------------------------
	//Handles thread type declarations :
	// thread Mythread {                         struct MyThread {
	// 	int data;                                  int data;
	// 	a_struct_t more_data;                      a_struct_t more_data;
	//                                =>             thread_desc __thrd_d;
	// };                                        };
	//                                           static inline thread_desc * get_thread( MyThread * this ) { return &this->__thrd_d; }
	//                                           void main( MyThread * this );
	//
	class ThreadKeyword final : public Mutator {
	  public:

		static void implement( std::list< Declaration * > & translationUnit ) {}
	};

	//-----------------------------------------------------------------------------
	//Handles coroutine type declarations :
	// coroutine MyCoroutine {                   struct MyCoroutine {
	// 	int data;                                  int data;
	// 	a_struct_t more_data;                      a_struct_t more_data;
	//                                =>             coroutine_desc __cor_d;
	// };                                        };
	//                                           static inline coroutine_desc * get_coroutine( MyCoroutine * this ) { return &this->__cor_d; }
	//                                           void main( MyCoroutine * this );
	//
	class CoroutineKeyword final : public Visitor {
	    template< typename Visitor >
	    friend void SymTab::acceptAndAdd( std::list< Declaration * > &translationUnit, Visitor &visitor );
	  public:

		using Visitor::visit;
		virtual void visit( StructDecl * decl ) override final;

		void handle( StructDecl * );
		Declaration * addField( StructDecl * );
		void addRoutines( StructDecl *, Declaration * );

		static void implement( std::list< Declaration * > & translationUnit ) {
			CoroutineKeyword impl;
			SymTab::acceptAndAdd( translationUnit, impl );
		}

	  private:
		std::list< Declaration * > declsToAdd, declsToAddAfter;
		StructDecl* coroutine_decl = nullptr;
	};

	//-----------------------------------------------------------------------------
	//Handles monitor type declarations :
	// monitor MyMonitor {                       struct MyMonitor {
	// 	int data;                                  int data;
	// 	a_struct_t more_data;                      a_struct_t more_data;
	//                                =>             monitor_desc __mon_d;
	// };                                        };
	//                                           static inline monitor_desc * get_coroutine( MyMonitor * this ) { return &this->__cor_d; }
	//                                           void main( MyMonitor * this );
	//
	class MonitorKeyword final : public Mutator {
	  public:

		static void implement( std::list< Declaration * > & translationUnit ) {}
	};

	//-----------------------------------------------------------------------------
	//Handles mutex routines definitions :
	// void foo( A * mutex a, B * mutex b,  int i ) {                  void foo( A * a, B * b,  int i ) {
	// 	                                                                 monitor_desc * __monitors[] = { get_monitor(a), get_monitor(b) };
	// 	                                                                 monitor_guard_t __guard = { __monitors, 2 };
	//    /*Some code*/                                       =>           /*Some code*/
	// }                                                               }
	//
	class MutexKeyword final : public Visitor {
	  public:

		using Visitor::visit;
		virtual void visit( FunctionDecl * decl ) override final;
		virtual void visit(   StructDecl * decl ) override final;

		std::list<DeclarationWithType*> findMutexArgs( FunctionDecl* );
		void validate( DeclarationWithType * );
		void addStatments( CompoundStmt *, const std::list<DeclarationWithType * > &);

		static void implement( std::list< Declaration * > & translationUnit ) {
			MutexKeyword impl;
			acceptAll( translationUnit, impl );
		}

	  private:
	  	StructDecl* monitor_decl = nullptr;
		StructDecl* guard_decl = nullptr;
	};

	//=============================================================================================
	// General entry routine
	//=============================================================================================
	void applyKeywords( std::list< Declaration * > & translationUnit ) {
		ThreadKeyword	::implement( translationUnit );
		CoroutineKeyword	::implement( translationUnit );
		MonitorKeyword	::implement( translationUnit );
		MutexKeyword	::implement( translationUnit );
	}

	//=============================================================================================
	// Coroutine keyword implementation
	//=============================================================================================
	void CoroutineKeyword::visit(StructDecl * decl) {
		if( decl->get_name() == "coroutine_desc" ) {
			assert( !coroutine_decl );
			coroutine_decl = decl;
		}
		else if ( decl->is_coroutine() ) {
			handle( decl );
		}

	}

	void CoroutineKeyword::handle( StructDecl * decl ) {
		if( ! decl->has_body() ) return;

		if( !coroutine_decl ) throw SemanticError( "coroutine keyword requires coroutines to be in scope, add #include <coroutine>", decl );

		Declaration * field = addField( decl );
		addRoutines( decl, field );
	}

	Declaration * CoroutineKeyword::addField( StructDecl * decl ) {
		Declaration * cor = new ObjectDecl(
			"__cor",
			noStorage,
			LinkageSpec::Cforall,
			nullptr,
			new StructInstType(
				noQualifiers,
				coroutine_decl
			),
			nullptr
		);

		decl->get_members().push_back( cor );

		return cor;
	}

	void CoroutineKeyword::addRoutines( StructDecl * decl, Declaration * field ) {
		FunctionType * type = new FunctionType( noQualifiers, false );
		type->get_parameters().push_back(
			new ObjectDecl(
				"this",
				noStorage,
				LinkageSpec::Cforall,
				nullptr,
				new PointerType(
					noQualifiers,
					new StructInstType(
						noQualifiers,
						decl
					)
				),
				nullptr
			)
		);
		type->get_returnVals().push_back(
			new ObjectDecl(
				"ret",
				noStorage,
				LinkageSpec::Cforall,
				nullptr,
				new PointerType(
					noQualifiers,
					new StructInstType(
						noQualifiers,
						coroutine_decl
					)
				),
				nullptr
			)
		);

		CompoundStmt * statement = new CompoundStmt( noLabels );
		statement->push_back( 
			new ReturnStmt(
				noLabels,
				new AddressExpr(
					new UntypedMemberExpr(
						new NameExpr( "__cor" ),
						new UntypedExpr(
							new NameExpr( "*?" ),
							{ new NameExpr( "this" ) }
						)
					)
				)
			)
		);

		FunctionDecl * get_decl = new FunctionDecl(
			"get_coroutine",
			Type::Static,
			LinkageSpec::Cforall,
			type,
			statement,
			noAttributes,
			Type::Inline
		);

		declsToAddAfter.push_back( get_decl );

		get_decl->fixUniqueId();
	}
	

	//=============================================================================================
	// Mutex keyword implementation
	//=============================================================================================
	void MutexKeyword::visit(FunctionDecl* decl) {
		std::list<DeclarationWithType*> mutexArgs = findMutexArgs( decl );
		if( mutexArgs.empty() ) return;

		for(auto arg : mutexArgs) {
			validate( arg );
		}

		CompoundStmt* body = decl->get_statements();
		if( ! body ) return;

		if( !monitor_decl ) throw SemanticError( "mutex keyword requires monitors to be in scope, add #include <monitor>", decl );
		if( !guard_decl ) throw SemanticError( "mutex keyword requires monitors to be in scope, add #include <monitor>", decl );

		addStatments( body, mutexArgs );
	}

	void MutexKeyword::visit(StructDecl* decl) {
		if( decl->get_name() == "monitor_desc" ) {
			assert( !monitor_decl );
			monitor_decl = decl;
		}
		else if( decl->get_name() == "monitor_guard_t" ) {
			assert( !guard_decl );
			guard_decl = decl;
		}
	}

	std::list<DeclarationWithType*> MutexKeyword::findMutexArgs( FunctionDecl* decl ) {
		std::list<DeclarationWithType*> mutexArgs;

		for( auto arg : decl->get_functionType()->get_parameters()) {
			//Find mutex arguments
			Type* ty = arg->get_type();
			if( ! ty->get_mutex() ) continue;

			//Append it to the list
			mutexArgs.push_back( arg );
		}

		return mutexArgs;
	}

	void MutexKeyword::validate( DeclarationWithType * arg ) {
		Type* ty = arg->get_type();

		//Makes sure it's not a copy
		PointerType* pty = dynamic_cast< PointerType * >( ty );
		if( ! pty ) throw SemanticError( "Mutex argument must be of pointer/reference type ", arg );

		//Make sure the we are pointing directly to a type
		Type* base = pty->get_base();
		if(  dynamic_cast< PointerType * >( base ) ) throw SemanticError( "Mutex argument have exactly one level of indirection ", arg );

		//Make sure that typed isn't mutex
		if( base->get_mutex() ) throw SemanticError( "mutex keyword may only appear once per argument ", arg );
	}

	void MutexKeyword::addStatments( CompoundStmt * body, const std::list<DeclarationWithType * > & args ) {
		ObjectDecl * monitors = new ObjectDecl(
			"__monitors",
			noStorage,
			LinkageSpec::Cforall,
			nullptr,
			new ArrayType(
				noQualifiers,
				new PointerType(
					noQualifiers,
					new StructInstType(
						noQualifiers,
						monitor_decl
					)
				),
				new ConstantExpr( Constant::from_ulong( args.size() ) ),
				false,
				false
			),
			new ListInit(
				map_range < std::list<Initializer*> > ( args, [this](DeclarationWithType * var ){
					Type * type = var->get_type()->clone();
					type->set_mutex( false );
					return new SingleInit( new UntypedExpr(
						new NameExpr( "get_monitor" ),
						{  new CastExpr( new VariableExpr( var ), type ) }
					) );
				})
			)
		);

		//in reverse order :
		// monitor_guard_t __guard = { __monitors, # };
		body->push_front(
			new DeclStmt( noLabels, new ObjectDecl(
				"__guard",
				noStorage,
				LinkageSpec::Cforall,
				nullptr,
				new StructInstType(
					noQualifiers,
					guard_decl
				),
				new ListInit(
					{
						new SingleInit( new VariableExpr( monitors ) ),
						new SingleInit( new ConstantExpr( Constant::from_ulong( args.size() ) ) )
					},
					noDesignators,
					true
				)
			))
		);

		//monitor_desc * __monitors[] = { get_monitor(a), get_monitor(b) };
		body->push_front( new DeclStmt( noLabels, monitors) );
	}
};
