[LANGUAGE] Added support for bitcast (#119)
This commit is contained in:
committed by
Philippe Tillet
parent
38ab4e955a
commit
59b0ac672a
@@ -129,6 +129,7 @@ void init_triton_frontend(py::module &&m) {
|
||||
typedef ir::value *(*broadcast_to_ty)(ir::value *, ir::type::block_shapes_t, ir::builder *);
|
||||
m.def("broadcast", (broadcast_ty)(&ir::dispatch::broadcast), ret::reference);
|
||||
m.def("broadcast_to", (broadcast_to_ty)(&ir::dispatch::broadcast), ret::reference);
|
||||
m.def("bitcast", &ir::dispatch::bitcast, ret::reference);
|
||||
m.def("cast", &ir::dispatch::cast, ret::reference);
|
||||
// memory
|
||||
m.def("load", &ir::dispatch::load, ret::reference);
|
||||
|
@@ -263,6 +263,38 @@ def test_atomic_add(dtype_x, device='cuda'):
|
||||
triton.testing.assert_allclose(x_ref, last_sum)
|
||||
|
||||
|
||||
# ---------------
|
||||
# test cast
|
||||
# ---------------
|
||||
@pytest.mark.parametrize("dtype_x, dtype_z, bitcast", [
|
||||
(dtype_x, dtype_z, False) for dtype_x in dtypes \
|
||||
for dtype_z in dtypes
|
||||
] + [
|
||||
('float32', 'int32', True)
|
||||
])
|
||||
def test_cast(dtype_x, dtype_z, bitcast, device='cuda'):
|
||||
x = torch.tensor([43.5], dtype=cvt[dtype_x], device=device)
|
||||
|
||||
# triton kernel
|
||||
@triton.jit
|
||||
def kernel(X, Z, **meta):
|
||||
x = tl.load(X)
|
||||
z = x.to(Z.dtype.element_ty, bitcast=meta['BITCAST'])
|
||||
tl.store(Z, z)
|
||||
|
||||
# triton result
|
||||
z_tri = torch.empty((1, ), dtype=cvt[dtype_z], device=device)
|
||||
kernel[(1, )](x, z_tri, BITCAST=bitcast)
|
||||
# torch result
|
||||
if bitcast:
|
||||
import numpy as np
|
||||
z_ref = x.detach().cpu().numpy().view(getattr(np, dtype_z))
|
||||
z_ref = torch.from_numpy(z_ref).to(device)
|
||||
else:
|
||||
z_ref = x.to(z_tri.dtype)
|
||||
assert z_tri == z_ref
|
||||
|
||||
|
||||
# ---------------
|
||||
# test load
|
||||
# ---------------
|
||||
|
@@ -76,7 +76,7 @@ class pointer_dtype:
|
||||
self.element_ty = element_ty
|
||||
|
||||
def handle(self, builder):
|
||||
return ir.type.make_ptr(self.element_ty, 1)
|
||||
return ir.type.make_ptr(self.element_ty.handle(builder), 1)
|
||||
|
||||
|
||||
int1 = dtype(ir.type.get_int1)
|
||||
@@ -249,8 +249,11 @@ class block:
|
||||
return ret
|
||||
|
||||
@builtin
|
||||
def to(self, dtype, builder=None):
|
||||
return frontend.cast(self, dtype.handle(builder), builder)
|
||||
def to(self, dtype, bitcast=False, builder=None):
|
||||
dtype = dtype.handle(builder)
|
||||
if bitcast:
|
||||
return frontend.bitcast(self, dtype, builder)
|
||||
return frontend.cast(self, dtype, builder)
|
||||
|
||||
|
||||
# -----------------------
|
||||
|
Reference in New Issue
Block a user