diff --git a/python/src/triton.cc b/python/src/triton.cc index fcebeeb5f..260f83942 100644 --- a/python/src/triton.cc +++ b/python/src/triton.cc @@ -236,8 +236,14 @@ void parse_args(py::list& args, py::list do_not_specialize, const std::string& f continue; } // argument is `constexpr` - if(py::hasattr(arg, "value")){ + if (py::hasattr(arg, "value")) { py::object value = arg.attr("value"); + // check if value is a callable object using PyCallable_Check + if (PyCallable_Check(value.ptr())) { + throw std::runtime_error( + "constant argument cannot be a callable object: " + + std::string(py::str(arg))); + } py::object name = arg_names[i]; constants[name] = value; py::object repr = py::repr(value); diff --git a/python/test/unit/runtime/test_cache.py b/python/test/unit/runtime/test_cache.py index d866d6983..fd95dbd38 100644 --- a/python/test/unit/runtime/test_cache.py +++ b/python/test/unit/runtime/test_cache.py @@ -130,3 +130,23 @@ def test_value_specialization(value: int, value_type: str, device='cuda') -> Non cache_str_match = re.match(r'_(\w+)\[multipleof\(\d+\)]_float32\*\[multipleof\(16\)\]', cache_str[-1]) spec_type = None if cache_str_match is None else cache_str_match.group(1) assert spec_type == value_type + + +def test_constexpr_not_callable() -> None: + @triton.jit + def kernel(X, c: tl.constexpr): + tl.store(X, 2) + + x = torch.empty(1, dtype=torch.int32, device='cuda') + error = False + try: + kernel[(1, )](x, c="str") + except BaseException: + error = True + assert error is False + # try and catch + try: + kernel[(1, )](x, c=tl.abs) + except BaseException: + error = True + assert error is True diff --git a/python/tutorials/03-matrix-multiplication.py b/python/tutorials/03-matrix-multiplication.py index 39bf8c46a..231d3371c 100644 --- a/python/tutorials/03-matrix-multiplication.py +++ b/python/tutorials/03-matrix-multiplication.py @@ -236,8 +236,8 @@ def matmul_kernel( b_ptrs += BLOCK_SIZE_K * stride_bk # you can fuse arbitrary activation functions here # while the accumulator is still in FP32! - if ACTIVATION: - accumulator = ACTIVATION(accumulator) + if ACTIVATION == "leaky_relu": + accumulator = leaky_relu(accumulator) c = accumulator.to(tl.float16) # ----------------------------------------------------------- @@ -261,7 +261,7 @@ def leaky_relu(x): # and (1) checks any shape constraint; (2) allocates the output; (3) launches the above kernel -def matmul(a, b, activation=None): +def matmul(a, b, activation=""): # checks constraints assert a.shape[1] == b.shape[0], "incompatible dimensions" assert a.is_contiguous(), "matrix A must be contiguous" @@ -347,7 +347,7 @@ def benchmark(M, N, K, provider): ) if provider == 'triton + relu': ms, min_ms, max_ms = triton.testing.do_bench( - lambda: matmul(a, b, activation=leaky_relu) + lambda: matmul(a, b, activation="leaky_relu") ) perf = lambda ms: 2 * M * N * K * 1e-12 / (ms * 1e-3) return perf(ms), perf(max_ms), perf(min_ms)