[LANGUAGE] Added support for bitcast (#119)

This commit is contained in:
Philippe Tillet
2021-05-21 02:47:53 -04:00
committed by Philippe Tillet
parent 38ab4e955a
commit 59b0ac672a
5 changed files with 57 additions and 3 deletions

View File

@@ -263,6 +263,38 @@ def test_atomic_add(dtype_x, device='cuda'):
triton.testing.assert_allclose(x_ref, last_sum)
# ---------------
# test cast
# ---------------
@pytest.mark.parametrize("dtype_x, dtype_z, bitcast", [
(dtype_x, dtype_z, False) for dtype_x in dtypes \
for dtype_z in dtypes
] + [
('float32', 'int32', True)
])
def test_cast(dtype_x, dtype_z, bitcast, device='cuda'):
x = torch.tensor([43.5], dtype=cvt[dtype_x], device=device)
# triton kernel
@triton.jit
def kernel(X, Z, **meta):
x = tl.load(X)
z = x.to(Z.dtype.element_ty, bitcast=meta['BITCAST'])
tl.store(Z, z)
# triton result
z_tri = torch.empty((1, ), dtype=cvt[dtype_z], device=device)
kernel[(1, )](x, z_tri, BITCAST=bitcast)
# torch result
if bitcast:
import numpy as np
z_ref = x.detach().cpu().numpy().view(getattr(np, dtype_z))
z_ref = torch.from_numpy(z_ref).to(device)
else:
z_ref = x.to(z_tri.dtype)
assert z_tri == z_ref
# ---------------
# test load
# ---------------