import pandas as pd
import numpy as np
import math
import os
from subprocess import Popen, PIPE
from scipy.stats import gmean

def getDataset( infile ):
    # grep to remove lines that end in comma; these were error runs
    with Popen("grep '[^,]$' " + infile, shell=True, stdout=PIPE) as process:
        timings = pd.read_csv(
            process.stdout,
            names=['RunMoment', 'RunIdx', 'Args', 'Program', 'Width',
                   'expt_ops_completed', 'expt_elapsed_sec', 'mean_op_dur_ns'],
            dtype={'RunMoment':       str,
                'RunIdx':             np.int64,
                'Args':               str,
                'Program':            str,
                'Width':              np.int64,
                'expt_ops_completed': np.int64,
                'expt_elapsed_sec':   np.float64,
                'mean_op_dur_ns':     np.float64},
            parse_dates=['RunMoment']
            )
    # print(timings.head())

    ## parse executable name and args

    timings[['ExperimentDurSec',
        'CheckDonePeriod',
        'Length',
        'ExperimentDurOpCount',
        'Seed',
        'InterleaveFrac']] = timings['Args'].str.strip().str.split(expand=True)
    timings["Length"] = pd.to_numeric(timings["Length"])
    timings["InterleaveFrac"] = pd.to_numeric(timings["InterleaveFrac"]).round(3)

    timings["NumNodes"] = timings["Length"] * timings["Width"]

    timings[['__ProgramPrefix',
        'fx',
        'op']] = timings['Program'].str.split('--', expand=True)

    timings[['movement',
        'polarity',
        'accessor']] = timings['op'].str.split('-', expand=True)
    
    ## SizeZone as NumNodes t-shirt size
    timings['SizeZone'] = np.select(
        condlist = [
            (4 <= timings['NumNodes']) & (timings['NumNodes'] <= 16),
            (48 <= timings['NumNodes']) & (timings['NumNodes'] <= 256)
        ],
        choicelist = [
            'SM',
            'ML'
        ],
        default = 'none'
    )

    return timings

# `c` = column name
def c( baseName, marginalizeOn ):
    margSlug = str.join( "_", marginalizeOn )
    return baseName + "_" + margSlug

explanations = ['movement', 'polarity', 'accessor',
                'NumNodes',
                'SizeZone', # note fd: NumNodes -> SizeZone
                'fx',
                'machine',
                'InterleaveFrac', # unused and always zero
                ]

# helper for avoiding pollution from e.g. alternate cfa list versions
# when a preference-limiting factor is marginalized, make bl value from preferred subset
# but still stamp result everywhere; e.g. even cfa-strip has canon-bl-relative perf
# when conditioning on such factor, peer groups are already small enough to stop such pollution
# use nontrivial marginalizeOn when calculating baseline values, to achieve the above outside-canonical behaviour non-degenerately
# use default full marginalizeOn when removing points from a graph, which leaves only canonical points
def getJustCanon( timings,
                  marginalizeOn = explanations, *,
                    # no c++: bl is for comparing intrusives
                    # no lq-list: sparse
                    # no cfa-fredDisbled: bl is for comparing prod-readies
                  fxInc = ['cfa-cfa', 'lq-tailq', 'upp-upp'],
                  szInc = ['SM', 'ML'],
                  sExcl = [1]
                  ): # all explanations marginalized => maximally aggressive filter
    if 'fx' in marginalizeOn:
        fxIsCanon = timings.fx.isin(fxInc)
        timings = timings[ fxIsCanon ]
    if 'SizeZone' in marginalizeOn:
        szIsCanon = timings.SizeZone.isin(szInc)
        timings = timings[ szIsCanon ]
    if 'NumNodes' in marginalizeOn:
        sIsCanon = ~ timings.NumNodes.isin(sExcl)
        timings = timings[ sIsCanon ]
    return timings


def annotateBaseline( timings, marginalizeOn ):
    c_tgtPeers = c( 'Peers', marginalizeOn )
    c_tgtBl = c("Baseline", marginalizeOn)
    c_tgtRel = c("OpDurRel", marginalizeOn)
    if c_tgtBl in timings.columns or c_tgtRel in timings.columns:
        assert( c_tgtBl in timings.columns and c_tgtRel in timings.columns )
        return
    # size handling:
    # two ordinary baselines (sz-nn, nn) and one synthetic baseline (sz)
    # the SizeZone-only baseline has no interpretation wrt a real peer group
    # it isolates the effect of belonging to one SZ or the other
    # while conditioning away the specific-size effects within the SZ
    # notably in zone SM, opDur-v-size usually pitches upward
    # comparing to sz-only baseline gets rid of "they all pitch up," while keeping "SM is faster then ML"
    if 'SizeZone' in marginalizeOn and 'NumNodes' not in marginalizeOn:
        # special case: sz-only synthetic benchmark
        margNeither = list( set(marginalizeOn) - {'SizeZone'} )
        margBoth = list( set(marginalizeOn) | {'NumNodes'} )
        margJustNn = list( set(margNeither) | {'NumNodes'} )
        annotateBaseline( timings, margNeither )
        annotateBaseline( timings, margBoth )
        annotateBaseline( timings, margJustNn )
        c_neitherRel = c("OpDurRel", margNeither)
        c_bothBl = c("Baseline", margBoth)
        c_justNnBl = c("Baseline", margJustNn)
        timings[ c_tgtBl ] = np.nan
        timings[ c_tgtRel ] = timings[ c_justNnBl ] / timings[ c_bothBl ] * timings[ c_neitherRel ]
    else: # general case
        # prevent non-canonical samples from polluting baseline values
        # note, depending on the presentation, the polluting points may already be removed from timings entirely
        canonSrc = getJustCanon(timings, marginalizeOn)
    #   print(f"for marg on {marginalizeOn}, |canonSrc| = {len(canonSrc)}, |timings| = {len(timings)}", file=sys.stderr)
        conditionOn = list( set(explanations) - set(marginalizeOn) )
    #   print( "marginalizing on", marginalizeOn, "conditioning on", conditionOn, file=sys.stderr )

        if conditionOn:
            stats = canonSrc.groupby(conditionOn)['mean_op_dur_ns'].agg(**{
                c_tgtPeers: 'count',
                c_tgtBl: gmean
            })
            group_lookup = timings.set_index(conditionOn).index
            timings[c_tgtPeers] = stats[c_tgtPeers].reindex(group_lookup).values
            timings[c_tgtBl] = stats[c_tgtBl].reindex(group_lookup).values
        else:
            stats = canonSrc.groupby((lambda _: 0))['mean_op_dur_ns'].agg(**{
                c_tgtPeers: 'count',
                c_tgtBl: gmean
            })
            # Extract the single row
            row = stats.iloc[0]
            # Broadcast to all rows
            timings[c_tgtPeers] = row[c_tgtPeers]
            timings[c_tgtBl] = row[c_tgtBl]


        # everywhere := itself / [preferred-subset derived]
        timings[c_tgtRel] = timings['mean_op_dur_ns'] / timings[c_tgtBl]


# longer column name (Peers_%, Baseline_%, OpDurRel_%) gives larger peer group and more (total) variation
def annotateCommonBaselines( timings ):
    def applyGeneralExplanations( bgMarginalizeOn ):
        def fg( marginalizeOn ):
            return bgMarginalizeOn + marginalizeOn
        annotateBaseline( timings, fg( [] ) ) # all-in baseline (all factors conditioned): only inter-run differences
        annotateBaseline( timings, fg( ['movement', 'polarity'] ) )
        annotateBaseline( timings, fg( ['accessor'] ) )
        annotateBaseline( timings, fg( ['machine'] ) )

        annotateBaseline( timings, fg( ['SizeZone', 'NumNodes'] ) )  # SizeZone is NOT redundant; conditioned on neither
        annotateBaseline( timings, fg( ['NumNodes'] ) )  # still conditioned on SizeZone
        annotateBaseline( timings, fg( ['SizeZone'] ) )  # synthetic: conditioned on NumNodes but not SizeZone
    applyGeneralExplanations( [] )
    applyGeneralExplanations( ['fx'] )

def getMachineDataset( dsname, machine ):
    infileLocal = f"results-{machine}-{dsname}.csv"
    infile = os.path.dirname(os.path.abspath(__file__)) + '/../benchmarks/list/' + infileLocal
    timings = getDataset( infile )
    timings['machine'] = machine
    return timings

allMachines = ['swift', 'java']


# general, as in exclude the stripped-down experimental CFAs
general_fxs_full = ['cfa-cfa', 'cpp-stlref', 'upp-upp', 'lq-tailq', 'lq-list']
general_fxs_intrusive = ['cfa-cfa', 'upp-upp', 'lq-tailq', 'lq-list']

def getSingleResults(
        dsname = 'general',
        machines = allMachines,
        *,
        fxs = general_fxs_full,
        tgtMovement = 'all',
        tgtPolarity = 'all',
        tgtAccessor = 'all',
        tgtInterleave = 0.0 ):

    timings = pd.concat([
        getMachineDataset( dsname, m )
        for m in machines ])
    
#    print(timings, file=sys.stderr)

    movements = timings['movement'].unique()
    polarities = timings['polarity'].unique()
    accessors = timings['accessor'].unique()
    interleaves = timings['InterleaveFrac'].unique()

    if movements.size > 1:
        movements = np.append(movements, 'all')
    if polarities.size > 1:
        polarities = np.append(polarities, 'all')
    if accessors.size > 1:
        accessors = np.append(accessors, 'all')

#    print(f"trying to filter {dsname} {machines} {len(timings)}", file=sys.stderr)
    grp = timings.groupby('fx')
#    print(f"with fxs {grp.groups.keys()}", file=sys.stderr)
    timings = pd.concat([
        grp.get_group(fx)
        for fx in fxs ])

    if (tgtMovement != 'all'):
        grp = timings.groupby('movement')
        timings = grp.get_group(tgtMovement)
    if (tgtPolarity != 'all'):
        grp = timings.groupby('polarity')
        timings = grp.get_group(tgtPolarity)
    if (tgtAccessor != 'all'):
        grp = timings.groupby('accessor')
        timings = grp.get_group(tgtAccessor)
    if (tgtInterleave != 'all'):
        timings = timings[ timings['InterleaveFrac'] == float(tgtInterleave) ]


    return timings

def stripMachine(pyCore):
    parts = str.split(pyCore, '-')
    exceptLast = parts[ 0 : -1 ]
    return str.join('-', exceptLast)

def getSummaryMeta(metaFileCore):
    metafile = os.path.dirname(os.path.abspath(__file__)) + "/" + metaFileCore + '-meta.dat'
    metadata = pd.read_csv(
        metafile,
        names=['OpIx', 'Op'],
        delimiter='\t'
    )
    metadata[['movement',
        'polarity',
        'accessor']] = metadata['Op'].str.split('\\\\n', expand=True)
    metadata.replace('*', 'all', inplace=True)
    metadata.replace('S', 'stack', inplace=True)
    metadata.replace('Q', 'queue', inplace=True)
    metadata.replace('iF', 'insfirst', inplace=True)
    metadata.replace('iL', 'inslast', inplace=True)
    metadata.replace('H', 'allhead', inplace=True)
    metadata.replace('Ie', 'inselem', inplace=True)
    metadata.replace('Re', 'remelem', inplace=True)
    return metadata

swiftSweetspot = (lambda x: x > 16 and x < 150)
# swiftSweetspot = (lambda x: x > 4 and x < 32)
javaSweetspot = (lambda x: x >= 24 and x <= 256)

def printManySummary(*,
        dsname = 'general',
        machines = allMachines,
        metafileCore,
        fxs,
        sizeQual,
        tgtInterleave = 0.0,
        marginalizeOn = ['fx'] ) :
    
    metadata = getSummaryMeta(metafileCore)

    measure = c( 'OpDurRel', marginalizeOn ) 

    print("# op_num\tfx_num\tfx\tmean\tstdev\tmin\tmax\tcount\tpl95\tpl68\tp50\tph68\tph95")

    for op in metadata.itertuples():
        timings = getSingleResults(dsname, machines,
            fxs=fxs,
            tgtMovement = op.movement,
            tgtPolarity = op.polarity,
            tgtAccessor = op.accessor,
            tgtInterleave = tgtInterleave )
        annotateBaseline(timings, marginalizeOn)

        timings = timings[ timings['fx'].isin(fxs) ]
        timings = timings[ timings['NumNodes'].apply(sizeQual) ]

        fxnums = timings['fx'].apply(
            lambda fx: fxs.index(fx) + 1
        )
        timings.insert(loc=0, column='fx_num', value=fxnums)
        timings.insert(loc=0, column='op_num', value=op.OpIx)

        grouped = timings.groupby(['op_num', 'fx_num', 'fx'])

        aggregated = grouped[measure].agg(
            ["mean", "std", "min", "max", "count",
            lambda x: x.quantile(0.025),
            lambda x: x.quantile(0.16),
            lambda x: x.quantile(0.5),
            lambda x: x.quantile(0.84),
            lambda x: x.quantile(0.975)]
        )

        text = aggregated.to_csv(header=False, index=True, sep='\t')
        print(text, end='')

def printSingleDetail(
        dsname = 'general',
        machines = allMachines,
        *,
        fxs = general_fxs_full,
        tgtMovement = 'all',
        tgtPolarity = 'all',
        tgtAccessor = 'all',
        tgtInterleave = 0.0,
        measureBase = 'mean_op_dur_ns',
        marginalizeOn = explanations ):


    timings = getSingleResults(dsname, machines,
        fxs = fxs,
        tgtMovement = tgtMovement,
        tgtPolarity = tgtPolarity,
        tgtAccessor = tgtAccessor,
        tgtInterleave = tgtInterleave)

    if measureBase == 'OpDurRel':
        annotateBaseline(timings, marginalizeOn)
        measure = c( measureBase, marginalizeOn )
    elif measureBase == 'mean_op_dur_ns':
        measure = measureBase
    else:
        raise RuntimeError(f"measureBase '{measureBase}' not handled")
    
    groupedFx = timings.groupby('fx')
    for fx, fgroup in groupedFx:
        # print(fgroup.head())
        groupedRun = fgroup.groupby(['NumNodes']) # , 'fx', 'op'
        aggregated = groupedRun[measure].agg(
            ["mean", "std", "min", "max", "count", "sum"]
        )
        aggregated['mean_no_outlr'] = (
            ( aggregated['sum'] - aggregated['min'] - aggregated['max'] )
            /
            ( aggregated['count'] - 2 )
        )

        #print(aggregated.head())

        print('"{header}"'.format(header=fx))
        text = aggregated.to_csv(header=False, index=True, sep='\t')
        print(text)
        print()
        print()

def aMeanNoOutlr(range):
    return ( range.sum() - range.min() - range.max() ) / ( range.count() - 2 )

def gMeanNoOutlr(range):
    return ( range.prod() / range.min() / range.max() ) ** ( 1 / ( range.count() - 2 ) )


def trimPer( df, criteria ):
    for field, values in criteria.items():
        areMatches = df[ field ].isin(values)
        df = df[ areMatches ]
    return df

# The range from 0.9759 to 1.0247 (which is 1.05 x wide) has 1.0 in its centre.
# This is the bucket with key 0.
# Logs of values in this bucket go from -0.5 to +0.5.
# Rounding a log value to the nearest integer gives the key.
# Exponentiating a key directly gives the centre of its bucket.
# Exponentiating a key less 0.5 gives the bottom of its bucket.
# Gnuplot expects the latter.

bucketMin = 0.25
bucketMax = 4.0
bucketGrain = 1.05
bktKeyLo = math.floor( math.log(bucketMin, bucketGrain) )
bktKeyHi = math.ceil( math.log(bucketMax, bucketGrain) )

def bktKeyOfVal( relDur ):
    distance = math.log(relDur, bucketGrain)
    key = round( distance )
    return key

def bktIxOfVal( relDur ):
    return bktKeyToIx( bktKeyOfVal( relDur ) )

def botValOfBucketK( key ):
    return bucketGrain ** ( key - 0.5 )

def topValOfBucketBotVal( botVal ):
    return bucketGrain * botVal

def bktKeyToIx( key ):
    return key - bktKeyLo

def bktIxToKey( ix ):
    return ix + bktKeyLo

def botOfBucketOfVal( relDur ):
    return botValOfBucketK( bktKeyOfVal( relDur ) )

buckets = [ botValOfBucketK(key) for key in range(bktKeyLo, bktKeyHi) ]

# printSingleDetail
def printHistos(*,
    tgtMovement = 'all',
    tgtPolarity = 'all',
    tgtAccessor = 'all',
    tgtInterleave = 0.0,
    earlyFilter = {}, # exclude from benchmarking
    lateFilter = {}, # exclude from output
    drillOn = ['fx'],
    marginalizeOn = None ):  # None means match drill-on

    if marginalizeOn == None:
        marginalizeOn = drillOn

    # watch out for filtering too early here; need everything sticking around until baselines are applies
    # ie, maybe I should get rid of all the tgt parms at the pre-benchmark layers
    timings = getSingleResults(
        tgtMovement = tgtMovement,
        tgtPolarity = tgtPolarity,
        tgtAccessor = tgtAccessor,
        tgtInterleave = tgtInterleave)
    timings = getJustCanon( timings,
                  fxInc = ['cfa-cfa', 'lq-tailq', 'upp-upp', 'lq-list'],
                  szInc = ['SM', 'ML'],
                  sExcl = [1] )
    
    timings = trimPer( timings, earlyFilter )

    options = timings.groupby(explanations)
    aggregated = options.agg(
        mean_op_dur_ns = ('mean_op_dur_ns', gMeanNoOutlr)
    ).reset_index()
    annotateBaseline(aggregated, marginalizeOn)

    aggregated = trimPer( aggregated, lateFilter )

    # if examining "why CFA slow" need both
    # - getVariousCfa inplace of getJust Canon
    # - do annotate-then-filter because baseline needs to stay cfa-tailq-upp
    # (filter-then-annotate is fine for general cases (where all three canons are included) and good for build time)


    c_measure = c('OpDurRel', marginalizeOn)
    # options = timings.groupby(explanations)

    # aggregated = options.agg(
    #     **{measure:(measure,gMeanNoOutlr)}
    # ).reset_index()

    c_measureBkt = 'BUCKET_' + c_measure
    aggregated[ c_measureBkt ] = aggregated[c_measure].apply( botOfBucketOfVal )

    drillgrp = aggregated.groupby(drillOn)

    # print(f'measure is {measure}')
    # print()
    # print()

    for dkey, dgroup in drillgrp:
#       print(mgroup, file=sys.stderr)

        histo_raw = dgroup[ c_measureBkt ].value_counts()
        for b in buckets:
            if b not in histo_raw.keys():
#                print( f"{b} := 0", file=sys.stderr )
                histo_raw[b] = 0
        histo_raw = histo_raw.sort_index()

        histo = histo_raw.rename("count").reset_index()
        histo = histo.rename(columns={c_measureBkt: "y_lo"})
        y_lo_col_loc = histo.columns.get_loc("y_lo")
        histo.insert(y_lo_col_loc + 1, "y_hi", histo["y_lo"].apply(topValOfBucketBotVal))

        dkey_str = list( map( str, dkey ) )
        header = str.join(', ', dkey_str)
        print(f'"{header}"')
        text = histo.to_csv(header=False, index=False, sep='\t')
        print(text)
        print()
        print()

        # print(f'"{header}" FULL')
        # text = group.to_csv(header=False, index=True, sep='\t')
        # print(text)
        # print()
        # print()

    # print(f'"RAW"')
    # text = timings.to_csv(header=False, index=True, sep='\t')
    # print(text)
