//                              -*- Mode: CPP -*-
//
// Cforall Version 1.0.0 Copyright (C) 2016 University of Waterloo
//
// The contents of this file are covered under the licence agreement in the
// file "LICENCE" distributed with Cforall.
//
// Keywords.cc --
//
// Author           : Thierry Delisle
// Created On       : Mon Mar 13 12:41:22 2017
// Last Modified By :
// Last Modified On :
// Update Count     : 0
//

#include "Concurrency/Keywords.h"

#include "SynTree/Declaration.h"
#include "SynTree/Expression.h"
#include "SynTree/Initializer.h"
#include "SynTree/Mutator.h"
#include "SynTree/Statement.h"
#include "SynTree/Type.h"
#include "SynTree/Visitor.h"

namespace Concurrency {

	namespace {
		const std::list<Label> noLabels;
		DeclarationNode::StorageClasses noStorage;
		Type::Qualifiers noQualifiers;
	}

	//=============================================================================================
	// Visitors declaration
	//=============================================================================================

	//-----------------------------------------------------------------------------
	//Handles thread type declarations :
	// thread Mythread {                         struct MyThread {
	// 	int data;                                  int data;
	// 	a_struct_t more_data;                      a_struct_t more_data;
	//                                =>             thread_desc __thrd_d;
	// };                                        };
	//                                           static inline thread_desc * get_thread( MyThread * this ) { return &this->__thrd_d; }
	//                                           void main( MyThread * this );
	//
	class ThreadKeyword final : public Mutator {
	  public:

		static void implement( std::list< Declaration * > & translationUnit ) {}
	};

	//-----------------------------------------------------------------------------
	//Handles coroutine type declarations :
	// coroutine MyCoroutine {                   struct MyCoroutine {
	// 	int data;                                  int data;
	// 	a_struct_t more_data;                      a_struct_t more_data;
	//                                =>             coroutine_desc __cor_d;
	// };                                        };
	//                                           static inline coroutine_desc * get_coroutine( MyCoroutine * this ) { return &this->__cor_d; }
	//                                           void main( MyCoroutine * this );
	//
	class CoroutineKeyword final : public Mutator {
	  public:

		static void implement( std::list< Declaration * > & translationUnit ) {}
	};

	//-----------------------------------------------------------------------------
	//Handles monitor type declarations :
	// monitor MyMonitor {                       struct MyMonitor {
	// 	int data;                                  int data;
	// 	a_struct_t more_data;                      a_struct_t more_data;
	//                                =>             monitor_desc __mon_d;
	// };                                        };
	//                                           static inline monitor_desc * get_coroutine( MyMonitor * this ) { return &this->__cor_d; }
	//                                           void main( MyMonitor * this );
	//
	class MonitorKeyword final : public Mutator {
	  public:

		static void implement( std::list< Declaration * > & translationUnit ) {}
	};

	//-----------------------------------------------------------------------------
	//Handles mutex routines definitions :
	// void foo( A * mutex a, B * mutex b,  int i ) {                  void foo( A * a, B * b,  int i ) {
	// 	                                                                 monitor_desc * __monitors[] = { a, b };
	// 	                                                                 monitor_guard_t __guard = { __monitors, 2 };
	//    /*Some code*/                                       =>           /*Some code*/
	// }                                                               }
	//
	class MutexKeyword final : public Visitor {
	  public:

		using Visitor::visit;
		virtual void visit( FunctionDecl *functionDecl ) override final;

		std::list<DeclarationWithType*> findMutexArgs( FunctionDecl* );
		void validate( DeclarationWithType * );
		void addStatments( CompoundStmt *, const std::list<DeclarationWithType * > &);

		static void implement( std::list< Declaration * > & translationUnit ) {
			MutexKeyword impl;
			acceptAll( translationUnit, impl );
		}
	};

	//=============================================================================================
	// General entry routine
	//=============================================================================================
	void applyKeywords( std::list< Declaration * > & translationUnit ) {
		ThreadKeyword	::implement( translationUnit );
		CoroutineKeyword	::implement( translationUnit );
		MonitorKeyword	::implement( translationUnit );
		MutexKeyword	::implement( translationUnit );
	}

	//=============================================================================================
	// Mutex keyword implementation
	//=============================================================================================
	void MutexKeyword::visit(FunctionDecl* decl) {
		std::list<DeclarationWithType*> mutexArgs = findMutexArgs( decl );
		if( mutexArgs.empty() ) return;

		for(auto arg : mutexArgs) {
			validate( arg );
		}

		CompoundStmt* body = decl->get_statements();
		if( ! body ) return;

		addStatments( body, mutexArgs );
	}

	std::list<DeclarationWithType*> MutexKeyword::findMutexArgs( FunctionDecl* decl ) {
		std::list<DeclarationWithType*> mutexArgs;

		for( auto arg : decl->get_functionType()->get_parameters()) {
			//Find mutex arguments
			Type* ty = arg->get_type();
			if( ! ty->get_qualifiers().isMutex ) continue;

			//Append it to the list
			mutexArgs.push_back( arg );
		}

		return mutexArgs;
	}

	void MutexKeyword::validate( DeclarationWithType * arg ) {
		Type* ty = arg->get_type();

		//Makes sure it's not a copy
		PointerType* pty = dynamic_cast< PointerType * >( ty );
		if( ! pty ) throw SemanticError( "Mutex argument must be of pointer/reference type ", arg );

		//Make sure the we are pointing directly to a type
		Type* base = pty->get_base();
		if(  dynamic_cast< PointerType * >( base ) ) throw SemanticError( "Mutex argument have exactly one level of indirection ", arg );

		//Make sure that typed isn't mutex
		if( ! base->get_qualifiers().isMutex ) throw SemanticError( "mutex keyword may only appear once per argument ", arg );
	}

	void MutexKeyword::addStatments( CompoundStmt * body, const std::list<DeclarationWithType * > & args ) {
		//in reverse order :
		// monitor_guard_t __guard = { __monitors, # };
		body->push_front(
			new DeclStmt( noLabels, new ObjectDecl(
				"__guard",
				noStorage,
				LinkageSpec::Cforall,
				nullptr,
				new StructInstType(
					noQualifiers,
					"monitor_guard_t"
				),
				new ListInit(
					{
						new SingleInit( new NameExpr( "__monitors" ) ),
						new SingleInit( new ConstantExpr( Constant::from_ulong( args.size() ) ) )
					}
				)
			))
		);

		//monitor_desc * __monitors[] = { a, b };
		body->push_front(
			new DeclStmt( noLabels, new ObjectDecl(
				"__monitors",
				noStorage,
				LinkageSpec::Cforall,
				nullptr,
				new ArrayType(
					noQualifiers,
					new PointerType(
						noQualifiers,
						new StructInstType(
							noQualifiers,
							"monitor_desc"
						)
					),
					new ConstantExpr( Constant::from_ulong( args.size() ) ),
					false,
					false
				),
				new ListInit(
					map_range < std::list<Initializer*> > ( args, [](DeclarationWithType * var ){
						return new SingleInit( new VariableExpr( var ) );
					})
				)
			))
		);
	}
};