#
# Copyright (C) Lynn Tran, Jiachen Zhang 2018
#
# utils-gdb.py --
#
# Author           : Lynn Tran
# Created On       : Mon Oct 1 22:06:09 2018
# Last Modified By : Peter A. Buhr
# Last Modified On : Sat Jan 19 14:16:10 2019
# Update Count     : 11
#

"""
To run this extension, the python name has to be as same as one of the loaded library
Additionally, the file must exist in a folder which is in gdb's safe path
"""
import collections
import gdb
import re

# set these signal handlers with some settings (nostop, noprint, pass)
gdb.execute('handle SIGALRM nostop noprint pass')
gdb.execute('handle SIGUSR1 nostop noprint pass')

CfaTypes = collections.namedtuple('CfaTypes', 'cluster_ptr processor_ptr thread_ptr int_ptr thread_state yield_state')

class ThreadInfo:
	tid = 0
	cluster = None
	value = None

	def __init__(self, cluster, value):
		self.cluster = cluster
		self.value = value

	def is_system(self):
		return False

# A named tuple representing information about a stack
StackInfo = collections.namedtuple('StackInfo', 'sp fp pc')

# A global variable to keep track of stack information as one switches from one
# task to another task
STACK = []

not_supported_error_msg = "Not a supported command for this language"

def is_cforall():
	return True

def get_cfa_types():
	# GDB types for various structures/types in CFA
	return CfaTypes(cluster_ptr = gdb.lookup_type('struct cluster').pointer(),
		processor_ptr = gdb.lookup_type('struct processor').pointer(),
		thread_ptr = gdb.lookup_type('struct thread$').pointer(),
		int_ptr = gdb.lookup_type('int').pointer(),
		thread_state = gdb.lookup_type('enum __Coroutine_State'),
		yield_state = gdb.lookup_type('enum __Preemption_Reason'))

def get_addr(addr):
	"""
	NOTE: sketchy solution to retrieve address. There is a better solution...
	@addr: str of an address that can be in a format 0xfffff <type of the object
	at this address>
	Return: str of just the address
	"""
	str_addr = str(addr)
	ending_addr_index = str_addr.find('<')
	if ending_addr_index == -1:
		return str(addr)
	return str_addr[:ending_addr_index].strip()

def print_usage(obj):
	print(obj.__doc__)

def parse(args):
	"""
	Split the argument list in string format, where each argument is separated
	by whitespace delimiter, to a list of arguments like argv
	@args: str of arguments
	Return:
		[] if args is an empty string
		list if args is not empty
	"""
	# parse the string format of arguments and return a list of arguments
	argv = args.split(' ')
	if len(argv) == 1 and argv[0] == '':
		return []
	return argv

class ClusterIter:
	def __init__(self, root):
		self.curr = None
		self.root = root

	def __iter__(self):
		return self

	def __next__(self):
		# Clusters form a cycle
		# If we haven't seen the root yet, then the root is the first
		if not self.curr:
			self.curr = self.root
			return self.curr

		# if we already saw the root, then go forward
		self.curr = self.curr['_X4nodeS26__cluster____dbg_node_cltr_1']['_X4nextPS7cluster_1']

		# if we reached the root again, then we are done
		if self.curr == self.root:
			raise StopIteration

		# otherwise return the next
		return self.curr

def all_clusters():
	"""
	Return: a list of all the clusters as an iterator.
	obtained from gdb.Value of globalClusters.root (is an address)
	"""
	if not is_cforall():
		return []

	cluster_root = gdb.parse_and_eval('_X11mainClusterPS7cluster_1')
	if cluster_root.address == 0x0:
		print('No clusters, program terminated')
		return []

	return ClusterIter(cluster_root)

class ProcIter:
	def __init__(self, root):
		self.curr = None
		self.root = root

	def __iter__(self):
		return self

	def check(self):
		# check if this is the last value
		addr = int(self.curr)
		mask = 1 << ((8 * int(gdb.parse_and_eval('sizeof(void*)'))) - 1)
		if 0 != (mask & addr):
			raise StopIteration

	def __next__(self):
		cfa_t = get_cfa_types()

		# Processors form a cycle
		# If we haven't seen the root yet, then the root is the first
		if not self.curr:
			my_next = self.root
			self.curr = my_next.cast(cfa_t.processor_ptr)

			#check if this is an empty list
			self.check()

			return self.curr

		# if we already saw the root, then go forward
		my_next = self.curr['__anonymous_object2225']['_X4nextPY13__tE_generic__1']
		self.curr = my_next.cast(cfa_t.processor_ptr)

		#check if we reached the end
		self.check()

		# otherwise return the next
		return self.curr

def proc_list(cluster):
	"""
	Return: for a given processor, return the active and idle processors, as 2 iterators
	"""
	cfa_t = get_cfa_types()
	proclist = cluster['_X5procsS19__cluster_proc_list_1']
	idle = proclist['_X5idlesS5dlist_S9processorS5dlink_S9processor___1']['__anonymous_object2167']['_X4nextPY13__tE_generic__1']
	active = proclist['_X7activesS5dlist_S9processorS5dlink_S9processor___1']['__anonymous_object2167']['_X4nextPY13__tE_generic__1']
	return ProcIter(active.cast(cfa_t.processor_ptr)), ProcIter(idle.cast(cfa_t.processor_ptr))

def all_processors():
	procs = []
	for c in all_clusters():
		active, idle = proc_list(c)
		for p in active:
			procs.append(p)

		for p in idle:
			procs.append(p)

	print(procs)
	return procs

def tls_for_pthread(pthrd):
	prev = gdb.selected_thread()
	inf = gdb.selected_inferior()

	thrd = inf.thread_from_thread_handle( pthrd )
	thrd.switch()
	tls = gdb.parse_and_eval('&_X9kernelTLSS16KernelThreadData_1')

	prev.switch()
	return tls

def tls_for_proc(proc):
	return proc['_X10local_dataPS16KernelThreadData_1']

def thread_for_pthread(pthrd):
	return tls_for_pthread(pthrd)['_X11this_threadVPS7thread$_1']

def thread_for_proc(proc):
	return tls_for_proc(proc)['_X11this_threadVPS7thread$_1']



def find_curr_thread():
	# btstr = gdb.execute('bt', to_string = True).splitlines()
	# if len(btstr) == 0:
	#     print('error')
	#     return None
	# return btstr[0].split('this=',1)[1].split(',')[0].split(')')[0]
	return None

def lookup_cluster(name = None):
	"""
	Look up one or more cluster given a name
	@name: str
	Return: gdb.Value
	"""
	if not is_cforall():
		return None

	clusters = all_clusters()
	if not clusters:
		return None

	if not name:
		return clusters.root

	# lookup for the task associated with the id
	found = [c for c in clusters if c['_X4namePKc_1'].string() == name]

	if not found:
		print("Cannot find a cluster with the name: {}.".format(name))
		return None

	return found


def lookup_threads_by_cluster(cluster):
		# Iterate through a circular linked list of threads and accumulate them in an array
		threads = []

		cfa_t = get_cfa_types()
		root = cluster['_X7threadsS8__dllist_S7thread$__1']['_X4headPY15__TYPE_generic__1'].cast(cfa_t.thread_ptr)

		if root == 0x0 or root.address == 0x0:
			print('There are no tasks for cluster: {}'.format(cluster))
			return threads

		curr = root
		tid = 0
		sid = -1

		while True:
			t = ThreadInfo(cluster, curr)
			if t.is_system():
				t.tid = sid
				sid -= 1
			else:
				t.tid = tid
				tid += 1

			threads.append(t)

			curr = curr['node']['next']
			if curr == root or curr == 0x0:
				break

		return threads

def system_thread(thread):
	return False

def adjust_stack(pc, fp, sp):
	# pop sp, fp, pc from global stack
	gdb.execute('set $pc = {}'.format(pc))
	gdb.execute('set $rbp = {}'.format(fp))
	gdb.execute('set $sp = {}'.format(sp))

############################ COMMAND IMPLEMENTATION #########################

class Clusters(gdb.Command):
	"""Cforall: Display currently known clusters
Usage:
	info clusters                 : print out all the clusters
"""

	def __init__(self):
		super(Clusters, self).__init__('info clusters', gdb.COMMAND_USER)

	def print_cluster(self, cluster_name, cluster_address):
		print('{:>20}  {:>20}'.format(cluster_name, cluster_address))

	#entry point from gdb
	def invoke(self, arg, from_tty):
		if not is_cforall():
			return

		if arg:
			print("info clusters does not take arguments")
			print_usage(self)
			return

		self.print_cluster('Name', 'Address')

		for c in all_clusters():
			self.print_cluster(c['_X4namePKc_1'].string(), str(c))

		print("")

############
class Processors(gdb.Command):
	"""Cforall: Display currently known processors
Usage:
	info processors                 : print out all the processors
	info processors <cluster_name>  : print out all processors in a given cluster
"""

	def __init__(self):
		super(Processors, self).__init__('info processors', gdb.COMMAND_USER)

	def print_processor(self, processor, in_stats):
		should_stop = processor['_X12do_terminateVb_1']
		if not should_stop:
			status = in_stats
		else:
			stop_count  = processor['_X10terminatedS9semaphore_1']['_X5counti_1']
			status_str  = 'Last Thread' if stop_count >= 0 else 'Terminating'
			status      = '{}({},{})'.format(status_str, should_stop, stop_count)

		print('{:>20}  {:>11}  {:<7}  {:<}'.format(
			processor['_X4namePKc_1'].string(),
			status,
			str(processor['_X18pending_preemptionb_1']),
			str(processor)
		))
		tls = tls_for_proc( processor )
		thrd = tls['_X11this_threadVPS7thread$_1']
		if thrd != 0x0:
			tname = '{} {}'.format(thrd['self_cor']['name'].string(), str(thrd))
		else:
			tname = None

		print('{:>20}  {}'.format('Thread', tname))
		print('{:>20}  {}'.format('TLS', tls))

	#entry point from gdb
	def invoke(self, arg, from_tty):
		if not is_cforall():
			return

		if not arg:
			clusters = all_clusters()
		else:
			clusters = [lookup_cluster(arg)]

		if not clusters:
			print("No Cluster matching arguments found")
			return

		print('{:>20}  {:>11}  {:<7}  {}'.format('Processor', '', 'Pending', 'Object'))
		print('{:>20}  {:>11}  {:<7}  {}'.format('Name', 'Status', 'Yield', 'Address'))
		for c in clusters:
			print('Cluster {}'.format(c['_X4namePKc_1'].string()))

			active, idle = proc_list(c)
			# print the processor information
			for p in active:
				self.print_processor(p, 'Active')

			for p in idle:
				self.print_processor(p, 'Idle')

			print()

		print()

############
class Threads(gdb.Command):
	"""Cforall: Display currently known threads
Usage:
	cfathreads                           : print Main Cluster threads, application threads only
	cfathreads all                       : print all clusters, all threads
	cfathreads <clusterName>             : print cluster threads, application threads only
	"""
	def __init__(self):
		# The first parameter of the line below is the name of the command. You
		# can call it 'uc++ task'
		super(Threads, self).__init__('info cfathreads', gdb.COMMAND_USER)

	def print_formatted(self, marked, tid, name, state, address):
		print('{:>1}  {:>4}  {:>20}  {:>10}  {:>20}'.format('*' if marked else ' ', tid, name, state, address))

	def print_thread(self, thread, tid, marked):
		cfa_t = get_cfa_types()
		ys = str(thread['preempted'].cast(cfa_t.yield_state))
		if ys == '_X15__NO_PREEMPTIONKM19__Preemption_Reason_1':
			state = str(thread['state'].cast(cfa_t.thread_state))
		elif ys == '_X18__ALARM_PREEMPTIONKM19__Preemption_Reason_1':
			state = 'preempted'
		elif ys == '_X19__MANUAL_PREEMPTIONKM19__Preemption_Reason_1':
			state = 'yield'
		elif ys == '_X17__POLL_PREEMPTIONKM19__Preemption_Reason_1':
			state = 'poll'
		else:
			print("error: thread {} in undefined preemption state {}".format(thread, ys))
			state = 'error'
		self.print_formatted(marked, tid, thread['self_cor']['name'].string(), state, str(thread))

	def print_threads_by_cluster(self, cluster, print_system = False):
		# Iterate through a circular linked list of tasks and print out its
		# name along with address associated to each cluster
		threads = lookup_threads_by_cluster(cluster)
		if not threads:
			return

		running_thread = find_curr_thread()
		if running_thread is None:
			print('Could not identify current thread')

		self.print_formatted(False, '', 'Name', 'State', 'Address')

		for t in threads:
			if not t.is_system() or print_system:
				self.print_thread(t.value, t.tid, t.value == running_thread if running_thread else False)

		print()

	def print_all_threads(self):
		for c in all_clusters():
			self.print_threads_by_cluster(c, False)

	def invoke(self, arg, from_tty):
		"""
		@arg: str
		@from_tty: bool
		"""
		if not is_cforall():
			return

		if not arg:
			cluster = lookup_cluster()
			if not cluster:
				print("Could not find Main Cluster")
				return

			# only tasks and main
			self.print_threads_by_cluster(cluster, False)

		elif arg == 'all':
			# all threads, all clusters
			self.print_all_threads()

		else:
			cluster = lookup_cluster(arg)
			if not cluster:
				print("No matching cluster")
				return

			# all tasks, specified cluster
			self.print_threads_by_cluster(cluster, True)


############
class Thread(gdb.Command):
	"""Cforall: Switch to specified user threads
Usage:
	cfathread <id>                       : switch stack to thread id on main cluster
	cfathread 0x<address>	             : switch stack to thread on any cluster
	cfathread <id> <clusterName>         : switch stack to thread on specified cluster
	"""
	def __init__(self):
		# The first parameter of the line below is the name of the command. You
		# can call it 'uc++ task'
		super(Thread, self).__init__('cfathread', gdb.COMMAND_USER)

	############################ AUXILIARY FUNCTIONS #########################

	def switchto(self, thread):
		"""Change to a new task by switching to a different stack and manually
		adjusting sp, fp and pc
		@task_address: str
			2 supported format:
				in hex format
					<hex_address>: literal hexadecimal address
					Ex: 0xffffff
				in name of the pointer to the task
					"task_name": pointer of the variable name of the cluster
						Ex: T* s -> task_name = s
			Return: gdb.value of the cluster's address
		"""
		try:
			if not gdb.lookup_symbol('__cfactx_switch'):
				print('__cfactx_switch symbol is unavailable')
				return
		except:
			print('here 3')

		cfa_t = get_cfa_types()

		state = thread['state'].cast(cfa_t.thread_state)
		try:
			if state == gdb.parse_and_eval('Halted'):
				print('Cannot switch to a terminated thread')
				return

			if state == gdb.parse_and_eval('Start'):
				print('Cannjot switch to a thread not yet run')
				return
		except:
			print("here 2")
			return


		context = thread['context']



		# must be at frame 0 to set pc register
		gdb.execute('select-frame 0')
		if gdb.selected_frame().architecture().name() != 'i386:x86-64':
			print('gdb debugging only supported for i386:x86-64 for now')
			return

		# gdb seems to handle things much better if we pretend we just entered the context switch
		# pretend the pc is __cfactx_switch and adjust the sp, base pointer doesn't need to change
		# lookup for sp,fp and uSwitch
		xsp = context['SP'] + 40 # 40 = 5 64bit registers : %r15, %r14, %r13, %r12, %rbx WARNING: x64 specific
		xfp = context['FP']

		# convert string so we can strip out the address
		try:
			xpc = get_addr(gdb.parse_and_eval('__cfactx_switch').address)
		except:
			print("here")
			return

		# push sp, fp, pc into a global stack
		global STACK
		sp = gdb.parse_and_eval('$sp')
		fp = gdb.parse_and_eval('$fp')
		pc = gdb.parse_and_eval('$pc')
		stack_info = StackInfo(sp = sp, fp = fp, pc = pc)
		STACK.append(stack_info)

		# update registers for new task
		# print('switching to {} ({}) : [{}, {}, {}]'.format(thread['self_cor']['name'].string(), str(thread), str(xsp), str(xfp), str(xpc)))
		print('switching to thread {} ({})'.format(str(thread), thread['self_cor']['name'].string()))
		gdb.execute('set $rsp={}'.format(xsp))
		gdb.execute('set $rbp={}'.format(xfp))
		gdb.execute('set $pc={}'.format(xpc))

	def find_matching_gdb_thread_id():
		"""
		Parse the str from info thread to get the number
		"""
		info_thread_str = gdb.execute('info thread', to_string=True).splitlines()
		for thread_str in info_thread_str:
			if thread_str.find('this={}'.format(task)) != -1:
				thread_id_pattern = r'^\*?\s+(\d+)\s+Thread'
				# retrive gdb thread id
				return re.match(thread_id_pattern, thread_str).group(1)

			# check if the task is running or not
			if task_state == gdb.parse_and_eval('uBaseTask::Running'):
				# find the equivalent thread from info thread
				gdb_thread_id = find_matching_gdb_thread_id()
				if gdb_thread_id is None:
					print('cannot find the thread id to switch to')
					return
				# switch to that thread based using thread command
				gdb.execute('thread {}'.format(gdb_thread_id))

	def switchto_id(self, tid, cluster):
		"""
		@cluster: cluster object
		@tid: int
		"""
		threads = lookup_threads_by_cluster( cluster )

		for t in threads:
			if t.tid == tid:
				self.switchto(t.value)
				return

		print("Cound not find thread by id '{}'".format(tid))

	def invoke(self, arg, from_tty):
		"""
		@arg: str
		@from_tty: bool
		"""
		if not is_cforall():
			return

		argv = parse(arg)
		if argv[0].isdigit():
			cname = " ".join(argv[1:]) if len(argv) > 1 else None
			cluster = lookup_cluster(cname)
			if not cluster:
				print("Could not find cluster '{}'".format(cname if cname else "Main Cluster"))
				return

			try:
				tid = int(argv[0])
			except:
				print("'{}' not a valid thread id".format(argv[0]))
				print_usage(self)
				return

				# by id, userCluster
			self.switchto_id(tid, cluster)

		elif argv[0].startswith('0x') or argv[0].startswith('0X'):
			self.switchto(argv[0]) # by address, any cluster

############
class PrevThread(gdb.Command):
	"""Switch back to previous task on the stack"""
	usage_msg = 'prevtask'

	def __init__(self):
		super(PrevThread, self).__init__('prevtask', gdb.COMMAND_USER)

	def invoke(self, arg, from_tty):
		"""
		@arg: str
		@from_tty: bool
		"""
		global STACK
		if len(STACK) != 0:
			# must be at frame 0 to set pc register
			gdb.execute('select-frame 0')

			# pop stack
			stack_info = STACK.pop()
			pc = get_addr(stack_info.pc)
			sp = stack_info.sp
			fp = stack_info.fp

			# pop sp, fp, pc from global stack
			adjust_stack(pc, fp, sp)

			# must be at C++ frame to access C++ vars
			gdb.execute('frame 1')
		else:
			print('empty stack')

class ResetOriginFrame(gdb.Command):
	"""Reset to the origin frame prior to continue execution again"""
	usage_msg = 'resetOriginFrame'
	def __init__(self):
		super(ResetOriginFrame, self).__init__('reset', gdb.COMMAND_USER)

	def invoke(self, arg, from_tty):
		"""
		@arg: str
		@from_tty: bool
		"""
		global STACK
		if len(STACK) != 0:
			stack_info = STACK.pop(0)
			STACK.clear()
			pc = get_addr(stack_info.pc)
			sp = stack_info.sp
			fp = stack_info.fp

			# pop sp, fp, pc from global stack
			adjust_stack(pc, fp, sp)

			# must be at C++ frame to access C++ vars
			gdb.execute('frame 1')
		#else:
			#print('reset: empty stack') #probably does not have to print msg

Clusters()
Processors()
ResetOriginFrame()
PrevThread()
Threads()
Thread()

# Local Variables: #
# mode: Python #
# End: #
