# Read csv given on cmdline, e.g results-sizing-c.csv
# In each op dimension (movement, polarity, accessor)
#   that has several values showing up in the input
#   extend it with an 'all' member.
# Each resulting op combination defines an output file, so named, e.g.
#   results-sizing-c-stack-insfirst-allhead.dat
#   results-sizing-c-queue-insfirst-allhead.dat
#   results-sizing-c-all-insfirst-allhead.dat
# For each output file
#   considering the subset of the input data that qualifies,
#   proceed as in crunch1, i.e. [following steps], putting the output in that file
# Split "series" goups of fx
# Group by all remaining classifiers except "repeat number"
# output:
# x y-mean y-stdev y-min y-max
# where x is size, y is duration
# in chunks, each headed by fx

import pandas as pd
import numpy as np
import sys
import os
from contextlib import redirect_stdout

plotsdir = os.path.dirname(__file__) + "/../../plots"
sys.path.insert(0, plotsdir)
from ListCommon import *

infile = sys.argv[1]

outdir = 'detail-plots'
if (len(sys.argv) >= 3 ):
    outdir = sys.argv[2]
os.makedirs(outdir, exist_ok=True)

timings = getDataset( infile )

## inventory the op dimensions

movements = timings['movement'].unique()
polarities = timings['polarity'].unique()
accessors = timings['accessor'].unique()
interleaves = timings['InterleaveFrac'].unique()

if movements.size > 1:
    movements = np.append(movements, 'all')
if polarities.size > 1:
    polarities = np.append(polarities, 'all')
if accessors.size > 1:
    accessors = np.append(accessors, 'all')
# if interleaves.size > 1:
#     interleaves = np.append(interleaves, 'all')

# print(movements)
# print(polarities)
# print(accessors)
# print(interleaves)

ops = np.stack(np.meshgrid(movements, polarities, accessors, interleaves), -1).reshape(-1, 4)
# print(ops)

for [movement, polarity, accessor, interleave] in ops:    # output-file grain

    tgtOp = '{}-{}-{}-{}'.format(movement, polarity, accessor, interleave)
    outfile = '{}/{}-{}.dat'.format(outdir, infile[:-4], tgtOp)
    # print()
    print ("=== ", outfile, " ===")
    # print()

    ## re-shape

    timingsFiltered = timings

    if (movement != 'all'):
        grp = timingsFiltered.groupby('movement')
        timingsFiltered = grp.get_group(movement)
    if (polarity != 'all'):
        grp = timingsFiltered.groupby('polarity')
        timingsFiltered = grp.get_group(polarity)
    if (accessor != 'all'):
        grp = timingsFiltered.groupby('accessor')
        timingsFiltered = grp.get_group(accessor)
    if (interleave != 'all'):
        timingsFiltered = timingsFiltered[ timingsFiltered['InterleaveFrac'] == float(interleave) ]

    rows = timingsFiltered.shape[0]
    if rows == 0:
        print("skip")
    else:
        print("got", rows)

        with open(outfile, 'w') as f:
            with redirect_stdout(f):

                groupedFx = timingsFiltered.groupby('fx')

                for fx, fgroup in groupedFx:
                    # print(fgroup.head())
                    groupedRun = fgroup.groupby(['NumNodes']) # , 'fx', 'op'
                    aggregated = groupedRun['mean_op_dur_ns'].agg(["mean", "std", "min", "max", "count"])
                    #print(aggregated.head())

                    print('"{header}"'.format(header=fx))
                    text = aggregated.to_csv(header=False, index=True, sep='\t')
                    print(text)
                    print()
                    print()
