#include <cstddef>
#include <cstdint>
#include <x86intrin.h>

__attribute__((noinline)) unsigned nthSetBit(size_t mask, unsigned bit) {
	uint64_t v = mask;   // Input value to find position with rank r.
	unsigned int r = bit;// Input: bit's desired rank [1-64].
	unsigned int s;      // Output: Resulting position of bit with rank r [1-64]
	uint64_t a, b, c, d; // Intermediate temporaries for bit count.
	unsigned int t;      // Bit count temporary.

	// Do a normal parallel bit count for a 64-bit integer,
	// but store all intermediate steps.
	// a = (v & 0x5555...) + ((v >> 1) & 0x5555...);
	a =  v - ((v >> 1) & ~0UL/3);
	// b = (a & 0x3333...) + ((a >> 2) & 0x3333...);
	b = (a & ~0UL/5) + ((a >> 2) & ~0UL/5);
	// c = (b & 0x0f0f...) + ((b >> 4) & 0x0f0f...);
	c = (b + (b >> 4)) & ~0UL/0x11;
	// d = (c & 0x00ff...) + ((c >> 8) & 0x00ff...);
	d = (c + (c >> 8)) & ~0UL/0x101;


	t = (d >> 32) + (d >> 48);
	// Now do branchless select!
	s  = 64;
	// if (r > t) {s -= 32; r -= t;}
	s -= ((t - r) & 256) >> 3; r -= (t & ((t - r) >> 8));
	t  = (d >> (s - 16)) & 0xff;
	// if (r > t) {s -= 16; r -= t;}
	s -= ((t - r) & 256) >> 4; r -= (t & ((t - r) >> 8));
	t  = (c >> (s - 8)) & 0xf;
	// if (r > t) {s -= 8; r -= t;}
	s -= ((t - r) & 256) >> 5; r -= (t & ((t - r) >> 8));
	t  = (b >> (s - 4)) & 0x7;
	// if (r > t) {s -= 4; r -= t;}
	s -= ((t - r) & 256) >> 6; r -= (t & ((t - r) >> 8));
	t  = (a >> (s - 2)) & 0x3;
	// if (r > t) {s -= 2; r -= t;}
	s -= ((t - r) & 256) >> 7; r -= (t & ((t - r) >> 8));
	t  = (v >> (s - 1)) & 0x1;
	// if (r > t) s--;
	s -= ((t - r) & 256) >> 8;
	// s = 65 - s;
	return s;
}

unsigned rand_bit(unsigned rnum, uint64_t mask) {
	unsigned bit = mask ? rnum % __builtin_popcountl(mask) : 0;
#if defined(BRANCHLESS)
	uint64_t v = mask;   // Input value to find position with rank r.
	unsigned int r = bit + 1;// Input: bit's desired rank [1-64].
	unsigned int s;      // Output: Resulting position of bit with rank r [1-64]
	uint64_t a, b, c, d; // Intermediate temporaries for bit count.
	unsigned int t;      // Bit count temporary.

	// Do a normal parallel bit count for a 64-bit integer,
	// but store all intermediate steps.
	// a = (v & 0x5555...) + ((v >> 1) & 0x5555...);
	a =  v - ((v >> 1) & ~0UL/3);
	// b = (a & 0x3333...) + ((a >> 2) & 0x3333...);
	b = (a & ~0UL/5) + ((a >> 2) & ~0UL/5);
	// c = (b & 0x0f0f...) + ((b >> 4) & 0x0f0f...);
	c = (b + (b >> 4)) & ~0UL/0x11;
	// d = (c & 0x00ff...) + ((c >> 8) & 0x00ff...);
	d = (c + (c >> 8)) & ~0UL/0x101;


	t = (d >> 32) + (d >> 48);
	// Now do branchless select!
	s  = 64;
	// if (r > t) {s -= 32; r -= t;}
	s -= ((t - r) & 256) >> 3; r -= (t & ((t - r) >> 8));
	t  = (d >> (s - 16)) & 0xff;
	// if (r > t) {s -= 16; r -= t;}
	s -= ((t - r) & 256) >> 4; r -= (t & ((t - r) >> 8));
	t  = (c >> (s - 8)) & 0xf;
	// if (r > t) {s -= 8; r -= t;}
	s -= ((t - r) & 256) >> 5; r -= (t & ((t - r) >> 8));
	t  = (b >> (s - 4)) & 0x7;
	// if (r > t) {s -= 4; r -= t;}
	s -= ((t - r) & 256) >> 6; r -= (t & ((t - r) >> 8));
	t  = (a >> (s - 2)) & 0x3;
	// if (r > t) {s -= 2; r -= t;}
	s -= ((t - r) & 256) >> 7; r -= (t & ((t - r) >> 8));
	t  = (v >> (s - 1)) & 0x1;
	// if (r > t) s--;
	s -= ((t - r) & 256) >> 8;
	// s = 65 - s;
	return s - 1;
#elif defined(LOOP)
	for(unsigned i = 0; i < bit; i++) {
		mask ^= (1ul << (__builtin_ffsl(mask) - 1ul));
	}
	return __builtin_ffsl(mask) - 1ul;
#elif defined(PDEP)
	uint64_t picked = _pdep_u64(1ul << bit, mask);
	return __builtin_ffsl(picked) - 1ul;
#else
#error must define LOOP, PDEP or BRANCHLESS
#endif
}

#include <cassert>
#include <atomic>
#include <chrono>
#include <iomanip>
#include <iostream>
#include <locale>
#include <thread>

#include <unistd.h>

class barrier_t {
public:
	barrier_t(size_t total)
		: waiting(0)
		, total(total)
	{}

	void wait(unsigned) {
		size_t target = waiting++;
		target = (target - (target % total)) + total;
		while(waiting < target)
			asm volatile("pause");

		assert(waiting < (1ul << 60));
    	}

private:
	std::atomic<size_t> waiting;
	size_t total;
};

class Random {
private:
	unsigned int seed;
public:
	Random(int seed) {
		this->seed = seed;
	}

	/** returns pseudorandom x satisfying 0 <= x < n. **/
	unsigned int next() {
		seed ^= seed << 6;
		seed ^= seed >> 21;
		seed ^= seed << 7;
		return seed;
    	}
};

using Clock = std::chrono::high_resolution_clock;
using duration_t = std::chrono::duration<double>;
using std::chrono::nanoseconds;

template<typename Ratio, typename T>
T duration_cast(T seconds) {
	return std::chrono::duration_cast<std::chrono::duration<T, Ratio>>(std::chrono::duration<T>(seconds)).count();
}

void waitfor(double & duration, barrier_t & barrier, std::atomic_bool & done) {


	std::cout << "Starting" << std::endl;
	auto before = Clock::now();
	barrier.wait(0);

	while(true) {
		usleep(100000);
		auto now = Clock::now();
		duration_t durr = now - before;
		if( durr.count() > duration ) {
			done = true;
			break;
		}
		std::cout << "\r" << std::setprecision(4) << durr.count();
		std::cout.flush();
	}

	barrier.wait(0);
	auto after = Clock::now();
	duration_t durr = after - before;
	duration = durr.count();
	std::cout << "\rClosing down" << std::endl;
}

__attribute__((noinline)) void body(Random & rand) {
	uint64_t mask = (uint64_t(rand.next()) << 32ul) | uint64_t(rand.next());
	unsigned idx = rand.next();

	unsigned bit = rand_bit(idx, mask);

	if(__builtin_expect(((1ul << bit) & mask) == 0, false)) {
		std::cerr << std::hex <<  "Rand " << idx << " from " << mask;
		std::cerr << " gave " << (1ul << bit) << "(" << std::dec << bit << ")" << std::endl;
		std::abort();
	}
}

void runRandBit(double duration) {

	std::atomic_bool done  = { false };
	barrier_t barrier(2);

	size_t count = 0;
	std::thread thread([&done, &barrier, &count]() {

		Random rand(22);

		barrier.wait(1);

		for(;!done; count++) {
			body(rand);
		}

		barrier.wait(1);
	});

	waitfor(duration, barrier, done);
	thread.join();

	size_t ops = count;
	size_t ops_sec = size_t(double(ops) / duration);
	auto dur_nano = duration_cast<std::nano>(1.0);

	std::cout << "Duration      : " << duration << "s\n";
	std::cout << "ns/Op         : " << ( dur_nano / ops )<< "\n";
	std::cout << "Ops/sec       : " << ops_sec << "\n";
	std::cout << "Total ops     : " << ops << std::endl;

}

int main() {
	std::cout.imbue(std::locale(""));
	runRandBit(5);
}