[FRONTEND] better stringification (#394)
- Don't override `self.args` in `CompilationError`, and show the line number and column in error messages. This causes it to generate an easier-to-read backtrace. - Better `__str__` on `TensorWrapper`, `dtype`, and `block`.
This commit is contained in:
committed by
GitHub
parent
4e93b41c52
commit
fa62b4a8f6
@@ -74,8 +74,8 @@ public:
|
||||
bool is_fp8_ty() const { return id_ == FP8TyID; }
|
||||
bool is_fp16_ty() const { return id_ == FP16TyID; }
|
||||
bool is_bf16_ty() const { return id_ == BF16TyID; }
|
||||
bool is_fp32_ty() const { return id_ == FP32TyID; }
|
||||
bool is_fp64_ty() const { return id_ == FP64TyID; }
|
||||
bool is_fp32_ty() const { return id_ == FP32TyID; }
|
||||
bool is_fp64_ty() const { return id_ == FP64TyID; }
|
||||
bool is_label_ty() const { return id_ == LabelTyID;}
|
||||
bool is_metadata_ty() const { return id_ == MetadataTyID; }
|
||||
bool is_token_ty() const { return id_ == TokenTyID; }
|
||||
|
@@ -521,10 +521,10 @@ class LoadedBinary:
|
||||
|
||||
class CompilationError(Exception):
|
||||
def __init__(self, src, node):
|
||||
self.message = '\n'.join(src.split('\n')[:node.lineno])
|
||||
self.message = f'at {node.lineno}:{node.col_offset}:\n'
|
||||
self.message += '\n'.join(src.split('\n')[:node.lineno])
|
||||
self.message += '\n' + ' ' * node.col_offset + '^'
|
||||
super().__init__(self.message)
|
||||
self.args = (src, node)
|
||||
|
||||
|
||||
class OutOfResources(Exception):
|
||||
@@ -1085,6 +1085,9 @@ class TensorWrapper:
|
||||
def data_ptr(self):
|
||||
return self.base.data_ptr()
|
||||
|
||||
def __str__(self) -> str:
|
||||
return f'TensorWrapper[{self.dtype}]({self.base})'
|
||||
|
||||
|
||||
def reinterpret(tensor, dtype):
|
||||
return TensorWrapper(tensor, dtype)
|
||||
|
@@ -68,12 +68,20 @@ class dtype:
|
||||
def __init__(self, init):
|
||||
self.init = init
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
# The init functions are named something like 'get_int8'. Strip the prefix.
|
||||
nom = self.init.__name__
|
||||
prefix = 'get_'
|
||||
assert nom.startswith(prefix)
|
||||
return nom[len(prefix):]
|
||||
|
||||
def handle(self, builder):
|
||||
ctx = builder.context
|
||||
return self.init(ctx)
|
||||
|
||||
def __str__(self):
|
||||
return f"dtype({self.init.__name__})"
|
||||
return self.name
|
||||
|
||||
|
||||
class pointer_dtype:
|
||||
@@ -131,6 +139,10 @@ class block:
|
||||
# Data-type wrapper
|
||||
self.dtype = block._init_dtype(self.handle.type.scalar)
|
||||
|
||||
def __str__(self) -> str:
|
||||
# ex. "float32[3,4]"
|
||||
return str(self.dtype) + '[' + ','.join(str(s) for s in self.shape) + ']'
|
||||
|
||||
@builtin
|
||||
def __add__(self, other, _builder=None):
|
||||
return frontend.add(self, other, _builder)
|
||||
|
Reference in New Issue
Block a user