[PYTHON][KERNEL] Added benchmarking functionalities for kernels
This commit is contained in:
@@ -5,6 +5,7 @@ import shutil
|
||||
import hashlib
|
||||
import sysconfig
|
||||
import sys
|
||||
import weakref
|
||||
# import for just-in-time compilation
|
||||
import distutils
|
||||
import setuptools.command.build_ext
|
||||
@@ -176,6 +177,38 @@ def _make_grid(args) :
|
||||
return grid
|
||||
|
||||
|
||||
class bench_dict:
|
||||
|
||||
# Lazy entry for e.g., tensorflow, when value of benchmark is
|
||||
# not known at graph compile time
|
||||
class lazy_entry:
|
||||
def __init__(self, id):
|
||||
self.id = id
|
||||
|
||||
def get(self):
|
||||
return libtriton.retrieve_scalar(self.id)
|
||||
|
||||
|
||||
def __init__(self):
|
||||
self.data = dict()
|
||||
|
||||
def __delitem__(self, key):
|
||||
del self.data[id(key)]
|
||||
|
||||
def __getitem__(self, key):
|
||||
ret = self.data[id(key)]
|
||||
if isinstance(ret, bench_dict.lazy_entry):
|
||||
return ret.get()
|
||||
return ret
|
||||
|
||||
def __len__(self):
|
||||
return len(self.data)
|
||||
|
||||
def __setitem__(self, key, value):
|
||||
self.data[id(key)] = value
|
||||
|
||||
bench_registry = bench_dict()
|
||||
|
||||
class kernel:
|
||||
|
||||
def __init__(self, src, outputs):
|
||||
@@ -200,7 +233,7 @@ class kernel:
|
||||
defines.append((k, values))
|
||||
opt = libtriton.options_space()
|
||||
opt.defines = defines
|
||||
opt.num_warps = [4]
|
||||
opt.num_warps = [2, 4, 8]
|
||||
# create unique id for this op
|
||||
op_id = libtriton.make_op_id()
|
||||
self.fw_id[key] = op_id
|
||||
@@ -209,6 +242,10 @@ class kernel:
|
||||
if self.fw_op is None:
|
||||
self.fw_op = _make_framework_op(self.src, self.outputs, opt)
|
||||
|
||||
# benchmarking info
|
||||
bench = 0
|
||||
if 'bench' in kwargs:
|
||||
bench = kwargs['bench']
|
||||
# retrieve framework op
|
||||
op_id = self.fw_id[key]
|
||||
# register grid
|
||||
@@ -217,9 +254,16 @@ class kernel:
|
||||
op_args = [x.handle if isinstance(x, triton.utils.scalar) else x for x in args[:-1]]
|
||||
# call framework function
|
||||
if fw.has_tensorflow():
|
||||
return self.fw_op(*op_args, id=op_id)
|
||||
bench_id = libtriton.make_scalar_id() if bench > 0 else 0
|
||||
ret = self.fw_op(*op_args, id=op_id, bench=bench, bench_id=bench_id)
|
||||
if bench > 0:
|
||||
bench_registry[ret] = bench_dict.lazy_entry(bench_id)
|
||||
|
||||
elif fw.has_torch():
|
||||
args = [x.contiguous() if isinstance(x, fw.torch.Tensor) else x for x in op_args]
|
||||
return self.fw_op(op_id, *args)
|
||||
ret = self.fw_op(op_id, bench, *args)
|
||||
if bench > 0:
|
||||
bench_registry[ret] = libtriton.retrieve_scalar(op_id)
|
||||
else:
|
||||
assert False
|
||||
assert False
|
||||
return ret
|
Reference in New Issue
Block a user