import os
import sys
import time
import matplotlib.pyplot as plt
import matplotlib.ticker as ticks
import math
from scipy import stats as st
import numpy as np
from enum import Enum
from statistics import median

import matplotlib
matplotlib.use("pgf")
matplotlib.rcParams.update({
    "pgf.texsystem": "pdflatex",
    'font.family': 'serif',
    'text.usetex': True,
    'pgf.rcfonts': False,
})

readfile = open(sys.argv[1], "r")

machineName = ""

if len(sys.argv) > 2:
    machineName = sys.argv[2]

# first line has num times per experiment
line = readfile.readline()
numTimes = int(line)

# second line has processor args
line = readfile.readline()
procs = []
for val in line.split():
    procs.append(int(val))

# 3rd line has number of variants
line = readfile.readline()
names = line.split()
numVariants = len(names)

lines = (line.rstrip() for line in readfile) # All lines including the blank ones
lines = (line for line in lines if line) # Non-blank lines

class Bench(Enum):
    Unset = 0
    Executor = 1
    Matrix = 2
    Repeat = 3
    Balance_One = 4
    Balance_Multi = 5
    Static = 7
    Dynamic = 8
    Mem = 9

nameSet = False
currBench = Bench.Unset # default val
count = 0
procCount = 0
currVariant = 0
name = ""
var_name = ""
sendData = [0.0 for j in range(numVariants)]
data = [[0.0 for i in range(len(procs))] for j in range(numVariants)]
bars = [[[0.0 for i in range(len(procs))],[0.0 for k in range(len(procs))]] for j in range(numVariants)]
tempData = [0.0 for i in range(numTimes)]
for idx, line in enumerate(lines):
    # print(line)
    
    if currBench == Bench.Unset:
        if line == "executor":
            name = "Executor"
            currBench = Bench.Executor
        elif line == "matrix":
            name = "Matrix"
            currBench = Bench.Matrix
        elif line == "repeat":
            name = "Repeat"
            currBench = Bench.Repeat
        elif line == "balance_one":
            name = "Balance-One"
            currBench = Bench.Balance_One
        elif line == "balance_multi":
            name = "Balance-Multi"
            currBench = Bench.Balance_Multi
        elif line == "static":
            name = "Static"
            currBench = Bench.Static
        elif line == "dynamic":
            name = "Dynamic"
            currBench = Bench.Dynamic
        elif line == "mem":
            name = "ExecutorMemory"
            currBench = Bench.Mem
        else:
            print("Expected benchmark name")
            sys.exit()
        continue

    if line[0:4] == "proc":
        continue

    if currBench == Bench.Static or currBench == Bench.Dynamic or currBench == Bench.Mem:
        if not nameSet:
            nameSet = True
            continue
        lineArr = line.split()
        tempData[count] = float(lineArr[-1])
        count += 1
        if count == numTimes:
            currMedian = median( tempData )
            sendData[currVariant] = currMedian
            count = 0
            nameSet = False
            currVariant += 1

            if currVariant == numVariants:
                fileName = "data/" + machineName
                if currBench == Bench.Static:
                    fileName += "SendStatic"
                elif currBench == Bench.Dynamic:
                    fileName += "SendDynamic"
                else:
                    fileName += "ExecutorMem"
                f = open(fileName, 'w')
                if currBench == Bench.Mem:
                    f.write(" & ".join(map(lambda a: str(int(a/1000)) + 'MB', sendData)))
                else:
                    f.write(" & ".join(map(lambda a: str(int(a)) + 'ns', sendData)))

                # reset
                currBench = Bench.Unset
                currVariant = 0
                
    else:
        if not nameSet:
            nameSet = True
            continue
        
        lineArr = line.split()
        tempData[count] = float(lineArr[-1])
        count += 1
        if count == numTimes:
            currMedian = median( tempData )
            data[currVariant][procCount] = currMedian
            lower, upper = st.t.interval(0.95, numTimes - 1, loc=np.mean(tempData), scale=st.sem(tempData))
            bars[currVariant][0][procCount] = currMedian - lower
            bars[currVariant][1][procCount] = upper - currMedian
            count = 0
            procCount += 1

            if procCount == len(procs):
                procCount = 0
                nameSet = False
                currVariant += 1

                if currVariant == numVariants:
                    fig, ax = plt.subplots()
                    plt.title(name + " Benchmark")
                    plt.ylabel("Runtime (seconds)")
                    plt.xlabel("Cores")
                    for idx, arr in enumerate(data):
                        plt.errorbar( procs, arr, [bars[idx][0], bars[idx][1]], capsize=2, marker='o' )
                    if currBench == Bench.Executor or currBench == Bench.Matrix or currBench == Bench.Balance_One or currBench == Bench.Repeat:
                        plt.yscale("log")
                        plt.ylim(1, None)
                        ax.get_yaxis().set_major_formatter(ticks.ScalarFormatter())
                    else:
                        plt.ylim(0, None)
                    plt.xticks(procs)
                    ax.legend(names)
                    # fig.savefig("plots/" + machineName + name + ".png")
                    plt.savefig("plots/" + machineName + name + ".pgf")
                    fig.clf()

                    # reset
                    currBench = Bench.Unset
                    currVariant = 0