fix for test_cast
This commit is contained in:
@@ -740,8 +740,6 @@ def test_cast(dtype_x, dtype_z, bitcast, device='cuda'):
|
|||||||
x = np.array([x0], dtype=getattr(np, dtype_x))
|
x = np.array([x0], dtype=getattr(np, dtype_x))
|
||||||
x_tri = to_triton(x)
|
x_tri = to_triton(x)
|
||||||
|
|
||||||
SIZE = 1024
|
|
||||||
x = triton.testing.random((SIZE, ), dtype=cvt[dtype_x], device=device)
|
|
||||||
# triton kernel
|
# triton kernel
|
||||||
@triton.jit
|
@triton.jit
|
||||||
def kernel(X, Z, BITCAST: tl.constexpr):
|
def kernel(X, Z, BITCAST: tl.constexpr):
|
||||||
|
Reference in New Issue
Block a user