#include "processor_list.hpp"

#include <array>
#include <iomanip>
#include <iostream>
#include <locale>
#include <string>
#include <thread>

#include "utils.hpp"

unsigned num() {
	return 0x1000000;
}

//-------------------

struct processor {
	unsigned id;
};
void run(unsigned nthread, double duration, unsigned writes, unsigned epochs) {
	assert(writes < 100);

	// List being tested
	processor_list list = {};

	// Barrier for synchronization
	barrier_t barrier(nthread + 1);

	// Data to check everything is OK
	size_t write_committed = 0ul;
	struct {
		std::atomic_size_t write = { 0ul };
		std::atomic_size_t read  = { 0ul };
		std::atomic_size_t epoch = { 0ul };
	} lock_cnt;

	// Flag to signal termination
	std::atomic_bool done = { false };

	std::thread * threads[nthread];
	unsigned i = 1;
	for(auto & t : threads) {
		t = new std::thread([&done, &list, &barrier, &write_committed, &lock_cnt, writes, epochs](unsigned tid) {
			Random rand(tid + rdtscl());
			processor proc;
			proc.id = list.doregister(&proc);
			size_t writes_cnt = 0;
			size_t reads_cnt = 0;
			size_t epoch_cnt = 0;

			affinity(tid);

			barrier.wait(tid);

			while(__builtin_expect(!done, true)) {
				auto r = rand.next() % 100;
				if (r < writes) {
					auto n = list.write_lock();
					write_committed++;
					writes_cnt++;
					assert(writes_cnt < -2ul);
					list.write_unlock(n);
				}
				else if(r < epochs) {
					list.epoch_check();
					epoch_cnt++;
				}
				else {
					list.read_lock(proc.id);
					reads_cnt++;
					assert(reads_cnt < -2ul);
					list.read_unlock(proc.id);
				}
			}

			barrier.wait(tid);

			auto p = list.unregister(proc.id);
			assert(&proc == p);
			lock_cnt.write += writes_cnt;
			lock_cnt.read  += reads_cnt;
			lock_cnt.epoch += epoch_cnt;
		}, i++);
	}

	auto before = Clock::now();
	barrier.wait(0);

	while(true) {
		usleep(1000);
		auto now = Clock::now();
		duration_t durr = now - before;
		if( durr.count() > duration ) {
			done = true;
			break;
		}
	}

	barrier.wait(0);
	auto after = Clock::now();
	duration_t durr = after - before;
	duration = durr.count();

	for(auto t : threads) {
		t->join();
		delete t;
	}

	assert(write_committed == lock_cnt.write);

	size_t totalop = lock_cnt.read + lock_cnt.write + lock_cnt.epoch;
	size_t ops_sec = size_t(double(totalop) / duration);
	size_t ops_thread = ops_sec / nthread;
	double dur_nano = duration_cast<std::nano>(1.0);

	std::cout << "Duration      : " << duration << "s\n";
	std::cout << "Total ops     : " << totalop << "(" << lock_cnt.read << "r, " << lock_cnt.write << "w, " << lock_cnt.epoch << "e)\n";
	std::cout << "Ops/sec       : " << ops_sec << "\n";
	std::cout << "Ops/sec/thread: " << ops_thread << "\n";
	std::cout << "ns/Op         : " << ( dur_nano / ops_thread )<< "\n";
}

void usage(char * argv[]) {
	std::cerr << argv[0] << ": [DURATION (FLOAT:SEC)] [NTHREADS] [%WRITES]" << std::endl;;
	std::exit(1);
}

int main(int argc, char * argv[]) {

	double duration   = 5.0;
	unsigned nthreads = 2;
	unsigned writes   = 0;
	unsigned epochs   = 0;

	std::cout.imbue(std::locale(""));

	switch (argc)
	{
	case 5:
		epochs = std::stoul(argv[4]);
		[[fallthrough]];
	case 4:
		writes = std::stoul(argv[3]);
		if( (writes + epochs) > 100 ) {
			std::cerr << "Writes + Epochs must be valid percentage, was " << argv[3] << " + " << argv[4] << "(" << writes << " + " << epochs << ")" << std::endl;
			usage(argv);
		}
		[[fallthrough]];
	case 3:
		nthreads = std::stoul(argv[2]);
		[[fallthrough]];
	case 2:
		duration = std::stod(argv[1]);
		if( duration <= 0.0 ) {
			std::cerr << "Duration must be positive, was " << argv[1] << "(" << duration << ")" << std::endl;
			usage(argv);
		}
		[[fallthrough]];
	case 1:
		break;
	default:
		usage(argv);
		break;
	}

	check_cache_line_size();

	std::cout << "Running " << nthreads << " threads for " << duration << " seconds with " << writes << "% writes and " << epochs << "% epochs" << std::endl;
	run(nthreads, duration, writes, epochs + writes);

	return 0;
}
