[RUNTIME] Restored do_not_specialize
(#374)
This commit is contained in:
@@ -110,7 +110,7 @@ std::string pow2_divisor(long N){
|
||||
}
|
||||
|
||||
// Launch
|
||||
void parse_args(py::handle& args, const std::string& func_key, py::handle& arg_names,
|
||||
void parse_args(py::handle& args, py::handle do_not_specialize, const std::string& func_key, py::handle& arg_names,
|
||||
std::string& cache_key, std::string& params, size_t& params_size, PyObject* constants,
|
||||
int num_warps, int num_stages) {
|
||||
size_t len = PyList_Size(args.ptr());
|
||||
@@ -118,6 +118,8 @@ void parse_args(py::handle& args, const std::string& func_key, py::handle& arg_n
|
||||
char* params_ptr = ¶ms[0];
|
||||
cache_key = func_key;
|
||||
for(int i = 0; i < len; i++){
|
||||
PyObject* py_i = PyLong_FromLong(i);
|
||||
bool specialize = !PySequence_Contains(do_not_specialize.ptr(), py_i);
|
||||
auto arg_ptr = PyList_GetItem(args.ptr(), i);
|
||||
auto arg = py::handle(arg_ptr);
|
||||
// argument is `long`
|
||||
@@ -141,6 +143,8 @@ void parse_args(py::handle& args, const std::string& func_key, py::handle& arg_n
|
||||
std::memcpy(params_ptr, &value, 8);
|
||||
params_ptr += 8;
|
||||
}
|
||||
if(!specialize)
|
||||
continue;
|
||||
// values equal to 1 are specialized
|
||||
if(value == 1)
|
||||
cache_key += '1';
|
||||
@@ -224,7 +228,7 @@ void init_triton_runtime(py::module &&m) {
|
||||
);
|
||||
|
||||
// cache key
|
||||
m.def("launch", [](py::handle args, const std::string& func_key, py::list& arg_names,
|
||||
m.def("launch", [](py::handle args, py::handle do_not_specialize, const std::string& func_key, py::list& arg_names,
|
||||
py::handle device, py::handle stream, py::handle bin_cache, py::handle num_warps, py::handle num_stages,
|
||||
py::handle add_to_cache, py::handle grid){
|
||||
// parse arguments to compute cache key, compile-time constants and packed kernel arguments
|
||||
@@ -234,7 +238,7 @@ void init_triton_runtime(py::module &&m) {
|
||||
std::string params;
|
||||
size_t params_size;
|
||||
PyObject* constants = PyDict_New();
|
||||
parse_args(args, func_key, arg_names, cache_key, params, params_size, constants, _num_warps, _num_stages);
|
||||
parse_args(args, do_not_specialize, func_key, arg_names, cache_key, params, params_size, constants, _num_warps, _num_stages);
|
||||
// get cached binary
|
||||
PyObject* key = PyUnicode_FromString(cache_key.c_str());
|
||||
PyObject* bin = nullptr;
|
||||
|
Reference in New Issue
Block a user