| 1 | import argparse
|
|---|
| 2 | import os
|
|---|
| 3 | from pprint import pprint
|
|---|
| 4 | from antlr4 import *
|
|---|
| 5 | from parser.CMODLexer import CMODLexer
|
|---|
| 6 | from parser.CMODParser import CMODParser
|
|---|
| 7 |
|
|---|
| 8 | def parse_args():
|
|---|
| 9 | parser = argparse.ArgumentParser()
|
|---|
| 10 | parser.add_argument("input_file", help="file to parse")
|
|---|
| 11 | parser.add_argument("-r", "--project_root", help="the root folder of the project, used to generate module names. All modules, including imports, should be within this folder")
|
|---|
| 12 | args = parser.parse_args()
|
|---|
| 13 | return args
|
|---|
| 14 |
|
|---|
| 15 | def main():
|
|---|
| 16 | args = parse_args()
|
|---|
| 17 | parse_top_level(args)
|
|---|
| 18 |
|
|---|
| 19 | def parse_top_level(args):
|
|---|
| 20 | # args.input_file
|
|---|
| 21 | # args.project_root
|
|---|
| 22 |
|
|---|
| 23 | # This changes the OS directory to be project_root, and removes the suffix from input_file
|
|---|
| 24 | # (it makes it easier to navigate the imports)
|
|---|
| 25 | input_file = os.path.relpath(args.input_file, args.project_root)
|
|---|
| 26 | os.chdir(args.project_root)
|
|---|
| 27 | dot_index = input_file.rfind('.')
|
|---|
| 28 | if dot_index < 0 or input_file[dot_index:] != '.cmod':
|
|---|
| 29 | raise RuntimeError("input_file must use .cmod suffix")
|
|---|
| 30 | input_file = input_file[:dot_index]
|
|---|
| 31 |
|
|---|
| 32 | # Type notes to help you understand this code
|
|---|
| 33 | # (it helps to have this open in another tab while reading the code)
|
|---|
| 34 | #
|
|---|
| 35 | # module_asts: dict[file_path: str, (text, tokens, tree)]
|
|---|
| 36 | # module_imports: dict[file_path: str, list[imported_file_path: str]]
|
|---|
| 37 | # text # str
|
|---|
| 38 | # tokens[0].text # start, stop, line, column
|
|---|
| 39 | # a, b = tree.getSourceInterval() # getChildren, getChildCount, getChild
|
|---|
| 40 | #
|
|---|
| 41 | # module_data: dict[file_path: str, (types, variables)]
|
|---|
| 42 | # types, variables: dict[name: str,
|
|---|
| 43 | # (idx: int, used_names_decl, used_names_defn, is_exported: bool, AST)]
|
|---|
| 44 | # used_names_decl, used_names_defn:
|
|---|
| 45 | # list[(name: str, idx: int, is_variable: bool, needs_defn: bool)]
|
|---|
| 46 | #
|
|---|
| 47 | # module_input: dict[file_path: str, (type_symbol_table, var_symbol_table)]
|
|---|
| 48 | # type_symbol_table, var_symbol_table: dict[name: str, file_path: str]
|
|---|
| 49 |
|
|---|
| 50 | # This parses the input file, as well as any recursively imported modules
|
|---|
| 51 | # (output: module_asts, module_imports)
|
|---|
| 52 | module_asts = {} # dict[file_path: str, (text, tokens, tree)]
|
|---|
| 53 | module_imports = {} # dict[file_path: str, list[imported_file_path: str]]
|
|---|
| 54 | modules_to_process = [input_file] # list[file_path: str]
|
|---|
| 55 | while modules_to_process:
|
|---|
| 56 | file_path = modules_to_process.pop()
|
|---|
| 57 | if file_path in module_asts:
|
|---|
| 58 | continue
|
|---|
| 59 | if "." in file_path:
|
|---|
| 60 | raise RuntimeError(f"{file_path=} contains '.', which is disallowed. Make sure {args.project_root=} contains all modules")
|
|---|
| 61 | input_stream = FileStream(file_path + ".cmod")
|
|---|
| 62 | text = str(input_stream)
|
|---|
| 63 | lexer = CMODLexer(input_stream)
|
|---|
| 64 | tokens = lexer.getAllTokens() # list[token: {text, start, stop, line, column}]
|
|---|
| 65 | # stop is inclusive, just like getSourceInterval
|
|---|
| 66 | lexer.reset()
|
|---|
| 67 | stream = CommonTokenStream(lexer)
|
|---|
| 68 | parser = CMODParser(stream)
|
|---|
| 69 | tree = parser.compilationUnit() # {getSourceInterval, getChildren, getChildCount, getChild}
|
|---|
| 70 | if parser.getNumberOfSyntaxErrors() > 0:
|
|---|
| 71 | raise RuntimeError(f"syntax errors for file {file_path}")
|
|---|
| 72 | module_asts[file_path] = text, tokens, tree
|
|---|
| 73 |
|
|---|
| 74 | import_names = [] # list[import_name: str]
|
|---|
| 75 | for importDeclaration in tree.getChild(0).getChildren():
|
|---|
| 76 | if type(importDeclaration) != CMODParser.ImportDeclarationContext:
|
|---|
| 77 | continue
|
|---|
| 78 | numChildren = importDeclaration.getChildCount()
|
|---|
| 79 | token_start, _ = importDeclaration.getChild(1).getSourceInterval()
|
|---|
| 80 | text_start = tokens[token_start].start
|
|---|
| 81 | _, token_end = importDeclaration.getChild(numChildren-2).getSourceInterval()
|
|---|
| 82 | input_end = tokens[token_end].stop
|
|---|
| 83 | import_name = text[text_start:input_end+1]
|
|---|
| 84 | if import_name.startswith('"'):
|
|---|
| 85 | import_name = import_name[1:len(import_name)-1]
|
|---|
| 86 | import_names.append(import_name)
|
|---|
| 87 | file_path_dir = os.path.split(file_path)[0]
|
|---|
| 88 | import_file_paths = [os.path.normpath('./'+file_path_dir+'/'+x) for x in import_names]
|
|---|
| 89 | module_imports[file_path] = import_file_paths
|
|---|
| 90 | modules_to_process.extend(module_imports[file_path])
|
|---|
| 91 |
|
|---|
| 92 | print("---- MODULE IMPORTS: ----")
|
|---|
| 93 | pprint(module_imports)
|
|---|
| 94 |
|
|---|
| 95 | # This analyzes each module, extracting top-level symbols
|
|---|
| 96 | # (output: module_data)
|
|---|
| 97 | module_data = {} # dict[file_path: str, (types, variables)]
|
|---|
| 98 | for file_path, (_, tokens, tree) in module_asts.items():
|
|---|
| 99 | types, variables = {}, {} # dict[name: str, (idx: int, used_names_decl, used_names_defn, is_exported: bool, AST)]
|
|---|
| 100 | if tree.translationUnit() is None:
|
|---|
| 101 | module_data[file_path] = types, variables
|
|---|
| 102 | continue
|
|---|
| 103 | for externalDeclaration in tree.translationUnit().externalDeclaration():
|
|---|
| 104 | lfd = externalDeclaration.limitedFunctionDefinition()
|
|---|
| 105 | lstruct = externalDeclaration.limitedStruct()
|
|---|
| 106 | lg = externalDeclaration.limitedGlobal()
|
|---|
| 107 | is_exported = tokens[externalDeclaration.getChild(0).getSourceInterval()[0]].text == 'export'
|
|---|
| 108 | if lfd is not None:
|
|---|
| 109 | used_names_decl, used_names_defn = [], [] # list[(name: str, idx: int, is_variable: bool, needs_defn: bool)]
|
|---|
| 110 | return_struct = lfd.limitedTypeSpecifier().Identifier()
|
|---|
| 111 | if return_struct is not None:
|
|---|
| 112 | idx, _ = return_struct.getSourceInterval()
|
|---|
| 113 | name = tokens[idx].text
|
|---|
| 114 | used_names_decl.append((name, idx, False, lfd.limitedDeclarator().getChildCount() == 1))
|
|---|
| 115 | lpl = lfd.limitedParameterList()
|
|---|
| 116 | if lpl is not None:
|
|---|
| 117 | lts_array = lpl.limitedTypeSpecifier()
|
|---|
| 118 | ld_array = lpl.limitedDeclarator()
|
|---|
| 119 | for lts, ld in zip(lts_array, ld_array):
|
|---|
| 120 | used_struct = lts.Identifier()
|
|---|
| 121 | if used_struct is not None:
|
|---|
| 122 | idx, _ = used_struct.getSourceInterval()
|
|---|
| 123 | name = tokens[idx].text
|
|---|
| 124 | used_names_decl.append((name, idx, False, ld.getChildCount() == 1))
|
|---|
| 125 | for ls in lfd.limitedCompoundStatement().limitedStatement():
|
|---|
| 126 | idx, _ = ls.Identifier().getSourceInterval()
|
|---|
| 127 | name = tokens[idx].text
|
|---|
| 128 | used_names_defn.append((name, idx, ls.getChildCount() == 2, ls.getChildCount() == 3))
|
|---|
| 129 |
|
|---|
| 130 | idx, _ = lfd.limitedDeclarator().Identifier().getSourceInterval()
|
|---|
| 131 | name = tokens[idx].text
|
|---|
| 132 |
|
|---|
| 133 | # Assume no overloading of names
|
|---|
| 134 | variables[name] = idx, used_names_decl, used_names_defn, is_exported, lfd
|
|---|
| 135 | elif lstruct is not None:
|
|---|
| 136 | used_names_decl, used_names_defn = [], []
|
|---|
| 137 | for ls in lstruct.limitedCompoundStatement().limitedStatement():
|
|---|
| 138 | idx, _ = ls.Identifier().getSourceInterval()
|
|---|
| 139 | name = tokens[idx].text
|
|---|
| 140 | used_names_defn.append((name, idx, ls.getChildCount() == 2, ls.getChildCount() == 3))
|
|---|
| 141 |
|
|---|
| 142 | idx, _ = lstruct.Identifier().getSourceInterval()
|
|---|
| 143 | name = tokens[idx].text
|
|---|
| 144 |
|
|---|
| 145 | types[name] = idx, used_names_decl, used_names_defn, is_exported, lstruct
|
|---|
| 146 | elif lg is not None:
|
|---|
| 147 | used_names_decl, used_names_defn = [], []
|
|---|
| 148 | struct_type = lg.limitedTypeSpecifier().Identifier()
|
|---|
| 149 | if struct_type is not None:
|
|---|
| 150 | idx, _ = struct_type.getSourceInterval()
|
|---|
| 151 | name = tokens[idx].text
|
|---|
| 152 | # global decls actually don't need the struct defn
|
|---|
| 153 | used_names_decl.append((name, idx, False, False))
|
|---|
| 154 | used_names_defn.append((name, idx, False, lg.limitedDeclarator().getChildCount() == 1))
|
|---|
| 155 | li = lg.limitedInitializer()
|
|---|
| 156 | if li is not None:
|
|---|
| 157 | for identifier in li.Identifier():
|
|---|
| 158 | idx, _ = identifier.getSourceInterval()
|
|---|
| 159 | name = tokens[idx].text
|
|---|
| 160 | used_names_defn.append((name, idx, True, False))
|
|---|
| 161 |
|
|---|
| 162 | idx, _ = lg.limitedDeclarator().Identifier().getSourceInterval()
|
|---|
| 163 | name = tokens[idx].text
|
|---|
| 164 |
|
|---|
| 165 | variables[name] = idx, used_names_decl, used_names_defn, is_exported, lg
|
|---|
| 166 | module_data[file_path] = types, variables
|
|---|
| 167 |
|
|---|
| 168 | print("---- MODULE DATA: ----")
|
|---|
| 169 | pprint(module_data)
|
|---|
| 170 |
|
|---|
| 171 | # This combines imported symbols to produce symbol tables for each module
|
|---|
| 172 | # (output: module_input)
|
|---|
| 173 | module_input = {} # dict[file_path: str, (type_symbol_table, var_symbol_table)]
|
|---|
| 174 | for file_path, imported_file_paths in module_imports.items():
|
|---|
| 175 | type_symbol_table = {} # dict[name: str, file_path: str]
|
|---|
| 176 | var_symbol_table = {} # dict[name: str, file_path: str]
|
|---|
| 177 | for imported_file_path in imported_file_paths:
|
|---|
| 178 | types, variables = module_data[imported_file_path]
|
|---|
| 179 | # Assume no clashes (ie. overwrites)
|
|---|
| 180 | for name, (_, _, _, is_exported, _) in types.items():
|
|---|
| 181 | if is_exported:
|
|---|
| 182 | type_symbol_table[name] = imported_file_path
|
|---|
| 183 | for name, (_, _, _, is_exported, _) in variables.items():
|
|---|
| 184 | if is_exported:
|
|---|
| 185 | var_symbol_table[name] = imported_file_path
|
|---|
| 186 | types, variables = module_data[file_path]
|
|---|
| 187 | for name in types:
|
|---|
| 188 | type_symbol_table[name] = file_path
|
|---|
| 189 | for name in variables:
|
|---|
| 190 | var_symbol_table[name] = file_path
|
|---|
| 191 | module_input[file_path] = type_symbol_table, var_symbol_table
|
|---|
| 192 |
|
|---|
| 193 | print("---- MODULE INPUT: ----")
|
|---|
| 194 | pprint(module_input)
|
|---|
| 195 |
|
|---|
| 196 | # This performs a topological ordering of any needed symbols, as well as renaming them to the correct module
|
|---|
| 197 | # (output: module_output)
|
|---|
| 198 | module_output = {} # dict[file_path: str, generated_output: str]
|
|---|
| 199 | for file_path in module_asts:
|
|---|
| 200 | already_added = {} # dict[(file_path: str, name: str, is_variable: bool), (idx: int, is_defn: bool)]
|
|---|
| 201 | def is_already_added(file_path, name, is_variable: bool, trying_to_add_defn: bool):
|
|---|
| 202 | if (file_path, name, is_variable) in already_added:
|
|---|
| 203 | # This helps avoid double definitions
|
|---|
| 204 | if trying_to_add_defn:
|
|---|
| 205 | _, is_defn = already_added[(file_path, name, is_variable)]
|
|---|
| 206 | return is_defn
|
|---|
| 207 | else:
|
|---|
| 208 | return True
|
|---|
| 209 | else:
|
|---|
| 210 | return False
|
|---|
| 211 |
|
|---|
| 212 | generated_parts = [] # list[symbol_output: str]
|
|---|
| 213 | circular_check = set() # set[file_path, name, is_variable, is_defn: bool]
|
|---|
| 214 | def try_add_to_parts(file_path, name, is_variable, trying_to_add_defn, idx, used_names_decl, used_names_defn, ast):
|
|---|
| 215 | if is_already_added(file_path, name, is_variable, trying_to_add_defn):
|
|---|
| 216 | return
|
|---|
| 217 | if (file_path, name, is_variable, trying_to_add_defn) in circular_check:
|
|---|
| 218 | circular_reference_str = f"// Circular reference! {file_path}, {name}, {is_variable}, {trying_to_add_defn}"
|
|---|
| 219 | print(circular_reference_str)
|
|---|
| 220 | generated_parts.append(circular_reference_str)
|
|---|
| 221 | return
|
|---|
| 222 | circular_check.add((file_path, name, is_variable, trying_to_add_defn))
|
|---|
| 223 |
|
|---|
| 224 | def add_necessary_symbols(used_names_):
|
|---|
| 225 | for u_name, _, u_is_variable, needs_defn in used_names_:
|
|---|
| 226 | if u_name not in module_input[file_path][u_is_variable]:
|
|---|
| 227 | external_reference_str = f"// External reference! {file_path}, {u_name}, {u_is_variable}"
|
|---|
| 228 | print(external_reference_str)
|
|---|
| 229 | generated_parts.append(external_reference_str)
|
|---|
| 230 | continue
|
|---|
| 231 | m_file_path = module_input[file_path][u_is_variable][u_name]
|
|---|
| 232 | m_idx, m_used_names_decl, m_used_names_defn, _, m_ast = module_data[m_file_path][u_is_variable][u_name]
|
|---|
| 233 | try_add_to_parts(m_file_path, u_name, u_is_variable, needs_defn, m_idx, m_used_names_decl, m_used_names_defn, m_ast)
|
|---|
| 234 | add_necessary_symbols(used_names_decl)
|
|---|
| 235 | if trying_to_add_defn:
|
|---|
| 236 | add_necessary_symbols(used_names_defn)
|
|---|
| 237 |
|
|---|
| 238 | text, tokens, _ = module_asts[file_path]
|
|---|
| 239 | if trying_to_add_defn:
|
|---|
| 240 | token_start, token_end = ast.getSourceInterval()
|
|---|
| 241 | text_start = tokens[token_start].start
|
|---|
| 242 | input_end = tokens[token_end].stop
|
|---|
| 243 | symbol_output = text[text_start:input_end+1]
|
|---|
| 244 | else:
|
|---|
| 245 | token_start, _ = ast.getSourceInterval()
|
|---|
| 246 | text_start = tokens[token_start].start
|
|---|
| 247 | if type(ast) == CMODParser.LimitedFunctionDefinitionContext:
|
|---|
| 248 | token_end, _ = ast.limitedCompoundStatement().getSourceInterval()
|
|---|
| 249 | token_end -= 1
|
|---|
| 250 | while tokens[token_end].text.isspace():
|
|---|
| 251 | token_end -= 1
|
|---|
| 252 | input_end = tokens[token_end].stop
|
|---|
| 253 | symbol_output = text[text_start:input_end+1] + ';'
|
|---|
| 254 | elif type(ast) == CMODParser.LimitedGlobalContext:
|
|---|
| 255 | _, token_end = ast.limitedDeclarator().getSourceInterval()
|
|---|
| 256 | input_end = tokens[token_end].stop
|
|---|
| 257 | extern_str = 'extern '
|
|---|
| 258 | symbol_output = extern_str + text[text_start:input_end+1] + ';'
|
|---|
| 259 | text_start -= len(extern_str) # adjust text_start to make the rewriting work
|
|---|
| 260 | elif type(ast) == CMODParser.LimitedStructContext:
|
|---|
| 261 | token_end, _ = ast.limitedCompoundStatement().getSourceInterval()
|
|---|
| 262 | token_end -= 1
|
|---|
| 263 | while tokens[token_end].text.isspace():
|
|---|
| 264 | token_end -= 1
|
|---|
| 265 | input_end = tokens[token_end].stop
|
|---|
| 266 | symbol_output = text[text_start:input_end+1] + ';'
|
|---|
| 267 | else:
|
|---|
| 268 | assert False, "This code path should not be reached"
|
|---|
| 269 |
|
|---|
| 270 | # This renaming is very unoptimized and relies on used_names_* being in order
|
|---|
| 271 | def full_name(file_path, name):
|
|---|
| 272 | return file_path.replace('/', '$') + '$$' + name
|
|---|
| 273 | def update_symbol_output(file_path, name, idx):
|
|---|
| 274 | nonlocal symbol_output
|
|---|
| 275 | replace_start = tokens[idx].start - text_start
|
|---|
| 276 | replace_end = tokens[idx].stop - text_start
|
|---|
| 277 | symbol_output = symbol_output[:replace_start] + full_name(file_path, name) + symbol_output[replace_end+1:]
|
|---|
| 278 | if trying_to_add_defn:
|
|---|
| 279 | for u_name, u_idx, u_is_variable, _ in reversed(used_names_defn):
|
|---|
| 280 | if u_name not in module_input[file_path][u_is_variable]:
|
|---|
| 281 | continue
|
|---|
| 282 | m_file_path = module_input[file_path][u_is_variable][u_name]
|
|---|
| 283 | update_symbol_output(m_file_path, u_name, u_idx)
|
|---|
| 284 | for u_name, u_idx, u_is_variable, _ in reversed(used_names_decl):
|
|---|
| 285 | if u_idx < idx:
|
|---|
| 286 | break
|
|---|
| 287 | if u_name not in module_input[file_path][u_is_variable]:
|
|---|
| 288 | continue
|
|---|
| 289 | m_file_path = module_input[file_path][u_is_variable][u_name]
|
|---|
| 290 | update_symbol_output(m_file_path, u_name, u_idx)
|
|---|
| 291 | update_symbol_output(file_path, name, idx)
|
|---|
| 292 | for u_name, u_idx, u_is_variable, _ in reversed(used_names_decl):
|
|---|
| 293 | if u_idx >= idx:
|
|---|
| 294 | continue
|
|---|
| 295 | if u_name not in module_input[file_path][u_is_variable]:
|
|---|
| 296 | continue
|
|---|
| 297 | m_file_path = module_input[file_path][u_is_variable][u_name]
|
|---|
| 298 | update_symbol_output(m_file_path, u_name, u_idx)
|
|---|
| 299 | generated_parts.append(symbol_output)
|
|---|
| 300 |
|
|---|
| 301 | already_added[(file_path, name, is_variable)] = len(generated_parts)-1, trying_to_add_defn
|
|---|
| 302 |
|
|---|
| 303 | types, variables = module_data[file_path]
|
|---|
| 304 | for name, (idx, used_names_decl, used_names_defn, _, ast) in types.items():
|
|---|
| 305 | try_add_to_parts(file_path, name, False, True, idx, used_names_decl, used_names_defn, ast)
|
|---|
| 306 | for name, (idx, used_names_decl, used_names_defn, _, ast) in variables.items():
|
|---|
| 307 | try_add_to_parts(file_path, name, True, True, idx, used_names_decl, used_names_defn, ast)
|
|---|
| 308 | module_output[file_path] = "\n\n".join(generated_parts)
|
|---|
| 309 |
|
|---|
| 310 | for file_path, generated_output in module_output.items():
|
|---|
| 311 | print(f"//////////////// START OF FILE {file_path} ////////////////")
|
|---|
| 312 | print(generated_output)
|
|---|
| 313 |
|
|---|
| 314 | def simple_parse(input_file):
|
|---|
| 315 | # Used in debugging
|
|---|
| 316 | input_stream = FileStream(input_file)
|
|---|
| 317 | lexer = CMODLexer(input_stream)
|
|---|
| 318 | stream = CommonTokenStream(lexer)
|
|---|
| 319 | parser = CMODParser(stream)
|
|---|
| 320 | tree = parser.compilationUnit()
|
|---|
| 321 | print(tree.getText())
|
|---|
| 322 |
|
|---|
| 323 | def details(ast):
|
|---|
| 324 | # Used in debugging
|
|---|
| 325 | print(f"Details for {repr(ast)}")
|
|---|
| 326 | num_auto_expand = 0
|
|---|
| 327 | while ast.getChildCount() == 1:
|
|---|
| 328 | ast = ast.getChild(0)
|
|---|
| 329 | num_auto_expand += 1
|
|---|
| 330 | print(repr(ast))
|
|---|
| 331 | if num_auto_expand:
|
|---|
| 332 | print(f"auto expanded {num_auto_expand} times")
|
|---|
| 333 | n = ast.getChildCount()
|
|---|
| 334 | print(n)
|
|---|
| 335 | for i in range(n):
|
|---|
| 336 | print()
|
|---|
| 337 | print(i)
|
|---|
| 338 | print(repr(ast.getChild(i)))
|
|---|
| 339 | print(ast.getChild(i).getText())
|
|---|
| 340 |
|
|---|
| 341 | if __name__ == '__main__':
|
|---|
| 342 | main()
|
|---|