source: doc/proposals/modules-alvin/1_stitched_modules/Driver.py

stuck-waitfor-destruct
Last change on this file was 2cb10170, checked in by Alvin Zhang <alvin.zhang@…>, 3 weeks ago

stitched modules proposal

  • Property mode set to 100644
File size: 17.2 KB
Line 
1import argparse
2import os
3from pprint import pprint
4from antlr4 import *
5from parser.CMODLexer import CMODLexer
6from parser.CMODParser import CMODParser
7
8def parse_args():
9 parser = argparse.ArgumentParser()
10 parser.add_argument("input_file", help="file to parse")
11 parser.add_argument("-r", "--project_root", help="the root folder of the project, used to generate module names. All modules, including imports, should be within this folder")
12 args = parser.parse_args()
13 return args
14
15def main():
16 args = parse_args()
17 parse_top_level(args)
18
19def parse_top_level(args):
20 # args.input_file
21 # args.project_root
22
23 # This changes the OS directory to be project_root, and removes the suffix from input_file
24 # (it makes it easier to navigate the imports)
25 input_file = os.path.relpath(args.input_file, args.project_root)
26 os.chdir(args.project_root)
27 dot_index = input_file.rfind('.')
28 if dot_index < 0 or input_file[dot_index:] != '.cmod':
29 raise RuntimeError("input_file must use .cmod suffix")
30 input_file = input_file[:dot_index]
31
32 # Type notes to help you understand this code
33 # (it helps to have this open in another tab while reading the code)
34 #
35 # module_asts: dict[file_path: str, (text, tokens, tree)]
36 # module_imports: dict[file_path: str, list[imported_file_path: str]]
37 # text # str
38 # tokens[0].text # start, stop, line, column
39 # a, b = tree.getSourceInterval() # getChildren, getChildCount, getChild
40 #
41 # module_data: dict[file_path: str, (types, variables)]
42 # types, variables: dict[name: str,
43 # (idx: int, used_names_decl, used_names_defn, is_exported: bool, AST)]
44 # used_names_decl, used_names_defn:
45 # list[(name: str, idx: int, is_variable: bool, needs_defn: bool)]
46 #
47 # module_input: dict[file_path: str, (type_symbol_table, var_symbol_table)]
48 # type_symbol_table, var_symbol_table: dict[name: str, file_path: str]
49
50 # This parses the input file, as well as any recursively imported modules
51 # (output: module_asts, module_imports)
52 module_asts = {} # dict[file_path: str, (text, tokens, tree)]
53 module_imports = {} # dict[file_path: str, list[imported_file_path: str]]
54 modules_to_process = [input_file] # list[file_path: str]
55 while modules_to_process:
56 file_path = modules_to_process.pop()
57 if file_path in module_asts:
58 continue
59 if "." in file_path:
60 raise RuntimeError(f"{file_path=} contains '.', which is disallowed. Make sure {args.project_root=} contains all modules")
61 input_stream = FileStream(file_path + ".cmod")
62 text = str(input_stream)
63 lexer = CMODLexer(input_stream)
64 tokens = lexer.getAllTokens() # list[token: {text, start, stop, line, column}]
65 # stop is inclusive, just like getSourceInterval
66 lexer.reset()
67 stream = CommonTokenStream(lexer)
68 parser = CMODParser(stream)
69 tree = parser.compilationUnit() # {getSourceInterval, getChildren, getChildCount, getChild}
70 if parser.getNumberOfSyntaxErrors() > 0:
71 raise RuntimeError(f"syntax errors for file {file_path}")
72 module_asts[file_path] = text, tokens, tree
73
74 import_names = [] # list[import_name: str]
75 for importDeclaration in tree.getChild(0).getChildren():
76 if type(importDeclaration) != CMODParser.ImportDeclarationContext:
77 continue
78 numChildren = importDeclaration.getChildCount()
79 token_start, _ = importDeclaration.getChild(1).getSourceInterval()
80 text_start = tokens[token_start].start
81 _, token_end = importDeclaration.getChild(numChildren-2).getSourceInterval()
82 input_end = tokens[token_end].stop
83 import_name = text[text_start:input_end+1]
84 if import_name.startswith('"'):
85 import_name = import_name[1:len(import_name)-1]
86 import_names.append(import_name)
87 file_path_dir = os.path.split(file_path)[0]
88 import_file_paths = [os.path.normpath('./'+file_path_dir+'/'+x) for x in import_names]
89 module_imports[file_path] = import_file_paths
90 modules_to_process.extend(module_imports[file_path])
91
92 print("---- MODULE IMPORTS: ----")
93 pprint(module_imports)
94
95 # This analyzes each module, extracting top-level symbols
96 # (output: module_data)
97 module_data = {} # dict[file_path: str, (types, variables)]
98 for file_path, (_, tokens, tree) in module_asts.items():
99 types, variables = {}, {} # dict[name: str, (idx: int, used_names_decl, used_names_defn, is_exported: bool, AST)]
100 if tree.translationUnit() is None:
101 module_data[file_path] = types, variables
102 continue
103 for externalDeclaration in tree.translationUnit().externalDeclaration():
104 lfd = externalDeclaration.limitedFunctionDefinition()
105 lstruct = externalDeclaration.limitedStruct()
106 lg = externalDeclaration.limitedGlobal()
107 is_exported = tokens[externalDeclaration.getChild(0).getSourceInterval()[0]].text == 'export'
108 if lfd is not None:
109 used_names_decl, used_names_defn = [], [] # list[(name: str, idx: int, is_variable: bool, needs_defn: bool)]
110 return_struct = lfd.limitedTypeSpecifier().Identifier()
111 if return_struct is not None:
112 idx, _ = return_struct.getSourceInterval()
113 name = tokens[idx].text
114 used_names_decl.append((name, idx, False, lfd.limitedDeclarator().getChildCount() == 1))
115 lpl = lfd.limitedParameterList()
116 if lpl is not None:
117 lts_array = lpl.limitedTypeSpecifier()
118 ld_array = lpl.limitedDeclarator()
119 for lts, ld in zip(lts_array, ld_array):
120 used_struct = lts.Identifier()
121 if used_struct is not None:
122 idx, _ = used_struct.getSourceInterval()
123 name = tokens[idx].text
124 used_names_decl.append((name, idx, False, ld.getChildCount() == 1))
125 for ls in lfd.limitedCompoundStatement().limitedStatement():
126 idx, _ = ls.Identifier().getSourceInterval()
127 name = tokens[idx].text
128 used_names_defn.append((name, idx, ls.getChildCount() == 2, ls.getChildCount() == 3))
129
130 idx, _ = lfd.limitedDeclarator().Identifier().getSourceInterval()
131 name = tokens[idx].text
132
133 # Assume no overloading of names
134 variables[name] = idx, used_names_decl, used_names_defn, is_exported, lfd
135 elif lstruct is not None:
136 used_names_decl, used_names_defn = [], []
137 for ls in lstruct.limitedCompoundStatement().limitedStatement():
138 idx, _ = ls.Identifier().getSourceInterval()
139 name = tokens[idx].text
140 used_names_defn.append((name, idx, ls.getChildCount() == 2, ls.getChildCount() == 3))
141
142 idx, _ = lstruct.Identifier().getSourceInterval()
143 name = tokens[idx].text
144
145 types[name] = idx, used_names_decl, used_names_defn, is_exported, lstruct
146 elif lg is not None:
147 used_names_decl, used_names_defn = [], []
148 struct_type = lg.limitedTypeSpecifier().Identifier()
149 if struct_type is not None:
150 idx, _ = struct_type.getSourceInterval()
151 name = tokens[idx].text
152 # global decls actually don't need the struct defn
153 used_names_decl.append((name, idx, False, False))
154 used_names_defn.append((name, idx, False, lg.limitedDeclarator().getChildCount() == 1))
155 li = lg.limitedInitializer()
156 if li is not None:
157 for identifier in li.Identifier():
158 idx, _ = identifier.getSourceInterval()
159 name = tokens[idx].text
160 used_names_defn.append((name, idx, True, False))
161
162 idx, _ = lg.limitedDeclarator().Identifier().getSourceInterval()
163 name = tokens[idx].text
164
165 variables[name] = idx, used_names_decl, used_names_defn, is_exported, lg
166 module_data[file_path] = types, variables
167
168 print("---- MODULE DATA: ----")
169 pprint(module_data)
170
171 # This combines imported symbols to produce symbol tables for each module
172 # (output: module_input)
173 module_input = {} # dict[file_path: str, (type_symbol_table, var_symbol_table)]
174 for file_path, imported_file_paths in module_imports.items():
175 type_symbol_table = {} # dict[name: str, file_path: str]
176 var_symbol_table = {} # dict[name: str, file_path: str]
177 for imported_file_path in imported_file_paths:
178 types, variables = module_data[imported_file_path]
179 # Assume no clashes (ie. overwrites)
180 for name, (_, _, _, is_exported, _) in types.items():
181 if is_exported:
182 type_symbol_table[name] = imported_file_path
183 for name, (_, _, _, is_exported, _) in variables.items():
184 if is_exported:
185 var_symbol_table[name] = imported_file_path
186 types, variables = module_data[file_path]
187 for name in types:
188 type_symbol_table[name] = file_path
189 for name in variables:
190 var_symbol_table[name] = file_path
191 module_input[file_path] = type_symbol_table, var_symbol_table
192
193 print("---- MODULE INPUT: ----")
194 pprint(module_input)
195
196 # This performs a topological ordering of any needed symbols, as well as renaming them to the correct module
197 # (output: module_output)
198 module_output = {} # dict[file_path: str, generated_output: str]
199 for file_path in module_asts:
200 already_added = {} # dict[(file_path: str, name: str, is_variable: bool), (idx: int, is_defn: bool)]
201 def is_already_added(file_path, name, is_variable: bool, trying_to_add_defn: bool):
202 if (file_path, name, is_variable) in already_added:
203 # This helps avoid double definitions
204 if trying_to_add_defn:
205 _, is_defn = already_added[(file_path, name, is_variable)]
206 return is_defn
207 else:
208 return True
209 else:
210 return False
211
212 generated_parts = [] # list[symbol_output: str]
213 circular_check = set() # set[file_path, name, is_variable, is_defn: bool]
214 def try_add_to_parts(file_path, name, is_variable, trying_to_add_defn, idx, used_names_decl, used_names_defn, ast):
215 if is_already_added(file_path, name, is_variable, trying_to_add_defn):
216 return
217 if (file_path, name, is_variable, trying_to_add_defn) in circular_check:
218 circular_reference_str = f"// Circular reference! {file_path}, {name}, {is_variable}, {trying_to_add_defn}"
219 print(circular_reference_str)
220 generated_parts.append(circular_reference_str)
221 return
222 circular_check.add((file_path, name, is_variable, trying_to_add_defn))
223
224 def add_necessary_symbols(used_names_):
225 for u_name, _, u_is_variable, needs_defn in used_names_:
226 if u_name not in module_input[file_path][u_is_variable]:
227 external_reference_str = f"// External reference! {file_path}, {u_name}, {u_is_variable}"
228 print(external_reference_str)
229 generated_parts.append(external_reference_str)
230 continue
231 m_file_path = module_input[file_path][u_is_variable][u_name]
232 m_idx, m_used_names_decl, m_used_names_defn, _, m_ast = module_data[m_file_path][u_is_variable][u_name]
233 try_add_to_parts(m_file_path, u_name, u_is_variable, needs_defn, m_idx, m_used_names_decl, m_used_names_defn, m_ast)
234 add_necessary_symbols(used_names_decl)
235 if trying_to_add_defn:
236 add_necessary_symbols(used_names_defn)
237
238 text, tokens, _ = module_asts[file_path]
239 if trying_to_add_defn:
240 token_start, token_end = ast.getSourceInterval()
241 text_start = tokens[token_start].start
242 input_end = tokens[token_end].stop
243 symbol_output = text[text_start:input_end+1]
244 else:
245 token_start, _ = ast.getSourceInterval()
246 text_start = tokens[token_start].start
247 if type(ast) == CMODParser.LimitedFunctionDefinitionContext:
248 token_end, _ = ast.limitedCompoundStatement().getSourceInterval()
249 token_end -= 1
250 while tokens[token_end].text.isspace():
251 token_end -= 1
252 input_end = tokens[token_end].stop
253 symbol_output = text[text_start:input_end+1] + ';'
254 elif type(ast) == CMODParser.LimitedGlobalContext:
255 _, token_end = ast.limitedDeclarator().getSourceInterval()
256 input_end = tokens[token_end].stop
257 extern_str = 'extern '
258 symbol_output = extern_str + text[text_start:input_end+1] + ';'
259 text_start -= len(extern_str) # adjust text_start to make the rewriting work
260 elif type(ast) == CMODParser.LimitedStructContext:
261 token_end, _ = ast.limitedCompoundStatement().getSourceInterval()
262 token_end -= 1
263 while tokens[token_end].text.isspace():
264 token_end -= 1
265 input_end = tokens[token_end].stop
266 symbol_output = text[text_start:input_end+1] + ';'
267 else:
268 assert False, "This code path should not be reached"
269
270 # This renaming is very unoptimized and relies on used_names_* being in order
271 def full_name(file_path, name):
272 return file_path.replace('/', '$') + '$$' + name
273 def update_symbol_output(file_path, name, idx):
274 nonlocal symbol_output
275 replace_start = tokens[idx].start - text_start
276 replace_end = tokens[idx].stop - text_start
277 symbol_output = symbol_output[:replace_start] + full_name(file_path, name) + symbol_output[replace_end+1:]
278 if trying_to_add_defn:
279 for u_name, u_idx, u_is_variable, _ in reversed(used_names_defn):
280 if u_name not in module_input[file_path][u_is_variable]:
281 continue
282 m_file_path = module_input[file_path][u_is_variable][u_name]
283 update_symbol_output(m_file_path, u_name, u_idx)
284 for u_name, u_idx, u_is_variable, _ in reversed(used_names_decl):
285 if u_idx < idx:
286 break
287 if u_name not in module_input[file_path][u_is_variable]:
288 continue
289 m_file_path = module_input[file_path][u_is_variable][u_name]
290 update_symbol_output(m_file_path, u_name, u_idx)
291 update_symbol_output(file_path, name, idx)
292 for u_name, u_idx, u_is_variable, _ in reversed(used_names_decl):
293 if u_idx >= idx:
294 continue
295 if u_name not in module_input[file_path][u_is_variable]:
296 continue
297 m_file_path = module_input[file_path][u_is_variable][u_name]
298 update_symbol_output(m_file_path, u_name, u_idx)
299 generated_parts.append(symbol_output)
300
301 already_added[(file_path, name, is_variable)] = len(generated_parts)-1, trying_to_add_defn
302
303 types, variables = module_data[file_path]
304 for name, (idx, used_names_decl, used_names_defn, _, ast) in types.items():
305 try_add_to_parts(file_path, name, False, True, idx, used_names_decl, used_names_defn, ast)
306 for name, (idx, used_names_decl, used_names_defn, _, ast) in variables.items():
307 try_add_to_parts(file_path, name, True, True, idx, used_names_decl, used_names_defn, ast)
308 module_output[file_path] = "\n\n".join(generated_parts)
309
310 for file_path, generated_output in module_output.items():
311 print(f"//////////////// START OF FILE {file_path} ////////////////")
312 print(generated_output)
313
314def simple_parse(input_file):
315 # Used in debugging
316 input_stream = FileStream(input_file)
317 lexer = CMODLexer(input_stream)
318 stream = CommonTokenStream(lexer)
319 parser = CMODParser(stream)
320 tree = parser.compilationUnit()
321 print(tree.getText())
322
323def details(ast):
324 # Used in debugging
325 print(f"Details for {repr(ast)}")
326 num_auto_expand = 0
327 while ast.getChildCount() == 1:
328 ast = ast.getChild(0)
329 num_auto_expand += 1
330 print(repr(ast))
331 if num_auto_expand:
332 print(f"auto expanded {num_auto_expand} times")
333 n = ast.getChildCount()
334 print(n)
335 for i in range(n):
336 print()
337 print(i)
338 print(repr(ast.getChild(i)))
339 print(ast.getChild(i).getText())
340
341if __name__ == '__main__':
342 main()
Note: See TracBrowser for help on using the repository browser.