[FRONTEND] Backport new runtime from master
(#706)
This PR merges the new runtime back into the `triton-mlir` branch. This adds caching and just-in-time compilation functionality to the triton-mlir project, and paves the way for re-using tests from the master branch.
This commit is contained in:
@@ -2,8 +2,9 @@ import triton
|
||||
import triton.language as tl
|
||||
|
||||
|
||||
# TODO: function with no arguments don't work
|
||||
@triton.jit
|
||||
def cast_check():
|
||||
def cast_check(X):
|
||||
zero_0d = tl.zeros([], dtype=tl.float32)
|
||||
zero_1d = tl.zeros([2], dtype=tl.float32)
|
||||
zero_2d_21 = tl.zeros([2, 1], dtype=tl.float32)
|
||||
@@ -48,9 +49,9 @@ def cast_check():
|
||||
|
||||
|
||||
def test_cast_check():
|
||||
kernel = triton.compile(cast_check,
|
||||
signature="",
|
||||
device=0,
|
||||
output="ttir")
|
||||
kernel = triton.compiler._compile(cast_check,
|
||||
signature="*fp32",
|
||||
device=0,
|
||||
output="ttgir")
|
||||
assert (kernel)
|
||||
# TODO: Check types of the results
|
||||
|
@@ -2,7 +2,6 @@ import torch
|
||||
|
||||
import triton
|
||||
import triton.language as tl
|
||||
import triton.runtime as runtime
|
||||
|
||||
# trigger the torch.device implicitly to ensure cuda context initialization
|
||||
torch.zeros([10], device=torch.device('cuda'))
|
||||
@@ -16,30 +15,18 @@ def empty_kernel(X, stride_xm, BLOCK: tl.constexpr):
|
||||
def test_empty_kernel_cubin_compile():
|
||||
|
||||
device = torch.cuda.current_device()
|
||||
cubin = triton.compile(empty_kernel,
|
||||
"*fp32,i32,i32",
|
||||
device=device,
|
||||
constants={"BLOCK": 256},
|
||||
output="cubin")
|
||||
kernel = triton.compile(empty_kernel,
|
||||
"*fp32,i32,i32",
|
||||
device=device,
|
||||
constants={"BLOCK": 256})
|
||||
|
||||
print('cubin size:', len(cubin))
|
||||
assert len(cubin) > 0
|
||||
assert len(kernel.asm["cubin"]) > 0
|
||||
|
||||
|
||||
def test_empty_kernel_launch():
|
||||
device = torch.cuda.current_device()
|
||||
binary = runtime.build_kernel(empty_kernel, "*fp32,i32,i32",
|
||||
constants={"BLOCK": 256},
|
||||
num_warps=4,
|
||||
num_stages=3)
|
||||
grid = lambda META: (
|
||||
triton.cdiv(1024, META['BLOCK']) * triton.cdiv(1024, META['BLOCK']),
|
||||
)
|
||||
|
||||
A = torch.zeros([1024], device="cuda")
|
||||
runtime.launch_kernel(kernel=binary,
|
||||
grid=grid,
|
||||
device=device,
|
||||
X=A,
|
||||
stride_xm=256,
|
||||
BLOCK=tl.constexpr(256))
|
||||
empty_kernel[grid](X=A, stride_xm=256, BLOCK=256)
|
||||
|
@@ -23,11 +23,11 @@ def math_kernel(x1_ptr, x2_ptr, x3_ptr, x4_ptr, n, BLOCK_SIZE: tl.constexpr):
|
||||
|
||||
|
||||
def test_empty_kernel_cubin_compile():
|
||||
kernel = triton.compile(math_kernel,
|
||||
"*fp32,*fp32,*fp32,*fp32,i32",
|
||||
device=0,
|
||||
constants={"BLOCK_SIZE": 256},
|
||||
output="ttgir") # "cubin"
|
||||
kernel = triton.compiler._compile(math_kernel,
|
||||
"*fp32,*fp32,*fp32,*fp32,i32",
|
||||
device=0,
|
||||
constants={"BLOCK_SIZE": 256},
|
||||
output="ttgir") # "cubin"
|
||||
assert kernel
|
||||
# TODO: Check if the values are correct.
|
||||
# TODO: Cover all the math operators
|
||||
|
@@ -4,7 +4,6 @@ from torch.testing import assert_allclose
|
||||
|
||||
import triton
|
||||
import triton.language as tl
|
||||
import triton.runtime as runtime
|
||||
|
||||
|
||||
@triton.jit
|
||||
@@ -40,29 +39,9 @@ def kernel(x_ptr, stride_xm,
|
||||
[2, 128, 64]
|
||||
])
|
||||
def test_convert_layout_impl(NUM_WARPS, SIZE_M, SIZE_N):
|
||||
# TODO: this is to initialize the cuda context since it is not properly
|
||||
# dealed with in the existing runtime, remove this when the runtime
|
||||
# is updated
|
||||
torch.zeros([10], device=torch.device('cuda'))
|
||||
device = torch.cuda.current_device()
|
||||
binary = runtime.build_kernel(kernel,
|
||||
"*fp32,i32,*fp32,i32",
|
||||
constants={"SIZE_M": SIZE_M,
|
||||
"SIZE_N": SIZE_N},
|
||||
num_warps=NUM_WARPS,
|
||||
num_stages=3)
|
||||
grid = lambda META: (1, )
|
||||
|
||||
x = torch.randn((SIZE_M, SIZE_N), device='cuda', dtype=torch.float32)
|
||||
z = torch.empty((SIZE_N, SIZE_M), device=x.device, dtype=x.dtype)
|
||||
runtime.launch_kernel(kernel=binary,
|
||||
device=device,
|
||||
grid=grid,
|
||||
x_ptr=x,
|
||||
stride_xm=x.stride(0),
|
||||
z_ptr=z,
|
||||
stride_zn=z.stride(0),
|
||||
SIZE_M=tl.constexpr(SIZE_M),
|
||||
SIZE_N=tl.constexpr(SIZE_N))
|
||||
kernel[grid](x_ptr=x, stride_xm=x.stride(0), z_ptr=z, stride_zn=z.stride(0), SIZE_M=SIZE_M, SIZE_N=SIZE_N, num_warps=NUM_WARPS)
|
||||
golden_z = torch.t(x)
|
||||
assert_allclose(z, golden_z, rtol=1e-7, atol=1e-7)
|
||||
|
@@ -3,7 +3,6 @@ from torch.testing import assert_allclose
|
||||
|
||||
import triton
|
||||
import triton.language as tl
|
||||
import triton.runtime as runtime
|
||||
|
||||
|
||||
def vecadd_no_scf_tester(num_warps, block_size):
|
||||
@@ -22,27 +21,13 @@ def vecadd_no_scf_tester(num_warps, block_size):
|
||||
z_ptrs = z_ptr + offset
|
||||
tl.store(z_ptrs, z)
|
||||
|
||||
torch.zeros([10], device=torch.device('cuda'))
|
||||
device = torch.cuda.current_device()
|
||||
binary = runtime.build_kernel(kernel, "*fp32,*fp32,*fp32,i32",
|
||||
constants={"BLOCK_SIZE_N": block_size},
|
||||
num_warps=num_warps,
|
||||
num_stages=3)
|
||||
|
||||
x = torch.randn((block_size,), device='cuda', dtype=torch.float32)
|
||||
y = torch.randn((block_size,), device='cuda', dtype=torch.float32)
|
||||
z = torch.empty((block_size,), device=x.device, dtype=x.dtype)
|
||||
|
||||
assert x.shape.numel() % block_size == 0, "Only test load without mask here"
|
||||
grid = lambda EA: (x.shape.numel() // block_size,)
|
||||
kernel[grid](x_ptr=x, y_ptr=y, z_ptr=z, BLOCK_SIZE_N=block_size, num_warps=num_warps)
|
||||
|
||||
runtime.launch_kernel(kernel=binary,
|
||||
grid=grid,
|
||||
device=device,
|
||||
x_ptr=x,
|
||||
y_ptr=y,
|
||||
z_ptr=z,
|
||||
BLOCK_SIZE_N=tl.constexpr(block_size))
|
||||
golden_z = x + y
|
||||
assert_allclose(z, golden_z, rtol=1e-7, atol=1e-7)
|
||||
|
||||
|
Reference in New Issue
Block a user