//
// Cforall Version 1.0.0 Copyright (C) 2019 University of Waterloo
//
// The contents of this file are covered under the licence agreement in the
// file "LICENCE" distributed with Cforall.
//
// Pass.impl.hpp --
//
// Author           : Thierry Delisle
// Created On       : Thu May 09 15::37::05 2019
// Last Modified By :
// Last Modified On :
// Update Count     :
//

#pragma once
// IWYU pragma: private, include "Pass.hpp"

namespace ast {
template<typename pass_type>
class Pass;

namespace __pass {
	typedef std::function<void( void * )> cleanup_func_t;
	typedef std::function<void( cleanup_func_t, void * )> at_cleanup_t;


	// boolean reference that may be null
	// either refers to a boolean value or is null and returns true
	class bool_ref {
	public:
		bool_ref() = default;
		~bool_ref() = default;

		operator bool() { return m_ref ? *m_ref : true; }
		bool operator=( bool val ) { assert(m_ref); return *m_ref = val; }

	private:

		friend class visit_children_guard;

		bool * set( bool * val ) {
			bool * prev = m_ref;
			m_ref = val;
			return prev;
		}

		bool * m_ref = nullptr;
	};

	// Implementation of the guard value
	// Created inside the visit scope
	class guard_value {
	public:
		/// Push onto the cleanup
		guard_value( at_cleanup_t * at_cleanup ) {
			if( at_cleanup ) {
				*at_cleanup = [this]( cleanup_func_t && func, void* val ) {
					push( std::move( func ), val );
				};
			}
		}

		~guard_value() {
			while( !cleanups.empty() ) {
				auto& cleanup = cleanups.top();
				cleanup.func( cleanup.val );
				cleanups.pop();
			}
		}

		void push( cleanup_func_t && func, void* val ) {
			cleanups.emplace( std::move(func), val );
		}

	private:
		struct cleanup_t {
			cleanup_func_t func;
			void * val;

			cleanup_t( cleanup_func_t&& func, void * val ) : func(func), val(val) {}
		};

		std::stack< cleanup_t > cleanups;
	};

	// Guard structure implementation for whether or not children should be visited
	class visit_children_guard {
	public:

		visit_children_guard( bool_ref * ref )
			: m_val ( true )
			, m_prev( ref ? ref->set( &m_val ) : nullptr )
			, m_ref ( ref )
		{}

		~visit_children_guard() {
			if( m_ref ) {
				m_ref->set( m_prev );
			}
		}

		operator bool() { return m_val; }

	private:
		bool       m_val;
		bool     * m_prev;
		bool_ref * m_ref;
	};

	/// "Short hand" to check if this is a valid previsit function
	/// Mostly used to make the static_assert look (and print) prettier
	template<typename pass_t, typename node_t>
	struct is_valid_previsit {
		using ret_t = decltype( ((pass_t*)nullptr)->previsit( (const node_t *)nullptr ) );

		static constexpr bool value = std::is_void< ret_t >::value ||
			std::is_base_of<const node_t, typename std::remove_pointer<ret_t>::type >::value;
	};

	/// Used by previsit implementation
	/// We need to reassign the result to 'node', unless the function
	/// returns void, then we just leave 'node' unchanged
	template<bool is_void>
	struct __assign;

	template<>
	struct __assign<true> {
		template<typename pass_t, typename node_t>
		static inline void result( pass_t & pass, const node_t * & node ) {
			pass.previsit( node );
		}
	};

	template<>
	struct __assign<false> {
		template<typename pass_t, typename node_t>
		static inline void result( pass_t & pass, const node_t * & node ) {
			node = pass.previsit( node );
			assertf(node, "Previsit must not return NULL");
		}
	};

	/// Used by postvisit implementation
	/// We need to return the result unless the function
	/// returns void, then we just return the original node
	template<bool is_void>
	struct __return;

	template<>
	struct __return<true> {
		template<typename pass_t, typename node_t>
		static inline const node_t * result( pass_t & pass, const node_t * & node ) {
			pass.postvisit( node );
			return node;
		}
	};

	template<>
	struct __return<false> {
		template<typename pass_t, typename node_t>
		static inline auto result( pass_t & pass, const node_t * & node ) {
			return pass.postvisit( node );
		}
	};

	//-------------------------------------------------------------------------------------------------------------------------------------------------------------------------
	// Deep magic (a.k.a template meta programming) to make the templated visitor work
	// Basically the goal is to make 2 previsit
	// 1 - Use when a pass implements a valid previsit. This uses overloading which means the any overload of
	//     'pass.previsit( node )' that compiles will be used for that node for that type
	//     This requires that this option only compile for passes that actually define an appropriate visit.
	//     SFINAE will make sure the compilation errors in this function don't halt the build.
	//     See http://en.cppreference.com/w/cpp/language/sfinae for details on SFINAE
	// 2 - Since the first implementation might not be specilizable, the second implementation exists and does nothing.
	//     This is needed only to eliminate the need for passes to specify any kind of handlers.
	//     The second implementation only works because it has a lower priority. This is due to the bogus last parameter.
	//     The second implementation takes a long while the first takes an int. Since the caller always passes an literal 0
	//     the first implementation takes priority in regards to overloading.
	//-------------------------------------------------------------------------------------------------------------------------------------------------------------------------
	// PreVisit : may mutate the pointer passed in if the node is mutated in the previsit call
	template<typename pass_t, typename node_t>
	static inline auto previsit( pass_t & pass, const node_t * & node, int ) -> decltype( pass.previsit( node ), void() ) {
		static_assert(
			is_valid_previsit<pass_t, node_t>::value,
			"Previsit may not change the type of the node. It must return its paremeter or void."
		);

		__assign<
			std::is_void<
				decltype( pass.previsit( node ) )
			>::value
		>::result( pass, node );
	}

	template<typename pass_t, typename node_t>
	static inline auto previsit( pass_t &, const node_t *, long ) {}

	// PostVisit : never mutates the passed pointer but may return a different node
	template<typename pass_t, typename node_t>
	static inline auto postvisit( pass_t & pass, const node_t * node, int ) ->
		decltype( pass.postvisit( node ), node->accept( *(Visitor*)nullptr ) )
	{
		return __return<
			std::is_void<
				decltype( pass.postvisit( node ) )
			>::value
		>::result( pass, node );
	}

	template<typename pass_t, typename node_t>
	static inline const node_t * postvisit( pass_t &, const node_t * node, long ) { return node; }

	//-------------------------------------------------------------------------------------------------------------------------------------------------------------------------
	// Deep magic (a.k.a template meta programming) continued
	// To make the templated visitor be more expressive, we allow 'accessories' : classes/structs the implementation can inherit
	// from in order to get extra functionallity for example
	// class ErrorChecker : WithShortCircuiting { ... };
	// Pass<ErrorChecker> checker;
	// this would define a pass that uses the templated visitor with the additionnal feature that it has short circuiting
	// Note that in all cases the accessories are not required but guarantee the requirements of the feature is matched
	//-------------------------------------------------------------------------------------------------------------------------------------------------------------------------
	// For several accessories, the feature is enabled by detecting that a specific field is present
	// Use a macro the encapsulate the logic of detecting a particular field
	// The type is not strictly enforced but does match the accessory
	#define FIELD_PTR( name, default_type ) \
	template< typename pass_t > \
	static inline auto name( pass_t & pass, int ) -> decltype( &pass.name ) { return &pass.name; } \
	\
	template< typename pass_t > \
	static inline default_type * name( pass_t &, long ) { return nullptr; }

	// List of fields and their expected types
	FIELD_PTR( env, const ast::TypeSubstitution * )
	FIELD_PTR( stmtsToAddBefore, std::list< ast::ptr< ast::Stmt > > )
	FIELD_PTR( stmtsToAddAfter , std::list< ast::ptr< ast::Stmt > > )
	FIELD_PTR( declsToAddBefore, std::list< ast::ptr< ast::Decl > > )
	FIELD_PTR( declsToAddAfter , std::list< ast::ptr< ast::Decl > > )
	FIELD_PTR( visit_children, __pass::bool_ref )
	FIELD_PTR( at_cleanup, __pass::at_cleanup_t )
	FIELD_PTR( visitor, ast::Pass<pass_t> * const )

	// Remove the macro to make sure we don't clash
	#undef FIELD_PTR

	// Another feature of the templated visitor is that it calls beginScope()/endScope() for compound statement.
	// All passes which have such functions are assumed desire this behaviour
	// detect it using the same strategy
	namespace scope {
		template<typename pass_t>
		static inline auto enter( pass_t & pass, int ) -> decltype( pass.beginScope(), void() ) {
			pass.beginScope();
		}

		template<typename pass_t>
		static inline void enter( pass_t &, long ) {}

		template<typename pass_t>
		static inline auto leave( pass_t & pass, int ) -> decltype( pass.endScope(), void() ) {
			pass.endScope();
		}

		template<typename pass_t>
		static inline void leave( pass_t &, long ) {}
	};

	// Finally certain pass desire an up to date indexer automatically
	// detect the presence of a member name indexer and call all the members appropriately
	namespace indexer {
		// Some simple scoping rules
		template<typename pass_t>
		static inline auto enter( pass_t & pass, int ) -> decltype( pass.indexer.enterScope(), void() ) {
			pass.indexer.enterScope();
		}

		template<typename pass_t>
		static inline auto enter( pass_t &, long ) {}

		template<typename pass_t>
		static inline auto leave( pass_t & pass, int ) -> decltype( pass.indexer.leaveScope(), void() ) {
			pass.indexer.leaveScope();
		}

		template<typename pass_t>
		static inline auto leave( pass_t &, long ) {}

		// The indexer has 2 kind of functions mostly, 1 argument and 2 arguments
		// Create macro to condense these common patterns
		#define INDEXER_FUNC1( func, type ) \
		template<typename pass_t> \
		static inline auto func( pass_t & pass, int, type arg ) -> decltype( pass.indexer.func( arg ), void() ) {\
			pass.indexer.func( arg ); \
		} \
		\
		template<typename pass_t> \
		static inline void func( pass_t &, long, type ) {}

		#define INDEXER_FUNC2( func, type1, type2 ) \
		template<typename pass_t> \
		static inline auto func( pass_t & pass, int, type1 arg1, type2 arg2 ) -> decltype( pass.indexer.func( arg1, arg2 ), void () ) {\
			pass.indexer.func( arg1, arg2 ); \
		} \
			\
		template<typename pass_t> \
		static inline void func( pass_t &, long, type1, type2 ) {}

		INDEXER_FUNC1( addId     , const DeclWithType *  );
		INDEXER_FUNC1( addType   , const NamedTypeDecl * );
		INDEXER_FUNC1( addStruct , const StructDecl *    );
		INDEXER_FUNC1( addEnum   , const EnumDecl *      );
		INDEXER_FUNC1( addUnion  , const UnionDecl *     );
		INDEXER_FUNC1( addTrait  , const TraitDecl *     );
		INDEXER_FUNC2( addWith   , const std::vector< ptr<Expr> > &, const Node * );

		// A few extra functions have more complicated behaviour, they are hand written
		template<typename pass_t>
		static inline auto addStructFwd( pass_t & pass, int, const ast::StructDecl * decl ) -> decltype( pass.indexer.addStruct( decl ), void() ) {
			ast::StructDecl * fwd = new ast::StructDecl( decl->location, decl->name );
			fwd->params = decl->params;
			pass.indexer.addStruct( fwd );
		}

		template<typename pass_t>
		static inline void addStructFwd( pass_t &, long, const ast::StructDecl * ) {}

		template<typename pass_t>
		static inline auto addUnionFwd( pass_t & pass, int, const ast::UnionDecl * decl ) -> decltype( pass.indexer.addUnion( decl ), void() ) {
			UnionDecl * fwd = new UnionDecl( decl->location, decl->name );
			fwd->params = decl->params;
			pass.indexer.addUnion( fwd );
		}

		template<typename pass_t>
		static inline void addUnionFwd( pass_t &, long, const ast::UnionDecl * ) {}

		template<typename pass_t>
		static inline auto addStruct( pass_t & pass, int, const std::string & str ) -> decltype( pass.indexer.addStruct( str ), void() ) {
			if ( ! pass.indexer.lookupStruct( str ) ) {
				pass.indexer.addStruct( str );
			}
		}

		template<typename pass_t>
		static inline void addStruct( pass_t &, long, const std::string & ) {}

		template<typename pass_t>
		static inline auto addUnion( pass_t & pass, int, const std::string & str ) -> decltype( pass.indexer.addUnion( str ), void() ) {
			if ( ! pass.indexer.lookupUnion( str ) ) {
				pass.indexer.addUnion( str );
			}
		}

		template<typename pass_t>
		static inline void addUnion( pass_t &, long, const std::string & ) {}

		#undef INDEXER_FUNC1
		#undef INDEXER_FUNC2
	};
};
};