[Backend] Vectorize Load/Store Ops (#86)
This PR does the following things: - Code refactoring on Load and Store op codegen, rewrite with same logic and share much code - Support the vectorized load/store
This commit is contained in:
@@ -1649,6 +1649,10 @@ void init_triton_ir(py::module &&m) {
|
||||
.def(
|
||||
"add_sccp_pass",
|
||||
[](mlir::PassManager &self) { self.addPass(mlir::createSCCPPass()); })
|
||||
.def("add_coalesce_pass",
|
||||
[](mlir::PassManager &self) {
|
||||
self.addPass(mlir::createTritonGPUCoalescePass());
|
||||
})
|
||||
.def("add_symbol_dce_pass",
|
||||
[](mlir::PassManager &self) {
|
||||
self.addPass(mlir::createSymbolDCEPass());
|
||||
|
@@ -29,7 +29,6 @@ def test_empty_kernel_cubin_compile():
|
||||
def test_empty_kernel_launch():
|
||||
device = torch.cuda.current_device()
|
||||
binary = runtime.build_kernel(empty_kernel, "*fp32,i32,i32",
|
||||
device=device,
|
||||
constants={"BLOCK": 256},
|
||||
num_warps=4,
|
||||
num_stages=3)
|
||||
@@ -38,11 +37,9 @@ def test_empty_kernel_launch():
|
||||
)
|
||||
|
||||
A = torch.zeros([1024], device="cuda")
|
||||
runtime.launch_kernel(fn=empty_kernel,
|
||||
binary=binary,
|
||||
runtime.launch_kernel(kernel=binary,
|
||||
grid=grid,
|
||||
num_warps=4,
|
||||
num_stages=3,
|
||||
device=device,
|
||||
X=A,
|
||||
stride_xm=256,
|
||||
BLOCK=tl.constexpr(256))
|
||||
|
@@ -5,17 +5,12 @@ import triton
|
||||
import triton.language as tl
|
||||
import triton.runtime as runtime
|
||||
|
||||
NUM_WARPS = 4
|
||||
BLOCK_SIZE = 256
|
||||
|
||||
# triton kernel
|
||||
|
||||
|
||||
def test_vecadd_no_scf():
|
||||
def vecadd_no_scf_tester(num_warps, block_size):
|
||||
@triton.jit
|
||||
def kernel(x_ptr, stride_xn,
|
||||
y_ptr, stride_yn,
|
||||
z_ptr, stride_zn,
|
||||
def kernel(x_ptr,
|
||||
y_ptr,
|
||||
z_ptr,
|
||||
BLOCK_SIZE_N: tl.constexpr):
|
||||
pid = tl.program_id(axis=0)
|
||||
offset = pid * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
|
||||
@@ -27,37 +22,35 @@ def test_vecadd_no_scf():
|
||||
z_ptrs = z_ptr + offset
|
||||
tl.store(z_ptrs, z)
|
||||
|
||||
# TODO: add this to CI, to make sure the the compilation flow is at lease OK
|
||||
# before we have GPU machines for CI.
|
||||
# ptx, shem_size, kernel_name = triton.compile(kernel,
|
||||
# "*fp32,i32,*fp32,i32,*fp32,i32",
|
||||
# constants={"BLOCK_SIZE_N": 256},
|
||||
# num_warps=NUM_WARPS,
|
||||
# device=0, output="ptx")
|
||||
|
||||
torch.zeros([10], device=torch.device('cuda'))
|
||||
device = torch.cuda.current_device()
|
||||
binary = runtime.build_kernel(kernel, "*fp32,i32,*fp32,i32,*fp32,i32",
|
||||
device=device,
|
||||
constants={"BLOCK_SIZE_N": BLOCK_SIZE},
|
||||
num_warps=NUM_WARPS,
|
||||
binary = runtime.build_kernel(kernel, "*fp32,*fp32,*fp32,i32",
|
||||
constants={"BLOCK_SIZE_N": block_size},
|
||||
num_warps=num_warps,
|
||||
num_stages=3)
|
||||
grid = lambda META: (1, )
|
||||
|
||||
x = torch.randn((256,), device='cuda', dtype=torch.float32)
|
||||
y = torch.randn((256,), device='cuda', dtype=torch.float32)
|
||||
z = torch.empty((256,), device=x.device, dtype=x.dtype)
|
||||
runtime.launch_kernel(fn=kernel,
|
||||
binary=binary,
|
||||
x = torch.randn((block_size,), device='cuda', dtype=torch.float32)
|
||||
y = torch.randn((block_size,), device='cuda', dtype=torch.float32)
|
||||
z = torch.empty((block_size,), device=x.device, dtype=x.dtype)
|
||||
|
||||
assert x.shape.numel() % block_size == 0, "Only test load without mask here"
|
||||
grid = lambda EA: (x.shape.numel() // block_size,)
|
||||
|
||||
runtime.launch_kernel(kernel=binary,
|
||||
grid=grid,
|
||||
num_warps=NUM_WARPS,
|
||||
num_stages=3,
|
||||
device=device,
|
||||
x_ptr=x,
|
||||
stride_xn=x.stride(0),
|
||||
y_ptr=y,
|
||||
stride_yn=y.stride(0),
|
||||
z_ptr=z,
|
||||
stride_zn=z.stride(0),
|
||||
BLOCK_SIZE_N=tl.constexpr(BLOCK_SIZE))
|
||||
BLOCK_SIZE_N=tl.constexpr(block_size))
|
||||
golden_z = x + y
|
||||
assert_allclose(z, golden_z, rtol=1e-7, atol=1e-7)
|
||||
|
||||
|
||||
def test_vecadd_no_scf():
|
||||
vecadd_no_scf_tester(num_warps=2, block_size=256)
|
||||
vecadd_no_scf_tester(num_warps=1, block_size=256)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_vecadd_no_scf()
|
||||
|
@@ -798,7 +798,8 @@ def optimize_tritongpu_ir(mod, num_stages):
|
||||
pm.add_tritongpu_pipeline_pass(num_stages)
|
||||
pm.add_canonicalizer_pass()
|
||||
pm.add_cse_pass()
|
||||
# pm.add_triton_gpu_combine_pass()
|
||||
pm.add_coalesce_pass()
|
||||
pm.add_triton_gpu_combine_pass()
|
||||
pm.add_triton_gpu_verifier_pass()
|
||||
pm.run(mod)
|
||||
return mod
|
||||
|
@@ -8,7 +8,7 @@ import os
|
||||
import subprocess
|
||||
import tempfile
|
||||
import textwrap
|
||||
from typing import Any, Dict, List
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
import torch
|
||||
|
||||
@@ -256,41 +256,126 @@ class JITFunction:
|
||||
return f"JITFunction({self.module}:{self.fn.__name__})"
|
||||
|
||||
|
||||
def pow2_divisor(N):
|
||||
if N % 16 == 0:
|
||||
return 16
|
||||
if N % 8 == 0:
|
||||
return 8
|
||||
if N % 4 == 0:
|
||||
return 4
|
||||
if N % 2 == 0:
|
||||
return 2
|
||||
return 1
|
||||
|
||||
|
||||
class _KernelCache:
|
||||
def __init__(self,
|
||||
fn: JITFunction,
|
||||
fn_type: str,
|
||||
constants: Dict[str, Any],
|
||||
num_warps: int = 4,
|
||||
num_stages: int = 3):
|
||||
# hold the arguments for building a kernel
|
||||
self.fn = fn
|
||||
self.fn_type = fn_type
|
||||
self.constants = constants
|
||||
self.num_warps = num_warps
|
||||
self.num_stages = num_stages
|
||||
|
||||
# kernel compilation cache
|
||||
self._binary_cache: Optional[LoadedBinary] = None
|
||||
|
||||
@property
|
||||
def binary_cache(self):
|
||||
return self._binary_cache
|
||||
|
||||
def set_binary_cache(self, binary: LoadedBinary):
|
||||
assert binary
|
||||
assert not self._binary_cache, "cannot set binary cache duplicately"
|
||||
self._binary_cache = binary
|
||||
|
||||
|
||||
def build_kernel(fn: JITFunction,
|
||||
fn_type: str,
|
||||
device: int,
|
||||
constants: Dict[str, Any],
|
||||
num_warps: int = 4,
|
||||
num_stages: int = 3,
|
||||
) -> LoadedBinary:
|
||||
cubin, shem_size, kernel_name = compile(fn, fn_type, device=device, constants=constants, num_warps=num_warps, num_stages=num_stages, output="cubin")
|
||||
assert cubin
|
||||
assert kernel_name
|
||||
|
||||
backend = _triton.runtime.backend.CUDA
|
||||
|
||||
max_shared_memory = _triton.runtime.max_shared_memory(backend, device)
|
||||
assert shem_size <= max_shared_memory, "shared memory out of resource, max size is %d, but want %s" % (max_shared_memory, shem_size)
|
||||
|
||||
asm = dict(cubin=cubin)
|
||||
binary = Binary(backend, kernel_name, asm, shem_size, num_warps)
|
||||
loaded_binary = LoadedBinary(device, binary)
|
||||
return loaded_binary
|
||||
) -> _KernelCache:
|
||||
return _KernelCache(fn, fn_type, constants, num_warps, num_stages)
|
||||
|
||||
|
||||
def launch_kernel(fn: JITFunction, binary: LoadedBinary, grid, num_warps, num_stages, *wargs, **kwargs):
|
||||
kwargs = {fn.arg_names.index(name): value for name, value in kwargs.items()}
|
||||
torch_dtype_to_bytes = {
|
||||
torch.int8: 1,
|
||||
torch.uint8: 1,
|
||||
|
||||
torch.int16: 2,
|
||||
torch.short: 2,
|
||||
|
||||
torch.int: 4,
|
||||
torch.int32: 4,
|
||||
|
||||
torch.long: 8,
|
||||
torch.int64: 8,
|
||||
|
||||
torch.float32: 4,
|
||||
torch.float: 4,
|
||||
|
||||
torch.float16: 2,
|
||||
torch.half: 2,
|
||||
torch.bfloat16: 2,
|
||||
# free to extend
|
||||
}
|
||||
|
||||
|
||||
def launch_kernel(kernel: _KernelCache, grid, device, *wargs, **kwargs):
|
||||
def is_tensor(arg):
|
||||
return hasattr(arg, 'data_ptr') # a torch.tensor
|
||||
|
||||
# prepare function args for compile
|
||||
kwargs = {kernel.fn.arg_names.index(name): value for name, value in kwargs.items()}
|
||||
wargs = list(wargs)
|
||||
for i, pos in enumerate(sorted(kwargs)):
|
||||
wargs.insert(pos + i, kwargs[pos])
|
||||
assert len(wargs) == len(fn.arg_names), "Function argument list not match, need %d but get %d args" % (len(fn.arg_names), len(wargs))
|
||||
assert len(wargs) == len(kernel.fn.arg_names), "Function argument list not match, need %d but get %d args" % (len(kernel.fn.arg_names), len(wargs))
|
||||
|
||||
if not kernel.binary_cache:
|
||||
# build the kernel cache
|
||||
backend = _triton.runtime.backend.CUDA
|
||||
|
||||
attributes = dict()
|
||||
for i, arg in enumerate(wargs):
|
||||
if i in kernel.fn.do_not_specialize:
|
||||
continue
|
||||
if isinstance(arg, int):
|
||||
attributes[i] = pow2_divisor(arg)
|
||||
elif is_tensor(arg):
|
||||
assert arg.dtype in torch_dtype_to_bytes
|
||||
addr = arg.data_ptr()
|
||||
range_size = _triton.runtime.get_pointer_range_size(addr)
|
||||
divisibility = min(pow2_divisor(addr), pow2_divisor(range_size)) // torch_dtype_to_bytes[arg.dtype]
|
||||
attributes[i] = divisibility
|
||||
|
||||
attributes_ = dict()
|
||||
for i, value in attributes.items():
|
||||
attributes_[kernel.fn.arg_names[i]] = value
|
||||
|
||||
cubin, shem_size, kernel_name = compile(kernel.fn, kernel.fn_type, device=device, constants=kernel.constants, attributes=attributes_, num_warps=kernel.num_warps, num_stages=kernel.num_stages, output="cubin")
|
||||
assert cubin
|
||||
assert kernel_name
|
||||
|
||||
max_shared_memory = _triton.runtime.max_shared_memory(backend, device)
|
||||
assert shem_size <= max_shared_memory, "shared memory out of resource, max size is %d, but want %s" % (max_shared_memory, shem_size)
|
||||
|
||||
asm = dict(cubin=cubin)
|
||||
binary = Binary(backend, kernel_name, asm, shem_size, kernel.num_warps)
|
||||
loaded_binary = LoadedBinary(device, binary)
|
||||
kernel.set_binary_cache(loaded_binary)
|
||||
|
||||
device = torch.cuda.current_device()
|
||||
torch.cuda.set_device(device)
|
||||
stream = get_cuda_stream(device)
|
||||
|
||||
_triton.runtime.launch_binary(binary, wargs, fn.do_not_specialize, fn.arg_names,
|
||||
stream, num_warps, num_stages, grid)
|
||||
_triton.runtime.launch_binary(kernel.binary_cache, wargs, kernel.fn.do_not_specialize, kernel.fn.arg_names,
|
||||
stream, kernel.num_warps, kernel.num_stages, grid)
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
|
Reference in New Issue
Block a user