
#include "../utils.hpp"

void consume(int i, int j) __attribute__((noinline));
void consume(int i, int j) {
	asm volatile("":: "rm" (i), "rm" (i) );
}

static inline unsigned rand_bit_sw(unsigned rnum, size_t mask) {
	unsigned bit = mask ? rnum % __builtin_popcountl(mask) : 0;
	uint64_t v = mask;   // Input value to find position with rank r.
	unsigned int r = bit + 1;// Input: bit's desired rank [1-64].
	unsigned int s;      // Output: Resulting position of bit with rank r [1-64]
	uint64_t a, b, c, d; // Intermediate temporaries for bit count.
	unsigned int t;      // Bit count temporary.

	// Do a normal parallel bit count for a 64-bit integer,
	// but store all intermediate steps.
	a =  v - ((v >> 1) & ~0UL/3);
	b = (a & ~0UL/5) + ((a >> 2) & ~0UL/5);
	c = (b + (b >> 4)) & ~0UL/0x11;
	d = (c + (c >> 8)) & ~0UL/0x101;


	t = (d >> 32) + (d >> 48);
	// Now do branchless select!
	s  = 64;
	s -= ((t - r) & 256) >> 3; r -= (t & ((t - r) >> 8));
	t  = (d >> (s - 16)) & 0xff;
	s -= ((t - r) & 256) >> 4; r -= (t & ((t - r) >> 8));
	t  = (c >> (s - 8)) & 0xf;
	s -= ((t - r) & 256) >> 5; r -= (t & ((t - r) >> 8));
	t  = (b >> (s - 4)) & 0x7;
	s -= ((t - r) & 256) >> 6; r -= (t & ((t - r) >> 8));
	t  = (a >> (s - 2)) & 0x3;
	s -= ((t - r) & 256) >> 7; r -= (t & ((t - r) >> 8));
	t  = (v >> (s - 1)) & 0x1;
	s -= ((t - r) & 256) >> 8;
	return s - 1;
}

static inline unsigned rand_bit_hw(unsigned rnum, size_t mask) {
	unsigned bit = mask ? rnum % __builtin_popcountl(mask) : 0;
	uint64_t picked = _pdep_u64(1ul << bit, mask);
	return picked ? __builtin_ctzl(picked) : 0;
}

struct TLS {
	Random rng = { 6 };
} tls;

const unsigned numLists = 64;

static inline void blind() {
	int i = tls.rng.next() % numLists;
	int j = tls.rng.next() % numLists;

	consume(i, j);
}

std::atomic_size_t list_mask[7];
static inline void bitmask_sw() {
	unsigned i, j;
	{
		// Pick two lists at random
		unsigned num = ((numLists - 1) >> 6) + 1;

		unsigned ri = tls.rng.next();
		unsigned rj = tls.rng.next();

		unsigned wdxi = (ri >> 6u) % num;
		unsigned wdxj = (rj >> 6u) % num;

		size_t maski = list_mask[wdxi].load(std::memory_order_relaxed);
		size_t maskj = list_mask[wdxj].load(std::memory_order_relaxed);

		unsigned bi = rand_bit_sw(ri, maski);
		unsigned bj = rand_bit_sw(rj, maskj);

		i = bi | (wdxi << 6);
		j = bj | (wdxj << 6);
	}

	consume(i, j);
}

static inline void bitmask_hw() {
	#if !defined(__BMI2__)
		#warning NO bmi2 for pdep rand_bit
		return;
	#endif
	unsigned i, j;
	{
		// Pick two lists at random
		unsigned num = ((numLists - 1) >> 6) + 1;

		unsigned ri = tls.rng.next();
		unsigned rj = tls.rng.next();

		unsigned wdxi = (ri >> 6u) % num;
		unsigned wdxj = (rj >> 6u) % num;

		size_t maski = list_mask[wdxi].load(std::memory_order_relaxed);
		size_t maskj = list_mask[wdxj].load(std::memory_order_relaxed);

		unsigned bi = rand_bit_hw(ri, maski);
		unsigned bj = rand_bit_hw(rj, maskj);

		i = bi | (wdxi << 6);
		j = bj | (wdxj << 6);
	}

	consume(i, j);
}

struct {
	const unsigned mask = 7;
	const unsigned depth = 3;
	const uint64_t indexes = 0x0706050403020100;
	uint64_t masks( unsigned node ) {
		return 0xff00ffff00ff;
	}
} snzm;
static inline void sparsemask() {
	#if !defined(__BMI2__)
		#warning NO bmi2 for sparse mask
		return;
	#endif
	unsigned i, j;
	{
		// Pick two random number
		unsigned ri = tls.rng.next();
		unsigned rj = tls.rng.next();

		// Pick two nodes from it
		unsigned wdxi = ri & snzm.mask;
		unsigned wdxj = rj & snzm.mask;

		// Get the masks from the nodes
		size_t maski = snzm.masks(wdxi);
		size_t maskj = snzm.masks(wdxj);

		uint64_t idxsi = _pext_u64(snzm.indexes, maski);
		uint64_t idxsj = _pext_u64(snzm.indexes, maskj);

		auto pi = __builtin_popcountll(maski);
		auto pj = __builtin_popcountll(maskj);

		ri = pi ? ri & ((pi >> 3) - 1) : 0;
		rj = pj ? rj & ((pj >> 3) - 1) : 0;

		unsigned bi = (idxsi >> (ri << 3)) & 0xff;
		unsigned bj = (idxsj >> (rj << 3)) & 0xff;

		i = (bi << snzm.depth) | wdxi;
		j = (bj << snzm.depth) | wdxj;
	}

	consume(i, j);
}

template<typename T>
void benchmark( T func, const std::string & name ) {
	std::cout << "Starting " << name << std::endl;
	auto before = Clock::now();
	const int N = 250'000'000;
	for(int i = 0; i < N; i++) {
		func();
	}
	auto after = Clock::now();
	duration_t durr = after - before;
	double duration = durr.count();
	std::cout << "Duration(s) : " << duration << std::endl;
	std::cout << "Ops/sec     : " << uint64_t(N / duration) << std::endl;
	std::cout << "ns/Op       : " << double(duration * 1'000'000'000.0 / N) << std::endl;
	std::cout << std::endl;
}

int main() {
	std::cout.imbue(std::locale(""));

	benchmark(blind, "Blind guess");
	benchmark(bitmask_sw, "Dense bitmask");
	benchmark(bitmask_hw, "Dense bitmask with Parallel Deposit");
	benchmark(sparsemask, "Parallel Extract bitmask");
}