import os
import sys
import time
import itertools
import matplotlib.pyplot as plt
import matplotlib.ticker as ticks
import math
from scipy import stats as st
import numpy as np
from enum import Enum
from statistics import median

import matplotlib
matplotlib.use("pgf")
matplotlib.rcParams.update({
    "pgf.texsystem": "pdflatex",
    'font.family': 'serif',
    'text.usetex': True,
    'pgf.rcfonts': False,
    'font.size': 16
})
marker = itertools.cycle(('o', 's', 'D', 'x', 'p', '^', 'h', '*', 'v' ))

def sci_format(x, pos):
    return '{:.1e}'.format(x).replace('+0', '')

readfile = open(sys.argv[1], "r")

machineName = ""

if len(sys.argv) > 2:
    machineName = sys.argv[2]

# first line has num times per experiment
line = readfile.readline()
numTimes = int(line)

# second line has processor args
line = readfile.readline()
procs = []
for val in line.split():
    procs.append(int(val))

# 3rd line has number of variants
line = readfile.readline()
names = line.split()
numVariants = len(names)

lines = (line.rstrip() for line in readfile) # All lines including the blank ones
lines = (line for line in lines if line) # Non-blank lines

class Bench(Enum):
    Unset = 0
    Contend = 1
    Zero = 2
    Barrier = 3
    Churn = 4
    Daisy_Chain = 5
    Hot_Potato = 6
    Pub_Sub = 7

nameSet = False
currBench = Bench.Unset # default val
count = 0
procCount = 0
currVariant = 0
experiment_duration = 10.0
name = ""
title = ""
var_name = ""
sendData = [0.0 for j in range(numVariants)]
data = [[0.0 for i in range(len(procs))] for j in range(numVariants)]
bars = [[[0.0 for i in range(len(procs))],[0.0 for k in range(len(procs))]] for j in range(numVariants)]
tempData = [0.0 for i in range(numTimes)]
for idx, line in enumerate(lines):
    # print(line)
    
    if currBench == Bench.Unset:
        if line == "contend:":
            name = "Channel_Contention"
            title = "Channel Contention"
            currBench = Bench.Contend
        elif line == "zero:":
            name = "Zero"
            currBench = Bench.Zero
        elif line == "barrier:":
            name = "Barrier"
            currBench = Bench.Barrier
        elif line == "churn:":
            name = "Churn"
            currBench = Bench.Churn
        elif line == "daisy_chain:":
            name = "Daisy_Chain"
            currBench = Bench.Daisy_Chain
        elif line == "hot_potato:":
            name = "Hot_Potato"
            currBench = Bench.Hot_Potato
        elif line == "pub_sub:":
            name = "Pub_Sub"
            currBench = Bench.Pub_Sub
        else:
            print("Expected benchmark name")
            print("Line: " + line)
            sys.exit()
        continue

    if line[0:5] == "cores":
        continue

    if not nameSet:
        nameSet = True
        continue
    
    lineArr = line.split()
    tempData[count] = float(lineArr[-1]) / experiment_duration
    count += 1
    if count == numTimes:
        currMedian = median( tempData )
        data[currVariant][procCount] = currMedian
        lower, upper = st.t.interval(0.95, numTimes - 1, loc=np.mean(tempData), scale=st.sem(tempData))
        bars[currVariant][0][procCount] = currMedian - lower
        bars[currVariant][1][procCount] = upper - currMedian
        count = 0
        procCount += 1

        if procCount == len(procs):
            procCount = 0
            nameSet = False
            currVariant += 1

            if currVariant == numVariants:
                fig, ax = plt.subplots(layout='constrained')
                if title != "":
                    plt.title(title + " Benchmark")
                    title = ""
                else:
                    plt.title(name + " Benchmark")
                plt.ylabel("Throughput (channel operations per second)")
                plt.xlabel("Cores")
                ax.yaxis.set_major_formatter(ticks.FuncFormatter(sci_format))
                for idx, arr in enumerate(data):
                    plt.errorbar( procs, arr, [bars[idx][0], bars[idx][1]], capsize=2, marker=next(marker) )
                marker = itertools.cycle(('o', 's', 'D', 'x', 'p', '^', 'h', '*', 'v' )) 
                # plt.yscale("log")
                # plt.ylim(1, None)
                # ax.get_yaxis().set_major_formatter(ticks.ScalarFormatter())
                # else:
                #     plt.ylim(0, None)
                plt.xticks(procs)
                ax.legend(names)
                # fig.savefig("plots/" + machineName + name + ".png")
                plt.savefig("plots/" + machineName + name + ".pgf")
                fig.clf()

                # reset
                currBench = Bench.Unset
                currVariant = 0