[FRONTEND] Better cache hook (#400)

Added an additional `repr` argument to the cache hook, which represents a human-readable string representation of the signature and argument attributes associated with the compiled binary.
This commit is contained in:
Philippe Tillet
2021-12-21 21:29:47 -08:00
committed by GitHub
parent 2509124dd0
commit a425f24d54
3 changed files with 31 additions and 14 deletions

View File

@@ -117,7 +117,11 @@ void parse_args(py::list& args, py::list do_not_specialize, const std::string& f
params.reserve(8*len); // 8 max bytes by argument params.reserve(8*len); // 8 max bytes by argument
char* params_ptr = &params[0]; char* params_ptr = &params[0];
cache_key = func_key; cache_key = func_key;
cache_key += "-" + std::to_string(num_warps);
cache_key += "-" + std::to_string(num_stages);
cache_key += "-";
for(int i = 0; i < len; i++){ for(int i = 0; i < len; i++){
cache_key += "_";
py::int_ py_i = py::int_(i); py::int_ py_i = py::int_(i);
bool specialize = std::find(do_not_specialize.begin(), do_not_specialize.end(), py_i) == do_not_specialize.end(); bool specialize = std::find(do_not_specialize.begin(), do_not_specialize.end(), py_i) == do_not_specialize.end();
py::object arg = args[i]; py::object arg = args[i];
@@ -127,19 +131,20 @@ void parse_args(py::list& args, py::list do_not_specialize, const std::string& f
if(PyLong_Check(arg_ptr)){ if(PyLong_Check(arg_ptr)){
int overflow; int overflow;
long long value = PyLong_AsLongLongAndOverflow(arg_ptr, &overflow); long long value = PyLong_AsLongLongAndOverflow(arg_ptr, &overflow);
// values equal to 1 are specialized
if(specialize && (value == 1)){ if(specialize && (value == 1)){
cache_key += '1'; cache_key += "1";
continue; continue;
} }
// long and int have different kernels // long and int have different kernels
if(!overflow & (std::abs(value) <= 0xffffffff)){ if(!overflow & (std::abs(value) <= 0xffffffff)){
cache_key += 'I'; cache_key += "int32";
params_ptr = (char*)(((uintptr_t)params_ptr + 3) & (-4)); params_ptr = (char*)(((uintptr_t)params_ptr + 3) & (-4));
std::memcpy(params_ptr, &value, 4); std::memcpy(params_ptr, &value, 4);
params_ptr += 4; params_ptr += 4;
} }
else{ else{
cache_key += 'L'; cache_key += "int64";
params_ptr = (char*)(((uintptr_t)params_ptr + 7) & (-8)); params_ptr = (char*)(((uintptr_t)params_ptr + 7) & (-8));
if(overflow){ if(overflow){
unsigned long long uvalue = PyLong_AsUnsignedLongLong(arg_ptr); unsigned long long uvalue = PyLong_AsUnsignedLongLong(arg_ptr);
@@ -150,15 +155,15 @@ void parse_args(py::list& args, py::list do_not_specialize, const std::string& f
} }
if(!specialize) if(!specialize)
continue; continue;
// values equal to 1 are specialized
cache_key += 'x';
// values divisible by small powers of 2 are specialized // values divisible by small powers of 2 are specialized
cache_key += "[multipleof(";
cache_key += pow2_divisor(value); cache_key += pow2_divisor(value);
cache_key += ")]";
continue; continue;
} }
// argument is `float` // argument is `float`
if(PyFloat_Check(arg_ptr)){ if(PyFloat_Check(arg_ptr)){
cache_key += "f"; cache_key += "float32";
float value = PyFloat_AsDouble(arg_ptr); float value = PyFloat_AsDouble(arg_ptr);
params_ptr = (char*)(((uintptr_t)params_ptr + 3) & (-4)); params_ptr = (char*)(((uintptr_t)params_ptr + 3) & (-4));
std::memcpy(params_ptr, &value, 4); std::memcpy(params_ptr, &value, 4);
@@ -167,7 +172,7 @@ void parse_args(py::list& args, py::list do_not_specialize, const std::string& f
} }
// argument is `bool` // argument is `bool`
if(PyBool_Check(arg_ptr)){ if(PyBool_Check(arg_ptr)){
cache_key += "B"; cache_key += "bool";
bool value = arg_ptr == Py_True ? true : false; bool value = arg_ptr == Py_True ? true : false;
std::memcpy(params_ptr, &value, 1); std::memcpy(params_ptr, &value, 1);
params_ptr += 1; params_ptr += 1;
@@ -176,7 +181,6 @@ void parse_args(py::list& args, py::list do_not_specialize, const std::string& f
// argument is tensor // argument is tensor
if(py::hasattr(arg, "data_ptr")){ if(py::hasattr(arg, "data_ptr")){
py::object data_ptr = arg.attr("data_ptr")(); py::object data_ptr = arg.attr("data_ptr")();
cache_key += "P";
long value = data_ptr.cast<long>(); long value = data_ptr.cast<long>();
params_ptr = (char*)(((uintptr_t)params_ptr + 7) & (-8)); params_ptr = (char*)(((uintptr_t)params_ptr + 7) & (-8));
std::memcpy(params_ptr, &value, 8); std::memcpy(params_ptr, &value, 8);
@@ -186,6 +190,10 @@ void parse_args(py::list& args, py::list do_not_specialize, const std::string& f
const char* start = (const char*)PyUnicode_1BYTE_DATA(repr.ptr()) + 6; // remove 'torch.' const char* start = (const char*)PyUnicode_1BYTE_DATA(repr.ptr()) + 6; // remove 'torch.'
size_t len = PyUnicode_GET_LENGTH(repr.ptr()) - 6; size_t len = PyUnicode_GET_LENGTH(repr.ptr()) - 6;
cache_key += std::string(start, len); cache_key += std::string(start, len);
cache_key += "*";
cache_key += "[multipleof(";
cache_key += pow2_divisor(value);
cache_key += ")]";
continue; continue;
} }
// argument is `constexpr` // argument is `constexpr`
@@ -208,8 +216,6 @@ void parse_args(py::list& args, py::list do_not_specialize, const std::string& f
+ " Only int, float, bool, torch.Tensor, and triton.language.constexpr are supported."; + " Only int, float, bool, torch.Tensor, and triton.language.constexpr are supported.";
throw std::runtime_error(err_msg); throw std::runtime_error(err_msg);
} }
cache_key += std::to_string(num_warps);
cache_key += std::to_string(num_stages);
params_size = (std::ptrdiff_t)(params_ptr - &params[0]); params_size = (std::ptrdiff_t)(params_ptr - &params[0]);
} }

View File

@@ -64,20 +64,21 @@ def reset_tmp_dir():
def test_reuse(): def test_reuse():
counter = 0 counter = 0
def inc_counter(key, binary): def inc_counter(key, binary, repr):
nonlocal counter nonlocal counter
counter += 1 counter += 1
JITFunction.cache_hook = inc_counter JITFunction.cache_hook = inc_counter
reset_tmp_dir() reset_tmp_dir()
x = torch.empty(1, dtype=torch.int32, device='cuda') x = torch.empty(1, dtype=torch.int32, device='cuda')
for i in range(10): for i in range(10):
kernel[(1,)](x, 43, BLOCK=1024) kernel[(1,)](x, 1, BLOCK=1024)
assert counter == 1 assert counter == 1
@pytest.mark.parametrize('mode', ['enable', 'disable']) @pytest.mark.parametrize('mode', ['enable', 'disable'])
def test_specialize(mode): def test_specialize(mode):
counter = 0 counter = 0
def inc_counter(key, binary): def inc_counter(key, binary, repr):
nonlocal counter nonlocal counter
counter += 1 counter += 1
JITFunction.cache_hook = inc_counter JITFunction.cache_hook = inc_counter

View File

@@ -700,7 +700,17 @@ class Kernel:
pickle.dump({"binary": binary, "key": key}, f) pickle.dump({"binary": binary, "key": key}, f)
os.rename(bin_cache_path + ".tmp", bin_cache_path) os.rename(bin_cache_path + ".tmp", bin_cache_path)
if JITFunction.cache_hook is not None: if JITFunction.cache_hook is not None:
JITFunction.cache_hook(key=key, binary=binary) name = self.fn.fn.__name__
info = key.split('-')[-3:]
num_warps, num_stages, sig = info[0], info[1], info[2].split('_')[1:]
# make signature human-readable
arg_reprs = []
for arg_name, arg_sig in zip(self.fn.arg_names, sig):
arg_reprs.append(f'{arg_name}: {arg_sig}')
# assemble the repr
arg_reprs = ", ".join(arg_reprs)
repr = f"{name}[num_warps={num_warps}, num_stages={num_stages}]({arg_reprs})"
JITFunction.cache_hook(key=key, binary=binary, repr=repr)
self.fn.bin_cache[key] = LoadedBinary(device_idx, binary) self.fn.bin_cache[key] = LoadedBinary(device_idx, binary)