package main

import (
	"fmt"
	"sync"
	"time"
	"runtime"
	"os"
	"strconv"
)

var Processors, Channels, ProdsPerChan, ConsPerChan, ChannelSize int = 1, 1, 1, 1, 128
var cons_done, prod_done bool = false, false;
var total_operations, cons_check, prod_check uint64 = 0, 0, 0
var m sync.Mutex

var prodJoin chan int = make(chan int, ProdsPerChan * Channels + 1)
var consJoin chan int = make(chan int, ConsPerChan * Channels + 1)

func consumer( channel chan uint64 ) {
	var count uint64 = 0
	var checksum uint64 = 0
	for {
		if cons_done { break }
		j := <- channel
		checksum = checksum ^ j
		if ! prod_done { count++ }
	}
	m.Lock()
	total_operations += count
	cons_check = cons_check ^ checksum
	// fmt.Print("C: ",count)
	m.Unlock()
	consJoin <- 0
}

func producer( channel chan uint64 ) {
	var count uint64 = 0
	var checksum uint64 = 0
	for {
		if prod_done { break }
		checksum = checksum ^ count
		channel <- count
		count++
	}
	m.Lock()
	total_operations += count
	prod_check = prod_check ^ checksum
	// fmt.Print("P: ",count, " ")
	m.Unlock()
	prodJoin <- 0
}

func usage() {
	fmt.Printf( "Usage: %v " +
		"[ processors (> 0) | 'd' (default %v) ] " +
		"[ ChannelSize (> 0) | 'd' (default %v) ]\n",
		os.Args[0], Processors, ChannelSize );
	os.Exit( 1 );
}

func main() {
	switch len( os.Args ) {
		case 3:
			if os.Args[2] != "d" {							// default ?
				ChannelSize, _ = strconv.Atoi( os.Args[2] )
					if ChannelSize < 0 { usage(); }
			} // if
		fallthrough
		case 2:
			if os.Args[1] != "d" {							// default ?
				Processors, _ = strconv.Atoi( os.Args[1] )
				if Processors < 1 { usage(); }
			} // if
		case 1:											// use defaults
		default:
		usage();
	} // switch
	runtime.GOMAXPROCS( Processors );
	ProdsPerChan = Processors /2
	ConsPerChan = Processors / 2

	// fmt.Println("Processors: ",Processors," Channels: ",Channels," ProdsPerChan: ",ProdsPerChan," ConsPerChan: ",ConsPerChan," Channel Size: ",ChannelSize)
	
	chans := make( [] chan uint64, Channels )
	for i := range chans {
		chans[i] = make(chan uint64, ChannelSize)
	}
	for i := range chans {
		for j := 0; j < ProdsPerChan; j++ {
			go producer( chans[i] )
		}

		for j := 0; j < ConsPerChan; j++ {
			go consumer( chans[i] )
		}
	}
		

	// wait 10 seconds
	time.Sleep(time.Second * 10)
	// fmt.Println("prod done\n")
	prod_done = true
	for j := 0; j < ProdsPerChan * Channels ; j++ {
		<-prodJoin
	}
	// fmt.Println("cons done\n")
	cons_done = true
	for i := range chans {
		L: for {
			select {
				case k := <-chans[i]:
					cons_check = cons_check ^ k
				default:
					break L
			}
		}
	}
	for i := range chans {
		close(chans[i])
	}

	for j := 0; j < ConsPerChan * Channels; j++{
		<-consJoin
	}
	
	
	if cons_check != prod_check {
		fmt.Println("\nChecksum mismatch: Cons: %d, Prods: %d", cons_check, prod_check)
	}
    fmt.Println(total_operations)
}