#pragma once
#define LIST_VARIANT work_stealing

#include <cmath>
#include <iomanip>
#include <memory>
#include <mutex>
#include <thread>
#include <type_traits>

#include "assert.hpp"
#include "utils.hpp"
#include "links.hpp"
#include "links2.hpp"
#include "snzi.hpp"

// #include <x86intrin.h>

using namespace std;

static const long long lim = 2000;
static const unsigned nqueues = 2;

struct __attribute__((aligned(128))) timestamp_t {
	volatile unsigned long long val = 0;
};

template<typename node_t>
struct __attribute__((aligned(128))) localQ_t {
	#ifdef NO_MPSC
		intrusive_queue_t<node_t> list;

		inline auto ts() { return list.ts(); }
		inline auto lock() { return list.lock.lock(); }
		inline auto try_lock() { return list.lock.try_lock(); }
		inline auto unlock() { return list.lock.unlock(); }

		inline auto push( node_t * node ) { return list.push( node ); }
		inline auto pop() { return list.pop(); }
	#else
		mpsc_queue<node_t> queue = {};
		spinlock_t _lock = {};

		inline auto ts() { auto h = queue.head(); return h ? h->_links.ts : 0ull; }
		inline auto lock() { return _lock.lock(); }
		inline auto try_lock() { return _lock.try_lock(); }
		inline auto unlock() { return _lock.unlock(); }

		inline auto push( node_t * node ) { return queue.push( node ); }
		inline auto pop() { return queue.pop(); }
	#endif


};

template<typename node_t>
class __attribute__((aligned(128))) work_stealing {
	static_assert(std::is_same<decltype(node_t::_links), _LinksFields_t<node_t>>::value, "Node must have a links field");

public:
	static const char * name() {
		return "Work Stealing";
	}

	work_stealing(unsigned _numThreads, unsigned)
		: numThreads(_numThreads * nqueues)
		, lists(new localQ_t<node_t>[numThreads])
		// , lists(new intrusive_queue_t<node_t>[numThreads])
		, times(new timestamp_t[numThreads])
		// , snzi( std::log2( numThreads / 2 ), 2 )

	{
		std::cout << "Constructing Work Stealer with " << numThreads << std::endl;
	}

	~work_stealing() {
		std::cout << "Destroying Work Stealer" << std::endl;
		lists.reset();
	}

	__attribute__((noinline, hot)) void push(node_t * node) {
		node->_links.ts = rdtscl();
		// node->_links.ts = 1;

		auto & list = *({
			unsigned i;
			#ifdef NO_MPSC
				do {
			#endif
				tls.stats.push.attempt++;
				// unsigned r = tls.rng1.next();
				unsigned r = tls.it++;
				if(tls.my_queue == outside) {
					i = r % numThreads;
				} else {
					i = tls.my_queue + (r % nqueues);
				}
			#ifdef NO_MPSC
				} while(!lists[i].try_lock());
			#endif
		 	&lists[i];
		});

		list.push( node );
		#ifdef NO_MPSC
			list.unlock();
		#endif
		// tls.rng2.set_raw_state( tls.rng1.get_raw_state());
		// count++;
		tls.stats.push.success++;
	}

	__attribute__((noinline, hot)) node_t * pop() {
		if(tls.my_queue != outside) {
			// if( tls.myfriend == outside ) {
			// 	auto r  = tls.rng1.next();
			// 	tls.myfriend = r % numThreads;
			// 	// assert(lists[(tls.it % nqueues) + tls.my_queue].ts() >= lists[((tls.it + 1) % nqueues) + tls.my_queue].ts());
			// 	tls.mytime = std::min(lists[(tls.it % nqueues) + tls.my_queue].ts(), lists[((tls.it + 1) % nqueues) + tls.my_queue].ts());
			// 	// times[tls.myfriend].val = 0;
			// 	// lists[tls.myfriend].val = 0;
			// }
			// // else if(times[tls.myfriend].val == 0) {
			// // else if(lists[tls.myfriend].val == 0) {
			// else if(times[tls.myfriend].val < tls.mytime) {
			// // else if(times[tls.myfriend].val < lists[(tls.it % nqueues) + tls.my_queue].ts()) {
			// 	node_t * n = try_pop(tls.myfriend, tls.stats.pop.help);
			// 	tls.stats.help++;
			// 	tls.myfriend = outside;
			// 	if(n) return n;
			// }
			// if( tls.myfriend == outside ) {
			// 	auto r  = tls.rng1.next();
			// 	tls.myfriend = r % numThreads;
			// 	tls.mytime = lists[((tls.it + 1) % nqueues) + tls.my_queue].ts();
			// }
			// else {
			// 	if(times[tls.myfriend].val + 1000 < tls.mytime) {
			// 		node_t * n = try_pop(tls.myfriend, tls.stats.pop.help);
			// 		tls.stats.help++;
			// 		if(n) return n;
			// 	}
			// 	tls.myfriend = outside;
			// }

			node_t * n = local();
			if(n) return n;
		}

		// try steal
		for(int i = 0; i < 25; i++) {
			node_t * n = steal();
			if(n) return n;
		}

		return search();
	}

private:
	inline node_t * local() {
		unsigned i = (--tls.it % nqueues) + tls.my_queue;
		node_t * n = try_pop(i, tls.stats.pop.local);
		if(n) return n;
		i = (--tls.it % nqueues) + tls.my_queue;
		return try_pop(i, tls.stats.pop.local);
	}

	inline node_t * steal() {
		unsigned i = tls.rng2.prev() % numThreads;
		return try_pop(i, tls.stats.pop.steal);
	}

	inline node_t * search() {
		unsigned offset = tls.rng2.prev();
		for(unsigned i = 0; i < numThreads; i++) {
			unsigned idx = (offset + i) % numThreads;
			node_t * thrd = try_pop(idx, tls.stats.pop.search);
			if(thrd) {
				return thrd;
			}
		}

		return nullptr;
	}

private:
	struct attempt_stat_t {
		std::size_t attempt = { 0 };
		std::size_t elock   = { 0 };
		std::size_t eempty  = { 0 };
		std::size_t espec   = { 0 };
		std::size_t success = { 0 };
	};

	node_t * try_pop(unsigned i, attempt_stat_t & stat) {
		assert(i < numThreads);
		auto & list = lists[i];
		stat.attempt++;

		// If the list is empty, don't try
		if(list.ts() == 0) { stat.espec++; return nullptr; }

		// If we can't get the lock, move on
		if( !list.try_lock() ) { stat.elock++; return nullptr; }

		// If list is empty, unlock and retry
		if( list.ts() == 0 ) {
			list.unlock();
			stat.eempty++;
			return nullptr;
		}

		auto node = list.pop();
		list.unlock();
		stat.success++;
		#ifdef NO_MPSC
			// times[i].val = 1;
			times[i].val = node.first->_links.ts;
			// lists[i].val = node.first->_links.ts;
			return node.first;
		#else
			times[i].val = node->_links.ts;
			return node;
		#endif
	}


public:

	static std::atomic_uint32_t ticket;
	static const unsigned outside = 0xFFFFFFFF;

	static inline unsigned calc_preferred() {
		unsigned t = ticket++;
		if(t == 0) return outside;
		unsigned i = (t - 1) * nqueues;
		return i;
	}

	static __attribute__((aligned(128))) thread_local struct TLS {
		Random     rng1 = { unsigned(std::hash<std::thread::id>{}(std::this_thread::get_id()) ^ rdtscl()) };
		Random     rng2 = { unsigned(std::hash<std::thread::id>{}(std::this_thread::get_id()) ^ rdtscl()) };
		unsigned   it   = 0;
		unsigned   my_queue = calc_preferred();
		unsigned   myfriend = outside;
		unsigned long long int mytime = 0;
		#if defined(READ)
			unsigned it = 0;
		#endif
		struct {
			struct {
				std::size_t attempt = { 0 };
				std::size_t success = { 0 };
			} push;
			struct {
				attempt_stat_t help;
				attempt_stat_t local;
				attempt_stat_t steal;
				attempt_stat_t search;
			} pop;
			std::size_t help = { 0 };
		} stats;
	} tls;

private:
	const unsigned numThreads;
    	std::unique_ptr<localQ_t<node_t> []> lists;
    	// std::unique_ptr<intrusive_queue_t<node_t> []> lists;
    	std::unique_ptr<timestamp_t []> times;
	__attribute__((aligned(128))) std::atomic_size_t count;

#ifndef NO_STATS
private:
	static struct GlobalStats {
		struct {
			std::atomic_size_t attempt = { 0 };
			std::atomic_size_t success = { 0 };
		} push;
		struct {
			struct {
				std::atomic_size_t attempt = { 0 };
				std::atomic_size_t elock   = { 0 };
				std::atomic_size_t eempty  = { 0 };
				std::atomic_size_t espec   = { 0 };
				std::atomic_size_t success = { 0 };
			} help;
			struct {
				std::atomic_size_t attempt = { 0 };
				std::atomic_size_t elock   = { 0 };
				std::atomic_size_t eempty  = { 0 };
				std::atomic_size_t espec   = { 0 };
				std::atomic_size_t success = { 0 };
			} local;
			struct {
				std::atomic_size_t attempt = { 0 };
				std::atomic_size_t elock   = { 0 };
				std::atomic_size_t eempty  = { 0 };
				std::atomic_size_t espec   = { 0 };
				std::atomic_size_t success = { 0 };
			} steal;
			struct {
				std::atomic_size_t attempt = { 0 };
				std::atomic_size_t elock   = { 0 };
				std::atomic_size_t eempty  = { 0 };
				std::atomic_size_t espec   = { 0 };
				std::atomic_size_t success = { 0 };
			} search;
		} pop;
		std::atomic_size_t help = { 0 };
	} global_stats;

public:
	static void stats_tls_tally() {
		global_stats.push.attempt += tls.stats.push.attempt;
		global_stats.push.success += tls.stats.push.success;
		global_stats.pop.help  .attempt += tls.stats.pop.help  .attempt;
		global_stats.pop.help  .elock   += tls.stats.pop.help  .elock  ;
		global_stats.pop.help  .eempty  += tls.stats.pop.help  .eempty ;
		global_stats.pop.help  .espec   += tls.stats.pop.help  .espec  ;
		global_stats.pop.help  .success += tls.stats.pop.help  .success;
		global_stats.pop.local .attempt += tls.stats.pop.local .attempt;
		global_stats.pop.local .elock   += tls.stats.pop.local .elock  ;
		global_stats.pop.local .eempty  += tls.stats.pop.local .eempty ;
		global_stats.pop.local .espec   += tls.stats.pop.local .espec  ;
		global_stats.pop.local .success += tls.stats.pop.local .success;
		global_stats.pop.steal .attempt += tls.stats.pop.steal .attempt;
		global_stats.pop.steal .elock   += tls.stats.pop.steal .elock  ;
		global_stats.pop.steal .eempty  += tls.stats.pop.steal .eempty ;
		global_stats.pop.steal .espec   += tls.stats.pop.steal .espec  ;
		global_stats.pop.steal .success += tls.stats.pop.steal .success;
		global_stats.pop.search.attempt += tls.stats.pop.search.attempt;
		global_stats.pop.search.elock   += tls.stats.pop.search.elock  ;
		global_stats.pop.search.eempty  += tls.stats.pop.search.eempty ;
		global_stats.pop.search.espec   += tls.stats.pop.search.espec  ;
		global_stats.pop.search.success += tls.stats.pop.search.success;
		global_stats.help += tls.stats.help;
	}

	static void stats_print(std::ostream & os, double duration ) {
		std::cout << "----- Work Stealing Stats -----" << std::endl;

		double push_suc = (100.0 * double(global_stats.push.success) / global_stats.push.attempt);
		double push_len = double(global_stats.push.attempt     ) / global_stats.push.success;
		os << "Push   Pick : " << push_suc << " %, len " << push_len << " (" << global_stats.push.attempt      << " / " << global_stats.push.success << ")\n";

		double hlp_suc = (100.0 * double(global_stats.pop.help.success) / global_stats.pop.help.attempt);
		double hlp_len = double(global_stats.pop.help.attempt     ) / global_stats.pop.help.success;
		os << "Help        : " << hlp_suc << " %, len " << hlp_len << " (" << global_stats.pop.help.attempt      << " / " << global_stats.pop.help.success << ")\n";
		os << "Help Fail   : " << global_stats.pop.help.espec << "s, " << global_stats.pop.help.eempty << "e, " << global_stats.pop.help.elock << "l\n";

		double pop_suc = (100.0 * double(global_stats.pop.local.success) / global_stats.pop.local.attempt);
		double pop_len = double(global_stats.pop.local.attempt     ) / global_stats.pop.local.success;
		os << "Local       : " << pop_suc << " %, len " << pop_len << " (" << global_stats.pop.local.attempt      << " / " << global_stats.pop.local.success << ")\n";
		os << "Local Fail  : " << global_stats.pop.local.espec << "s, " << global_stats.pop.local.eempty << "e, " << global_stats.pop.local.elock << "l\n";

		double stl_suc = (100.0 * double(global_stats.pop.steal.success) / global_stats.pop.steal.attempt);
		double stl_len = double(global_stats.pop.steal.attempt     ) / global_stats.pop.steal.success;
		os << "Steal       : " << stl_suc << " %, len " << stl_len << " (" << global_stats.pop.steal.attempt      << " / " << global_stats.pop.steal.success << ")\n";
		os << "Steal Fail  : " << global_stats.pop.steal.espec << "s, " << global_stats.pop.steal.eempty << "e, " << global_stats.pop.steal.elock << "l\n";

		double srh_suc = (100.0 * double(global_stats.pop.search.success) / global_stats.pop.search.attempt);
		double srh_len = double(global_stats.pop.search.attempt     ) / global_stats.pop.search.success;
		os << "Search      : " << srh_suc << " %, len " << srh_len << " (" << global_stats.pop.search.attempt      << " / " << global_stats.pop.search.success << ")\n";
		os << "Search Fail : " << global_stats.pop.search.espec << "s, " << global_stats.pop.search.eempty << "e, " << global_stats.pop.search.elock << "l\n";
		os << "Helps       : " << std::setw(15) << std::scientific << global_stats.help / duration << "/sec (" << global_stats.help  << ")\n";
	}
private:
#endif
};