diff --git a/python/triton/code_gen.py b/python/triton/code_gen.py index 8b765a5c8..6a301fd1a 100644 --- a/python/triton/code_gen.py +++ b/python/triton/code_gen.py @@ -794,7 +794,7 @@ def heuristics(values): .. highlight:: python .. code-block:: python - @heuristics(values={'BLOCK_SIZE': lambda args: 2 ** int(math.ceil(math.log2(args[1])))}) + @triton.heuristics(values={'BLOCK_SIZE': lambda args: 2 ** int(math.ceil(math.log2(args[1])))}) @triton.jit def kernel(x_ptr, x_size, **META): BLOCK_SIZE = META['BLOCK_SIZE'] # smallest power-of-two >= x_size