[RUNTIME] Config hook v2.0 (#373)

* Add pre_hook to triton.Config
* Use argument names in triton.heuristics
* Update base perf
* Remove meta from heuristics
This commit is contained in:
daadaada
2021-11-22 03:20:59 +08:00
committed by GitHub
parent 5693b582ea
commit 1296eb877b
7 changed files with 44 additions and 40 deletions

View File

@@ -574,7 +574,7 @@ class Kernel:
prototype = _triton.ir.type.make_function(ret_type, arg_types)
# generate Triton-IR
# export symbols visible from self.fn into code-generator object
gscope = sys.modules[self.fn.module].__dict__
gscope = self.fn.fn.__globals__
generator = CodeGenerator(context, prototype, gscope=gscope, attributes=attributes, constants=constants, kwargs=dict())
try:
generator.visit(self.fn.parse())
@@ -698,6 +698,7 @@ class Autotuner:
for i in self.reset_idx:
args[i].zero_()
self.hook = _hook
self.arg_names = arg_names
def _bench(self, *args, config, **meta):
# check for conflicts, i.e. meta-parameters both provided
@@ -711,11 +712,14 @@ class Autotuner:
# augment meta-parameters with tunable ones
current = dict(meta, **config.kwargs)
def kernel_call():
if config.pre_hook:
config.pre_hook(self.nargs)
self.hook(args)
self.kernel(*args, num_warps=config.num_warps, num_stages=config.num_stages, **current)
return triton.testing.do_bench(kernel_call)
def __call__(self, *args, **kwargs):
self.nargs = dict(zip(self.arg_names, args))
if len(self.configs) > 1:
key = tuple([args[i] for i in self.key_idx])
if key not in self.cache:
@@ -726,6 +730,8 @@ class Autotuner:
config = self.cache[key]
else:
config = self.configs[0]
if config.pre_hook != None:
config.pre_hook(self.nargs)
return self.kernel(*args, num_warps=config.num_warps, num_stages=config.num_stages, **kwargs, **config.kwargs)
@@ -893,11 +899,14 @@ class Config:
:ivar num_stages: the number of stages that the compiler should use when software-pipelining loops.
Mostly useful for matrix multiplication workloads on SM80+ GPUs.
:type num_stages: int
:ivar pre_hook: a function that will be called before the kernel is called. Parameters of this
function are args.
"""
def __init__(self, kwargs, num_warps=4, num_stages=2):
def __init__(self, kwargs, num_warps=4, num_stages=2, pre_hook=None):
self.kwargs = kwargs
self.num_warps = num_warps
self.num_stages = num_stages
self.pre_hook = pre_hook
def autotune(configs, key, reset_to_zero=None):
@@ -963,7 +972,7 @@ def heuristics(values):
def fun(*args, **meta):
for v, heur in values.items():
assert v not in meta
meta[v] = heur(*args, **meta)
meta[v] = heur({**dict(zip(fn.arg_names, args)), **meta})
return kernel(*args, **meta)
return fun