[FRONTEND] Better cache hook (#400)
Added an additional `repr` argument to the cache hook, which represents a human-readable string representation of the signature and argument attributes associated with the compiled binary.
This commit is contained in:
@@ -117,7 +117,11 @@ void parse_args(py::list& args, py::list do_not_specialize, const std::string& f
|
||||
params.reserve(8*len); // 8 max bytes by argument
|
||||
char* params_ptr = ¶ms[0];
|
||||
cache_key = func_key;
|
||||
cache_key += "-" + std::to_string(num_warps);
|
||||
cache_key += "-" + std::to_string(num_stages);
|
||||
cache_key += "-";
|
||||
for(int i = 0; i < len; i++){
|
||||
cache_key += "_";
|
||||
py::int_ py_i = py::int_(i);
|
||||
bool specialize = std::find(do_not_specialize.begin(), do_not_specialize.end(), py_i) == do_not_specialize.end();
|
||||
py::object arg = args[i];
|
||||
@@ -127,19 +131,20 @@ void parse_args(py::list& args, py::list do_not_specialize, const std::string& f
|
||||
if(PyLong_Check(arg_ptr)){
|
||||
int overflow;
|
||||
long long value = PyLong_AsLongLongAndOverflow(arg_ptr, &overflow);
|
||||
// values equal to 1 are specialized
|
||||
if(specialize && (value == 1)){
|
||||
cache_key += '1';
|
||||
cache_key += "1";
|
||||
continue;
|
||||
}
|
||||
// long and int have different kernels
|
||||
if(!overflow & (std::abs(value) <= 0xffffffff)){
|
||||
cache_key += 'I';
|
||||
cache_key += "int32";
|
||||
params_ptr = (char*)(((uintptr_t)params_ptr + 3) & (-4));
|
||||
std::memcpy(params_ptr, &value, 4);
|
||||
params_ptr += 4;
|
||||
}
|
||||
else{
|
||||
cache_key += 'L';
|
||||
cache_key += "int64";
|
||||
params_ptr = (char*)(((uintptr_t)params_ptr + 7) & (-8));
|
||||
if(overflow){
|
||||
unsigned long long uvalue = PyLong_AsUnsignedLongLong(arg_ptr);
|
||||
@@ -150,15 +155,15 @@ void parse_args(py::list& args, py::list do_not_specialize, const std::string& f
|
||||
}
|
||||
if(!specialize)
|
||||
continue;
|
||||
// values equal to 1 are specialized
|
||||
cache_key += 'x';
|
||||
// values divisible by small powers of 2 are specialized
|
||||
cache_key += "[multipleof(";
|
||||
cache_key += pow2_divisor(value);
|
||||
cache_key += ")]";
|
||||
continue;
|
||||
}
|
||||
// argument is `float`
|
||||
if(PyFloat_Check(arg_ptr)){
|
||||
cache_key += "f";
|
||||
cache_key += "float32";
|
||||
float value = PyFloat_AsDouble(arg_ptr);
|
||||
params_ptr = (char*)(((uintptr_t)params_ptr + 3) & (-4));
|
||||
std::memcpy(params_ptr, &value, 4);
|
||||
@@ -167,7 +172,7 @@ void parse_args(py::list& args, py::list do_not_specialize, const std::string& f
|
||||
}
|
||||
// argument is `bool`
|
||||
if(PyBool_Check(arg_ptr)){
|
||||
cache_key += "B";
|
||||
cache_key += "bool";
|
||||
bool value = arg_ptr == Py_True ? true : false;
|
||||
std::memcpy(params_ptr, &value, 1);
|
||||
params_ptr += 1;
|
||||
@@ -176,7 +181,6 @@ void parse_args(py::list& args, py::list do_not_specialize, const std::string& f
|
||||
// argument is tensor
|
||||
if(py::hasattr(arg, "data_ptr")){
|
||||
py::object data_ptr = arg.attr("data_ptr")();
|
||||
cache_key += "P";
|
||||
long value = data_ptr.cast<long>();
|
||||
params_ptr = (char*)(((uintptr_t)params_ptr + 7) & (-8));
|
||||
std::memcpy(params_ptr, &value, 8);
|
||||
@@ -186,6 +190,10 @@ void parse_args(py::list& args, py::list do_not_specialize, const std::string& f
|
||||
const char* start = (const char*)PyUnicode_1BYTE_DATA(repr.ptr()) + 6; // remove 'torch.'
|
||||
size_t len = PyUnicode_GET_LENGTH(repr.ptr()) - 6;
|
||||
cache_key += std::string(start, len);
|
||||
cache_key += "*";
|
||||
cache_key += "[multipleof(";
|
||||
cache_key += pow2_divisor(value);
|
||||
cache_key += ")]";
|
||||
continue;
|
||||
}
|
||||
// argument is `constexpr`
|
||||
@@ -208,8 +216,6 @@ void parse_args(py::list& args, py::list do_not_specialize, const std::string& f
|
||||
+ " Only int, float, bool, torch.Tensor, and triton.language.constexpr are supported.";
|
||||
throw std::runtime_error(err_msg);
|
||||
}
|
||||
cache_key += std::to_string(num_warps);
|
||||
cache_key += std::to_string(num_stages);
|
||||
params_size = (std::ptrdiff_t)(params_ptr - ¶ms[0]);
|
||||
}
|
||||
|
||||
|
@@ -64,20 +64,21 @@ def reset_tmp_dir():
|
||||
|
||||
def test_reuse():
|
||||
counter = 0
|
||||
def inc_counter(key, binary):
|
||||
def inc_counter(key, binary, repr):
|
||||
nonlocal counter
|
||||
counter += 1
|
||||
JITFunction.cache_hook = inc_counter
|
||||
reset_tmp_dir()
|
||||
x = torch.empty(1, dtype=torch.int32, device='cuda')
|
||||
for i in range(10):
|
||||
kernel[(1,)](x, 43, BLOCK=1024)
|
||||
kernel[(1,)](x, 1, BLOCK=1024)
|
||||
assert counter == 1
|
||||
|
||||
|
||||
@pytest.mark.parametrize('mode', ['enable', 'disable'])
|
||||
def test_specialize(mode):
|
||||
counter = 0
|
||||
def inc_counter(key, binary):
|
||||
def inc_counter(key, binary, repr):
|
||||
nonlocal counter
|
||||
counter += 1
|
||||
JITFunction.cache_hook = inc_counter
|
||||
|
@@ -700,7 +700,17 @@ class Kernel:
|
||||
pickle.dump({"binary": binary, "key": key}, f)
|
||||
os.rename(bin_cache_path + ".tmp", bin_cache_path)
|
||||
if JITFunction.cache_hook is not None:
|
||||
JITFunction.cache_hook(key=key, binary=binary)
|
||||
name = self.fn.fn.__name__
|
||||
info = key.split('-')[-3:]
|
||||
num_warps, num_stages, sig = info[0], info[1], info[2].split('_')[1:]
|
||||
# make signature human-readable
|
||||
arg_reprs = []
|
||||
for arg_name, arg_sig in zip(self.fn.arg_names, sig):
|
||||
arg_reprs.append(f'{arg_name}: {arg_sig}')
|
||||
# assemble the repr
|
||||
arg_reprs = ", ".join(arg_reprs)
|
||||
repr = f"{name}[num_warps={num_warps}, num_stages={num_stages}]({arg_reprs})"
|
||||
JITFunction.cache_hook(key=key, binary=binary, repr=repr)
|
||||
|
||||
self.fn.bin_cache[key] = LoadedBinary(device_idx, binary)
|
||||
|
||||
|
Reference in New Issue
Block a user