import torch
import triton
import triton.language as tl
@triton.jit
def kernel(X, stride_xm, stride_xn, BLOCK: tl.constexpr):
pass
X = torch.randn(1, device="cuda")
pgm = kernel[(1,)](X, 1, 1, BLOCK=1024)