#pragma once
#define LIST_VARIANT work_stealing

#include <cmath>
#include <iomanip>
#include <memory>
#include <mutex>
#include <type_traits>

#include "assert.hpp"
#include "utils.hpp"
#include "links.hpp"
#include "snzi.hpp"

using namespace std;

template<typename node_t>
class __attribute__((aligned(128))) work_stealing {
	static_assert(std::is_same<decltype(node_t::_links), _LinksFields_t<node_t>>::value, "Node must have a links field");

public:
	static const char * name() {
		return "Work Stealing";
	}

	work_stealing(unsigned _numThreads, unsigned)
		: numThreads(_numThreads)
		, lists(new intrusive_queue_t<node_t>[numThreads])
		, snzi( std::log2( numThreads / 2 ), 2 )

	{
		std::cout << "Constructing Work Stealer with " << numThreads << std::endl;
	}

	~work_stealing() {
		std::cout << "Destroying Work Stealer" << std::endl;
		lists.reset();
	}

	__attribute__((noinline, hot)) void push(node_t * node) {
		node->_links.ts = rdtscl();
		if( node->_links.hint > numThreads ) {
			node->_links.hint = tls.rng.next() % numThreads;
			tls.stat.push.nhint++;
		}

		unsigned i = node->_links.hint;
		auto & list = lists[i];
		list.lock.lock();

		if(list.push( node )) {
			snzi.arrive(i);
		}

		list.lock.unlock();
	}

	__attribute__((noinline, hot)) node_t * pop() {
		node_t * node;
		while(true) {
			if(!snzi.query()) {
				return nullptr;
			}

			{
				unsigned i = tls.my_queue;
				auto & list = lists[i];
				if( list.ts() != 0 ) {
					list.lock.lock();
					if((node = try_pop(i))) {
						tls.stat.pop.local.success++;
						break;
					}
					else {
						tls.stat.pop.local.elock++;
					}
				}
				else {
					tls.stat.pop.local.espec++;
				}
			}

			tls.stat.pop.steal.tried++;

			int i = tls.rng.next() % numThreads;
			auto & list = lists[i];
			if( list.ts() == 0 ) {
				tls.stat.pop.steal.empty++;
				continue;
			}

			if( !list.lock.try_lock() ) {
				tls.stat.pop.steal.locked++;
				continue;
			}

			if((node = try_pop(i))) {
				tls.stat.pop.steal.success++;
				break;
			}
		}

		#if defined(READ)
			const unsigned f = READ;
			if(0 == (tls.it % f)) {
				unsigned i = tls.it / f;
				lists[i % numThreads].ts();
			}
			// lists[tls.it].ts();
			tls.it++;
		#endif


		return node;
	}

private:
	node_t * try_pop(unsigned i) {
		auto & list = lists[i];

		// If list is empty, unlock and retry
		if( list.ts() == 0 ) {
			list.lock.unlock();
			return nullptr;
		}

			// Actually pop the list
		node_t * node;
		bool emptied;
		std::tie(node, emptied) = list.pop();
		assert(node);

		if(emptied) {
			snzi.depart(i);
		}

		// Unlock and return
		list.lock.unlock();
		return node;
	}


public:

	static std::atomic_uint32_t ticket;
	static __attribute__((aligned(128))) thread_local struct TLS {
		Random     rng = { int(rdtscl()) };
		unsigned   my_queue = ticket++;
		#if defined(READ)
			unsigned it = 0;
		#endif
		struct {
			struct {
				std::size_t nhint = { 0 };
			} push;
			struct {
				struct {
					std::size_t success = { 0 };
					std::size_t espec = { 0 };
					std::size_t elock = { 0 };
				} local;
				struct {
					std::size_t tried   = { 0 };
					std::size_t locked  = { 0 };
					std::size_t empty   = { 0 };
					std::size_t success = { 0 };
				} steal;
			} pop;
		} stat;
	} tls;

private:
	const unsigned numThreads;
    	std::unique_ptr<intrusive_queue_t<node_t> []> lists;
	__attribute__((aligned(64))) snzi_t snzi;

#ifndef NO_STATS
private:
	static struct GlobalStats {
		struct {
			std::atomic_size_t nhint = { 0 };
		} push;
		struct {
			struct {
				std::atomic_size_t success = { 0 };
				std::atomic_size_t espec = { 0 };
				std::atomic_size_t elock = { 0 };
			} local;
			struct {
				std::atomic_size_t tried   = { 0 };
				std::atomic_size_t locked  = { 0 };
				std::atomic_size_t empty   = { 0 };
				std::atomic_size_t success = { 0 };
			} steal;
		} pop;
	} global_stats;

public:
	static void stats_tls_tally() {
		global_stats.push.nhint += tls.stat.push.nhint;
		global_stats.pop.local.success += tls.stat.pop.local.success;
		global_stats.pop.local.espec   += tls.stat.pop.local.espec  ;
		global_stats.pop.local.elock   += tls.stat.pop.local.elock  ;
		global_stats.pop.steal.tried   += tls.stat.pop.steal.tried  ;
		global_stats.pop.steal.locked  += tls.stat.pop.steal.locked ;
		global_stats.pop.steal.empty   += tls.stat.pop.steal.empty  ;
		global_stats.pop.steal.success += tls.stat.pop.steal.success;
	}

	static void stats_print(std::ostream & os ) {
		std::cout << "----- Work Stealing Stats -----" << std::endl;

		double stealSucc = double(global_stats.pop.steal.success) / global_stats.pop.steal.tried;
		os << "Push to new Q : " << std::setw(15) << global_stats.push.nhint << "\n";
		os << "Local Pop     : " << std::setw(15) << global_stats.pop.local.success << "\n";
		os << "Steal Pop     : " << std::setw(15) << global_stats.pop.steal.success << "(" << global_stats.pop.local.espec << "s, " << global_stats.pop.local.elock << "l)\n";
		os << "Steal Success : " << std::setw(15) << stealSucc << "(" << global_stats.pop.steal.tried << " tries)\n";
		os << "Steal Fails   : " << std::setw(15) << global_stats.pop.steal.empty << "e, " << global_stats.pop.steal.locked << "l\n";
	}
private:
#endif
};