#include "relaxed_list.hpp"

#include <array>
#include <iomanip>
#include <iostream>
#include <locale>
#include <string>
#include <thread>
#include <vector>

#include <getopt.h>
#include <unistd.h>
#include <sys/sysinfo.h>

#include "utils.hpp"

struct __attribute__((aligned(64))) Node {
	static std::atomic_size_t creates;
	static std::atomic_size_t destroys;

	_LinksFields_t<Node> _links;

	int value;
	int id;

	Node() { creates++; }
	Node(int value): value(value) { creates++; }
	~Node() { destroys++; }
};

std::atomic_size_t Node::creates  = { 0 };
std::atomic_size_t Node::destroys = { 0 };

bool enable_stats = false;

template<>
thread_local relaxed_list<Node>::TLS relaxed_list<Node>::tls = {};

template<>
relaxed_list<Node> * relaxed_list<Node>::head = nullptr;

#ifndef NO_STATS
template<>
relaxed_list<Node>::GlobalStats relaxed_list<Node>::global_stats = {};
#endif

// ================================================================================================
//                        UTILS
// ================================================================================================

struct local_stat_t {
	size_t in  = 0;
	size_t out = 0;
	size_t empty = 0;
	size_t crc_in  = 0;
	size_t crc_out = 0;
	size_t valmax = 0;
	size_t valmin = 100000000ul;
};

struct global_stat_t {
	std::atomic_size_t in  = { 0 };
	std::atomic_size_t out = { 0 };
	std::atomic_size_t empty = { 0 };
	std::atomic_size_t crc_in  = { 0 };
	std::atomic_size_t crc_out = { 0 };
	std::atomic_size_t valmax = { 0 };
	std::atomic_size_t valmin = { 100000000ul };
};

void atomic_max(std::atomic_size_t & target, size_t value) {
	for(;;) {
		size_t expect = target.load(std::memory_order_relaxed);
		if(value <= expect) return;
		bool success = target.compare_exchange_strong(expect, value);
		if(success) return;
	}
}

void atomic_min(std::atomic_size_t & target, size_t value) {
	for(;;) {
		size_t expect = target.load(std::memory_order_relaxed);
		if(value >= expect) return;
		bool success = target.compare_exchange_strong(expect, value);
		if(success) return;
	}
}

void tally_stats(global_stat_t & global, local_stat_t & local) {

	global.in    += local.in;
	global.out   += local.out;
	global.empty += local.empty;

	global.crc_in  += local.crc_in;
	global.crc_out += local.crc_out;

	atomic_max(global.valmax, local.valmax);
	atomic_min(global.valmin, local.valmin);

	relaxed_list<Node>::stats_tls_tally();
}

void waitfor(double & duration, barrier_t & barrier, std::atomic_bool & done) {
	std::cout << "Starting" << std::endl;
	auto before = Clock::now();
	barrier.wait(0);

	while(true) {
		usleep(100000);
		auto now = Clock::now();
		duration_t durr = now - before;
		if( durr.count() > duration ) {
			done = true;
			break;
		}
		std::cout << "\r" << std::setprecision(4) << durr.count();
		std::cout.flush();
	}

	barrier.wait(0);
	auto after = Clock::now();
	duration_t durr = after - before;
	duration = durr.count();
	std::cout << "\rClosing down" << std::endl;
}

void waitfor(double & duration, barrier_t & barrier, const std::atomic_size_t & count) {
	std::cout << "Starting" << std::endl;
	auto before = Clock::now();
	barrier.wait(0);

	while(true) {
		usleep(100000);
		size_t c = count.load();
		if( c == 0 ) {
			break;
		}
		std::cout << "\r" << c;
		std::cout.flush();
	}

	barrier.wait(0);
	auto after = Clock::now();
	duration_t durr = after - before;
	duration = durr.count();
	std::cout << "\rClosing down" << std::endl;
}

void print_stats(double duration, unsigned nthread, global_stat_t & global) {
	assert(Node::creates == Node::destroys);
	assert(global.crc_in == global.crc_out);

	std::cout << "Done" << std::endl;

	size_t ops = global.in + global.out;
	size_t ops_sec = size_t(double(ops) / duration);
	size_t ops_thread = ops_sec / nthread;
	auto dur_nano = duration_cast<std::nano>(1.0);

	std::cout << "Duration      : " << duration << "s\n";
	std::cout << "ns/Op         : " << ( dur_nano / ops_thread )<< "\n";
	std::cout << "Ops/sec/thread: " << ops_thread << "\n";
	std::cout << "Ops/sec       : " << ops_sec << "\n";
	std::cout << "Total ops     : " << ops << "(" << global.in << "i, " << global.out << "o, " << global.empty << "e)\n";
	if(global.valmax != 0) {
		std::cout << "Max runs      : " << global.valmax << "\n";
		std::cout << "Min runs      : " << global.valmin << "\n";
	}
	#ifndef NO_STATS
		relaxed_list<Node>::stats_print(std::cout);
	#endif
}

void save_fairness(const int data[], int factor, unsigned nthreads, size_t columns, size_t rows, const std::string & output);

// ================================================================================================
//                        EXPERIMENTS
// ================================================================================================

// ================================================================================================
__attribute__((noinline)) void runChurn_body(
	std::atomic<bool>& done,
	Random & rand,
	Node * my_nodes[],
	unsigned nslots,
	local_stat_t & local,
	relaxed_list<Node> & list
) {
	while(__builtin_expect(!done.load(std::memory_order_relaxed), true)) {
		int idx = rand.next() % nslots;
		if (auto node = my_nodes[idx]) {
			local.crc_in += node->value;
			list.push(node);
			my_nodes[idx] = nullptr;
			local.in++;
		}
		else if(auto node = list.pop()) {
			local.crc_out += node->value;
			my_nodes[idx] = node;
			local.out++;
		}
		else {
			local.empty++;
		}
	}
}

void runChurn(unsigned nthread, unsigned nqueues, double duration, unsigned nnodes, const unsigned nslots) {
	std::cout << "Churn Benchmark" << std::endl;
	assert(nnodes <= nslots);
	// List being tested

	// Barrier for synchronization
	barrier_t barrier(nthread + 1);

	// Data to check everything is OK
	global_stat_t global;

	// Flag to signal termination
	std::atomic_bool done  = { false };

	// Prep nodes
	std::cout << "Initializing ";
	size_t npushed = 0;
	relaxed_list<Node> list = { nthread * nqueues };
	{
		Node** all_nodes[nthread];
		for(auto & nodes : all_nodes) {
			nodes = new __attribute__((aligned(64))) Node*[nslots + 8];
			Random rand(rdtscl());
			for(unsigned i = 0; i < nnodes; i++) {
				nodes[i] = new Node(rand.next() % 100);
			}

			for(unsigned i = nnodes; i < nslots; i++) {
				nodes[i] = nullptr;
			}

			for(int i = 0; i < 10 && i < (int)nslots; i++) {
				int idx = rand.next() % nslots;
				if (auto node = nodes[idx]) {
					global.crc_in += node->value;
					list.push(node);
					npushed++;
					nodes[idx] = nullptr;
				}
			}
		}

		std::cout << nnodes << " nodes (" << nslots << " slots)" << std::endl;

		enable_stats = true;

		std::thread * threads[nthread];
		unsigned i = 1;
		for(auto & t : threads) {
			auto & my_nodes = all_nodes[i - 1];
			t = new std::thread([&done, &list, &barrier, &global, &my_nodes, nslots](unsigned tid) {
				Random rand(tid + rdtscl());

				local_stat_t local;

				// affinity(tid);

				barrier.wait(tid);

				// EXPERIMENT START

				runChurn_body(done, rand, my_nodes, nslots, local, list);

				// EXPERIMENT END

				barrier.wait(tid);

				tally_stats(global, local);

				for(unsigned i = 0; i < nslots; i++) {
					delete my_nodes[i];
				}
			}, i++);
		}

		waitfor(duration, barrier, done);

		for(auto t : threads) {
			t->join();
			delete t;
		}

		enable_stats = false;

		while(auto node = list.pop()) {
			global.crc_out += node->value;
			delete node;
		}

		for(auto nodes : all_nodes) {
			delete[] nodes;
		}
	}

	print_stats(duration, nthread, global);
}

// ================================================================================================
__attribute__((noinline)) void runPingPong_body(
	std::atomic<bool>& done,
	Node initial_nodes[],
	unsigned nnodes,
	local_stat_t & local,
	relaxed_list<Node> & list
) {
	Node * nodes[nnodes];
	{
		unsigned i = 0;
		for(auto & n : nodes) {
			n = &initial_nodes[i++];
		}
	}

	while(__builtin_expect(!done.load(std::memory_order_relaxed), true)) {

		for(Node * & node : nodes) {
			local.crc_in += node->value;
			list.push(node);
			local.in++;
		}

		// -----

		for(Node * & node : nodes) {
			node = list.pop();
			assert(node);
			local.crc_out += node->value;
			local.out++;
		}
	}
}

void runPingPong(unsigned nthread, unsigned nqueues, double duration, unsigned nnodes) {
	std::cout << "PingPong Benchmark" << std::endl;


	// Barrier for synchronization
	barrier_t barrier(nthread + 1);

	// Data to check everything is OK
	global_stat_t global;

	// Flag to signal termination
	std::atomic_bool done  = { false };

	std::cout << "Initializing ";
	// List being tested
	relaxed_list<Node> list = { nthread * nqueues };
	{
		enable_stats = true;

		std::thread * threads[nthread];
		unsigned i = 1;
		for(auto & t : threads) {
			t = new std::thread([&done, &list, &barrier, &global, nnodes](unsigned tid) {
				Random rand(tid + rdtscl());

				Node nodes[nnodes];
				for(auto & n : nodes) {
					n.value = (int)rand.next() % 100;
				}

				local_stat_t local;

				// affinity(tid);

				barrier.wait(tid);

				// EXPERIMENT START

				runPingPong_body(done, nodes, nnodes, local, list);

				// EXPERIMENT END

				barrier.wait(tid);

				tally_stats(global, local);
			}, i++);
		}

		waitfor(duration, barrier, done);

		for(auto t : threads) {
			t->join();
			delete t;
		}

		enable_stats = false;
	}

	print_stats(duration, nthread, global);
}

// ================================================================================================
__attribute__((noinline)) void runFairness_body(
	unsigned tid,
	size_t width,
	size_t length,
	int output[],
	std::atomic_size_t & count,
	Node initial_nodes[],
	unsigned nnodes,
	local_stat_t & local,
	relaxed_list<Node> & list
) {
	Node * nodes[nnodes];
	{
		unsigned i = 0;
		for(auto & n : nodes) {
			n = &initial_nodes[i++];
		}
	}

	while(__builtin_expect(0 != count.load(std::memory_order_relaxed), true)) {

		for(Node * & node : nodes) {
			local.crc_in += node->id;
			list.push(node);
			local.in++;
		}

		// -----

		for(Node * & node : nodes) {
			node = list.pop();
			assert(node);

			if (unsigned(node->value) < length) {
				size_t idx = (node->value * width) + node->id;
				assert(idx < (width * length));
				output[idx] = tid;
			}

			node->value++;
			if(unsigned(node->value) == length) count--;

			local.crc_out += node->id;
			local.out++;
		}
	}
}

void runFairness(unsigned nthread, unsigned nqueues, double duration, unsigned nnodes, const std::string & output) {
	std::cout << "Fairness Benchmark, outputing to : " << output << std::endl;

	// Barrier for synchronization
	barrier_t barrier(nthread + 1);

	// Data to check everything is OK
	global_stat_t global;

	std::cout << "Initializing ";

	// Check fairness by creating a png of where the threads ran
	size_t width = nthread * nnodes;
	size_t length = 100000;

	std::unique_ptr<int[]> data_out { new int[width * length] };

	// Flag to signal termination
	std::atomic_size_t count = width;

	// List being tested
	relaxed_list<Node> list = { nthread * nqueues };
	{
		enable_stats = true;

		std::thread * threads[nthread];
		unsigned i = 1;
		for(auto & t : threads) {
			t = new std::thread([&count, &list, &barrier, &global, nnodes, width, length, data_out = data_out.get()](unsigned tid) {
				unsigned int start = (tid - 1) * nnodes;
				Node nodes[nnodes];
				for(auto & n : nodes) {
					n.id = start;
					n.value = 0;
					start++;
				}

				local_stat_t local;

				// affinity(tid);

				barrier.wait(tid);

				// EXPERIMENT START

				runFairness_body(tid, width, length, data_out, count, nodes, nnodes, local, list);

				// EXPERIMENT END

				barrier.wait(tid);

				for(const auto & n : nodes) {
					local.valmax = max(local.valmax, size_t(n.value));
					local.valmin = min(local.valmin, size_t(n.value));
				}

				tally_stats(global, local);
			}, i++);
		}

		waitfor(duration, barrier, count);

		for(auto t : threads) {
			t->join();
			delete t;
		}

		enable_stats = false;
	}

	print_stats(duration, nthread, global);

	save_fairness(data_out.get(), 100, nthread, width, length, output);
}

// ================================================================================================

bool iequals(const std::string& a, const std::string& b)
{
    return std::equal(a.begin(), a.end(),
                      b.begin(), b.end(),
                      [](char a, char b) {
                          return std::tolower(a) == std::tolower(b);
                      });
}

int main(int argc, char * argv[]) {

	double duration   = 5.0;
	unsigned nthreads = 2;
	unsigned nqueues  = 4;
	unsigned nnodes   = 100;
	unsigned nslots   = 100;
	std::string out   = "fairness.png";

	enum {
		Churn,
		PingPong,
		Fairness,
		NONE
	} benchmark = NONE;

	std::cout.imbue(std::locale(""));

	for(;;) {
		static struct option options[] = {
			{"duration",  required_argument, 0, 'd'},
			{"nthreads",  required_argument, 0, 't'},
			{"nqueues",   required_argument, 0, 'q'},
			{"benchmark", required_argument, 0, 'b'},
			{0, 0, 0, 0}
		};

		int idx = 0;
		int opt = getopt_long(argc, argv, "d:t:q:b:", options, &idx);

		std::string arg = optarg ? optarg : "";
		size_t len = 0;
		switch(opt) {
			// Exit Case
			case -1:
				/* paranoid */ assert(optind <= argc);
				switch(benchmark) {
				case NONE:
					std::cerr << "Must specify a benchmark" << std::endl;
					goto usage;
				case PingPong:
					nnodes = 1;
					nslots = 1;
					switch(argc - optind) {
					case 0: break;
					case 1:
						try {
							arg = optarg = argv[optind];
							nnodes = stoul(optarg, &len);
							if(len != arg.size()) { throw std::invalid_argument(""); }
						} catch(std::invalid_argument &) {
							std::cerr << "Number of nodes must be a positive integer, was " << arg << std::endl;
							goto usage;
						}
						break;
					default:
						std::cerr << "'PingPong' benchmark doesn't accept more than 2 extra arguments" << std::endl;
						goto usage;
					}
					break;
				case Churn:
					nnodes = 100;
					nslots = 100;
					switch(argc - optind) {
					case 0: break;
					case 1:
						try {
							arg = optarg = argv[optind];
							nnodes = stoul(optarg, &len);
							if(len != arg.size()) { throw std::invalid_argument(""); }
							nslots = nnodes;
						} catch(std::invalid_argument &) {
							std::cerr << "Number of nodes must be a positive integer, was " << arg << std::endl;
							goto usage;
						}
						break;
					case 2:
						try {
							arg = optarg = argv[optind];
							nnodes = stoul(optarg, &len);
							if(len != arg.size()) { throw std::invalid_argument(""); }
						} catch(std::invalid_argument &) {
							std::cerr << "Number of nodes must be a positive integer, was " << arg << std::endl;
							goto usage;
						}
						try {
							arg = optarg = argv[optind + 1];
							nslots = stoul(optarg, &len);
							if(len != arg.size()) { throw std::invalid_argument(""); }
						} catch(std::invalid_argument &) {
							std::cerr << "Number of slots must be a positive integer, was " << arg << std::endl;
							goto usage;
						}
						break;
					default:
						std::cerr << "'Churn' benchmark doesn't accept more than 2 extra arguments" << std::endl;
						goto usage;
					}
					break;
				case Fairness:
					nnodes = 1;
					switch(argc - optind) {
					case 0: break;
					case 1:
						arg = optarg = argv[optind];
						out = arg;
						break;
					default:
						std::cerr << "'Churn' benchmark doesn't accept more than 2 extra arguments" << std::endl;
						goto usage;
					}
				}
				goto run;
			// Benchmarks
			case 'b':
				if(benchmark != NONE) {
					std::cerr << "Only when benchmark can be run" << std::endl;
					goto usage;
				}
				if(iequals(arg, "churn")) {
					benchmark = Churn;
					break;
				}
				if(iequals(arg, "pingpong")) {
					benchmark = PingPong;
					break;
				}
				if(iequals(arg, "fairness")) {
					benchmark = Fairness;
					break;
				}
				std::cerr << "Unkown benchmark " << arg << std::endl;
				goto usage;
			// Numeric Arguments
			case 'd':
				try {
					duration = stod(optarg, &len);
					if(len != arg.size()) { throw std::invalid_argument(""); }
				} catch(std::invalid_argument &) {
					std::cerr << "Duration must be a valid double, was " << arg << std::endl;
					goto usage;
				}
				break;
			case 't':
				try {
					nthreads = stoul(optarg, &len);
					if(len != arg.size()) { throw std::invalid_argument(""); }
				} catch(std::invalid_argument &) {
					std::cerr << "Number of threads must be a positive integer, was " << arg << std::endl;
					goto usage;
				}
				break;
			case 'q':
				try {
					nqueues = stoul(optarg, &len);
					if(len != arg.size()) { throw std::invalid_argument(""); }
				} catch(std::invalid_argument &) {
					std::cerr << "Number of queues must be a positive integer, was " << arg << std::endl;
					goto usage;
				}
				break;
			// Other cases
			default: /* ? */
				std::cerr << opt << std::endl;
			usage:
				std::cerr << "Usage: " << argv[0] << ": [options] -b churn [NNODES] [NSLOTS = NNODES]" << std::endl;
				std::cerr << "  or:  " << argv[0] << ": [options] -b pingpong [NNODES]" << std::endl;
				std::cerr << std::endl;
				std::cerr << "  -d, --duration=DURATION  Duration of the experiment, in seconds" << std::endl;
				std::cerr << "  -t, --nthreads=NTHREADS  Number of kernel threads" << std::endl;
				std::cerr << "  -q, --nqueues=NQUEUES    Number of queues per threads" << std::endl;
				std::exit(1);
		}
	}
	run:

	check_cache_line_size();

	std::cout << "Running " << nthreads << " threads (" << (nthreads * nqueues) << " queues) for " << duration << " seconds" << std::endl;
	switch(benchmark) {
		case Churn:
			runChurn(nthreads, nqueues, duration, nnodes, nslots);
			break;
		case PingPong:
			runPingPong(nthreads, nqueues, duration, nnodes);
			break;
		case Fairness:
			runFairness(nthreads, nqueues, duration, nnodes, out);
			break;
		default:
			abort();
	}
	return 0;
}

const char * __my_progname = "Relaxed List";

struct rgb_t {
    double r;       // a fraction between 0 and 1
    double g;       // a fraction between 0 and 1
    double b;       // a fraction between 0 and 1
};

struct hsv_t {
    double h;       // angle in degrees
    double s;       // a fraction between 0 and 1
    double v;       // a fraction between 0 and 1
};

rgb_t hsv2rgb(hsv_t in) {
	double hh, p, q, t, ff;
	long   i;
	rgb_t  out;

	if(in.s <= 0.0) {       // < is bogus, just shuts up warnings
		out.r = in.v;
		out.g = in.v;
		out.b = in.v;
		return out;
	}
	hh = in.h;
	if(hh >= 360.0) hh = 0.0;
	hh /= 60.0;
	i = (long)hh;
	ff = hh - i;
	p = in.v * (1.0 - in.s);
	q = in.v * (1.0 - (in.s * ff));
	t = in.v * (1.0 - (in.s * (1.0 - ff)));

	switch(i) {
	case 0:
		out.r = in.v;
		out.g = t;
		out.b = p;
		break;
	case 1:
		out.r = q;
		out.g = in.v;
		out.b = p;
		break;
	case 2:
		out.r = p;
		out.g = in.v;
		out.b = t;
		break;

	case 3:
		out.r = p;
		out.g = q;
		out.b = in.v;
		break;
	case 4:
		out.r = t;
		out.g = p;
		out.b = in.v;
		break;
	case 5:
	default:
		out.r = in.v;
		out.g = p;
		out.b = q;
		break;
	}
	return out;
}

void save_fairness(const int data[], int factor, unsigned nthreads, size_t columns, size_t rows, const std::string & output) {
	std::ofstream os(output);
	os << "<html>\n";
	os << "<head>\n";
	os << "<style>\n";
	os << "</style>\n";
	os << "</head>\n";
	os << "<body>\n";
	os << "<table style=\"width=100%\">\n";

	size_t idx = 0;
	for(size_t r = 0ul; r < rows; r++) {
		os << "<tr>\n";
		for(size_t c = 0ul; c < columns; c++) {
			os << "<td class=\"custom custom" << data[idx] << "\"></td>\n";
			idx++;
		}
		os << "</tr>\n";
	}

	os << "</table>\n";
	os << "</body>\n";
	os << "</html>\n";
	os << std::endl;
}

#include <png.h>
#include <setjmp.h>

/*
void save_fairness(const int data[], int factor, unsigned nthreads, size_t columns, size_t rows, const std::string & output) {
	int width  = columns * factor;
	int height = rows / factor;

	int code = 0;
	int idx = 0;
	FILE *fp = NULL;
	png_structp png_ptr = NULL;
	png_infop info_ptr = NULL;
	png_bytep row = NULL;

	// Open file for writing (binary mode)
	fp = fopen(output.c_str(), "wb");
	if (fp == NULL) {
		fprintf(stderr, "Could not open file %s for writing\n", output.c_str());
		code = 1;
		goto finalise;
	}

	   // Initialize write structure
	png_ptr = png_create_write_struct(PNG_LIBPNG_VER_STRING, NULL, NULL, NULL);
	if (png_ptr == NULL) {
		fprintf(stderr, "Could not allocate write struct\n");
		code = 1;
		goto finalise;
	}

	// Initialize info structure
	info_ptr = png_create_info_struct(png_ptr);
	if (info_ptr == NULL) {
		fprintf(stderr, "Could not allocate info struct\n");
		code = 1;
		goto finalise;
	}

	// Setup Exception handling
	if (setjmp(png_jmpbuf(png_ptr))) {
		fprintf(stderr, "Error during png creation\n");
		code = 1;
		goto finalise;
	}

	png_init_io(png_ptr, fp);

	// Write header (8 bit colour depth)
	png_set_IHDR(png_ptr, info_ptr, width, height,
		8, PNG_COLOR_TYPE_RGB, PNG_INTERLACE_NONE,
		PNG_COMPRESSION_TYPE_BASE, PNG_FILTER_TYPE_BASE);

	png_write_info(png_ptr, info_ptr);

	// Allocate memory for one row (3 bytes per pixel - RGB)
	row = (png_bytep) malloc(3 * width * sizeof(png_byte));

	// Write image data
	int x, y;
	for (y=0 ; y<height ; y++) {
		for (x=0 ; x<width ; x++) {
			auto & r = row[(x * 3) + 0];
			auto & g = row[(x * 3) + 1];
			auto & b = row[(x * 3) + 2];
			assert(idx < (rows * columns));
			int color = data[idx] - 1;
			assert(color < nthreads);
			assert(color >= 0);
			idx++;

			double angle = double(color) / double(nthreads);

			auto c = hsv2rgb({ 360.0 * angle, 0.8, 0.8 });

			r = char(c.r * 255.0);
			g = char(c.g * 255.0);
			b = char(c.b * 255.0);

		}
		png_write_row(png_ptr, row);
	}

	assert(idx == (rows * columns));

	// End write
	png_write_end(png_ptr, NULL);

	finalise:
	if (fp != NULL) fclose(fp);
	if (info_ptr != NULL) png_free_data(png_ptr, info_ptr, PNG_FREE_ALL, -1);
	if (png_ptr != NULL) png_destroy_write_struct(&png_ptr, (png_infopp)NULL);
	if (row != NULL) free(row);
}
*/