diff --git a/python/triton/tools/build_extern.py b/python/triton/tools/build_extern.py index 551def69d..f4141c31f 100644 --- a/python/triton/tools/build_extern.py +++ b/python/triton/tools/build_extern.py @@ -1,10 +1,24 @@ import argparse import subprocess from abc import ABC, abstractmethod +from typing import Dict, List, Optional class Symbol: - def __init__(self, name: str, op_name: str, ret_type: str, arg_names: list, arg_types: list) -> None: + _name: str + _op_name: str + _ret_type: str + _arg_names: List[str] + _arg_types: List[str] + + def __init__( + self, + name: str, + op_name: str, + ret_type: str, + arg_names: List[str], + arg_types: List[str], + ) -> None: ''' A symbol is a function declaration. @@ -17,31 +31,31 @@ class Symbol: self._name = name self._op_name = op_name self._ret_type = ret_type - self._arg_names = arg_names - self._arg_types = arg_types + self._arg_names = list(arg_names) + self._arg_types = list(arg_types) @property - def name(self): + def name(self) -> str: return self._name @property - def op_name(self): + def op_name(self) -> str: return self._op_name @property - def ret_type(self): + def ret_type(self) -> str: return self._ret_type @property - def arg_names(self): + def arg_names(self) -> List[str]: return self._arg_names @property - def arg_types(self): + def arg_types(self) -> List[str]: return self._arg_types -def convert_type(type_str): +def convert_type(type_str) -> Optional[str]: if type_str == "i32": return "int32" elif type_str == "u32": @@ -59,7 +73,7 @@ def convert_type(type_str): return None -def to_unsigned(type_str): +def to_unsigned(type_str) -> str: if type_str == "int32": return "uint32" elif type_str == "int64": @@ -69,7 +83,19 @@ def to_unsigned(type_str): class ExternLibrary(ABC): - def __init__(self, name: str, path: str, format: bool = True, grouping: bool = True) -> None: + _name: str + _path: str + _symbols: Dict[str, Symbol] + _format: bool + _grouping: bool + + def __init__( + self, + name: str, + path: str, + format: bool = True, + grouping: bool = True, + ) -> None: ''' Abstract class for extern library. @@ -80,34 +106,34 @@ class ExternLibrary(ABC): self._name = name self._path = path self._symbols = {} - self._format = True + self._format = format self._grouping = grouping @property - def name(self): + def name(self) -> str: return self._name @property - def path(self): + def path(self) -> str: return self._path @property - def symbols(self): + def symbols(self) -> Dict[str, Symbol]: return self._symbols @property - def grouping(self): + def grouping(self) -> bool: return self._grouping @abstractmethod - def parse_symbols(self, input_file): + def parse_symbols(self, input_file) -> None: pass @abstractmethod def _output_stubs(self) -> str: pass - def generate_stub_file(self, output_dir): + def generate_stub_file(self, output_dir) -> None: file_str = self._output_stubs() if file_str is None or len(file_str) == 0: raise Exception("file_str is empty") @@ -123,6 +149,8 @@ class ExternLibrary(ABC): class Libdevice(ExternLibrary): + _symbol_groups: Dict[str, List[Symbol]] + def __init__(self, path) -> None: ''' Constructor for Libdevice. @@ -132,7 +160,7 @@ class Libdevice(ExternLibrary): super().__init__("libdevice", path) self._symbol_groups = {} - def _extract_symbol(self, line): + def _extract_symbol(self, line) -> Optional[Symbol]: # Extract symbols from line in the following format: # "define [internal] @(,)" entries = line.split("@") @@ -174,7 +202,7 @@ class Libdevice(ExternLibrary): arg_types[i] = to_unsigned(arg_type) return Symbol(func_name, op_name, ret_type, arg_names, arg_types) - def _group_symbols(self): + def _group_symbols(self) -> None: symbol_set = {} for symbol in self._symbols.values(): op_name = symbol.op_name @@ -244,7 +272,7 @@ class Libdevice(ExternLibrary): else: self._symbol_groups[op_name] = [symbol] - def parse_symbols(self, input_file): + def parse_symbols(self, input_file) -> None: if len(self.symbols) > 0: return output = subprocess.check_output(["grep", "define", input_file]).decode().splitlines() @@ -256,7 +284,7 @@ class Libdevice(ExternLibrary): self._group_symbols() - def _output_stubs(self): + def _output_stubs(self) -> str: # Generate python functions in the following format: # @extern.extern # def (, _builder=None): @@ -297,7 +325,10 @@ class Libdevice(ExternLibrary): class LLVMDisassembler: - def __init__(self, path): + _path: str + _ll_file: str + + def __init__(self, path) -> None: ''' Invoke llvm-dis to disassemble the given file. @@ -306,23 +337,28 @@ class LLVMDisassembler: self._path = path self._ll_file = "/tmp/extern_lib.ll" - def disasm(self, lib_path): + def disasm(self, lib_path: str) -> None: subprocess.Popen([self._path, lib_path, "-o", self.ll_file], stdout=subprocess.PIPE).communicate() @property - def ll_file(self): + def ll_file(self) -> str: return self._ll_file @property - def path(self): + def path(self) -> str: return self._path extern_libs = ["libdevice"] -def build(llvm_dis_path, lib_path, lib_name, output_dir): +def build( + llvm_dis_path: str, + lib_path: str, + lib_name: str, + output_dir: str, +) -> None: ''' Interface function to build the library file.