//
// Cforall Version 1.0.0 Copyright (C) 2016 University of Waterloo
//
// The contents of this file are covered under the licence agreement in the
// file "LICENCE" distributed with Cforall.
//
// Waitfor.cc --
//
// Author           : Thierry Delisle
// Created On       : Mon Aug 28 11:06:52 2017
// Last Modified By :
// Last Modified On :
// Update Count     : 5
//

#include "Concurrency/Keywords.h"

#include <cassert>                 // for assert
#include <string>                  // for string, operator==

using namespace std::string_literals;

#include "Common/PassVisitor.h"    // for PassVisitor
#include "Common/SemanticError.h"  // for SemanticError
#include "Common/utility.h"        // for deleteAll, map_range
#include "CodeGen/OperatorTable.h" // for isConstructor
#include "InitTweak/InitTweak.h"   // for getPointerBase
#include "Parser/LinkageSpec.h"    // for Cforall
#include "SymTab/AddVisit.h"       // for acceptAndAdd
#include "SynTree/Constant.h"      // for Constant
#include "SynTree/Declaration.h"   // for StructDecl, FunctionDecl, ObjectDecl
#include "SynTree/Expression.h"    // for VariableExpr, ConstantExpr, Untype...
#include "SynTree/Initializer.h"   // for SingleInit, ListInit, Initializer ...
#include "SynTree/Label.h"         // for Label
#include "SynTree/Statement.h"     // for CompoundStmt, DeclStmt, ExprStmt
#include "SynTree/Type.h"          // for StructInstType, Type, PointerType
#include "SynTree/Visitor.h"       // for Visitor, acceptAll

class Attribute;
/*
void foo() {
	while( true ) {
		when( a < 1 ) waitfor( f, a ) { bar(); }
		or timeout( swagl() );
		or waitfor( g, a ) { baz(); }
		or waitfor( ^?{}, a ) { break; }
		or waitfor( ^?{} ) { break; }
	}
}

void f(int i, float f, A & mutex b, struct foo *  );
void f(int );


                      |  |
                      |  |
			    |  |
                      |  |
                      |  |
                    \ |  | /
                     \    /
                      \  /
                       \/


void foo() {
	while( true ) {

		acceptable_t acceptables[3];
		if( a < 1 ) {
			acceptables[0].func = f;
			acceptables[0].mon = a;
		}
		acceptables[1].func = g;
		acceptables[1].mon = a;

		acceptables[2].func = f;
		acceptables[2].mon = a;
		acceptables[2].is_dtor = true;

		int ret = waitfor_internal( acceptables, swagl() );

		switch( ret ) {
			case 0:
			{
				bar();
			}
			case 1:
			{
				baz();
			}
			case 2:
				signal(a);
				{
					break;
				}
		}
	}
}*/

namespace Concurrency {

	namespace {
		const std::list<Label> noLabels;
		const std::list< Attribute * > noAttributes;
		Type::StorageClasses noStorage;
		Type::Qualifiers noQualifiers;
	}

	//=============================================================================================
	// Pass declarations
	//=============================================================================================

	class GenerateWaitForPass final : public WithStmtsToAdd {
	  public:

		void premutate( FunctionDecl * decl );
		void premutate( StructDecl   * decl );

		Statement * postmutate( WaitForStmt * stmt );

		static void generate( std::list< Declaration * > & translationUnit ) {
			PassVisitor< GenerateWaitForPass > impl;
			acceptAll( translationUnit, impl );
		}

		ObjectDecl * declare( unsigned long count, CompoundStmt * stmt );
		ObjectDecl * declMon( WaitForStmt::Clause & clause, CompoundStmt * stmt );
		void         init( ObjectDecl * acceptables, int index, WaitForStmt::Clause & clause, CompoundStmt * stmt );
		Expression * init_timeout( Expression *& time, Expression *& time_cond, bool has_else, Expression *& else_cond, CompoundStmt * stmt );
		Expression * call();
		void choose();

		static void implement( std::list< Declaration * > & translationUnit ) {
			PassVisitor< GenerateWaitForPass > impl;
			mutateAll( translationUnit, impl );
		}


	  private:
	  	FunctionDecl        * decl_waitfor    = nullptr;
		StructDecl          * decl_acceptable = nullptr;
		StructDecl          * decl_monitor    = nullptr;
		DeclarationWithType * decl_m_func     = nullptr;
		DeclarationWithType * decl_m_count    = nullptr;
		DeclarationWithType * decl_m_monitors = nullptr;
		DeclarationWithType * decl_m_isdtor   = nullptr;

		static std::unique_ptr< Type > generic_func;

		UniqueName namer_mon = "__monitors_"s;
		UniqueName namer_acc = "__acceptables_"s;
		UniqueName namer_tim = "__timeout_"s;
	};

	//=============================================================================================
	// General entry routine
	//=============================================================================================
	void generateWaitFor( std::list< Declaration * > & translationUnit ) {
		GenerateWaitForPass	::implement( translationUnit );
	}

	//=============================================================================================
	// Generic helper routine
	//=============================================================================================

	namespace {
		Expression * makeOpIndex( DeclarationWithType * array, unsigned long index ) {
			return new ApplicationExpr(
				new NameExpr( "?[?]" ),
				{
					new VariableExpr( array ),
					new ConstantExpr( Constant::from_ulong( index ) )
				}
			);
		}

		Expression * makeOpAssign( Expression * lhs, Expression * rhs ) {
			return new ApplicationExpr(
					new NameExpr( "?=?" ),
					{ lhs, rhs }
			);
		}

		Expression * makeOpMember( Expression * sue, DeclarationWithType * mem ) {
			return new MemberExpr( mem, sue );
		}

		Statement * makeAccStatement( DeclarationWithType * object, unsigned long index, DeclarationWithType * member, Expression * value ) {
			return new ExprStmt(
				noLabels,
				makeOpAssign(
					makeOpMember(
						makeOpIndex(
							object,
							index
						),
						member
					),
					value
				)
			);
		}

		Expression * safeCond( Expression * expr, bool ifnull = true ) {
			if( expr ) return expr;

			return new ConstantExpr( Constant::from_bool( ifnull ) );
		}
	};


	//=============================================================================================
	// Generate waitfor implementation
	//=============================================================================================

	void GenerateWaitForPass::premutate( FunctionDecl * decl) {
		if( decl->name != "__accept_internal" ) return;

		decl_waitfor = decl;
	}

	void GenerateWaitForPass::premutate( StructDecl   * decl ) {
		if( ! decl->body ) return;

		if( decl->name == "__acceptable_t" ) {
			assert( !decl_acceptable );
			decl_acceptable = decl;
			for( Declaration * field : decl_acceptable->members ) {
				     if( field->name == "func"    ) decl_m_func     = strict_dynamic_cast< DeclarationWithType * >( field );
				else if( field->name == "count"   ) decl_m_count    = strict_dynamic_cast< DeclarationWithType * >( field );
				else if( field->name == "monitor" ) decl_m_monitors = strict_dynamic_cast< DeclarationWithType * >( field );
				else if( field->name == "is_dtor" ) decl_m_isdtor   = strict_dynamic_cast< DeclarationWithType * >( field );
			}

		}
		else if( decl->name == "monitor_desc" ) {
			assert( !decl_monitor );
			decl_monitor = decl;
		}
	}

	Statement * GenerateWaitForPass::postmutate( WaitForStmt * waitfor ) {
		return waitfor;

		if( !decl_monitor || !decl_acceptable ) throw SemanticError( "waitfor keyword requires monitors to be in scope, add #include <monitor>", waitfor );

		CompoundStmt * stmt = new CompoundStmt( noLabels );

		ObjectDecl * acceptables = declare( waitfor->clauses.size(), stmt );

		int index = 0;
		for( auto & clause : waitfor->clauses ) {
			init( acceptables, index, clause, stmt );

			index++;
		}

		Expression * timeout = init_timeout(
			waitfor->timeout.time,
			waitfor->timeout.condition,
			waitfor->orelse .statement,
			waitfor->orelse .condition,
			stmt
		);

		// Expression * result  = call( acceptables, timeout, orelse, stmt );

		// choose( waitfor, result );

		return stmt;
	}

	ObjectDecl * GenerateWaitForPass::declare( unsigned long count, CompoundStmt * stmt )
	{
		ObjectDecl * acceptables = new ObjectDecl(
			namer_acc.newName(),
			noStorage,
			LinkageSpec::Cforall,
			nullptr,
			new ArrayType(
				noQualifiers,
				new StructInstType(
					noQualifiers,
					decl_acceptable
				),
				new ConstantExpr( Constant::from_ulong( count ) ),
				false,
				false
			),
			nullptr
		);

		stmt->push_back( new DeclStmt( noLabels, acceptables) );

		return acceptables;
	}

	ObjectDecl * GenerateWaitForPass::declMon( WaitForStmt::Clause & clause, CompoundStmt * stmt ) {

		ObjectDecl * mon = new ObjectDecl(
			namer_mon.newName(),
			noStorage,
			LinkageSpec::Cforall,
			nullptr,
			new ArrayType(
				noQualifiers,
				new StructInstType(
					noQualifiers,
					decl_monitor
				),
				new ConstantExpr( Constant::from_ulong( clause.target.arguments.size() ) ),
				false,
				false
			),
			new ListInit(
				map_range < std::list<Initializer*> > ( clause.target.arguments, [this](Expression * expr ){
					return new SingleInit( expr );
				})
			)
		);

		stmt->push_back( new DeclStmt( noLabels, mon) );

		return mon;
	}

	void GenerateWaitForPass::init( ObjectDecl * acceptables, int index, WaitForStmt::Clause & clause, CompoundStmt * stmt ) {

		ObjectDecl * monitors = declMon( clause, stmt );

		CompoundStmt * compound = new CompoundStmt( noLabels );
		compound->push_back( makeAccStatement( acceptables, index, decl_m_func    , clause.target.function ) );
		compound->push_back( makeAccStatement( acceptables, index, decl_m_count   , new ConstantExpr( Constant::from_ulong( clause.target.arguments.size() ) ) ) );
		compound->push_back( makeAccStatement( acceptables, index, decl_m_monitors, new VariableExpr( monitors ) ) );
		compound->push_back( makeAccStatement( acceptables, index, decl_m_isdtor  , new ConstantExpr( Constant::from_bool( true ) ) ) );

		stmt->push_back( new IfStmt(
			noLabels,
			safeCond( clause.condition ),
			compound,
			nullptr
		));

		clause.target.function = nullptr;
		clause.target.arguments.empty();
		clause.condition = nullptr;
	}

	Expression * GenerateWaitForPass::init_timeout(
		Expression *& time,
		Expression *& time_cond,
		bool has_else,
		Expression *& else_cond,
		CompoundStmt * stmt
	) {
		ObjectDecl * timeout = new ObjectDecl(
			namer_tim.newName(),
			noStorage,
			LinkageSpec::Cforall,
			nullptr,
			new BasicType(
				noQualifiers,
				BasicType::LongLongUnsignedInt
			),
			new SingleInit(
				new ConstantExpr( Constant::from_int( -1 ) )
			)
		);

		stmt->push_back( new DeclStmt( noLabels, timeout ) );

		if( time ) {
			stmt->push_back( new IfStmt(
				noLabels,
				safeCond( else_cond ),
				new ExprStmt(
					noLabels,
					makeOpAssign(
						new VariableExpr( timeout ),
						time
					)
				),
				nullptr
			));

			time = time_cond = nullptr;
		}

		if( has_else ) {
			stmt->push_back( new IfStmt(
				noLabels,
				safeCond( else_cond ),
				new ExprStmt(
					noLabels,
					makeOpAssign(
						new VariableExpr( timeout ),
						new ConstantExpr( Constant::from_ulong( 0 ) )
					)
				),
				nullptr
			));

			else_cond = nullptr;
		}

		return new VariableExpr( timeout );
	}
};

// Local Variables: //
// mode: c //
// tab-width: 4 //
// End: //