#include "relaxed_list.hpp"

#include <array>
#include <iomanip>
#include <iostream>
#include <locale>
#include <string>
#include <thread>
#include <vector>

#include <unistd.h>
#include <sys/sysinfo.h>

#include "utils.hpp"

struct __attribute__((aligned(64))) Node {
	static std::atomic_size_t creates;
	static std::atomic_size_t destroys;

	_LinksFields_t<Node> _links;

	int value;
	Node(int value): value(value) {
		creates++;
	}

	~Node() {
		destroys++;
	}
};

std::atomic_size_t Node::creates  = { 0 };
std::atomic_size_t Node::destroys = { 0 };

static const constexpr int nodes_per_threads = 128;
struct NodeArray {
	__attribute__((aligned(64))) Node * array[nodes_per_threads];
	__attribute__((aligned(64))) char pad;
};

bool enable_stats = false;

struct local_stat_t {
	size_t in  = 0;
	size_t out = 0;
	size_t empty = 0;
	size_t crc_in  = 0;
	size_t crc_out = 0;
};

__attribute__((noinline)) void run_body(
	std::atomic<bool>& done,
	Random & rand,
	Node * (&my_nodes)[128],
	local_stat_t & local,
	relaxed_list<Node> & list
) {
	while(__builtin_expect(!done.load(std::memory_order_relaxed), true)) {
		int idx = rand.next() % nodes_per_threads;
		if (auto node = my_nodes[idx]) {
			local.crc_in += node->value;
			list.push(node);
			my_nodes[idx] = nullptr;
			local.in++;
		}
		else if(auto node = list.pop()) {
			local.crc_out += node->value;
			my_nodes[idx] = node;
			local.out++;
		}
		else {
			local.empty++;
		}
	}
}

void run(unsigned nthread, unsigned nqueues, unsigned fill, double duration) {
	// List being tested
	relaxed_list<Node> list = { nthread * nqueues };

	// Barrier for synchronization
	barrier_t barrier(nthread + 1);

	// Data to check everything is OK
	struct {
		std::atomic_size_t in  = { 0 };
		std::atomic_size_t out = { 0 };
		std::atomic_size_t empty = { 0 };
		std::atomic_size_t crc_in  = { 0 };
		std::atomic_size_t crc_out = { 0 };
		struct {
			struct {
				std::atomic_size_t attempt = { 0 };
				std::atomic_size_t success = { 0 };
			} push;
			struct {
				std::atomic_size_t attempt = { 0 };
				std::atomic_size_t success = { 0 };
			} pop;
		} pick;
	} global;

	// Flag to signal termination
	std::atomic_bool done  = { false };

	// Prep nodes
	std::cout << "Initializing ";
	size_t nnodes  = 0;
	size_t npushed = 0;
	NodeArray all_nodes[nthread];
	for(auto & nodes : all_nodes) {
		Random rand(rdtscl());
		for(auto & node : nodes.array) {
			auto r = rand.next() % 100;
			if(r < fill) {
				node = new Node(rand.next() % 100);
				nnodes++;
			} else {
				node = nullptr;
			}
		}

		for(int i = 0; i < 10; i++) {
			int idx = rand.next() % nodes_per_threads;
			if (auto node = nodes.array[idx]) {
				global.crc_in += node->value;
				list.push(node);
				npushed++;
				nodes.array[idx] = nullptr;
			}
		}
	}

	std::cout << nnodes << " nodes " << fill << "% (" << npushed << " pushed)" << std::endl;

	enable_stats = true;

	std::thread * threads[nthread];
	unsigned i = 1;
	for(auto & t : threads) {
		auto & my_nodes = all_nodes[i - 1].array;
		t = new std::thread([&done, &list, &barrier, &global, &my_nodes](unsigned tid) {
			Random rand(tid + rdtscl());

			local_stat_t local;

			// affinity(tid);

			barrier.wait(tid);

			// EXPERIMENT START

			run_body(done, rand, my_nodes, local, list);

			// EXPERIMENT END

			barrier.wait(tid);

			global.in    += local.in;
			global.out   += local.out;
			global.empty += local.empty;

			for(auto node : my_nodes) {
				delete node;
			}

			global.crc_in  += local.crc_in;
			global.crc_out += local.crc_out;

			global.pick.push.attempt += relaxed_list<Node>::tls.pick.push.attempt;
			global.pick.push.success += relaxed_list<Node>::tls.pick.push.success;
			global.pick.pop .attempt += relaxed_list<Node>::tls.pick.pop.attempt;
			global.pick.pop .success += relaxed_list<Node>::tls.pick.pop.success;
		}, i++);
	}

	std::cout << "Starting" << std::endl;
	auto before = Clock::now();
	barrier.wait(0);

	while(true) {
		usleep(100000);
		auto now = Clock::now();
		duration_t durr = now - before;
		if( durr.count() > duration ) {
			done = true;
			break;
		}
		std::cout << "\r" << std::setprecision(4) << durr.count();
		std::cout.flush();
	}

	barrier.wait(0);
	auto after = Clock::now();
	duration_t durr = after - before;
	duration = durr.count();
	std::cout << "\rClosing down" << std::endl;

	for(auto t : threads) {
		t->join();
		delete t;
	}

	enable_stats = false;

	while(auto node = list.pop()) {
		global.crc_out += node->value;
		delete node;
	}

	assert(Node::creates == Node::destroys);
	assert(global.crc_in == global.crc_out);

	std::cout << "Done" << std::endl;

	size_t ops = global.in + global.out;
	size_t ops_sec = size_t(double(ops) / duration);
	size_t ops_thread = ops_sec / nthread;
	auto dur_nano = duration_cast<std::nano>(1.0);

	std::cout << "Duration      : " << duration << "s\n";
	std::cout << "ns/Op         : " << ( dur_nano / ops_thread )<< "\n";
	std::cout << "Ops/sec/thread: " << ops_thread << "\n";
	std::cout << "Ops/sec       : " << ops_sec << "\n";
	std::cout << "Total ops     : " << ops << "(" << global.in << "i, " << global.out << "o, " << global.empty << "e)\n";
	#ifndef NO_STATS
		double push_sur = (100.0 * double(global.pick.push.success) / global.pick.push.attempt);
		double pop_sur  = (100.0 * double(global.pick.pop .success) / global.pick.pop .attempt);
		std::cout << "Push Pick %   : " << push_sur << "(" << global.pick.push.success << " / " << global.pick.push.attempt << ")\n";
		std::cout << "Pop  Pick %   : " << pop_sur  << "(" << global.pick.pop .success << " / " << global.pick.pop .attempt << ")\n";
	#endif
}

void usage(char * argv[]) {
	std::cerr << argv[0] << ": [DURATION (FLOAT:SEC)] [NTHREADS] [NQUEUES] [FILL]" << std::endl;;
	std::exit(1);
}

int main(int argc, char * argv[]) {

	double duration   = 5.0;
	unsigned nthreads = 2;
	unsigned nqueues  = 2;
	unsigned fill     = 100;

	std::cout.imbue(std::locale(""));

	switch (argc)
	{
	case 5:
		fill = std::stoul(argv[4]);
		[[fallthrough]];
	case 4:
		nqueues = std::stoul(argv[3]);
		[[fallthrough]];
	case 3:
		nthreads = std::stoul(argv[2]);
		[[fallthrough]];
	case 2:
		duration = std::stod(argv[1]);
		if( duration <= 0.0 ) {
			std::cerr << "Duration must be positive, was " << argv[1] << "(" << duration << ")" << std::endl;
			usage(argv);
		}
		[[fallthrough]];
	case 1:
		break;
	default:
		usage(argv);
		break;
	}

	check_cache_line_size();

	std::cout << "Running " << nthreads << " threads (" << (nthreads * nqueues) << " queues) for " << duration << " seconds" << std::endl;
	run(nthreads, nqueues, fill, duration);

	return 0;
}

template<>
thread_local relaxed_list<Node>::TLS relaxed_list<Node>::tls = {};

template<>
relaxed_list<Node>::intrusive_queue_t::stat::Dif relaxed_list<Node>::intrusive_queue_t::stat::dif = {};

const char * __my_progname = "Relaxed List";