diff --git a/python/src/triton.cc b/python/src/triton.cc index ce56d9c26..b44ffbc27 100644 --- a/python/src/triton.cc +++ b/python/src/triton.cc @@ -117,7 +117,11 @@ void parse_args(py::list& args, py::list do_not_specialize, const std::string& f params.reserve(8*len); // 8 max bytes by argument char* params_ptr = ¶ms[0]; cache_key = func_key; + cache_key += "-" + std::to_string(num_warps); + cache_key += "-" + std::to_string(num_stages); + cache_key += "-"; for(int i = 0; i < len; i++){ + cache_key += "_"; py::int_ py_i = py::int_(i); bool specialize = std::find(do_not_specialize.begin(), do_not_specialize.end(), py_i) == do_not_specialize.end(); py::object arg = args[i]; @@ -127,19 +131,20 @@ void parse_args(py::list& args, py::list do_not_specialize, const std::string& f if(PyLong_Check(arg_ptr)){ int overflow; long long value = PyLong_AsLongLongAndOverflow(arg_ptr, &overflow); + // values equal to 1 are specialized if(specialize && (value == 1)){ - cache_key += '1'; + cache_key += "1"; continue; } // long and int have different kernels if(!overflow & (std::abs(value) <= 0xffffffff)){ - cache_key += 'I'; + cache_key += "int32"; params_ptr = (char*)(((uintptr_t)params_ptr + 3) & (-4)); std::memcpy(params_ptr, &value, 4); params_ptr += 4; } else{ - cache_key += 'L'; + cache_key += "int64"; params_ptr = (char*)(((uintptr_t)params_ptr + 7) & (-8)); if(overflow){ unsigned long long uvalue = PyLong_AsUnsignedLongLong(arg_ptr); @@ -150,15 +155,15 @@ void parse_args(py::list& args, py::list do_not_specialize, const std::string& f } if(!specialize) continue; - // values equal to 1 are specialized - cache_key += 'x'; // values divisible by small powers of 2 are specialized + cache_key += "[multipleof("; cache_key += pow2_divisor(value); + cache_key += ")]"; continue; } // argument is `float` if(PyFloat_Check(arg_ptr)){ - cache_key += "f"; + cache_key += "float32"; float value = PyFloat_AsDouble(arg_ptr); params_ptr = (char*)(((uintptr_t)params_ptr + 3) & (-4)); std::memcpy(params_ptr, &value, 4); @@ -167,7 +172,7 @@ void parse_args(py::list& args, py::list do_not_specialize, const std::string& f } // argument is `bool` if(PyBool_Check(arg_ptr)){ - cache_key += "B"; + cache_key += "bool"; bool value = arg_ptr == Py_True ? true : false; std::memcpy(params_ptr, &value, 1); params_ptr += 1; @@ -176,7 +181,6 @@ void parse_args(py::list& args, py::list do_not_specialize, const std::string& f // argument is tensor if(py::hasattr(arg, "data_ptr")){ py::object data_ptr = arg.attr("data_ptr")(); - cache_key += "P"; long value = data_ptr.cast(); params_ptr = (char*)(((uintptr_t)params_ptr + 7) & (-8)); std::memcpy(params_ptr, &value, 8); @@ -186,6 +190,10 @@ void parse_args(py::list& args, py::list do_not_specialize, const std::string& f const char* start = (const char*)PyUnicode_1BYTE_DATA(repr.ptr()) + 6; // remove 'torch.' size_t len = PyUnicode_GET_LENGTH(repr.ptr()) - 6; cache_key += std::string(start, len); + cache_key += "*"; + cache_key += "[multipleof("; + cache_key += pow2_divisor(value); + cache_key += ")]"; continue; } // argument is `constexpr` @@ -208,8 +216,6 @@ void parse_args(py::list& args, py::list do_not_specialize, const std::string& f + " Only int, float, bool, torch.Tensor, and triton.language.constexpr are supported."; throw std::runtime_error(err_msg); } - cache_key += std::to_string(num_warps); - cache_key += std::to_string(num_stages); params_size = (std::ptrdiff_t)(params_ptr - ¶ms[0]); } diff --git a/python/test/unit/runtime/test_cache.py b/python/test/unit/runtime/test_cache.py index a1c994241..3ad387f09 100644 --- a/python/test/unit/runtime/test_cache.py +++ b/python/test/unit/runtime/test_cache.py @@ -64,20 +64,21 @@ def reset_tmp_dir(): def test_reuse(): counter = 0 - def inc_counter(key, binary): + def inc_counter(key, binary, repr): nonlocal counter counter += 1 JITFunction.cache_hook = inc_counter reset_tmp_dir() x = torch.empty(1, dtype=torch.int32, device='cuda') for i in range(10): - kernel[(1,)](x, 43, BLOCK=1024) + kernel[(1,)](x, 1, BLOCK=1024) assert counter == 1 + @pytest.mark.parametrize('mode', ['enable', 'disable']) def test_specialize(mode): counter = 0 - def inc_counter(key, binary): + def inc_counter(key, binary, repr): nonlocal counter counter += 1 JITFunction.cache_hook = inc_counter diff --git a/python/triton/code_gen.py b/python/triton/code_gen.py index 2f6ddf3c1..cacc39675 100644 --- a/python/triton/code_gen.py +++ b/python/triton/code_gen.py @@ -700,7 +700,17 @@ class Kernel: pickle.dump({"binary": binary, "key": key}, f) os.rename(bin_cache_path + ".tmp", bin_cache_path) if JITFunction.cache_hook is not None: - JITFunction.cache_hook(key=key, binary=binary) + name = self.fn.fn.__name__ + info = key.split('-')[-3:] + num_warps, num_stages, sig = info[0], info[1], info[2].split('_')[1:] + # make signature human-readable + arg_reprs = [] + for arg_name, arg_sig in zip(self.fn.arg_names, sig): + arg_reprs.append(f'{arg_name}: {arg_sig}') + # assemble the repr + arg_reprs = ", ".join(arg_reprs) + repr = f"{name}[num_warps={num_warps}, num_stages={num_stages}]({arg_reprs})" + JITFunction.cache_hook(key=key, binary=binary, repr=repr) self.fn.bin_cache[key] = LoadedBinary(device_idx, binary)