package main

import (
	"fmt"
	"sync"
	"time"
	"runtime"
	"os"
	"strconv"
)

var Processors, Tasks, BarrierSize int = 1, 1, 2
var done bool = false;
var total_operations uint64 = 0
var m sync.Mutex

var taskJoin chan int = make(chan int, Tasks + 1)

var barWait chan int = make(chan int, 2 * BarrierSize)
var entryWait chan int = make(chan int, 2 * BarrierSize)

func initBarrier() {
	for j := 0; j < BarrierSize; j++ {
		entryWait <- j
	}
}

func barrier() {
	ticket := <-entryWait
	if ( ticket == -1 ) {
		entryWait <- -1
		return
	}
	if ( ticket == BarrierSize - 1 ) {
		for j := 0; j < BarrierSize - 1; j++ {
			barWait <- j
		}
	} else {
		ticket = <- barWait
		if ( ticket == -1 ) {
			barWait <- -1
			return
		}
	}

	// last one out
	if ( BarrierSize == 1 || ticket == BarrierSize - 2 ) {
		for j := 0; j < BarrierSize; j++ {
			entryWait <- j
		}
	}
}

func task() {
	var count uint64 = 0
	for {
		if done { break }
		barrier()
		count++
	}
	m.Lock()
	total_operations += count
	// fmt.Print("C: ",count)
	m.Unlock()
	taskJoin <- 0
}

func usage() {
	fmt.Printf( "Usage: %v " +
		"[ processors (> 0) | 'd' (default %v) ] " +
		"[ BarrierSize (> 0) | 'd' (default %v) ]\n",
		os.Args[0], Processors, BarrierSize );
	os.Exit( 1 );
}

func main() {
	switch len( os.Args ) {
		case 3:
			if os.Args[2] != "d" {							// default ?
				BarrierSize, _ = strconv.Atoi( os.Args[2] )
				if BarrierSize < 1 { usage(); }
			} // if
		fallthrough
		case 2:
			if os.Args[1] != "d" {							// default ?
				Processors, _ = strconv.Atoi( os.Args[1] )
				if Processors < 1 { usage(); }
			} // if
		case 1:											// use defaults
		default:
		usage();
	} // switch
	runtime.GOMAXPROCS( Processors );
	Tasks = Processors

	if ( Tasks < BarrierSize ) {
        Tasks = BarrierSize
	}

	// fmt.Println("Processors: ",Processors," Channels: ",Channels," ProdsPerChan: ",ProdsPerChan," ConsPerChan: ",ConsPerChan," Channel Size: ",ChannelSize)
	taskJoin = make(chan int, Tasks + 1)
	barWait = make(chan int, 2 * BarrierSize)
	entryWait = make(chan int, 2 * BarrierSize)
	initBarrier()

	for j := 0; j < Tasks; j++ {
		go task()
	}
		
	// wait 10 seconds
	time.Sleep(time.Second * 10)
	// fmt.Println("prod done\n")
	done = true

	for j := 0; j < BarrierSize; j++ {
		barWait <- -1
		entryWait <- -1
	}

	for j := 0; j < Tasks; j++ {
		<-taskJoin
	}

    fmt.Println(total_operations)
}