#include <cstdio>
#include <mutex>
#include <thread>
#include <chrono>
#include <stdlib.h>
#include "cppLock.hpp"

#include "../bench.h"

cpp_test_spinlock LOCKS;
cpp_test_spinlock  ** lock_arr;

inline void locks( size_t * arr ) {
    if (num_locks == 2) {
        std::scoped_lock lock( *lock_arr[arr[0]], *lock_arr[arr[1]] );
    } else if (num_locks == 4) {
        std::scoped_lock lock( *lock_arr[arr[0]], *lock_arr[arr[1]], *lock_arr[arr[2]], *lock_arr[arr[3]] );
    } else if (num_locks == 8) {
        std::scoped_lock lock( *lock_arr[arr[0]], *lock_arr[arr[1]], *lock_arr[arr[2]], *lock_arr[arr[3]], *lock_arr[arr[4]], *lock_arr[arr[5]], *lock_arr[arr[6]], *lock_arr[arr[7]] );
    }
}

bool done = false;
uint64_t total = 0;
size_t num_gen = 100; // number of rand orderings per thd
size_t ** rand_arrs;

// generate repeatable orderings for each experiment
void gen_orders() {
    rand_arrs = new size_t *[threads];
    for ( int i = 0; i < threads; i++ )
        rand_arrs[i] = new size_t[ num_locks * num_gen ];

    size_t work_arr[num_locks];

    for ( int i = 0; i < num_locks; i++ )
        work_arr[i] = i;

    size_t curr_idx;
    for ( int i = 0; i < threads; i++ ) {
        state = i;
        curr_idx = 0;
        for ( int j = 0; j < num_gen; j++ ) {
            for ( size_t k = num_locks; k > 0; k-- ) {
                size_t rand_idx = next_int() % k; // choose one of remaining elems in work_arr
                rand_arrs[i][curr_idx] = work_arr[rand_idx];
                curr_idx++;

                // swap chosen elem to end so it isn't picked again
                size_t temp = work_arr[rand_idx];
                work_arr[rand_idx] = work_arr[k - 1];
                work_arr[k - 1] = temp;
            }
        }
        
    }
}

void thread_main( int id ) {
    size_t * my_arr = rand_arrs[id];
    uint64_t count = 0;
    while (true) {
        locks( my_arr + (count % num_gen) * num_locks );
        count++;
        if (done) break;
    }
    __atomic_add_fetch(&total, count, __ATOMIC_SEQ_CST);
}

int main( int argc, char * argv[] ) {
	BENCH_START()
    if ( num_locks == -1 ) { printf("must pass # of locks to program!\n"); exit( EXIT_FAILURE ); }
    
    lock_arr = new cpp_test_spinlock *[ num_locks ];

    if (num_locks >= 2) {
        lock_arr[0] = &l1; lock_arr[1] = &l2;
    }
    if (num_locks >= 4) {
        lock_arr[2] = &l3; lock_arr[3] = &l4;
    }
    if (num_locks == 8) {
        lock_arr[4] = &l5; lock_arr[5] = &l6; lock_arr[6] = &l7; lock_arr[7] = &l8;
    }

    gen_orders();

    std::thread myThreads[threads];
    for (int i = 0; i < threads; i++) {
        myThreads[i] = std::thread(thread_main, i); // move constructed
    }

    std::this_thread::sleep_for (std::chrono::seconds(10));
    done = true;
    
    for (int i = 0; i < threads; i++) {
        myThreads[i].join();
    }

    for ( int i = 0; i < threads; i++ )
        delete[] rand_arrs[i];
    delete[] rand_arrs;
    delete[] lock_arr;

	printf( "%lu\n", total );
}

// Local Variables: //
// tab-width: 4 //
// End: //
