[FRONTEND] Refresh cache when the source code of outlined functions are changed (#590)
This commit is contained in:
@@ -236,8 +236,14 @@ void parse_args(py::list& args, py::list do_not_specialize, const std::string& f
|
|||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
// argument is `constexpr`
|
// argument is `constexpr`
|
||||||
if(py::hasattr(arg, "value")){
|
if (py::hasattr(arg, "value")) {
|
||||||
py::object value = arg.attr("value");
|
py::object value = arg.attr("value");
|
||||||
|
// check if value is a callable object using PyCallable_Check
|
||||||
|
if (PyCallable_Check(value.ptr())) {
|
||||||
|
throw std::runtime_error(
|
||||||
|
"constant argument cannot be a callable object: " +
|
||||||
|
std::string(py::str(arg)));
|
||||||
|
}
|
||||||
py::object name = arg_names[i];
|
py::object name = arg_names[i];
|
||||||
constants[name] = value;
|
constants[name] = value;
|
||||||
py::object repr = py::repr(value);
|
py::object repr = py::repr(value);
|
||||||
|
@@ -130,3 +130,23 @@ def test_value_specialization(value: int, value_type: str, device='cuda') -> Non
|
|||||||
cache_str_match = re.match(r'_(\w+)\[multipleof\(\d+\)]_float32\*\[multipleof\(16\)\]', cache_str[-1])
|
cache_str_match = re.match(r'_(\w+)\[multipleof\(\d+\)]_float32\*\[multipleof\(16\)\]', cache_str[-1])
|
||||||
spec_type = None if cache_str_match is None else cache_str_match.group(1)
|
spec_type = None if cache_str_match is None else cache_str_match.group(1)
|
||||||
assert spec_type == value_type
|
assert spec_type == value_type
|
||||||
|
|
||||||
|
|
||||||
|
def test_constexpr_not_callable() -> None:
|
||||||
|
@triton.jit
|
||||||
|
def kernel(X, c: tl.constexpr):
|
||||||
|
tl.store(X, 2)
|
||||||
|
|
||||||
|
x = torch.empty(1, dtype=torch.int32, device='cuda')
|
||||||
|
error = False
|
||||||
|
try:
|
||||||
|
kernel[(1, )](x, c="str")
|
||||||
|
except BaseException:
|
||||||
|
error = True
|
||||||
|
assert error is False
|
||||||
|
# try and catch
|
||||||
|
try:
|
||||||
|
kernel[(1, )](x, c=tl.abs)
|
||||||
|
except BaseException:
|
||||||
|
error = True
|
||||||
|
assert error is True
|
||||||
|
@@ -236,8 +236,8 @@ def matmul_kernel(
|
|||||||
b_ptrs += BLOCK_SIZE_K * stride_bk
|
b_ptrs += BLOCK_SIZE_K * stride_bk
|
||||||
# you can fuse arbitrary activation functions here
|
# you can fuse arbitrary activation functions here
|
||||||
# while the accumulator is still in FP32!
|
# while the accumulator is still in FP32!
|
||||||
if ACTIVATION:
|
if ACTIVATION == "leaky_relu":
|
||||||
accumulator = ACTIVATION(accumulator)
|
accumulator = leaky_relu(accumulator)
|
||||||
c = accumulator.to(tl.float16)
|
c = accumulator.to(tl.float16)
|
||||||
|
|
||||||
# -----------------------------------------------------------
|
# -----------------------------------------------------------
|
||||||
@@ -261,7 +261,7 @@ def leaky_relu(x):
|
|||||||
# and (1) checks any shape constraint; (2) allocates the output; (3) launches the above kernel
|
# and (1) checks any shape constraint; (2) allocates the output; (3) launches the above kernel
|
||||||
|
|
||||||
|
|
||||||
def matmul(a, b, activation=None):
|
def matmul(a, b, activation=""):
|
||||||
# checks constraints
|
# checks constraints
|
||||||
assert a.shape[1] == b.shape[0], "incompatible dimensions"
|
assert a.shape[1] == b.shape[0], "incompatible dimensions"
|
||||||
assert a.is_contiguous(), "matrix A must be contiguous"
|
assert a.is_contiguous(), "matrix A must be contiguous"
|
||||||
@@ -347,7 +347,7 @@ def benchmark(M, N, K, provider):
|
|||||||
)
|
)
|
||||||
if provider == 'triton + relu':
|
if provider == 'triton + relu':
|
||||||
ms, min_ms, max_ms = triton.testing.do_bench(
|
ms, min_ms, max_ms = triton.testing.do_bench(
|
||||||
lambda: matmul(a, b, activation=leaky_relu)
|
lambda: matmul(a, b, activation="leaky_relu")
|
||||||
)
|
)
|
||||||
perf = lambda ms: 2 * M * N * K * 1e-12 / (ms * 1e-3)
|
perf = lambda ms: 2 * M * N * K * 1e-12 / (ms * 1e-3)
|
||||||
return perf(ms), perf(max_ms), perf(min_ms)
|
return perf(ms), perf(max_ms), perf(min_ms)
|
||||||
|
Reference in New Issue
Block a user