[FRONTEND] Better cache hook (#400)
Added an additional `repr` argument to the cache hook, which represents a human-readable string representation of the signature and argument attributes associated with the compiled binary.
This commit is contained in:
@@ -117,7 +117,11 @@ void parse_args(py::list& args, py::list do_not_specialize, const std::string& f
|
|||||||
params.reserve(8*len); // 8 max bytes by argument
|
params.reserve(8*len); // 8 max bytes by argument
|
||||||
char* params_ptr = ¶ms[0];
|
char* params_ptr = ¶ms[0];
|
||||||
cache_key = func_key;
|
cache_key = func_key;
|
||||||
|
cache_key += "-" + std::to_string(num_warps);
|
||||||
|
cache_key += "-" + std::to_string(num_stages);
|
||||||
|
cache_key += "-";
|
||||||
for(int i = 0; i < len; i++){
|
for(int i = 0; i < len; i++){
|
||||||
|
cache_key += "_";
|
||||||
py::int_ py_i = py::int_(i);
|
py::int_ py_i = py::int_(i);
|
||||||
bool specialize = std::find(do_not_specialize.begin(), do_not_specialize.end(), py_i) == do_not_specialize.end();
|
bool specialize = std::find(do_not_specialize.begin(), do_not_specialize.end(), py_i) == do_not_specialize.end();
|
||||||
py::object arg = args[i];
|
py::object arg = args[i];
|
||||||
@@ -127,19 +131,20 @@ void parse_args(py::list& args, py::list do_not_specialize, const std::string& f
|
|||||||
if(PyLong_Check(arg_ptr)){
|
if(PyLong_Check(arg_ptr)){
|
||||||
int overflow;
|
int overflow;
|
||||||
long long value = PyLong_AsLongLongAndOverflow(arg_ptr, &overflow);
|
long long value = PyLong_AsLongLongAndOverflow(arg_ptr, &overflow);
|
||||||
|
// values equal to 1 are specialized
|
||||||
if(specialize && (value == 1)){
|
if(specialize && (value == 1)){
|
||||||
cache_key += '1';
|
cache_key += "1";
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
// long and int have different kernels
|
// long and int have different kernels
|
||||||
if(!overflow & (std::abs(value) <= 0xffffffff)){
|
if(!overflow & (std::abs(value) <= 0xffffffff)){
|
||||||
cache_key += 'I';
|
cache_key += "int32";
|
||||||
params_ptr = (char*)(((uintptr_t)params_ptr + 3) & (-4));
|
params_ptr = (char*)(((uintptr_t)params_ptr + 3) & (-4));
|
||||||
std::memcpy(params_ptr, &value, 4);
|
std::memcpy(params_ptr, &value, 4);
|
||||||
params_ptr += 4;
|
params_ptr += 4;
|
||||||
}
|
}
|
||||||
else{
|
else{
|
||||||
cache_key += 'L';
|
cache_key += "int64";
|
||||||
params_ptr = (char*)(((uintptr_t)params_ptr + 7) & (-8));
|
params_ptr = (char*)(((uintptr_t)params_ptr + 7) & (-8));
|
||||||
if(overflow){
|
if(overflow){
|
||||||
unsigned long long uvalue = PyLong_AsUnsignedLongLong(arg_ptr);
|
unsigned long long uvalue = PyLong_AsUnsignedLongLong(arg_ptr);
|
||||||
@@ -150,15 +155,15 @@ void parse_args(py::list& args, py::list do_not_specialize, const std::string& f
|
|||||||
}
|
}
|
||||||
if(!specialize)
|
if(!specialize)
|
||||||
continue;
|
continue;
|
||||||
// values equal to 1 are specialized
|
|
||||||
cache_key += 'x';
|
|
||||||
// values divisible by small powers of 2 are specialized
|
// values divisible by small powers of 2 are specialized
|
||||||
|
cache_key += "[multipleof(";
|
||||||
cache_key += pow2_divisor(value);
|
cache_key += pow2_divisor(value);
|
||||||
|
cache_key += ")]";
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
// argument is `float`
|
// argument is `float`
|
||||||
if(PyFloat_Check(arg_ptr)){
|
if(PyFloat_Check(arg_ptr)){
|
||||||
cache_key += "f";
|
cache_key += "float32";
|
||||||
float value = PyFloat_AsDouble(arg_ptr);
|
float value = PyFloat_AsDouble(arg_ptr);
|
||||||
params_ptr = (char*)(((uintptr_t)params_ptr + 3) & (-4));
|
params_ptr = (char*)(((uintptr_t)params_ptr + 3) & (-4));
|
||||||
std::memcpy(params_ptr, &value, 4);
|
std::memcpy(params_ptr, &value, 4);
|
||||||
@@ -167,7 +172,7 @@ void parse_args(py::list& args, py::list do_not_specialize, const std::string& f
|
|||||||
}
|
}
|
||||||
// argument is `bool`
|
// argument is `bool`
|
||||||
if(PyBool_Check(arg_ptr)){
|
if(PyBool_Check(arg_ptr)){
|
||||||
cache_key += "B";
|
cache_key += "bool";
|
||||||
bool value = arg_ptr == Py_True ? true : false;
|
bool value = arg_ptr == Py_True ? true : false;
|
||||||
std::memcpy(params_ptr, &value, 1);
|
std::memcpy(params_ptr, &value, 1);
|
||||||
params_ptr += 1;
|
params_ptr += 1;
|
||||||
@@ -176,7 +181,6 @@ void parse_args(py::list& args, py::list do_not_specialize, const std::string& f
|
|||||||
// argument is tensor
|
// argument is tensor
|
||||||
if(py::hasattr(arg, "data_ptr")){
|
if(py::hasattr(arg, "data_ptr")){
|
||||||
py::object data_ptr = arg.attr("data_ptr")();
|
py::object data_ptr = arg.attr("data_ptr")();
|
||||||
cache_key += "P";
|
|
||||||
long value = data_ptr.cast<long>();
|
long value = data_ptr.cast<long>();
|
||||||
params_ptr = (char*)(((uintptr_t)params_ptr + 7) & (-8));
|
params_ptr = (char*)(((uintptr_t)params_ptr + 7) & (-8));
|
||||||
std::memcpy(params_ptr, &value, 8);
|
std::memcpy(params_ptr, &value, 8);
|
||||||
@@ -186,6 +190,10 @@ void parse_args(py::list& args, py::list do_not_specialize, const std::string& f
|
|||||||
const char* start = (const char*)PyUnicode_1BYTE_DATA(repr.ptr()) + 6; // remove 'torch.'
|
const char* start = (const char*)PyUnicode_1BYTE_DATA(repr.ptr()) + 6; // remove 'torch.'
|
||||||
size_t len = PyUnicode_GET_LENGTH(repr.ptr()) - 6;
|
size_t len = PyUnicode_GET_LENGTH(repr.ptr()) - 6;
|
||||||
cache_key += std::string(start, len);
|
cache_key += std::string(start, len);
|
||||||
|
cache_key += "*";
|
||||||
|
cache_key += "[multipleof(";
|
||||||
|
cache_key += pow2_divisor(value);
|
||||||
|
cache_key += ")]";
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
// argument is `constexpr`
|
// argument is `constexpr`
|
||||||
@@ -208,8 +216,6 @@ void parse_args(py::list& args, py::list do_not_specialize, const std::string& f
|
|||||||
+ " Only int, float, bool, torch.Tensor, and triton.language.constexpr are supported.";
|
+ " Only int, float, bool, torch.Tensor, and triton.language.constexpr are supported.";
|
||||||
throw std::runtime_error(err_msg);
|
throw std::runtime_error(err_msg);
|
||||||
}
|
}
|
||||||
cache_key += std::to_string(num_warps);
|
|
||||||
cache_key += std::to_string(num_stages);
|
|
||||||
params_size = (std::ptrdiff_t)(params_ptr - ¶ms[0]);
|
params_size = (std::ptrdiff_t)(params_ptr - ¶ms[0]);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@@ -64,20 +64,21 @@ def reset_tmp_dir():
|
|||||||
|
|
||||||
def test_reuse():
|
def test_reuse():
|
||||||
counter = 0
|
counter = 0
|
||||||
def inc_counter(key, binary):
|
def inc_counter(key, binary, repr):
|
||||||
nonlocal counter
|
nonlocal counter
|
||||||
counter += 1
|
counter += 1
|
||||||
JITFunction.cache_hook = inc_counter
|
JITFunction.cache_hook = inc_counter
|
||||||
reset_tmp_dir()
|
reset_tmp_dir()
|
||||||
x = torch.empty(1, dtype=torch.int32, device='cuda')
|
x = torch.empty(1, dtype=torch.int32, device='cuda')
|
||||||
for i in range(10):
|
for i in range(10):
|
||||||
kernel[(1,)](x, 43, BLOCK=1024)
|
kernel[(1,)](x, 1, BLOCK=1024)
|
||||||
assert counter == 1
|
assert counter == 1
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize('mode', ['enable', 'disable'])
|
@pytest.mark.parametrize('mode', ['enable', 'disable'])
|
||||||
def test_specialize(mode):
|
def test_specialize(mode):
|
||||||
counter = 0
|
counter = 0
|
||||||
def inc_counter(key, binary):
|
def inc_counter(key, binary, repr):
|
||||||
nonlocal counter
|
nonlocal counter
|
||||||
counter += 1
|
counter += 1
|
||||||
JITFunction.cache_hook = inc_counter
|
JITFunction.cache_hook = inc_counter
|
||||||
|
@@ -700,7 +700,17 @@ class Kernel:
|
|||||||
pickle.dump({"binary": binary, "key": key}, f)
|
pickle.dump({"binary": binary, "key": key}, f)
|
||||||
os.rename(bin_cache_path + ".tmp", bin_cache_path)
|
os.rename(bin_cache_path + ".tmp", bin_cache_path)
|
||||||
if JITFunction.cache_hook is not None:
|
if JITFunction.cache_hook is not None:
|
||||||
JITFunction.cache_hook(key=key, binary=binary)
|
name = self.fn.fn.__name__
|
||||||
|
info = key.split('-')[-3:]
|
||||||
|
num_warps, num_stages, sig = info[0], info[1], info[2].split('_')[1:]
|
||||||
|
# make signature human-readable
|
||||||
|
arg_reprs = []
|
||||||
|
for arg_name, arg_sig in zip(self.fn.arg_names, sig):
|
||||||
|
arg_reprs.append(f'{arg_name}: {arg_sig}')
|
||||||
|
# assemble the repr
|
||||||
|
arg_reprs = ", ".join(arg_reprs)
|
||||||
|
repr = f"{name}[num_warps={num_warps}, num_stages={num_stages}]({arg_reprs})"
|
||||||
|
JITFunction.cache_hook(key=key, binary=binary, repr=repr)
|
||||||
|
|
||||||
self.fn.bin_cache[key] = LoadedBinary(device_idx, binary)
|
self.fn.bin_cache[key] = LoadedBinary(device_idx, binary)
|
||||||
|
|
||||||
|
Reference in New Issue
Block a user