#pragma once

#include "utils.hpp"


class snzm_t {
	class node;
public:
	const unsigned depth;
	const unsigned mask;
	const int root;
	std::unique_ptr<snzm_t::node[]> nodes;

	#if defined(__BMI2__)
		const uint64_t indexes = 0x0706050403020100;
	#endif

	snzm_t(unsigned numLists);

	void arrive(int idx) {
		int i = idx & mask;
		nodes[i].arrive( idx >> depth);
	}

	void depart(int idx) {
		int i = idx & mask;
		nodes[i].depart( idx >> depth );
	}

	bool query() const {
		return nodes[root].query();
	}

	uint64_t masks( unsigned node ) {
		/* paranoid */ assert( (node & mask) == node );
		#if defined(__BMI2__)
			return nodes[node].mask_all;
		#else
			return nodes[node].mask;
		#endif
	}

private:
	class __attribute__((aligned(128))) node {
		friend class snzm_t;
	private:

		union val_t {
			static constexpr char Half = -1;

			uint64_t _all;
			struct __attribute__((packed)) {
				char cnt;
				uint64_t ver:56;
			};

			bool cas(val_t & exp, char _cnt, uint64_t _ver) volatile {
				val_t t;
				t.ver = _ver;
				t.cnt = _cnt;
				/* paranoid */ assert(t._all == ((_ver << 8) | ((unsigned char)_cnt)));
				return __atomic_compare_exchange_n(&this->_all, &exp._all, t._all, false, __ATOMIC_SEQ_CST, __ATOMIC_SEQ_CST);
			}

			bool cas(val_t & exp, const val_t & tar) volatile {
				return __atomic_compare_exchange_n(&this->_all, &exp._all, tar._all, false, __ATOMIC_SEQ_CST, __ATOMIC_SEQ_CST);
			}

			val_t() : _all(0) {}
			val_t(const volatile val_t & o) : _all(o._all) {}
		};

		//--------------------------------------------------
		// Hierarchical node
		void arrive_h() {
			int undoArr = 0;
			bool success = false;
			while(!success) {
				auto x{ value };
				/* paranoid */ assert(x.cnt <= 120);
				if( x.cnt >= 1 ) {
					if( value.cas(x, x.cnt + 1, x.ver ) ) {
						success = true;
					}
				}
				/* paranoid */ assert(x.cnt <= 120);
				if( x.cnt == 0 ) {
					if( value.cas(x, val_t::Half, x.ver + 1) ) {
						success = true;
						x.cnt = val_t::Half;
						x.ver = x.ver + 1;
					}
				}
				/* paranoid */ assert(x.cnt <= 120);
				if( x.cnt == val_t::Half ) {
					/* paranoid */ assert(parent);
					parent->arrive();
					if( !value.cas(x, 1, x.ver) ) {
						undoArr = undoArr + 1;
					}
				}
			}

			for(int i = 0; i < undoArr; i++) {
				/* paranoid */ assert(parent);
				parent->depart();
			}
		}

		void depart_h() {
			while(true) {
				auto x = (const val_t)value;
				/* paranoid */ assertf(x.cnt >= 1, "%d", x.cnt);
				if( value.cas( x, x.cnt - 1, x.ver ) ) {
					if( x.cnt == 1 ) {
						/* paranoid */ assert(parent);
						parent->depart();
					}
					return;
				}
			}
		}

		//--------------------------------------------------
		// Root node
		void arrive_r() {
			__atomic_fetch_add(&value._all, 1, __ATOMIC_SEQ_CST);
		}

		void depart_r() {
			__atomic_fetch_sub(&value._all, 1, __ATOMIC_SEQ_CST);
		}

		//--------------------------------------------------
		// Interface node
		void arrive() {
			/* paranoid */ assert(!is_leaf);
			if(is_root()) arrive_r();
			else arrive_h();
		}

		void depart() {
			/* paranoid */ assert(!is_leaf);
			if(is_root()) depart_r();
			else depart_h();
		}

	private:
		volatile val_t value;
		#if defined(__BMI2__)
			union __attribute__((packed)) {
				volatile uint8_t mask[8];
				volatile uint64_t mask_all;
			};
		#else
			volatile size_t mask = 0;
		#endif

		class node * parent = nullptr;
		bool is_leaf = false;

		bool is_root() {
			return parent == nullptr;
		}

	public:
		void arrive( int bit ) {
			/* paranoid */ assert( is_leaf );

			arrive_h();
			#if defined(__BMI2__)
				/* paranoid */ assert( bit < 8 );
				mask[bit] = 0xff;
			#else
				/* paranoid */ assert( (mask & ( 1 << bit )) == 0 );
				__atomic_fetch_add( &mask, 1 << bit, __ATOMIC_RELAXED );
			#endif

		}

		void depart( int bit ) {
			/* paranoid */ assert( is_leaf );

			#if defined(__BMI2__)
				/* paranoid */ assert( bit < 8 );
				mask[bit] = 0x00;
			#else
				/* paranoid */ assert( (mask & ( 1 << bit )) != 0 );
				__atomic_fetch_sub( &mask, 1 << bit, __ATOMIC_RELAXED );
			#endif
			depart_h();
		}

		bool query() {
			/* paranoid */ assert(is_root());
			return value._all > 0;
		}
	};
};

snzm_t::snzm_t(unsigned numLists)
	: depth( std::log2( numLists / 8 ) )
	, mask( (1 << depth) - 1 )
	, root( (1 << (depth + 1)) - 2 )
	, nodes(new node[ root + 1 ]())
{
	int width = 1 << depth;
	std::cout << "SNZI with Mask: " << depth << "x" << width << "(" << mask << ")" << std::endl;
	for(int i = 0; i < root; i++) {
		nodes[i].is_leaf = i < width;
		nodes[i].parent = &nodes[(i / 2) + width ];
	}
}