[FRONTEND] Fix ExternLibrary(format=) bug; type annotate build_extern.py (#883)

Ran mypy over `build_extern.py`, cleaned up type annotations.

Found a fixed a bug where `ExternLibrary(format=)` was being ignored.
This commit is contained in:
Crutcher Dunnavant
2022-11-17 09:45:30 -08:00
committed by GitHub
parent 0d7e753227
commit 0e4691e6dd

View File

@@ -1,10 +1,24 @@
import argparse
import subprocess
from abc import ABC, abstractmethod
from typing import Dict, List, Optional
class Symbol:
def __init__(self, name: str, op_name: str, ret_type: str, arg_names: list, arg_types: list) -> None:
_name: str
_op_name: str
_ret_type: str
_arg_names: List[str]
_arg_types: List[str]
def __init__(
self,
name: str,
op_name: str,
ret_type: str,
arg_names: List[str],
arg_types: List[str],
) -> None:
'''
A symbol is a function declaration.
@@ -17,31 +31,31 @@ class Symbol:
self._name = name
self._op_name = op_name
self._ret_type = ret_type
self._arg_names = arg_names
self._arg_types = arg_types
self._arg_names = list(arg_names)
self._arg_types = list(arg_types)
@property
def name(self):
def name(self) -> str:
return self._name
@property
def op_name(self):
def op_name(self) -> str:
return self._op_name
@property
def ret_type(self):
def ret_type(self) -> str:
return self._ret_type
@property
def arg_names(self):
def arg_names(self) -> List[str]:
return self._arg_names
@property
def arg_types(self):
def arg_types(self) -> List[str]:
return self._arg_types
def convert_type(type_str):
def convert_type(type_str) -> Optional[str]:
if type_str == "i32":
return "int32"
elif type_str == "u32":
@@ -59,7 +73,7 @@ def convert_type(type_str):
return None
def to_unsigned(type_str):
def to_unsigned(type_str) -> str:
if type_str == "int32":
return "uint32"
elif type_str == "int64":
@@ -69,7 +83,19 @@ def to_unsigned(type_str):
class ExternLibrary(ABC):
def __init__(self, name: str, path: str, format: bool = True, grouping: bool = True) -> None:
_name: str
_path: str
_symbols: Dict[str, Symbol]
_format: bool
_grouping: bool
def __init__(
self,
name: str,
path: str,
format: bool = True,
grouping: bool = True,
) -> None:
'''
Abstract class for extern library.
@@ -80,34 +106,34 @@ class ExternLibrary(ABC):
self._name = name
self._path = path
self._symbols = {}
self._format = True
self._format = format
self._grouping = grouping
@property
def name(self):
def name(self) -> str:
return self._name
@property
def path(self):
def path(self) -> str:
return self._path
@property
def symbols(self):
def symbols(self) -> Dict[str, Symbol]:
return self._symbols
@property
def grouping(self):
def grouping(self) -> bool:
return self._grouping
@abstractmethod
def parse_symbols(self, input_file):
def parse_symbols(self, input_file) -> None:
pass
@abstractmethod
def _output_stubs(self) -> str:
pass
def generate_stub_file(self, output_dir):
def generate_stub_file(self, output_dir) -> None:
file_str = self._output_stubs()
if file_str is None or len(file_str) == 0:
raise Exception("file_str is empty")
@@ -123,6 +149,8 @@ class ExternLibrary(ABC):
class Libdevice(ExternLibrary):
_symbol_groups: Dict[str, List[Symbol]]
def __init__(self, path) -> None:
'''
Constructor for Libdevice.
@@ -132,7 +160,7 @@ class Libdevice(ExternLibrary):
super().__init__("libdevice", path)
self._symbol_groups = {}
def _extract_symbol(self, line):
def _extract_symbol(self, line) -> Optional[Symbol]:
# Extract symbols from line in the following format:
# "define [internal] <ret_type> @<name>(<arg_types>,)"
entries = line.split("@")
@@ -174,7 +202,7 @@ class Libdevice(ExternLibrary):
arg_types[i] = to_unsigned(arg_type)
return Symbol(func_name, op_name, ret_type, arg_names, arg_types)
def _group_symbols(self):
def _group_symbols(self) -> None:
symbol_set = {}
for symbol in self._symbols.values():
op_name = symbol.op_name
@@ -244,7 +272,7 @@ class Libdevice(ExternLibrary):
else:
self._symbol_groups[op_name] = [symbol]
def parse_symbols(self, input_file):
def parse_symbols(self, input_file) -> None:
if len(self.symbols) > 0:
return
output = subprocess.check_output(["grep", "define", input_file]).decode().splitlines()
@@ -256,7 +284,7 @@ class Libdevice(ExternLibrary):
self._group_symbols()
def _output_stubs(self):
def _output_stubs(self) -> str:
# Generate python functions in the following format:
# @extern.extern
# def <op_name>(<args>, _builder=None):
@@ -297,7 +325,10 @@ class Libdevice(ExternLibrary):
class LLVMDisassembler:
def __init__(self, path):
_path: str
_ll_file: str
def __init__(self, path) -> None:
'''
Invoke llvm-dis to disassemble the given file.
@@ -306,23 +337,28 @@ class LLVMDisassembler:
self._path = path
self._ll_file = "/tmp/extern_lib.ll"
def disasm(self, lib_path):
def disasm(self, lib_path: str) -> None:
subprocess.Popen([self._path, lib_path, "-o", self.ll_file],
stdout=subprocess.PIPE).communicate()
@property
def ll_file(self):
def ll_file(self) -> str:
return self._ll_file
@property
def path(self):
def path(self) -> str:
return self._path
extern_libs = ["libdevice"]
def build(llvm_dis_path, lib_path, lib_name, output_dir):
def build(
llvm_dis_path: str,
lib_path: str,
lib_name: str,
output_dir: str,
) -> None:
'''
Interface function to build the library file.