From e66bf76354a63d1c0d49a2534b51031947ee5251 Mon Sep 17 00:00:00 2001 From: Philippe Tillet Date: Fri, 12 Nov 2021 00:55:00 -0800 Subject: [PATCH] [RUNTIME] Bunch of bugfixes (#372) --- lib/codegen/selection/generator.cc | 4 +- python/test/unit/runtime/test_cache.py | 66 +++++++++++++++++++++++ python/triton/__init__.py | 1 - python/triton/code_gen.py | 73 ++++++++++++++++++++------ python/triton/language/random.py | 1 + 5 files changed, 128 insertions(+), 17 deletions(-) create mode 100644 python/test/unit/runtime/test_cache.py diff --git a/lib/codegen/selection/generator.cc b/lib/codegen/selection/generator.cc index 7316e047a..eeabb6841 100644 --- a/lib/codegen/selection/generator.cc +++ b/lib/codegen/selection/generator.cc @@ -782,11 +782,13 @@ void generator::visit_cat_inst(ir::cat_inst* x) { ir::value* lhs = x->get_operand(0); ir::value* rhs = x->get_operand(1); int i = 0; - for(size_t j = 0; j < idxs_.at(lhs).size(); j ++) + for(size_t j = 0; j < idxs_.at(lhs).size(); j ++){ vals_[x][idxs_[x][i++]] = vals_[lhs][idxs_[lhs][j]]; + } for(size_t j = 0; j < idxs_.at(rhs).size(); j ++){ vals_[x][idxs_[x][i++]] = vals_[rhs][idxs_[rhs][j]]; } +// std::cout << "!" << std::endl; } diff --git a/python/test/unit/runtime/test_cache.py b/python/test/unit/runtime/test_cache.py new file mode 100644 index 000000000..215b90d8b --- /dev/null +++ b/python/test/unit/runtime/test_cache.py @@ -0,0 +1,66 @@ +import torch +import triton +from triton.code_gen import JITFunction +import triton.language as tl +import os +import shutil + +tmpdir = ".tmp" + +@triton.jit +def function_1(i): + i = i + 1 + i = function_2(i) + return i + + +@triton.jit +def function_2(i): + i = i + 1 + return i + +@triton.jit +def kernel(X, i, BLOCK: tl.constexpr): + i = i + 1 + i = function_1(i) + tl.store(X, i) + + +def apply_src_change(target, old, new): + delattr(kernel.fn, 'hash') + delattr(function_1.fn, 'hash') + delattr(function_2.fn, 'hash') + function_1.src = function_1.src.replace(old, new) + target.src = target.src.replace(old, new) + ret = target.cache_key + target.src = target.src.replace(new, old) + return ret + +def test_nochange(): + baseline = kernel.cache_key + updated = apply_src_change(kernel, 'i + 1', 'i + 1') + assert baseline == updated + +def test_toplevel_change(): + baseline = kernel.cache_key + updated = apply_src_change(kernel, 'i + 1', 'i + 2') + assert baseline != updated + +def test_nested1_change(): + baseline = kernel.cache_key + updated = apply_src_change(function_1, 'i + 1', 'i + 2') + assert baseline != updated + +def test_reuse(): + counter = 0 + def inc_counter(key, binary): + nonlocal counter + counter += 1 + os.environ["TRITON_CACHE_DIR"] = tmpdir + if os.path.exists(tmpdir): + shutil.rmtree(tmpdir) + JITFunction.cache_hook = inc_counter + x = torch.empty(1, dtype=torch.int32, device='cuda') + for i in range(10): + kernel[(1,)](x, 43, BLOCK=1024) + assert counter == 1 diff --git a/python/triton/__init__.py b/python/triton/__init__.py index a76df9b75..4b8c54703 100644 --- a/python/triton/__init__.py +++ b/python/triton/__init__.py @@ -6,7 +6,6 @@ __version__ = '1.1.1' import torch # submodules from .code_gen import cdiv, next_power_of_2, jit, autotune, heuristics, JITFunction, Config, Autotuner, reinterpret - from . import language from . import code_gen from . import testing diff --git a/python/triton/code_gen.py b/python/triton/code_gen.py index 22b910f5a..d418cb9d1 100644 --- a/python/triton/code_gen.py +++ b/python/triton/code_gen.py @@ -412,6 +412,8 @@ class CodeGenerator(ast.NodeVisitor): def visit_Call(self, node): fn = self.visit(node.func) + if isinstance(fn, triton.language.constexpr): + fn = fn.value kws = dict() for keyword in node.keywords: kws.update(self.visit(keyword)) @@ -652,6 +654,9 @@ class Kernel: wargs[pos] = _type(wargs[pos]) # query device index and cuda stream device = torch.cuda.current_device() + torch.cuda.set_device(device) + cc = torch.cuda.get_device_capability(device) + cc = str(cc[0]) + '-' + str(cc[1]) # query stream # this is hacky but much faster than `torch.cuda.current_stream(device).cuda_stream` # https://github.com/pytorch/pytorch/blob/master/c10/core/Stream.h#L154 @@ -660,8 +665,9 @@ class Kernel: bits = torch._C._cuda_getCurrentStream(device) mask = 1 << 47 stream = ((bits & 0xFFFFFFFFFFFF) ^ mask) - mask + # stream = torch.cuda.current_stream(device).cuda_stream # make key for cache - return _triton.runtime.launch(wargs, self.fn.cache_key, self.fn.arg_names, device, stream, + return _triton.runtime.launch(wargs, self.fn.cache_key + cc, self.fn.arg_names, device, stream, self.fn.bin_cache, num_warps, num_stages, self.add_to_cache, grid) @@ -723,11 +729,6 @@ class Autotuner: return self.kernel(*args, num_warps=config.num_warps, num_stages=config.num_stages, **kwargs, **config.kwargs) -@functools.lru_cache() -def compute_capability(): - device = torch.device('cuda', 0) - return '-'.join(map(str, torch.cuda.get_device_capability(device))) - @functools.lru_cache() def version_key(): import pkgutil @@ -750,16 +751,49 @@ def version_key(): ptxas_version = '' return '-'.join(triton.__version__) + '-' + ptxas_version + '-' + '-'.join(contents) +#########################3 + + +class DependenciesFinder(ast.NodeVisitor): + + def __init__(self, globals, src) -> None: + super().__init__() + self.ret = hashlib.md5(src.encode("utf-8")).hexdigest() + self.globals = globals + + def visit_Name(self, node): + return self.globals.get(node.id, None) + + def visit_Attribute(self, node): + lhs = self.visit(node.value) + while isinstance(lhs, ast.Attribute): + lhs = self.visit(lhs.value) + if lhs is None or lhs is triton: + return None + return getattr(lhs, node.attr) + + def visit_Call(self, node): + func = self.visit(node.func) + if func is None: + return + if isinstance(func, triton.JITFunction): + func = func.fn + module = inspect.getmodule(func) + if module and module.__name__.startswith('triton.'): + return + if not hasattr(func, 'hash'): + src = textwrap.dedent(inspect.getsource(func)) + tree = ast.parse(src) + finder = DependenciesFinder(func.__globals__, src) + finder.visit(tree) + func.hash = finder.ret + self.ret = (self.ret + func.hash).encode("utf-8") + self.ret = hashlib.md5(self.ret).hexdigest() + class JITFunction: cache_hook = None - def _set_cache_key(self): - self.cache_key = hashlib.md5(self.src.encode("utf-8")).hexdigest() - self.cache_key += str(self.version) - self.cache_key += version_key() - self.cache_key += compute_capability() - self.cache_key = hashlib.md5(self.cache_key.encode("utf-8")).hexdigest() def __init__(self, fn, version=None, do_not_specialize=None): # information of wrapped function @@ -772,8 +806,6 @@ class JITFunction: [self.arg_names.index(arg) for arg in do_not_specialize] # cache for callable driver objects (e.g. CUkernel) self.bin_cache = dict() - # cache for binaries (on-disk) - self._set_cache_key() # JITFunction can be instantiated as kernel # when called with a grid using __getitem__ self.kernel_decorators = [] @@ -785,6 +817,15 @@ class JITFunction: self.__doc__ = fn.__doc__ + @property + @functools.lru_cache() + def cache_key(self): + if not hasattr(self.fn, 'hash'): + dependencies_finder = DependenciesFinder(globals=self.fn.__globals__, src=self.src) + dependencies_finder.visit(self.parse()) + self.fn.hash = dependencies_finder.ret + return self.fn.hash + # we do not parse `src` in the constructor because # the user might want to monkey-patch self.src dynamically. # Some unit tests do this, for example. @@ -821,7 +862,9 @@ class JITFunction: self.kernel = None super(JITFunction, self).__setattr__(name, value) if name == 'src': - self._set_cache_key() + if hasattr(self.fn, 'hash'): + delattr(self.fn, 'hash') + JITFunction.cache_key.fget.cache_clear() def _init_kernel(self): if self.kernel is None: diff --git a/python/triton/language/random.py b/python/triton/language/random.py index 3a3d7f9e1..9bb29588a 100644 --- a/python/triton/language/random.py +++ b/python/triton/language/random.py @@ -110,6 +110,7 @@ def randint4x(seed, offset): :param offsets: The offsets to generate random numbers for. """ z = offset*0 #FIXME: just 0 doesn't work. Likelye some error with broadcasting + seed = seed + 0 seed = hacky_to_uint64(seed) # uint will solve this seed_hi = ((seed >> 32) & 0xffffffff).to(tl.int32) seed_lo = (seed & 0xffffffff).to(tl.int32)