[CODEGEN] Performance improvement on A100 (#125)

Improved codegen for the Ampere GPUs.

    * Make the layout pass recognize the multistage pipelined pattern.
    * Now the pipeline pass can automate the multistage pipelining transformation.
    * Remove extra barriers (from the prefetch pass & WAR) on Ampere.
    * Update the code generator (generator.cc) to make Triton generate n-buffered shared memory loads/stores.
This commit is contained in:
daadaada
2021-06-21 14:25:13 +08:00
committed by Philippe Tillet
parent 5a51f3e529
commit d8d6b715c8
21 changed files with 855 additions and 174 deletions

View File

@@ -407,13 +407,14 @@ class CodeGenerator(ast.NodeVisitor):
class Binary:
def __init__(self, module, kernel, num_warps, shared_mem, ir_asm):
def __init__(self, module, kernel, num_warps, num_stages, shared_mem, ir_asm):
# cache ir asm
self.ir_asm = ir_asm
self.module = module
self.kernel = kernel
self.shared_mem = shared_mem
self.num_warps = num_warps
self.num_stages = num_stages
self.sass = None
def asm(self, mode):
@@ -447,6 +448,13 @@ class CompilationError(Exception):
self.message += '\n Error: ' + str(err)
super().__init__(self.message)
class OutOfResources(Exception):
def __init__(self, required, limit, name):
self.message = f'out of resource: {name}'\
f'Required: {required}'\
f'Hardware limit: {limit}'
super().__init__(self.message)
class Kernel:
@staticmethod
@@ -513,7 +521,7 @@ class Kernel:
def __init__(self, fn):
self.fn = fn
def _compile(self, *wargs, device, attributes, constants, num_warps, **meta):
def _compile(self, *wargs, device, attributes, constants, num_warps, num_stages, **meta):
# explicitly set device
torch.cuda.set_device(device.index)
# create IR module
@@ -535,10 +543,12 @@ class Kernel:
raise CompilationError(self.fn.src, node, e)
tt_device = _triton.driver.cu_device(device.index, False)
# Compile to machine code
mod, ker, shared_mem, ir_asm = _triton.code_gen.add_passes_to_emit_bin(generator.module, tt_device, num_warps)
return Binary(mod, ker, num_warps, shared_mem, ir_asm)
mod, ker, shared_mem, ir_asm = _triton.code_gen.add_passes_to_emit_bin(generator.module, tt_device, num_warps, num_stages)
if shared_mem > tt_device.max_shared_memory():
raise OutOfResources(shared_mem, tt_device.max_shared_memory(), "shared memory")
return Binary(mod, ker, num_warps, num_stages, shared_mem, ir_asm)
def __call__(self, *wargs, grid, num_warps=4, **meta):
def __call__(self, *wargs, grid, num_warps=4, num_stages=2, **meta):
# device inference
tensor_idxs = [i for i, arg in enumerate(wargs) if hasattr(arg, 'data_ptr')]
if len(tensor_idxs) == 0:
@@ -554,12 +564,12 @@ class Kernel:
attr_key = frozenset(attributes.items())
meta_key = frozenset(meta.items())
const_key = frozenset(constants.items())
key = (device.type, device.index, types_key, attr_key, num_warps, meta_key, const_key)
key = (device.type, device.index, types_key, attr_key, num_warps, num_stages, meta_key, const_key)
cache = self.fn.cache
if key not in cache:
# compile and cache configuration if necessary
cache[key] = self._compile(
*wargs, device=device, attributes=attributes, num_warps=num_warps, constants=constants, **meta
*wargs, device=device, attributes=attributes, num_warps=num_warps, num_stages=num_stages, constants=constants, **meta
)
# pack arguments
fmt = ''.join(['P' if i in tensor_idxs else Kernel._type_name(arg.__class__) for i, arg in enumerate(wargs)])
@@ -585,7 +595,7 @@ class Launcher:
class Autotuner:
def __init__(self, kernel, arg_names, configs, key):
if not configs:
self.configs = [Config(dict(), num_warps=4)]
self.configs = [Config(dict(), num_warps=4, num_stages=2)]
else:
self.configs = configs
self.key_idx = [arg_names.index(k) for k in key]
@@ -603,7 +613,7 @@ class Autotuner:
)
# augment meta-parameters with tunable ones
current = dict(meta, **config.meta)
kernel_call = lambda: self.kernel(*args, num_warps=config.num_warps, **current)
kernel_call = lambda: self.kernel(*args, num_warps=config.num_warps, num_stages=config.num_stages, **current)
return triton.testing.do_bench(kernel_call)
def __call__(self, *args, **meta):
@@ -616,7 +626,7 @@ class Autotuner:
config = self.cache[key]
else:
config = self.configs[0]
return self.kernel(*args, num_warps=config.num_warps, **meta, **config.meta)
return self.kernel(*args, num_warps=config.num_warps, num_stages=config.num_stages, **meta, **config.meta)
class JITFunction:
@@ -671,9 +681,10 @@ class JITFunction:
class Config:
def __init__(self, meta, num_warps=4):
def __init__(self, meta, num_warps=4, num_stages=2):
self.meta = meta
self.num_warps = num_warps
self.num_stages = num_stages
def autotune(configs, key):