#!/usr/bin/env python3

import os
import sys
import math
import argparse
import statistics as st

# Parsing Logic
parser = argparse.ArgumentParser(prog = 'GenConvoyStats',
    description = 'Analyzes handoff matrix output and uses Markov chain modelling to provide upper and lower bounds on the largest long term convoy',
    epilog = '')

parser.add_argument("Filename")
parser.add_argument("-r", "--RunsPerNumThds", default=5, help="Number of trials per # of threads. Default is 5")
parser.add_argument("-m", "--MaxThreads", default=32, help="Maximum number of threads. Default is 32")
parser.add_argument("-o", "--OutputFile", default='', help="File to write output to")
parser.add_argument("-v", "--Verbose", help="Verbose output. Will print per run stats alongside aggregates", action='count', default=0)
parser.add_argument("-s", "--Single", help="Indicates that input only contains a single run. Assumes the number of threads is the value passed to -m", action='count', default=0)

args = parser.parse_args()

# handoff data
handoff = []
sumPerRow = [] 
handoffTotal = 0
maxCycleSum = 0
maxCycle = []

# per thread data (across runs)
minBound = []
maxBound = []
expected = []

# input file descriptor
readFile = open(args.Filename, "r")

writeFile = False
if args.OutputFile != '':
    writeFile = open(args.Filename, "w")

def reset():
    global handoff, sumPerRow, handoffTotal, readFile, maxCycleSum, maxCycle
    handoff.clear()
    sumPerRow.clear()
    handoffTotal = 0
    maxCycleSum = 0
    maxCycle.clear()

def thdReset():
    global minBound, maxBound, expected
    minBound.clear()
    maxBound.clear()
    expected.clear()

def output(string):
    global writeFile
    if writeFile:
        writeFile.write(string + '\n')
    
    print(string)

# reads in handoff matrix for a single run and accumulates row/matrix total at the same time
def readInMatrix(currThds):
    global handoff, sumPerRow, handoffTotal, readFile
    for i in range(currThds):
        line = readFile.readline()

        # Deal with EOF
        if not line:
            print("Incorrect arguments or file format: error in readInMatrix")
            sys.exit(1)

        # deal with any empty lines
        while line == '\n':
            line = readFile.readline()

        row = []
        rowSum = 0

        # convert row into list of ints and accumulate
        for val in line.replace(',','').split():
            row.append(int(val))
            rowSum += int(val)
        
        #store row in global state
        handoff.append(row)
        sumPerRow.append(rowSum)
        handoffTotal += rowSum

# moves current non empty line in readFile to line after line described by first two chars
def goToLineByStartingChars(string):
    global readFile
    # find start of relevant data
    line = ""
    while True:
        line = readFile.readline()
        if not line:
            print("Incorrect arguments or file format: error in goToLineByStartingChars")
            sys.exit(1)
        
        # strip after checking for EOF so we can distinguish EOF vs empty line
        line = line.strip()

        # discard lines until we see the column line in output
        if line and line[0:len(string)] == string:
            break

# recursively find largest cycle included a specific node using DFS
def findLargestCycle(startNode, currNode, visited, globalVisited, currSum, currCycle):
    global handoff, maxCycle, maxCycleSum

    # print("CurrNode: " + str(currNode) + " StartNode: " + str(startNode))
    # print(currCycle)
    # if we visit a node from a previous call then return since we have already
    #   looked at all cycles containing that node
    if globalVisited[currNode]:
        # print("globalVisited")
        return

    # found a cycle
    if visited[currNode]:
        # if currNode == startNode:
        # print("LocalVisited, curr: " + str(currSum) + ", max: " + str(maxCycleSum) + ", Start: " + str(startNode) )
        # if the cycle contains the start node check if it is our new max cycle
        if currNode == startNode and currSum > maxCycleSum:
            # print("NewMax")
            maxCycleSum = currSum
            maxCycle = currCycle.copy()
        return

    visited[currNode] = True
    # print(visited)
    for idx, val in enumerate(handoff[currNode]):
        # continue if no edge
        # if val == 0:
        #     continue

        currCycle.append(currNode)
        findLargestCycle(startNode, idx, visited, globalVisited, currSum + val, currCycle)
        currCycle.pop()
    # print(currNode)
    visited[currNode] = False



def analyzeRun():
    global handoff, sumPerRow, handoffTotal, maxCycle, maxCycleSum, minBound, maxBound, expected
    currThds = len(handoff)
    

    # This is NP-Hard and is currently not tractable so we just estimate the largest cycle
    # the code to do so is commented below

    # find largest cycle
    # globalVisited = [False] * currThds
    # for i in range(currThds):
    #     # print(i)
    #     visited = [False] * currThds
    #     findLargestCycle(i, i, visited, globalVisited, 0, [])
    #     globalVisited[i] = True

    # # calculate stats
    # cycleHandoffs = []
    # cycleHandoffs.append(handoff[maxCycle[-1]][maxCycle[0]])
    # sumOfMaxCycleRows = sumPerRow[maxCycle[0]]

    # # expected handoff is MULT P( handoff ) for each handoff in max cycle
    # expectedConvoy = handoff[maxCycle[-1]][maxCycle[0]] / sumPerRow[maxCycle[-1]]
    # for idx, val in enumerate(maxCycle[1:]):
    #     cycleHandoffs.append(handoff[maxCycle[idx]][val])
    #     sumOfMaxCycleRows += sumPerRow[val]
    #     expectedConvoy = expectedConvoy * handoff[maxCycle[idx]][val] / sumPerRow[maxCycle[idx]]

    # # adjust expected bound
    # # if max cycle contains all nodes, sumOfMaxCycleRows / handoffTotal == 1
    # # else this adjusts the expected bound to compensate for the cycle not visiting all nodes
    # # also mult by 100 to turn into percentage from decimal
    # expectedConvoy = expectedConvoy * sumOfMaxCycleRows * 100 / handoffTotal

    ################################

    #start of approximation code is here:

    # instead we take the maximum handoff from each row and assume it is all a cycle as an approximation
    maxCycle = []
    maxCycleSum = 0
    cycleHandoffs = []
    expectedConvoy = 100
    for i in range(currThds):
        currMax = max(handoff[i])
        maxCycle.append(i)
        cycleHandoffs.append(currMax)
        maxCycleSum += currMax
        expectedConvoy = expectedConvoy * currMax / sumPerRow[i]

    sumOfMaxCycleRows = handoffTotal

    # end of approximation code
    ###################################################

    # upper bound is the percentage of all handoffs that occur in the maximum possible cycle
    # The cycle is possible min(cycleHandoffs) number of times
    maxFeasibleHandoff = min(cycleHandoffs)
    upperBoundConvoy = maxFeasibleHandoff * len(maxCycle) * 100 / handoffTotal
    lowerBoundConvoy = 1
    for val in maxCycle:
        lowerBoundConvoy = lowerBoundConvoy * maxFeasibleHandoff / sumPerRow[val]

    # adjust lower bound. See comment for expectedConvoy adjustment to explain why
    lowerBoundConvoy = lowerBoundConvoy * sumOfMaxCycleRows * 100 / handoffTotal

    maxBound.append(upperBoundConvoy)
    minBound.append(lowerBoundConvoy)
    expected.append(expectedConvoy)

    if args.Verbose or args.Single:
        output('Convoying bounds: {:.2f}%-{:.2f}%, Expected convoying: {:.2f}%'.format(lowerBoundConvoy,upperBoundConvoy, expectedConvoy))


if args.Single:
    output("N: " + str(args.MaxThreads))
    goToLineByStartingChars(str(args.MaxThreads)+' ')
    readInMatrix(int(args.MaxThreads))
    analyzeRun()
    reset()
else:
    for i in range(args.MaxThreads):
        output("N: " + str(i+1))

        goToLineByStartingChars(str(i+1)+' ')
        for j in range(args.RunsPerNumThds):
            readInMatrix(i+1)
            analyzeRun()
            reset()

        output('Mean convoying bounds: {:.2f}%-{:.2f}%, Mean expected convoying: {:.2f}%'.format(st.mean(minBound), st.mean(maxBound),st.mean(expected)))
        output('Median convoying bounds: {:.2f}%-{:.2f}%, Median expected convoying: {:.2f}%'.format(st.median(minBound), st.median(maxBound),st.median(expected)))
        output('')
        thdReset()

readFile.close()

if writeFile:
    writeFile.close()