import torch import triton import triton.language as tl @triton.jit def kernel(X, stride_xm, stride_xn, BLOCK: tl.constexpr): pass X = torch.randn(1, device="cuda") pgm = kernel[(1,)](X, 1, 1, BLOCK=1024)