Fix OpBuilder
This commit is contained in:
@@ -1177,7 +1177,16 @@ class JITFunction:
|
||||
# Compile to ttir, for the propose of testing MLIR rewriting
|
||||
def compile_to_ttir(self, *wargs, grid, num_warps=4, num_stages=2, **kwargs):
|
||||
# TODO: share code with _compile & __call__
|
||||
|
||||
# handle arguments passed by name
|
||||
kwargs = {self.arg_names.index(name): value for name, value in kwargs.items()}
|
||||
wargs = list(wargs)
|
||||
for i, pos in enumerate(sorted(kwargs)):
|
||||
wargs.insert(pos + i, kwargs[pos])
|
||||
if len(wargs) != len(self.arg_names):
|
||||
raise TypeError(f"Function takes {len(self.arg_names)} positional arguments but {len(wargs)} were given")
|
||||
# handle annotations
|
||||
for pos, _type in self.annotations.items():
|
||||
wargs[pos] = _type(wargs[pos])
|
||||
# preparing args
|
||||
tensor_idxs = [i for i, arg in enumerate(wargs) if hasattr(arg, 'data_ptr')]
|
||||
# attributes
|
||||
@@ -1191,7 +1200,7 @@ class JITFunction:
|
||||
attributes[i] = min(Kernel.pow2_divisor(addr),
|
||||
Kernel.pow2_divisor(range_size))
|
||||
# transforms ints whose value is one into constants for just-in-time compilation
|
||||
constants = {i: arg for i, arg in enumerate(wargs) if isinstance(arg, int) and arg == 1 and i not in self.fn.do_not_specialize}
|
||||
constants = {i: arg for i, arg in enumerate(wargs) if isinstance(arg, int) and arg == 1 and i not in self.do_not_specialize}
|
||||
constants.update({i: arg.value for i, arg in enumerate(wargs) if isinstance(arg, triton.language.constexpr)})
|
||||
constants.update({i: None for i, arg in enumerate(wargs) if arg is None})
|
||||
arg_types = [Kernel._to_python_ir(arg) for i, arg in enumerate(wargs) if i not in constants]
|
||||
|
Reference in New Issue
Block a user