#include "processor_list.hpp"

#include <iostream>
#include <string>
#include <thread>

unsigned num() {
	return 0x1000000;
}

// Barrier from
class barrier_t {
public:
	barrier_t(size_t total)
		: waiting(0)
		, total(total)
	{}

	void wait(unsigned) {
		size_t target = waiting++;
		target = (target - (target % total)) + total;
		while(waiting < target)
			asm volatile("pause");

		assert(waiting < (1ul << 60));
    	}

private:
	std::atomic<size_t> waiting;
	size_t total;
};

class Random {
private:
	unsigned int seed;
public:
	Random(int seed) {
		this->seed = seed;
	}

	/** returns pseudorandom x satisfying 0 <= x < n. **/
	unsigned int next() {
		seed ^= seed << 6;
		seed ^= seed >> 21;
		seed ^= seed << 7;
		return seed;
    	}
};

//-------------------

struct processor {
	unsigned id;
};

// Stage 1
// Make sure that the early registration works correctly
// Registration uses a different process if the act of
// registering the processor makes it the highest processor count
// seen yet.
void stage1(unsigned nthread, unsigned repeats) {
	const int n = repeats;
	const int nproc = 10;

	// List being tested
	processor_list list;

	// Barrier for synchronization
	barrier_t barrier(nthread + 1);

	// Seen values to detect duplicattion
	std::atomic<processor *> ids[nthread * nproc];
	for(auto & i : ids) {
		i = nullptr;
	}

	// Can't pass VLA to lambda
	std::atomic<processor *> * idsp = ids;

	// Threads which will run the code
	std::thread * threads[nthread];
	unsigned i = 1;
	for(auto & t : threads) {
		// Each thread will try to register a processor then add it to the
		// list of registerd processor
		t = new std::thread([&list, &barrier, idsp, n](unsigned tid){
			processor proc[nproc];
			for(int i = 0; i < n; i++) {
				for(auto & p : proc) {
					// Register the thread
					p.id = list.doregister(&p);
				}

				for(auto & p : proc) {
					// Make sure no one got this id before
					processor * prev = idsp[p.id].exchange(&p);
					assert(nullptr == prev);

					// Make sure id is still consistend
					assert(&p == list.get(p.id));
				}

				// wait for round to finish
				barrier.wait(tid);

				// wait for reset
				barrier.wait(tid);
			}
		}, i++);
	}

	for(int i = 0; i < n; i++) {
		//Wait for round to finish
		barrier.wait(0);

		// Reset list
		list.reset();

		std::cout << i << "\r";

		// Reset seen values
		for(auto & i : ids) {
			i = nullptr;
		}

		// Start next round
		barrier.wait(0);
	}

	for(auto t : threads) {
		t->join();
		delete t;
	}
}

// Stage 2
// Check that once churning starts, registration is still consistent.
void stage2(unsigned nthread, unsigned repeats) {
	// List being tested
	processor_list list;

	// Threads which will run the code
	std::thread * threads[nthread];
	unsigned i = 1;
	for(auto & t : threads) {
		// Each thread will try to register a few processors and
		// unregister them, making sure that the registration is
		// consistent
		t = new std::thread([&list, repeats](unsigned tid){
			processor procs[10];
			for(unsigned i = 0; i < repeats; i++) {
				// register the procs and note the id
				for(auto & p : procs) {
					p.id = list.doregister(&p);
				}

				if(1 == tid) std::cout << i << "\r";

				// check the id is still consistent
				for(const auto & p : procs) {
					assert(&p == list.get(p.id));
				}

				// unregister and check the id is consistent
				for(const auto & p : procs) {
					assert(&p == list.unregister(p.id));
				}
			}
		}, i++);
	}

	for(auto t : threads) {
		t->join();
		delete t;
	}
}

bool is_writer();

// Stage 3
// Check that the reader writer lock works.
void stage3(unsigned nthread, unsigned repeats) {
	// List being tested
	processor_list list;

	size_t before = 0;

	std::unique_ptr<size_t> after( new size_t(0) );

	std::atomic<bool> done ( false );

	// Threads which will run the code
	std::thread * threads[nthread];
	unsigned i = 1;
	for(auto & t : threads) {
		// Each thread will try to register a few processors and
		// unregister them, making sure that the registration is
		// consistent
		t = new std::thread([&list, repeats, &before, &after, &done](unsigned tid){
			Random rng(tid);
			processor proc;
			proc.id = list.doregister(&proc);
			while(!done) {

				if( (rng.next() % 100) == 0 ) {
					auto r = list.write_lock();

					auto b = before++;

					std::cout << b << "\r";

					(*after)++;

					if(b >= repeats) done = true;

					list.write_unlock(r);
				}
				else {
					list.read_lock(proc.id);
					assert(before == *after);
					list.read_unlock(proc.id);
				}

			}

			list.unregister(proc.id);
		}, i++);
	}

	for(auto t : threads) {
		t->join();
		delete t;
	}
}

int main(int argc, char * argv[]) {

	unsigned nthreads = 1;
	if( argc >= 3 ) {
		size_t idx;
		nthreads = std::stoul(argv[2], &idx);
		assert('\0' == argv[2][idx]);
	}

	unsigned repeats = 100;
	if( argc >= 2 ) {
		size_t idx;
		repeats = std::stoul(argv[1], &idx);
		assert('\0' == argv[1][idx]);
	}

	processor_list::check_cache_line_size();

	std::cout << "Running " << repeats << " repetitions on " << nthreads << " threads" << std::endl;
	std::cout << "Checking registration - early" << std::endl;
	stage1(nthreads, repeats);
	std::cout << "Done                         " << std::endl;

	std::cout << "Checking registration - churn" << std::endl;
	stage2(nthreads, repeats);
	std::cout << "Done                         " << std::endl;

	std::cout << "Checking RW lock             " << std::endl;
	stage3(nthreads, repeats);
	std::cout << "Done                         " << std::endl;


	return 0;
}