[RUNTIME] Bunch of bugfixes (#372)
This commit is contained in:
@@ -782,11 +782,13 @@ void generator::visit_cat_inst(ir::cat_inst* x) {
|
||||
ir::value* lhs = x->get_operand(0);
|
||||
ir::value* rhs = x->get_operand(1);
|
||||
int i = 0;
|
||||
for(size_t j = 0; j < idxs_.at(lhs).size(); j ++)
|
||||
for(size_t j = 0; j < idxs_.at(lhs).size(); j ++){
|
||||
vals_[x][idxs_[x][i++]] = vals_[lhs][idxs_[lhs][j]];
|
||||
}
|
||||
for(size_t j = 0; j < idxs_.at(rhs).size(); j ++){
|
||||
vals_[x][idxs_[x][i++]] = vals_[rhs][idxs_[rhs][j]];
|
||||
}
|
||||
// std::cout << "!" << std::endl;
|
||||
}
|
||||
|
||||
|
||||
|
66
python/test/unit/runtime/test_cache.py
Normal file
66
python/test/unit/runtime/test_cache.py
Normal file
@@ -0,0 +1,66 @@
|
||||
import torch
|
||||
import triton
|
||||
from triton.code_gen import JITFunction
|
||||
import triton.language as tl
|
||||
import os
|
||||
import shutil
|
||||
|
||||
tmpdir = ".tmp"
|
||||
|
||||
@triton.jit
|
||||
def function_1(i):
|
||||
i = i + 1
|
||||
i = function_2(i)
|
||||
return i
|
||||
|
||||
|
||||
@triton.jit
|
||||
def function_2(i):
|
||||
i = i + 1
|
||||
return i
|
||||
|
||||
@triton.jit
|
||||
def kernel(X, i, BLOCK: tl.constexpr):
|
||||
i = i + 1
|
||||
i = function_1(i)
|
||||
tl.store(X, i)
|
||||
|
||||
|
||||
def apply_src_change(target, old, new):
|
||||
delattr(kernel.fn, 'hash')
|
||||
delattr(function_1.fn, 'hash')
|
||||
delattr(function_2.fn, 'hash')
|
||||
function_1.src = function_1.src.replace(old, new)
|
||||
target.src = target.src.replace(old, new)
|
||||
ret = target.cache_key
|
||||
target.src = target.src.replace(new, old)
|
||||
return ret
|
||||
|
||||
def test_nochange():
|
||||
baseline = kernel.cache_key
|
||||
updated = apply_src_change(kernel, 'i + 1', 'i + 1')
|
||||
assert baseline == updated
|
||||
|
||||
def test_toplevel_change():
|
||||
baseline = kernel.cache_key
|
||||
updated = apply_src_change(kernel, 'i + 1', 'i + 2')
|
||||
assert baseline != updated
|
||||
|
||||
def test_nested1_change():
|
||||
baseline = kernel.cache_key
|
||||
updated = apply_src_change(function_1, 'i + 1', 'i + 2')
|
||||
assert baseline != updated
|
||||
|
||||
def test_reuse():
|
||||
counter = 0
|
||||
def inc_counter(key, binary):
|
||||
nonlocal counter
|
||||
counter += 1
|
||||
os.environ["TRITON_CACHE_DIR"] = tmpdir
|
||||
if os.path.exists(tmpdir):
|
||||
shutil.rmtree(tmpdir)
|
||||
JITFunction.cache_hook = inc_counter
|
||||
x = torch.empty(1, dtype=torch.int32, device='cuda')
|
||||
for i in range(10):
|
||||
kernel[(1,)](x, 43, BLOCK=1024)
|
||||
assert counter == 1
|
@@ -6,7 +6,6 @@ __version__ = '1.1.1'
|
||||
import torch
|
||||
# submodules
|
||||
from .code_gen import cdiv, next_power_of_2, jit, autotune, heuristics, JITFunction, Config, Autotuner, reinterpret
|
||||
|
||||
from . import language
|
||||
from . import code_gen
|
||||
from . import testing
|
||||
|
@@ -412,6 +412,8 @@ class CodeGenerator(ast.NodeVisitor):
|
||||
|
||||
def visit_Call(self, node):
|
||||
fn = self.visit(node.func)
|
||||
if isinstance(fn, triton.language.constexpr):
|
||||
fn = fn.value
|
||||
kws = dict()
|
||||
for keyword in node.keywords:
|
||||
kws.update(self.visit(keyword))
|
||||
@@ -652,6 +654,9 @@ class Kernel:
|
||||
wargs[pos] = _type(wargs[pos])
|
||||
# query device index and cuda stream
|
||||
device = torch.cuda.current_device()
|
||||
torch.cuda.set_device(device)
|
||||
cc = torch.cuda.get_device_capability(device)
|
||||
cc = str(cc[0]) + '-' + str(cc[1])
|
||||
# query stream
|
||||
# this is hacky but much faster than `torch.cuda.current_stream(device).cuda_stream`
|
||||
# https://github.com/pytorch/pytorch/blob/master/c10/core/Stream.h#L154
|
||||
@@ -660,8 +665,9 @@ class Kernel:
|
||||
bits = torch._C._cuda_getCurrentStream(device)
|
||||
mask = 1 << 47
|
||||
stream = ((bits & 0xFFFFFFFFFFFF) ^ mask) - mask
|
||||
# stream = torch.cuda.current_stream(device).cuda_stream
|
||||
# make key for cache
|
||||
return _triton.runtime.launch(wargs, self.fn.cache_key, self.fn.arg_names, device, stream,
|
||||
return _triton.runtime.launch(wargs, self.fn.cache_key + cc, self.fn.arg_names, device, stream,
|
||||
self.fn.bin_cache, num_warps, num_stages, self.add_to_cache, grid)
|
||||
|
||||
|
||||
@@ -723,11 +729,6 @@ class Autotuner:
|
||||
return self.kernel(*args, num_warps=config.num_warps, num_stages=config.num_stages, **kwargs, **config.kwargs)
|
||||
|
||||
|
||||
@functools.lru_cache()
|
||||
def compute_capability():
|
||||
device = torch.device('cuda', 0)
|
||||
return '-'.join(map(str, torch.cuda.get_device_capability(device)))
|
||||
|
||||
@functools.lru_cache()
|
||||
def version_key():
|
||||
import pkgutil
|
||||
@@ -750,16 +751,49 @@ def version_key():
|
||||
ptxas_version = ''
|
||||
return '-'.join(triton.__version__) + '-' + ptxas_version + '-' + '-'.join(contents)
|
||||
|
||||
#########################3
|
||||
|
||||
|
||||
class DependenciesFinder(ast.NodeVisitor):
|
||||
|
||||
def __init__(self, globals, src) -> None:
|
||||
super().__init__()
|
||||
self.ret = hashlib.md5(src.encode("utf-8")).hexdigest()
|
||||
self.globals = globals
|
||||
|
||||
def visit_Name(self, node):
|
||||
return self.globals.get(node.id, None)
|
||||
|
||||
def visit_Attribute(self, node):
|
||||
lhs = self.visit(node.value)
|
||||
while isinstance(lhs, ast.Attribute):
|
||||
lhs = self.visit(lhs.value)
|
||||
if lhs is None or lhs is triton:
|
||||
return None
|
||||
return getattr(lhs, node.attr)
|
||||
|
||||
def visit_Call(self, node):
|
||||
func = self.visit(node.func)
|
||||
if func is None:
|
||||
return
|
||||
if isinstance(func, triton.JITFunction):
|
||||
func = func.fn
|
||||
module = inspect.getmodule(func)
|
||||
if module and module.__name__.startswith('triton.'):
|
||||
return
|
||||
if not hasattr(func, 'hash'):
|
||||
src = textwrap.dedent(inspect.getsource(func))
|
||||
tree = ast.parse(src)
|
||||
finder = DependenciesFinder(func.__globals__, src)
|
||||
finder.visit(tree)
|
||||
func.hash = finder.ret
|
||||
self.ret = (self.ret + func.hash).encode("utf-8")
|
||||
self.ret = hashlib.md5(self.ret).hexdigest()
|
||||
|
||||
class JITFunction:
|
||||
|
||||
cache_hook = None
|
||||
|
||||
def _set_cache_key(self):
|
||||
self.cache_key = hashlib.md5(self.src.encode("utf-8")).hexdigest()
|
||||
self.cache_key += str(self.version)
|
||||
self.cache_key += version_key()
|
||||
self.cache_key += compute_capability()
|
||||
self.cache_key = hashlib.md5(self.cache_key.encode("utf-8")).hexdigest()
|
||||
|
||||
def __init__(self, fn, version=None, do_not_specialize=None):
|
||||
# information of wrapped function
|
||||
@@ -772,8 +806,6 @@ class JITFunction:
|
||||
[self.arg_names.index(arg) for arg in do_not_specialize]
|
||||
# cache for callable driver objects (e.g. CUkernel)
|
||||
self.bin_cache = dict()
|
||||
# cache for binaries (on-disk)
|
||||
self._set_cache_key()
|
||||
# JITFunction can be instantiated as kernel
|
||||
# when called with a grid using __getitem__
|
||||
self.kernel_decorators = []
|
||||
@@ -785,6 +817,15 @@ class JITFunction:
|
||||
self.__doc__ = fn.__doc__
|
||||
|
||||
|
||||
@property
|
||||
@functools.lru_cache()
|
||||
def cache_key(self):
|
||||
if not hasattr(self.fn, 'hash'):
|
||||
dependencies_finder = DependenciesFinder(globals=self.fn.__globals__, src=self.src)
|
||||
dependencies_finder.visit(self.parse())
|
||||
self.fn.hash = dependencies_finder.ret
|
||||
return self.fn.hash
|
||||
|
||||
# we do not parse `src` in the constructor because
|
||||
# the user might want to monkey-patch self.src dynamically.
|
||||
# Some unit tests do this, for example.
|
||||
@@ -821,7 +862,9 @@ class JITFunction:
|
||||
self.kernel = None
|
||||
super(JITFunction, self).__setattr__(name, value)
|
||||
if name == 'src':
|
||||
self._set_cache_key()
|
||||
if hasattr(self.fn, 'hash'):
|
||||
delattr(self.fn, 'hash')
|
||||
JITFunction.cache_key.fget.cache_clear()
|
||||
|
||||
def _init_kernel(self):
|
||||
if self.kernel is None:
|
||||
|
@@ -110,6 +110,7 @@ def randint4x(seed, offset):
|
||||
:param offsets: The offsets to generate random numbers for.
|
||||
"""
|
||||
z = offset*0 #FIXME: just 0 doesn't work. Likelye some error with broadcasting
|
||||
seed = seed + 0
|
||||
seed = hacky_to_uint64(seed) # uint will solve this
|
||||
seed_hi = ((seed >> 32) & 0xffffffff).to(tl.int32)
|
||||
seed_lo = (seed & 0xffffffff).to(tl.int32)
|
||||
|
Reference in New Issue
Block a user