#pragma once
// IWYU pragma: private, include "Pass.hpp"

#define VISIT_START( node ) \
	/* back-up the visit children */ \
	__attribute__((unused)) ast::__pass::visit_children_guard guard1( ast::__pass::visit_children(m_pass, 0) ); \
	/* setup the scope for passes that want to run code at exit */ \
	__attribute__((unused)) ast::__pass::guard_value          guard2( ast::__pass::at_cleanup    (m_pass, 0) ); \
	/* call the implementation of the previsit of this pass */ \
	__pass::previsit( m_pass, node, 0 );

#define VISIT( code ) \
	/* if this node should visit its children */ \
	if ( __visit_children() ) { \
		/* visit the children */ \
		code \
	}

#define VISIT_END( type, node ) \
	/* call the implementation of the postvisit of this pass */ \
	auto __return = __pass::postvisit< type * >( node ); \
	assertf(__return, "post visit should never return null"); \
	return __return;

#ifdef PEDANTIC_PASS_ASSERT
#define __pedantic_pass_assert (...) assert (__VAR_ARGS__)
#define __pedantic_pass_assertf(...) assertf(__VAR_ARGS__)
#else
#define __pedantic_pass_assert (...)
#define __pedantic_pass_assertf(...)
#endif

namespace ast {
	namespace __pass {
		// Check if this is either a null pointer or a pointer to an empty container
		template<typename T>
		static inline bool empty( T * ptr ) {
			return !ptr || ptr->empty();
		}

		template<typename it_t, template <class> class container_t>
		static inline void take_all( it_t it, container_t<ast::ptr<ast::Declaration>> * decls, bool * mutated = nullptr ) {
			if(empty(decls)) return;

			std::transform(decls->begin(), decls->end(), it, [](Declaration * decl) -> auto {
					return new DeclStmt( decl );
				});
			decls->clear();
			if(mutated) *mutated = true;
		}

		template<typename it_t, template <class> class container_t>
		static inline void take_all( it_t it, container_t<ast::ptr<ast::Statement>> * decls, bool * mutated = nullptr ) {
			if(empty(decls)) return;

			std::move(decls->begin(), decls->end(), it);
			decls->clear();
			if(mutated) *mutated = true;
		}

		template<typename node_t>
		bool differs( const node_t * old_val, const node_t * new_val ) {
			return old_val != new_val;
		}

		template< template <class> class container_t >
		bool differs( const container_t<ast::ptr< ast::Statement >> &, const container_t<ast::ptr< ast::Statement >> & new_val ) {
			return !new_val.empty();
		}
	}

	template<typename parent_t, typename child_t>
	template< typename pass_t >
	void Pass< pass_t >::maybe_accept(
		const parent_t * & parent,
		const typename parent_t::child_t * child
	) {
		const auto & old_val = parent->*child;
		if(!old_val) return;

		auto new_val = call_accept(old_val);

		if( __pass::differs(old_val, new_val) ) {
			auto new_parent = mutate(parent);
			new_parent->*child = new_val;
			parent = new_parent;
		}
	}

	template< typename pass_t >
	template< typename node_t >
	auto Pass< pass_t >::call_accept( const node_t * node ) {
		__pedantic_pass_assert( __visit_children() );
		__pedantic_pass_assert( expr );

		return node->accept( *this );
	}

	template< typename pass_t >
	ast::Expression * Pass< pass_t >::call_accept( const ast::Expression * expr ) {
		__pedantic_pass_assert( __visit_children() );
		__pedantic_pass_assert( expr );

		const ast::TypeSubstitution ** env_ptr = __pass::env( m_pass, 0);
		if ( env_ptr && expr->env ) {
			*env_ptr = expr->env;
		}

		return expr->accept( *this );
	}

	template< typename pass_t >
	ast::Statement * Pass< pass_t >::call_accept( const ast::Statement * stmt ) {
		__pedantic_pass_assert( __visit_children() );
		__pedantic_pass_assert( stmt );

		// add a few useful symbols to the scope
		using __pass::empty;
		using decls_t = typename std::remove_pointer< decltype(__decls_before()) >::type;
		using stmts_t = typename std::remove_pointer< decltype(__stmts_before()) >::type;

		// get the stmts/decls that will need to be spliced in
		auto stmts_before = __pass::stmtsToAddBefore( m_pass, 0);
		auto stmts_after  = __pass::stmtsToAddAfter ( m_pass, 0);
		auto decls_before = __pass::declsToAddBefore( m_pass, 0);
		auto decls_after  = __pass::declsToAddAfter ( m_pass, 0);

		// These may be modified by subnode but most be restored once we exit this statemnet.
		ValueGuardPtr< const ast::TypeSubstitution * > __old_env         ( __pass::env( m_pass, 0);  );
		ValueGuardPtr< typename std::remove_pointer< decltype(stmts_before) > __old_decls_before( stmts_before );
		ValueGuardPtr< typename std::remove_pointer< decltype(stmts_after ) > __old_decls_after ( stmts_after  );
		ValueGuardPtr< typename std::remove_pointer< decltype(decls_before) > __old_stmts_before( decls_before );
		ValueGuardPtr< typename std::remove_pointer< decltype(decls_after ) > __old_stmts_after ( decls_after  );

		// Now is the time to actually visit the node
		ast::Statement * nstmt = stmt->accept( *this );

		// If the pass doesn't want to add anything then we are done
		if( empty(stmts_before) && empty(stmts_after) && empty(decls_before) && empty(decls_after) ) {
			return nstmt;
		}

		// Make sure that it is either adding statements or declartions but not both
		// this is because otherwise the order would be awkward to predict
		assert(( empty( stmts_before ) && empty( stmts_after ))
		    || ( empty( decls_before ) && empty( decls_after )) );

		// Create a new Compound Statement to hold the new decls/stmts
		ast::CompoundStmt * compound = new ast::CompoundStmt( parent->*child.location );

		// Take all the declarations that go before
		__pass::take_all( std::back_inserter( compound->kids ), decls_before );
		__pass::take_all( std::back_inserter( compound->kids ), stmts_before );

		// Insert the original declaration
		compound->kids.push_back( nstmt );

		// Insert all the declarations that go before
		__pass::take_all( std::back_inserter( compound->kids ), decls_after );
		__pass::take_all( std::back_inserter( compound->kids ), stmts_after );

		return compound;
	}

	template< typename pass_t >
	template< template <class> class container_t >
	container_t< ast::ptr<ast::Statement> > Pass< pass_t >::call_accept( const container_t< ast::ptr<ast::Statement> > & statements ) {
		__pedantic_pass_assert( __visit_children() );
		if( statements.empty() ) return {};

		// We are going to aggregate errors for all these statements
		SemanticErrorException errors;

		// add a few useful symbols to the scope
		using __pass::empty;

		// get the stmts/decls that will need to be spliced in
		auto stmts_before = __pass::stmtsToAddBefore( pass, 0);
		auto stmts_after  = __pass::stmtsToAddAfter ( pass, 0);
		auto decls_before = __pass::declsToAddBefore( pass, 0);
		auto decls_after  = __pass::declsToAddAfter ( pass, 0);

		// These may be modified by subnode but most be restored once we exit this statemnet.
		ValueGuardPtr< typename std::remove_pointer< decltype(stmts_before) > __old_decls_before( stmts_before );
		ValueGuardPtr< typename std::remove_pointer< decltype(stmts_after ) > __old_decls_after ( stmts_after  );
		ValueGuardPtr< typename std::remove_pointer< decltype(decls_before) > __old_stmts_before( decls_before );
		ValueGuardPtr< typename std::remove_pointer< decltype(decls_after ) > __old_stmts_after ( decls_after  );

		// update pass statitistics
		pass_visitor_stats.depth++;
		pass_visitor_stats.max->push(pass_visitor_stats.depth);
		pass_visitor_stats.avg->push(pass_visitor_stats.depth);

		bool mutated = false;
		container_t<ast::ptr< ast::Statement >> new_kids;
		for( const ast::Statement * stmt : statements ) {
			try {
				__pedantic_pass_assert( stmt );
				const ast::Statment * new_stmt = stmt->accept( visitor );
				assert( new_stmt );
				if(new_stmt != stmt ) mutated = true;

				// Make sure that it is either adding statements or declartions but not both
				// this is because otherwise the order would be awkward to predict
				assert(( empty( stmts_before ) && empty( stmts_after ))
				    || ( empty( decls_before ) && empty( decls_after )) );



				// Take all the statements which should have gone after, N/A for first iteration
				__pass::take_all( std::back_inserter( new_kids ), decls_before, &mutated );
				__pass::take_all( std::back_inserter( new_kids ), stmts_before, &mutated );

				// Now add the statement if there is one
				new_kids.emplace_back( new_stmt );

				// Take all the declarations that go before
				__pass::take_all( std::back_inserter( new_kids ), decls_after, &mutated );
				__pass::take_all( std::back_inserter( new_kids ), stmts_after, &mutated );
			}
			catch ( SemanticErrorException &e ) {
				errors.append( e );
			}
		}
		pass_visitor_stats.depth--;
		if ( !errors.isEmpty() ) { throw errors; }

		return mutated ? new_kids : {};
	}

	template< typename pass_t >
	template< template <class> class container_t, typename node_t >
	container_t< ast::ptr<node_t> > Pass< pass_t >::call_accept( const container_t< ast::ptr<node_t> > & container ) {
		__pedantic_pass_assert( __visit_children() );
		if( container.empty() ) return {};
		SemanticErrorException errors;

		pass_visitor_stats.depth++;
		pass_visitor_stats.max->push(pass_visitor_stats.depth);
		pass_visitor_stats.avg->push(pass_visitor_stats.depth);

		bool mutated = false;
		container_t< ast::ptr<node_t> > new_kids;
		for ( const node_t * node : container ) {
			try {
				__pedantic_pass_assert( node );
				const node_t * new_node = strict_dynamic_cast< const node_t * >( node->accept( *this ) );
				if(new_stmt != stmt ) mutated = true;

				new_kids.emplace_back( new_stmt );
			}
			catch( SemanticErrorException &e ) {
				errors.append( e );
			}
		}
		pass_visitor_stats.depth--;
		if ( ! errors.isEmpty() ) { throw errors; }

		return mutated ? new_kids : {};
	}
}

//------------------------------------------------------------------------------------------------------------------------------------------------------------------------
//========================================================================================================================================================================
//========================================================================================================================================================================
//========================================================================================================================================================================
//========================================================================================================================================================================
//========================================================================================================================================================================
//------------------------------------------------------------------------------------------------------------------------------------------------------------------------

// A NOTE ON THE ORDER OF TRAVERSAL
//
// Types and typedefs have their base types visited before they are added to the type table.  This is ok, since there is
// no such thing as a recursive type or typedef.
//
//             typedef struct { T *x; } T; // never allowed
//
// for structs/unions, it is possible to have recursion, so the decl should be added as if it's incomplete to begin, the
// members are traversed, and then the complete type should be added (assuming the type is completed by this particular
// declaration).
//
//             struct T { struct T *x; }; // allowed
//
// It is important to add the complete type to the symbol table *after* the members/base has been traversed, since that
// traversal may modify the definition of the type and these modifications should be visible when the symbol table is
// queried later in this pass.

//--------------------------------------------------------------------------
// ObjectDecl
template< typename pass_t >
ast::DeclarationWithType * Pass< pass_t >::mutate( ast::ObjectDecl * node ) {
	VISIT_START( node );

	VISIT(
		{
			auto guard = make_indexer_guard();
			maybe_accept( node, ast::ObjectDecl::type );
		}
		maybe_accept( node, ast::ObjectDecl::init          );
		maybe_accept( node, ast::ObjectDecl::bitfieldWidth );
		maybe_accept( node, ast::ObjectDecl::attributes    );
	)

	__pass::indexer::AddId( m_pass, 0, node );

	VISIT_END( DeclarationWithType, node );
}

//--------------------------------------------------------------------------
// Attribute
template< typename pass_type >
ast::Attribute * ast::Pass< pass_type >::visit( ast::ptr<ast::Attribute> & node  )  {
	VISIT_START(node);

	VISIT(
		maybe_accept( node, ast::Attribute::parameters );
	)

	VISIT_END(ast::Attribute *, node );
}

//--------------------------------------------------------------------------
// TypeSubstitution
template< typename pass_type >
TypeSubstitution * PassVisitor< pass_type >::mutate( TypeSubstitution * node ) {
	MUTATE_START( node );

	#error this is broken

	for ( auto & p : node->typeEnv ) {
		indexerScopedMutate( p.second, *this );
	}
	for ( auto & p : node->varEnv ) {
		indexerScopedMutate( p.second, *this );
	}

	MUTATE_END( TypeSubstitution, node );
}

#undef VISIT_START
#undef VISIT
#undef VISIT_END