[LANGUAGE] Added support for bitcast (#119)

This commit is contained in:
Philippe Tillet
2021-05-21 02:47:53 -04:00
committed by Philippe Tillet
parent 38ab4e955a
commit 59b0ac672a
5 changed files with 57 additions and 3 deletions

View File

@@ -129,6 +129,7 @@ void init_triton_frontend(py::module &&m) {
typedef ir::value *(*broadcast_to_ty)(ir::value *, ir::type::block_shapes_t, ir::builder *);
m.def("broadcast", (broadcast_ty)(&ir::dispatch::broadcast), ret::reference);
m.def("broadcast_to", (broadcast_to_ty)(&ir::dispatch::broadcast), ret::reference);
m.def("bitcast", &ir::dispatch::bitcast, ret::reference);
m.def("cast", &ir::dispatch::cast, ret::reference);
// memory
m.def("load", &ir::dispatch::load, ret::reference);

View File

@@ -263,6 +263,38 @@ def test_atomic_add(dtype_x, device='cuda'):
triton.testing.assert_allclose(x_ref, last_sum)
# ---------------
# test cast
# ---------------
@pytest.mark.parametrize("dtype_x, dtype_z, bitcast", [
(dtype_x, dtype_z, False) for dtype_x in dtypes \
for dtype_z in dtypes
] + [
('float32', 'int32', True)
])
def test_cast(dtype_x, dtype_z, bitcast, device='cuda'):
x = torch.tensor([43.5], dtype=cvt[dtype_x], device=device)
# triton kernel
@triton.jit
def kernel(X, Z, **meta):
x = tl.load(X)
z = x.to(Z.dtype.element_ty, bitcast=meta['BITCAST'])
tl.store(Z, z)
# triton result
z_tri = torch.empty((1, ), dtype=cvt[dtype_z], device=device)
kernel[(1, )](x, z_tri, BITCAST=bitcast)
# torch result
if bitcast:
import numpy as np
z_ref = x.detach().cpu().numpy().view(getattr(np, dtype_z))
z_ref = torch.from_numpy(z_ref).to(device)
else:
z_ref = x.to(z_tri.dtype)
assert z_tri == z_ref
# ---------------
# test load
# ---------------

View File

@@ -76,7 +76,7 @@ class pointer_dtype:
self.element_ty = element_ty
def handle(self, builder):
return ir.type.make_ptr(self.element_ty, 1)
return ir.type.make_ptr(self.element_ty.handle(builder), 1)
int1 = dtype(ir.type.get_int1)
@@ -249,8 +249,11 @@ class block:
return ret
@builtin
def to(self, dtype, builder=None):
return frontend.cast(self, dtype.handle(builder), builder)
def to(self, dtype, bitcast=False, builder=None):
dtype = dtype.handle(builder)
if bitcast:
return frontend.bitcast(self, dtype, builder)
return frontend.cast(self, dtype, builder)
# -----------------------