#include "thrdlib/thread.hpp"

#include <cassert>

#include <algorithm>
#include <atomic>
#include <iostream>
#include <memory>
#include <vector>

#include <getopt.h>
using thrdlib::thread_t;

//--------------------
// Constants
unsigned nframes;
unsigned fsize;
unsigned nproduce;

//--------------------
// Frame management

class Frame {
	static const thread_t reset;
	static const thread_t set;
	std::atomic<thread_t> rdy_state = { reset };
	std::atomic<thread_t> rnd_state = { set };
public:
	unsigned number;
	std::unique_ptr<unsigned char[]> data;

private:
	inline bool wait( thread_t self, std::atomic<thread_t> & state, std::atomic<thread_t> & other ) {
		bool ret;
		while(true) {
			thread_t expected = state;
			if( expected == set ) { ret = false; goto END; }
			assert( expected == reset );
			if( std::atomic_compare_exchange_strong( &state, &expected, self) ) {
				thrdlib::park( self );
				ret = true;
				goto END;
			}
		}
		END:
		assert( state == set );
		assert( other != set );
		state = reset;
		return ret;
	}

	inline bool publish(  std::atomic<thread_t> & state ) {
		thread_t got = std::atomic_exchange( &state, set );
		assert( got != set );

		if( got == reset ) return false;

		thrdlib::unpark( got );
		return true;
	}

public:
	inline bool wait_rendered( thread_t self ) {
		return wait( self, rnd_state, rdy_state );
	}

	inline bool wait_ready   ( thread_t self ) {
		return wait( self, rdy_state, rnd_state );
	}

	inline bool publish() {
		return publish( rdy_state );
	}

	inline bool release() {
		return publish( rnd_state );
	}
};

const thread_t Frame::reset = nullptr;
const thread_t Frame::set   = reinterpret_cast<thread_t>(1);

std::unique_ptr<Frame[]> frames;
volatile unsigned last_produced = 0;

//--------------------
// Threads
thread_t volatile the_stats_thread = nullptr;

inline void fence(void) {
	std::atomic_thread_fence(std::memory_order_seq_cst);
}

struct {
	struct {
		volatile unsigned long long   parks = 0;
		volatile unsigned long long unparks = 0;
	} sim;
	struct {
		volatile unsigned long long   parks = 0;
		volatile unsigned long long unparks = 0;
	} rend;

	struct {
		volatile unsigned long long ran = 0;
		volatile unsigned long long saw = 0;
	} stats;
} thrd_stats;

void Stats( thread_t self ) {
	the_stats_thread = self;
	fence();
	thrdlib::park( self );

	std::vector<bool> seen;
	seen.resize(nproduce, false);

	while(last_produced < nproduce) {
		thrdlib::yield();
		thrd_stats.stats.ran++;
		if( last_produced > 0 ) seen.at(last_produced - 1) = true;
	}

	thrd_stats.stats.saw = std::count(seen.begin(), seen.end(), true);
}

typedef uint64_t __wyhash64_state_t;
static inline uint64_t __wyhash64( __wyhash64_state_t & state ) {
	state += 0x60bee2bee120fc15;
	__uint128_t tmp;
	tmp = (__uint128_t) state * 0xa3b195354a39b70d;
	uint64_t m1 = (tmp >> 64) ^ tmp;
	tmp = (__uint128_t)m1 * 0x1b03738712fad5c9;
	uint64_t m2 = (tmp >> 64) ^ tmp;
	return m2;
}

void Simulator( thread_t self ) {
	for(unsigned i = 0; i < nproduce; i++) {
		auto & frame = frames[i % nframes];
		// Wait for the frames to be rendered
		if( frame.wait_rendered( self ) ) {
			thrd_stats.sim.parks++;
		}

		__wyhash64_state_t state = 0;

		// Write the frame information
		frame.number = i;
		for( unsigned x = 0; x < fsize; x++ ) {
			frame.data[x] = __wyhash64(state);
		}
		std::cout << "Simulated " << i << std::endl;
		last_produced = i+1;

		// Publish it
		if( frame.publish()  ) {
			thrd_stats.sim.unparks++;
		}
	}
}

void Renderer( thread_t self ) {
	thrdlib::unpark( the_stats_thread );
	for(unsigned i = 0; i < nproduce; i++) {
		auto & frame = frames[i % nframes];
		// Wait for the frames to be ready
		if( frame.wait_ready( self ) ) {
			thrd_stats.rend.parks++;
		}

		// Render the frame
		unsigned total = 0;
		for( unsigned x = 0; x < fsize; x++ ) {
			total += frame.data[x];
		}

		std::cout << "Rendered " << i << std::endl;
		// assert(total == i * fsize);

		// Release
		if( frame.release() ) {
			thrd_stats.rend.unparks++;
		}
	}

}



int main(int argc, char * argv[]) {
	nframes  = 3;
	fsize    = 3840 * 2160 * 4 * 4;
	nproduce = 60;

	for(;;) {
		static struct option options[] = {
			{"buff",  required_argument, 0, 'b'},
			{"nprod",  required_argument, 0, 'p'},
			{"fsize",   required_argument, 0, 'f'},
			{0, 0, 0, 0}
		};

		int idx = 0;
		int opt = getopt_long(argc, argv, "b:p:f:", options, &idx);

		std::string arg = optarg ? optarg : "";
		size_t len = 0;
		switch(opt) {
			// Exit Case
			case -1:
				goto run;
			case 'b':
				try {
					nframes = std::stoul(optarg, &len);
					if(nframes == 0 || len != arg.size()) { throw std::invalid_argument(""); }
				} catch(std::invalid_argument &) {
					std::cerr << "Number of buffered frames must be at least 1, was" << arg << std::endl;
					goto usage;
				}
				break;
			case 'p':
				try {
					nproduce = std::stoul(optarg, &len);
					if(nproduce == 0 || len != arg.size()) { throw std::invalid_argument(""); }
				} catch(std::invalid_argument &) {
					std::cerr << "Number of produced frames must be at least 1, was" << arg << std::endl;
					goto usage;
				}
				break;
			case 'f':
				try {
					fsize = std::stoul(optarg, &len);
					if(fsize == 0 || len != arg.size()) { throw std::invalid_argument(""); }
				} catch(std::invalid_argument &) {
					std::cerr << "Size of produced frames must be at least 1, was" << arg << std::endl;
					goto usage;
				}
				break;
			// Other cases
			default: /* ? */
				std::cerr << opt << std::endl;
			usage:
				std::cerr << "Usage: " << argv[0] << " [options]" << std::endl;
				std::cerr << std::endl;
				std::cerr << "  -b, --buff=COUNT    Number of frames to buffer" << std::endl;
				std::cerr << "  -p, --nprod=COUNT   Number of frames to produce" << std::endl;
				std::cerr << "  -f, --fsize=SIZE    Size of each frame in bytes" << std::endl;
				std::exit(1);
		}
	}
	run:
	frames.reset(new Frame[nframes]);
	for(unsigned i = 0; i < nframes; i++) {
		frames[i].number = 0;
		frames[i].data.reset(new unsigned char[fsize]);
	}
	std::cout << "Created frames of " << fsize << " bytes" << std::endl;
	std::cout << "(Buffering " << nframes << ")" << std::endl;

	thrdlib::init( 2 );

	thread_t stats     = thrdlib::create( Stats );
	std::cout << "Created Stats Thread" << std::endl;
	while( the_stats_thread == nullptr ) thrdlib::yield();

	std::cout << "Creating Main Threads" << std::endl;
	thread_t renderer  = thrdlib::create( Renderer  );
	thread_t simulator = thrdlib::create( Simulator );

	std::cout << "Running" << std::endl;

	thrdlib::join( simulator );
	thrdlib::join( renderer  );
	thrdlib::join( stats     );

	thrdlib::clean();

	std::cout << "----------" << std::endl;
	std::cout << "# Parks" << std::endl;
	std::cout << "  Renderer   park: " << thrd_stats. sim.  parks << std::endl;
	std::cout << "  Renderer unpark: " << thrd_stats. sim.unparks << std::endl;
	std::cout << " Simulator   park: " << thrd_stats.rend.  parks << std::endl;
	std::cout << " Simulator unpark: " << thrd_stats.rend.unparks << std::endl;

	std::cout << "Stats thread" << std::endl;
	std::cout << " Ran             : " << thrd_stats.stats.ran << " times" << std::endl;
	std::cout << " Saw             : " << thrd_stats.stats.saw << " (" << ((100.f * thrd_stats.stats.saw) / nproduce) << "%)" << std::endl;
}