[LANGUAGE] Added support for bitcast (#119)

This commit is contained in:
Philippe Tillet
2021-05-21 02:47:53 -04:00
committed by Philippe Tillet
parent 38ab4e955a
commit 59b0ac672a
5 changed files with 57 additions and 3 deletions

View File

@@ -62,6 +62,7 @@ struct dispatch{
static ir::value *reshape(ir::value *input, shape_t shape, ir::builder *builder);
static ir::value *broadcast(ir::value *input, shape_t shape, ir::builder *builder);
static std::tuple<ir::value*, ir::value*> broadcast(ir::value *lhs, ir::value* rhs, ir::builder *builder);
static ir::value *bitcast(ir::value *input, ir::type *type, ir::builder *builder);
static ir::value *cast(ir::value *input, ir::type *type, ir::builder *builder);
// memory operators

View File

@@ -414,6 +414,23 @@ std::tuple<ir::value*, ir::value*> dispatch::broadcast(ir::value *lhs, ir::value
return std::make_tuple(lhs, rhs);
}
ir::value *dispatch::bitcast(ir::value *input, ir::type *dst_ty, ir::builder *builder){
ir::type *src_ty = input->get_type();
if (src_ty->is_block_ty())
dst_ty = ir::block_type::get(dst_ty, input->get_type()->get_block_shapes());
if(src_ty == dst_ty)
return input;
ir::type *src_sca_ty = src_ty->get_scalar_ty();
ir::type *dst_sca_ty = dst_ty->get_scalar_ty();
// Bitcast
int src_bits = src_sca_ty->get_primitive_size_in_bits();
int dst_bits = dst_sca_ty->get_primitive_size_in_bits();
if( src_bits!= dst_bits)
throw std::runtime_error("Cannot bitcast data-type of size " + std::to_string(src_bits) +
"to data-type of size " + std::to_string(dst_bits));
return builder->create_cast(ir::BitCast, input, dst_ty);
}
ir::value *dispatch::cast(ir::value *input, ir::type *dst_ty, ir::builder *builder) {
ir::type *src_ty = input->get_type();
if (src_ty->is_block_ty())

View File

@@ -129,6 +129,7 @@ void init_triton_frontend(py::module &&m) {
typedef ir::value *(*broadcast_to_ty)(ir::value *, ir::type::block_shapes_t, ir::builder *);
m.def("broadcast", (broadcast_ty)(&ir::dispatch::broadcast), ret::reference);
m.def("broadcast_to", (broadcast_to_ty)(&ir::dispatch::broadcast), ret::reference);
m.def("bitcast", &ir::dispatch::bitcast, ret::reference);
m.def("cast", &ir::dispatch::cast, ret::reference);
// memory
m.def("load", &ir::dispatch::load, ret::reference);

View File

@@ -263,6 +263,38 @@ def test_atomic_add(dtype_x, device='cuda'):
triton.testing.assert_allclose(x_ref, last_sum)
# ---------------
# test cast
# ---------------
@pytest.mark.parametrize("dtype_x, dtype_z, bitcast", [
(dtype_x, dtype_z, False) for dtype_x in dtypes \
for dtype_z in dtypes
] + [
('float32', 'int32', True)
])
def test_cast(dtype_x, dtype_z, bitcast, device='cuda'):
x = torch.tensor([43.5], dtype=cvt[dtype_x], device=device)
# triton kernel
@triton.jit
def kernel(X, Z, **meta):
x = tl.load(X)
z = x.to(Z.dtype.element_ty, bitcast=meta['BITCAST'])
tl.store(Z, z)
# triton result
z_tri = torch.empty((1, ), dtype=cvt[dtype_z], device=device)
kernel[(1, )](x, z_tri, BITCAST=bitcast)
# torch result
if bitcast:
import numpy as np
z_ref = x.detach().cpu().numpy().view(getattr(np, dtype_z))
z_ref = torch.from_numpy(z_ref).to(device)
else:
z_ref = x.to(z_tri.dtype)
assert z_tri == z_ref
# ---------------
# test load
# ---------------

View File

@@ -76,7 +76,7 @@ class pointer_dtype:
self.element_ty = element_ty
def handle(self, builder):
return ir.type.make_ptr(self.element_ty, 1)
return ir.type.make_ptr(self.element_ty.handle(builder), 1)
int1 = dtype(ir.type.get_int1)
@@ -249,8 +249,11 @@ class block:
return ret
@builtin
def to(self, dtype, builder=None):
return frontend.cast(self, dtype.handle(builder), builder)
def to(self, dtype, bitcast=False, builder=None):
dtype = dtype.handle(builder)
if bitcast:
return frontend.bitcast(self, dtype, builder)
return frontend.cast(self, dtype, builder)
# -----------------------