[FRONTEND] Reduce number of compiles in JITFunction (#704)

I suspect this was the cause of the "new compiles even on a warm cache"
behavior I was seeing, though haven't 100% confirmed it.

Python `set()` iteration order is nondeterministic when you create a new
process. So the same args could produce different `instance_descriptor`s
and have false cache misses.
This commit is contained in:
Jason Ansel
2022-09-23 14:44:52 -07:00
committed by GitHub
parent 25e1b36785
commit 579c03615d

View File

@@ -151,8 +151,8 @@ class JITFunction(KernelInterface):
if x is None:
return True
return False
divisible_by_16 = {i for i, arg in enumerate(args) if is_divisible_by_16(arg) and i not in self.do_not_specialize}
equal_to_1 = {i for i, arg in enumerate(args) if isinstance(arg, int) and arg == 1 and i not in self.do_not_specialize}
divisible_by_16 = [i for i, arg in enumerate(args) if is_divisible_by_16(arg) and i not in self.do_not_specialize]
equal_to_1 = [i for i, arg in enumerate(args) if isinstance(arg, int) and arg == 1 and i not in self.do_not_specialize]
return namedtuple("instance_descriptor", ["divisible_by_16", "equal_to_1"])(tuple(divisible_by_16), tuple(equal_to_1))
# return _triton.code_gen.instance_descriptor(divisible_by_16, equal_to_1)