#include <array>
#include <iomanip>
#include <iostream>
#include <locale>
#include <string>
#include <thread>
#include <vector>

#include <getopt.h>
#include <unistd.h>
#include <sys/sysinfo.h>

#include "utils.hpp"

// ================================================================================================
//                        UTILS
// ================================================================================================

struct local_stat_t {
	size_t cnt = 0;
};

struct global_stat_t {
	std::atomic_size_t cnt = { 0 };
};

void atomic_max(std::atomic_size_t & target, size_t value) {
	for(;;) {
		size_t expect = target.load(std::memory_order_relaxed);
		if(value <= expect) return;
		bool success = target.compare_exchange_strong(expect, value);
		if(success) return;
	}
}

void atomic_min(std::atomic_size_t & target, size_t value) {
	for(;;) {
		size_t expect = target.load(std::memory_order_relaxed);
		if(value >= expect) return;
		bool success = target.compare_exchange_strong(expect, value);
		if(success) return;
	}
}

void tally_stats(global_stat_t & global, local_stat_t & local) {
	global.cnt   += local.cnt;
}

void waitfor(double & duration, barrier_t & barrier, std::atomic_bool & done) {
	std::cout << "Starting" << std::endl;
	auto before = Clock::now();
	barrier.wait(0);

	while(true) {
		usleep(100000);
		auto now = Clock::now();
		duration_t durr = now - before;
		if( durr.count() > duration ) {
			done = true;
			break;
		}
		std::cout << "\r" << std::setprecision(4) << durr.count();
		std::cout.flush();
	}

	barrier.wait(0);
	auto after = Clock::now();
	duration_t durr = after - before;
	duration = durr.count();
	std::cout << "\rClosing down" << std::endl;
}

void waitfor(double & duration, barrier_t & barrier, const std::atomic_size_t & count) {
	std::cout << "Starting" << std::endl;
	auto before = Clock::now();
	barrier.wait(0);

	while(true) {
		usleep(100000);
		size_t c = count.load();
		if( c == 0 ) {
			break;
		}
		std::cout << "\r" << c;
		std::cout.flush();
	}

	barrier.wait(0);
	auto after = Clock::now();
	duration_t durr = after - before;
	duration = durr.count();
	std::cout << "\rClosing down" << std::endl;
}

void print_stats(double duration, unsigned nthread, global_stat_t & global) {
	std::cout << "Done" << std::endl;

	size_t ops = global.cnt;
	size_t ops_sec = size_t(double(ops) / duration);
	size_t ops_thread = ops_sec / nthread;
	auto dur_nano = duration_cast<std::nano>(1.0);

	std::cout << "Duration      : " << duration << "s\n";
	std::cout << "ns/Op         : " << ( dur_nano / ops_thread )<< "\n";
	std::cout << "Ops/sec/thread: " << ops_thread << "\n";
	std::cout << "Ops/sec       : " << ops_sec << "\n";
	std::cout << "Total ops     : " << ops << "\n";
}

static inline bool bts(std::atomic_size_t & target, size_t bit ) {
	/*
	int result = 0;
	asm volatile(
		"LOCK btsq %[bit], %[target]\n\t"
		:"=@ccc" (result)
		: [target] "m" (target), [bit] "r" (bit)
	);
 	return result != 0;
	/*/
	size_t mask = 1ul << bit;
	size_t ret = target.fetch_or(mask, std::memory_order_relaxed);
	return (ret & mask) != 0;
	//*/
}

static inline bool btr(std::atomic_size_t & target, size_t bit ) {
	/*
	int result = 0;
	asm volatile(
		"LOCK btrq %[bit], %[target]\n\t"
		:"=@ccc" (result)
		: [target] "m" (target), [bit] "r" (bit)
	);
 	return result != 0;
	/*/
	size_t mask = 1ul << bit;
	size_t ret = target.fetch_and(~mask, std::memory_order_relaxed);
	return (ret & mask) != 0;
	//*/
}

// ================================================================================================
//                        EXPERIMENTS
// ================================================================================================

// ================================================================================================
__attribute__((noinline)) void runPingPong_body(
	std::atomic<bool>& done,
	local_stat_t & local,
	std::atomic_size_t & target,
	size_t id
) {
	while(__builtin_expect(!done.load(std::memory_order_relaxed), true)) {

		bool ret;
		ret = bts(target, id);
		assert(!ret);

		// -----

		ret = btr(target, id);
		assert(ret);
		local.cnt++;
	}
}

void run(unsigned nthread, double duration) {
	// Barrier for synchronization
	barrier_t barrier(nthread + 1);

	// Data to check everything is OK
	global_stat_t global;

	// Flag to signal termination
	std::atomic_bool done  = { false };

	std::cout << "Initializing ";
	// List being tested
	std::atomic_size_t word = { 0 };
	{
		std::thread * threads[nthread];
		unsigned i = 1;
		for(auto & t : threads) {
			t = new std::thread([&done, &word, &barrier, &global](unsigned tid) {
				local_stat_t local;

				// affinity(tid);

				barrier.wait(tid);

				// EXPERIMENT START

				runPingPong_body(done, local, word, tid - 1);

				// EXPERIMENT END

				barrier.wait(tid);

				tally_stats(global, local);
			}, i++);
		}

		waitfor(duration, barrier, done);

		for(auto t : threads) {
			t->join();
			delete t;
		}
	}

	print_stats(duration, nthread, global);
}

// ================================================================================================

int main(int argc, char * argv[]) {

	double duration   = 5.0;
	unsigned nthreads = 2;

	std::cout.imbue(std::locale(""));

	for(;;) {
		static struct option options[] = {
			{"duration",  required_argument, 0, 'd'},
			{"nthreads",  required_argument, 0, 't'},
			{0, 0, 0, 0}
		};

		int idx = 0;
		int opt = getopt_long(argc, argv, "d:t:", options, &idx);

		std::string arg = optarg ? optarg : "";
		size_t len = 0;
		switch(opt) {
			case -1:
				if(optind != argc) {
					std::cerr << "Too many arguments " << argc << " " << idx << std::endl;
					goto usage;
				}
				goto run;
			// Numeric Arguments
			case 'd':
				try {
					duration = std::stod(optarg, &len);
					if(len != arg.size()) { throw std::invalid_argument(""); }
				} catch(std::invalid_argument &) {
					std::cerr << "Duration must be a valid double, was " << arg << std::endl;
					goto usage;
				}
				break;
			case 't':
				try {
					nthreads = std::stoul(optarg, &len);
					if(len != arg.size() || nthreads > (8 * sizeof(size_t))) { throw std::invalid_argument(""); }
				} catch(std::invalid_argument &) {
					std::cerr << "Number of threads must be a positive integer less than or equal to " << sizeof(size_t) * 8 << ", was " << arg << std::endl;
					goto usage;
				}
				break;
			// Other cases
			default: /* ? */
				std::cerr << opt << std::endl;
			usage:
				std::cerr << "Usage: " << argv[0] << ": [options]" << std::endl;
				std::cerr << std::endl;
				std::cerr << "  -d, --duration=DURATION  Duration of the experiment, in seconds" << std::endl;
				std::cerr << "  -t, --nthreads=NTHREADS  Number of kernel threads" << std::endl;
				std::exit(1);
		}
	}
	run:

	check_cache_line_size();

	std::cout << "Running " << nthreads << " threads for " << duration << " seconds" << std::endl;
	run(nthreads, duration);
	return 0;
}