#
# Copyright (C) Lynn Tran, Jiachen Zhang 2018
#
# utils-gdb.py --
#
# Author           : Lynn Tran
# Created On       : Mon Oct 1 22:06:09 2018
# Last Modified By : Peter A. Buhr
# Last Modified On : Sat Jan 19 14:16:10 2019
# Update Count     : 11
#

"""
To run this extension, the python name has to be as same as one of the loaded library
Additionally, the file must exist in a folder which is in gdb's safe path
"""
import collections
import gdb
import re

# set these signal handlers with some settings (nostop, noprint, pass)
gdb.execute('handle SIGALRM nostop noprint pass')
gdb.execute('handle SIGUSR1 nostop noprint pass')

CfaTypes = collections.namedtuple('CfaTypes', 'cluster_ptr processor_ptr thread_ptr int_ptr thread_state')

class ThreadInfo:
    tid = 0
    cluster = None
    value = None

    def __init__(self, cluster, value):
        self.cluster = cluster
        self.value = value

    def is_system(self):
        return False

# A named tuple representing information about a stack
StackInfo = collections.namedtuple('StackInfo', 'sp fp pc')

# A global variable to keep track of stack information as one switches from one
# task to another task
STACK = []

# A global variable to keep all system task name
SysTask_Name = ["uLocalDebuggerReader", "uLocalDebugger", "uProcessorTask", "uBootTask", "uSystemTask",
"uProcessorTask", "uPthread", "uProfiler"]

not_supported_error_msg = "Not a supported command for this language"

def is_cforall():
    return True

def get_cfa_types():
    # GDB types for various structures/types in CFA
    return CfaTypes(cluster_ptr = gdb.lookup_type('struct cluster').pointer(),
                  processor_ptr = gdb.lookup_type('struct processor').pointer(),
                     thread_ptr = gdb.lookup_type('struct $thread').pointer(),
                        int_ptr = gdb.lookup_type('int').pointer(),
                   thread_state = gdb.lookup_type('enum coroutine_state'))

def get_addr(addr):
    """
    NOTE: sketchy solution to retrieve address. There is a better solution...
    @addr: str of an address that can be in a format 0xfffff <type of the object
    at this address>
    Return: str of just the address
    """
    str_addr = str(addr)
    ending_addr_index = str_addr.find('<')
    if ending_addr_index == -1:
        return str(addr)
    return str_addr[:ending_addr_index].strip()

def print_usage(obj):
    print(obj.__doc__)

def parse(args):
    """
    Split the argument list in string format, where each argument is separated
    by whitespace delimiter, to a list of arguments like argv
    @args: str of arguments
    Return:
        [] if args is an empty string
        list if args is not empty
    """
    # parse the string format of arguments and return a list of arguments
    argv = args.split(' ')
    if len(argv) == 1 and argv[0] == '':
        return []
    return argv

def get_cluster_root():
    """
    Return: gdb.Value of globalClusters.root (is an address)
    """
    cluster_root = gdb.parse_and_eval('_X11mainClusterPS7cluster_1')
    if cluster_root.address == 0x0:
        print('No clusters, program terminated')
    return cluster_root

def find_curr_thread():
    # btstr = gdb.execute('bt', to_string = True).splitlines()
    # if len(btstr) == 0:
    #     print('error')
    #     return None
    # return btstr[0].split('this=',1)[1].split(',')[0].split(')')[0]
    return None

def lookup_cluster(name = None):
    """
    Look up a cluster given its ID
    @name: str
    Return: gdb.Value
    """
    if not is_cforall():
        return None

    root = get_cluster_root()
    if root.address == 0x0:
        return None

    if not name:
        return root

    # lookup for the task associated with the id
    cluster = None
    curr = root
    while True:
        if curr['_X4namePKc_1'].string() == name:
            cluster = curr.address
            break
        curr = curr['_X4nodeS26__cluster____dbg_node_cltr_1']['_X4nextPS7cluster_1']
        if curr == root or curr == 0x0:
            break

    if not cluster:
        print("Cannot find a cluster with the name: {}.".format(name))
        return None

    return cluster

def lookup_threads_by_cluster(cluster):
        # Iterate through a circular linked list of threads and accumulate them in an array
        threads = []

        cfa_t = get_cfa_types()
        root = cluster['_X7threadsS8__dllist_S7$thread__1']['_X4headPY15__TYPE_generic__1'].cast(cfa_t.thread_ptr)

        if root == 0x0 or root.address == 0x0:
            print('There are no tasks for cluster: {}'.format(cluster))
            return threads

        curr = root
        tid = 0
        sid = -1

        while True:
            t = ThreadInfo(cluster, curr)
            if t.is_system():
                t.tid = sid
                sid -= 1
            else:
                t.tid = tid
                tid += 1

            threads.append(t)

            curr = curr['node']['next']
            if curr == root or curr == 0x0:
                break

        return threads

def system_thread(thread):
    return False

def adjust_stack(pc, fp, sp):
    # pop sp, fp, pc from global stack
    gdb.execute('set $pc = {}'.format(pc))
    gdb.execute('set $rbp = {}'.format(fp))
    gdb.execute('set $sp = {}'.format(sp))

############################ COMMAND IMPLEMENTATION #########################

class Clusters(gdb.Command):
    """Cforall: Display currently known clusters
Usage:
    info clusters                 : print out all the clusters
"""

    def __init__(self):
        super(Clusters, self).__init__('info clusters', gdb.COMMAND_USER)

    def print_cluster(self, cluster_name, cluster_address):
        print('{:>20}  {:>20}'.format(cluster_name, cluster_address))

    #entry point from gdb
    def invoke(self, arg, from_tty):
        if not is_cforall():
            return

        if arg:
            print("info clusters does not take arguments")
            print_usage(self)
            return

        cluster_root = get_cluster_root()
        if cluster_root.address == 0x0:
            return

        curr = cluster_root
        self.print_cluster('Name', 'Address')

        while True:
            self.print_cluster(curr['_X4namePKc_1'].string(), str(curr))
            curr = curr['_X4nodeS26__cluster____dbg_node_cltr_1']['_X4nextPS7cluster_1']
            if curr == cluster_root:
                break

        print("")

############
class Processors(gdb.Command):
    """Cforall: Display currently known processors
Usage:
    info processors                 : print out all the processors in the Main Cluster
    info processors <cluster_name>  : print out all processors in a given cluster
"""

    def __init__(self):
        super(Processors, self).__init__('info processors', gdb.COMMAND_USER)

    def print_processor(self, name, status, pending, address):
        print('{:>20}  {:>11}  {:>13}  {:>20}'.format(name, status, pending, address))

    def iterate_procs(self, root, active):
        if root == 0x0:
            return

        cfa_t = get_cfa_types()
        curr = root

        while True:
            processor = curr
            should_stop = processor['_X12do_terminateVb_1']
            stop_count  = processor['_X10terminatedS9semaphore_1']['_X5counti_1']
            if not should_stop:
                status = 'Active' if active else 'Idle'
            else:
                status_str  = 'Last Thread' if stop_count >= 0 else 'Terminating'
                status      = '{}({},{})'.format(status_str, should_stop, stop_count)

            self.print_processor(processor['_X4namePKc_1'].string(),
                    status, str(processor['_X18pending_preemptionb_1']), str(processor)
                )

            curr = curr['_X4nodeS28__processor____dbg_node_proc_1']['_X4nextPS9processor_1']

            if curr == root or curr == 0x0:
                break

    #entry point from gdb
    def invoke(self, arg, from_tty):
        if not is_cforall():
            return

        cluster = lookup_cluster(arg if arg else None)

        if not cluster:
            print("No Cluster matching arguments found")
            return

        cfa_t = get_cfa_types()
        print('Cluster: "{}"({})'.format(cluster['_X4namePKc_1'].string(), cluster.cast(cfa_t.cluster_ptr)))

        active_root = cluster.cast(cfa_t.cluster_ptr) \
                ['_X5procsS8__dllist_S9processor__1'] \
                ['_X4headPY15__TYPE_generic__1'] \
                .cast(cfa_t.processor_ptr)

        idle_root = cluster.cast(cfa_t.cluster_ptr) \
                ['_X5idlesS8__dllist_S9processor__1'] \
                ['_X4headPY15__TYPE_generic__1'] \
                .cast(cfa_t.processor_ptr)

        if idle_root != 0x0 or active_root != 0x0:
            self.print_processor('Name', 'Status', 'Pending Yield', 'Address')
            self.iterate_procs(active_root, True)
            self.iterate_procs(idle_root, False)
        else:
            print("No processors on cluster")

        print()

############
class Threads(gdb.Command):
    """Cforall: Display currently known threads
Usage:
    cfathreads                           : print Main Cluster threads, application threads only
    cfathreads all                       : print all clusters, all threads
    cfathreads <clusterName>             : print cluster threads, application threads only
    """
    def __init__(self):
        # The first parameter of the line below is the name of the command. You
        # can call it 'uc++ task'
        super(Threads, self).__init__('info cfathreads', gdb.COMMAND_USER)

    def print_formatted(self, marked, tid, name, state, address):
        print('{:>1}  {:>4}  {:>20}  {:>10}  {:>20}'.format('*' if marked else ' ', tid, name, state, address))

    def print_thread(self, thread, tid, marked):
        cfa_t = get_cfa_types()
        self.print_formatted(marked, tid, thread['self_cor']['name'].string(), str(thread['state'].cast(cfa_t.thread_state)), str(thread))

    def print_formatted_cluster(self, str_format, cluster_name, cluster_addr):
        print(str_format.format(cluster_name, cluster_addr))

    def print_threads_by_cluster(self, cluster, print_system = False):
        # Iterate through a circular linked list of tasks and print out its
        # name along with address associated to each cluster
        threads = lookup_threads_by_cluster(cluster)
        if not threads:
            return

        running_thread = find_curr_thread()
        if running_thread is None:
            print('Could not identify current thread')

        self.print_formatted(False, '', 'Name', 'State', 'Address')

        for t in threads:
            if not t.is_system() or print_system:
                self.print_thread(t.value, t.tid, t.value == running_thread if running_thread else False)

        print()

    def print_all_threads(self):
        print("Not implemented")

    def invoke(self, arg, from_tty):
        """
        @arg: str
        @from_tty: bool
        """
        if not is_cforall():
            return

        if not arg:
            cluster = lookup_cluster()
            if not cluster:
                print("Could not find Main Cluster")
                return

            # only tasks and main
            self.print_threads_by_cluster(cluster, False)

        elif arg == 'all':
            # all threads, all clusters
            self.print_all_threads()

        else:
            cluster = lookup_cluster(arg)
            if not cluster:
                print("Could not find cluster '{}'".format(arg))
                return

            # all tasks, specified cluster
            self.print_threads_by_cluster(cluster, True)


############
class Thread(gdb.Command):
    def __init__(self):
        # The first parameter of the line below is the name of the command. You
        # can call it 'uc++ task'
        super(Threads, self).__init__('cfathread', gdb.COMMAND_USER)

    def print_usage(self):
        print_usage("""
    cfathread                            : print userCluster tasks, application tasks only
    cfathread <clusterName>              : print cluster tasks, application tasks only
    cfathread all                        : print all clusters, all tasks
    cfathread <id>                       : switch stack to thread id on userCluster
    cfathread 0x<address>	             : switch stack to thread on any cluster
    cfathread <id> <clusterName>         : switch stack to thread on specified cluster
    """)

    ############################ AUXILIARY FUNCTIONS #########################

    def print_formatted(self, marked, tid, name, state, address):
        print('{:>1}  {:>4}  {:>20}  {:>10}  {:>20}'.format('*' if marked else ' ', tid, name, state, address))

    def print_thread(self, thread, tid, marked):
        cfa_t = get_cfa_types()
        self.print_formatted(marked, tid, thread['self_cor']['name'].string(), str(thread['state'].cast(cfa_t.thread_state)), str(thread))

    def print_formatted_cluster(self, str_format, cluster_name, cluster_addr):
        print(str_format.format(cluster_name, cluster_addr))

    def print_tasks_by_cluster_all(self, cluster_address):
        """
        Display a list of all info about all available tasks on a particular cluster
        @cluster_address: gdb.Value
        """
        cluster_address = cluster_address.cast(uCPPTypes.ucluster_ptr)
        task_root = cluster_address['tasksOnCluster']['root']

        if task_root == 0x0 or task_root.address == 0x0:
            print('There are no tasks for cluster at address: {}'.format(cluster_address))
            return

        self.print_formatted_task('', 'Task Name', 'Address', 'State')
        curr = task_root
        task_id = 0
        systask_id = -1

        breakpoint_addr = self.find_curr_breakpoint_addr()
        if breakpoint_addr is None:
            return

        while True:
            global SysTask_Name
            if (curr['task_']['name'].string() in SysTask_Name):
                self.print_formatted_tasks(systask_id, breakpoint_addr, curr)
                systask_id -= 1
            else:
                self.print_formatted_tasks(task_id, breakpoint_addr, curr)
                task_id += 1

            curr = curr['next'].cast(uCPPTypes.uBaseTaskDL_ptr_type)
            if curr == task_root:
                break

    def print_tasks_by_cluster_address_all(self, cluster_address):
        """
        Display a list of all info about all available tasks on a particular cluster
        @cluster_address: str
        """
        # Iterate through a circular linked list of tasks and print out its
        # name along with address associated to each cluster

        # convert hex string to hex number
        try:
            hex_addr = int(cluster_address, 16)
        except:
            self.print_usage()
            return

        cluster_address = gdb.Value(hex_addr)
        if not self.print_tasks_by_cluster_all(cluster_address):
            return

    def print_threads_by_cluster(self, cluster, print_system = False):
        """
        Display a list of limited info about all available threads on a particular cluster
        @cluster: str
        @print_system: bool
        """
        # Iterate through a circular linked list of tasks and print out its
        # name along with address associated to each cluster

        threads = self.threads_by_cluster(cluster)
        if not threads:
            return

        running_thread = self.find_curr_thread()
        if running_thread is None:
            print('Could not identify current thread')

        self.print_formatted(False, '', 'Name', 'State', 'Address')

        for t in threads:
            if not t.is_system() or print_system:
                self.print_thread(t.value, t.tid, t.value == running_thread if running_thread else False)

        print()

    ############################ COMMAND FUNCTIONS #########################

    def print_all_threads(self):
        """Iterate through each cluster, iterate through all tasks and  print out info about all the tasks
        in those clusters"""
        uCPPTypes = None
        try:
            uCPPTypes = get_uCPP_types()
        except gdb.error:
            print(not_supported_error_msg)
            print(gdb.error)
            return

        cluster_root = get_cluster_root()
        if cluster_root.address == 0x0:
            return

        curr = cluster_root
        self.print_formatted_cluster(self.cluster_str_format, 'Cluster Name', 'Address')

        while True:
            addr = str(curr['cluster_'].reference_value())[1:]
            self.print_formatted_cluster(self.cluster_str_format, curr['cluster_']['name'].string(), addr)

            self.print_tasks_by_cluster_address_all(addr)
            curr = curr['next'].cast(uCPPTypes.uClusterDL_ptr_type)
            if curr == cluster_root:
                break

    def switchto(self, thread):
        """Change to a new task by switching to a different stack and manually
        adjusting sp, fp and pc
        @task_address: str
            2 supported format:
                in hex format
                    <hex_address>: literal hexadecimal address
                    Ex: 0xffffff
                in name of the pointer to the task
                    "task_name": pointer of the variable name of the cluster
                        Ex: T* s -> task_name = s
            Return: gdb.value of the cluster's address
        """
        # uCPPTypes = None
        # try:
        #     uCPPTypes = get_uCPP_types()
        # except gdb.error:
        #     print(not_supported_error_msg)
        #     print(gdb.error)
        #     return

        # # Task address has a format "task_address", which implies that it is the
        # # name of the variable, and it needs to be evaluated
        # if task_address.startswith('"') and task_address.endswith('"'):
        #     task = gdb.parse_and_eval(task_address.replace('"', ''))
        # else:
        # # Task address format does not include the quotation marks, which implies
        # # that it is a hex address
        #     # convert hex string to hex number
        #     try:
        #         hex_addr = int(task_address, 16)
        #     except:
        #         self.print_usage()
        #         return
        #     task_address = gdb.Value(hex_addr)
        #     task = task_address.cast(uCPPTypes.uBaseTask_ptr_type)
        try:
            if not gdb.lookup_symbol('__cfactx_switch'):
                print('__cfactx_switch symbol is unavailable')
                return
        except:
            print('here 3')

        cfa_t = get_cfa_types()

        state = thread['state'].cast(cfa_t.thread_state)
        try:
            if state == gdb.parse_and_eval('Halted'):
                print('Cannot switch to a terminated thread')
                return

            if state == gdb.parse_and_eval('Start'):
                print('Cannjot switch to a thread not yet run')
                return
        except:
            print("here 2")
            return


        context = thread['context']

        # lookup for sp,fp and uSwitch
        xsp = context['SP'] + 48
        xfp = context['FP']

        # convert string so we can strip out the address
        try:
            xpc = get_addr(gdb.parse_and_eval('__cfactx_switch').address + 28)
        except:
            print("here")
            return

        # must be at frame 0 to set pc register
        gdb.execute('select-frame 0')

        # push sp, fp, pc into a global stack
        global STACK
        sp = gdb.parse_and_eval('$sp')
        fp = gdb.parse_and_eval('$fp')
        pc = gdb.parse_and_eval('$pc')
        stack_info = StackInfo(sp = sp, fp = fp, pc = pc)
        STACK.append(stack_info)

        # update registers for new task
        print('switching to ')
        gdb.execute('set $rsp={}'.format(xsp))
        gdb.execute('set $rbp={}'.format(xfp))
        gdb.execute('set $pc={}'.format(xpc))

    def find_matching_gdb_thread_id():
        """
        Parse the str from info thread to get the number
        """
        info_thread_str = gdb.execute('info thread', to_string=True).splitlines()
        for thread_str in info_thread_str:
            if thread_str.find('this={}'.format(task)) != -1:
                thread_id_pattern = r'^\*?\s+(\d+)\s+Thread'
                # retrive gdb thread id
                return re.match(thread_id_pattern, thread_str).group(1)

            # check if the task is running or not
            if task_state == gdb.parse_and_eval('uBaseTask::Running'):
                # find the equivalent thread from info thread
                gdb_thread_id = find_matching_gdb_thread_id()
                if gdb_thread_id is None:
                    print('cannot find the thread id to switch to')
                    return
                # switch to that thread based using thread command
                gdb.execute('thread {}'.format(gdb_thread_id))

    def switchto_id(self, tid, cluster):
        """
        @cluster: cluster object
        @tid: int
        """
        threads = self.threads_by_cluster( cluster )

        for t in threads:
            if t.tid == tid:
                self.switchto(t.value)
                return

        print("Cound not find thread by id '{}'".format(tid))

    def invoke(self, arg, from_tty):
        """
        @arg: str
        @from_tty: bool
        """
        if not is_cforall():
            return

        argv = parse(arg)
        print(argv)
        if len(argv) == 0:
            """
            Iterate only Main Thread, print only tasks and main
            """
            cluster = lookup_cluster()
            if not cluster:
                print("Could not find Main Cluster")
                return

            # only tasks and main
            self.print_threads_by_cluster(cluster, False)

        elif len(argv) == 1:
            if argv[0] == 'help':
                self.print_usage()
            # push task
            elif argv[0].isdigit():
                cluster = lookup_cluster()
                if not cluster:
                    print("Could not find Main Cluster")
                    return

                try:
                    tid = int(argv[0])
                except:
                    print("'{}' not a valid thread id".format(argv[0]))
                    self.print_usage()
                    return

                 # by id, userCluster
                self.switchto_id(tid, cluster)

            elif argv[0].startswith('0x') or argv[0].startswith('0X'):
                self.switchto(argv[0]) # by address, any cluster
            # print tasks
            elif argv[0] == 'all':
                self.print_all_threads() # all tasks, all clusters
            else:
                """
                Print out all the tasks available in the specified cluster
                @cluster_name: str
                """
                print("cfathread by name")
                cluster = lookup_cluster(argv[0])
                if not cluster:
                    return

                # all tasks, specified cluster
                self.print_threads_by_cluster(cluster, True)

        elif len(argv) == 2:
            # push task
            self.pushtask_by_id(argv[0], argv[1]) # by id, specified cluster
        else:
            print('Invalid arguments')
            self.print_usage()

############
class PrevThread(gdb.Command):
    """Switch back to previous task on the stack"""
    usage_msg = 'prevtask'

    def __init__(self):
        super(PrevThread, self).__init__('prevtask', gdb.COMMAND_USER)

    def invoke(self, arg, from_tty):
        """
        @arg: str
        @from_tty: bool
        """
        global STACK
        if len(STACK) != 0:
            # must be at frame 0 to set pc register
            gdb.execute('select-frame 0')

            # pop stack
            stack_info = STACK.pop()
            pc = get_addr(stack_info.pc)
            sp = stack_info.sp
            fp = stack_info.fp

            # pop sp, fp, pc from global stack
            adjust_stack(pc, fp, sp)

            # must be at C++ frame to access C++ vars
            gdb.execute('frame 1')
        else:
            print('empty stack')

class ResetOriginFrame(gdb.Command):
    """Reset to the origin frame prior to continue execution again"""
    usage_msg = 'resetOriginFrame'
    def __init__(self):
        super(ResetOriginFrame, self).__init__('reset', gdb.COMMAND_USER)

    def invoke(self, arg, from_tty):
        """
        @arg: str
        @from_tty: bool
        """
        global STACK
        if len(STACK) != 0:
            stack_info = STACK.pop(0)
            STACK.clear()
            pc = get_addr(stack_info.pc)
            sp = stack_info.sp
            fp = stack_info.fp

            # pop sp, fp, pc from global stack
            adjust_stack(pc, fp, sp)

            # must be at C++ frame to access C++ vars
            gdb.execute('frame 1')
        #else:
            #print('reset: empty stack') #probably does not have to print msg

Clusters()
Processors()
ResetOriginFrame()
PrevThread()
Threads()

# Local Variables: #
# mode: Python #
# End: #
