[FRONTEND] better stringification (#394)

- Don't override `self.args` in `CompilationError`, and show the line number and column in error messages. This causes it to generate an easier-to-read backtrace.
- Better `__str__` on `TensorWrapper`, `dtype`, and `block`.
This commit is contained in:
Madeleine Thompson
2021-12-17 20:11:45 -08:00
committed by GitHub
parent 4e93b41c52
commit fa62b4a8f6
3 changed files with 20 additions and 5 deletions

View File

@@ -74,8 +74,8 @@ public:
bool is_fp8_ty() const { return id_ == FP8TyID; }
bool is_fp16_ty() const { return id_ == FP16TyID; }
bool is_bf16_ty() const { return id_ == BF16TyID; }
bool is_fp32_ty() const { return id_ == FP32TyID; }
bool is_fp64_ty() const { return id_ == FP64TyID; }
bool is_fp32_ty() const { return id_ == FP32TyID; }
bool is_fp64_ty() const { return id_ == FP64TyID; }
bool is_label_ty() const { return id_ == LabelTyID;}
bool is_metadata_ty() const { return id_ == MetadataTyID; }
bool is_token_ty() const { return id_ == TokenTyID; }

View File

@@ -521,10 +521,10 @@ class LoadedBinary:
class CompilationError(Exception):
def __init__(self, src, node):
self.message = '\n'.join(src.split('\n')[:node.lineno])
self.message = f'at {node.lineno}:{node.col_offset}:\n'
self.message += '\n'.join(src.split('\n')[:node.lineno])
self.message += '\n' + ' ' * node.col_offset + '^'
super().__init__(self.message)
self.args = (src, node)
class OutOfResources(Exception):
@@ -1085,6 +1085,9 @@ class TensorWrapper:
def data_ptr(self):
return self.base.data_ptr()
def __str__(self) -> str:
return f'TensorWrapper[{self.dtype}]({self.base})'
def reinterpret(tensor, dtype):
return TensorWrapper(tensor, dtype)

View File

@@ -68,12 +68,20 @@ class dtype:
def __init__(self, init):
self.init = init
@property
def name(self) -> str:
# The init functions are named something like 'get_int8'. Strip the prefix.
nom = self.init.__name__
prefix = 'get_'
assert nom.startswith(prefix)
return nom[len(prefix):]
def handle(self, builder):
ctx = builder.context
return self.init(ctx)
def __str__(self):
return f"dtype({self.init.__name__})"
return self.name
class pointer_dtype:
@@ -131,6 +139,10 @@ class block:
# Data-type wrapper
self.dtype = block._init_dtype(self.handle.type.scalar)
def __str__(self) -> str:
# ex. "float32[3,4]"
return str(self.dtype) + '[' + ','.join(str(s) for s in self.shape) + ']'
@builtin
def __add__(self, other, _builder=None):
return frontend.add(self, other, _builder)