[FRONTEND] Reduce number of compiles in JITFunction (#704)
I suspect this was the cause of the "new compiles even on a warm cache" behavior I was seeing, though haven't 100% confirmed it. Python `set()` iteration order is nondeterministic when you create a new process. So the same args could produce different `instance_descriptor`s and have false cache misses.
This commit is contained in:
@@ -151,8 +151,8 @@ class JITFunction(KernelInterface):
|
||||
if x is None:
|
||||
return True
|
||||
return False
|
||||
divisible_by_16 = {i for i, arg in enumerate(args) if is_divisible_by_16(arg) and i not in self.do_not_specialize}
|
||||
equal_to_1 = {i for i, arg in enumerate(args) if isinstance(arg, int) and arg == 1 and i not in self.do_not_specialize}
|
||||
divisible_by_16 = [i for i, arg in enumerate(args) if is_divisible_by_16(arg) and i not in self.do_not_specialize]
|
||||
equal_to_1 = [i for i, arg in enumerate(args) if isinstance(arg, int) and arg == 1 and i not in self.do_not_specialize]
|
||||
return namedtuple("instance_descriptor", ["divisible_by_16", "equal_to_1"])(tuple(divisible_by_16), tuple(equal_to_1))
|
||||
# return _triton.code_gen.instance_descriptor(divisible_by_16, equal_to_1)
|
||||
|
||||
|
Reference in New Issue
Block a user