[RUNTIME] Config hook v2.0 (#373)
* Add pre_hook to triton.Config * Use argument names in triton.heuristics * Update base perf * Remove meta from heuristics
This commit is contained in:
@@ -574,7 +574,7 @@ class Kernel:
|
||||
prototype = _triton.ir.type.make_function(ret_type, arg_types)
|
||||
# generate Triton-IR
|
||||
# export symbols visible from self.fn into code-generator object
|
||||
gscope = sys.modules[self.fn.module].__dict__
|
||||
gscope = self.fn.fn.__globals__
|
||||
generator = CodeGenerator(context, prototype, gscope=gscope, attributes=attributes, constants=constants, kwargs=dict())
|
||||
try:
|
||||
generator.visit(self.fn.parse())
|
||||
@@ -698,6 +698,7 @@ class Autotuner:
|
||||
for i in self.reset_idx:
|
||||
args[i].zero_()
|
||||
self.hook = _hook
|
||||
self.arg_names = arg_names
|
||||
|
||||
def _bench(self, *args, config, **meta):
|
||||
# check for conflicts, i.e. meta-parameters both provided
|
||||
@@ -711,11 +712,14 @@ class Autotuner:
|
||||
# augment meta-parameters with tunable ones
|
||||
current = dict(meta, **config.kwargs)
|
||||
def kernel_call():
|
||||
if config.pre_hook:
|
||||
config.pre_hook(self.nargs)
|
||||
self.hook(args)
|
||||
self.kernel(*args, num_warps=config.num_warps, num_stages=config.num_stages, **current)
|
||||
return triton.testing.do_bench(kernel_call)
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
self.nargs = dict(zip(self.arg_names, args))
|
||||
if len(self.configs) > 1:
|
||||
key = tuple([args[i] for i in self.key_idx])
|
||||
if key not in self.cache:
|
||||
@@ -726,6 +730,8 @@ class Autotuner:
|
||||
config = self.cache[key]
|
||||
else:
|
||||
config = self.configs[0]
|
||||
if config.pre_hook != None:
|
||||
config.pre_hook(self.nargs)
|
||||
return self.kernel(*args, num_warps=config.num_warps, num_stages=config.num_stages, **kwargs, **config.kwargs)
|
||||
|
||||
|
||||
@@ -893,11 +899,14 @@ class Config:
|
||||
:ivar num_stages: the number of stages that the compiler should use when software-pipelining loops.
|
||||
Mostly useful for matrix multiplication workloads on SM80+ GPUs.
|
||||
:type num_stages: int
|
||||
:ivar pre_hook: a function that will be called before the kernel is called. Parameters of this
|
||||
function are args.
|
||||
"""
|
||||
def __init__(self, kwargs, num_warps=4, num_stages=2):
|
||||
def __init__(self, kwargs, num_warps=4, num_stages=2, pre_hook=None):
|
||||
self.kwargs = kwargs
|
||||
self.num_warps = num_warps
|
||||
self.num_stages = num_stages
|
||||
self.pre_hook = pre_hook
|
||||
|
||||
|
||||
def autotune(configs, key, reset_to_zero=None):
|
||||
@@ -963,7 +972,7 @@ def heuristics(values):
|
||||
def fun(*args, **meta):
|
||||
for v, heur in values.items():
|
||||
assert v not in meta
|
||||
meta[v] = heur(*args, **meta)
|
||||
meta[v] = heur({**dict(zip(fn.arg_names, args)), **meta})
|
||||
return kernel(*args, **meta)
|
||||
|
||||
return fun
|
||||
|
Reference in New Issue
Block a user