package main

import (
	"fmt"
	"sync"
	"math/rand"
	"time"
	"runtime"
	"os"
	"strconv"
)

var Processors, Channels, Producers, Consumers, ChannelSize int = 1, 4, 1, 1, 128
var cons_done, prod_done bool = false, false;
var total_operations, cons_check, prod_check uint64 = 0, 0, 0
var m sync.Mutex

var prodJoin chan int = make(chan int, Producers + 1)
var consJoin chan int = make(chan int, Consumers + 1)

func getRandArray() []int {
	chanIndices := make( [] int, Channels )
	for i := 0; i < Channels; i += 1 {
		chanIndices[i] = i
	}
	for i := 0; i < Channels; i += 1 {
		var loc_1 int  = rand.Intn(Channels) % Channels
        var loc_2 int  = rand.Intn(Channels) % Channels;
        var temp int = chanIndices[loc_1]
        chanIndices[loc_1] = chanIndices[loc_2]
        chanIndices[loc_2] = temp
	}
	return chanIndices
}

func consumer( chans [] chan uint64 ) {
	var count uint64 = 0
	var checksum uint64 = 0
	var i int = 0
	chanIndices := getRandArray()
	for {
		if cons_done { break }
		j := <- chans[ chanIndices[ i ] ]
		i = (i + 1) % Channels
		checksum = checksum ^ j
		if ! prod_done { count++ }
	}
	m.Lock()
	total_operations += count
	cons_check = cons_check ^ checksum
	m.Unlock()
	consJoin <- 0
}

func producer( chans [] chan uint64 ) {
	var count uint64 = 0
	var i int = 0
	var checksum uint64 = 0
	chanIndices := getRandArray()
	for {
		if prod_done { break }
		chans[ chanIndices[ i ] ] <- count
		i = (i + 1) % Channels
		checksum = checksum ^ count
		count++
	}
	m.Lock()
	total_operations += count
	prod_check = prod_check ^ checksum
	m.Unlock()
	prodJoin <- 0
}

func usage() {
	fmt.Printf( "Usage: %v " +
		"[ processors (> 0) | 'd' (default %v) ] " +
		"[ ChannelSize (> 0) | 'd' (default %v) ]\n",
		os.Args[0], Processors, ChannelSize );
	os.Exit( 1 );
}

func main() {
	switch len( os.Args ) {
		case 3:
			if os.Args[2] != "d" {							// default ?
				Channels, _ = strconv.Atoi( os.Args[2] )
					if Channels < 0 { usage(); }
			} // if
		fallthrough
		case 2:
			if os.Args[1] != "d" {							// default ?
				Processors, _ = strconv.Atoi( os.Args[1] )
				if Processors < 1 { usage(); }
			} // if
		case 1:											// use defaults
		default:
		usage();
	} // switch
	runtime.GOMAXPROCS( Processors );
	Producers = Processors /2
	Consumers = Processors /2

	// fmt.Println("Processors: ",Processors," Channels: ",Channels," Prods: ",Producers," Cons: ",Consumers," Channel Size: ",ChannelSize)

	chans := make( [] chan uint64, Channels )
	for i := range chans {
		chans[i] = make(chan uint64, ChannelSize)
	}

	for j := 0; j < Consumers; j++ {
		go consumer( chans )
	}

	for j := 0; j < Producers; j++ {
		go producer( chans )
	}

	// wait 10 seconds
	time.Sleep(time.Second * 10)
	prod_done = true

	for j := 0; j < Producers; j++ {
		<-prodJoin
	}

	cons_done = true

	for i := range chans {
		for j := 0; j < Consumers; j++ {
			select {
				case chans[i] <- 0:
					
				default:
					break
			}
		}
	}
	for j := 0; j < Consumers; j++{
		<-consJoin
	}
	for i := range chans {
		L: for {
			select {
				case k := <-chans[i]:
					cons_check = cons_check ^ k
				default:
					break L
			}
		}
	}
	if cons_check != prod_check {
		fmt.Println("\nChecksum mismatch: Cons: %d, Prods: %d", cons_check, prod_check)
	}

    fmt.Println(total_operations)
}