#
# Copyright (C) Lynn Tran, Jiachen Zhang 2018
#
# utils-gdb.py --
#
# Author           : Lynn Tran
# Created On       : Mon Oct 1 22:06:09 2018
# Last Modified By : Peter A. Buhr
# Last Modified On : Sat Jan 19 14:16:10 2019
# Update Count     : 11
#

"""
To run this extension, the python name has to be as same as one of the loaded library
Additionally, the file must exist in a folder which is in gdb's safe path
"""
import collections
import gdb
import re

# set these signal handlers with some settings (nostop, noprint, pass)
gdb.execute('handle SIGALRM nostop noprint pass')
gdb.execute('handle SIGUSR1 nostop noprint pass')

CfaTypes = collections.namedtuple('CfaTypes', 'cluster_ptr processor_ptr thread_ptr int_ptr thread_state')

class ThreadInfo:
	tid = 0
	cluster = None
	value = None

	def __init__(self, cluster, value):
		self.cluster = cluster
		self.value = value

	def is_system(self):
		return False

# A named tuple representing information about a stack
StackInfo = collections.namedtuple('StackInfo', 'sp fp pc')

# A global variable to keep track of stack information as one switches from one
# task to another task
STACK = []

# A global variable to keep all system task name
SysTask_Name = ["uLocalDebuggerReader", "uLocalDebugger", "uProcessorTask", "uBootTask", "uSystemTask",
"uProcessorTask", "uPthread", "uProfiler"]

not_supported_error_msg = "Not a supported command for this language"

def is_cforall():
	return True

def get_cfa_types():
	# GDB types for various structures/types in CFA
	return CfaTypes(cluster_ptr = gdb.lookup_type('struct cluster').pointer(),
				  processor_ptr = gdb.lookup_type('struct processor').pointer(),
					 thread_ptr = gdb.lookup_type('struct $thread').pointer(),
						int_ptr = gdb.lookup_type('int').pointer(),
				   thread_state = gdb.lookup_type('enum coroutine_state'))

def get_addr(addr):
	"""
	NOTE: sketchy solution to retrieve address. There is a better solution...
	@addr: str of an address that can be in a format 0xfffff <type of the object
	at this address>
	Return: str of just the address
	"""
	str_addr = str(addr)
	ending_addr_index = str_addr.find('<')
	if ending_addr_index == -1:
		return str(addr)
	return str_addr[:ending_addr_index].strip()

def print_usage(obj):
	print(obj.__doc__)

def parse(args):
	"""
	Split the argument list in string format, where each argument is separated
	by whitespace delimiter, to a list of arguments like argv
	@args: str of arguments
	Return:
		[] if args is an empty string
		list if args is not empty
	"""
	# parse the string format of arguments and return a list of arguments
	argv = args.split(' ')
	if len(argv) == 1 and argv[0] == '':
		return []
	return argv

def get_cluster_root():
	"""
	Return: gdb.Value of globalClusters.root (is an address)
	"""
	cluster_root = gdb.parse_and_eval('_X11mainClusterPS7cluster_1')
	if cluster_root.address == 0x0:
		print('No clusters, program terminated')
	return cluster_root

def find_curr_thread():
	# btstr = gdb.execute('bt', to_string = True).splitlines()
	# if len(btstr) == 0:
	#     print('error')
	#     return None
	# return btstr[0].split('this=',1)[1].split(',')[0].split(')')[0]
	return None

def all_clusters():
	if not is_cforall():
		return None

	cluster_root = get_cluster_root()
	if cluster_root.address == 0x0:
		return

	curr = cluster_root
	ret = [curr]

	while True:
		curr = curr['_X4nodeS26__cluster____dbg_node_cltr_1']['_X4nextPS7cluster_1']
		if curr == cluster_root:
			break

		ret.append(curr)

	return ret


def lookup_cluster(name = None):
	"""
	Look up a cluster given its ID
	@name: str
	Return: gdb.Value
	"""
	if not is_cforall():
		return None

	root = get_cluster_root()
	if root.address == 0x0:
		return None

	if not name:
		return root

	# lookup for the task associated with the id
	cluster = None
	curr = root
	while True:
		if curr['_X4namePKc_1'].string() == name:
			cluster = curr.address
			break
		curr = curr['_X4nodeS26__cluster____dbg_node_cltr_1']['_X4nextPS7cluster_1']
		if curr == root or curr == 0x0:
			break

	if not cluster:
		print("Cannot find a cluster with the name: {}.".format(name))
		return None

	return cluster

def lookup_threads_by_cluster(cluster):
		# Iterate through a circular linked list of threads and accumulate them in an array
		threads = []

		cfa_t = get_cfa_types()
		root = cluster['_X7threadsS8__dllist_S7$thread__1']['_X4headPY15__TYPE_generic__1'].cast(cfa_t.thread_ptr)

		if root == 0x0 or root.address == 0x0:
			print('There are no tasks for cluster: {}'.format(cluster))
			return threads

		curr = root
		tid = 0
		sid = -1

		while True:
			t = ThreadInfo(cluster, curr)
			if t.is_system():
				t.tid = sid
				sid -= 1
			else:
				t.tid = tid
				tid += 1

			threads.append(t)

			curr = curr['node']['next']
			if curr == root or curr == 0x0:
				break

		return threads

def system_thread(thread):
	return False

def adjust_stack(pc, fp, sp):
	# pop sp, fp, pc from global stack
	gdb.execute('set $pc = {}'.format(pc))
	gdb.execute('set $rbp = {}'.format(fp))
	gdb.execute('set $sp = {}'.format(sp))

############################ COMMAND IMPLEMENTATION #########################

class Clusters(gdb.Command):
	"""Cforall: Display currently known clusters
Usage:
	info clusters                 : print out all the clusters
"""

	def __init__(self):
		super(Clusters, self).__init__('info clusters', gdb.COMMAND_USER)

	def print_cluster(self, cluster_name, cluster_address):
		print('{:>20}  {:>20}'.format(cluster_name, cluster_address))

	#entry point from gdb
	def invoke(self, arg, from_tty):
		if not is_cforall():
			return

		if arg:
			print("info clusters does not take arguments")
			print_usage(self)
			return

		self.print_cluster('Name', 'Address')

		for c in all_clusters():
			self.print_cluster(c['_X4namePKc_1'].string(), str(c))

		print("")

############
class Processors(gdb.Command):
	"""Cforall: Display currently known processors
Usage:
	info processors                 : print out all the processors in the Main Cluster
	info processors all             : print out all processors in all clusters
	info processors <cluster_name>  : print out all processors in a given cluster
"""

	def __init__(self):
		super(Processors, self).__init__('info processors', gdb.COMMAND_USER)

	def print_processor(self, name, status, pending, address):
		print('{:>20}  {:>11}  {:>13}  {:>20}'.format(name, status, pending, address))

	def iterate_procs(self, root, active):
		if root == 0x0:
			return

		cfa_t = get_cfa_types()
		curr = root

		while True:
			processor = curr
			should_stop = processor['_X12do_terminateVb_1']
			stop_count  = processor['_X10terminatedS9semaphore_1']['_X5counti_1']
			if not should_stop:
				status = 'Active' if active else 'Idle'
			else:
				status_str  = 'Last Thread' if stop_count >= 0 else 'Terminating'
				status      = '{}({},{})'.format(status_str, should_stop, stop_count)

			self.print_processor(processor['_X4namePKc_1'].string(),
					status, str(processor['_X18pending_preemptionb_1']), str(processor)
				)

			curr = curr['_X4nodeS28__processor____dbg_node_proc_1']['_X4nextPS9processor_1']

			if curr == root or curr == 0x0:
				break

	#entry point from gdb
	def invoke(self, arg, from_tty):
		if not is_cforall():
			return

		if not arg:
			clusters = [lookup_cluster(None)]
		elif arg == "all":
			clusters = all_clusters()
		else:
			clusters = [lookup_cluster(arg)]

		if not clusters:
			print("No Cluster matching arguments found")
			return

		cfa_t = get_cfa_types()
		for cluster in clusters:
			print('Cluster: "{}"({})'.format(cluster['_X4namePKc_1'].string(), cluster.cast(cfa_t.cluster_ptr)))

			active_root = cluster.cast(cfa_t.cluster_ptr) \
					['_X5procsS8__dllist_S9processor__1'] \
					['_X4headPY15__TYPE_generic__1'] \
					.cast(cfa_t.processor_ptr)

			idle_root = cluster.cast(cfa_t.cluster_ptr) \
					['_X5idlesS8__dllist_S9processor__1'] \
					['_X4headPY15__TYPE_generic__1'] \
					.cast(cfa_t.processor_ptr)

			if idle_root != 0x0 or active_root != 0x0:
				self.print_processor('Name', 'Status', 'Pending Yield', 'Address')
				self.iterate_procs(active_root, True)
				self.iterate_procs(idle_root, False)
			else:
				print("No processors on cluster")

		print()

############
class Threads(gdb.Command):
	"""Cforall: Display currently known threads
Usage:
	cfathreads                           : print Main Cluster threads, application threads only
	cfathreads all                       : print all clusters, all threads
	cfathreads <clusterName>             : print cluster threads, application threads only
	"""
	def __init__(self):
		# The first parameter of the line below is the name of the command. You
		# can call it 'uc++ task'
		super(Threads, self).__init__('info cfathreads', gdb.COMMAND_USER)

	def print_formatted(self, marked, tid, name, state, address):
		print('{:>1}  {:>4}  {:>20}  {:>10}  {:>20}'.format('*' if marked else ' ', tid, name, state, address))

	def print_thread(self, thread, tid, marked):
		cfa_t = get_cfa_types()
		self.print_formatted(marked, tid, thread['self_cor']['name'].string(), str(thread['state'].cast(cfa_t.thread_state)), str(thread))

	def print_threads_by_cluster(self, cluster, print_system = False):
		# Iterate through a circular linked list of tasks and print out its
		# name along with address associated to each cluster
		threads = lookup_threads_by_cluster(cluster)
		if not threads:
			return

		running_thread = find_curr_thread()
		if running_thread is None:
			print('Could not identify current thread')

		self.print_formatted(False, '', 'Name', 'State', 'Address')

		for t in threads:
			if not t.is_system() or print_system:
				self.print_thread(t.value, t.tid, t.value == running_thread if running_thread else False)

		print()

	def print_all_threads(self):
		for c in all_clusters():
			self.print_threads_by_cluster(c, False)

	def invoke(self, arg, from_tty):
		"""
		@arg: str
		@from_tty: bool
		"""
		if not is_cforall():
			return

		if not arg:
			cluster = lookup_cluster()
			if not cluster:
				print("Could not find Main Cluster")
				return

			# only tasks and main
			self.print_threads_by_cluster(cluster, False)

		elif arg == 'all':
			# all threads, all clusters
			self.print_all_threads()

		else:
			cluster = lookup_cluster(arg)
			if not cluster:
				print("Could not find cluster '{}'".format(arg))
				return

			# all tasks, specified cluster
			self.print_threads_by_cluster(cluster, True)


############
class Thread(gdb.Command):
	"""Cforall: Switch to specified user threads
Usage:
	cfathread <id>                       : switch stack to thread id on main cluster
	cfathread 0x<address>	             : switch stack to thread on any cluster
	cfathread <id> <clusterName>         : switch stack to thread on specified cluster
	"""
	def __init__(self):
		# The first parameter of the line below is the name of the command. You
		# can call it 'uc++ task'
		super(Thread, self).__init__('cfathread', gdb.COMMAND_USER)

	############################ AUXILIARY FUNCTIONS #########################

	def switchto(self, thread):
		"""Change to a new task by switching to a different stack and manually
		adjusting sp, fp and pc
		@task_address: str
			2 supported format:
				in hex format
					<hex_address>: literal hexadecimal address
					Ex: 0xffffff
				in name of the pointer to the task
					"task_name": pointer of the variable name of the cluster
						Ex: T* s -> task_name = s
			Return: gdb.value of the cluster's address
		"""
		try:
			if not gdb.lookup_symbol('__cfactx_switch'):
				print('__cfactx_switch symbol is unavailable')
				return
		except:
			print('here 3')

		cfa_t = get_cfa_types()

		state = thread['state'].cast(cfa_t.thread_state)
		try:
			if state == gdb.parse_and_eval('Halted'):
				print('Cannot switch to a terminated thread')
				return

			if state == gdb.parse_and_eval('Start'):
				print('Cannjot switch to a thread not yet run')
				return
		except:
			print("here 2")
			return


		context = thread['context']

		# lookup for sp,fp and uSwitch
		xsp = context['SP'] + 48
		xfp = context['FP']

		# convert string so we can strip out the address
		try:
			xpc = get_addr(gdb.parse_and_eval('__cfactx_switch').address + 28)
		except:
			print("here")
			return

		# must be at frame 0 to set pc register
		gdb.execute('select-frame 0')

		# push sp, fp, pc into a global stack
		global STACK
		sp = gdb.parse_and_eval('$sp')
		fp = gdb.parse_and_eval('$fp')
		pc = gdb.parse_and_eval('$pc')
		stack_info = StackInfo(sp = sp, fp = fp, pc = pc)
		STACK.append(stack_info)

		# update registers for new task
		print('switching to ')
		gdb.execute('set $rsp={}'.format(xsp))
		gdb.execute('set $rbp={}'.format(xfp))
		gdb.execute('set $pc={}'.format(xpc))

	def find_matching_gdb_thread_id():
		"""
		Parse the str from info thread to get the number
		"""
		info_thread_str = gdb.execute('info thread', to_string=True).splitlines()
		for thread_str in info_thread_str:
			if thread_str.find('this={}'.format(task)) != -1:
				thread_id_pattern = r'^\*?\s+(\d+)\s+Thread'
				# retrive gdb thread id
				return re.match(thread_id_pattern, thread_str).group(1)

			# check if the task is running or not
			if task_state == gdb.parse_and_eval('uBaseTask::Running'):
				# find the equivalent thread from info thread
				gdb_thread_id = find_matching_gdb_thread_id()
				if gdb_thread_id is None:
					print('cannot find the thread id to switch to')
					return
				# switch to that thread based using thread command
				gdb.execute('thread {}'.format(gdb_thread_id))

	def switchto_id(self, tid, cluster):
		"""
		@cluster: cluster object
		@tid: int
		"""
		threads = lookup_threads_by_cluster( cluster )

		for t in threads:
			if t.tid == tid:
				self.switchto(t.value)
				return

		print("Cound not find thread by id '{}'".format(tid))

	def invoke(self, arg, from_tty):
		"""
		@arg: str
		@from_tty: bool
		"""
		if not is_cforall():
			return

		argv = parse(arg)
		print(argv)
		if argv[0].isdigit():
			cname = " ".join(argv[1:]) if len(argv) > 1 else None
			cluster = lookup_cluster(cname)
			if not cluster:
				print("Could not find cluster '{}'".format(cname if cname else "Main Cluster"))
				return

			try:
				tid = int(argv[0])
			except:
				print("'{}' not a valid thread id".format(argv[0]))
				print_usage(self)
				return

				# by id, userCluster
			self.switchto_id(tid, cluster)

		elif argv[0].startswith('0x') or argv[0].startswith('0X'):
			self.switchto(argv[0]) # by address, any cluster

############
class PrevThread(gdb.Command):
	"""Switch back to previous task on the stack"""
	usage_msg = 'prevtask'

	def __init__(self):
		super(PrevThread, self).__init__('prevtask', gdb.COMMAND_USER)

	def invoke(self, arg, from_tty):
		"""
		@arg: str
		@from_tty: bool
		"""
		global STACK
		if len(STACK) != 0:
			# must be at frame 0 to set pc register
			gdb.execute('select-frame 0')

			# pop stack
			stack_info = STACK.pop()
			pc = get_addr(stack_info.pc)
			sp = stack_info.sp
			fp = stack_info.fp

			# pop sp, fp, pc from global stack
			adjust_stack(pc, fp, sp)

			# must be at C++ frame to access C++ vars
			gdb.execute('frame 1')
		else:
			print('empty stack')

class ResetOriginFrame(gdb.Command):
	"""Reset to the origin frame prior to continue execution again"""
	usage_msg = 'resetOriginFrame'
	def __init__(self):
		super(ResetOriginFrame, self).__init__('reset', gdb.COMMAND_USER)

	def invoke(self, arg, from_tty):
		"""
		@arg: str
		@from_tty: bool
		"""
		global STACK
		if len(STACK) != 0:
			stack_info = STACK.pop(0)
			STACK.clear()
			pc = get_addr(stack_info.pc)
			sp = stack_info.sp
			fp = stack_info.fp

			# pop sp, fp, pc from global stack
			adjust_stack(pc, fp, sp)

			# must be at C++ frame to access C++ vars
			gdb.execute('frame 1')
		#else:
			#print('reset: empty stack') #probably does not have to print msg

Clusters()
Processors()
ResetOriginFrame()
PrevThread()
Threads()
Thread()

# Local Variables: #
# mode: Python #
# End: #
