#include "relaxed_list.hpp"

#include <array>
#include <iostream>
#include <locale>
#include <string>
#include <thread>
#include <vector>

#include <unistd.h>
#include <sys/sysinfo.h>

#include "utils.hpp"

struct Node {
	static std::atomic_size_t creates;
	static std::atomic_size_t destroys;

	_LinksFields_t<Node> _links;

	int value;
	Node(int value): value(value) {
		creates++;
	}

	~Node() {
		destroys++;
	}
};

std::atomic_size_t Node::creates  = { 0 };
std::atomic_size_t Node::destroys = { 0 };

static const constexpr int nodes_per_threads = 128;

bool enable_stats = false;

__attribute__((aligned(64))) thread_local pick_stat local_pick;

void run(unsigned nthread, double duration) {
	// List being tested
	relaxed_list<Node> list = { nthread * 2 };

	// Barrier for synchronization
	barrier_t barrier(nthread + 1);

	// Data to check everything is OK
	struct {
		std::atomic_size_t in  = { 0 };
		std::atomic_size_t out = { 0 };
		std::atomic_size_t empty = { 0 };
		std::atomic_size_t crc_in  = { 0 };
		std::atomic_size_t crc_out = { 0 };
		std::atomic_size_t pick_at = { 0 };
		std::atomic_size_t pick_su = { 0 };
	} global;

	// Flag to signal termination
	std::atomic_bool done  = { false };

	// Prep nodes
	std::cout << "Initializing" << std::endl;
	std::vector<Node *> all_nodes[nthread];
	for(auto & nodes : all_nodes) {
		Random rand(rdtscl());
		nodes.resize(nodes_per_threads);
		for(auto & node : nodes) {
			node = new Node(rand.next() % 100);
		}

		for(int i = 0; i < 10; i++) {
			int idx = rand.next() % nodes_per_threads;
			if (auto node = nodes[idx]) {
				global.crc_in += node->value;
				list.push(node);
				nodes[idx] = nullptr;
			}
		}
	}

	enable_stats = true;

	std::thread * threads[nthread];
	unsigned i = 1;
	for(auto & t : threads) {
		auto & my_nodes = all_nodes[i - 1];
		t = new std::thread([&done, &list, &barrier, &global, &my_nodes](unsigned tid) {
			Random rand(tid + rdtscl());

			size_t local_in  = 0;
			size_t local_out = 0;
			size_t local_empty = 0;
			size_t local_crc_in  = 0;
			size_t local_crc_out = 0;

			affinity(tid);

			barrier.wait(tid);

			// EXPERIMENT START

			while(__builtin_expect(!done, true)) {
				int idx = rand.next() % nodes_per_threads;
				if (auto node = my_nodes[idx]) {
					local_crc_in += node->value;
					list.push(node);
					my_nodes[idx] = nullptr;
					local_in++;
				}
				else if(auto node = list.pop2()) {
					local_crc_out += node->value;
					my_nodes[idx] = node;
					local_out++;
				}
				else {
					local_empty++;
				}
			}

			// EXPERIMENT END

			barrier.wait(tid);

			global.in    += local_in;
			global.out   += local_out;
			global.empty += local_empty;

			for(auto node : my_nodes) {
				delete node;
			}

			global.crc_in  += local_crc_in;
			global.crc_out += local_crc_out;

			global.pick_at += local_pick.attempt;
			global.pick_su += local_pick.success;
		}, i++);
	}

	std::cout << "Starting" << std::endl;
	auto before = Clock::now();
	barrier.wait(0);

	while(true) {
		usleep(1000);
		auto now = Clock::now();
		duration_t durr = now - before;
		if( durr.count() > duration ) {
			done = true;
			break;
		}
	}

	barrier.wait(0);
	auto after = Clock::now();
	duration_t durr = after - before;
	duration = durr.count();
	std::cout << "Closing down" << std::endl;

	for(auto t : threads) {
		t->join();
		delete t;
	}

	enable_stats = false;

	while(auto node = list.pop()) {
		global.crc_out += node->value;
		delete node;
	}

	assert(Node::creates == Node::destroys);
	assert(global.crc_in == global.crc_out);

	std::cout << "Done" << std::endl;

	size_t ops = global.in + global.out;
	size_t ops_sec = size_t(double(ops) / duration);
	size_t ops_thread = ops_sec / nthread;
	auto dur_nano = duration_cast<std::nano>(1.0);

	std::cout << "Duration      : " << duration << "s\n";
	std::cout << "Total ops     : " << ops << "(" << global.in << "i, " << global.out << "o, " << global.empty << "e)\n";
	std::cout << "Ops/sec       : " << ops_sec << "\n";
	std::cout << "Ops/sec/thread: " << ops_thread << "\n";
	std::cout << "ns/Op         : " << ( dur_nano / ops_thread )<< "\n";
	std::cout << "Pick %        : " << (100.0 * double(global.pick_su) / global.pick_at) << "(" << global.pick_su << " / " << global.pick_at << ")\n";
}

void usage(char * argv[]) {
	std::cerr << argv[0] << ": [DURATION (FLOAT:SEC)] [NTHREADS]" << std::endl;;
	std::exit(1);
}

int main(int argc, char * argv[]) {

	double duration   = 5.0;
	unsigned nthreads = 2;

	std::cout.imbue(std::locale(""));

	switch (argc)
	{
	case 3:
		nthreads = std::stoul(argv[2]);
		[[fallthrough]];
	case 2:
		duration = std::stod(argv[1]);
		if( duration <= 0.0 ) {
			std::cerr << "Duration must be positive, was " << argv[1] << "(" << duration << ")" << std::endl;
			usage(argv);
		}
		[[fallthrough]];
	case 1:
		break;
	default:
		usage(argv);
		break;
	}

	check_cache_line_size();

	std::cout << "Running " << nthreads << " threads for " << duration << " seconds" << std::endl;
	run(nthreads, duration);

	return 0;
}

template<>
thread_local Random relaxed_list<Node>::rng_g = { int(rdtscl()) };