# Based on crunch1
# updates for run-scenario columns not seen back then 
# result eyeballs okay

import pandas as pd
import numpy as np
import sys
import os
from subprocess import Popen, PIPE

def getSingleResults(infileLocal, *,
    tgtMovement = 'all',
    tgtPolarity = 'all',
    tgtAccessor = 'all',
    tgtInterleave = 0.0 ):

    infile = os.path.dirname(os.path.abspath(__file__)) + '/../benchmarks/list/' + infileLocal

    # grep to remove lines that end in comma; these were error runs
    with Popen("grep '[^,]$' " + infile, shell=True, stdout=PIPE) as process:
        timings = pd.read_csv(
            process.stdout,
            names=['RunMoment', 'RunIdx', 'Args', 'Program', 'expt_ops_completed', 'expt_elapsed_sec', 'mean_op_dur_ns'],
            dtype={'RunMoment':          str,
                'RunIdx':             np.int64,
                'Args':               str,
                'Program':            str,
                'expt_ops_completed': np.int64,
                'expt_elapsed_sec':   np.float64,
                'mean_op_dur_ns':     np.float64},
            parse_dates=['RunMoment']
            )
    # print(timings.head())

    ## parse executable name and args

    timings[['ExperimentDurSec',
        'CheckDonePeriod',
        'NumNodes',
        'ExperimentDurOpCount',
        'Seed',
        'InterleaveFrac']] = timings['Args'].str.strip().str.split(expand=True)
    timings["NumNodes"] = pd.to_numeric(timings["NumNodes"])
    timings["InterleaveFrac"] = pd.to_numeric(timings["InterleaveFrac"]).round(3)

    timings[['__ProgramPrefix',
        'fx',
        'op']] = timings['Program'].str.split('--', expand=True)

    timings[['movement',
        'polarity',
        'accessor']] = timings['op'].str.split('-', expand=True)

    ## calculate relative to baselines
    baseline_fx = 'lq-tailq'
    baseline_intrl = 0.0

    # chose calc "FineCrossRun" from labpc:crunch3
    byPeer = timings.groupby(['NumNodes', 'op', 'InterleaveFrac'])
    for [NumNodes, op, intrlFrac], peerGroup in byPeer:
        grpfx = peerGroup.groupby(['fx'])
        if baseline_fx in grpfx.groups:
            baselineRows = grpfx.get_group(baseline_fx)
            baselineDur = meanNoOutlr( baselineRows['mean_op_dur_ns'] )
        else:
            baselineDur = 1.0
        timings.loc[peerGroup.index, 'BaselineFxOpDurNs'] = baselineDur
    timings['OpDurRelFx'] = timings['mean_op_dur_ns'] / timings['BaselineFxOpDurNs']

    # relative to same fx, no interleave
    byPeer = timings.groupby(['NumNodes', 'op', 'fx'])
    for [NumNodes, op, fx], peerGroup in byPeer:
        baselineRows = peerGroup.groupby(['InterleaveFrac']).get_group(baseline_intrl)
        baselineDur = meanNoOutlr( baselineRows['mean_op_dur_ns'] )
        timings.loc[peerGroup.index, 'BaselineIntrlOpDurNs'] = baselineDur
    timings['OpDurRelIntrl'] = timings['mean_op_dur_ns'] / timings['BaselineIntrlOpDurNs']

    movements = timings['movement'].unique()
    polarities = timings['polarity'].unique()
    accessors = timings['accessor'].unique()
    interleaves = timings['InterleaveFrac'].unique()

    if movements.size > 1:
        movements = np.append(movements, 'all')
    if polarities.size > 1:
        polarities = np.append(polarities, 'all')
    if accessors.size > 1:
        accessors = np.append(accessors, 'all')

    if (tgtMovement != 'all'):
        grp = timings.groupby('movement')
        timings = grp.get_group(tgtMovement)
    if (tgtPolarity != 'all'):
        grp = timings.groupby('polarity')
        timings = grp.get_group(tgtPolarity)
    if (tgtAccessor != 'all'):
        grp = timings.groupby('accessor')
        timings = grp.get_group(tgtAccessor)
    if (tgtInterleave != 'all'):
        timings = timings[ timings['InterleaveFrac'] == float(tgtInterleave) ]

    return timings

def getSummaryMeta(metaFileCore):
    metafile = os.path.dirname(os.path.abspath(__file__)) + "/" + metaFileCore + '-meta.dat'
    metadata = pd.read_csv(
        metafile,
        names=['OpIx', 'Op'],
        delimiter='\t'
    )
    metadata[['movement',
        'polarity',
        'accessor']] = metadata['Op'].str.split('\\\\n', expand=True)
    metadata.replace('*', 'all', inplace=True)
    return metadata

def printManySummary(*,
        infileLocal,
        metafileCore,
        fxs,
        sizeQual = (lambda x: x < 150),  # x < 8
        tgtInterleave = 0.0,
        measure = 'OpDurRelFx') :
    
    metadata = getSummaryMeta(metafileCore)

    print("# op_num\tfx_num\tfx\tmean\tstdev\tmin\tmax\tcount\tpl95\tpl68\tp50\tph68\tph95")

    for op in metadata.itertuples():
        timings = getSingleResults(infileLocal,
            tgtMovement = op.movement,
            tgtPolarity = op.polarity,
            tgtAccessor = op.accessor,
            tgtInterleave = tgtInterleave )

        timings = timings[ timings['fx'].isin(fxs) ]
        timings = timings[ timings['NumNodes'].apply(sizeQual) ]

        fxnums = timings['fx'].apply(
            lambda fx: fxs.index(fx) + 1
        )
        timings.insert(loc=0, column='fx_num', value=fxnums)
        timings.insert(loc=0, column='op_num', value=op.OpIx)

        grouped = timings.groupby(['op_num', 'fx_num', 'fx'])

        aggregated = grouped[measure].agg(
            ["mean", "std", "min", "max", "count",
            lambda x: x.quantile(0.025),
            lambda x: x.quantile(0.16),
            lambda x: x.quantile(0.5),
            lambda x: x.quantile(0.84),
            lambda x: x.quantile(0.975)]
        )

        text = aggregated.to_csv(header=False, index=True, sep='\t')
        print(text, end='')

def printSingleDetail(infileLocal, *,
    tgtMovement = 'all',
    tgtPolarity = 'all',
    tgtAccessor = 'all',
    tgtInterleave = 0.0,
    measure = 'mean_op_dur_ns' ):

    timings = getSingleResults(infileLocal,
        tgtMovement = tgtMovement,
        tgtPolarity = tgtPolarity,
        tgtAccessor = tgtAccessor,
        tgtInterleave = tgtInterleave)
    groupedFx = timings.groupby('fx')

    for fx, fgroup in groupedFx:
        # print(fgroup.head())
        groupedRun = fgroup.groupby(['NumNodes']) # , 'fx', 'op'
        aggregated = groupedRun[measure].agg(
            ["mean", "std", "min", "max", "count", "sum"]
        )
        aggregated['mean_no_outlr'] = (
            ( aggregated['sum'] - aggregated['min'] - aggregated['max'] )
            /
            ( aggregated['count'] - 2 )
        )

        #print(aggregated.head())

        print('"{header}"'.format(header=fx))
        text = aggregated.to_csv(header=False, index=True, sep='\t')
        print(text)
        print()
        print()

def meanNoOutlr(range):
    return ( range.sum() - range.min() - range.max() ) / ( range.count() - 2 )
