#pragma once
#define LIST_VARIANT relaxed_list

#define VANILLA 0
#define SNZI 1
#define BITMASK 2
#define DISCOVER 3
#define SNZM 4
#define BIAS 5

#ifndef VARIANT
#define VARIANT VANILLA
#endif

#ifndef NO_STATS
#include <iostream>
#endif

#include <cmath>
#include <memory>
#include <mutex>
#include <type_traits>

#include "assert.hpp"
#include "utils.hpp"
#include "links.hpp"
#include "snzi.hpp"
#include "snzm.hpp"

using namespace std;

struct pick_stat {
	struct {
		size_t attempt = 0;
		size_t success = 0;
		size_t local = 0;
	} push;
	struct {
		size_t attempt = 0;
		size_t success = 0;
		size_t mask_attempt = 0;
		size_t mask_reset = 0;
		size_t local = 0;
	} pop;
};

struct empty_stat {
	struct {
		size_t value = 0;
		size_t count = 0;
	} push;
	struct {
		size_t value = 0;
		size_t count = 0;
	} pop;
};

template<typename node_t>
class __attribute__((aligned(128))) relaxed_list {
	static_assert(std::is_same<decltype(node_t::_links), _LinksFields_t<node_t>>::value, "Node must have a links field");

public:
	static const char * name() {
		const char * names[] = {
			"RELAXED: VANILLA",
			"RELAXED: SNZI",
			"RELAXED: BITMASK",
			"RELAXED: SNZI + DISCOVERED MASK",
			"RELAXED: SNZI + MASK",
			"RELAXED: SNZI + LOCAL BIAS"
		};
		return names[VARIANT];
	}

	relaxed_list(unsigned numThreads, unsigned numQueues)
		: numLists(numThreads * numQueues)
	  	, lists(new intrusive_queue_t<node_t>[numLists])
		#if VARIANT == SNZI || VARIANT == BIAS
			, snzi( std::log2( numLists / (2 * numQueues) ), 2 )
		#elif VARIANT == SNZM || VARIANT == DISCOVER
			, snzm( numLists )
		#endif
	{
		assertf(7 * 8 * 8 >= numLists, "List currently only supports 448 sublists");
		std::cout << "Constructing Relaxed List with " << numLists << std::endl;
	}

	~relaxed_list() {
		std::cout << "Destroying Relaxed List" << std::endl;
		lists.reset();
	}

    	__attribute__((noinline, hot)) void push(node_t * node) {
		node->_links.ts = rdtscl();

		while(true) {
			// Pick a random list
			#if VARIANT == BIAS
			unsigned r = tls.rng.next();
			unsigned i;
			if(0 == (r & 0xF)) {
				i = r >> 4;
			} else {
				i = tls.my_queue + ((r >> 4) % 4);
				tls.pick.push.local++;
			}
			i %= numLists;
			#else
			unsigned i = tls.rng.next() % numLists;
			#endif

			#ifndef NO_STATS
				tls.pick.push.attempt++;
			#endif

			// If we can't lock it retry
			if( !lists[i].lock.try_lock() ) continue;

			#if VARIANT != SNZM && VARIANT != SNZI && VARIANT != DISCOVER && VARIANT != BIAS
				__attribute__((unused)) int num = numNonEmpty;
			#endif

			// Actually push it
			if(lists[i].push(node)) {
				#if VARIANT == DISCOVER
					size_t qword = i >> 6ull;
					size_t bit   = i & 63ull;
					assert(qword == 0);
					bts(tls.mask, bit);
					snzm.arrive(i);
				#elif VARIANT == SNZI || VARIANT == BIAS
					snzi.arrive(i);
				#elif VARIANT == SNZM
					snzm.arrive(i);
				#elif VARIANT == BITMASK
					numNonEmpty++;
					size_t qword = i >> 6ull;
					size_t bit   = i & 63ull;
					assertf((list_mask[qword] & (1ul << bit)) == 0, "Before set %zu:%zu (%u), %zx & %zx", qword, bit, i, list_mask[qword].load(), (1ul << bit));
					__attribute__((unused)) bool ret = bts(list_mask[qword], bit);
					assert(!ret);
					assertf((list_mask[qword] & (1ul << bit)) != 0, "After set %zu:%zu (%u), %zx & %zx", qword, bit, i, list_mask[qword].load(), (1ul << bit));
				#else
					numNonEmpty++;
				#endif
			}
			#if VARIANT != SNZM && VARIANT != SNZI && VARIANT != DISCOVER && VARIANT != BIAS
				assert(numNonEmpty <= (int)numLists);
			#endif

			// Unlock and return
			lists[i].lock.unlock();

			#ifndef NO_STATS
				tls.pick.push.success++;
				#if VARIANT != SNZM && VARIANT != SNZI && VARIANT != DISCOVER && VARIANT != BIAS
					tls.empty.push.value += num;
					tls.empty.push.count += 1;
				#endif
			#endif
			return;
		}
    	}

	__attribute__((noinline, hot)) node_t * pop() {
		#if VARIANT == DISCOVER
			assert(numLists <= 64);
			while(snzm.query()) {
				tls.pick.pop.mask_attempt++;
				unsigned i, j;
				{
					// Pick first list totally randomly
					i = tls.rng.next() % numLists;

					// Pick the other according to the bitmask
					unsigned r = tls.rng.next();

					size_t mask = tls.mask.load(std::memory_order_relaxed);
					if(mask == 0) {
						tls.pick.pop.mask_reset++;
						mask = (1U << numLists) - 1;
						tls.mask.store(mask, std::memory_order_relaxed);
					}

					unsigned b = rand_bit(r, mask);

					assertf(b < 64, "%zu %u", mask, b);

					j = b;

					assert(j < numLists);
				}

				if(auto node = try_pop(i, j)) return node;
			}
		#elif VARIANT == SNZI
			while(snzi.query()) {
				// Pick two lists at random
				int i = tls.rng.next() % numLists;
				// int j = tls.rng.next() % numLists;

				if(auto node = try_pop(i, j)) return node;
			}

		#elif VARIANT == BIAS
			while(snzi.query()) {
				// Pick two lists at random
				unsigned ri = tls.rng.next();
				unsigned i;
				unsigned j = tls.rng.next();
				if(0 == (ri & 0xF)) {
					i = (ri >> 4) % numLists;
				} else {
					i = tls.my_queue + ((ri >> 4) % 4);
					j = tls.my_queue + ((j >> 4) % 4);
					tls.pick.pop.local++;
				}
				i %= numLists;
				j %= numLists;

				if(auto node = try_pop(i, j)) return node;
			}
		#elif VARIANT == SNZM
			//*
			while(snzm.query()) {
				tls.pick.pop.mask_attempt++;
				unsigned i, j;
				{
					// Pick two random number
					unsigned ri = tls.rng.next();
					unsigned rj = tls.rng.next();

					// Pick two nodes from it
					unsigned wdxi = ri & snzm.mask;
					// unsigned wdxj = rj & snzm.mask;

					// Get the masks from the nodes
					// size_t maski = snzm.masks(wdxi);
					size_t maskj = snzm.masks(wdxj);

					if(maski == 0 && maskj == 0) continue;

					#if defined(__BMI2__)
						uint64_t idxsi = _pext_u64(snzm.indexes, maski);
						// uint64_t idxsj = _pext_u64(snzm.indexes, maskj);

						auto pi = __builtin_popcountll(maski);
						// auto pj = __builtin_popcountll(maskj);

						ri = pi ? ri & ((pi >> 3) - 1) : 0;
						rj = pj ? rj & ((pj >> 3) - 1) : 0;

						unsigned bi = (idxsi >> (ri << 3)) & 0xff;
						unsigned bj = (idxsj >> (rj << 3)) & 0xff;
					#else
						unsigned bi = rand_bit(ri >> snzm.depth, maski);
						unsigned bj = rand_bit(rj >> snzm.depth, maskj);
					#endif

					i = (bi << snzm.depth) | wdxi;
					j = (bj << snzm.depth) | wdxj;

					/* paranoid */ assertf(i < numLists, "%u %u", bj, wdxi);
					/* paranoid */ assertf(j < numLists, "%u %u", bj, wdxj);
				}

				if(auto node = try_pop(i, j)) return node;
			}
			/*/
			while(snzm.query()) {
				// Pick two lists at random
				int i = tls.rng.next() % numLists;
				int j = tls.rng.next() % numLists;

				if(auto node = try_pop(i, j)) return node;
			}
			//*/
		#elif VARIANT == BITMASK
			int nnempty;
			while(0 != (nnempty = numNonEmpty)) {
				tls.pick.pop.mask_attempt++;
				unsigned i, j;
				{
					// Pick two lists at random
					unsigned num = ((numLists - 1) >> 6) + 1;

					unsigned ri = tls.rng.next();
					unsigned rj = tls.rng.next();

					unsigned wdxi = (ri >> 6u) % num;
					unsigned wdxj = (rj >> 6u) % num;

					size_t maski = list_mask[wdxi].load(std::memory_order_relaxed);
					size_t maskj = list_mask[wdxj].load(std::memory_order_relaxed);

					if(maski == 0 && maskj == 0) continue;

					unsigned bi = rand_bit(ri, maski);
					unsigned bj = rand_bit(rj, maskj);

					assertf(bi < 64, "%zu %u", maski, bi);
					assertf(bj < 64, "%zu %u", maskj, bj);

					i = bi | (wdxi << 6);
					j = bj | (wdxj << 6);

					assertf(i < numLists, "%u", wdxi << 6);
					assertf(j < numLists, "%u", wdxj << 6);
				}

				if(auto node = try_pop(i, j)) return node;
			}
		#else
			while(numNonEmpty != 0) {
				// Pick two lists at random
				int i = tls.rng.next() % numLists;
				int j = tls.rng.next() % numLists;

				if(auto node = try_pop(i, j)) return node;
			}
		#endif

		return nullptr;
    	}

private:
	node_t * try_pop(unsigned i, unsigned j) {
		#ifndef NO_STATS
			tls.pick.pop.attempt++;
		#endif

		#if VARIANT == DISCOVER
			if(lists[i].ts() > 0) bts(tls.mask, i); else btr(tls.mask, i);
			if(lists[j].ts() > 0) bts(tls.mask, j); else btr(tls.mask, j);
		#endif

		// Pick the bet list
		int w = i;
		if( __builtin_expect(lists[j].ts() != 0, true) ) {
			w = (lists[i].ts() < lists[j].ts()) ? i : j;
		}

		auto & list = lists[w];
		// If list looks empty retry
		if( list.ts() == 0 ) return nullptr;

		// If we can't get the lock retry
		if( !list.lock.try_lock() ) return nullptr;

		#if VARIANT != SNZM && VARIANT != SNZI && VARIANT != DISCOVER  && VARIANT != BIAS
			__attribute__((unused)) int num = numNonEmpty;
		#endif

		// If list is empty, unlock and retry
		if( list.ts() == 0 ) {
			list.lock.unlock();
			return nullptr;
		}

		// Actually pop the list
		node_t * node;
		bool emptied;
		std::tie(node, emptied) = list.pop();
		assert(node);

		if(emptied) {
			#if VARIANT == DISCOVER
				size_t qword = w >> 6ull;
				size_t bit   = w & 63ull;
				assert(qword == 0);
				__attribute__((unused)) bool ret = btr(tls.mask, bit);
				snzm.depart(w);
			#elif VARIANT == SNZI || VARIANT == BIAS
				snzi.depart(w);
			#elif VARIANT == SNZM
				snzm.depart(w);
			#elif VARIANT == BITMASK
				numNonEmpty--;
				size_t qword = w >> 6ull;
				size_t bit   = w & 63ull;
				assert((list_mask[qword] & (1ul << bit)) != 0);
				__attribute__((unused)) bool ret = btr(list_mask[qword], bit);
				assert(ret);
				assert((list_mask[qword] & (1ul << bit)) == 0);
			#else
				numNonEmpty--;
			#endif
		}

		// Unlock and return
		list.lock.unlock();
		#if VARIANT != SNZM && VARIANT != SNZI && VARIANT != DISCOVER && VARIANT != BIAS
			assert(numNonEmpty >= 0);
		#endif
		#ifndef NO_STATS
			tls.pick.pop.success++;
			#if VARIANT != SNZM && VARIANT != SNZI && VARIANT != DISCOVER && VARIANT != BIAS
				tls.empty.pop.value += num;
				tls.empty.pop.count += 1;
			#endif
		#endif
		return node;
	}


public:

	static __attribute__((aligned(128))) thread_local struct TLS {
		Random     rng = { int(rdtscl()) };
		unsigned   my_queue = (ticket++) * 4;
		pick_stat  pick;
		empty_stat empty;
		__attribute__((aligned(64))) std::atomic_size_t mask = { 0 };
	} tls;

private:
	const unsigned numLists;
    	__attribute__((aligned(64))) std::unique_ptr<intrusive_queue_t<node_t> []> lists;
private:
	#if VARIANT == SNZI || VARIANT == BIAS
		snzi_t snzi;
	#elif VARIANT == SNZM || VARIANT == DISCOVER
		snzm_t snzm;
	#else
		std::atomic_int numNonEmpty  = { 0 };  // number of non-empty lists
	#endif
	#if VARIANT == BITMASK
		std::atomic_size_t list_mask[7] = { {0}, {0}, {0}, {0}, {0}, {0}, {0} }; // which queues are empty
	#endif

public:
	static const constexpr size_t sizeof_queue = sizeof(intrusive_queue_t<node_t>);
	static std::atomic_uint32_t ticket;

#ifndef NO_STATS
	static void stats_tls_tally() {
		global_stats.pick.push.attempt += tls.pick.push.attempt;
		global_stats.pick.push.success += tls.pick.push.success;
		global_stats.pick.push.local += tls.pick.push.local;
		global_stats.pick.pop .attempt += tls.pick.pop.attempt;
		global_stats.pick.pop .success += tls.pick.pop.success;
		global_stats.pick.pop .mask_attempt += tls.pick.pop.mask_attempt;
		global_stats.pick.pop .mask_reset += tls.pick.pop.mask_reset;
		global_stats.pick.pop .local += tls.pick.pop.local;

		global_stats.qstat.push.value += tls.empty.push.value;
		global_stats.qstat.push.count += tls.empty.push.count;
		global_stats.qstat.pop .value += tls.empty.pop .value;
		global_stats.qstat.pop .count += tls.empty.pop .count;
	}

private:
	static struct GlobalStats {
		struct {
			struct {
				std::atomic_size_t attempt = { 0 };
				std::atomic_size_t success = { 0 };
				std::atomic_size_t local = { 0 };
			} push;
			struct {
				std::atomic_size_t attempt = { 0 };
				std::atomic_size_t success = { 0 };
				std::atomic_size_t mask_attempt = { 0 };
				std::atomic_size_t mask_reset = { 0 };
				std::atomic_size_t local = { 0 };
			} pop;
		} pick;
		struct {
			struct {
				std::atomic_size_t value = { 0 };
				std::atomic_size_t count = { 0 };
			} push;
			struct {
				std::atomic_size_t value = { 0 };
				std::atomic_size_t count = { 0 };
			} pop;
		} qstat;
	} global_stats;

public:
	static void stats_print(std::ostream & os ) {
		std::cout << "----- Relaxed List Stats -----" << std::endl;

		const auto & global = global_stats;

		double push_sur = (100.0 * double(global.pick.push.success) / global.pick.push.attempt);
		double pop_sur  = (100.0 * double(global.pick.pop .success) / global.pick.pop .attempt);
		double mpop_sur = (100.0 * double(global.pick.pop .success) / global.pick.pop .mask_attempt);
		double rpop_sur = (100.0 * double(global.pick.pop .success) / global.pick.pop .mask_reset);

		double push_len = double(global.pick.push.attempt     ) / global.pick.push.success;
		double pop_len  = double(global.pick.pop .attempt     ) / global.pick.pop .success;
		double mpop_len = double(global.pick.pop .mask_attempt) / global.pick.pop .success;
		double rpop_len = double(global.pick.pop .mask_reset  ) / global.pick.pop .success;

		os << "Push   Pick   : " << push_sur << " %, len " << push_len << " (" << global.pick.push.attempt      << " / " << global.pick.push.success << ")\n";
		os << "Pop    Pick   : " << pop_sur  << " %, len " << pop_len  << " (" << global.pick.pop .attempt      << " / " << global.pick.pop .success << ")\n";
		os << "TryPop Pick   : " << mpop_sur << " %, len " << mpop_len << " (" << global.pick.pop .mask_attempt << " / " << global.pick.pop .success << ")\n";
		os << "Pop M Reset   : " << rpop_sur << " %, len " << rpop_len << " (" << global.pick.pop .mask_reset   << " / " << global.pick.pop .success << ")\n";

		double avgQ_push = double(global.qstat.push.value) / global.qstat.push.count;
		double avgQ_pop  = double(global.qstat.pop .value) / global.qstat.pop .count;
		double avgQ      = double(global.qstat.push.value + global.qstat.pop .value) / (global.qstat.push.count + global.qstat.pop .count);
		os << "Push   Avg Qs : " << avgQ_push << " (" << global.qstat.push.count << "ops)\n";
		os << "Pop    Avg Qs : " << avgQ_pop  << " (" << global.qstat.pop .count << "ops)\n";
		os << "Global Avg Qs : " << avgQ      << " (" << (global.qstat.push.count + global.qstat.pop .count) << "ops)\n";

		os << "Local Push    : " << global.pick.push.local << "\n";
		os << "Local Pop     : " << global.pick.pop .local << "\n";
	}
#endif
};