[Triton-MLIR][Backend] Add SCF lowering in the backend (#750)
This commit is contained in:
@@ -3,6 +3,7 @@
|
||||
#include "mlir/IR/MLIRContext.h"
|
||||
#include "mlir/IR/Verifier.h"
|
||||
|
||||
#include "mlir/Conversion/Passes.h"
|
||||
#include "mlir/Pass/Pass.h"
|
||||
#include "mlir/Pass/PassManager.h"
|
||||
#include "mlir/Transforms/Passes.h"
|
||||
@@ -1185,8 +1186,12 @@ void init_triton_ir(py::module &&m) {
|
||||
[](mlir::PassManager &self) {
|
||||
self.addPass(mlir::createTritonGPUVerifier());
|
||||
})
|
||||
.def("add_triton_gpu_to_llvm", [](mlir::PassManager &self) {
|
||||
self.addPass(mlir::triton::createConvertTritonGPUToLLVMPass());
|
||||
.def("add_triton_gpu_to_llvm",
|
||||
[](mlir::PassManager &self) {
|
||||
self.addPass(mlir::triton::createConvertTritonGPUToLLVMPass());
|
||||
})
|
||||
.def("add_scf_to_cfg", [](mlir::PassManager &self) {
|
||||
self.addPass(mlir::createLowerToCFGPass());
|
||||
});
|
||||
}
|
||||
|
||||
|
79
python/tests/test_vecadd.py
Normal file
79
python/tests/test_vecadd.py
Normal file
@@ -0,0 +1,79 @@
|
||||
import pytest
|
||||
import torch
|
||||
from torch.testing import assert_allclose
|
||||
|
||||
import triton
|
||||
import triton.language as tl
|
||||
|
||||
|
||||
@pytest.mark.parametrize('NUM_WARPS, BLOCK_SIZE', [
|
||||
[4, 256],
|
||||
[2, 256],
|
||||
[1, 256],
|
||||
])
|
||||
def test_vecadd_no_mask(NUM_WARPS, BLOCK_SIZE):
|
||||
|
||||
@triton.jit
|
||||
def kernel(x_ptr,
|
||||
y_ptr,
|
||||
z_ptr,
|
||||
BLOCK_SIZE: tl.constexpr):
|
||||
pid = tl.program_id(axis=0)
|
||||
offset = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
|
||||
x_ptrs = x_ptr + offset
|
||||
y_ptrs = y_ptr + offset
|
||||
x = tl.load(x_ptrs)
|
||||
y = tl.load(y_ptrs)
|
||||
z = x + y
|
||||
z_ptrs = z_ptr + offset
|
||||
tl.store(z_ptrs, z)
|
||||
|
||||
x = torch.randn((BLOCK_SIZE,), device='cuda', dtype=torch.float32)
|
||||
y = torch.randn((BLOCK_SIZE,), device='cuda', dtype=torch.float32)
|
||||
z = torch.empty((BLOCK_SIZE,), device=x.device, dtype=x.dtype)
|
||||
|
||||
grid = lambda EA: (x.shape.numel() // BLOCK_SIZE,)
|
||||
kernel[grid](x_ptr=x, y_ptr=y, z_ptr=z, BLOCK_SIZE=BLOCK_SIZE, num_warps=NUM_WARPS)
|
||||
|
||||
golden_z = x + y
|
||||
assert_allclose(z, golden_z, rtol=1e-7, atol=1e-7)
|
||||
|
||||
|
||||
@pytest.mark.parametrize('NUM_WARPS, BLOCK_SIZE, ITER_SIZE', [
|
||||
[4, 256, 1],
|
||||
[4, 1024, 256],
|
||||
])
|
||||
def test_vecadd_scf_no_mask(NUM_WARPS, BLOCK_SIZE, ITER_SIZE):
|
||||
|
||||
@triton.jit
|
||||
def kernel(x_ptr,
|
||||
y_ptr,
|
||||
z_ptr,
|
||||
BLOCK_SIZE,
|
||||
ITER_SIZE: tl.constexpr):
|
||||
pid = tl.program_id(axis=0)
|
||||
for i in range(0, BLOCK_SIZE, ITER_SIZE):
|
||||
offset = pid * BLOCK_SIZE + tl.arange(0, ITER_SIZE)
|
||||
x_ptrs = x_ptr + offset
|
||||
y_ptrs = y_ptr + offset
|
||||
x = tl.load(x_ptrs)
|
||||
y = tl.load(y_ptrs)
|
||||
z = x + y
|
||||
z_ptrs = z_ptr + offset
|
||||
tl.store(z_ptrs, z)
|
||||
x_ptr += ITER_SIZE
|
||||
y_ptr += ITER_SIZE
|
||||
z_ptr += ITER_SIZE
|
||||
|
||||
x = torch.randn((BLOCK_SIZE,), device='cuda', dtype=torch.float32)
|
||||
y = torch.randn((BLOCK_SIZE,), device='cuda', dtype=torch.float32)
|
||||
z = torch.empty((BLOCK_SIZE,), device=x.device, dtype=x.dtype)
|
||||
|
||||
grid = lambda EA: (x.shape.numel() // (BLOCK_SIZE),)
|
||||
kernel[grid](x_ptr=x, y_ptr=y, z_ptr=z,
|
||||
BLOCK_SIZE=x.shape[0], ITER_SIZE=ITER_SIZE, num_warps=NUM_WARPS)
|
||||
|
||||
golden_z = x + y
|
||||
assert_allclose(z, golden_z, rtol=1e-7, atol=1e-7)
|
||||
|
||||
# TODO: test_vecadd with mask
|
@@ -1,42 +0,0 @@
|
||||
import torch
|
||||
from torch.testing import assert_allclose
|
||||
|
||||
import triton
|
||||
import triton.language as tl
|
||||
|
||||
|
||||
def vecadd_no_scf_tester(num_warps, block_size):
|
||||
@triton.jit
|
||||
def kernel(x_ptr,
|
||||
y_ptr,
|
||||
z_ptr,
|
||||
BLOCK_SIZE_N: tl.constexpr):
|
||||
pid = tl.program_id(axis=0)
|
||||
offset = pid * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
|
||||
x_ptrs = x_ptr + offset
|
||||
y_ptrs = y_ptr + offset
|
||||
x = tl.load(x_ptrs)
|
||||
y = tl.load(y_ptrs)
|
||||
z = x + y
|
||||
z_ptrs = z_ptr + offset
|
||||
tl.store(z_ptrs, z)
|
||||
|
||||
x = torch.randn((block_size,), device='cuda', dtype=torch.float32)
|
||||
y = torch.randn((block_size,), device='cuda', dtype=torch.float32)
|
||||
z = torch.empty((block_size,), device=x.device, dtype=x.dtype)
|
||||
|
||||
grid = lambda EA: (x.shape.numel() // block_size,)
|
||||
kernel[grid](x_ptr=x, y_ptr=y, z_ptr=z, BLOCK_SIZE_N=block_size, num_warps=num_warps)
|
||||
|
||||
golden_z = x + y
|
||||
assert_allclose(z, golden_z, rtol=1e-7, atol=1e-7)
|
||||
|
||||
|
||||
def test_vecadd_no_scf():
|
||||
vecadd_no_scf_tester(num_warps=4, block_size=256)
|
||||
vecadd_no_scf_tester(num_warps=2, block_size=256)
|
||||
vecadd_no_scf_tester(num_warps=1, block_size=256)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_vecadd_no_scf()
|
Reference in New Issue
Block a user