add missing bfloat16 repr and improve assertions (#403)

- `BF16TyID` was missing a repr implementation.
- Throw a better exception on impossible casts.
- Add a few assertions. Tested with a debug build.
- Add `pointer_dtype.__str__` to aid kernel debugging.
This commit is contained in:
Madeleine Thompson
2021-12-23 17:01:17 -08:00
committed by GitHub
parent d8fce83e7a
commit 985798f101
7 changed files with 12 additions and 12 deletions

View File

@@ -105,7 +105,7 @@ ir::value *cast(ir::value *input, type_code _dtype, ir::builder *builder) {
other = builder->create_splat(other, src_ty->get_block_shapes());
return builder->create_icmpNE(input, other);
}
throw_not_implemented("cast");
throw_not_implemented("cast from " + src_sca_ty->repr() + " to " + dst_sca_ty->repr());
}
/*----------------------------------------------

View File

@@ -187,6 +187,7 @@ void parse_args(py::list& args, py::list do_not_specialize, const std::string& f
params_ptr += 8;
py::object dtype = arg.attr("dtype");
py::object repr = py::repr(dtype);
assert(!strncmp((const char*)PyUnicode_1BYTE_DATA(repr.ptr()), "torch.", 6));
const char* start = (const char*)PyUnicode_1BYTE_DATA(repr.ptr()) + 6; // remove 'torch.'
size_t len = PyUnicode_GET_LENGTH(repr.ptr()) - 6;
cache_key += std::string(start, len);