From 579c03615d13fbf81cfb1b5d67c7e1046a6bfd9a Mon Sep 17 00:00:00 2001 From: Jason Ansel Date: Fri, 23 Sep 2022 14:44:52 -0700 Subject: [PATCH] [FRONTEND] Reduce number of compiles in JITFunction (#704) I suspect this was the cause of the "new compiles even on a warm cache" behavior I was seeing, though haven't 100% confirmed it. Python `set()` iteration order is nondeterministic when you create a new process. So the same args could produce different `instance_descriptor`s and have false cache misses. --- python/triton/runtime/jit.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/python/triton/runtime/jit.py b/python/triton/runtime/jit.py index 5e1bc544b..4e9e58b06 100644 --- a/python/triton/runtime/jit.py +++ b/python/triton/runtime/jit.py @@ -151,8 +151,8 @@ class JITFunction(KernelInterface): if x is None: return True return False - divisible_by_16 = {i for i, arg in enumerate(args) if is_divisible_by_16(arg) and i not in self.do_not_specialize} - equal_to_1 = {i for i, arg in enumerate(args) if isinstance(arg, int) and arg == 1 and i not in self.do_not_specialize} + divisible_by_16 = [i for i, arg in enumerate(args) if is_divisible_by_16(arg) and i not in self.do_not_specialize] + equal_to_1 = [i for i, arg in enumerate(args) if isinstance(arg, int) and arg == 1 and i not in self.do_not_specialize] return namedtuple("instance_descriptor", ["divisible_by_16", "equal_to_1"])(tuple(divisible_by_16), tuple(equal_to_1)) # return _triton.code_gen.instance_descriptor(divisible_by_16, equal_to_1)