[FRONTEND] Significantly reduce kernel launch time (#367)
This commit is contained in:
@@ -464,6 +464,7 @@ class LoadedBinary:
|
||||
self.module = module
|
||||
self.kernel = kernel
|
||||
self.device = device
|
||||
self.shared_mem = bin.shared_mem
|
||||
|
||||
def __call__(self, stream, args, grid_0, grid_1=1, grid_2=1):
|
||||
_triton.runtime.enqueue(self.bin.backend, stream, self.kernel,
|
||||
@@ -548,16 +549,6 @@ class Kernel:
|
||||
name = Kernel._type_name(obj)
|
||||
return type_map[name](context)
|
||||
|
||||
@staticmethod
|
||||
def _types_key(*wargs, tensor_idxs):
|
||||
# type inference
|
||||
types_key = [None] * len(wargs)
|
||||
for i, arg in enumerate(wargs):
|
||||
prefix = 'P' if i in tensor_idxs else ''
|
||||
suffix = Kernel._type_name(arg) if i in tensor_idxs else Kernel._type_name(arg)
|
||||
types_key[i] = prefix + suffix
|
||||
return tuple(types_key)
|
||||
|
||||
@staticmethod
|
||||
def pow2_divisor(N):
|
||||
if N % 16 == 0: return 16
|
||||
@@ -599,6 +590,53 @@ class Kernel:
|
||||
raise OutOfResources(shared_mem, max_shared_memory, "shared memory")
|
||||
return Binary(backend, name, asm, shared_mem, num_warps)
|
||||
|
||||
def add_to_cache(self, key, wargs, device_idx, num_warps, num_stages):
|
||||
tensor_idxs = [i for i, arg in enumerate(wargs) if hasattr(arg, 'data_ptr')]
|
||||
# attributes
|
||||
args = [arg.data_ptr() if i in tensor_idxs else arg for i, arg in enumerate(wargs)]
|
||||
attributes = {i: Kernel.pow2_divisor(a) for i, a in enumerate(args) \
|
||||
if isinstance(a, int) and i not in self.fn.do_not_specialize}
|
||||
|
||||
# transforms ints whose value is one into constants for just-in-time compilation
|
||||
constants = {i: arg for i, arg in enumerate(wargs) if isinstance(arg, int) and arg == 1}
|
||||
constants.update({i: arg.value for i, arg in enumerate(wargs) if isinstance(arg, triton.language.constexpr)})
|
||||
hashed_key = hashlib.md5(key.encode("utf-8")).hexdigest()
|
||||
|
||||
# create cache directory
|
||||
cache_dir = os.environ.get('TRITON_CACHE_DIR', '/tmp/triton/')
|
||||
if cache_dir and not os.path.exists(cache_dir):
|
||||
os.makedirs(cache_dir, exist_ok=True)
|
||||
|
||||
if cache_dir:
|
||||
bin_cache_path = os.path.join(cache_dir, hashed_key)
|
||||
bin_lock_path = bin_cache_path + ".lock"
|
||||
else:
|
||||
bin_cache_path = None
|
||||
bin_lock_path = None
|
||||
|
||||
binary = None
|
||||
if bin_cache_path and os.path.exists(bin_cache_path):
|
||||
assert bin_lock_path is not None
|
||||
with FileLock(bin_lock_path):
|
||||
with open(bin_cache_path, 'rb') as f:
|
||||
binary = pickle.load(f)["binary"]
|
||||
if binary is None:
|
||||
binary = self._compile(
|
||||
*wargs, device=device_idx, attributes=attributes,
|
||||
num_warps=num_warps, num_stages=num_stages,
|
||||
constants=constants,
|
||||
)
|
||||
if bin_cache_path:
|
||||
assert bin_lock_path is not None
|
||||
with FileLock(bin_lock_path):
|
||||
with open(bin_cache_path + ".tmp", "wb") as f:
|
||||
pickle.dump({"binary": binary, "key": key}, f)
|
||||
os.rename(bin_cache_path + ".tmp", bin_cache_path)
|
||||
if JITFunction.cache_hook is not None:
|
||||
JITFunction.cache_hook(key=key, binary=binary)
|
||||
|
||||
self.fn.bin_cache[key] = LoadedBinary(device_idx, binary)
|
||||
|
||||
def __call__(self, *wargs, grid, num_warps=4, num_stages=2, **kwargs):
|
||||
# handle arguments passed by name
|
||||
kwargs = {self.fn.arg_names.index(name): value for name, value in kwargs.items()}
|
||||
@@ -608,112 +646,21 @@ class Kernel:
|
||||
if len(wargs) != len(self.fn.arg_names):
|
||||
raise TypeError(f"Function takes {len(self.fn.arg_names)} positional arguments but {len(wargs)} were given")
|
||||
# handle annotations
|
||||
for name, type in self.fn.__annotations__.items():
|
||||
pos = self.fn.arg_names.index(name)
|
||||
assert type == triton.language.core.constexpr
|
||||
wargs[pos] = type(wargs[pos])
|
||||
# device inference
|
||||
tensor_idxs = [i for i, arg in enumerate(wargs) if hasattr(arg, 'data_ptr')]
|
||||
if len(tensor_idxs) == 0:
|
||||
raise ValueError("No Tensor argument found.")
|
||||
invalid_args = []
|
||||
device_ids = []
|
||||
for idx in tensor_idxs:
|
||||
curr = wargs[idx]
|
||||
if not curr.is_cuda:
|
||||
invalid_args.append(idx)
|
||||
else:
|
||||
device_ids.append(curr.device.index)
|
||||
if invalid_args:
|
||||
raise ValueError("Arguments at index {invalid_args} are on the wrong device.".format(invalid_args=invalid_args) +
|
||||
" Only CUDA is supported at the moment")
|
||||
|
||||
device = torch.device('cuda', torch.cuda.current_device())
|
||||
device_idx = device.index
|
||||
# if len(set(device_ids)) != 1 or device_ids[0] != device_idx:
|
||||
# # try to enable P2P communication
|
||||
# for arg_idx, dst_idx in zip(tensor_idxs, device_ids):
|
||||
# if dst_idx != device_idx:
|
||||
# try:
|
||||
# _triton.runtime.enable_peer_access(self.backend, wargs[arg_idx].data_ptr())
|
||||
# except RuntimeError as e:
|
||||
# raise RuntimeError("Cannot enable P2P access from device {} to device {}: {}"
|
||||
# .format(device_idx, dst_idx, str(e)))
|
||||
|
||||
# enqueue kernel on the current device
|
||||
torch.cuda.set_device(device_idx)
|
||||
# attributes
|
||||
args = [arg.data_ptr() if i in tensor_idxs else arg for i, arg in enumerate(wargs)]
|
||||
attributes = {i: Kernel.pow2_divisor(a) for i, a in enumerate(args) \
|
||||
if isinstance(a, int) and i not in self.fn.do_not_specialize}
|
||||
|
||||
# transforms ints whose value is one into constants for just-in-time compilation
|
||||
constants = {i: arg for i, arg in enumerate(wargs) if isinstance(arg, int) and arg == 1}
|
||||
constants.update({i: arg.value for i, arg in enumerate(wargs) if isinstance(arg, triton.language.constexpr)})
|
||||
|
||||
# compute hash for caching this kernel
|
||||
types_key = Kernel._types_key(*wargs, tensor_idxs=tensor_idxs)
|
||||
attr_key = tuple(attributes.items())
|
||||
const_key = tuple(constants.items())
|
||||
compute_capability = torch.cuda.get_device_capability(device)
|
||||
key = (
|
||||
self.fn.cache_key, version_key(), compute_capability,
|
||||
types_key, attr_key, num_warps, num_stages, const_key
|
||||
)
|
||||
key = repr(key)
|
||||
|
||||
# get cached binary
|
||||
drv_cache = self.fn.drv_cache
|
||||
|
||||
if key not in drv_cache:
|
||||
hashed_key = hashlib.md5(key.encode("utf-8")).hexdigest()
|
||||
|
||||
# create cache directory
|
||||
cache_dir = os.environ.get('TRITON_CACHE_DIR', '/tmp/triton/')
|
||||
if cache_dir and not os.path.exists(cache_dir):
|
||||
os.makedirs(cache_dir, exist_ok=True)
|
||||
|
||||
if cache_dir:
|
||||
bin_cache_path = os.path.join(cache_dir, hashed_key)
|
||||
bin_lock_path = bin_cache_path + ".lock"
|
||||
else:
|
||||
bin_cache_path = None
|
||||
bin_lock_path = None
|
||||
|
||||
binary = None
|
||||
if bin_cache_path and os.path.exists(bin_cache_path):
|
||||
assert bin_lock_path is not None
|
||||
with FileLock(bin_lock_path):
|
||||
with open(bin_cache_path, 'rb') as f:
|
||||
binary = pickle.load(f)["binary"]
|
||||
if binary is None:
|
||||
binary = self._compile(
|
||||
*wargs, device=device_idx, attributes=attributes,
|
||||
num_warps=num_warps, num_stages=num_stages,
|
||||
constants=constants,
|
||||
)
|
||||
if bin_cache_path:
|
||||
assert bin_lock_path is not None
|
||||
with FileLock(bin_lock_path):
|
||||
with open(bin_cache_path + ".tmp", "wb") as f:
|
||||
pickle.dump({"binary": binary, "key": key}, f)
|
||||
os.rename(bin_cache_path + ".tmp", bin_cache_path)
|
||||
if JITFunction.cache_hook is not None:
|
||||
JITFunction.cache_hook(key=key, binary=binary)
|
||||
|
||||
drv_cache[key] = LoadedBinary(device_idx, binary)
|
||||
# pack arguments
|
||||
fmt = ''.join(['P' if i in tensor_idxs else Kernel._type_name(arg) for i, arg in enumerate(wargs) if not isinstance(arg, triton.language.core.constexpr)])
|
||||
params = struct.pack(fmt, *[arg for arg in args if not isinstance(arg, triton.language.core.constexpr)])
|
||||
# enqueue cached function into stream
|
||||
callable = drv_cache[key]
|
||||
stream = torch.cuda.current_stream(device_idx).cuda_stream
|
||||
csts = {self.fn.arg_names[i]: arg.value for i, arg in enumerate(wargs) if isinstance(arg, triton.language.core.constexpr)}
|
||||
grid = grid(csts) if hasattr(grid, '__call__') else grid
|
||||
if isinstance(grid, int):
|
||||
grid = tuple(grid)
|
||||
callable(stream, params, *grid)
|
||||
return callable
|
||||
for pos, _type in self.fn.annotations.items():
|
||||
wargs[pos] = _type(wargs[pos])
|
||||
# query device index and cuda stream
|
||||
device = torch.cuda.current_device()
|
||||
# query stream
|
||||
# this is hacky but much faster than `torch.cuda.current_stream(device).cuda_stream`
|
||||
# https://github.com/pytorch/pytorch/blob/master/c10/core/Stream.h#L154
|
||||
# building a C wrapper to re-use the unpack function would add a build-time torch dependency
|
||||
# and require different wheels for different torch versions -- undesirable!
|
||||
bits = torch._C._cuda_getCurrentStream(device)
|
||||
mask = 1 << 47
|
||||
stream = ((bits & 0xFFFFFFFFFFFF) ^ mask) - mask
|
||||
# make key for cache
|
||||
return _triton.runtime.launch(wargs, self.fn.cache_key, self.fn.arg_names, device, stream,
|
||||
self.fn.bin_cache, num_warps, num_stages, self.add_to_cache, grid)
|
||||
|
||||
|
||||
class Launcher:
|
||||
@@ -723,6 +670,7 @@ class Launcher:
|
||||
|
||||
def __call__(self, *wargs, **kwargs):
|
||||
return self.kernel(*wargs, **kwargs, grid=self.grid)
|
||||
|
||||
|
||||
|
||||
class Autotuner:
|
||||
@@ -773,6 +721,11 @@ class Autotuner:
|
||||
return self.kernel(*args, num_warps=config.num_warps, num_stages=config.num_stages, **kwargs, **config.kwargs)
|
||||
|
||||
|
||||
@functools.lru_cache()
|
||||
def compute_capability():
|
||||
device = torch.device('cuda', 0)
|
||||
return '-'.join(map(str, torch.cuda.get_device_capability(device)))
|
||||
|
||||
@functools.lru_cache()
|
||||
def version_key():
|
||||
import pkgutil
|
||||
@@ -784,22 +737,27 @@ def version_key():
|
||||
with open(triton._C.libtriton.__file__, "rb") as f:
|
||||
contents += [hashlib.md5(f.read()).hexdigest()]
|
||||
# language
|
||||
for lib in pkgutil.iter_modules(triton.language.__path__):
|
||||
language_path = os.path.join(*triton.__path__, 'language')
|
||||
for lib in pkgutil.iter_modules([language_path]):
|
||||
with open(lib.module_finder.find_spec(lib.name).origin, "rb") as f:
|
||||
contents += [hashlib.md5(f.read()).hexdigest()]
|
||||
# ptxas version
|
||||
try:
|
||||
ptxas_version = hashlib.md5(subprocess.check_output(["ptxas", "--version"])).hexdigest()
|
||||
except Exception:
|
||||
ptxas_version = None
|
||||
return (triton.__version__, ptxas_version) + tuple(contents)
|
||||
ptxas_version = ''
|
||||
return '-'.join(triton.__version__) + '-' + ptxas_version + '-' + '-'.join(contents)
|
||||
|
||||
class JITFunction:
|
||||
|
||||
cache_hook = None
|
||||
|
||||
def _set_cache_key(self):
|
||||
self.cache_key = (hashlib.md5(self.src.encode("utf-8")).hexdigest(), self.version)
|
||||
self.cache_key = hashlib.md5(self.src.encode("utf-8")).hexdigest()
|
||||
self.cache_key += str(self.version)
|
||||
self.cache_key += version_key()
|
||||
self.cache_key += compute_capability()
|
||||
self.cache_key = hashlib.md5(self.cache_key.encode("utf-8")).hexdigest()
|
||||
|
||||
def __init__(self, fn, version=None, do_not_specialize=None):
|
||||
# information of wrapped function
|
||||
@@ -811,7 +769,7 @@ class JITFunction:
|
||||
self.do_not_specialize = [] if do_not_specialize is None else\
|
||||
[self.arg_names.index(arg) for arg in do_not_specialize]
|
||||
# cache for callable driver objects (e.g. CUkernel)
|
||||
self.drv_cache = dict()
|
||||
self.bin_cache = dict()
|
||||
# cache for binaries (on-disk)
|
||||
self._set_cache_key()
|
||||
# JITFunction can be instantiated as kernel
|
||||
@@ -819,6 +777,7 @@ class JITFunction:
|
||||
self.kernel_decorators = []
|
||||
self.kernel = None
|
||||
# annotations
|
||||
self.annotations = {self.arg_names.index(name): ty for name, ty in fn.__annotations__.items()}
|
||||
self.__annotations__ = fn.__annotations__
|
||||
# forward docs
|
||||
self.__doc__ = fn.__doc__
|
||||
@@ -834,7 +793,7 @@ class JITFunction:
|
||||
assert isinstance(tree.body[0], ast.FunctionDef)
|
||||
return tree
|
||||
|
||||
def __call__(self, *args, generator: CodeGenerator, **meta):
|
||||
def __call__(self, *args, generator: CodeGenerator):
|
||||
try:
|
||||
gscope = generator.gscope.copy()
|
||||
lscope = generator.lscope.copy()
|
||||
|
Reference in New Issue
Block a user