add missing bfloat16 repr and improve assertions (#403)
- `BF16TyID` was missing a repr implementation. - Throw a better exception on impossible casts. - Add a few assertions. Tested with a debug build. - Add `pointer_dtype.__str__` to aid kernel debugging.
This commit is contained in:
committed by
GitHub
parent
d8fce83e7a
commit
985798f101
@@ -131,6 +131,7 @@ public:
|
|||||||
case FP16TyID: return "f16";
|
case FP16TyID: return "f16";
|
||||||
case FP32TyID: return "f32";
|
case FP32TyID: return "f32";
|
||||||
case FP64TyID: return "f64";
|
case FP64TyID: return "f64";
|
||||||
|
case BF16TyID: return "bf16";
|
||||||
case LabelTyID: return "label";
|
case LabelTyID: return "label";
|
||||||
case MetadataTyID: return "md";
|
case MetadataTyID: return "md";
|
||||||
case TokenTyID: return "tok";
|
case TokenTyID: return "tok";
|
||||||
@@ -141,8 +142,7 @@ public:
|
|||||||
case BlockTyID: return tile_repr();
|
case BlockTyID: return tile_repr();
|
||||||
default: break;
|
default: break;
|
||||||
}
|
}
|
||||||
assert(false);
|
throw std::logic_error("unknown type id '" + std::to_string(id_) + "'");
|
||||||
return "";
|
|
||||||
};
|
};
|
||||||
|
|
||||||
private:
|
private:
|
||||||
|
@@ -55,7 +55,7 @@ ir::type *computation_type(ir::type* a_ty, ir::type* b_ty, DivOrMod div_or_mod)
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
if (!a_ty->is_integer_ty() || !b_ty->is_integer_ty())
|
if (!a_ty->is_integer_ty() || !b_ty->is_integer_ty())
|
||||||
throw_unreachable("augment_types");
|
throw_unreachable("computation_type");
|
||||||
// 4 ) both operands are integer and undergo
|
// 4 ) both operands are integer and undergo
|
||||||
// integer promotion
|
// integer promotion
|
||||||
return integer_promote(a_ty, b_ty);
|
return integer_promote(a_ty, b_ty);
|
||||||
@@ -493,7 +493,7 @@ ir::value *dispatch::cast(ir::value *input, ir::type *dst_ty, ir::builder *build
|
|||||||
other = builder->create_splat(other, src_ty->get_block_shapes());
|
other = builder->create_splat(other, src_ty->get_block_shapes());
|
||||||
return builder->create_icmpNE(input, other);
|
return builder->create_icmpNE(input, other);
|
||||||
}
|
}
|
||||||
return throw_unreachable("cast");
|
return throw_unreachable("cast from " + src_sca_ty->repr() + " to " + dst_sca_ty->repr());
|
||||||
}
|
}
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
@@ -232,6 +232,7 @@ icmp_inst::icmp_inst(type *ty, cmp_pred_t pred,
|
|||||||
|
|
||||||
icmp_inst* icmp_inst::create(cmp_pred_t pred, value *lhs, value *rhs, const std::string &name, instruction *next){
|
icmp_inst* icmp_inst::create(cmp_pred_t pred, value *lhs, value *rhs, const std::string &name, instruction *next){
|
||||||
assert(is_int_predicate(pred));
|
assert(is_int_predicate(pred));
|
||||||
|
assert(lhs->get_type() == rhs->get_type());
|
||||||
type *res_ty = make_cmp_result_type(lhs->get_type());
|
type *res_ty = make_cmp_result_type(lhs->get_type());
|
||||||
return new icmp_inst(res_ty, pred, lhs, rhs, name, next);
|
return new icmp_inst(res_ty, pred, lhs, rhs, name, next);
|
||||||
}
|
}
|
||||||
@@ -920,7 +921,5 @@ const constant_int* make_range::get_last() const {
|
|||||||
return last_;
|
return last_;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@@ -105,7 +105,7 @@ ir::value *cast(ir::value *input, type_code _dtype, ir::builder *builder) {
|
|||||||
other = builder->create_splat(other, src_ty->get_block_shapes());
|
other = builder->create_splat(other, src_ty->get_block_shapes());
|
||||||
return builder->create_icmpNE(input, other);
|
return builder->create_icmpNE(input, other);
|
||||||
}
|
}
|
||||||
throw_not_implemented("cast");
|
throw_not_implemented("cast from " + src_sca_ty->repr() + " to " + dst_sca_ty->repr());
|
||||||
}
|
}
|
||||||
|
|
||||||
/*----------------------------------------------
|
/*----------------------------------------------
|
||||||
|
@@ -187,6 +187,7 @@ void parse_args(py::list& args, py::list do_not_specialize, const std::string& f
|
|||||||
params_ptr += 8;
|
params_ptr += 8;
|
||||||
py::object dtype = arg.attr("dtype");
|
py::object dtype = arg.attr("dtype");
|
||||||
py::object repr = py::repr(dtype);
|
py::object repr = py::repr(dtype);
|
||||||
|
assert(!strncmp((const char*)PyUnicode_1BYTE_DATA(repr.ptr()), "torch.", 6));
|
||||||
const char* start = (const char*)PyUnicode_1BYTE_DATA(repr.ptr()) + 6; // remove 'torch.'
|
const char* start = (const char*)PyUnicode_1BYTE_DATA(repr.ptr()) + 6; // remove 'torch.'
|
||||||
size_t len = PyUnicode_GET_LENGTH(repr.ptr()) - 6;
|
size_t len = PyUnicode_GET_LENGTH(repr.ptr()) - 6;
|
||||||
cache_key += std::string(start, len);
|
cache_key += std::string(start, len);
|
||||||
|
@@ -453,11 +453,6 @@ def test_reduce2d(dtype, shape, axis, device='cuda'):
|
|||||||
# ---------------
|
# ---------------
|
||||||
# test permute
|
# test permute
|
||||||
# ---------------
|
# ---------------
|
||||||
|
|
||||||
# ---------------
|
|
||||||
# test permute
|
|
||||||
# ---------------
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("dtype, shape, perm",
|
@pytest.mark.parametrize("dtype, shape, perm",
|
||||||
[(dtype, shape, perm) \
|
[(dtype, shape, perm) \
|
||||||
for dtype in ['float32']\
|
for dtype in ['float32']\
|
||||||
|
@@ -86,11 +86,16 @@ class dtype:
|
|||||||
|
|
||||||
class pointer_dtype:
|
class pointer_dtype:
|
||||||
def __init__(self, element_ty):
|
def __init__(self, element_ty):
|
||||||
|
if not isinstance(element_ty, dtype):
|
||||||
|
raise TypeError('element_ty is a {type(element_ty).__name__}.')
|
||||||
self.element_ty = element_ty
|
self.element_ty = element_ty
|
||||||
|
|
||||||
def handle(self, builder):
|
def handle(self, builder):
|
||||||
return ir.type.make_ptr(self.element_ty.handle(builder), 1)
|
return ir.type.make_ptr(self.element_ty.handle(builder), 1)
|
||||||
|
|
||||||
|
def __str__(self):
|
||||||
|
return f'pointer<{self.element_ty}>'
|
||||||
|
|
||||||
# scalar types
|
# scalar types
|
||||||
int1 = dtype(ir.type.get_int1)
|
int1 = dtype(ir.type.get_int1)
|
||||||
int8 = dtype(ir.type.get_int8)
|
int8 = dtype(ir.type.get_int8)
|
||||||
|
Reference in New Issue
Block a user