[PYTHON][KERNEL] Added benchmarking functionalities for kernels

This commit is contained in:
Philippe Tillet
2019-10-27 15:32:34 -04:00
parent e11557855f
commit 0ec213547c
9 changed files with 207 additions and 112 deletions

View File

@@ -5,6 +5,7 @@ import shutil
import hashlib
import sysconfig
import sys
import weakref
# import for just-in-time compilation
import distutils
import setuptools.command.build_ext
@@ -176,6 +177,38 @@ def _make_grid(args) :
return grid
class bench_dict:
# Lazy entry for e.g., tensorflow, when value of benchmark is
# not known at graph compile time
class lazy_entry:
def __init__(self, id):
self.id = id
def get(self):
return libtriton.retrieve_scalar(self.id)
def __init__(self):
self.data = dict()
def __delitem__(self, key):
del self.data[id(key)]
def __getitem__(self, key):
ret = self.data[id(key)]
if isinstance(ret, bench_dict.lazy_entry):
return ret.get()
return ret
def __len__(self):
return len(self.data)
def __setitem__(self, key, value):
self.data[id(key)] = value
bench_registry = bench_dict()
class kernel:
def __init__(self, src, outputs):
@@ -200,7 +233,7 @@ class kernel:
defines.append((k, values))
opt = libtriton.options_space()
opt.defines = defines
opt.num_warps = [4]
opt.num_warps = [2, 4, 8]
# create unique id for this op
op_id = libtriton.make_op_id()
self.fw_id[key] = op_id
@@ -209,6 +242,10 @@ class kernel:
if self.fw_op is None:
self.fw_op = _make_framework_op(self.src, self.outputs, opt)
# benchmarking info
bench = 0
if 'bench' in kwargs:
bench = kwargs['bench']
# retrieve framework op
op_id = self.fw_id[key]
# register grid
@@ -217,9 +254,16 @@ class kernel:
op_args = [x.handle if isinstance(x, triton.utils.scalar) else x for x in args[:-1]]
# call framework function
if fw.has_tensorflow():
return self.fw_op(*op_args, id=op_id)
bench_id = libtriton.make_scalar_id() if bench > 0 else 0
ret = self.fw_op(*op_args, id=op_id, bench=bench, bench_id=bench_id)
if bench > 0:
bench_registry[ret] = bench_dict.lazy_entry(bench_id)
elif fw.has_torch():
args = [x.contiguous() if isinstance(x, fw.torch.Tensor) else x for x in op_args]
return self.fw_op(op_id, *args)
ret = self.fw_op(op_id, bench, *args)
if bench > 0:
bench_registry[ret] = libtriton.retrieve_scalar(op_id)
else:
assert False
assert False
return ret