#!/usr/bin/python3
"""
Python Script to plot values obtained by the rmit.py script
Runs a R.I.P.L.

./plot.py
-t trials
-o option:values
"""

import argparse
import itertools
import json
import math
import numpy
import os
import re
import statistics
import sys

import matplotlib.pyplot as plt
from matplotlib.ticker import EngFormatter

class Field:
	def __init__(self, unit, _min, _log, _name=None, _factor=1.0):
		self.unit = unit
		self.min  = _min
		self.log  = _log
		self.name = _name
		self.factor = _factor

field_names = {
	"ns per ops"            : Field('ns'    , 0, False),
	"Number of processors"  : Field(''      , 1, False),
	"Ops per procs"         : Field('Ops'   , 0, False),
	"Ops per threads"       : Field('Ops'   , 0, False),
	"ns per ops/procs"      : Field(''      , 0, False, _name = "Latency ((ns $/$ Operation) $\\times$ Processor)" ),
	"Number of threads"     : Field(''      , 1, False),
	"Total Operations(ops)" : Field('Ops'   , 0, False),
	"Ops/sec/procs"         : Field('Ops'   , 0, False),
	"Total blocks"          : Field('Blocks', 0, False),
	"Ops per second"        : Field(''      , 0, False),
	"Cycle size (# thrds)"  : Field('thrd'  , 1, False),
	"Duration (ms)"         : Field('ms'    , 0, False),
	"Target QPS"            : Field(''      , 0, False),
	"Actual QPS"            : Field(''      , 0, False),
	"Average Read Latency"  : Field('s'     , 0, False, _factor = 0.000001),
	"Median Read Latency"   : Field('s'     , 0, True, _factor = 0.000001),
	"Tail Read Latency"     : Field('s'     , 0, True, _factor = 0.000001),
	"Average Update Latency": Field('s'     , 0, True, _factor = 0.000001),
	"Median Update Latency" : Field('s'     , 0, True, _factor = 0.000001),
	"Tail Update Latency"   : Field('s'     , 0, True, _factor = 0.000001),
	"Update Ratio"          : Field('%'   , 0, False),
	"Request Rate"          : Field('req/s' , 0, False),
	"Data Rate"             : Field('b/s'   , 0, False, _factor = 1000 * 1000, _name = "Response Throughput"),
	"Errors"                : Field('%'   , 0, False),
}

def plot(in_data, x, y, options, prefix):
	fig, ax = plt.subplots()
	colors = itertools.cycle(['#006cb4','#0aa000','#ff6600','#8510a1','#0095e3','#fd8f00','#e30002','#8f00d6','#4b009a','#ffff00','#69df00','#fb0300','#b13f00'])
	series = {} # scatter data for each individual data point
	groups = {} # data points for x value

	print("Preparing Data")

	for entry in in_data:
		name = entry[0]
		if not name in series:
			series[name] = {'x':[], 'y':[]}

		if not name in groups:
			groups[name] = {}

		if x in entry[2] and y in entry[2]:
			xval = entry[2][x]
			yval = entry[2][y] * field_names[y].factor
			series[name]['x'].append(xval)
			series[name]['y'].append(yval)

			if not xval in groups[name]:
				groups[name][xval] = []

			groups[name][xval].append(yval)

	print("Preparing Lines")

	lines = {} # lines from groups with min, max, median, etc.
	for name, data in groups.items():
		if not name in lines:
			lines[name] = { 'x': [], 'min':[], 'max':[], 'med':[], 'avg':[] }

		for xkey in sorted(data):
			ys = data[xkey]
			lines[name]['x']  .append(xkey)
			lines[name]['min'].append(min(ys))
			lines[name]['max'].append(max(ys))
			lines[name]['med'].append(statistics.median(ys))
			lines[name]['avg'].append(statistics.mean(ys))

	print("Making Plots")

	for name, data in sorted(series.items()):
		_col = next(colors)
		plt.scatter(data['x'], data['y'], color=_col, label=name[len(prefix):], marker='x')
		plt.plot(lines[name]['x'], lines[name]['min'], '--', color=_col)
		plt.plot(lines[name]['x'], lines[name]['max'], '--', color=_col)
		plt.plot(lines[name]['x'], lines[name]['med'], '-', color=_col)

	print("Calculating Extremums")

	mx = max([max(s['x']) for s in series.values()])
	my = max([max(s['y']) for s in series.values()])

	print("Finishing Plots")

	plt.ylabel(field_names[y].name if field_names[y].name else y)
	# plt.xticks(range(1, math.ceil(mx) + 1))
	plt.xlabel(field_names[x].name if field_names[x].name else x)
	plt.grid(b = True)
	ax.xaxis.set_major_formatter( EngFormatter(unit=field_names[x].unit) )
	if options.logx:
		ax.set_xscale('log')
	elif field_names[x].log:
		ax.set_xscale('log')
	else:
		plt.xlim(field_names[x].min, mx + 0.25)

	if options.logy:
		ax.set_yscale('log')
	elif field_names[y].log:
		ax.set_yscale('log')
	else:
		plt.ylim(field_names[y].min, options.MaxY if options.MaxY else my*1.2)

	ax.yaxis.set_major_formatter( EngFormatter(unit=field_names[y].unit) )

	plt.legend(loc='upper left')

	print("Results Ready")
	if options.out:
		plt.savefig(options.out, bbox_inches='tight')
	else:
		plt.show()


if __name__ == "__main__":
	# ================================================================================
	# parse command line arguments
	parser = argparse.ArgumentParser(description='Python Script to draw R.M.I.T. results')
	parser.add_argument('-f', '--file', nargs='?', type=argparse.FileType('r'), default=sys.stdin, help="Input file")
	parser.add_argument('-o', '--out', nargs='?', type=str, default=None, help="Output file")
	parser.add_argument('-y', nargs='?', type=str, default="", help="Which field to use as the Y axis")
	parser.add_argument('-x', nargs='?', type=str, default="", help="Which field to use as the X axis")
	parser.add_argument('--logx', action='store_true', help="if set, makes the x-axis logscale")
	parser.add_argument('--logy', action='store_true', help="if set, makes the y-axis logscale")
	parser.add_argument('--MaxY', nargs='?', type=int, help="maximum value of the y-axis")

	options =  parser.parse_args()

	# ================================================================================
	# load data
	try :
		data = json.load(options.file)
	except :
		print('ERROR: could not read input', file=sys.stderr)
		parser.print_help(sys.stderr)
		sys.exit(1)

	# ================================================================================
	# identify the keys

	series = set()
	fields = set()

	for entry in data:
		series.add(entry[0])
		for label in entry[2].keys():
			fields.add(label)

	# find the common prefix on series for removal
	prefix = os.path.commonprefix(list(series))

	if not options.out :
		print(series)
		print("fields: ", ' '.join(fields))

	wantx = "Number of processors"
	wanty = "ns per ops"

	if options.x:
		if options.x in field_names.keys():
			wantx = options.x
		else:
			print("Could not find X key '{}', defaulting to '{}'".format(options.x, wantx))

	if options.y:
		if options.y in field_names.keys():
			wanty = options.y
		else:
			print("Could not find Y key '{}', defaulting to '{}'".format(options.y, wanty))


	plot(data, wantx, wanty, options, prefix)
