//
// Cforall Version 1.0.0 Copyright (C) 2019 University of Waterloo
//
// The contents of this file are covered under the licence agreement in the
// file "LICENCE" distributed with Cforall.
//
// Pass.impl.hpp --
//
// Author           : Thierry Delisle
// Created On       : Thu May 09 15::37::05 2019
// Last Modified By :
// Last Modified On :
// Update Count     :
//

#pragma once
// IWYU pragma: private, include "AST/Pass.hpp"

#include <type_traits>
#include <unordered_map>

#define VISIT_START( node ) \
	using namespace ast; \
	/* back-up the visit children */ \
	__attribute__((unused)) ast::__pass::visit_children_guard guard1( ast::__pass::visit_children(pass, 0) ); \
	/* setup the scope for passes that want to run code at exit */ \
	__attribute__((unused)) ast::__pass::guard_value          guard2( ast::__pass::at_cleanup    (pass, 0) ); \
	/* call the implementation of the previsit of this pass */ \
	__pass::previsit( pass, node, 0 );

#define VISIT( code... ) \
	/* if this node should visit its children */ \
	if ( __visit_children() ) { \
		/* visit the children */ \
		code \
	}

#define VISIT_END( type, node ) \
	/* call the implementation of the postvisit of this pass */ \
	auto __return = __pass::postvisit( pass, node, 0 ); \
	assertf(__return, "post visit should never return null"); \
	return __return;

#ifdef PEDANTIC_PASS_ASSERT
#define __pedantic_pass_assert(...) assert (__VA_ARGS__)
#define __pedantic_pass_assertf(...) assertf(__VA_ARGS__)
#else
#define __pedantic_pass_assert(...)
#define __pedantic_pass_assertf(...)
#endif

namespace ast {
	namespace __pass {
		// Check if this is either a null pointer or a pointer to an empty container
		template<typename T>
		static inline bool empty( T * ptr ) {
			return !ptr || ptr->empty();
		}

		//------------------------------
		template<typename it_t, template <class...> class container_t>
		static inline void take_all( it_t it, container_t<ast::ptr<ast::Decl>> * decls, bool * mutated = nullptr ) {
			if(empty(decls)) return;

			std::transform(decls->begin(), decls->end(), it, [](const ast::Decl * decl) -> auto {
					return new DeclStmt( decl );
				});
			decls->clear();
			if(mutated) *mutated = true;
		}

		template<typename it_t, template <class...> class container_t>
		static inline void take_all( it_t it, container_t<ast::ptr<ast::Stmt>> * decls, bool * mutated = nullptr ) {
			if(empty(decls)) return;

			std::move(decls->begin(), decls->end(), it);
			decls->clear();
			if(mutated) *mutated = true;
		}

		//------------------------------
		/// Check if should be skipped, different for pointers and containers
		template<typename node_t>
		bool skip( const ast::ptr<node_t> & val) {
			return !val;
		}

		template< template <class...> class container_t, typename node_t >
		bool skip( const container_t<ast::ptr< node_t >> & val ) {
			return val.empty();
		}

		//------------------------------
		/// Get the value to visit, different for pointers and containers
		template<typename node_t>
		auto get( const ast::ptr<node_t> & val, int ) -> decltype(val.get()) {
			return val.get();
		}

		template<typename node_t>
		const node_t & get( const node_t & val, long) {
			return val;
		}


		//------------------------------
		/// Check if value was mutated, different for pointers and containers
		template<typename lhs_t, typename rhs_t>
		bool differs( const lhs_t * old_val, const rhs_t * new_val ) {
			return old_val != new_val;
		}

		template< template <class...> class container_t, typename node_t >
		bool differs( const container_t<ast::ptr< node_t >> &, const container_t<ast::ptr< node_t >> & new_val ) {
			return !new_val.empty();
		}
	}

	template< typename pass_t >
	template< typename node_t >
	auto Pass< pass_t >::call_accept( const node_t * node ) -> decltype( node->accept(*this) ) {
		__pedantic_pass_assert( __visit_children() );
		__pedantic_pass_assert( expr );

		static_assert( !std::is_base_of<ast::Expr, node_t>::value, "ERROR");
		static_assert( !std::is_base_of<ast::Stmt, node_t>::value, "ERROR");

		return node->accept( *this );
	}

	template< typename pass_t >
	const ast::Expr * Pass< pass_t >::call_accept( const ast::Expr * expr ) {
		__pedantic_pass_assert( __visit_children() );
		__pedantic_pass_assert( expr );

		const ast::TypeSubstitution ** env_ptr = __pass::env( pass, 0);
		if ( env_ptr && expr->env ) {
			*env_ptr = expr->env;
		}

		return expr->accept( *this );
	}

	template< typename pass_t >
	const ast::Stmt * Pass< pass_t >::call_accept( const ast::Stmt * stmt ) {
		__pedantic_pass_assert( __visit_children() );
		__pedantic_pass_assert( stmt );

		// add a few useful symbols to the scope
		using __pass::empty;

		// get the stmts/decls that will need to be spliced in
		auto stmts_before = __pass::stmtsToAddBefore( pass, 0);
		auto stmts_after  = __pass::stmtsToAddAfter ( pass, 0);
		auto decls_before = __pass::declsToAddBefore( pass, 0);
		auto decls_after  = __pass::declsToAddAfter ( pass, 0);

		// These may be modified by subnode but most be restored once we exit this statemnet.
		ValueGuardPtr< const ast::TypeSubstitution * > __old_env         ( __pass::env( pass, 0) );
		ValueGuardPtr< typename std::remove_pointer< decltype(stmts_before) >::type > __old_decls_before( stmts_before );
		ValueGuardPtr< typename std::remove_pointer< decltype(stmts_after ) >::type > __old_decls_after ( stmts_after  );
		ValueGuardPtr< typename std::remove_pointer< decltype(decls_before) >::type > __old_stmts_before( decls_before );
		ValueGuardPtr< typename std::remove_pointer< decltype(decls_after ) >::type > __old_stmts_after ( decls_after  );

		// Now is the time to actually visit the node
		const ast::Stmt * nstmt = stmt->accept( *this );

		// If the pass doesn't want to add anything then we are done
		if( empty(stmts_before) && empty(stmts_after) && empty(decls_before) && empty(decls_after) ) {
			return nstmt;
		}

		// Make sure that it is either adding statements or declartions but not both
		// this is because otherwise the order would be awkward to predict
		assert(( empty( stmts_before ) && empty( stmts_after ))
		    || ( empty( decls_before ) && empty( decls_after )) );

		// Create a new Compound Statement to hold the new decls/stmts
		ast::CompoundStmt * compound = new ast::CompoundStmt( stmt->location );

		// Take all the declarations that go before
		__pass::take_all( std::back_inserter( compound->kids ), decls_before );
		__pass::take_all( std::back_inserter( compound->kids ), stmts_before );

		// Insert the original declaration
		compound->kids.emplace_back( nstmt );

		// Insert all the declarations that go before
		__pass::take_all( std::back_inserter( compound->kids ), decls_after );
		__pass::take_all( std::back_inserter( compound->kids ), stmts_after );

		return compound;
	}

	template< typename pass_t >
	template< template <class...> class container_t >
	container_t< ptr<Stmt> > Pass< pass_t >::call_accept( const container_t< ptr<Stmt> > & statements ) {
		__pedantic_pass_assert( __visit_children() );
		if( statements.empty() ) return {};

		// We are going to aggregate errors for all these statements
		SemanticErrorException errors;

		// add a few useful symbols to the scope
		using __pass::empty;

		// get the stmts/decls that will need to be spliced in
		auto stmts_before = __pass::stmtsToAddBefore( pass, 0);
		auto stmts_after  = __pass::stmtsToAddAfter ( pass, 0);
		auto decls_before = __pass::declsToAddBefore( pass, 0);
		auto decls_after  = __pass::declsToAddAfter ( pass, 0);

		// These may be modified by subnode but most be restored once we exit this statemnet.
		ValueGuardPtr< typename std::remove_pointer< decltype(stmts_before) >::type > __old_decls_before( stmts_before );
		ValueGuardPtr< typename std::remove_pointer< decltype(stmts_after ) >::type > __old_decls_after ( stmts_after  );
		ValueGuardPtr< typename std::remove_pointer< decltype(decls_before) >::type > __old_stmts_before( decls_before );
		ValueGuardPtr< typename std::remove_pointer< decltype(decls_after ) >::type > __old_stmts_after ( decls_after  );

		// update pass statitistics
		pass_visitor_stats.depth++;
		pass_visitor_stats.max->push(pass_visitor_stats.depth);
		pass_visitor_stats.avg->push(pass_visitor_stats.depth);

		bool mutated = false;
		container_t< ptr<Stmt> > new_kids;
		for( const Stmt * stmt : statements ) {
			try {
				__pedantic_pass_assert( stmt );
				const ast::Stmt * new_stmt = stmt->accept( *this );
				assert( new_stmt );
				if(new_stmt != stmt ) mutated = true;

				// Make sure that it is either adding statements or declartions but not both
				// this is because otherwise the order would be awkward to predict
				assert(( empty( stmts_before ) && empty( stmts_after ))
				    || ( empty( decls_before ) && empty( decls_after )) );



				// Take all the statements which should have gone after, N/A for first iteration
				__pass::take_all( std::back_inserter( new_kids ), decls_before, &mutated );
				__pass::take_all( std::back_inserter( new_kids ), stmts_before, &mutated );

				// Now add the statement if there is one
				new_kids.emplace_back( new_stmt );

				// Take all the declarations that go before
				__pass::take_all( std::back_inserter( new_kids ), decls_after, &mutated );
				__pass::take_all( std::back_inserter( new_kids ), stmts_after, &mutated );
			}
			catch ( SemanticErrorException &e ) {
				errors.append( e );
			}
		}
		pass_visitor_stats.depth--;
		if ( !errors.isEmpty() ) { throw errors; }

		return mutated ? new_kids : container_t< ptr<Stmt> >();
	}

	template< typename pass_t >
	template< template <class...> class container_t, typename node_t >
	container_t< ast::ptr<node_t> > Pass< pass_t >::call_accept( const container_t< ast::ptr<node_t> > & container ) {
		__pedantic_pass_assert( __visit_children() );
		if( container.empty() ) return {};
		SemanticErrorException errors;

		pass_visitor_stats.depth++;
		pass_visitor_stats.max->push(pass_visitor_stats.depth);
		pass_visitor_stats.avg->push(pass_visitor_stats.depth);

		bool mutated = false;
		container_t< ast::ptr<node_t> > new_kids;
		for ( const node_t * node : container ) {
			try {
				__pedantic_pass_assert( node );
				const node_t * new_stmt = strict_dynamic_cast< const node_t * >( node->accept( *this ) );
				if(new_stmt != node ) mutated = true;

				new_kids.emplace_back( new_stmt );
			}
			catch( SemanticErrorException &e ) {
				errors.append( e );
			}
		}
		pass_visitor_stats.depth--;
		if ( ! errors.isEmpty() ) { throw errors; }

		return mutated ? new_kids : container_t< ast::ptr<node_t> >();
	}

	template< typename pass_t >
	template<typename node_t, typename parent_t, typename child_t>
	void Pass< pass_t >::maybe_accept(
		const node_t * & parent,
		child_t parent_t::*child
	) {
		static_assert( std::is_base_of<parent_t, node_t>::value, "Error deductiing member object" );

		if(__pass::skip(parent->*child)) return;
		const auto & old_val = __pass::get(parent->*child, 0);

		static_assert( !std::is_same<const ast::Node * &, decltype(old_val)>::value, "ERROR");

		auto new_val = call_accept( old_val );

		static_assert( !std::is_same<const ast::Node *, decltype(new_val)>::value || std::is_same<int, decltype(old_val)>::value, "ERROR");

		if( __pass::differs(old_val, new_val) ) {
			auto new_parent = mutate(parent);
			new_parent->*child = new_val;
			parent = new_parent;
		}
	}

}

//------------------------------------------------------------------------------------------------------------------------------------------------------------------------
//========================================================================================================================================================================
//========================================================================================================================================================================
//========================================================================================================================================================================
//========================================================================================================================================================================
//========================================================================================================================================================================
//------------------------------------------------------------------------------------------------------------------------------------------------------------------------

template< typename pass_t >
inline void ast::accept_all( std::list< ast::ptr<ast::Decl> > & decls, ast::Pass< pass_t > & visitor ) {
	// We are going to aggregate errors for all these statements
	SemanticErrorException errors;

	// add a few useful symbols to the scope
	using __pass::empty;

	// get the stmts/decls that will need to be spliced in
	auto decls_before = __pass::declsToAddBefore( visitor.pass, 0);
	auto decls_after  = __pass::declsToAddAfter ( visitor.pass, 0);

	// update pass statitistics
	pass_visitor_stats.depth++;
	pass_visitor_stats.max->push(pass_visitor_stats.depth);
	pass_visitor_stats.avg->push(pass_visitor_stats.depth);

	for ( std::list< ast::ptr<ast::Decl> >::iterator i = decls.begin(); ; ++i ) {
		// splice in new declarations after previous decl
		if ( !empty( decls_after ) ) { decls.splice( i, *decls_after ); }

		if ( i == decls.end() ) break;

		try {
			// run visitor on declaration
			ast::ptr<ast::Decl> & node = *i;
			assert( node );
			node = node->accept( visitor );
		}
		catch( SemanticErrorException &e ) {
			errors.append( e );
		}

		// splice in new declarations before current decl
		if ( !empty( decls_before ) ) { decls.splice( i, *decls_before ); }
	}
	pass_visitor_stats.depth--;
	if ( !errors.isEmpty() ) { throw errors; }
}

// A NOTE ON THE ORDER OF TRAVERSAL
//
// Types and typedefs have their base types visited before they are added to the type table.  This is ok, since there is
// no such thing as a recursive type or typedef.
//
//             typedef struct { T *x; } T; // never allowed
//
// for structs/unions, it is possible to have recursion, so the decl should be added as if it's incomplete to begin, the
// members are traversed, and then the complete type should be added (assuming the type is completed by this particular
// declaration).
//
//             struct T { struct T *x; }; // allowed
//
// It is important to add the complete type to the symbol table *after* the members/base has been traversed, since that
// traversal may modify the definition of the type and these modifications should be visible when the symbol table is
// queried later in this pass.

//--------------------------------------------------------------------------
// ObjectDecl
template< typename pass_t >
const ast::DeclWithType * ast::Pass< pass_t >::visit( const ast::ObjectDecl * node ) {
	VISIT_START( node );

	VISIT(
		{
			guard_indexer guard { *this };
			maybe_accept( node, &ObjectDecl::type );
		}
		maybe_accept( node, &ObjectDecl::init          );
		maybe_accept( node, &ObjectDecl::bitfieldWidth );
		maybe_accept( node, &ObjectDecl::attributes    );
	)

	__pass::indexer::addId( pass, 0, node );

	VISIT_END( DeclWithType, node );
}

//--------------------------------------------------------------------------
// FunctionDecl
template< typename pass_t >
const ast::DeclWithType * ast::Pass< pass_t >::visit( const ast::FunctionDecl * node ) {
	VISIT_START( node );

	__pass::indexer::addId( pass, 0, node );

	VISIT(maybe_accept( node, &FunctionDecl::withExprs );)
	{
		// with clause introduces a level of scope (for the with expression members).
		// with clause exprs are added to the indexer before parameters so that parameters
		// shadow with exprs and not the other way around.
		guard_indexer guard { *this };
		__pass::indexer::addWith( pass, 0, node->withExprs, node );
		{
			guard_indexer guard { *this };
			// implicit add __func__ identifier as specified in the C manual 6.4.2.2
			static ast::ObjectDecl func(
				node->location, "__func__",
				new ast::ArrayType(
					new ast::BasicType( ast::BasicType::Char, ast::CV::Qualifiers( ast::CV::Const ) ),
					nullptr, true, false
				)
			);
			__pass::indexer::addId( pass, 0, &func );
			VISIT(
				maybe_accept( node, &FunctionDecl::type );
				// function body needs to have the same scope as parameters - CompoundStmt will not enter
				// a new scope if inFunction is true
				ValueGuard< bool > oldInFunction( inFunction );
				inFunction = true;
				maybe_accept( node, &FunctionDecl::statements );
				maybe_accept( node, &FunctionDecl::attributes );
			)
		}
	}

	VISIT_END( DeclWithType, node );
}

//--------------------------------------------------------------------------
// StructDecl
template< typename pass_t >
const ast::Decl * ast::Pass< pass_t >::visit( const ast::StructDecl * node ) {
	VISIT_START( node );

	// make up a forward declaration and add it before processing the members
	// needs to be on the heap because addStruct saves the pointer
	__pass::indexer::addStructFwd( pass, 0, node );

	VISIT({
		guard_indexer guard { * this };
		maybe_accept( node, &StructDecl::parameters );
		maybe_accept( node, &StructDecl::members    );
	})

	// this addition replaces the forward declaration
	__pass::indexer::addStruct( pass, 0, node );

	VISIT_END( Decl, node );
}

//--------------------------------------------------------------------------
// UnionDecl
template< typename pass_t >
const ast::Decl * ast::Pass< pass_t >::visit( const ast::UnionDecl * node ) {
	VISIT_START( node );

	// make up a forward declaration and add it before processing the members
	__pass::indexer::addUnionFwd( pass, 0, node );

	VISIT({
		guard_indexer guard { * this };
		maybe_accept( node, &UnionDecl::parameters );
		maybe_accept( node, &UnionDecl::members    );
	})

	__pass::indexer::addUnion( pass, 0, node );

	VISIT_END( Decl, node );
}

//--------------------------------------------------------------------------
// EnumDecl
template< typename pass_t >
const ast::Decl * ast::Pass< pass_t >::visit( const ast::EnumDecl * node ) {
	VISIT_START( node );

	__pass::indexer::addEnum( pass, 0, node );

	VISIT(
		// unlike structs, traits, and unions, enums inject their members into the global scope
		maybe_accept( node, &EnumDecl::parameters );
		maybe_accept( node, &EnumDecl::members    );
	)

	VISIT_END( Decl, node );
}

//--------------------------------------------------------------------------
// TraitDecl
template< typename pass_t >
const ast::Decl * ast::Pass< pass_t >::visit( const ast::TraitDecl * node ) {
	VISIT_START( node );

	VISIT({
		guard_indexer guard { *this };
		maybe_accept( node, &TraitDecl::parameters );
		maybe_accept( node, &TraitDecl::members    );
	})

	__pass::indexer::addTrait( pass, 0, node );

	VISIT_END( Decl, node );
}

//--------------------------------------------------------------------------
// TypeDecl
template< typename pass_t >
const ast::Decl * ast::Pass< pass_t >::visit( const ast::TypeDecl * node ) {
	VISIT_START( node );

	VISIT({
		guard_indexer guard { *this };
		maybe_accept( node, &TypeDecl::parameters );
		maybe_accept( node, &TypeDecl::base       );
	})

	// see A NOTE ON THE ORDER OF TRAVERSAL, above
	// note that assertions come after the type is added to the symtab, since they are not part of the type proper
	// and may depend on the type itself
	__pass::indexer::addType( pass, 0, node );

	VISIT(
		maybe_accept( node, &TypeDecl::assertions, *this );

		{
			guard_indexer guard { *this };
			maybe_accept( node, &TypeDecl::init );
		}
	)

	VISIT_END( Decl, node );
}

//--------------------------------------------------------------------------
// TypedefDecl
template< typename pass_t >
const ast::Decl * ast::Pass< pass_t >::visit( const ast::TypedefDecl * node ) {
	VISIT_START( node );

	VISIT({
		guard_indexer guard { *this };
		maybe_accept( node, &TypedefDecl::parameters );
		maybe_accept( node, &TypedefDecl::base       );
	})

	__pass::indexer::addType( pass, 0, node );

	maybe_accept( node, &TypedefDecl::assertions );

	VISIT_END( Decl, node );
}

//--------------------------------------------------------------------------
// AsmDecl
template< typename pass_t >
const ast::AsmDecl * ast::Pass< pass_t >::visit( const ast::AsmDecl * node ) {
	VISIT_START( node );

	VISIT(
		maybe_accept( node, &AsmDecl::stmt );
	)

	VISIT_END( AsmDecl, node );
}

//--------------------------------------------------------------------------
// StaticAssertDecl
template< typename pass_t >
const ast::StaticAssertDecl * ast::Pass< pass_t >::visit( const ast::StaticAssertDecl * node ) {
	VISIT_START( node );

	VISIT(
		maybe_accept( node, &StaticAssertDecl::condition );
		maybe_accept( node, &StaticAssertDecl::msg       );
	)

	VISIT_END( StaticAssertDecl, node );
}

//--------------------------------------------------------------------------
// CompoundStmt
template< typename pass_t >
const ast::CompoundStmt * ast::Pass< pass_t >::visit( const ast::CompoundStmt * node ) {
	VISIT_START( node );
	VISIT({
		// do not enter a new scope if inFunction is true - needs to check old state before the assignment
		auto guard1 = makeFuncGuard( [this, inFunction = this->inFunction]() {
			if ( ! inFunction ) __pass::indexer::enter(pass, 0);
		}, [this, inFunction = this->inFunction]() {
			if ( ! inFunction ) __pass::indexer::leave(pass, 0);
		});
		ValueGuard< bool > guard2( inFunction );
		guard_scope guard3 { *this };
		inFunction = false;
		maybe_accept( node, &CompoundStmt::kids );
	})
	VISIT_END( CompoundStmt, node );
}


//--------------------------------------------------------------------------
// SingleInit
template< typename pass_t >
const ast::Init * ast::Pass< pass_t >::visit( const ast::SingleInit * node ) {
	VISIT_START( node );

	VISIT(
		maybe_accept( node, &SingleInit::value );
	)

	VISIT_END( Init, node );
}

//--------------------------------------------------------------------------
// ListInit
template< typename pass_t >
const ast::Init * ast::Pass< pass_t >::visit( const ast::ListInit * node ) {
	VISIT_START( node );

	VISIT(
		maybe_accept( node, &ListInit::designations );
		maybe_accept( node, &ListInit::initializers );
	)

	VISIT_END( Init, node );
}

//--------------------------------------------------------------------------
// ConstructorInit
template< typename pass_t >
const ast::Init * ast::Pass< pass_t >::visit( const ast::ConstructorInit * node ) {
	VISIT_START( node );

	VISIT(
		maybe_accept( node, &ConstructorInit::ctor );
		maybe_accept( node, &ConstructorInit::dtor );
		maybe_accept( node, &ConstructorInit::init );
	)

	VISIT_END( Init, node );
}

//--------------------------------------------------------------------------
// Attribute
template< typename pass_t >
const ast::Attribute * ast::Pass< pass_t >::visit( const ast::Attribute * node  )  {
	VISIT_START( node );

	VISIT(
		maybe_accept( node, &Attribute::parameters );
	)

	VISIT_END( Attribute *, node );
}

//--------------------------------------------------------------------------
// TypeSubstitution
template< typename pass_t >
const ast::TypeSubstitution * ast::Pass< pass_t >::visit( const ast::TypeSubstitution * node ) {
	VISIT_START( node );

	VISIT(
		{
			bool mutated = false;
			std::unordered_map< std::string, ast::ptr< ast::Type > > new_map;
			for ( const auto & p : node->typeEnv ) {
				guard_indexer guard { *this };
				auto new_node = p.second->accept( *this );
				if (new_node != p.second) mutated = false;
				new_map.insert({ p.first, new_node });
			}
			if (mutated) {
				auto new_node = mutate( node );
				new_node->typeEnv.swap( new_map );
				node = new_node;
			}
		}

		{
			bool mutated = false;
			std::unordered_map< std::string, ast::ptr< ast::Expr > > new_map;
			for ( const auto & p : node->varEnv ) {
				guard_indexer guard { *this };
				auto new_node = p.second->accept( *this );
				if (new_node != p.second) mutated = false;
				new_map.insert({ p.first, new_node });
			}
			if (mutated) {
				auto new_node = mutate( node );
				new_node->varEnv.swap( new_map );
				node = new_node;
			}
		}
	)

	VISIT_END( TypeSubstitution, node );
}

#undef VISIT_START
#undef VISIT
#undef VISIT_END