#!/usr/bin/python3
"""
Python Script to implement R.M.I.T. testing : Randomized Multiple Interleaved Trials

./rmit.py run COMMAND CANDIDATES
-t trials
-o option:values
"""


import argparse
import datetime
import itertools
import json
import os
import random
import re
import socket
import subprocess
import sys


def parse_range(x):
	result = []
	for part in x.split(','):
		if '-' in part:
			a, b = part.split('-')
			a, b = int(a), int(b)
			result.extend(range(a, b + 1))
		else:
			a = int(part)
			result.append(a)
	return result

class DependentOpt:
	def __init__(self, key, value):
		self.key = key
		self.value = value
		self.vars = re.findall("[a-zA-Z]", value)

def parse_option(key, values):
	try:
		num = int(values)
		return key, [num]
	except:
		pass

	if values.startswith('\\'):
		return key, values[1:].split(',')
	elif re.search("^[0-9-,]+$", values):
		values = parse_range(values)
		return key, [v for v in values]
	else:
		return key, DependentOpt(key, values)

def eval_one(fmt, vals):
	orig = fmt
	for k, v in vals:
		fmt = fmt.replace(k, str(v))

	if not re.search("^[0-9-/*+ ]+$", fmt):
		print('ERROR: pattern option {} (interpreted as {}) could not be evaluated'.format(orig, fmt), file=sys.stderr)
		sys.exit(1)

	return eval(fmt)

# Evaluate all the options
# options can be of the for key = val or key = some_math(other_key)
# produce a list of all the options to replace some_math(other_key) with actual value
def eval_options(opts):
	# Find all the options with dependencies
	dependents = [d for d in opts.values() if type(d) is DependentOpt]

	# we need to find all the straglers
	processed = []

	# extract all the necessary inputs
	input_keys = {}
	for d in dependents:
		# Mark the dependent as seen
		processed.append(d.key)

		# process each of the dependencies
		for dvar in d.vars:
			# Check that it depends on something that exists
			if not dvar in opts.keys():
				print('ERROR: extra pattern option {}:{} uses unknown key {}'.format(d.key,d.value,dvar), file=sys.stderr)
				sys.exit(1)

			# Check that it's not nested
			if type(dvar) is DependentOpt:
				print('ERROR: dependent options cannot be nested {}:{} uses key {}'.format(d.key,d.value,dvar), file=sys.stderr)
				sys.exit(1)

			# Add the values to the input keys
			if dvar not in input_keys:
				input_keys[dvar] = opts[dvar]
			else :
				if input_keys[dvar] != opts[dvar]:
					print('INTERNAL ERROR: repeat input do not match {}:{} vs {}'.format(dvar,opts[dvar],input_keys[dvar]), file=sys.stderr)
					sys.exit(1)

			# Mark the input as seen
			processed.append(dvar)

	# add in all the straglers they should cause too many problems
	for k, v in opts.items():
		if type(v) is DependentOpt:
			continue

		if k in processed:
			# consistency check
			if k not in input_keys:
				print('INTERNAL ERROR: key \'{}\' marked as processed but not in input_keys'.format(k), file=sys.stderr)
				sys.exit(1)
			continue

		# consistency check
		if k in input_keys:
			print('INTERNAL ERROR: key \'{}\' in input_keys but not marked as processed'.format(k), file=sys.stderr)
			sys.exit(1)

		# add the straggler
		input_keys[k] = v

	# flatten the dict into a list of pairs so it's easier to work with
	input_list = []
	for k, v in input_keys.items():
		input_list.append([(k, o) for o in v])

	# evaluate all the dependents
	# they are not allowed to produce new values so it's a one-to-one mapping from here
	evaluated = []
	for inputs in list(itertools.product(*input_list)):
		this_eval = list(inputs)
		for d in dependents:
			this_eval.append((d.key, eval_one(d.value, inputs)))

		evaluated.append(this_eval)

	# reformat everything to a list of arguments
	formated = []
	for o in evaluated:
		inner = []
		for k,v in o:
			inner.append("-{}".format(k))
			inner.append("{}".format(v))

		# print(inner)
		formated.append(inner)

	return formated

# returns the first option with key 'opt'
def search_option(action, opt):
	i = 0
	while i < len(action):
		if action[i] == opt:
			i += 1
			if i != len(action):
				return action[i]
		i += 1

	return None

def actions_eta(actions):
	time = 0
	for a in actions:
		o = search_option(a, '-d')
		if o :
			time += int(o)
	return time

taskset_maps = None

def init_taskset_maps():
	global taskset_maps
	known_hosts = {
		"jax": {
			range(  1,  25) : "48-71",
			range( 25,  49) : "48-71,144-167",
			range( 49,  97) : "48-95,144-191",
			range( 97, 145) : "24-95,120-191",
			range(145, 193) : "0-95,96-191",
		},
		"nasus": {
			range(  1,  65) : "64-127",
			range( 65, 129) : "64-127,192-255",
			range(129, 193) : "64-255",
			range(193, 257) : "0-255",
		},
		"ocean": {
			range(  1,  33) : "0-31",
		},
	}

	host = socket.gethostname()
	if host in known_hosts:
		taskset_maps = known_hosts[host]
		return True

	print("Warning unknown host '{}', disable taskset usage".format(host), file=sys.stderr)
	return False


def settaskset_one(action):
	o = search_option(action, '-p')
	if not o:
		return action
	try:
		oi = int(o)
	except ValueError:
		return action

	m = "Not found"
	for key in taskset_maps:
		if oi in key:
			return ['taskset', '-c', taskset_maps[key], *action]

	print("Warning no mapping for {} cores".format(oi), file=sys.stderr)
	return action

def settaskset(actions):
	return [settaskset_one(a) for a in actions]

if __name__ == "__main__":
	# ================================================================================
	# parse command line arguments
	formats = ['raw', 'csv', 'json']
	parser = argparse.ArgumentParser(description='Python Script to implement R.M.I.T. testing : Randomized Multiple Interleaved Trials')
	parser.add_argument('--list', help='List all the commands that would be run', action='store_true')
	parser.add_argument('--file', nargs='?', type=argparse.FileType('w'), default=sys.stdout)
	parser.add_argument('--trials', help='Number of trials to run per combinaison', type=int, default=3)
	parser.add_argument('--notaskset', help='If specified, the trial will not use taskset to match the -p option', action='store_true')
	parser.add_argument('--extra', help='Extra arguments to be added unconditionally', action='append', type=str)
	parser.add_argument('command', metavar='command', type=str, nargs=1, help='the command prefix to run')
	parser.add_argument('candidates', metavar='candidates', type=str, nargs='*', help='the candidate suffix to run')

	try:
		options, unknown =  parser.parse_known_args()

		options.option = []
		while unknown:
			key = unknown.pop(0)
			val = unknown.pop(0)

			if key[0] != '-':
				raise ValueError

			options.option.append((key[1:], val))

	except:
		sys.exit(1)

	# ================================================================================
	# Identify the commands to run
	command = './' + options.command[0]
	if options.candidates:
		commands = [command + "-" + c for c in options.candidates]
	else:
		commands = [command]
	for c in commands:
		if not os.path.isfile(c):
			print('ERROR: invalid command {}, file does not exist'.format(c), file=sys.stderr)
			sys.exit(1)

		if not os.access(c, os.X_OK):
			print('ERROR: invalid command {}, file not executable'.format(c), file=sys.stderr)
			sys.exit(1)


	# ================================================================================
	# Identify the options to run
	opts = dict([parse_option(k, v) for k, v in options.option])

	# Evaluate the options (options can depend on the value of other options)
	opts = eval_options(opts)

	# ================================================================================
	# Figure out all the combinations to run
	actions = []
	for p in itertools.product(range(options.trials), commands, opts):
		act = [p[1]]
		for o in p[2:]:
			act.extend(o)
		actions.append(act)

	# ================================================================================
	# Fixup the different commands

	# add extras
	if options.extra:
		for act in actions:
			for e in options.extra:
				act.append(e)

	# Add tasksets
	withtaskset = False
	if not options.notaskset and init_taskset_maps():
		withtaskset = True
		actions = settaskset(actions)

	# ================================================================================
	# Now that we know what to run, print it.
	# find expected time
	time = actions_eta(actions)
	print("Running {} trials{}".format(len(actions), "" if time == 0 else " (expecting to take {})".format(str(datetime.timedelta(seconds=int(time)))) ))

	# dry run if options ask for it
	if options.list:
		for a in actions:
			print(" ".join(a))
		sys.exit(0)


	# ================================================================================
	# Prepare to run

	random.shuffle(actions)

	# ================================================================================
	# Run
	options.file.write("[")
	first = True
	for i, a in enumerate(actions):
		sa = " ".join(a[3:] if withtaskset else a)
		if first:
			first = False
		else:
			options.file.write(",")
		if options.file != sys.stdout:
			print("{}/{} : {}          \r".format(i, len(actions), sa), end = '')
		fields = {}
		with subprocess.Popen( a, stdout  = subprocess.PIPE, stderr  = subprocess.PIPE) as proc:
			out, err = proc.communicate()
			if proc.returncode != 0:
				print("ERROR: command '{}' encountered error, returned code {}".format(sa, proc.returncode), file=sys.stderr)
				print(err.decode("utf-8"))
				sys.exit(1)
			for s in out.decode("utf-8").splitlines():
				match = re.search("^(.*):(.*)$", s)
				if match:
					try:
						fields[match.group(1).strip()] = float(match.group(2).strip().replace(',',''))
					except:
						pass

		options.file.write(json.dumps([a[3 if withtaskset else 0][2:], sa, fields]))
		options.file.flush()

	options.file.write("]\n")

	if options.file != sys.stdout:
		print("Done                                                                                ")
