#!/usr/bin/python3
"""
Python Script to plot values obtained by the rmit.py script
Runs a R.I.P.L.

./plot.py
-t trials
-o option:values
"""

import argparse
import itertools
import json
import math
import numpy
import os
import re
import statistics
import sys
import time

import matplotlib
import matplotlib.pyplot as plt
from matplotlib.ticker import EngFormatter, ScalarFormatter

def fmtDur( duration ):
	if duration :
		hours, rem = divmod(duration, 3600)
		minutes, rem = divmod(rem, 60)
		seconds, millis = divmod(rem, 1)
		return "%2d:%02d.%03d" % (minutes, seconds, millis * 1000)
	return " n/a"

class Field:
	def __init__(self, unit, _min, _log, _name=None, _factor=1.0):
		self.unit = unit
		self.min  = _min
		self.log  = _log
		self.name = _name
		self.factor = _factor

field_names = {
	"ns per ops"            : Field('ns'    , 0, False),
	"Number of processors"  : Field(''      , 1, "exact"),
	"Ops per procs"         : Field('Ops'   , 0, False),
	"Ops per threads"       : Field('Ops'   , 0, False),
	"ns per ops/procs"      : Field(''      , 0, False, _name = "ns $\\times$ (Processor $/$ Total Ops)" ),
	"Number of threads"     : Field(''      , 1, False),
	"Total Operations(ops)" : Field('Ops'   , 0, False),
	"Ops/sec/procs"         : Field('Ops'   , 0, False),
	"Total blocks"          : Field('Blocks', 0, False),
	"Ops per second"        : Field(''      , 0, False),
	"Cycle size (# thrds)"  : Field('thrd'  , 1, False),
	"Duration (ms)"         : Field('ms'    , 0, False),
	"Target QPS"            : Field(''      , 0, False),
	"Actual QPS"            : Field(''      , 0, False),
	"Average Read Latency"  : Field('s'     , 0, False, _factor = 0.000001),
	"Median Read Latency"   : Field('s'     , 0, True, _factor = 0.000001),
	"Tail Read Latency"     : Field('s'     , 0, True, _factor = 0.000001),
	"Average Update Latency": Field('s'     , 0, True, _factor = 0.000001),
	"Median Update Latency" : Field('s'     , 0, True, _factor = 0.000001),
	"Tail Update Latency"   : Field('s'     , 0, True, _factor = 0.000001),
	"Update Ratio"          : Field('%'   , 0, False),
	"Request Rate"          : Field('req/s' , 0, False),
	"Data Rate"             : Field('b/s'   , 0, False, _factor = 1000 * 1000, _name = "Response Throughput"),
	"Errors"                : Field('%'   , 0, False),
}

def plot(in_data, x, y, options, prefix):
	fig, ax = plt.subplots()
	colors  = itertools.cycle(['#006cb4','#0aa000','#ff6600','#8510a1','#0095e3','#fd8f00','#e30002','#8f00d6','#4b009a','#ffff00','#69df00','#fb0300','#b13f00'])
	markers = itertools.cycle(['x', '+', '1', '2', '3', '4'])
	series  = {} # scatter data for each individual data point
	groups  = {} # data points for x value

	print("Preparing Data")

	for entry in in_data:
		name = entry[0]
		if options.filter and not name.startswith(options.filter):
			continue

		if not name in series:
			series[name] = {'x':[], 'y':[]}

		if not name in groups:
			groups[name] = {}

		if x in entry[2] and y in entry[2]:
			xval = entry[2][x]
			yval = entry[2][y] * field_names[y].factor
			series[name]['x'].append(xval)
			series[name]['y'].append(yval)

			if not xval in groups[name]:
				groups[name][xval] = []

			groups[name][xval].append(yval)

	print("Preparing Lines")

	lines = {} # lines from groups with min, max, median, etc.
	for name, data in groups.items():
		if not name in lines:
			lines[name] = { 'x': [], 'min':[], 'max':[], 'med':[], 'avg':[] }

		for xkey in sorted(data):
			ys = data[xkey]
			lines[name]['x']  .append(xkey)
			lines[name]['min'].append(min(ys))
			lines[name]['max'].append(max(ys))
			lines[name]['med'].append(statistics.median(ys))
			lines[name]['avg'].append(statistics.mean(ys))

	print("Making Plots")

	for name, data in sorted(series.items()):
		_col = next(colors)
		_mrk = next(markers)
		plt.scatter(data['x'], data['y'], color=_col, label=name[len(prefix):], marker=_mrk)
		plt.plot(lines[name]['x'], lines[name]['min'], ':', color=_col)
		plt.plot(lines[name]['x'], lines[name]['max'], '--', color=_col)
		plt.plot(lines[name]['x'], lines[name]['med'], '-', color=_col)

	print("Calculating Extremums")

	mx = max([max(s['x']) for s in series.values()])
	my = max([max(s['y']) for s in series.values()])

	print("Finishing Plots")

	plt.ylabel(field_names[y].name if field_names[y].name else y)
	# plt.xticks(range(1, math.ceil(mx) + 1))
	plt.xlabel(field_names[x].name if field_names[x].name else x)
	plt.grid(b = True)
	ax.xaxis.set_major_formatter( EngFormatter(unit=field_names[x].unit) )
	if options.logx:
		ax.set_xscale('log')
	elif field_names[x].log:
		ax.set_xscale('log')
		if field_names[x].log == "exact":
			xvals = set()
			for s in series.values():
				xvals |= set(s['x'])
			ax.set_xticks(sorted(xvals))
			ax.get_xaxis().set_major_formatter(ScalarFormatter())
			plt.xticks(rotation = 45)
	else:
		plt.xlim(field_names[x].min, mx + 0.25)

	if options.logy:
		ax.set_yscale('log')
	elif field_names[y].log:
		ax.set_yscale('log')
	else:
		plt.ylim(field_names[y].min, options.MaxY if options.MaxY else my*1.2)

	ax.yaxis.set_major_formatter( EngFormatter(unit=field_names[y].unit) )

	plt.legend(loc='upper left')

	print("Results Ready")
	start = time.time()
	if options.out:
		plt.savefig(options.out, bbox_inches='tight')
	else:
		plt.show()
	end = time.time()
	print("Took {}".format(fmtDur(end - start)))


if __name__ == "__main__":
	# ================================================================================
	# parse command line arguments
	parser = argparse.ArgumentParser(description='Python Script to draw R.M.I.T. results')
	parser.add_argument('-f', '--file', nargs='?', type=argparse.FileType('r'), default=sys.stdin, help="Input file")
	parser.add_argument('-o', '--out', nargs='?', type=str, default=None, help="Output file")
	parser.add_argument('-y', nargs='?', type=str, default="", help="Which field to use as the Y axis")
	parser.add_argument('-x', nargs='?', type=str, default="", help="Which field to use as the X axis")
	parser.add_argument('--logx', action='store_true', help="if set, makes the x-axis logscale")
	parser.add_argument('--logy', action='store_true', help="if set, makes the y-axis logscale")
	parser.add_argument('--MaxY', nargs='?', type=int, help="maximum value of the y-axis")
	parser.add_argument('--filter', nargs='?', type=str, default="", help="if not empty, only print series that start with specified filter")

	options =  parser.parse_args()

	# if not options.out:
	# 	matplotlib.use('SVG')

	# ================================================================================
	# load data
	try :
		data = json.load(options.file)
	except :
		print('ERROR: could not read input', file=sys.stderr)
		parser.print_help(sys.stderr)
		sys.exit(1)

	# ================================================================================
	# identify the keys

	series = set()
	fields = set()

	for entry in data:
		series.add(entry[0])
		for label in entry[2].keys():
			fields.add(label)

	# filter out the series if needed
	if options.filter:
		series = set(filter(lambda elem: elem.startswith(options.filter), series))

	# find the common prefix on series for removal (only if no filter)
	prefix = os.path.commonprefix(list(series))

	if not options.out :
		print(series)
		print("fields: ", ' '.join(fields))

	wantx = "Number of processors"
	wanty = "ns per ops"

	if options.x:
		if options.x in field_names.keys():
			wantx = options.x
		else:
			print("Could not find X key '{}', defaulting to '{}'".format(options.x, wantx))

	if options.y:
		if options.y in field_names.keys():
			wanty = options.y
		else:
			print("Could not find Y key '{}', defaulting to '{}'".format(options.y, wanty))


	plot(data, wantx, wanty, options, prefix)
