diff --git a/python/src/triton.cc b/python/src/triton.cc index abe441b3a..5fb6bb8f2 100644 --- a/python/src/triton.cc +++ b/python/src/triton.cc @@ -110,7 +110,7 @@ std::string pow2_divisor(long N){ } // Launch -void parse_args(py::handle& args, const std::string& func_key, py::handle& arg_names, +void parse_args(py::handle& args, py::handle do_not_specialize, const std::string& func_key, py::handle& arg_names, std::string& cache_key, std::string& params, size_t& params_size, PyObject* constants, int num_warps, int num_stages) { size_t len = PyList_Size(args.ptr()); @@ -118,6 +118,8 @@ void parse_args(py::handle& args, const std::string& func_key, py::handle& arg_n char* params_ptr = ¶ms[0]; cache_key = func_key; for(int i = 0; i < len; i++){ + PyObject* py_i = PyLong_FromLong(i); + bool specialize = !PySequence_Contains(do_not_specialize.ptr(), py_i); auto arg_ptr = PyList_GetItem(args.ptr(), i); auto arg = py::handle(arg_ptr); // argument is `long` @@ -141,6 +143,8 @@ void parse_args(py::handle& args, const std::string& func_key, py::handle& arg_n std::memcpy(params_ptr, &value, 8); params_ptr += 8; } + if(!specialize) + continue; // values equal to 1 are specialized if(value == 1) cache_key += '1'; @@ -224,7 +228,7 @@ void init_triton_runtime(py::module &&m) { ); // cache key - m.def("launch", [](py::handle args, const std::string& func_key, py::list& arg_names, + m.def("launch", [](py::handle args, py::handle do_not_specialize, const std::string& func_key, py::list& arg_names, py::handle device, py::handle stream, py::handle bin_cache, py::handle num_warps, py::handle num_stages, py::handle add_to_cache, py::handle grid){ // parse arguments to compute cache key, compile-time constants and packed kernel arguments @@ -234,7 +238,7 @@ void init_triton_runtime(py::module &&m) { std::string params; size_t params_size; PyObject* constants = PyDict_New(); - parse_args(args, func_key, arg_names, cache_key, params, params_size, constants, _num_warps, _num_stages); + parse_args(args, do_not_specialize, func_key, arg_names, cache_key, params, params_size, constants, _num_warps, _num_stages); // get cached binary PyObject* key = PyUnicode_FromString(cache_key.c_str()); PyObject* bin = nullptr; diff --git a/python/test/unit/runtime/test_cache.py b/python/test/unit/runtime/test_cache.py index 215b90d8b..a1c994241 100644 --- a/python/test/unit/runtime/test_cache.py +++ b/python/test/unit/runtime/test_cache.py @@ -4,6 +4,7 @@ from triton.code_gen import JITFunction import triton.language as tl import os import shutil +import pytest tmpdir = ".tmp" @@ -25,6 +26,11 @@ def kernel(X, i, BLOCK: tl.constexpr): i = function_1(i) tl.store(X, i) +@triton.jit(do_not_specialize=["i"]) +def kernel_nospec(X, i, BLOCK: tl.constexpr): + i = i + 1 + i = function_1(i) + tl.store(X, i) def apply_src_change(target, old, new): delattr(kernel.fn, 'hash') @@ -51,16 +57,34 @@ def test_nested1_change(): updated = apply_src_change(function_1, 'i + 1', 'i + 2') assert baseline != updated +def reset_tmp_dir(): + os.environ["TRITON_CACHE_DIR"] = tmpdir + if os.path.exists(tmpdir): + shutil.rmtree(tmpdir) + def test_reuse(): counter = 0 def inc_counter(key, binary): nonlocal counter counter += 1 - os.environ["TRITON_CACHE_DIR"] = tmpdir - if os.path.exists(tmpdir): - shutil.rmtree(tmpdir) JITFunction.cache_hook = inc_counter + reset_tmp_dir() x = torch.empty(1, dtype=torch.int32, device='cuda') for i in range(10): kernel[(1,)](x, 43, BLOCK=1024) assert counter == 1 + +@pytest.mark.parametrize('mode', ['enable', 'disable']) +def test_specialize(mode): + counter = 0 + def inc_counter(key, binary): + nonlocal counter + counter += 1 + JITFunction.cache_hook = inc_counter + reset_tmp_dir() + x = torch.empty(1, dtype=torch.int32, device='cuda') + function = {'enable': kernel, 'disable': kernel_nospec}[mode] + target = {'enable': 5, 'disable': 1}[mode] + for i in [1, 2, 4, 8, 16, 32]: + function[(1,)](x, i, BLOCK=512) + assert counter == target diff --git a/python/triton/code_gen.py b/python/triton/code_gen.py index d418cb9d1..73d114ed1 100644 --- a/python/triton/code_gen.py +++ b/python/triton/code_gen.py @@ -667,7 +667,7 @@ class Kernel: stream = ((bits & 0xFFFFFFFFFFFF) ^ mask) - mask # stream = torch.cuda.current_stream(device).cuda_stream # make key for cache - return _triton.runtime.launch(wargs, self.fn.cache_key + cc, self.fn.arg_names, device, stream, + return _triton.runtime.launch(wargs, self.fn.do_not_specialize, self.fn.cache_key + cc, self.fn.arg_names, device, stream, self.fn.bin_cache, num_warps, num_stages, self.add_to_cache, grid)