#pragma once #include "utils.hpp" class snzm_t { class node; public: const unsigned depth; const unsigned mask; const int root; std::unique_ptr nodes; snzm_t(unsigned numLists); void arrive(int idx) { int i = idx & mask; nodes[i].arrive( idx >> depth); } void depart(int idx) { int i = idx & mask; nodes[i].depart( idx >> depth ); } bool query() const { return nodes[root].query(); } size_t masks( unsigned node ) { /* paranoid */ assert( (node & mask) == node ); return nodes[node].mask; } private: class __attribute__((aligned(64))) node { friend class snzm_t; private: union val_t { static constexpr char Half = -1; uint64_t _all; struct __attribute__((packed)) { char cnt; uint64_t ver:56; }; bool cas(val_t & exp, char _cnt, uint64_t _ver) volatile { val_t t; t.ver = _ver; t.cnt = _cnt; /* paranoid */ assert(t._all == ((_ver << 8) | ((unsigned char)_cnt))); return __atomic_compare_exchange_n(&this->_all, &exp._all, t._all, false, __ATOMIC_SEQ_CST, __ATOMIC_SEQ_CST); } bool cas(val_t & exp, const val_t & tar) volatile { return __atomic_compare_exchange_n(&this->_all, &exp._all, tar._all, false, __ATOMIC_SEQ_CST, __ATOMIC_SEQ_CST); } val_t() : _all(0) {} val_t(const volatile val_t & o) : _all(o._all) {} }; //-------------------------------------------------- // Hierarchical node void arrive_h() { int undoArr = 0; bool success = false; while(!success) { auto x{ value }; /* paranoid */ assert(x.cnt <= 120); if( x.cnt >= 1 ) { if( value.cas(x, x.cnt + 1, x.ver ) ) { success = true; } } /* paranoid */ assert(x.cnt <= 120); if( x.cnt == 0 ) { if( value.cas(x, val_t::Half, x.ver + 1) ) { success = true; x.cnt = val_t::Half; x.ver = x.ver + 1; } } /* paranoid */ assert(x.cnt <= 120); if( x.cnt == val_t::Half ) { /* paranoid */ assert(parent); parent->arrive(); if( !value.cas(x, 1, x.ver) ) { undoArr = undoArr + 1; } } } for(int i = 0; i < undoArr; i++) { /* paranoid */ assert(parent); parent->depart(); } } void depart_h() { while(true) { auto x = (const val_t)value; /* paranoid */ assertf(x.cnt >= 1, "%d", x.cnt); if( value.cas( x, x.cnt - 1, x.ver ) ) { if( x.cnt == 1 ) { /* paranoid */ assert(parent); parent->depart(); } return; } } } //-------------------------------------------------- // Root node void arrive_r() { __atomic_fetch_add(&value._all, 1, __ATOMIC_SEQ_CST); } void depart_r() { __atomic_fetch_sub(&value._all, 1, __ATOMIC_SEQ_CST); } //-------------------------------------------------- // Interface node void arrive() { /* paranoid */ assert(!is_leaf); if(is_root()) arrive_r(); else arrive_h(); } void depart() { /* paranoid */ assert(!is_leaf); if(is_root()) depart_r(); else depart_h(); } private: volatile val_t value; volatile size_t mask = 0; class node * parent = nullptr; bool is_leaf = false; bool is_root() { return parent == nullptr; } public: void arrive( int bit ) { /* paranoid */ assert( is_leaf ); /* paranoid */ assert( (mask & ( 1 << bit )) == 0 ); __atomic_fetch_add( &mask, 1 << bit, __ATOMIC_RELAXED ); arrive_h(); } void depart( int bit ) { /* paranoid */ assert( is_leaf ); /* paranoid */ assert( (mask & ( 1 << bit )) != 0 ); depart_h(); __atomic_fetch_sub( &mask, 1 << bit, __ATOMIC_RELAXED ); } bool query() { /* paranoid */ assert(is_root()); return value._all > 0; } }; }; snzm_t::snzm_t(unsigned numLists) : depth( std::log2( numLists / 8 ) ) , mask( (1 << depth) - 1 ) , root( (1 << (depth + 1)) - 2 ) , nodes(new node[ root + 1 ]()) { int width = 1 << depth; std::cout << "SNZI with Mask: " << depth << "x" << width << "(" << mask << ")" << std::endl; for(int i = 0; i < root; i++) { nodes[i].is_leaf = i < width; nodes[i].parent = &nodes[(i / 2) + width ]; } }