add missing bfloat16 repr and improve assertions (#403)
- `BF16TyID` was missing a repr implementation. - Throw a better exception on impossible casts. - Add a few assertions. Tested with a debug build. - Add `pointer_dtype.__str__` to aid kernel debugging.
This commit is contained in:
committed by
GitHub
parent
d8fce83e7a
commit
985798f101
@@ -105,7 +105,7 @@ ir::value *cast(ir::value *input, type_code _dtype, ir::builder *builder) {
|
||||
other = builder->create_splat(other, src_ty->get_block_shapes());
|
||||
return builder->create_icmpNE(input, other);
|
||||
}
|
||||
throw_not_implemented("cast");
|
||||
throw_not_implemented("cast from " + src_sca_ty->repr() + " to " + dst_sca_ty->repr());
|
||||
}
|
||||
|
||||
/*----------------------------------------------
|
||||
|
@@ -187,6 +187,7 @@ void parse_args(py::list& args, py::list do_not_specialize, const std::string& f
|
||||
params_ptr += 8;
|
||||
py::object dtype = arg.attr("dtype");
|
||||
py::object repr = py::repr(dtype);
|
||||
assert(!strncmp((const char*)PyUnicode_1BYTE_DATA(repr.ptr()), "torch.", 6));
|
||||
const char* start = (const char*)PyUnicode_1BYTE_DATA(repr.ptr()) + 6; // remove 'torch.'
|
||||
size_t len = PyUnicode_GET_LENGTH(repr.ptr()) - 6;
|
||||
cache_key += std::string(start, len);
|
||||
|
@@ -453,11 +453,6 @@ def test_reduce2d(dtype, shape, axis, device='cuda'):
|
||||
# ---------------
|
||||
# test permute
|
||||
# ---------------
|
||||
|
||||
# ---------------
|
||||
# test permute
|
||||
# ---------------
|
||||
|
||||
@pytest.mark.parametrize("dtype, shape, perm",
|
||||
[(dtype, shape, perm) \
|
||||
for dtype in ['float32']\
|
||||
|
@@ -86,11 +86,16 @@ class dtype:
|
||||
|
||||
class pointer_dtype:
|
||||
def __init__(self, element_ty):
|
||||
if not isinstance(element_ty, dtype):
|
||||
raise TypeError('element_ty is a {type(element_ty).__name__}.')
|
||||
self.element_ty = element_ty
|
||||
|
||||
def handle(self, builder):
|
||||
return ir.type.make_ptr(self.element_ty.handle(builder), 1)
|
||||
|
||||
def __str__(self):
|
||||
return f'pointer<{self.element_ty}>'
|
||||
|
||||
# scalar types
|
||||
int1 = dtype(ir.type.get_int1)
|
||||
int8 = dtype(ir.type.get_int8)
|
||||
|
Reference in New Issue
Block a user