[RUNTIME] Restored do_not_specialize
(#374)
This commit is contained in:
@@ -110,7 +110,7 @@ std::string pow2_divisor(long N){
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Launch
|
// Launch
|
||||||
void parse_args(py::handle& args, const std::string& func_key, py::handle& arg_names,
|
void parse_args(py::handle& args, py::handle do_not_specialize, const std::string& func_key, py::handle& arg_names,
|
||||||
std::string& cache_key, std::string& params, size_t& params_size, PyObject* constants,
|
std::string& cache_key, std::string& params, size_t& params_size, PyObject* constants,
|
||||||
int num_warps, int num_stages) {
|
int num_warps, int num_stages) {
|
||||||
size_t len = PyList_Size(args.ptr());
|
size_t len = PyList_Size(args.ptr());
|
||||||
@@ -118,6 +118,8 @@ void parse_args(py::handle& args, const std::string& func_key, py::handle& arg_n
|
|||||||
char* params_ptr = ¶ms[0];
|
char* params_ptr = ¶ms[0];
|
||||||
cache_key = func_key;
|
cache_key = func_key;
|
||||||
for(int i = 0; i < len; i++){
|
for(int i = 0; i < len; i++){
|
||||||
|
PyObject* py_i = PyLong_FromLong(i);
|
||||||
|
bool specialize = !PySequence_Contains(do_not_specialize.ptr(), py_i);
|
||||||
auto arg_ptr = PyList_GetItem(args.ptr(), i);
|
auto arg_ptr = PyList_GetItem(args.ptr(), i);
|
||||||
auto arg = py::handle(arg_ptr);
|
auto arg = py::handle(arg_ptr);
|
||||||
// argument is `long`
|
// argument is `long`
|
||||||
@@ -141,6 +143,8 @@ void parse_args(py::handle& args, const std::string& func_key, py::handle& arg_n
|
|||||||
std::memcpy(params_ptr, &value, 8);
|
std::memcpy(params_ptr, &value, 8);
|
||||||
params_ptr += 8;
|
params_ptr += 8;
|
||||||
}
|
}
|
||||||
|
if(!specialize)
|
||||||
|
continue;
|
||||||
// values equal to 1 are specialized
|
// values equal to 1 are specialized
|
||||||
if(value == 1)
|
if(value == 1)
|
||||||
cache_key += '1';
|
cache_key += '1';
|
||||||
@@ -224,7 +228,7 @@ void init_triton_runtime(py::module &&m) {
|
|||||||
);
|
);
|
||||||
|
|
||||||
// cache key
|
// cache key
|
||||||
m.def("launch", [](py::handle args, const std::string& func_key, py::list& arg_names,
|
m.def("launch", [](py::handle args, py::handle do_not_specialize, const std::string& func_key, py::list& arg_names,
|
||||||
py::handle device, py::handle stream, py::handle bin_cache, py::handle num_warps, py::handle num_stages,
|
py::handle device, py::handle stream, py::handle bin_cache, py::handle num_warps, py::handle num_stages,
|
||||||
py::handle add_to_cache, py::handle grid){
|
py::handle add_to_cache, py::handle grid){
|
||||||
// parse arguments to compute cache key, compile-time constants and packed kernel arguments
|
// parse arguments to compute cache key, compile-time constants and packed kernel arguments
|
||||||
@@ -234,7 +238,7 @@ void init_triton_runtime(py::module &&m) {
|
|||||||
std::string params;
|
std::string params;
|
||||||
size_t params_size;
|
size_t params_size;
|
||||||
PyObject* constants = PyDict_New();
|
PyObject* constants = PyDict_New();
|
||||||
parse_args(args, func_key, arg_names, cache_key, params, params_size, constants, _num_warps, _num_stages);
|
parse_args(args, do_not_specialize, func_key, arg_names, cache_key, params, params_size, constants, _num_warps, _num_stages);
|
||||||
// get cached binary
|
// get cached binary
|
||||||
PyObject* key = PyUnicode_FromString(cache_key.c_str());
|
PyObject* key = PyUnicode_FromString(cache_key.c_str());
|
||||||
PyObject* bin = nullptr;
|
PyObject* bin = nullptr;
|
||||||
|
@@ -4,6 +4,7 @@ from triton.code_gen import JITFunction
|
|||||||
import triton.language as tl
|
import triton.language as tl
|
||||||
import os
|
import os
|
||||||
import shutil
|
import shutil
|
||||||
|
import pytest
|
||||||
|
|
||||||
tmpdir = ".tmp"
|
tmpdir = ".tmp"
|
||||||
|
|
||||||
@@ -25,6 +26,11 @@ def kernel(X, i, BLOCK: tl.constexpr):
|
|||||||
i = function_1(i)
|
i = function_1(i)
|
||||||
tl.store(X, i)
|
tl.store(X, i)
|
||||||
|
|
||||||
|
@triton.jit(do_not_specialize=["i"])
|
||||||
|
def kernel_nospec(X, i, BLOCK: tl.constexpr):
|
||||||
|
i = i + 1
|
||||||
|
i = function_1(i)
|
||||||
|
tl.store(X, i)
|
||||||
|
|
||||||
def apply_src_change(target, old, new):
|
def apply_src_change(target, old, new):
|
||||||
delattr(kernel.fn, 'hash')
|
delattr(kernel.fn, 'hash')
|
||||||
@@ -51,16 +57,34 @@ def test_nested1_change():
|
|||||||
updated = apply_src_change(function_1, 'i + 1', 'i + 2')
|
updated = apply_src_change(function_1, 'i + 1', 'i + 2')
|
||||||
assert baseline != updated
|
assert baseline != updated
|
||||||
|
|
||||||
|
def reset_tmp_dir():
|
||||||
|
os.environ["TRITON_CACHE_DIR"] = tmpdir
|
||||||
|
if os.path.exists(tmpdir):
|
||||||
|
shutil.rmtree(tmpdir)
|
||||||
|
|
||||||
def test_reuse():
|
def test_reuse():
|
||||||
counter = 0
|
counter = 0
|
||||||
def inc_counter(key, binary):
|
def inc_counter(key, binary):
|
||||||
nonlocal counter
|
nonlocal counter
|
||||||
counter += 1
|
counter += 1
|
||||||
os.environ["TRITON_CACHE_DIR"] = tmpdir
|
|
||||||
if os.path.exists(tmpdir):
|
|
||||||
shutil.rmtree(tmpdir)
|
|
||||||
JITFunction.cache_hook = inc_counter
|
JITFunction.cache_hook = inc_counter
|
||||||
|
reset_tmp_dir()
|
||||||
x = torch.empty(1, dtype=torch.int32, device='cuda')
|
x = torch.empty(1, dtype=torch.int32, device='cuda')
|
||||||
for i in range(10):
|
for i in range(10):
|
||||||
kernel[(1,)](x, 43, BLOCK=1024)
|
kernel[(1,)](x, 43, BLOCK=1024)
|
||||||
assert counter == 1
|
assert counter == 1
|
||||||
|
|
||||||
|
@pytest.mark.parametrize('mode', ['enable', 'disable'])
|
||||||
|
def test_specialize(mode):
|
||||||
|
counter = 0
|
||||||
|
def inc_counter(key, binary):
|
||||||
|
nonlocal counter
|
||||||
|
counter += 1
|
||||||
|
JITFunction.cache_hook = inc_counter
|
||||||
|
reset_tmp_dir()
|
||||||
|
x = torch.empty(1, dtype=torch.int32, device='cuda')
|
||||||
|
function = {'enable': kernel, 'disable': kernel_nospec}[mode]
|
||||||
|
target = {'enable': 5, 'disable': 1}[mode]
|
||||||
|
for i in [1, 2, 4, 8, 16, 32]:
|
||||||
|
function[(1,)](x, i, BLOCK=512)
|
||||||
|
assert counter == target
|
||||||
|
@@ -667,7 +667,7 @@ class Kernel:
|
|||||||
stream = ((bits & 0xFFFFFFFFFFFF) ^ mask) - mask
|
stream = ((bits & 0xFFFFFFFFFFFF) ^ mask) - mask
|
||||||
# stream = torch.cuda.current_stream(device).cuda_stream
|
# stream = torch.cuda.current_stream(device).cuda_stream
|
||||||
# make key for cache
|
# make key for cache
|
||||||
return _triton.runtime.launch(wargs, self.fn.cache_key + cc, self.fn.arg_names, device, stream,
|
return _triton.runtime.launch(wargs, self.fn.do_not_specialize, self.fn.cache_key + cc, self.fn.arg_names, device, stream,
|
||||||
self.fn.bin_cache, num_warps, num_stages, self.add_to_cache, grid)
|
self.fn.bin_cache, num_warps, num_stages, self.add_to_cache, grid)
|
||||||
|
|
||||||
|
|
||||||
|
Reference in New Issue
Block a user