#include "relaxed_list.hpp"

#include <array>
#include <iomanip>
#include <iostream>
#include <locale>
#include <string>
#include <thread>
#include <vector>

#include <getopt.h>
#include <unistd.h>
#include <sys/sysinfo.h>

#include "utils.hpp"

struct __attribute__((aligned(64))) Node {
	static std::atomic_size_t creates;
	static std::atomic_size_t destroys;

	_LinksFields_t<Node> _links;

	int value;

	Node() { creates++; }
	Node(int value): value(value) { creates++; }
	~Node() { destroys++; }
};

std::atomic_size_t Node::creates  = { 0 };
std::atomic_size_t Node::destroys = { 0 };

bool enable_stats = false;

template<>
thread_local relaxed_list<Node>::TLS relaxed_list<Node>::tls = {};

template<>
relaxed_list<Node>::intrusive_queue_t::stat::Dif relaxed_list<Node>::intrusive_queue_t::stat::dif = {};

// ================================================================================================
//                        UTILS
// ================================================================================================

struct local_stat_t {
	size_t in  = 0;
	size_t out = 0;
	size_t empty = 0;
	size_t crc_in  = 0;
	size_t crc_out = 0;
};

struct global_stat_t {
	std::atomic_size_t in  = { 0 };
	std::atomic_size_t out = { 0 };
	std::atomic_size_t empty = { 0 };
	std::atomic_size_t crc_in  = { 0 };
	std::atomic_size_t crc_out = { 0 };
	struct {
		struct {
			std::atomic_size_t attempt = { 0 };
			std::atomic_size_t success = { 0 };
		} push;
		struct {
			std::atomic_size_t attempt = { 0 };
			std::atomic_size_t success = { 0 };
			std::atomic_size_t mask_attempt = { 0 };
		} pop;
	} pick;
	struct {
		struct {
			std::atomic_size_t value = { 0 };
			std::atomic_size_t count = { 0 };
		} push;
		struct {
			std::atomic_size_t value = { 0 };
			std::atomic_size_t count = { 0 };
		} pop;
	} qstat;
};

void tally_stats(global_stat_t & global, local_stat_t & local) {
	global.in    += local.in;
	global.out   += local.out;
	global.empty += local.empty;

	global.crc_in  += local.crc_in;
	global.crc_out += local.crc_out;

	global.pick.push.attempt += relaxed_list<Node>::tls.pick.push.attempt;
	global.pick.push.success += relaxed_list<Node>::tls.pick.push.success;
	global.pick.pop .attempt += relaxed_list<Node>::tls.pick.pop.attempt;
	global.pick.pop .success += relaxed_list<Node>::tls.pick.pop.success;
	global.pick.pop .mask_attempt += relaxed_list<Node>::tls.pick.pop.mask_attempt;

	global.qstat.push.value += relaxed_list<Node>::tls.empty.push.value;
	global.qstat.push.count += relaxed_list<Node>::tls.empty.push.count;
	global.qstat.pop .value += relaxed_list<Node>::tls.empty.pop .value;
	global.qstat.pop .count += relaxed_list<Node>::tls.empty.pop .count;
}

void waitfor(double & duration, barrier_t & barrier, std::atomic_bool & done) {
	std::cout << "Starting" << std::endl;
	auto before = Clock::now();
	barrier.wait(0);

	while(true) {
		usleep(100000);
		auto now = Clock::now();
		duration_t durr = now - before;
		if( durr.count() > duration ) {
			done = true;
			break;
		}
		std::cout << "\r" << std::setprecision(4) << durr.count();
		std::cout.flush();
	}

	barrier.wait(0);
	auto after = Clock::now();
	duration_t durr = after - before;
	duration = durr.count();
	std::cout << "\rClosing down" << std::endl;
}

void print_stats(double duration, unsigned nthread, global_stat_t & global) {
	assert(Node::creates == Node::destroys);
	assert(global.crc_in == global.crc_out);

	std::cout << "Done" << std::endl;

	size_t ops = global.in + global.out;
	size_t ops_sec = size_t(double(ops) / duration);
	size_t ops_thread = ops_sec / nthread;
	auto dur_nano = duration_cast<std::nano>(1.0);

	std::cout << "Duration      : " << duration << "s\n";
	std::cout << "ns/Op         : " << ( dur_nano / ops_thread )<< "\n";
	std::cout << "Ops/sec/thread: " << ops_thread << "\n";
	std::cout << "Ops/sec       : " << ops_sec << "\n";
	std::cout << "Total ops     : " << ops << "(" << global.in << "i, " << global.out << "o, " << global.empty << "e)\n";
	#ifndef NO_STATS
		double push_sur = (100.0 * double(global.pick.push.success) / global.pick.push.attempt);
		double pop_sur  = (100.0 * double(global.pick.pop .success) / global.pick.pop .attempt);

		std::cout << "Push Pick %   : " << push_sur << "(" << global.pick.push.success << " / " << global.pick.push.attempt << ")\n";
		std::cout << "Pop  Pick %   : " << pop_sur  << "(" << global.pick.pop .success << " / " << global.pick.pop .attempt << ")\n";
		std::cout << "Pop mask trys : " << global.pick.pop.mask_attempt << std::endl;

		double avgQ_push = double(global.qstat.push.value) / global.qstat.push.count;
		double avgQ_pop  = double(global.qstat.pop .value) / global.qstat.pop .count;
		double avgQ      = double(global.qstat.push.value + global.qstat.pop .value) / (global.qstat.push.count + global.qstat.pop .count);
		std::cout << "Push   Avg Qs : " << avgQ_push << " (" << global.qstat.push.count << "ops)\n";
		std::cout << "Pop    Avg Qs : " << avgQ_pop  << " (" << global.qstat.pop .count << "ops)\n";
		std::cout << "Global Avg Qs : " << avgQ      << " (" << (global.qstat.push.count + global.qstat.pop .count) << "ops)\n";
	#endif
}

// ================================================================================================
//                        EXPERIMENTS
// ================================================================================================

// ================================================================================================
__attribute__((noinline)) void runChurn_body(
	std::atomic<bool>& done,
	Random & rand,
	Node * my_nodes[],
	unsigned nslots,
	local_stat_t & local,
	relaxed_list<Node> & list
) {
	while(__builtin_expect(!done.load(std::memory_order_relaxed), true)) {
		int idx = rand.next() % nslots;
		if (auto node = my_nodes[idx]) {
			local.crc_in += node->value;
			list.push(node);
			my_nodes[idx] = nullptr;
			local.in++;
		}
		else if(auto node = list.pop()) {
			local.crc_out += node->value;
			my_nodes[idx] = node;
			local.out++;
		}
		else {
			local.empty++;
		}
	}
}

void runChurn(unsigned nthread, unsigned nqueues, double duration, unsigned nnodes, const unsigned nslots) {
	std::cout << "Churn Benchmark" << std::endl;
	assert(nnodes <= nslots);
	// List being tested
	relaxed_list<Node> list = { nthread * nqueues };

	// Barrier for synchronization
	barrier_t barrier(nthread + 1);

	// Data to check everything is OK
	global_stat_t global;

	// Flag to signal termination
	std::atomic_bool done  = { false };

	// Prep nodes
	std::cout << "Initializing ";
	size_t npushed = 0;

	Node** all_nodes[nthread];
	for(auto & nodes : all_nodes) {
		nodes = new __attribute__((aligned(64))) Node*[nslots + 8];
		Random rand(rdtscl());
		for(unsigned i = 0; i < nnodes; i++) {
			nodes[i] = new Node(rand.next() % 100);
		}

		for(unsigned i = nnodes; i < nslots; i++) {
			nodes[i] = nullptr;
		}

		for(int i = 0; i < 10 && i < (int)nslots; i++) {
			int idx = rand.next() % nslots;
			if (auto node = nodes[idx]) {
				global.crc_in += node->value;
				list.push(node);
				npushed++;
				nodes[idx] = nullptr;
			}
		}
	}

	std::cout << nnodes << " nodes (" << nslots << " slots)" << std::endl;

	enable_stats = true;

	std::thread * threads[nthread];
	unsigned i = 1;
	for(auto & t : threads) {
		auto & my_nodes = all_nodes[i - 1];
		t = new std::thread([&done, &list, &barrier, &global, &my_nodes, nslots](unsigned tid) {
			Random rand(tid + rdtscl());

			local_stat_t local;

			// affinity(tid);

			barrier.wait(tid);

			// EXPERIMENT START

			runChurn_body(done, rand, my_nodes, nslots, local, list);

			// EXPERIMENT END

			barrier.wait(tid);

			tally_stats(global, local);

			for(unsigned i = 0; i < nslots; i++) {
				delete my_nodes[i];
			}
		}, i++);
	}

	waitfor(duration, barrier, done);

	for(auto t : threads) {
		t->join();
		delete t;
	}

	enable_stats = false;

	while(auto node = list.pop()) {
		global.crc_out += node->value;
		delete node;
	}

	for(auto nodes : all_nodes) {
		delete[] nodes;
	}

	print_stats(duration, nthread, global);
}

// ================================================================================================
__attribute__((noinline)) void runPingPong_body(
	std::atomic<bool>& done,
	Node initial_nodes[],
	unsigned nnodes,
	local_stat_t & local,
	relaxed_list<Node> & list
) {
	Node * nodes[nnodes];
	{
		unsigned i = 0;
		for(auto & n : nodes) {
			n = &initial_nodes[i++];
		}
	}

	while(__builtin_expect(!done.load(std::memory_order_relaxed), true)) {

		for(Node * & node : nodes) {
			local.crc_in += node->value;
			list.push(node);
			local.in++;
		}

		// -----

		for(Node * & node : nodes) {
			node = list.pop();
			assert(node);
			local.crc_out += node->value;
			local.out++;
		}
	}
}

void runPingPong(unsigned nthread, unsigned nqueues, double duration, unsigned nnodes) {
	std::cout << "PingPong Benchmark" << std::endl;

	// List being tested
	relaxed_list<Node> list = { nthread * nqueues };

	// Barrier for synchronization
	barrier_t barrier(nthread + 1);

	// Data to check everything is OK
	global_stat_t global;

	// Flag to signal termination
	std::atomic_bool done  = { false };

	std::cout << "Initializing ";
	enable_stats = true;

	std::thread * threads[nthread];
	unsigned i = 1;
	for(auto & t : threads) {
		t = new std::thread([&done, &list, &barrier, &global, nnodes](unsigned tid) {
			Random rand(tid + rdtscl());

			Node nodes[nnodes];
			for(auto & n : nodes) {
				n.value = (int)rand.next() % 100;
			}

			local_stat_t local;

			// affinity(tid);

			barrier.wait(tid);

			// EXPERIMENT START

			runPingPong_body(done, nodes, nnodes, local, list);

			// EXPERIMENT END

			barrier.wait(tid);

			tally_stats(global, local);
		}, i++);
	}

	waitfor(duration, barrier, done);

	for(auto t : threads) {
		t->join();
		delete t;
	}

	enable_stats = false;

	print_stats(duration, nthread, global);
}

bool iequals(const std::string& a, const std::string& b)
{
    return std::equal(a.begin(), a.end(),
                      b.begin(), b.end(),
                      [](char a, char b) {
                          return std::tolower(a) == std::tolower(b);
                      });
}

int main(int argc, char * argv[]) {

	double duration   = 5.0;
	unsigned nthreads = 2;
	unsigned nqueues  = 4;
	unsigned nnodes   = 100;
	unsigned nslots   = 100;

	enum {
		Churn,
		PingPong,
		NONE
	} benchmark = NONE;

	std::cout.imbue(std::locale(""));

	for(;;) {
		static struct option options[] = {
			{"duration",  required_argument, 0, 'd'},
			{"nthreads",  required_argument, 0, 't'},
			{"nqueues",   required_argument, 0, 'q'},
			{"benchmark", required_argument, 0, 'b'},
			{0, 0, 0, 0}
		};

		int idx = 0;
		int opt = getopt_long(argc, argv, "d:t:q:b:", options, &idx);

		std::string arg = optarg ? optarg : "";
		size_t len = 0;
		switch(opt) {
			// Exit Case
			case -1:
				/* paranoid */ assert(optind <= argc);
				switch(benchmark) {
				case NONE:
					std::cerr << "Must specify a benchmark" << std::endl;
					goto usage;
				case PingPong:
					nnodes = 1;
					nslots = 1;
					switch(argc - optind) {
					case 0: break;
					case 1:
						try {
							arg = optarg = argv[optind];
							nnodes = stoul(optarg, &len);
							if(len != arg.size()) { throw std::invalid_argument(""); }
						} catch(std::invalid_argument &) {
							std::cerr << "Number of nodes must be a positive integer, was " << arg << std::endl;
							goto usage;
						}
						break;
					default:
						std::cerr << "'PingPong' benchmark doesn't accept more than 2 extra arguments" << std::endl;
						goto usage;
					}
					break;
				case Churn:
					nnodes = 100;
					nslots = 100;
					switch(argc - optind) {
					case 0: break;
					case 1:
						try {
							arg = optarg = argv[optind];
							nnodes = stoul(optarg, &len);
							if(len != arg.size()) { throw std::invalid_argument(""); }
							nslots = nnodes;
						} catch(std::invalid_argument &) {
							std::cerr << "Number of nodes must be a positive integer, was " << arg << std::endl;
							goto usage;
						}
						break;
					case 2:
						try {
							arg = optarg = argv[optind];
							nnodes = stoul(optarg, &len);
							if(len != arg.size()) { throw std::invalid_argument(""); }
						} catch(std::invalid_argument &) {
							std::cerr << "Number of nodes must be a positive integer, was " << arg << std::endl;
							goto usage;
						}
						try {
							arg = optarg = argv[optind + 1];
							nslots = stoul(optarg, &len);
							if(len != arg.size()) { throw std::invalid_argument(""); }
						} catch(std::invalid_argument &) {
							std::cerr << "Number of slots must be a positive integer, was " << arg << std::endl;
							goto usage;
						}
						break;
					default:
						std::cerr << "'Churn' benchmark doesn't accept more than 2 extra arguments" << std::endl;
						goto usage;
					}
					break;
				}
				goto run;
			// Benchmarks
			case 'b':
				if(benchmark != NONE) {
					std::cerr << "Only when benchmark can be run" << std::endl;
					goto usage;
				}
				if(iequals(arg, "churn")) {
					benchmark = Churn;
					break;
				}
				if(iequals(arg, "pingpong")) {
					benchmark = PingPong;
					break;
				}
				std::cerr << "Unkown benchmark " << arg << std::endl;
				goto usage;
			// Numeric Arguments
			case 'd':
				try {
					duration = stod(optarg, &len);
					if(len != arg.size()) { throw std::invalid_argument(""); }
				} catch(std::invalid_argument &) {
					std::cerr << "Duration must be a valid double, was " << arg << std::endl;
					goto usage;
				}
				break;
			case 't':
				try {
					nthreads = stoul(optarg, &len);
					if(len != arg.size()) { throw std::invalid_argument(""); }
				} catch(std::invalid_argument &) {
					std::cerr << "Number of threads must be a positive integer, was " << arg << std::endl;
					goto usage;
				}
				break;
			case 'q':
				try {
					nqueues = stoul(optarg, &len);
					if(len != arg.size()) { throw std::invalid_argument(""); }
				} catch(std::invalid_argument &) {
					std::cerr << "Number of queues must be a positive integer, was " << arg << std::endl;
					goto usage;
				}
				break;
			// Other cases
			default: /* ? */
				std::cerr << opt << std::endl;
			usage:
				std::cerr << "Usage: " << argv[0] << ": [options] -b churn [NNODES] [NSLOTS = NNODES]" << std::endl;
				std::cerr << "  or:  " << argv[0] << ": [options] -b pingpong [NNODES]" << std::endl;
				std::cerr << std::endl;
				std::cerr << "  -d, --duration=DURATION  Duration of the experiment, in seconds" << std::endl;
				std::cerr << "  -t, --nthreads=NTHREADS  Number of kernel threads" << std::endl;
				std::cerr << "  -q, --nqueues=NQUEUES    Number of queues per threads" << std::endl;
				std::exit(1);
		}
	}
	run:

	check_cache_line_size();

	std::cout << "Running " << nthreads << " threads (" << (nthreads * nqueues) << " queues) for " << duration << " seconds" << std::endl;
	switch(benchmark) {
		case Churn:
			runChurn(nthreads, nqueues, duration, nnodes, nslots);
			break;
		case PingPong:
			runPingPong(nthreads, nqueues, duration, nnodes);
			break;
		default:
			abort();
	}
	return 0;
}

const char * __my_progname = "Relaxed List";