[CODEGEN] Performance improvement on A100 (#125)
Improved codegen for the Ampere GPUs. * Make the layout pass recognize the multistage pipelined pattern. * Now the pipeline pass can automate the multistage pipelining transformation. * Remove extra barriers (from the prefetch pass & WAR) on Ampere. * Update the code generator (generator.cc) to make Triton generate n-buffered shared memory loads/stores.
This commit is contained in:
committed by
Philippe Tillet
parent
5a51f3e529
commit
d8d6b715c8
@@ -33,7 +33,10 @@ void init_triton_driver(py::module &&m) {
|
||||
CUdevice handle;
|
||||
drv::dispatch::cuDeviceGet(&handle, dev_id);
|
||||
return new drv::cu_device(handle, take_ownership);
|
||||
}));
|
||||
}))
|
||||
.def("max_shared_memory", [](drv::cu_device *self) {
|
||||
return self->max_shared_memory();
|
||||
});
|
||||
// host device
|
||||
py::class_<drv::host_device, drv::device>(m, "host_device")
|
||||
.def(py::init<>());
|
||||
@@ -75,11 +78,11 @@ void init_triton_driver(py::module &&m) {
|
||||
|
||||
void init_triton_codegen(py::module &&m) {
|
||||
m.def(
|
||||
"add_passes_to_emit_bin", [](ir::module &ir, drv::device *dev, int num_warps) {
|
||||
"add_passes_to_emit_bin", [](ir::module &ir, drv::device *dev, int num_warps, int num_stages) {
|
||||
drv::module *mod;
|
||||
drv::kernel *ker;
|
||||
size_t shared_mem;
|
||||
triton::codegen::add_passes_to_emit_bin(ir, dev, num_warps, mod, ker, shared_mem);
|
||||
triton::codegen::add_passes_to_emit_bin(ir, dev, num_warps, num_stages, mod, ker, shared_mem);
|
||||
std::stringstream ss;
|
||||
ir::print(ir, ss);
|
||||
return std::make_tuple(mod, ker, shared_mem, ss.str());
|
||||
|
@@ -27,7 +27,7 @@ def test_matmul(MODE, TRANS_A, TRANS_B, BLOCK, DTYPE, Z=3, H=2, M=512, N=384, K=
|
||||
op = triton.ops.blocksparse.matmul(layout, BLOCK, MODE, trans_a=TRANS_A, trans_b=TRANS_B)
|
||||
ra = triton.testing.sparsify_tensor(a, layout, BLOCK) if MODE == "dsd" else a
|
||||
rb = triton.testing.sparsify_tensor(b, layout, BLOCK) if MODE == "dds" else b
|
||||
rc = op(ra, rb)
|
||||
rc = triton.testing.catch_oor(lambda : op(ra, rb), pytest)
|
||||
# torch result
|
||||
ta = triton.testing.mask_tensor(a, layout, BLOCK) if MODE == "dsd" else a
|
||||
tb = triton.testing.mask_tensor(b, layout, BLOCK) if MODE == "dds" else b
|
||||
|
@@ -5,56 +5,69 @@ import torch
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"BLOCK_M, BLOCK_N, BLOCK_K, SPLIT_K, NWARP, M, N, K, AT, BT, DTYPE",
|
||||
"BLOCK_M, BLOCK_N, BLOCK_K, SPLIT_K, NWARP, NSTAGE, M, N, K, AT, BT, DTYPE",
|
||||
itertools.chain(
|
||||
*[
|
||||
[
|
||||
# 1 warp
|
||||
(16, 16, 16, 1, 1, None, None, None, AT, BT, DTYPE),
|
||||
(32, 16, 16, 1, 1, None, None, None, AT, BT, DTYPE),
|
||||
(16, 32, 16, 1, 1, None, None, None, AT, BT, DTYPE),
|
||||
(16, 16, 32, 1, 1, None, None, None, AT, BT, DTYPE),
|
||||
(32, 16, 32, 1, 1, None, None, None, AT, BT, DTYPE),
|
||||
(16, 32, 32, 1, 1, None, None, None, AT, BT, DTYPE),
|
||||
(16, 16, 64, 1, 1, None, None, None, AT, BT, DTYPE),
|
||||
(64, 16, 64, 1, 1, None, None, None, AT, BT, DTYPE),
|
||||
(16, 64, 64, 1, 1, None, None, None, AT, BT, DTYPE),
|
||||
(16, 16, 16, 1, 1, 2, None, None, None, AT, BT, DTYPE),
|
||||
(32, 16, 16, 1, 1, 2, None, None, None, AT, BT, DTYPE),
|
||||
(16, 32, 16, 1, 1, 2, None, None, None, AT, BT, DTYPE),
|
||||
(16, 16, 32, 1, 1, 2, None, None, None, AT, BT, DTYPE),
|
||||
(32, 16, 32, 1, 1, 2, None, None, None, AT, BT, DTYPE),
|
||||
(16, 32, 32, 1, 1, 2, None, None, None, AT, BT, DTYPE),
|
||||
(16, 16, 64, 1, 1, 2, None, None, None, AT, BT, DTYPE),
|
||||
(64, 16, 64, 1, 1, 2, None, None, None, AT, BT, DTYPE),
|
||||
(16, 64, 64, 1, 1, 2, None, None, None, AT, BT, DTYPE),
|
||||
# 2 warp
|
||||
(64, 32, 64, 1, 2, None, None, None, AT, BT, DTYPE),
|
||||
(32, 64, 64, 1, 2, None, None, None, AT, BT, DTYPE),
|
||||
(64, 32, 16, 1, 2, None, None, None, AT, BT, DTYPE),
|
||||
(32, 64, 16, 1, 2, None, None, None, AT, BT, DTYPE),
|
||||
(128, 32, 32, 1, 2, None, None, None, AT, BT, DTYPE),
|
||||
(32, 128, 32, 1, 2, None, None, None, AT, BT, DTYPE),
|
||||
(64, 32, 64, 1, 2, 2, None, None, None, AT, BT, DTYPE),
|
||||
(32, 64, 64, 1, 2, 2, None, None, None, AT, BT, DTYPE),
|
||||
(64, 32, 16, 1, 2, 2, None, None, None, AT, BT, DTYPE),
|
||||
(32, 64, 16, 1, 2, 2, None, None, None, AT, BT, DTYPE),
|
||||
(128, 32, 32, 1, 2, 2, None, None, None, AT, BT, DTYPE),
|
||||
(32, 128, 32, 1, 2, 2, None, None, None, AT, BT, DTYPE),
|
||||
# 4 warp
|
||||
(128, 64, 16, 1, 4, None, None, None, AT, BT, DTYPE),
|
||||
(64, 128, 16, 1, 4, None, None, None, AT, BT, DTYPE),
|
||||
(128, 32, 32, 1, 4, None, None, None, AT, BT, DTYPE),
|
||||
(32, 128, 32, 1, 4, None, None, None, AT, BT, DTYPE),
|
||||
(128, 32, 64, 1, 4, None, None, None, AT, BT, DTYPE),
|
||||
(32, 128, 64, 1, 4, None, None, None, AT, BT, DTYPE),
|
||||
(128, 64, 16, 1, 4, 2, None, None, None, AT, BT, DTYPE),
|
||||
(64, 128, 16, 1, 4, 2, None, None, None, AT, BT, DTYPE),
|
||||
(128, 32, 32, 1, 4, 2, None, None, None, AT, BT, DTYPE),
|
||||
(32, 128, 32, 1, 4, 2, None, None, None, AT, BT, DTYPE),
|
||||
(128, 32, 64, 1, 4, 2, None, None, None, AT, BT, DTYPE),
|
||||
(32, 128, 64, 1, 4, 2, None, None, None, AT, BT, DTYPE),
|
||||
# 8 warp
|
||||
(128, 256, 16, 1, 8, None, None, None, AT, BT, DTYPE),
|
||||
(256, 128, 16, 1, 8, None, None, None, AT, BT, DTYPE),
|
||||
(256, 128, 32, 1, 8, None, None, None, AT, BT, DTYPE),
|
||||
# # split-k
|
||||
(64, 64, 16, 2, 4, None, None, None, AT, BT, DTYPE),
|
||||
(64, 64, 16, 4, 4, None, None, None, AT, BT, DTYPE),
|
||||
(64, 64, 16, 8, 4, None, None, None, AT, BT, DTYPE),
|
||||
# # variable input
|
||||
(128, 128, 32, 1, 4, 1024, 1024, 1024, AT, BT, DTYPE),
|
||||
(128, 128, 32, 1, 4, 384, 128, 640, AT, BT, DTYPE),
|
||||
(128, 128, 32, 1, 4, 107, 233, 256, AT, BT, DTYPE),
|
||||
(128, 128, 32, 1, 4, 107, 233, 311, AT, BT, DTYPE),
|
||||
(128, 256, 16, 1, 8, 2, None, None, None, AT, BT, DTYPE),
|
||||
(256, 128, 16, 1, 8, 2, None, None, None, AT, BT, DTYPE),
|
||||
(256, 128, 32, 1, 8, 2, None, None, None, AT, BT, DTYPE),
|
||||
# split-k
|
||||
(64, 64, 16, 2, 4, 2, None, None, None, AT, BT, DTYPE),
|
||||
(64, 64, 16, 4, 4, 2, None, None, None, AT, BT, DTYPE),
|
||||
(64, 64, 16, 8, 4, 2, None, None, None, AT, BT, DTYPE),
|
||||
# variable input
|
||||
(128, 128, 32, 1, 4, 2, 1024, 1024, 1024, AT, BT, DTYPE),
|
||||
(128, 128, 32, 1, 4, 2, 384, 128, 640, AT, BT, DTYPE),
|
||||
(128, 128, 32, 1, 4, 2, 107, 233, 256, AT, BT, DTYPE),
|
||||
(128, 128, 32, 1, 4, 2, 107, 233, 311, AT, BT, DTYPE),
|
||||
] for DTYPE in ["float16", "float32"] for AT in [False, True] for BT in [False, True]
|
||||
],
|
||||
# n-stage
|
||||
*[
|
||||
[
|
||||
(16, 16, 16, 1, 1, STAGES, 1024, 1024, 1024, AT, BT, DTYPE),
|
||||
(64, 32, 64, 1, 2, STAGES, 1024, 1024, 1024, AT, BT, DTYPE),
|
||||
(128, 64, 16, 1, 4, STAGES, 1024, 1024, 1024, AT, BT, DTYPE),
|
||||
(256, 128, 32, 1, 8, STAGES, 1024, 1024, 1024, AT, BT, DTYPE),
|
||||
(128, 128, 32, 1, 4, STAGES, 384, 128, 640, AT, BT, DTYPE),
|
||||
# split-k
|
||||
(64, 64, 16, 8, 4, STAGES, 1024, 1024, 1024, AT, BT, DTYPE),
|
||||
(64, 64, 16, 8, 4, STAGES, 1024, 1024, 32, AT, BT, DTYPE),
|
||||
] for DTYPE in ["float16", "float32"] for AT in [False, True] for BT in [False, True] for STAGES in [2, 3, 4]
|
||||
]
|
||||
),
|
||||
)
|
||||
def test_op(BLOCK_M, BLOCK_N, BLOCK_K, SPLIT_K, NWARP, M, N, K, AT, BT, DTYPE):
|
||||
def test_op(BLOCK_M, BLOCK_N, BLOCK_K, SPLIT_K, NWARP, NSTAGE, M, N, K, AT, BT, DTYPE):
|
||||
torch.manual_seed(0)
|
||||
# nuke kernel decorators -- will set meta-parameters manually
|
||||
META = {'BLOCK_M': BLOCK_M, 'BLOCK_N': BLOCK_N, 'BLOCK_K': BLOCK_K, 'SPLIT_K': SPLIT_K, 'GROUP_M': 8}
|
||||
configs = [triton.Config(meta=META, num_warps=NWARP)]
|
||||
configs = [triton.Config(meta=META, num_warps=NWARP, num_stages=NSTAGE)]
|
||||
kernel = triton.ops._matmul.kernel
|
||||
decorators = kernel.kernel_decorators
|
||||
kernel.kernel_decorators = []
|
||||
@@ -72,5 +85,5 @@ def test_op(BLOCK_M, BLOCK_N, BLOCK_K, SPLIT_K, NWARP, M, N, K, AT, BT, DTYPE):
|
||||
b = b.t() if BT else b
|
||||
# run test
|
||||
th_c = torch.matmul(a, b)
|
||||
tt_c = triton.ops.matmul(a, b)
|
||||
tt_c = triton.testing.catch_oor(lambda : triton.ops.matmul(a, b), pytest)
|
||||
assert triton.testing.allclose(th_c, tt_c)
|
||||
|
@@ -407,13 +407,14 @@ class CodeGenerator(ast.NodeVisitor):
|
||||
|
||||
|
||||
class Binary:
|
||||
def __init__(self, module, kernel, num_warps, shared_mem, ir_asm):
|
||||
def __init__(self, module, kernel, num_warps, num_stages, shared_mem, ir_asm):
|
||||
# cache ir asm
|
||||
self.ir_asm = ir_asm
|
||||
self.module = module
|
||||
self.kernel = kernel
|
||||
self.shared_mem = shared_mem
|
||||
self.num_warps = num_warps
|
||||
self.num_stages = num_stages
|
||||
self.sass = None
|
||||
|
||||
def asm(self, mode):
|
||||
@@ -447,6 +448,13 @@ class CompilationError(Exception):
|
||||
self.message += '\n Error: ' + str(err)
|
||||
super().__init__(self.message)
|
||||
|
||||
class OutOfResources(Exception):
|
||||
def __init__(self, required, limit, name):
|
||||
self.message = f'out of resource: {name}'\
|
||||
f'Required: {required}'\
|
||||
f'Hardware limit: {limit}'
|
||||
super().__init__(self.message)
|
||||
|
||||
|
||||
class Kernel:
|
||||
@staticmethod
|
||||
@@ -513,7 +521,7 @@ class Kernel:
|
||||
def __init__(self, fn):
|
||||
self.fn = fn
|
||||
|
||||
def _compile(self, *wargs, device, attributes, constants, num_warps, **meta):
|
||||
def _compile(self, *wargs, device, attributes, constants, num_warps, num_stages, **meta):
|
||||
# explicitly set device
|
||||
torch.cuda.set_device(device.index)
|
||||
# create IR module
|
||||
@@ -535,10 +543,12 @@ class Kernel:
|
||||
raise CompilationError(self.fn.src, node, e)
|
||||
tt_device = _triton.driver.cu_device(device.index, False)
|
||||
# Compile to machine code
|
||||
mod, ker, shared_mem, ir_asm = _triton.code_gen.add_passes_to_emit_bin(generator.module, tt_device, num_warps)
|
||||
return Binary(mod, ker, num_warps, shared_mem, ir_asm)
|
||||
mod, ker, shared_mem, ir_asm = _triton.code_gen.add_passes_to_emit_bin(generator.module, tt_device, num_warps, num_stages)
|
||||
if shared_mem > tt_device.max_shared_memory():
|
||||
raise OutOfResources(shared_mem, tt_device.max_shared_memory(), "shared memory")
|
||||
return Binary(mod, ker, num_warps, num_stages, shared_mem, ir_asm)
|
||||
|
||||
def __call__(self, *wargs, grid, num_warps=4, **meta):
|
||||
def __call__(self, *wargs, grid, num_warps=4, num_stages=2, **meta):
|
||||
# device inference
|
||||
tensor_idxs = [i for i, arg in enumerate(wargs) if hasattr(arg, 'data_ptr')]
|
||||
if len(tensor_idxs) == 0:
|
||||
@@ -554,12 +564,12 @@ class Kernel:
|
||||
attr_key = frozenset(attributes.items())
|
||||
meta_key = frozenset(meta.items())
|
||||
const_key = frozenset(constants.items())
|
||||
key = (device.type, device.index, types_key, attr_key, num_warps, meta_key, const_key)
|
||||
key = (device.type, device.index, types_key, attr_key, num_warps, num_stages, meta_key, const_key)
|
||||
cache = self.fn.cache
|
||||
if key not in cache:
|
||||
# compile and cache configuration if necessary
|
||||
cache[key] = self._compile(
|
||||
*wargs, device=device, attributes=attributes, num_warps=num_warps, constants=constants, **meta
|
||||
*wargs, device=device, attributes=attributes, num_warps=num_warps, num_stages=num_stages, constants=constants, **meta
|
||||
)
|
||||
# pack arguments
|
||||
fmt = ''.join(['P' if i in tensor_idxs else Kernel._type_name(arg.__class__) for i, arg in enumerate(wargs)])
|
||||
@@ -585,7 +595,7 @@ class Launcher:
|
||||
class Autotuner:
|
||||
def __init__(self, kernel, arg_names, configs, key):
|
||||
if not configs:
|
||||
self.configs = [Config(dict(), num_warps=4)]
|
||||
self.configs = [Config(dict(), num_warps=4, num_stages=2)]
|
||||
else:
|
||||
self.configs = configs
|
||||
self.key_idx = [arg_names.index(k) for k in key]
|
||||
@@ -603,7 +613,7 @@ class Autotuner:
|
||||
)
|
||||
# augment meta-parameters with tunable ones
|
||||
current = dict(meta, **config.meta)
|
||||
kernel_call = lambda: self.kernel(*args, num_warps=config.num_warps, **current)
|
||||
kernel_call = lambda: self.kernel(*args, num_warps=config.num_warps, num_stages=config.num_stages, **current)
|
||||
return triton.testing.do_bench(kernel_call)
|
||||
|
||||
def __call__(self, *args, **meta):
|
||||
@@ -616,7 +626,7 @@ class Autotuner:
|
||||
config = self.cache[key]
|
||||
else:
|
||||
config = self.configs[0]
|
||||
return self.kernel(*args, num_warps=config.num_warps, **meta, **config.meta)
|
||||
return self.kernel(*args, num_warps=config.num_warps, num_stages=config.num_stages, **meta, **config.meta)
|
||||
|
||||
|
||||
class JITFunction:
|
||||
@@ -671,9 +681,10 @@ class JITFunction:
|
||||
|
||||
|
||||
class Config:
|
||||
def __init__(self, meta, num_warps=4):
|
||||
def __init__(self, meta, num_warps=4, num_stages=2):
|
||||
self.meta = meta
|
||||
self.num_warps = num_warps
|
||||
self.num_stages = num_stages
|
||||
|
||||
|
||||
def autotune(configs, key):
|
||||
|
@@ -1,5 +1,6 @@
|
||||
import torch
|
||||
import os
|
||||
from .code_gen import OutOfResources
|
||||
|
||||
try:
|
||||
import triton._C.libtriton.cutlass as _cutlass
|
||||
@@ -8,6 +9,15 @@ except ImportError:
|
||||
_cutlass = None
|
||||
has_cutlass = False
|
||||
|
||||
def catch_oor(kernel, pytest_handle=None):
|
||||
try:
|
||||
res = kernel()
|
||||
except OutOfResources as e:
|
||||
if pytest_handle:
|
||||
pytest_handle.skip(str(e))
|
||||
return None
|
||||
return res
|
||||
|
||||
|
||||
def sparsify_tensor(x, mask, block):
|
||||
ret = torch.empty((x.size(0), mask.sum(), block, block), dtype=x.dtype, device=x.device)
|
||||
|
Reference in New Issue
Block a user