[LANGUAGE] Added support for bitcast (#119)
This commit is contained in:
committed by
Philippe Tillet
parent
38ab4e955a
commit
59b0ac672a
@@ -263,6 +263,38 @@ def test_atomic_add(dtype_x, device='cuda'):
|
||||
triton.testing.assert_allclose(x_ref, last_sum)
|
||||
|
||||
|
||||
# ---------------
|
||||
# test cast
|
||||
# ---------------
|
||||
@pytest.mark.parametrize("dtype_x, dtype_z, bitcast", [
|
||||
(dtype_x, dtype_z, False) for dtype_x in dtypes \
|
||||
for dtype_z in dtypes
|
||||
] + [
|
||||
('float32', 'int32', True)
|
||||
])
|
||||
def test_cast(dtype_x, dtype_z, bitcast, device='cuda'):
|
||||
x = torch.tensor([43.5], dtype=cvt[dtype_x], device=device)
|
||||
|
||||
# triton kernel
|
||||
@triton.jit
|
||||
def kernel(X, Z, **meta):
|
||||
x = tl.load(X)
|
||||
z = x.to(Z.dtype.element_ty, bitcast=meta['BITCAST'])
|
||||
tl.store(Z, z)
|
||||
|
||||
# triton result
|
||||
z_tri = torch.empty((1, ), dtype=cvt[dtype_z], device=device)
|
||||
kernel[(1, )](x, z_tri, BITCAST=bitcast)
|
||||
# torch result
|
||||
if bitcast:
|
||||
import numpy as np
|
||||
z_ref = x.detach().cpu().numpy().view(getattr(np, dtype_z))
|
||||
z_ref = torch.from_numpy(z_ref).to(device)
|
||||
else:
|
||||
z_ref = x.to(z_tri.dtype)
|
||||
assert z_tri == z_ref
|
||||
|
||||
|
||||
# ---------------
|
||||
# test load
|
||||
# ---------------
|
||||
|
Reference in New Issue
Block a user