[FRONTEND] better stringification (#394)
- Don't override `self.args` in `CompilationError`, and show the line number and column in error messages. This causes it to generate an easier-to-read backtrace. - Better `__str__` on `TensorWrapper`, `dtype`, and `block`.
This commit is contained in:
committed by
GitHub
parent
4e93b41c52
commit
fa62b4a8f6
@@ -521,10 +521,10 @@ class LoadedBinary:
|
|||||||
|
|
||||||
class CompilationError(Exception):
|
class CompilationError(Exception):
|
||||||
def __init__(self, src, node):
|
def __init__(self, src, node):
|
||||||
self.message = '\n'.join(src.split('\n')[:node.lineno])
|
self.message = f'at {node.lineno}:{node.col_offset}:\n'
|
||||||
|
self.message += '\n'.join(src.split('\n')[:node.lineno])
|
||||||
self.message += '\n' + ' ' * node.col_offset + '^'
|
self.message += '\n' + ' ' * node.col_offset + '^'
|
||||||
super().__init__(self.message)
|
super().__init__(self.message)
|
||||||
self.args = (src, node)
|
|
||||||
|
|
||||||
|
|
||||||
class OutOfResources(Exception):
|
class OutOfResources(Exception):
|
||||||
@@ -1085,6 +1085,9 @@ class TensorWrapper:
|
|||||||
def data_ptr(self):
|
def data_ptr(self):
|
||||||
return self.base.data_ptr()
|
return self.base.data_ptr()
|
||||||
|
|
||||||
|
def __str__(self) -> str:
|
||||||
|
return f'TensorWrapper[{self.dtype}]({self.base})'
|
||||||
|
|
||||||
|
|
||||||
def reinterpret(tensor, dtype):
|
def reinterpret(tensor, dtype):
|
||||||
return TensorWrapper(tensor, dtype)
|
return TensorWrapper(tensor, dtype)
|
||||||
|
@@ -68,12 +68,20 @@ class dtype:
|
|||||||
def __init__(self, init):
|
def __init__(self, init):
|
||||||
self.init = init
|
self.init = init
|
||||||
|
|
||||||
|
@property
|
||||||
|
def name(self) -> str:
|
||||||
|
# The init functions are named something like 'get_int8'. Strip the prefix.
|
||||||
|
nom = self.init.__name__
|
||||||
|
prefix = 'get_'
|
||||||
|
assert nom.startswith(prefix)
|
||||||
|
return nom[len(prefix):]
|
||||||
|
|
||||||
def handle(self, builder):
|
def handle(self, builder):
|
||||||
ctx = builder.context
|
ctx = builder.context
|
||||||
return self.init(ctx)
|
return self.init(ctx)
|
||||||
|
|
||||||
def __str__(self):
|
def __str__(self):
|
||||||
return f"dtype({self.init.__name__})"
|
return self.name
|
||||||
|
|
||||||
|
|
||||||
class pointer_dtype:
|
class pointer_dtype:
|
||||||
@@ -131,6 +139,10 @@ class block:
|
|||||||
# Data-type wrapper
|
# Data-type wrapper
|
||||||
self.dtype = block._init_dtype(self.handle.type.scalar)
|
self.dtype = block._init_dtype(self.handle.type.scalar)
|
||||||
|
|
||||||
|
def __str__(self) -> str:
|
||||||
|
# ex. "float32[3,4]"
|
||||||
|
return str(self.dtype) + '[' + ','.join(str(s) for s in self.shape) + ']'
|
||||||
|
|
||||||
@builtin
|
@builtin
|
||||||
def __add__(self, other, _builder=None):
|
def __add__(self, other, _builder=None):
|
||||||
return frontend.add(self, other, _builder)
|
return frontend.add(self, other, _builder)
|
||||||
|
Reference in New Issue
Block a user