From 985798f10127fe5c1e918c09e1889de51d65abfe Mon Sep 17 00:00:00 2001 From: Madeleine Thompson Date: Thu, 23 Dec 2021 17:01:17 -0800 Subject: [PATCH] add missing bfloat16 repr and improve assertions (#403) - `BF16TyID` was missing a repr implementation. - Throw a better exception on impossible casts. - Add a few assertions. Tested with a debug build. - Add `pointer_dtype.__str__` to aid kernel debugging. --- include/triton/ir/type.h | 4 ++-- lib/ir/dispatch.cc | 4 ++-- lib/ir/instructions.cc | 3 +-- python/src/functions.h | 2 +- python/src/triton.cc | 1 + python/test/unit/language/test_core.py | 5 ----- python/triton/language/core.py | 5 +++++ 7 files changed, 12 insertions(+), 12 deletions(-) diff --git a/include/triton/ir/type.h b/include/triton/ir/type.h index 7fb14877a..c9c07c4f1 100644 --- a/include/triton/ir/type.h +++ b/include/triton/ir/type.h @@ -131,6 +131,7 @@ public: case FP16TyID: return "f16"; case FP32TyID: return "f32"; case FP64TyID: return "f64"; + case BF16TyID: return "bf16"; case LabelTyID: return "label"; case MetadataTyID: return "md"; case TokenTyID: return "tok"; @@ -141,8 +142,7 @@ public: case BlockTyID: return tile_repr(); default: break; } - assert(false); - return ""; + throw std::logic_error("unknown type id '" + std::to_string(id_) + "'"); }; private: diff --git a/lib/ir/dispatch.cc b/lib/ir/dispatch.cc index c616b2fd4..c4e8ccafb 100644 --- a/lib/ir/dispatch.cc +++ b/lib/ir/dispatch.cc @@ -55,7 +55,7 @@ ir::type *computation_type(ir::type* a_ty, ir::type* b_ty, DivOrMod div_or_mod) } } if (!a_ty->is_integer_ty() || !b_ty->is_integer_ty()) - throw_unreachable("augment_types"); + throw_unreachable("computation_type"); // 4 ) both operands are integer and undergo // integer promotion return integer_promote(a_ty, b_ty); @@ -493,7 +493,7 @@ ir::value *dispatch::cast(ir::value *input, ir::type *dst_ty, ir::builder *build other = builder->create_splat(other, src_ty->get_block_shapes()); return builder->create_icmpNE(input, other); } - return throw_unreachable("cast"); + return throw_unreachable("cast from " + src_sca_ty->repr() + " to " + dst_sca_ty->repr()); } //===----------------------------------------------------------------------===// diff --git a/lib/ir/instructions.cc b/lib/ir/instructions.cc index 32e7674c6..00d801616 100644 --- a/lib/ir/instructions.cc +++ b/lib/ir/instructions.cc @@ -232,6 +232,7 @@ icmp_inst::icmp_inst(type *ty, cmp_pred_t pred, icmp_inst* icmp_inst::create(cmp_pred_t pred, value *lhs, value *rhs, const std::string &name, instruction *next){ assert(is_int_predicate(pred)); + assert(lhs->get_type() == rhs->get_type()); type *res_ty = make_cmp_result_type(lhs->get_type()); return new icmp_inst(res_ty, pred, lhs, rhs, name, next); } @@ -920,7 +921,5 @@ const constant_int* make_range::get_last() const { return last_; } - - } } diff --git a/python/src/functions.h b/python/src/functions.h index 0f5a5c42f..19f7e7eb9 100644 --- a/python/src/functions.h +++ b/python/src/functions.h @@ -105,7 +105,7 @@ ir::value *cast(ir::value *input, type_code _dtype, ir::builder *builder) { other = builder->create_splat(other, src_ty->get_block_shapes()); return builder->create_icmpNE(input, other); } - throw_not_implemented("cast"); + throw_not_implemented("cast from " + src_sca_ty->repr() + " to " + dst_sca_ty->repr()); } /*---------------------------------------------- diff --git a/python/src/triton.cc b/python/src/triton.cc index b44ffbc27..cec6fba94 100644 --- a/python/src/triton.cc +++ b/python/src/triton.cc @@ -187,6 +187,7 @@ void parse_args(py::list& args, py::list do_not_specialize, const std::string& f params_ptr += 8; py::object dtype = arg.attr("dtype"); py::object repr = py::repr(dtype); + assert(!strncmp((const char*)PyUnicode_1BYTE_DATA(repr.ptr()), "torch.", 6)); const char* start = (const char*)PyUnicode_1BYTE_DATA(repr.ptr()) + 6; // remove 'torch.' size_t len = PyUnicode_GET_LENGTH(repr.ptr()) - 6; cache_key += std::string(start, len); diff --git a/python/test/unit/language/test_core.py b/python/test/unit/language/test_core.py index 785ca49ac..aa0e7430a 100644 --- a/python/test/unit/language/test_core.py +++ b/python/test/unit/language/test_core.py @@ -453,11 +453,6 @@ def test_reduce2d(dtype, shape, axis, device='cuda'): # --------------- # test permute # --------------- - -# --------------- -# test permute -# --------------- - @pytest.mark.parametrize("dtype, shape, perm", [(dtype, shape, perm) \ for dtype in ['float32']\ diff --git a/python/triton/language/core.py b/python/triton/language/core.py index 1c28cdef7..e939319aa 100644 --- a/python/triton/language/core.py +++ b/python/triton/language/core.py @@ -86,11 +86,16 @@ class dtype: class pointer_dtype: def __init__(self, element_ty): + if not isinstance(element_ty, dtype): + raise TypeError('element_ty is a {type(element_ty).__name__}.') self.element_ty = element_ty def handle(self, builder): return ir.type.make_ptr(self.element_ty.handle(builder), 1) + def __str__(self): + return f'pointer<{self.element_ty}>' + # scalar types int1 = dtype(ir.type.get_int1) int8 = dtype(ir.type.get_int8)