[OPS] Add performance model for gemm/gemv (#397)

Significantly improves the performance of `triton.ops.matmul` in memory-bound settings via the use of many more block configs coupled with a performance model to drive the auto-tuning process.
This commit is contained in:
daadaada
2021-12-22 01:56:10 +08:00
committed by GitHub
parent 5cdb948c05
commit 39d4bfed83
12 changed files with 289 additions and 27 deletions

View File

@@ -17,6 +17,10 @@ import triton
import triton._C.libtriton.triton as _triton
from filelock import FileLock
import dbm
import tempfile
from typing import Optional, Dict
import time
class CodeGenerator(ast.NodeVisitor):
@@ -508,6 +512,7 @@ class LoadedBinary:
device)
self.bin = bin
self.asm = bin.asm
self.sass = ''
self.module = module
self.kernel = kernel
self.device = device
@@ -519,6 +524,19 @@ class LoadedBinary:
self.bin.num_warps * 32, 1, 1,
args, self.bin.shared_mem)
def get_sass(self, fun=None):
if self.sass:
return self.sass
fd, path = tempfile.mkstemp()
try:
with open(fd, 'wb') as cubin:
cubin.write(self.asm['cubin'])
self.sass = extract(path, fun)
finally:
os.remove(path)
self.asm['sass'] = self.sass
return self.sass
class CompilationError(Exception):
def __init__(self, src, node):
@@ -530,8 +548,8 @@ class CompilationError(Exception):
class OutOfResources(Exception):
def __init__(self, required, limit, name):
self.message = f'out of resource: {name}'\
f'Required: {required}'\
self.message = f'out of resource: {name}, '\
f'Required: {required}, '\
f'Hardware limit: {limit}'
super().__init__(self.message)
self.args = (required, limit, name)
@@ -727,7 +745,13 @@ class Launcher:
class Autotuner:
def __init__(self, kernel, arg_names, configs, key, reset_to_zero):
def __init__(self, kernel, arg_names, configs, key, reset_to_zero, prune_configs_by: Dict=None):
'''
:param prune_configs_by: a dict of functions that are used to prune configs, fields:
'perf_model': performance model used to predicate running time with different configs, returns running time
'top_k': number of configs to bench
'prune_num_stages_by'(optional): a function used to prune num_stages. It take configs:List[Config] as its input, and returns pruned configs.
'''
if not configs:
self.configs = [Config(dict(), num_warps=4, num_stages=2)]
else:
@@ -744,7 +768,16 @@ class Autotuner:
args[i].zero_()
self.hook = _hook
self.arg_names = arg_names
# prune configs
if prune_configs_by:
perf_model, top_k = prune_configs_by['perf_model'], prune_configs_by['top_k']
if 'prune_num_stages_by' in prune_configs_by:
prune_num_stages_by = prune_configs_by['prune_num_stages_by']
else:
perf_model, top_k, prune_num_stages_by = None, None, None
self.perf_model, self.configs_top_k = perf_model, top_k
self.prune_num_stages_by = prune_num_stages_by
def _bench(self, *args, config, **meta):
# check for conflicts, i.e. meta-parameters both provided
# as kwargs and by the autotuner
@@ -768,13 +801,29 @@ class Autotuner:
if len(self.configs) > 1:
key = tuple([args[i] for i in self.key_idx])
if key not in self.cache:
# prune configs
pruned_configs = self.configs
if self.prune_num_stages_by:
pruned_configs = self.prune_num_stages_by(self.configs)
if self.perf_model:
top_k = self.configs_top_k
if isinstance(top_k, float) and top_k <= 1.0:
top_k = int(len(self.configs) * top_k)
if len(pruned_configs) > top_k:
est_timing = {config: self.perf_model(**self.nargs, **kwargs, **config.kwargs, num_stages=config.num_stages, num_warps=config.num_warps) for config in pruned_configs}
pruned_configs = sorted(est_timing.keys(), key=lambda x:est_timing[x])[:top_k]
bench_start = time.time()
timings = {config: self._bench(*args, config=config, **kwargs) \
for config in self.configs}
for config in pruned_configs}
bench_end = time.time()
self.bench_time = bench_end - bench_start
self.cache[key] = builtins.min(timings, key=timings.get)
self.hook(args)
self.configs_timings = timings
config = self.cache[key]
else:
config = self.configs[0]
self.best_config = config
if config.pre_hook != None:
config.pre_hook(self.nargs)
return self.kernel(*args, num_warps=config.num_warps, num_stages=config.num_stages, **kwargs, **config.kwargs)
@@ -832,6 +881,8 @@ class DependenciesFinder(ast.NodeVisitor):
module = inspect.getmodule(func)
if module and module.__name__.startswith('triton.'):
return
if inspect.isbuiltin(func):
return
if not hasattr(func, 'hash'):
src = textwrap.dedent(inspect.getsource(func))
tree = ast.parse(src)
@@ -957,8 +1008,16 @@ class Config:
self.num_stages = num_stages
self.pre_hook = pre_hook
def __str__(self):
res = []
for k, v in self.kwargs.items():
res.append(f'{k}: {v}')
res.append(f'num_warps: {self.num_warps}')
res.append(f'num_stages: {self.num_stages}')
return ', '.join(res)
def autotune(configs, key, reset_to_zero=None):
def autotune(configs, key, prune_configs_by=None, reset_to_zero=None):
"""
Decorator for auto-tuning a :code:`triton.jit`'d function.
@@ -985,12 +1044,16 @@ def autotune(configs, key, reset_to_zero=None):
:type configs: list[triton.Config]
:param key: a list of argument names whose change in value will trigger the evaluation of all provided configs.
:type key: list[str]
:param prune_configs_by: a dict of functions that are used to prune configs, fields:
'perf_model': performance model used to predicate running time with different configs, returns running time
'top_k': number of configs to bench
'prune_num_stages_by'(optional): a function used to prune num_stages. It take configs:List[Config] as its input, and returns pruned configs.
:param reset_to_zero: a list of argument names whose value will be reset to zero before evaluating any configs.
:type reset_to_zero: list[str]
"""
def decorator(fn):
def wrapper(kernel):
return Autotuner(kernel, fn.arg_names, configs, key, reset_to_zero)
return Autotuner(kernel, fn.arg_names, configs, key, reset_to_zero, prune_configs_by)
fn.kernel_decorators.append(wrapper)
return fn
@@ -1023,7 +1086,6 @@ def heuristics(values):
assert v not in meta
meta[v] = heur({**dict(zip(fn.arg_names, args)), **meta})
return kernel(*args, **meta)
return fun
fn.kernel_decorators.append(wrapper)