diff --git a/include/triton/ir/type.h b/include/triton/ir/type.h index 33c74f245..7fb14877a 100644 --- a/include/triton/ir/type.h +++ b/include/triton/ir/type.h @@ -74,8 +74,8 @@ public: bool is_fp8_ty() const { return id_ == FP8TyID; } bool is_fp16_ty() const { return id_ == FP16TyID; } bool is_bf16_ty() const { return id_ == BF16TyID; } - bool is_fp32_ty() const { return id_ == FP32TyID; } - bool is_fp64_ty() const { return id_ == FP64TyID; } + bool is_fp32_ty() const { return id_ == FP32TyID; } + bool is_fp64_ty() const { return id_ == FP64TyID; } bool is_label_ty() const { return id_ == LabelTyID;} bool is_metadata_ty() const { return id_ == MetadataTyID; } bool is_token_ty() const { return id_ == TokenTyID; } diff --git a/python/triton/code_gen.py b/python/triton/code_gen.py index d00a9d50c..688508265 100644 --- a/python/triton/code_gen.py +++ b/python/triton/code_gen.py @@ -521,10 +521,10 @@ class LoadedBinary: class CompilationError(Exception): def __init__(self, src, node): - self.message = '\n'.join(src.split('\n')[:node.lineno]) + self.message = f'at {node.lineno}:{node.col_offset}:\n' + self.message += '\n'.join(src.split('\n')[:node.lineno]) self.message += '\n' + ' ' * node.col_offset + '^' super().__init__(self.message) - self.args = (src, node) class OutOfResources(Exception): @@ -1085,6 +1085,9 @@ class TensorWrapper: def data_ptr(self): return self.base.data_ptr() + def __str__(self) -> str: + return f'TensorWrapper[{self.dtype}]({self.base})' + def reinterpret(tensor, dtype): return TensorWrapper(tensor, dtype) diff --git a/python/triton/language/core.py b/python/triton/language/core.py index 5db77efdc..1c28cdef7 100644 --- a/python/triton/language/core.py +++ b/python/triton/language/core.py @@ -68,12 +68,20 @@ class dtype: def __init__(self, init): self.init = init + @property + def name(self) -> str: + # The init functions are named something like 'get_int8'. Strip the prefix. + nom = self.init.__name__ + prefix = 'get_' + assert nom.startswith(prefix) + return nom[len(prefix):] + def handle(self, builder): ctx = builder.context return self.init(ctx) def __str__(self): - return f"dtype({self.init.__name__})" + return self.name class pointer_dtype: @@ -131,6 +139,10 @@ class block: # Data-type wrapper self.dtype = block._init_dtype(self.handle.type.scalar) + def __str__(self) -> str: + # ex. "float32[3,4]" + return str(self.dtype) + '[' + ','.join(str(s) for s in self.shape) + ']' + @builtin def __add__(self, other, _builder=None): return frontend.add(self, other, _builder)