fix for test_cast

This commit is contained in:
Michael Melesse
2022-10-26 21:34:58 +00:00
parent 8ecab462f6
commit ed9638801a

View File

@@ -740,8 +740,6 @@ def test_cast(dtype_x, dtype_z, bitcast, device='cuda'):
x = np.array([x0], dtype=getattr(np, dtype_x))
x_tri = to_triton(x)
SIZE = 1024
x = triton.testing.random((SIZE, ), dtype=cvt[dtype_x], device=device)
# triton kernel
@triton.jit
def kernel(X, Z, BITCAST: tl.constexpr):