#pragma once

#define SNZI_PACKED

#include "utils.hpp"


class snzip_t {
	class node;
	class node_aligned;
public:
	const unsigned mask;
	const int root;
	std::unique_ptr<snzip_t::node[]> leafs;
	std::unique_ptr<snzip_t::node_aligned[]> nodes;

	snzip_t(unsigned depth);

	void arrive(int idx) {
		// idx >>= 1;
		idx %= mask;
		leafs[idx].arrive();
	}

	void depart(int idx) {
		// idx >>= 1;
		idx %= mask;
		leafs[idx].depart();
	}

	bool query() const {
		return nodes[root].query();
	}


private:
	class __attribute__((aligned(32))) node {
		friend class snzip_t;
	private:

		union val_t {
			static constexpr char Half = -1;

			uint64_t _all;
			struct __attribute__((packed)) {
				char cnt;
				uint64_t ver:56;
			};

			bool cas(val_t & exp, char _cnt, uint64_t _ver) volatile {
				val_t t;
				t.ver = _ver;
				t.cnt = _cnt;
				/* paranoid */ assert(t._all == ((_ver << 8) | ((unsigned char)_cnt)));
				return __atomic_compare_exchange_n(&this->_all, &exp._all, t._all, false, __ATOMIC_SEQ_CST, __ATOMIC_SEQ_CST);
			}

			bool cas(val_t & exp, const val_t & tar) volatile {
				return __atomic_compare_exchange_n(&this->_all, &exp._all, tar._all, false, __ATOMIC_SEQ_CST, __ATOMIC_SEQ_CST);
			}

			val_t() : _all(0) {}
			val_t(const volatile val_t & o) : _all(o._all) {}
		};

		//--------------------------------------------------
		// Hierarchical node
		void arrive_h() {
			int undoArr = 0;
			bool success = false;
			while(!success) {
				auto x{ value };
				/* paranoid */ assert(x.cnt <= 120);
				if( x.cnt >= 1 ) {
					if( value.cas(x, x.cnt + 1, x.ver ) ) {
						success = true;
					}
				}
				/* paranoid */ assert(x.cnt <= 120);
				if( x.cnt == 0 ) {
					if( value.cas(x, val_t::Half, x.ver + 1) ) {
						success = true;
						x.cnt = val_t::Half;
						x.ver = x.ver + 1;
					}
				}
				/* paranoid */ assert(x.cnt <= 120);
				if( x.cnt == val_t::Half ) {
					/* paranoid */ assert(parent);
					if(undoArr == 2) {
						undoArr--;
					} else {
						parent->arrive();
					}
					if( !value.cas(x, 1, x.ver) ) {
						undoArr = undoArr + 1;
					}
				}
			}

			for(int i = 0; i < undoArr; i++) {
				/* paranoid */ assert(parent);
				parent->depart();
			}
		}

		void depart_h() {
			while(true) {
				auto x = (const val_t)value;
				/* paranoid */ assertf(x.cnt >= 1, "%d", x.cnt);
				if( value.cas( x, x.cnt - 1, x.ver ) ) {
					if( x.cnt == 1 ) {
						/* paranoid */ assert(parent);
						parent->depart();
					}
					return;
				}
			}
		}

		//--------------------------------------------------
		// Root node
		void arrive_r() {
			__atomic_fetch_add(&value._all, 1, __ATOMIC_SEQ_CST);
		}

		void depart_r() {
			__atomic_fetch_sub(&value._all, 1, __ATOMIC_SEQ_CST);
		}

	private:
		volatile val_t value;
		class node * parent = nullptr;

		bool is_root() {
			return parent == nullptr;
		}

	public:
		void arrive() {
			if(is_root()) arrive_r();
			else arrive_h();
		}

		void depart() {
			if(is_root()) depart_r();
			else depart_h();
		}

		bool query() {
			/* paranoid */ assert(is_root());
			return value._all > 0;
		}
	};

	class __attribute__((aligned(128))) node_aligned : public node {};
};

snzip_t::snzip_t(unsigned depth)
	: mask( std::pow(2, depth) )
	, root( ((std::pow(2, depth + 1) - 1) / (2 -1)) - 1 - mask )
	, leafs(new node[ mask ]())
	, nodes(new node_aligned[ root + 1 ]())
{
	int width = std::pow(2, depth);
	int hwdith = width / 2;
	std::cout << "SNZI: " << depth << "x" << width << "(" << mask - 1 << ") " << (sizeof(snzip_t::node) * (root + 1)) << " bytes" << std::endl;
	for(int i = 0; i < width; i++) {
		int idx = i % hwdith;
		std::cout << i << " -> " << idx + width << std::endl;
		leafs[i].parent = &nodes[ idx ];
	}

	for(int i = 0; i < root; i++) {
		int idx = (i / 2) + hwdith;
		std::cout << i + width << " -> " << idx + width << std::endl;
		nodes[i].parent = &nodes[ idx ];
	}
}