[FRONTEND] improved caching mechanism (#474)

Co-authored-by: Greg Brockman <gdb@gregbrockman.com>
Co-authored-by: Christopher Hesse <christopherhesse@users.noreply.github.com>
This commit is contained in:
Philippe Tillet
2022-03-15 12:20:51 -07:00
committed by GitHub
parent 21f8a0646d
commit d4d8eaf6c0
3 changed files with 106 additions and 84 deletions

View File

@@ -299,8 +299,12 @@ void init_triton_runtime(py::module &&m) {
// get cached binary // get cached binary
py::str key(cache_key); py::str key(cache_key);
if(!bin_cache.contains(key)) py::bool_ noop = false;
add_to_cache(key, args, device, num_warps, num_stages); if(!bin_cache.contains(key)) {
noop = add_to_cache(key, args, device, num_warps, num_stages);
}
if (noop)
return (py::object)py::none();
py::object bin = bin_cache[key]; py::object bin = bin_cache[key];
// get grid // get grid
@@ -529,6 +533,7 @@ void init_triton_codegen(py::module &&m) {
return hip_compile_ttir(name, ir, device, num_warps, num_stages, asm_map); return hip_compile_ttir(name, ir, device, num_warps, num_stages, asm_map);
}, py::return_value_policy::take_ownership); }, py::return_value_policy::take_ownership);
m.def("load_binary", [](backend_t backend, const std::string& name, asm_map_t &asm_map, size_t n_shared_bytes, uint64_t dev){ m.def("load_binary", [](backend_t backend, const std::string& name, asm_map_t &asm_map, size_t n_shared_bytes, uint64_t dev){
py::gil_scoped_release allow_threads;
if(backend == CUDA) if(backend == CUDA)
return cu_load_binary(name, asm_map, n_shared_bytes, dev); return cu_load_binary(name, asm_map, n_shared_bytes, dev);
if(backend == ROCM) if(backend == ROCM)

View File

@@ -76,7 +76,7 @@ def reset_tmp_dir():
def test_reuse(): def test_reuse():
counter = 0 counter = 0
def inc_counter(key, binary, repr): def inc_counter(*args, **kwargs):
nonlocal counter nonlocal counter
counter += 1 counter += 1
JITFunction.cache_hook = inc_counter JITFunction.cache_hook = inc_counter
@@ -91,7 +91,7 @@ def test_reuse():
def test_specialize(mode): def test_specialize(mode):
counter = 0 counter = 0
def inc_counter(key, binary, repr): def inc_counter(*args, **kwargs):
nonlocal counter nonlocal counter
counter += 1 counter += 1
JITFunction.cache_hook = inc_counter JITFunction.cache_hook = inc_counter

View File

@@ -602,8 +602,19 @@ class Kernel:
return 'str' return 'str'
raise NotImplementedError(f'could not compute type name for {obj}') raise NotImplementedError(f'could not compute type name for {obj}')
@staticmethod
def _to_python_ir(obj):
# convert torch.Tensor to Triton IR pointers
if hasattr(obj, 'data_ptr'):
name = Kernel._type_name(obj)
return 'ptr', name
# default path returns triton.ir.type directly
name = Kernel._type_name(obj)
return 'scalar', name
@staticmethod @staticmethod
def _to_triton_ir(context, obj): def _to_triton_ir(context, obj):
which, name = obj
type_map = { type_map = {
'I': _triton.ir.type.get_int32, 'I': _triton.ir.type.get_int32,
'L': _triton.ir.type.get_int64, 'L': _triton.ir.type.get_int64,
@@ -625,12 +636,10 @@ class Kernel:
'u64': _triton.ir.type.get_uint64, 'u64': _triton.ir.type.get_uint64,
} }
# convert torch.Tensor to Triton IR pointers # convert torch.Tensor to Triton IR pointers
if hasattr(obj, 'data_ptr'): if which == 'ptr':
name = Kernel._type_name(obj)
elt_ty = type_map[name](context) elt_ty = type_map[name](context)
return _triton.ir.type.make_ptr(elt_ty, 1) return _triton.ir.type.make_ptr(elt_ty, 1)
# default path returns triton.ir.type directly # default path returns triton.ir.type directly
name = Kernel._type_name(obj)
return type_map[name](context) return type_map[name](context)
@staticmethod @staticmethod
@@ -648,36 +657,6 @@ class Kernel:
def __init__(self, fn): def __init__(self, fn):
self.fn = fn self.fn = fn
def _compile(self, *wargs, device, attributes, constants, num_warps, num_stages):
# create IR module
context = _triton.ir.context()
# get just-in-time proto-type of kernel
fn_args = [arg for i, arg in enumerate(wargs) if i not in constants]
arg_types = [Kernel._to_triton_ir(context, arg) for arg in fn_args]
ret_type = _triton.ir.type.get_void(context)
prototype = _triton.ir.type.make_function(ret_type, arg_types)
# generate Triton-IR
# export symbols visible from self.fn into code-generator object
gscope = self.fn.__globals__
generator = CodeGenerator(context, prototype, gscope=gscope, attributes=attributes, constants=constants, kwargs=dict())
try:
generator.visit(self.fn.parse())
except Exception as e:
node = generator.last_node
if node is None or isinstance(e, (NotImplementedError, CompilationError)):
raise e
raise CompilationError(self.fn.src, node) from e
# Compile to machine code
if torch.version.hip is None:
backend = _triton.runtime.backend.CUDA
else:
backend = _triton.runtime.backend.ROCM
name, asm, shared_mem = _triton.code_gen.compile_ttir(backend, generator.module, device, num_warps, num_stages)
max_shared_memory = _triton.runtime.max_shared_memory(backend, device)
if shared_mem > max_shared_memory:
raise OutOfResources(shared_mem, max_shared_memory, "shared memory")
return Binary(backend, name, asm, shared_mem, num_warps)
def add_to_cache(self, key, wargs, device_idx, num_warps, num_stages): def add_to_cache(self, key, wargs, device_idx, num_warps, num_stages):
tensor_idxs = [i for i, arg in enumerate(wargs) if hasattr(arg, 'data_ptr')] tensor_idxs = [i for i, arg in enumerate(wargs) if hasattr(arg, 'data_ptr')]
# attributes # attributes
@@ -692,57 +671,12 @@ class Kernel:
range_size = _triton.runtime.get_pointer_range_size(addr) range_size = _triton.runtime.get_pointer_range_size(addr)
attributes[i] = min(Kernel.pow2_divisor(addr), attributes[i] = min(Kernel.pow2_divisor(addr),
Kernel.pow2_divisor(range_size)) Kernel.pow2_divisor(range_size))
# transforms ints whose value is one into constants for just-in-time compilation # transforms ints whose value is one into constants for just-in-time compilation
constants = {i: arg for i, arg in enumerate(wargs) if isinstance(arg, int) and arg == 1 and i not in self.fn.do_not_specialize} constants = {i: arg for i, arg in enumerate(wargs) if isinstance(arg, int) and arg == 1 and i not in self.fn.do_not_specialize}
constants.update({i: arg.value for i, arg in enumerate(wargs) if isinstance(arg, triton.language.constexpr)}) constants.update({i: arg.value for i, arg in enumerate(wargs) if isinstance(arg, triton.language.constexpr)})
constants.update({i: None for i, arg in enumerate(wargs) if arg is None}) constants.update({i: None for i, arg in enumerate(wargs) if arg is None})
hashed_key = hashlib.md5(key.encode("utf-8")).hexdigest() arg_types = [Kernel._to_python_ir(arg) for i, arg in enumerate(wargs) if i not in constants]
return self.fn._warmup(key, arg_types=arg_types, device=device_idx, attributes=attributes, constants=constants, num_warps=num_warps, num_stages=num_stages, is_manual_warmup=False)
# create cache directory
cache_dir = os.environ.get('TRITON_CACHE_DIR', '/tmp/triton/')
if cache_dir and not os.path.exists(cache_dir):
os.makedirs(cache_dir, exist_ok=True)
if cache_dir:
bin_cache_path = os.path.join(cache_dir, hashed_key)
bin_lock_path = bin_cache_path + ".lock"
else:
bin_cache_path = None
bin_lock_path = None
binary = None
if bin_cache_path and os.path.exists(bin_cache_path):
assert bin_lock_path is not None
with FileLock(bin_lock_path):
with open(bin_cache_path, 'rb') as f:
binary = pickle.load(f)["binary"]
if binary is None:
binary = self._compile(
*wargs, device=device_idx, attributes=attributes,
num_warps=num_warps, num_stages=num_stages,
constants=constants,
)
if bin_cache_path:
assert bin_lock_path is not None
with FileLock(bin_lock_path):
with open(bin_cache_path + ".tmp", "wb") as f:
pickle.dump({"binary": binary, "key": key}, f)
os.rename(bin_cache_path + ".tmp", bin_cache_path)
if JITFunction.cache_hook is not None:
name = self.fn.__name__
info = key.split('-')[-3:]
num_warps, num_stages, sig = info[0], info[1], info[2].split('_')[1:]
# make signature human-readable
arg_reprs = []
for arg_name, arg_sig in zip(self.fn.arg_names, sig):
arg_reprs.append(f'{arg_name}: {arg_sig}')
# assemble the repr
arg_reprs = ", ".join(arg_reprs)
repr = f"{name}[num_warps={num_warps}, num_stages={num_stages}]({arg_reprs})"
JITFunction.cache_hook(key=key, binary=binary, repr=repr)
self.fn.bin_cache[key] = LoadedBinary(device_idx, binary)
def __call__(self, *wargs, grid, num_warps=4, num_stages=2, **kwargs): def __call__(self, *wargs, grid, num_warps=4, num_stages=2, **kwargs):
# handle arguments passed by name # handle arguments passed by name
@@ -1027,6 +961,89 @@ class JITFunction:
self.kernel = decorator(self.kernel) self.kernel = decorator(self.kernel)
return self.kernel return self.kernel
def warmup(self, compile):
return self._warmup(**compile, is_manual_warmup=True)
def _warmup(self, key, arg_types, device, attributes, constants, num_warps, num_stages, is_manual_warmup):
hashed_key = hashlib.md5(key.encode("utf-8")).hexdigest()
# create cache directory
cache_dir = os.environ.get('TRITON_CACHE_DIR', '/tmp/triton/')
if cache_dir:
os.makedirs(cache_dir, exist_ok=True)
if cache_dir:
bin_cache_path = os.path.join(cache_dir, hashed_key)
bin_lock_path = bin_cache_path + ".lock"
else:
bin_cache_path = None
bin_lock_path = None
binary = None
if bin_cache_path and os.path.exists(bin_cache_path):
assert bin_lock_path is not None
with FileLock(bin_lock_path):
with open(bin_cache_path, 'rb') as f:
binary = pickle.load(f)["binary"]
compile = dict(arg_types=arg_types, device=device, attributes=attributes, constants=constants, num_warps=num_warps, num_stages=num_stages)
if JITFunction.cache_hook is not None:
name = self.__name__
info = key.split('-')[-3:]
num_warps, num_stages, sig = info[0], info[1], info[2].split('_')[1:]
# make signature human-readable
arg_reprs = []
for arg_name, arg_sig in zip(self.arg_names, sig):
arg_reprs.append(f'{arg_name}: {arg_sig}')
# assemble the repr
arg_reprs = ", ".join(arg_reprs)
repr = f"{name}[num_warps={num_warps}, num_stages={num_stages}]({arg_reprs})"
noop = JITFunction.cache_hook(key=key, repr=repr, fn=self, compile={"key": key, **compile}, is_manual_warmup=is_manual_warmup, already_compiled=binary is not None)
if noop:
return True
if binary is None:
binary = self._compile(**compile)
if bin_cache_path:
assert bin_lock_path is not None
with FileLock(bin_lock_path):
with open(bin_cache_path + ".tmp", "wb") as f:
pickle.dump({"binary": binary, "key": key}, f)
os.rename(bin_cache_path + ".tmp", bin_cache_path)
self.bin_cache[key] = LoadedBinary(device, binary)
return False
def _compile(self, arg_types, device, attributes, constants, num_warps, num_stages):
# create IR module
context = _triton.ir.context()
# get just-in-time proto-type of kernel
arg_types = [Kernel._to_triton_ir(context, arg) for arg in arg_types]
ret_type = _triton.ir.type.get_void(context)
prototype = _triton.ir.type.make_function(ret_type, arg_types)
# generate Triton-IR
# export symbols visible from self into code-generator object
gscope = self.__globals__
generator = CodeGenerator(context, prototype, gscope=gscope, attributes=attributes, constants=constants, kwargs=dict())
try:
generator.visit(self.parse())
except Exception as e:
node = generator.last_node
if node is None or isinstance(e, (NotImplementedError, CompilationError)):
raise e
raise CompilationError(self.src, node) from e
# Compile to machine code
if torch.version.hip is None:
backend = _triton.runtime.backend.CUDA
else:
backend = _triton.runtime.backend.ROCM
name, asm, shared_mem = _triton.code_gen.compile_ttir(backend, generator.module, device, num_warps, num_stages)
max_shared_memory = _triton.runtime.max_shared_memory(backend, device)
if shared_mem > max_shared_memory:
raise OutOfResources(shared_mem, max_shared_memory, "shared memory")
return Binary(backend, name, asm, shared_mem, num_warps)
def __getitem__(self, grid): def __getitem__(self, grid):
return Launcher(self._init_kernel(), grid) return Launcher(self._init_kernel(), grid)