[FRONTEND] improved caching mechanism (#474)
Co-authored-by: Greg Brockman <gdb@gregbrockman.com> Co-authored-by: Christopher Hesse <christopherhesse@users.noreply.github.com>
This commit is contained in:
@@ -299,8 +299,12 @@ void init_triton_runtime(py::module &&m) {
|
|||||||
|
|
||||||
// get cached binary
|
// get cached binary
|
||||||
py::str key(cache_key);
|
py::str key(cache_key);
|
||||||
if(!bin_cache.contains(key))
|
py::bool_ noop = false;
|
||||||
add_to_cache(key, args, device, num_warps, num_stages);
|
if(!bin_cache.contains(key)) {
|
||||||
|
noop = add_to_cache(key, args, device, num_warps, num_stages);
|
||||||
|
}
|
||||||
|
if (noop)
|
||||||
|
return (py::object)py::none();
|
||||||
py::object bin = bin_cache[key];
|
py::object bin = bin_cache[key];
|
||||||
|
|
||||||
// get grid
|
// get grid
|
||||||
@@ -529,6 +533,7 @@ void init_triton_codegen(py::module &&m) {
|
|||||||
return hip_compile_ttir(name, ir, device, num_warps, num_stages, asm_map);
|
return hip_compile_ttir(name, ir, device, num_warps, num_stages, asm_map);
|
||||||
}, py::return_value_policy::take_ownership);
|
}, py::return_value_policy::take_ownership);
|
||||||
m.def("load_binary", [](backend_t backend, const std::string& name, asm_map_t &asm_map, size_t n_shared_bytes, uint64_t dev){
|
m.def("load_binary", [](backend_t backend, const std::string& name, asm_map_t &asm_map, size_t n_shared_bytes, uint64_t dev){
|
||||||
|
py::gil_scoped_release allow_threads;
|
||||||
if(backend == CUDA)
|
if(backend == CUDA)
|
||||||
return cu_load_binary(name, asm_map, n_shared_bytes, dev);
|
return cu_load_binary(name, asm_map, n_shared_bytes, dev);
|
||||||
if(backend == ROCM)
|
if(backend == ROCM)
|
||||||
|
@@ -76,7 +76,7 @@ def reset_tmp_dir():
|
|||||||
def test_reuse():
|
def test_reuse():
|
||||||
counter = 0
|
counter = 0
|
||||||
|
|
||||||
def inc_counter(key, binary, repr):
|
def inc_counter(*args, **kwargs):
|
||||||
nonlocal counter
|
nonlocal counter
|
||||||
counter += 1
|
counter += 1
|
||||||
JITFunction.cache_hook = inc_counter
|
JITFunction.cache_hook = inc_counter
|
||||||
@@ -91,7 +91,7 @@ def test_reuse():
|
|||||||
def test_specialize(mode):
|
def test_specialize(mode):
|
||||||
counter = 0
|
counter = 0
|
||||||
|
|
||||||
def inc_counter(key, binary, repr):
|
def inc_counter(*args, **kwargs):
|
||||||
nonlocal counter
|
nonlocal counter
|
||||||
counter += 1
|
counter += 1
|
||||||
JITFunction.cache_hook = inc_counter
|
JITFunction.cache_hook = inc_counter
|
||||||
|
@@ -602,8 +602,19 @@ class Kernel:
|
|||||||
return 'str'
|
return 'str'
|
||||||
raise NotImplementedError(f'could not compute type name for {obj}')
|
raise NotImplementedError(f'could not compute type name for {obj}')
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _to_python_ir(obj):
|
||||||
|
# convert torch.Tensor to Triton IR pointers
|
||||||
|
if hasattr(obj, 'data_ptr'):
|
||||||
|
name = Kernel._type_name(obj)
|
||||||
|
return 'ptr', name
|
||||||
|
# default path returns triton.ir.type directly
|
||||||
|
name = Kernel._type_name(obj)
|
||||||
|
return 'scalar', name
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _to_triton_ir(context, obj):
|
def _to_triton_ir(context, obj):
|
||||||
|
which, name = obj
|
||||||
type_map = {
|
type_map = {
|
||||||
'I': _triton.ir.type.get_int32,
|
'I': _triton.ir.type.get_int32,
|
||||||
'L': _triton.ir.type.get_int64,
|
'L': _triton.ir.type.get_int64,
|
||||||
@@ -625,12 +636,10 @@ class Kernel:
|
|||||||
'u64': _triton.ir.type.get_uint64,
|
'u64': _triton.ir.type.get_uint64,
|
||||||
}
|
}
|
||||||
# convert torch.Tensor to Triton IR pointers
|
# convert torch.Tensor to Triton IR pointers
|
||||||
if hasattr(obj, 'data_ptr'):
|
if which == 'ptr':
|
||||||
name = Kernel._type_name(obj)
|
|
||||||
elt_ty = type_map[name](context)
|
elt_ty = type_map[name](context)
|
||||||
return _triton.ir.type.make_ptr(elt_ty, 1)
|
return _triton.ir.type.make_ptr(elt_ty, 1)
|
||||||
# default path returns triton.ir.type directly
|
# default path returns triton.ir.type directly
|
||||||
name = Kernel._type_name(obj)
|
|
||||||
return type_map[name](context)
|
return type_map[name](context)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@@ -648,36 +657,6 @@ class Kernel:
|
|||||||
def __init__(self, fn):
|
def __init__(self, fn):
|
||||||
self.fn = fn
|
self.fn = fn
|
||||||
|
|
||||||
def _compile(self, *wargs, device, attributes, constants, num_warps, num_stages):
|
|
||||||
# create IR module
|
|
||||||
context = _triton.ir.context()
|
|
||||||
# get just-in-time proto-type of kernel
|
|
||||||
fn_args = [arg for i, arg in enumerate(wargs) if i not in constants]
|
|
||||||
arg_types = [Kernel._to_triton_ir(context, arg) for arg in fn_args]
|
|
||||||
ret_type = _triton.ir.type.get_void(context)
|
|
||||||
prototype = _triton.ir.type.make_function(ret_type, arg_types)
|
|
||||||
# generate Triton-IR
|
|
||||||
# export symbols visible from self.fn into code-generator object
|
|
||||||
gscope = self.fn.__globals__
|
|
||||||
generator = CodeGenerator(context, prototype, gscope=gscope, attributes=attributes, constants=constants, kwargs=dict())
|
|
||||||
try:
|
|
||||||
generator.visit(self.fn.parse())
|
|
||||||
except Exception as e:
|
|
||||||
node = generator.last_node
|
|
||||||
if node is None or isinstance(e, (NotImplementedError, CompilationError)):
|
|
||||||
raise e
|
|
||||||
raise CompilationError(self.fn.src, node) from e
|
|
||||||
# Compile to machine code
|
|
||||||
if torch.version.hip is None:
|
|
||||||
backend = _triton.runtime.backend.CUDA
|
|
||||||
else:
|
|
||||||
backend = _triton.runtime.backend.ROCM
|
|
||||||
name, asm, shared_mem = _triton.code_gen.compile_ttir(backend, generator.module, device, num_warps, num_stages)
|
|
||||||
max_shared_memory = _triton.runtime.max_shared_memory(backend, device)
|
|
||||||
if shared_mem > max_shared_memory:
|
|
||||||
raise OutOfResources(shared_mem, max_shared_memory, "shared memory")
|
|
||||||
return Binary(backend, name, asm, shared_mem, num_warps)
|
|
||||||
|
|
||||||
def add_to_cache(self, key, wargs, device_idx, num_warps, num_stages):
|
def add_to_cache(self, key, wargs, device_idx, num_warps, num_stages):
|
||||||
tensor_idxs = [i for i, arg in enumerate(wargs) if hasattr(arg, 'data_ptr')]
|
tensor_idxs = [i for i, arg in enumerate(wargs) if hasattr(arg, 'data_ptr')]
|
||||||
# attributes
|
# attributes
|
||||||
@@ -692,57 +671,12 @@ class Kernel:
|
|||||||
range_size = _triton.runtime.get_pointer_range_size(addr)
|
range_size = _triton.runtime.get_pointer_range_size(addr)
|
||||||
attributes[i] = min(Kernel.pow2_divisor(addr),
|
attributes[i] = min(Kernel.pow2_divisor(addr),
|
||||||
Kernel.pow2_divisor(range_size))
|
Kernel.pow2_divisor(range_size))
|
||||||
|
|
||||||
# transforms ints whose value is one into constants for just-in-time compilation
|
# transforms ints whose value is one into constants for just-in-time compilation
|
||||||
constants = {i: arg for i, arg in enumerate(wargs) if isinstance(arg, int) and arg == 1 and i not in self.fn.do_not_specialize}
|
constants = {i: arg for i, arg in enumerate(wargs) if isinstance(arg, int) and arg == 1 and i not in self.fn.do_not_specialize}
|
||||||
constants.update({i: arg.value for i, arg in enumerate(wargs) if isinstance(arg, triton.language.constexpr)})
|
constants.update({i: arg.value for i, arg in enumerate(wargs) if isinstance(arg, triton.language.constexpr)})
|
||||||
constants.update({i: None for i, arg in enumerate(wargs) if arg is None})
|
constants.update({i: None for i, arg in enumerate(wargs) if arg is None})
|
||||||
hashed_key = hashlib.md5(key.encode("utf-8")).hexdigest()
|
arg_types = [Kernel._to_python_ir(arg) for i, arg in enumerate(wargs) if i not in constants]
|
||||||
|
return self.fn._warmup(key, arg_types=arg_types, device=device_idx, attributes=attributes, constants=constants, num_warps=num_warps, num_stages=num_stages, is_manual_warmup=False)
|
||||||
# create cache directory
|
|
||||||
cache_dir = os.environ.get('TRITON_CACHE_DIR', '/tmp/triton/')
|
|
||||||
if cache_dir and not os.path.exists(cache_dir):
|
|
||||||
os.makedirs(cache_dir, exist_ok=True)
|
|
||||||
|
|
||||||
if cache_dir:
|
|
||||||
bin_cache_path = os.path.join(cache_dir, hashed_key)
|
|
||||||
bin_lock_path = bin_cache_path + ".lock"
|
|
||||||
else:
|
|
||||||
bin_cache_path = None
|
|
||||||
bin_lock_path = None
|
|
||||||
|
|
||||||
binary = None
|
|
||||||
if bin_cache_path and os.path.exists(bin_cache_path):
|
|
||||||
assert bin_lock_path is not None
|
|
||||||
with FileLock(bin_lock_path):
|
|
||||||
with open(bin_cache_path, 'rb') as f:
|
|
||||||
binary = pickle.load(f)["binary"]
|
|
||||||
if binary is None:
|
|
||||||
binary = self._compile(
|
|
||||||
*wargs, device=device_idx, attributes=attributes,
|
|
||||||
num_warps=num_warps, num_stages=num_stages,
|
|
||||||
constants=constants,
|
|
||||||
)
|
|
||||||
if bin_cache_path:
|
|
||||||
assert bin_lock_path is not None
|
|
||||||
with FileLock(bin_lock_path):
|
|
||||||
with open(bin_cache_path + ".tmp", "wb") as f:
|
|
||||||
pickle.dump({"binary": binary, "key": key}, f)
|
|
||||||
os.rename(bin_cache_path + ".tmp", bin_cache_path)
|
|
||||||
if JITFunction.cache_hook is not None:
|
|
||||||
name = self.fn.__name__
|
|
||||||
info = key.split('-')[-3:]
|
|
||||||
num_warps, num_stages, sig = info[0], info[1], info[2].split('_')[1:]
|
|
||||||
# make signature human-readable
|
|
||||||
arg_reprs = []
|
|
||||||
for arg_name, arg_sig in zip(self.fn.arg_names, sig):
|
|
||||||
arg_reprs.append(f'{arg_name}: {arg_sig}')
|
|
||||||
# assemble the repr
|
|
||||||
arg_reprs = ", ".join(arg_reprs)
|
|
||||||
repr = f"{name}[num_warps={num_warps}, num_stages={num_stages}]({arg_reprs})"
|
|
||||||
JITFunction.cache_hook(key=key, binary=binary, repr=repr)
|
|
||||||
|
|
||||||
self.fn.bin_cache[key] = LoadedBinary(device_idx, binary)
|
|
||||||
|
|
||||||
def __call__(self, *wargs, grid, num_warps=4, num_stages=2, **kwargs):
|
def __call__(self, *wargs, grid, num_warps=4, num_stages=2, **kwargs):
|
||||||
# handle arguments passed by name
|
# handle arguments passed by name
|
||||||
@@ -1027,6 +961,89 @@ class JITFunction:
|
|||||||
self.kernel = decorator(self.kernel)
|
self.kernel = decorator(self.kernel)
|
||||||
return self.kernel
|
return self.kernel
|
||||||
|
|
||||||
|
def warmup(self, compile):
|
||||||
|
return self._warmup(**compile, is_manual_warmup=True)
|
||||||
|
|
||||||
|
def _warmup(self, key, arg_types, device, attributes, constants, num_warps, num_stages, is_manual_warmup):
|
||||||
|
hashed_key = hashlib.md5(key.encode("utf-8")).hexdigest()
|
||||||
|
|
||||||
|
# create cache directory
|
||||||
|
cache_dir = os.environ.get('TRITON_CACHE_DIR', '/tmp/triton/')
|
||||||
|
if cache_dir:
|
||||||
|
os.makedirs(cache_dir, exist_ok=True)
|
||||||
|
|
||||||
|
if cache_dir:
|
||||||
|
bin_cache_path = os.path.join(cache_dir, hashed_key)
|
||||||
|
bin_lock_path = bin_cache_path + ".lock"
|
||||||
|
else:
|
||||||
|
bin_cache_path = None
|
||||||
|
bin_lock_path = None
|
||||||
|
|
||||||
|
binary = None
|
||||||
|
if bin_cache_path and os.path.exists(bin_cache_path):
|
||||||
|
assert bin_lock_path is not None
|
||||||
|
with FileLock(bin_lock_path):
|
||||||
|
with open(bin_cache_path, 'rb') as f:
|
||||||
|
binary = pickle.load(f)["binary"]
|
||||||
|
|
||||||
|
compile = dict(arg_types=arg_types, device=device, attributes=attributes, constants=constants, num_warps=num_warps, num_stages=num_stages)
|
||||||
|
if JITFunction.cache_hook is not None:
|
||||||
|
name = self.__name__
|
||||||
|
info = key.split('-')[-3:]
|
||||||
|
num_warps, num_stages, sig = info[0], info[1], info[2].split('_')[1:]
|
||||||
|
# make signature human-readable
|
||||||
|
arg_reprs = []
|
||||||
|
for arg_name, arg_sig in zip(self.arg_names, sig):
|
||||||
|
arg_reprs.append(f'{arg_name}: {arg_sig}')
|
||||||
|
# assemble the repr
|
||||||
|
arg_reprs = ", ".join(arg_reprs)
|
||||||
|
repr = f"{name}[num_warps={num_warps}, num_stages={num_stages}]({arg_reprs})"
|
||||||
|
noop = JITFunction.cache_hook(key=key, repr=repr, fn=self, compile={"key": key, **compile}, is_manual_warmup=is_manual_warmup, already_compiled=binary is not None)
|
||||||
|
if noop:
|
||||||
|
return True
|
||||||
|
|
||||||
|
if binary is None:
|
||||||
|
binary = self._compile(**compile)
|
||||||
|
|
||||||
|
if bin_cache_path:
|
||||||
|
assert bin_lock_path is not None
|
||||||
|
with FileLock(bin_lock_path):
|
||||||
|
with open(bin_cache_path + ".tmp", "wb") as f:
|
||||||
|
pickle.dump({"binary": binary, "key": key}, f)
|
||||||
|
os.rename(bin_cache_path + ".tmp", bin_cache_path)
|
||||||
|
|
||||||
|
self.bin_cache[key] = LoadedBinary(device, binary)
|
||||||
|
return False
|
||||||
|
|
||||||
|
def _compile(self, arg_types, device, attributes, constants, num_warps, num_stages):
|
||||||
|
# create IR module
|
||||||
|
context = _triton.ir.context()
|
||||||
|
# get just-in-time proto-type of kernel
|
||||||
|
arg_types = [Kernel._to_triton_ir(context, arg) for arg in arg_types]
|
||||||
|
ret_type = _triton.ir.type.get_void(context)
|
||||||
|
prototype = _triton.ir.type.make_function(ret_type, arg_types)
|
||||||
|
# generate Triton-IR
|
||||||
|
# export symbols visible from self into code-generator object
|
||||||
|
gscope = self.__globals__
|
||||||
|
generator = CodeGenerator(context, prototype, gscope=gscope, attributes=attributes, constants=constants, kwargs=dict())
|
||||||
|
try:
|
||||||
|
generator.visit(self.parse())
|
||||||
|
except Exception as e:
|
||||||
|
node = generator.last_node
|
||||||
|
if node is None or isinstance(e, (NotImplementedError, CompilationError)):
|
||||||
|
raise e
|
||||||
|
raise CompilationError(self.src, node) from e
|
||||||
|
# Compile to machine code
|
||||||
|
if torch.version.hip is None:
|
||||||
|
backend = _triton.runtime.backend.CUDA
|
||||||
|
else:
|
||||||
|
backend = _triton.runtime.backend.ROCM
|
||||||
|
name, asm, shared_mem = _triton.code_gen.compile_ttir(backend, generator.module, device, num_warps, num_stages)
|
||||||
|
max_shared_memory = _triton.runtime.max_shared_memory(backend, device)
|
||||||
|
if shared_mem > max_shared_memory:
|
||||||
|
raise OutOfResources(shared_mem, max_shared_memory, "shared memory")
|
||||||
|
return Binary(backend, name, asm, shared_mem, num_warps)
|
||||||
|
|
||||||
def __getitem__(self, grid):
|
def __getitem__(self, grid):
|
||||||
return Launcher(self._init_kernel(), grid)
|
return Launcher(self._init_kernel(), grid)
|
||||||
|
|
||||||
|
Reference in New Issue
Block a user