[OPS] Add performance model for gemm/gemv (#397)
Significantly improves the performance of `triton.ops.matmul` in memory-bound settings via the use of many more block configs coupled with a performance model to drive the auto-tuning process.
This commit is contained in:
@@ -17,6 +17,10 @@ import triton
|
||||
import triton._C.libtriton.triton as _triton
|
||||
from filelock import FileLock
|
||||
import dbm
|
||||
import tempfile
|
||||
from typing import Optional, Dict
|
||||
import time
|
||||
|
||||
|
||||
|
||||
class CodeGenerator(ast.NodeVisitor):
|
||||
@@ -508,6 +512,7 @@ class LoadedBinary:
|
||||
device)
|
||||
self.bin = bin
|
||||
self.asm = bin.asm
|
||||
self.sass = ''
|
||||
self.module = module
|
||||
self.kernel = kernel
|
||||
self.device = device
|
||||
@@ -519,6 +524,19 @@ class LoadedBinary:
|
||||
self.bin.num_warps * 32, 1, 1,
|
||||
args, self.bin.shared_mem)
|
||||
|
||||
def get_sass(self, fun=None):
|
||||
if self.sass:
|
||||
return self.sass
|
||||
fd, path = tempfile.mkstemp()
|
||||
try:
|
||||
with open(fd, 'wb') as cubin:
|
||||
cubin.write(self.asm['cubin'])
|
||||
self.sass = extract(path, fun)
|
||||
finally:
|
||||
os.remove(path)
|
||||
self.asm['sass'] = self.sass
|
||||
return self.sass
|
||||
|
||||
|
||||
class CompilationError(Exception):
|
||||
def __init__(self, src, node):
|
||||
@@ -530,8 +548,8 @@ class CompilationError(Exception):
|
||||
|
||||
class OutOfResources(Exception):
|
||||
def __init__(self, required, limit, name):
|
||||
self.message = f'out of resource: {name}'\
|
||||
f'Required: {required}'\
|
||||
self.message = f'out of resource: {name}, '\
|
||||
f'Required: {required}, '\
|
||||
f'Hardware limit: {limit}'
|
||||
super().__init__(self.message)
|
||||
self.args = (required, limit, name)
|
||||
@@ -727,7 +745,13 @@ class Launcher:
|
||||
|
||||
|
||||
class Autotuner:
|
||||
def __init__(self, kernel, arg_names, configs, key, reset_to_zero):
|
||||
def __init__(self, kernel, arg_names, configs, key, reset_to_zero, prune_configs_by: Dict=None):
|
||||
'''
|
||||
:param prune_configs_by: a dict of functions that are used to prune configs, fields:
|
||||
'perf_model': performance model used to predicate running time with different configs, returns running time
|
||||
'top_k': number of configs to bench
|
||||
'prune_num_stages_by'(optional): a function used to prune num_stages. It take configs:List[Config] as its input, and returns pruned configs.
|
||||
'''
|
||||
if not configs:
|
||||
self.configs = [Config(dict(), num_warps=4, num_stages=2)]
|
||||
else:
|
||||
@@ -744,7 +768,16 @@ class Autotuner:
|
||||
args[i].zero_()
|
||||
self.hook = _hook
|
||||
self.arg_names = arg_names
|
||||
|
||||
# prune configs
|
||||
if prune_configs_by:
|
||||
perf_model, top_k = prune_configs_by['perf_model'], prune_configs_by['top_k']
|
||||
if 'prune_num_stages_by' in prune_configs_by:
|
||||
prune_num_stages_by = prune_configs_by['prune_num_stages_by']
|
||||
else:
|
||||
perf_model, top_k, prune_num_stages_by = None, None, None
|
||||
self.perf_model, self.configs_top_k = perf_model, top_k
|
||||
self.prune_num_stages_by = prune_num_stages_by
|
||||
|
||||
def _bench(self, *args, config, **meta):
|
||||
# check for conflicts, i.e. meta-parameters both provided
|
||||
# as kwargs and by the autotuner
|
||||
@@ -768,13 +801,29 @@ class Autotuner:
|
||||
if len(self.configs) > 1:
|
||||
key = tuple([args[i] for i in self.key_idx])
|
||||
if key not in self.cache:
|
||||
# prune configs
|
||||
pruned_configs = self.configs
|
||||
if self.prune_num_stages_by:
|
||||
pruned_configs = self.prune_num_stages_by(self.configs)
|
||||
if self.perf_model:
|
||||
top_k = self.configs_top_k
|
||||
if isinstance(top_k, float) and top_k <= 1.0:
|
||||
top_k = int(len(self.configs) * top_k)
|
||||
if len(pruned_configs) > top_k:
|
||||
est_timing = {config: self.perf_model(**self.nargs, **kwargs, **config.kwargs, num_stages=config.num_stages, num_warps=config.num_warps) for config in pruned_configs}
|
||||
pruned_configs = sorted(est_timing.keys(), key=lambda x:est_timing[x])[:top_k]
|
||||
bench_start = time.time()
|
||||
timings = {config: self._bench(*args, config=config, **kwargs) \
|
||||
for config in self.configs}
|
||||
for config in pruned_configs}
|
||||
bench_end = time.time()
|
||||
self.bench_time = bench_end - bench_start
|
||||
self.cache[key] = builtins.min(timings, key=timings.get)
|
||||
self.hook(args)
|
||||
self.configs_timings = timings
|
||||
config = self.cache[key]
|
||||
else:
|
||||
config = self.configs[0]
|
||||
self.best_config = config
|
||||
if config.pre_hook != None:
|
||||
config.pre_hook(self.nargs)
|
||||
return self.kernel(*args, num_warps=config.num_warps, num_stages=config.num_stages, **kwargs, **config.kwargs)
|
||||
@@ -832,6 +881,8 @@ class DependenciesFinder(ast.NodeVisitor):
|
||||
module = inspect.getmodule(func)
|
||||
if module and module.__name__.startswith('triton.'):
|
||||
return
|
||||
if inspect.isbuiltin(func):
|
||||
return
|
||||
if not hasattr(func, 'hash'):
|
||||
src = textwrap.dedent(inspect.getsource(func))
|
||||
tree = ast.parse(src)
|
||||
@@ -957,8 +1008,16 @@ class Config:
|
||||
self.num_stages = num_stages
|
||||
self.pre_hook = pre_hook
|
||||
|
||||
def __str__(self):
|
||||
res = []
|
||||
for k, v in self.kwargs.items():
|
||||
res.append(f'{k}: {v}')
|
||||
res.append(f'num_warps: {self.num_warps}')
|
||||
res.append(f'num_stages: {self.num_stages}')
|
||||
return ', '.join(res)
|
||||
|
||||
def autotune(configs, key, reset_to_zero=None):
|
||||
|
||||
def autotune(configs, key, prune_configs_by=None, reset_to_zero=None):
|
||||
"""
|
||||
Decorator for auto-tuning a :code:`triton.jit`'d function.
|
||||
|
||||
@@ -985,12 +1044,16 @@ def autotune(configs, key, reset_to_zero=None):
|
||||
:type configs: list[triton.Config]
|
||||
:param key: a list of argument names whose change in value will trigger the evaluation of all provided configs.
|
||||
:type key: list[str]
|
||||
:param prune_configs_by: a dict of functions that are used to prune configs, fields:
|
||||
'perf_model': performance model used to predicate running time with different configs, returns running time
|
||||
'top_k': number of configs to bench
|
||||
'prune_num_stages_by'(optional): a function used to prune num_stages. It take configs:List[Config] as its input, and returns pruned configs.
|
||||
:param reset_to_zero: a list of argument names whose value will be reset to zero before evaluating any configs.
|
||||
:type reset_to_zero: list[str]
|
||||
"""
|
||||
def decorator(fn):
|
||||
def wrapper(kernel):
|
||||
return Autotuner(kernel, fn.arg_names, configs, key, reset_to_zero)
|
||||
return Autotuner(kernel, fn.arg_names, configs, key, reset_to_zero, prune_configs_by)
|
||||
|
||||
fn.kernel_decorators.append(wrapper)
|
||||
return fn
|
||||
@@ -1023,7 +1086,6 @@ def heuristics(values):
|
||||
assert v not in meta
|
||||
meta[v] = heur({**dict(zip(fn.arg_names, args)), **meta})
|
||||
return kernel(*args, **meta)
|
||||
|
||||
return fun
|
||||
|
||||
fn.kernel_decorators.append(wrapper)
|
||||
|
Reference in New Issue
Block a user