[RUNTIME] Bunch of bugfixes (#372)

This commit is contained in:
Philippe Tillet
2021-11-12 00:55:00 -08:00
committed by GitHub
parent f7ab96cfd7
commit e66bf76354
5 changed files with 128 additions and 17 deletions

View File

@@ -782,11 +782,13 @@ void generator::visit_cat_inst(ir::cat_inst* x) {
ir::value* lhs = x->get_operand(0); ir::value* lhs = x->get_operand(0);
ir::value* rhs = x->get_operand(1); ir::value* rhs = x->get_operand(1);
int i = 0; int i = 0;
for(size_t j = 0; j < idxs_.at(lhs).size(); j ++) for(size_t j = 0; j < idxs_.at(lhs).size(); j ++){
vals_[x][idxs_[x][i++]] = vals_[lhs][idxs_[lhs][j]]; vals_[x][idxs_[x][i++]] = vals_[lhs][idxs_[lhs][j]];
}
for(size_t j = 0; j < idxs_.at(rhs).size(); j ++){ for(size_t j = 0; j < idxs_.at(rhs).size(); j ++){
vals_[x][idxs_[x][i++]] = vals_[rhs][idxs_[rhs][j]]; vals_[x][idxs_[x][i++]] = vals_[rhs][idxs_[rhs][j]];
} }
// std::cout << "!" << std::endl;
} }

View File

@@ -0,0 +1,66 @@
import torch
import triton
from triton.code_gen import JITFunction
import triton.language as tl
import os
import shutil
tmpdir = ".tmp"
@triton.jit
def function_1(i):
i = i + 1
i = function_2(i)
return i
@triton.jit
def function_2(i):
i = i + 1
return i
@triton.jit
def kernel(X, i, BLOCK: tl.constexpr):
i = i + 1
i = function_1(i)
tl.store(X, i)
def apply_src_change(target, old, new):
delattr(kernel.fn, 'hash')
delattr(function_1.fn, 'hash')
delattr(function_2.fn, 'hash')
function_1.src = function_1.src.replace(old, new)
target.src = target.src.replace(old, new)
ret = target.cache_key
target.src = target.src.replace(new, old)
return ret
def test_nochange():
baseline = kernel.cache_key
updated = apply_src_change(kernel, 'i + 1', 'i + 1')
assert baseline == updated
def test_toplevel_change():
baseline = kernel.cache_key
updated = apply_src_change(kernel, 'i + 1', 'i + 2')
assert baseline != updated
def test_nested1_change():
baseline = kernel.cache_key
updated = apply_src_change(function_1, 'i + 1', 'i + 2')
assert baseline != updated
def test_reuse():
counter = 0
def inc_counter(key, binary):
nonlocal counter
counter += 1
os.environ["TRITON_CACHE_DIR"] = tmpdir
if os.path.exists(tmpdir):
shutil.rmtree(tmpdir)
JITFunction.cache_hook = inc_counter
x = torch.empty(1, dtype=torch.int32, device='cuda')
for i in range(10):
kernel[(1,)](x, 43, BLOCK=1024)
assert counter == 1

View File

@@ -6,7 +6,6 @@ __version__ = '1.1.1'
import torch import torch
# submodules # submodules
from .code_gen import cdiv, next_power_of_2, jit, autotune, heuristics, JITFunction, Config, Autotuner, reinterpret from .code_gen import cdiv, next_power_of_2, jit, autotune, heuristics, JITFunction, Config, Autotuner, reinterpret
from . import language from . import language
from . import code_gen from . import code_gen
from . import testing from . import testing

View File

@@ -412,6 +412,8 @@ class CodeGenerator(ast.NodeVisitor):
def visit_Call(self, node): def visit_Call(self, node):
fn = self.visit(node.func) fn = self.visit(node.func)
if isinstance(fn, triton.language.constexpr):
fn = fn.value
kws = dict() kws = dict()
for keyword in node.keywords: for keyword in node.keywords:
kws.update(self.visit(keyword)) kws.update(self.visit(keyword))
@@ -652,6 +654,9 @@ class Kernel:
wargs[pos] = _type(wargs[pos]) wargs[pos] = _type(wargs[pos])
# query device index and cuda stream # query device index and cuda stream
device = torch.cuda.current_device() device = torch.cuda.current_device()
torch.cuda.set_device(device)
cc = torch.cuda.get_device_capability(device)
cc = str(cc[0]) + '-' + str(cc[1])
# query stream # query stream
# this is hacky but much faster than `torch.cuda.current_stream(device).cuda_stream` # this is hacky but much faster than `torch.cuda.current_stream(device).cuda_stream`
# https://github.com/pytorch/pytorch/blob/master/c10/core/Stream.h#L154 # https://github.com/pytorch/pytorch/blob/master/c10/core/Stream.h#L154
@@ -660,8 +665,9 @@ class Kernel:
bits = torch._C._cuda_getCurrentStream(device) bits = torch._C._cuda_getCurrentStream(device)
mask = 1 << 47 mask = 1 << 47
stream = ((bits & 0xFFFFFFFFFFFF) ^ mask) - mask stream = ((bits & 0xFFFFFFFFFFFF) ^ mask) - mask
# stream = torch.cuda.current_stream(device).cuda_stream
# make key for cache # make key for cache
return _triton.runtime.launch(wargs, self.fn.cache_key, self.fn.arg_names, device, stream, return _triton.runtime.launch(wargs, self.fn.cache_key + cc, self.fn.arg_names, device, stream,
self.fn.bin_cache, num_warps, num_stages, self.add_to_cache, grid) self.fn.bin_cache, num_warps, num_stages, self.add_to_cache, grid)
@@ -723,11 +729,6 @@ class Autotuner:
return self.kernel(*args, num_warps=config.num_warps, num_stages=config.num_stages, **kwargs, **config.kwargs) return self.kernel(*args, num_warps=config.num_warps, num_stages=config.num_stages, **kwargs, **config.kwargs)
@functools.lru_cache()
def compute_capability():
device = torch.device('cuda', 0)
return '-'.join(map(str, torch.cuda.get_device_capability(device)))
@functools.lru_cache() @functools.lru_cache()
def version_key(): def version_key():
import pkgutil import pkgutil
@@ -750,16 +751,49 @@ def version_key():
ptxas_version = '' ptxas_version = ''
return '-'.join(triton.__version__) + '-' + ptxas_version + '-' + '-'.join(contents) return '-'.join(triton.__version__) + '-' + ptxas_version + '-' + '-'.join(contents)
#########################3
class DependenciesFinder(ast.NodeVisitor):
def __init__(self, globals, src) -> None:
super().__init__()
self.ret = hashlib.md5(src.encode("utf-8")).hexdigest()
self.globals = globals
def visit_Name(self, node):
return self.globals.get(node.id, None)
def visit_Attribute(self, node):
lhs = self.visit(node.value)
while isinstance(lhs, ast.Attribute):
lhs = self.visit(lhs.value)
if lhs is None or lhs is triton:
return None
return getattr(lhs, node.attr)
def visit_Call(self, node):
func = self.visit(node.func)
if func is None:
return
if isinstance(func, triton.JITFunction):
func = func.fn
module = inspect.getmodule(func)
if module and module.__name__.startswith('triton.'):
return
if not hasattr(func, 'hash'):
src = textwrap.dedent(inspect.getsource(func))
tree = ast.parse(src)
finder = DependenciesFinder(func.__globals__, src)
finder.visit(tree)
func.hash = finder.ret
self.ret = (self.ret + func.hash).encode("utf-8")
self.ret = hashlib.md5(self.ret).hexdigest()
class JITFunction: class JITFunction:
cache_hook = None cache_hook = None
def _set_cache_key(self):
self.cache_key = hashlib.md5(self.src.encode("utf-8")).hexdigest()
self.cache_key += str(self.version)
self.cache_key += version_key()
self.cache_key += compute_capability()
self.cache_key = hashlib.md5(self.cache_key.encode("utf-8")).hexdigest()
def __init__(self, fn, version=None, do_not_specialize=None): def __init__(self, fn, version=None, do_not_specialize=None):
# information of wrapped function # information of wrapped function
@@ -772,8 +806,6 @@ class JITFunction:
[self.arg_names.index(arg) for arg in do_not_specialize] [self.arg_names.index(arg) for arg in do_not_specialize]
# cache for callable driver objects (e.g. CUkernel) # cache for callable driver objects (e.g. CUkernel)
self.bin_cache = dict() self.bin_cache = dict()
# cache for binaries (on-disk)
self._set_cache_key()
# JITFunction can be instantiated as kernel # JITFunction can be instantiated as kernel
# when called with a grid using __getitem__ # when called with a grid using __getitem__
self.kernel_decorators = [] self.kernel_decorators = []
@@ -785,6 +817,15 @@ class JITFunction:
self.__doc__ = fn.__doc__ self.__doc__ = fn.__doc__
@property
@functools.lru_cache()
def cache_key(self):
if not hasattr(self.fn, 'hash'):
dependencies_finder = DependenciesFinder(globals=self.fn.__globals__, src=self.src)
dependencies_finder.visit(self.parse())
self.fn.hash = dependencies_finder.ret
return self.fn.hash
# we do not parse `src` in the constructor because # we do not parse `src` in the constructor because
# the user might want to monkey-patch self.src dynamically. # the user might want to monkey-patch self.src dynamically.
# Some unit tests do this, for example. # Some unit tests do this, for example.
@@ -821,7 +862,9 @@ class JITFunction:
self.kernel = None self.kernel = None
super(JITFunction, self).__setattr__(name, value) super(JITFunction, self).__setattr__(name, value)
if name == 'src': if name == 'src':
self._set_cache_key() if hasattr(self.fn, 'hash'):
delattr(self.fn, 'hash')
JITFunction.cache_key.fget.cache_clear()
def _init_kernel(self): def _init_kernel(self):
if self.kernel is None: if self.kernel is None:

View File

@@ -110,6 +110,7 @@ def randint4x(seed, offset):
:param offsets: The offsets to generate random numbers for. :param offsets: The offsets to generate random numbers for.
""" """
z = offset*0 #FIXME: just 0 doesn't work. Likelye some error with broadcasting z = offset*0 #FIXME: just 0 doesn't work. Likelye some error with broadcasting
seed = seed + 0
seed = hacky_to_uint64(seed) # uint will solve this seed = hacky_to_uint64(seed) # uint will solve this
seed_hi = ((seed >> 32) & 0xffffffff).to(tl.int32) seed_hi = ((seed >> 32) & 0xffffffff).to(tl.int32)
seed_lo = (seed & 0xffffffff).to(tl.int32) seed_lo = (seed & 0xffffffff).to(tl.int32)