#pragma once

#ifndef NO_STATS
#include <iostream>
#endif

#include <memory>
#include <mutex>
#include <type_traits>

#include "assert.hpp"
#include "utils.hpp"

using namespace std;

struct spinlock_t {
	std::atomic_bool ll = { false };

	inline void lock() {
		while( __builtin_expect(ll.exchange(true),false) ) {
			while(ll.load(std::memory_order_relaxed))
				asm volatile("pause");
		}
	}

	inline bool try_lock() {
		return false == ll.exchange(true);
	}

	inline void unlock() {
		ll.store(false, std::memory_order_release);
	}

	inline explicit operator bool() {
		return ll.load(std::memory_order_relaxed);
	}
};


extern bool enable_stats;

struct pick_stat {
	struct {
		size_t attempt = 0;
		size_t success = 0;
	} push;
	struct {
		size_t attempt = 0;
		size_t success = 0;
	} pop;
};

template<typename node_t>
struct _LinksFields_t {
	node_t * prev = nullptr;
	node_t * next = nullptr;
	unsigned long long ts = 0;
};

template<typename node_t>
class __attribute__((aligned(128))) relaxed_list {
	static_assert(std::is_same<decltype(node_t::_links), _LinksFields_t<node_t>>::value, "Node must have a links field");


public:
	relaxed_list(unsigned numLists)
	  	: numNonEmpty{0}
		, lists(new intrusive_queue_t[numLists])
		, numLists(numLists)
	{}

	~relaxed_list() {
		lists.reset();
		#ifndef NO_STATS
			std::cout << "Difference   : "
				<< ssize_t(double(intrusive_queue_t::stat::dif.value) / intrusive_queue_t::stat::dif.num  ) << " avg\t"
				<< intrusive_queue_t::stat::dif.max << "max" << std::endl;
		#endif
	}

    	__attribute__((noinline, hot)) void push(node_t * node) {
		node->_links.ts = rdtscl();

		while(true) {
			// Pick a random list
			int i = tls.rng.next() % numLists;

			#ifndef NO_STATS
				tls.pick.push.attempt++;
			#endif

			// If we can't lock it retry
			if( !lists[i].lock.try_lock() ) continue;

			// Actually push it
			lists[i].push(node, numNonEmpty);
			assert(numNonEmpty <= (int)numLists);

			// Unlock and return
			lists[i].lock.unlock();

			#ifndef NO_STATS
				tls.pick.push.success++;
			#endif
			return;
		}
    	}

	__attribute__((noinline, hot)) node_t * pop() {
		while(numNonEmpty != 0) {
			// Pick two lists at random
			int i = tls.rng.next() % numLists;
			int j = tls.rng.next() % numLists;

			#ifndef NO_STATS
				tls.pick.pop.attempt++;
			#endif

			// Pick the bet list
			int w = i;
			if( __builtin_expect(lists[j].ts() != 0, true) ) {
				w = (lists[i].ts() < lists[j].ts()) ? i : j;
			}

			auto & list = lists[w];
			// If list looks empty retry
			if( list.ts() == 0 ) continue;

			// If we can't get the lock retry
			if( !list.lock.try_lock() ) continue;

			// If list is empty, unlock and retry
			if( list.ts() == 0 ) {
				list.lock.unlock();
				continue;
			}

			// Actually pop the list
			auto node = list.pop(numNonEmpty);
			assert(node);

			// Unlock and return
			list.lock.unlock();
			assert(numNonEmpty >= 0);
			#ifndef NO_STATS
				tls.pick.pop.success++;
			#endif
			return node;
		}

		return nullptr;
    	}

private:

	class __attribute__((aligned(128))) intrusive_queue_t {
	public:
		typedef spinlock_t lock_t;

		friend class relaxed_list<node_t>;

		struct stat {
			ssize_t diff = 0;

			static struct Dif {
				ssize_t value = 0;
				size_t  num   = 0;
				ssize_t max   = 0;
			} dif;
		};

	private:
		struct sentinel_t {
			_LinksFields_t<node_t> _links;
		};

		lock_t lock;
		sentinel_t before;
		sentinel_t after;
		stat s;

		static constexpr auto fields_offset = offsetof( node_t, _links );
	public:
		intrusive_queue_t()
			: before{{ nullptr, tail() }}
			, after {{ head(), nullptr }}
		{
			assert((reinterpret_cast<uintptr_t>( head() ) + fields_offset) == reinterpret_cast<uintptr_t>(&before));
			assert((reinterpret_cast<uintptr_t>( tail() ) + fields_offset) == reinterpret_cast<uintptr_t>(&after ));
			assert(head()->_links.prev == nullptr);
			assert(head()->_links.next == tail() );
			assert(tail()->_links.next == nullptr);
			assert(tail()->_links.prev == head() );
			assert(sizeof(*this) == 128);
			assert((intptr_t(this) % 128) == 0);
		}

		~intrusive_queue_t() {
			#ifndef NO_STATS
				stat::dif.value+= s.diff;
				stat::dif.num  ++;
				stat::dif.max  = std::abs(stat::dif.max) > std::abs(s.diff) ? stat::dif.max : s.diff;
			#endif
		}

		inline node_t * head() const {
			node_t * rhead = reinterpret_cast<node_t *>(
				reinterpret_cast<uintptr_t>( &before ) - fields_offset
			);
			assert(rhead);
			return rhead;
		}

		inline node_t * tail() const {
			node_t * rtail = reinterpret_cast<node_t *>(
				reinterpret_cast<uintptr_t>( &after ) - fields_offset
			);
			assert(rtail);
			return rtail;
		}

		inline void push(node_t * node, std::atomic_int & nonEmpty) {
			assert(lock);
			assert(node->_links.ts != 0);
			node_t * tail = this->tail();

			node_t * prev = tail->_links.prev;
			// assertf(node->_links.ts >= prev->_links.ts,
			// 	"New node has smaller timestamp: %llu < %llu", node->_links.ts, prev->_links.ts);
			node->_links.next = tail;
			node->_links.prev = prev;
			prev->_links.next = node;
			tail->_links.prev = node;
			if(before._links.ts == 0l) {
				nonEmpty += 1;
				before._links.ts = node->_links.ts;
			}
			#ifndef NO_STATS
				if(enable_stats) s.diff++;
			#endif
		}

		inline node_t * pop(std::atomic_int & nonEmpty) {
			assert(lock);
			node_t * head = this->head();
			node_t * tail = this->tail();

			node_t * node = head->_links.next;
			node_t * next = node->_links.next;
			if(node == tail) return nullptr;

			head->_links.next = next;
			next->_links.prev = head;

			if(next == tail) {
				before._links.ts = 0l;
				nonEmpty -= 1;
			}
			else {
				assert(next->_links.ts != 0);
				before._links.ts = next->_links.ts;
				assert(before._links.ts != 0);
			}
			#ifndef NO_STATS
				if(enable_stats) s.diff--;
			#endif
			return node;
		}

		long long ts() const {
			return before._links.ts;
		}
	};


public:

	static __attribute__((aligned(128))) thread_local struct TLS {
		Random    rng = { int(rdtscl()) };
		pick_stat pick;
	} tls;

private:
	std::atomic_int numNonEmpty; // number of non-empty lists
    	__attribute__((aligned(64))) std::unique_ptr<intrusive_queue_t []> lists;
	const unsigned numLists;

public:
	static const constexpr size_t sizeof_queue = sizeof(intrusive_queue_t);
};