//
// Cforall Version 1.0.0 Copyright (C) 2016 University of Waterloo
//
// The contents of this file are covered under the licence agreement in the
// file "LICENCE" distributed with Cforall.
//
// Waitfor.cc --
//
// Author           : Thierry Delisle
// Created On       : Mon Aug 28 11:06:52 2017
// Last Modified By :
// Last Modified On :
// Update Count     : 5
//

#include "Concurrency/Keywords.h"

#include <cassert>                 // for assert
#include <string>                  // for string, operator==

using namespace std::string_literals;

#include "Common/PassVisitor.h"    // for PassVisitor
#include "Common/SemanticError.h"  // for SemanticError
#include "Common/utility.h"        // for deleteAll, map_range
#include "CodeGen/OperatorTable.h" // for isConstructor
#include "InitTweak/InitTweak.h"   // for getPointerBase
#include "Parser/LinkageSpec.h"    // for Cforall
#include "ResolvExpr/Resolver.h"   // for findVoidExpression
#include "SynTree/Constant.h"      // for Constant
#include "SynTree/Declaration.h"   // for StructDecl, FunctionDecl, ObjectDecl
#include "SynTree/Expression.h"    // for VariableExpr, ConstantExpr, Untype...
#include "SynTree/Initializer.h"   // for SingleInit, ListInit, Initializer ...
#include "SynTree/Label.h"         // for Label
#include "SynTree/Statement.h"     // for CompoundStmt, DeclStmt, ExprStmt
#include "SynTree/Type.h"          // for StructInstType, Type, PointerType
#include "SynTree/Visitor.h"       // for Visitor, acceptAll

class Attribute;
/*
void foo() {
	while( true ) {
		when( a < 1 ) waitfor( f, a ) { bar(); }
		or timeout( swagl() );
		or waitfor( g, a ) { baz(); }
		or waitfor( ^?{}, a ) { break; }
		or waitfor( ^?{} ) { break; }
	}
}

void f(int i, float f, A & mutex b, struct foo *  );
void f(int );


                      |  |
                      |  |
			    |  |
                      |  |
                      |  |
                    \ |  | /
                     \    /
                      \  /
                       \/


void foo() {
	while( true ) {

		acceptable_t acceptables[3];
		if( a < 1 ) {
			acceptables[0].func = f;
			acceptables[0].mon = a;
		}
		acceptables[1].func = g;
		acceptables[1].mon = a;

		acceptables[2].func = f;
		acceptables[2].mon = a;
		acceptables[2].is_dtor = true;

		int ret = waitfor_internal( acceptables, swagl() );

		switch( ret ) {
			case 0:
			{
				bar();
			}
			case 1:
			{
				baz();
			}
			case 2:
				signal(a);
				{
					break;
				}
		}
	}
}*/

namespace Concurrency {

	namespace {
		const std::list<Label> noLabels;
		const std::list< Attribute * > noAttributes;
		Type::StorageClasses noStorage;
		Type::Qualifiers noQualifiers;
	}

	//=============================================================================================
	// Pass declarations
	//=============================================================================================

	class GenerateWaitForPass final : public WithIndexer {
	  public:

		void premutate( FunctionDecl * decl );
		void premutate( StructDecl   * decl );

		Statement * postmutate( WaitForStmt * stmt );

		static void generate( std::list< Declaration * > & translationUnit ) {
			PassVisitor< GenerateWaitForPass > impl;
			acceptAll( translationUnit, impl );
		}

		ObjectDecl * declare( unsigned long count, CompoundStmt * stmt );
		ObjectDecl * declMon( WaitForStmt::Clause & clause, CompoundStmt * stmt );
		void         init( ObjectDecl * acceptables, int index, WaitForStmt::Clause & clause, CompoundStmt * stmt );
		Expression * init_timeout( Expression *& time, Expression *& time_cond, bool has_else, Expression *& else_cond, CompoundStmt * stmt );
		Expression * call(size_t count, ObjectDecl * acceptables, Expression * timeout, CompoundStmt * stmt);
		void         choose( WaitForStmt * waitfor, Expression  * result, CompoundStmt * stmt );

		static void implement( std::list< Declaration * > & translationUnit ) {
			PassVisitor< GenerateWaitForPass > impl;
			mutateAll( translationUnit, impl );
		}


	  private:
	  	FunctionDecl        * decl_waitfor    = nullptr;
		StructDecl          * decl_acceptable = nullptr;
		StructDecl          * decl_monitor    = nullptr;

		static std::unique_ptr< Type > generic_func;

		UniqueName namer_mon = "__monitors_"s;
		UniqueName namer_acc = "__acceptables_"s;
		UniqueName namer_tim = "__timeout_"s;
		UniqueName namer_ret = "__return_"s;
	};

	//=============================================================================================
	// General entry routine
	//=============================================================================================
	void generateWaitFor( std::list< Declaration * > & translationUnit ) {
		GenerateWaitForPass	::implement( translationUnit );
	}

	//=============================================================================================
	// Generic helper routine
	//=============================================================================================

	namespace {
		Expression * makeOpIndex( DeclarationWithType * array, unsigned long index ) {
			return new UntypedExpr(
				new NameExpr( "?[?]" ),
				{
					new VariableExpr( array ),
					new ConstantExpr( Constant::from_ulong( index ) )
				}
			);
		}

		Expression * makeOpAssign( Expression * lhs, Expression * rhs ) {
			return new UntypedExpr(
					new NameExpr( "?=?" ),
					{ lhs, rhs }
			);
		}

		Expression * makeOpMember( Expression * sue, const std::string & mem ) {
			return new UntypedMemberExpr( new NameExpr( mem ), sue );
		}

		Statement * makeAccStatement( DeclarationWithType * object, unsigned long index, const std::string & member, Expression * value, const SymTab::Indexer & indexer ) {
			std::unique_ptr< Expression > expr( makeOpAssign(
				makeOpMember(
					makeOpIndex(
						object,
						index
					),
					member
				),
				value
			) );

			return new ExprStmt( noLabels, ResolvExpr::findVoidExpression( expr.get(), indexer ) );
		}

		Expression * safeCond( Expression * expr, bool ifnull = true ) {
			if( expr ) return expr;

			return new ConstantExpr( Constant::from_bool( ifnull ) );
		}

		VariableExpr * extractVariable( Expression * func ) {
			if( VariableExpr * var = dynamic_cast< VariableExpr * >( func ) ) {
				return var;
			}

			CastExpr * cast = strict_dynamic_cast< CastExpr * >( func );
			return strict_dynamic_cast< VariableExpr * >( cast->arg );
		}

		Expression * betterIsDtor( Expression * func ) {
			VariableExpr * typed_func = extractVariable( func );
			bool is_dtor = InitTweak::isDestructor( typed_func->var );
			return new ConstantExpr( Constant::from_bool( is_dtor ) );
		}
	};


	//=============================================================================================
	// Generate waitfor implementation
	//=============================================================================================

	void GenerateWaitForPass::premutate( FunctionDecl * decl) {
		if( decl->name != "__waitfor_internal" ) return;

		decl_waitfor = decl;
	}

	void GenerateWaitForPass::premutate( StructDecl   * decl ) {
		if( ! decl->body ) return;

		if( decl->name == "__acceptable_t" ) {
			assert( !decl_acceptable );
			decl_acceptable = decl;
		}
		else if( decl->name == "monitor_desc" ) {
			assert( !decl_monitor );
			decl_monitor = decl;
		}
	}

	Statement * GenerateWaitForPass::postmutate( WaitForStmt * waitfor ) {
		if( !decl_monitor || !decl_acceptable ) throw SemanticError( "waitfor keyword requires monitors to be in scope, add #include <monitor>", waitfor );

		CompoundStmt * stmt = new CompoundStmt( noLabels );

		ObjectDecl * acceptables = declare( waitfor->clauses.size(), stmt );

		int index = 0;
		for( auto & clause : waitfor->clauses ) {
			init( acceptables, index, clause, stmt );

			index++;
		}

		Expression * timeout = init_timeout(
			waitfor->timeout.time,
			waitfor->timeout.condition,
			waitfor->orelse .statement,
			waitfor->orelse .condition,
			stmt
		);

		Expression * result = call( waitfor->clauses.size(), acceptables, timeout, stmt );

		choose( waitfor, result, stmt );

		return stmt;
	}

	ObjectDecl * GenerateWaitForPass::declare( unsigned long count, CompoundStmt * stmt )
	{
		ObjectDecl * acceptables = ObjectDecl::newObject(
			namer_acc.newName(),
			new ArrayType(
				noQualifiers,
				new StructInstType(
					noQualifiers,
					decl_acceptable
				),
				new ConstantExpr( Constant::from_ulong( count ) ),
				false,
				false
			),
			nullptr
		);

		stmt->push_back( new DeclStmt( noLabels, acceptables) );

		return acceptables;
	}

	ObjectDecl * GenerateWaitForPass::declMon( WaitForStmt::Clause & clause, CompoundStmt * stmt ) {

		ObjectDecl * mon = ObjectDecl::newObject(
			namer_mon.newName(),
			new ArrayType(
				noQualifiers,
				new StructInstType(
					noQualifiers,
					decl_monitor
				),
				new ConstantExpr( Constant::from_ulong( clause.target.arguments.size() ) ),
				false,
				false
			),
			new ListInit(
				map_range < std::list<Initializer*> > ( clause.target.arguments, [this](Expression * expr ){
					return new SingleInit( expr );
				})
			)
		);

		stmt->push_back( new DeclStmt( noLabels, mon) );

		return mon;
	}

	void GenerateWaitForPass::init( ObjectDecl * acceptables, int index, WaitForStmt::Clause & clause, CompoundStmt * stmt ) {

		ObjectDecl * monitors = declMon( clause, stmt );

		Type * fptr_t = new PointerType( noQualifiers, new FunctionType( noQualifiers, true ) );

		Expression * is_dtor = betterIsDtor( clause.target.function );
		CompoundStmt * compound = new CompoundStmt( noLabels );
		compound->push_back( makeAccStatement( acceptables, index, "func"    , new CastExpr( clause.target.function, fptr_t )                            , indexer ) );
		compound->push_back( makeAccStatement( acceptables, index, "count"   , new ConstantExpr( Constant::from_ulong( clause.target.arguments.size() ) ), indexer ) );
		compound->push_back( makeAccStatement( acceptables, index, "monitors", new VariableExpr( monitors )                                              , indexer ) );
		compound->push_back( makeAccStatement( acceptables, index, "is_dtor" , is_dtor                                                                   , indexer ) );

		stmt->push_back( new IfStmt(
			noLabels,
			safeCond( clause.condition ),
			compound,
			nullptr
		));

		clause.target.function = nullptr;
		clause.target.arguments.empty();
		clause.condition = nullptr;
	}

	Expression * GenerateWaitForPass::init_timeout(
		Expression *& time,
		Expression *& time_cond,
		bool has_else,
		Expression *& else_cond,
		CompoundStmt * stmt
	) {
		ObjectDecl * timeout = ObjectDecl::newObject(
			namer_tim.newName(),
			new BasicType(
				noQualifiers,
				BasicType::LongLongUnsignedInt
			),
			new SingleInit(
				new ConstantExpr( Constant::from_int( -1 ) )
			)
		);

		stmt->push_back( new DeclStmt( noLabels, timeout ) );

		if( time ) {
			stmt->push_back( new IfStmt(
				noLabels,
				safeCond( time_cond ),
				new ExprStmt(
					noLabels,
					makeOpAssign(
						new VariableExpr( timeout ),
						time
					)
				),
				nullptr
			));

			time = time_cond = nullptr;
		}

		if( has_else ) {
			stmt->push_back( new IfStmt(
				noLabels,
				safeCond( else_cond ),
				new ExprStmt(
					noLabels,
					makeOpAssign(
						new VariableExpr( timeout ),
						new ConstantExpr( Constant::from_ulong( 0 ) )
					)
				),
				nullptr
			));

			else_cond = nullptr;
		}

		return new VariableExpr( timeout );
	}

	Expression * GenerateWaitForPass::call(
		size_t count,
		ObjectDecl * acceptables,
		Expression * timeout,
		CompoundStmt * stmt
	) {
		ObjectDecl * decl = ObjectDecl::newObject(
			namer_ret.newName(),
			new BasicType(
				noQualifiers,
				BasicType::LongLongUnsignedInt
			),
			new SingleInit(
				new UntypedExpr(
					VariableExpr::functionPointer( decl_waitfor ),
					{
						new ConstantExpr( Constant::from_ulong( count ) ),
						new VariableExpr( acceptables ),
						timeout
					}
				)
			)
		);

		stmt->push_back( new DeclStmt( noLabels, decl ) );

		return new VariableExpr( decl );
	}

	void GenerateWaitForPass::choose(
		WaitForStmt * waitfor,
		Expression  * result,
		CompoundStmt * stmt
	) {
		SwitchStmt * swtch = new SwitchStmt(
			noLabels,
			result,
			std::list<Statement *>()
		);

		unsigned long i = 0;
		for( auto & clause : waitfor->clauses ) {
			swtch->statements.push_back(
				new CaseStmt(
					noLabels,
					new ConstantExpr( Constant::from_ulong( i++ ) ),
					{
						clause.statement,
						new BranchStmt(
							noLabels,
							"",
							BranchStmt::Break
						)
					}
				)
			);
		}

		if(waitfor->timeout.statement) {
			swtch->statements.push_back(
				new CaseStmt(
					noLabels,
					new ConstantExpr( Constant::from_ulong( i++ ) ),
					{
						waitfor->timeout.statement,
						new BranchStmt(
							noLabels,
							"",
							BranchStmt::Break
						)
					}
				)
			);
		}

		if(waitfor->orelse.statement) {
			swtch->statements.push_back(
				new CaseStmt(
					noLabels,
					new ConstantExpr( Constant::from_ulong( i++ ) ),
					{
						waitfor->orelse.statement,
						new BranchStmt(
							noLabels,
							"",
							BranchStmt::Break
						)
					}
				)
			);
		}

		stmt->push_back( swtch );
	}
};

// Local Variables: //
// mode: c //
// tab-width: 4 //
// End: //