#pragma once

#include <iostream>
#include <memory>
#include <mutex>
#include <type_traits>

#include "assert.hpp"
#include "utils.hpp"

using namespace std;

extern bool enable_stats;


struct pick_stat {
	size_t attempt = 0;
	size_t success = 0;
};

extern __attribute__((aligned(64))) thread_local pick_stat local_pick;

template<typename node_t>
struct _LinksFields_t {
	node_t * prev = nullptr;
	node_t * next = nullptr;
	unsigned long long ts = 0;
};

struct spinlock_t {
	std::atomic_bool ll = { false };

	inline void lock() {
		while( __builtin_expect(ll.exchange(true),false) ) {
			while(ll.load(std::memory_order_relaxed))
				asm volatile("pause");
		}
	}

	inline void unlock() {
		ll.store(false, std::memory_order_release);
	}
};

template<typename node_t>
class relaxed_list {
	static_assert(std::is_same<decltype(node_t::_links), _LinksFields_t<node_t>>::value, "Node must have a links field");


public:
	relaxed_list(unsigned numLists)
		: numLists(numLists)
		, lists(new intrusive_queue_t[numLists])
	  	, numNonEmpty(0)
	{}

    	void push(node_t * node) {
		int i = rng_g.next() % numLists;
		lists[i].push(node, numNonEmpty);
    	}

	node_t * pop() {
		int i = pickRandomly(-1);
		int j = pickRandomly(i);

		if(i == -1) {
			return nullptr;
		}

		auto guard = lock(i, j);
		auto & list = best(i, j);
		return list.pop(numNonEmpty);
    	}

	node_t * pop2() {
		int i = pickRandomly(-1);
		int j = pickRandomly(i);

		if(i == -1) {
			return nullptr;
		}

		auto & list = best2(i, j);
		return list.pop2(numNonEmpty);
    	}

private:

	class intrusive_queue_t {
	public:
		typedef spinlock_t lock_t;

		friend class relaxed_list<node_t>;

	private:
		struct sentinel_t {
			_LinksFields_t<node_t> _links;
		};

		struct stat {
			size_t push = 0;
			size_t pop  = 0;
		};

		__attribute__((aligned(64))) lock_t lock;
		__attribute__((aligned(64))) bool empty;
		stat s;
		sentinel_t before;
		sentinel_t after;

		static constexpr auto fields_offset = offsetof( node_t, _links );
	public:
		intrusive_queue_t()
			: empty(true)
			, before{{ nullptr, tail() }}
			, after {{ head(), nullptr }}
		{
			assert((reinterpret_cast<uintptr_t>( head() ) + fields_offset) == reinterpret_cast<uintptr_t>(&before));
			assert((reinterpret_cast<uintptr_t>( tail() ) + fields_offset) == reinterpret_cast<uintptr_t>(&after ));
			assert(head()->_links.prev == nullptr);
			assert(head()->_links.next == tail() );
			assert(tail()->_links.next == nullptr);
			assert(tail()->_links.prev == head() );
		}

		~intrusive_queue_t() {
			std::cout << " Push: " << s.push << "\tPop: " << s.pop << "\t(this: " << this << ")" << std::endl;
		}

		node_t * head() const {
			return reinterpret_cast<node_t *>(
				reinterpret_cast<uintptr_t>( &before ) - fields_offset
			);
		}

		node_t * tail() const {
			return reinterpret_cast<node_t *>(
				reinterpret_cast<uintptr_t>( &after ) - fields_offset
			);
		}

		void push(node_t * node, volatile int & nonEmpty) {
			node_t * tail = this->tail();
			std::lock_guard<lock_t> guard(lock);
			node->_links.ts = rdtscl();

			node_t * prev = tail->_links.prev;
			// assertf(node->_links.ts >= prev->_links.ts,
			// "New node has smaller timestamp: %llu < %llu", node->_links.ts, prev->_links.ts);
			node->_links.next = tail;
			node->_links.prev = prev;
			prev->_links.next = node;
			tail->_links.prev = node;
			if(empty) {
				__atomic_fetch_add(&nonEmpty, 1, __ATOMIC_SEQ_CST);
				empty = false;
			}
			if(enable_stats) s.push++;
		}

		node_t * pop(volatile int & nonEmpty) {
			node_t * head = this->head();
			node_t * tail = this->tail();

			node_t * node = head->_links.next;
			node_t * next = node->_links.next;
			if(node == tail) return nullptr;

			head->_links.next = next;
			next->_links.prev = head;

			if(next == tail) {
				empty = true;
				__atomic_fetch_sub(&nonEmpty, 1, __ATOMIC_SEQ_CST);
			}
			if(enable_stats) s.pop++;
			return node;
		}

		node_t * pop2(volatile int & nonEmpty) {
			node_t * head = this->head();
			node_t * tail = this->tail();

			std::lock_guard<lock_t> guard(lock);
			node_t * node = head->_links.next;
			node_t * next = node->_links.next;
			if(node == tail) return nullptr;

			head->_links.next = next;
			next->_links.prev = head;

			if(next == tail) {
				empty = true;
				__atomic_fetch_sub(&nonEmpty, 1, __ATOMIC_SEQ_CST);
			}
			if(enable_stats) s.pop++;
			return node;
		}

		static intrusive_queue_t & best(intrusive_queue_t & lhs, intrusive_queue_t & rhs) {
			bool lhs_empty = lhs.empty;
			bool rhs_empty = rhs.empty;

			if(lhs_empty && rhs_empty) return lhs;
			if(!lhs_empty && rhs_empty) return lhs;
			if(lhs_empty && !rhs_empty) return rhs;
			node_t * lhs_head = lhs.head()->_links.next;
			node_t * rhs_head = rhs.head()->_links.next;

			assert(lhs_head != lhs.tail());
			assert(rhs_head != rhs.tail());

			if(lhs_head->_links.ts < lhs_head->_links.ts) {
				return lhs;
			} else {
				return rhs;
			}
		}

		static intrusive_queue_t & best2(intrusive_queue_t & lhs, intrusive_queue_t & rhs) {
			node_t * lhs_head = lhs.head()->_links.next;
			node_t * rhs_head = rhs.head()->_links.next;

			bool lhs_empty = lhs_head != lhs.tail();
			bool rhs_empty = rhs_head != rhs.tail();
			if(lhs_empty && rhs_empty) return lhs;
			if(!lhs_empty && rhs_empty) return lhs;
			if(lhs_empty && !rhs_empty) return rhs;

			if(lhs_head->_links.ts < lhs_head->_links.ts) {
				return lhs;
			} else {
				return rhs;
			}
		}
	};


private:

	static thread_local Random rng_g;
    	__attribute__((aligned(64))) const unsigned numLists;
	std::unique_ptr<intrusive_queue_t []> lists;
	__attribute__((aligned(64))) volatile int numNonEmpty; // number of non-empty lists


private:



private:
	int pickRandomly(const int avoid) {
		int j;
		do {
			local_pick.attempt++;
			j = rng_g.next() % numLists;
			if (numNonEmpty < 1 + (avoid != -1)) return -1;
		} while (j == avoid || lists[j].empty);
		local_pick.success++;
		return j;
	}

private:

	struct queue_guard {
		intrusive_queue_t * lists;
		int i, j;

		queue_guard(intrusive_queue_t * lists, int i, int j)
			: lists(lists), i(i), j(j)
		{
			if(i >= 0) lists[i].lock.lock();
			if(j >= 0) lists[j].lock.lock();
		}

		queue_guard(const queue_guard &) = delete;
		queue_guard(queue_guard &&) = default;

		~queue_guard() {
			if(i >= 0) lists[i].lock.unlock();
			if(j >= 0) lists[j].lock.unlock();
		}
	};

	auto lock(int i, int j) {
		assert(i >= 0);
		assert(i != j);
		if(j < i) return queue_guard(lists.get(), j, i);
		return queue_guard(lists.get(), i, j);
	}

	intrusive_queue_t & best(int i, int j) {
		assert(i != -1);
		if(j == -1) return lists[i];
		return intrusive_queue_t::best(lists[i], lists[j]);
	}

	intrusive_queue_t & best2(int i, int j) {
		assert(i != -1);
		if(j == -1) return lists[i];
		return intrusive_queue_t::best2(lists[i], lists[j]);
	}
};
