package main

import (
	"fmt"
	"sync"
	"time"
	"runtime"
	"os"
	"strconv"
)

var Processors, Channels, Prods, Cons, ChannelSize int = 2, 4, 1, 1, 10
var cons_done, prod_done bool = false, false;
var total_operations uint64 = 0
var m sync.Mutex

var prodJoin chan int = make(chan int)
var consJoin chan int = make(chan int)

func consumer( chans [] chan uint64 ) {
	var count uint64 = 0
	for {
		if cons_done { break }
		
		select {
			case <- chans[0]:
				if ! prod_done { count++ }
			case <- chans[1]:
				if ! prod_done { count++ }
			case <- chans[2]:
				if ! prod_done { count++ }
			case <- chans[3]:
			default:
		}
	}
	m.Lock()
	total_operations += count
	m.Unlock()
	consJoin <- 0
}

func producer( chans [] chan uint64 ) {
	var count uint64 = 0
	for {
		if prod_done { break }
		select {
			case chans[0] <- count:
			case chans[1] <- count:
			case chans[2] <- count:
			case chans[3] <- count:
			default:
		}
		count++
	}
	prodJoin <- 0
}

func usage() {
	fmt.Printf( "Usage: %v " +
		"[ processors (> 0) | 'd' (default %v) ] " +
		"[ ChannelSize (> 0) | 'd' (default %v) ]\n",
		os.Args[0], Processors, ChannelSize );
	os.Exit( 1 );
}

func main() {
	switch len( os.Args ) {
		case 3:
			if os.Args[2] != "d" {							// default ?
				ChannelSize, _ = strconv.Atoi( os.Args[2] )
					if ChannelSize < 0 { usage(); }
			} // if
		fallthrough
		case 2:
			if os.Args[1] != "d" {							// default ?
				Processors, _ = strconv.Atoi( os.Args[1] )
				if Processors < 1 { usage(); }
			} // if
		case 1:											// use defaults
		default:
		usage();
	} // switch
	runtime.GOMAXPROCS( Processors );
	Prods = Processors /2
	Cons = Processors / 2

	// fmt.Println("Processors: ",Processors," Channels: ",Channels," ProdsPerChan: ",ProdsPerChan," ConsPerChan: ",ConsPerChan," Channel Size: ",ChannelSize)
	
	chans := make( [] chan uint64, Channels )
	for i := range chans {
		chans[i] = make(chan uint64, ChannelSize)
	}
	for j := 0; j < Prods; j++ {
		go producer( chans )
	}

	for j := 0; j < Cons; j++ {
		go consumer( chans )
	}
		

	// wait 10 seconds
	time.Sleep(time.Second * 10)
	// fmt.Println("prod done\n")
	prod_done = true
	for j := 0; j < Prods; j++ {
		<-prodJoin
	}
	// fmt.Println("cons done\n")
	cons_done = true
	for i := range chans {
		close(chans[i])
	}
	
	for j := 0; j < Cons; j++{
		<-consJoin
	}

    fmt.Println(total_operations)
}