import os
import sys
import time
import itertools
import matplotlib.pyplot as plt
import matplotlib.ticker as ticks
import math
from scipy import stats as st
import numpy as np
from enum import Enum
from statistics import median

import matplotlib
matplotlib.use("pgf")
matplotlib.rcParams.update({
    "pgf.texsystem": "pdflatex",
    'font.family': 'serif',
    'text.usetex': True,
    'pgf.rcfonts': False,
    'font.size': 16
})
marker = itertools.cycle(('o', 's', 'D', 'x', 'p', '^', 'h', '*', 'v' )) 

readfile = open(sys.argv[1], "r")

machineName = ""

if len(sys.argv) > 2:
    machineName = sys.argv[2]

# first line has num times per experiment
line = readfile.readline()
numTimes = int(line)

# second line has processor args
line = readfile.readline()
procs = []
for val in line.split():
    procs.append(int(val))

# 3rd line has processor for side_chan bench
line = readfile.readline()
sideChanProcs = []
for val in line.split():
    sideChanProcs.append(int(val))

# 4th line has number of variants
line = readfile.readline()
names = line.split()
numVariants = len(names)

lines = (line.rstrip() for line in readfile) # All lines including the blank ones
lines = (line for line in lines if line) # Non-blank lines

def sci_format(x, pos):
    return '{:.1e}'.format(x).replace('+0', '')

def sci_format_label(x):
    return '{:.2e}'.format(x).replace('+0', '')

class Bench(Enum):
    Unset = 0
    Contend2 = 1
    Contend4 = 2
    Contend8 = 3
    Spin2 = 4
    Spin4 = 5
    Spin8 = 6
    SideChan = 7
    Future = 8
    Order = 9

nameSet = False
currBench = Bench.Unset # default val
count = 0
procCount = 0
currVariant = 0
name = ""
title = ""
experiment_duration = 10.0
var_name = ""
future_variants=["CFA", "uC++"]
future_names=["OR", "AND", "AND-OR", "OR-AND"]
future_data=[[0.0 for i in range(len(future_names))] for j in range(2)]
future_bars=[[[0.0 for i in range(len(future_names))],[0.0 for k in range(len(future_names))]] for j in range(2)]
curr_future=0
sendData = [0.0 for j in range(numVariants)]
data = [[0.0 for i in range(len(procs))] for j in range(numVariants)]
bars = [[[0.0 for i in range(len(procs))],[0.0 for k in range(len(procs))]] for j in range(numVariants)]
sideData = [[0.0 for i in range(len(sideChanProcs))] for j in range(numVariants)]
sideBars = [[[0.0 for i in range(len(sideChanProcs))],[0.0 for k in range(len(sideChanProcs))]] for j in range(numVariants)]
tempData = [0.0 for i in range(numTimes)]
orderData = [0.0 for i in range(numVariants)]
for idx, line in enumerate(lines):
    # print(line)
    
    if currBench == Bench.Unset:
        if line == "contend2:":
            name = "Contend_2"
            title = "2 Clause Contend"
            currBench = Bench.Contend2
        elif line == "contend4:":
            name = "Contend_4"
            title = "4 Clause Contend"
            currBench = Bench.Contend4
        elif line == "contend8:":
            name = "Contend_8"
            title = "8 Clause Contend"
            currBench = Bench.Contend8
        elif line == "spin2:":
            name = "Spin_2"
            title = "2 Clause Spin"
            currBench = Bench.Spin2
        elif line == "spin4:":
            name = "Spin_4"
            title = "4 Clause Spin"
            currBench = Bench.Spin4
        elif line == "spin8:":
            name = "Spin_8"
            title = "8 Clause Spin"
            currBench = Bench.Spin8
        elif line == "sidechan:":
            name = "Sidechan"
            currBench = Bench.SideChan
        elif line[0:6] == "future":
            name = "Future"
            title = "Future Synchronization"
            currBench = Bench.Future
        elif line == "order:":
            name = "order"
            currBench = Bench.Order
        else:
            print("Expected benchmark name")
            print("Line: " + line)
            sys.exit()
        continue

    if line[0:5] == "cores":
        continue

    if not nameSet:
        nameSet = True
        continue
    
    lineArr = line.split()
    tempData[count] = float(lineArr[-1]) / experiment_duration
    count += 1

    if currBench == Bench.Future:
        if count == numTimes:
            currMedian = median( tempData )
            future_data[currVariant][curr_future] = currMedian
            lower, upper = st.t.interval(0.95, numTimes - 1, loc=np.mean(tempData), scale=st.sem(tempData))
            future_bars[currVariant][0][curr_future] = currMedian - lower
            future_bars[currVariant][1][curr_future] = upper - currMedian
            count = 0
            nameSet = False
            currVariant += 1
            if currVariant == 2:
                curr_future += 1
                # reset
                currBench = Bench.Unset
                currVariant = 0
                if curr_future == len(future_names):
                    x = np.arange(len(future_names))  # the label locations
                    width = 0.45  # the width of the bars
                    multiplier = .5
                    fig, ax = plt.subplots(layout='constrained')
                    plt.title(title + " Benchmark")
                    plt.ylabel("Throughput (statement completions per second)")
                    plt.xlabel("Operation")
                    ax.yaxis.set_major_formatter(ticks.FuncFormatter(sci_format))
                    for idx, arr in enumerate(future_data):
                        offset = width * multiplier
                        rects = ax.bar(x + offset, arr, width, label=future_variants[idx], yerr=[future_bars[idx][0], future_bars[idx][1]])
                        # ax.bar_label(rects, padding=3, fmt='%.1e')
                        ax.bar_label(rects, padding=3, fmt=sci_format_label)
                        multiplier += 1
                    plt.xticks(x + width, future_names)
                    
                    ax.legend(future_variants, loc='lower right')
                    # fig.savefig("plots/" + machineName + name + ".png")
                    plt.savefig("plots/" + machineName + name + ".pgf")
                    fig.clf()

    elif currBench == Bench.Order:
        if count == numTimes:
            currMedian = median( tempData )
            orderData[currVariant] = currMedian
            count = 0
            currVariant += 1
            procCount = 0
            nameSet = False
            if currVariant == numVariants:
                fileName = "data/" + machineName + "Order"
                f = open(fileName, 'w')
                f.write(" & ".join(map(lambda a: str(int(a)), orderData)))
                
                # reset
                currBench = Bench.Unset
                currVariant = 0

    elif currBench == Bench.SideChan:
        if count == numTimes:
            currMedian = median( tempData )
            sideData[currVariant][procCount] = currMedian
            lower, upper = st.t.interval(0.95, numTimes - 1, loc=np.mean(tempData), scale=st.sem(tempData))
            sideBars[currVariant][0][procCount] = currMedian - lower
            sideBars[currVariant][1][procCount] = upper - currMedian
            count = 0
            procCount += 1
            if procCount == len(sideChanProcs):
                procCount = 0
                nameSet = False
                currVariant += 1

                if currVariant == numVariants:
                    fig, ax = plt.subplots()
                    plt.title(name + " Benchmark")
                    plt.ylabel("Throughput (channel operations per second)")
                    plt.xlabel("Cores")
                    ax.yaxis.set_major_formatter(ticks.FuncFormatter(sci_format))
                    for idx, arr in enumerate(sideData):
                        plt.errorbar( sideChanProcs, arr, [sideBars[idx][0], sideBars[idx][1]], capsize=2, marker=next(marker) )
                    plt.xticks(sideChanProcs)
                    marker = itertools.cycle(('o', 's', 'D', 'x', 'p', '^', 'h', '*', 'v' )) 
                    # plt.yscale("log")
                    ax.legend(names)
                    # fig.savefig("plots/" + machineName + name + ".png")
                    plt.savefig("plots/" + machineName + name + ".pgf")
                    fig.clf()

                    # reset
                    currBench = Bench.Unset
                    currVariant = 0
    else:
        if count == numTimes:
            currMedian = median( tempData )
            data[currVariant][procCount] = currMedian
            lower, upper = st.t.interval(0.95, numTimes - 1, loc=np.mean(tempData), scale=st.sem(tempData))
            bars[currVariant][0][procCount] = currMedian - lower
            bars[currVariant][1][procCount] = upper - currMedian
            count = 0
            procCount += 1

            if procCount == len(procs):
                procCount = 0
                nameSet = False
                currVariant += 1

                if currVariant == numVariants:
                    fig, ax = plt.subplots(layout='constrained')
                    plt.title(title + " Benchmark")
                    plt.ylabel("Throughput (channel operations per second)")
                    plt.xlabel("Cores")
                    ax.yaxis.set_major_formatter(ticks.FuncFormatter(sci_format))
                    for idx, arr in enumerate(data):
                        plt.errorbar( procs, arr, [bars[idx][0], bars[idx][1]], capsize=2, marker=next(marker) )
                    plt.xticks(procs)
                    marker = itertools.cycle(('o', 's', 'D', 'x', 'p', '^', 'h', '*', 'v' )) 
                    # plt.yscale("log")
                    # plt.ylim(1, None)
                    # ax.get_yaxis().set_major_formatter(ticks.ScalarFormatter())
                    # else:
                    #     plt.ylim(0, None)
                    ax.legend(names)
                    # fig.savefig("plots/" + machineName + name + ".png")
                    plt.savefig("plots/" + machineName + name + ".pgf")
                    fig.clf()

                    # reset
                    currBench = Bench.Unset
                    currVariant = 0