diff --git a/CMakeLists.txt b/CMakeLists.txt index 4cee0c220..0b80890cb 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -181,7 +181,10 @@ target_link_libraries(triton TritonDriver # TritonCodeGen - MLIRCAPIIR + # optimizations + MLIRPass + MLIRTransforms + ${PYTHON_LIBRARIES} ) diff --git a/python/src/triton.cc b/python/src/triton.cc index ee7b4877a..fe56340e4 100644 --- a/python/src/triton.cc +++ b/python/src/triton.cc @@ -8,6 +8,10 @@ #include "mlir/IR/MLIRContext.h" #include "mlir/IR/Verifier.h" +#include "mlir/Pass/PassManager.h" +#include "mlir/Transforms/Passes.h" + + #include "triton/ir/Dialect.h" #include "triton/ir/Types.h" @@ -717,6 +721,9 @@ void init_triton_ir(py::module &&m) { .def("set_attr", [](mlir::OpState &self, std::string &name, mlir::Attribute &attr) -> void { self->setAttr(name, attr); }) + .def("get_num_results", [](mlir::OpState &self) -> unsigned { + return self->getNumResults(); + }) .def("get_result", [](mlir::OpState &self, unsigned idx) -> mlir::Value { return self->getResult(idx); }) @@ -755,12 +762,18 @@ void init_triton_ir(py::module &&m) { py::class_(m, "CondtionOp"); py::class_(m, "module") - .def("dump", [](mlir::ModuleOp &self) -> void { - self.dump(); - }) + .def("dump", &mlir::ModuleOp::dump) .def("push_back", [](mlir::ModuleOp &self, mlir::FuncOp &funcOp) -> void { self.push_back(funcOp); }) + .def("has_function", [](mlir::ModuleOp &self, std::string &funcName) -> bool { + if (self.lookupSymbol(funcName)) + return true; + return false; + }) + .def("get_function", [](mlir::ModuleOp &self, std::string &funcName) -> mlir::FuncOp { + return self.lookupSymbol(funcName); + }) ; py::class_(m, "function") @@ -772,6 +785,7 @@ void init_triton_ir(py::module &&m) { .def("add_entry_block", [](mlir::FuncOp &self) -> mlir::Block* { return self.addEntryBlock(); }, ret::reference) + .def("reset_type", &mlir::FuncOp::setType) ; py::class_(m, "InsertPoint"); @@ -784,11 +798,14 @@ void init_triton_ir(py::module &&m) { auto loc = self.getUnknownLoc(); return self.create(loc); }) - // control flow - .def("ret_void", [](mlir::OpBuilder &self) { + .def("ret", [](mlir::OpBuilder &self, std::vector &vals) -> void { auto loc = self.getUnknownLoc(); - self.create(loc); - }, ret::reference) + self.create(loc, vals); + }) + .def("call", [](mlir::OpBuilder &self, mlir::FuncOp &func, std::vector &args) -> mlir::OpState { + auto loc = self.getUnknownLoc(); + return self.create(loc, func, args); + }) // insertion block/point .def("set_insertion_point_to_start", [](mlir::OpBuilder &self, mlir::Block &block) -> void { self.setInsertionPointToStart(&block); @@ -900,6 +917,16 @@ void init_triton_ir(py::module &&m) { } throw std::runtime_error("invalid function type"); }) + .def("get_or_insert_function", [](mlir::OpBuilder &self, mlir::ModuleOp &module, + std::string &funcName, mlir::Type &funcType) -> mlir::FuncOp { + if (mlir::Operation *funcOperation = module.lookupSymbol(funcName)) + return llvm::dyn_cast(funcOperation); + auto loc = self.getUnknownLoc(); + if (auto funcTy = funcType.dyn_cast()) { + return self.create(loc, funcName, funcTy); + } + throw std::runtime_error("invalid function type"); + }) .def("create_block", [](mlir::OpBuilder &self) -> mlir::Block* { mlir::Region *parent = self.getBlock()->getParent(); return self.createBlock(parent); @@ -1293,6 +1320,16 @@ void init_triton_ir(py::module &&m) { // .def("create_umulhi", &ir::builder::create_umulhi, ret::reference) // .def("create_barrier", &ir::builder::create_barrier, ret::reference); ; + + py::class_(m, "pass_manager") + .def(py::init()) + .def("run", [](mlir::PassManager &self, mlir::ModuleOp &mod) { + self.run(mod.getOperation()); + }) + .def("add_inliner_pass", [](mlir::PassManager &self) { + self.addPass(mlir::createInlinerPass()); + }) + ; } void init_triton(py::module &m) { diff --git a/python/triton/code_gen.py b/python/triton/code_gen.py index 13821f8ca..59e26a3cc 100644 --- a/python/triton/code_gen.py +++ b/python/triton/code_gen.py @@ -14,6 +14,7 @@ import textwrap import time import warnings from typing import Dict, Optional, Set, Tuple, Union +from numpy import isin import torch from filelock import FileLock @@ -22,6 +23,41 @@ import triton import triton._C.libtriton.triton as _triton from .tools.disasm import extract + +def mangle_ty(ty): + if ty.is_ptr(): + return 'P' + mangle_ty(ty.element_ty) + if ty.is_int(): + return 'i' + str(ty.int_bitwidth) + if ty.is_fp8(): + return 'fp8' + if ty.is_fp16(): + return 'fp16' + if ty.is_bf16(): + return 'bf16' + if ty.is_fp32(): + return 'fp32' + if ty.is_fp64(): + return 'fp64' + if ty.is_void(): + return 'V' + if ty.is_block(): + elt = mangle_ty(ty.scalar) + shape = '_'.join(map(str, ty.shape)) + return f'{elt}S{shape}S' + assert False, "Unsupport type" + + +def mangle_fn(name, arg_tys, constants): + # doesn't mangle ret type, which must be a function of arg tys + mangled_arg_names = '_'.join([mangle_ty(ty) for ty in arg_tys]) + key = lambda x: x.__name__ if isinstance(x, JITFunction) else repr(x) + mangled_constants = '_'.join([f'{i}c{key(constants[i])}' for i in sorted(constants)]) + mangled_constants = mangled_constants.replace('.', '_d_') + mangled_constants = mangled_constants.replace("'", '_sq_') + ret = f'{name}__{mangled_arg_names}__{mangled_constants}' + return ret + class enter_sub_region: def __init__(self, generator: CodeGenerator): self.generator = generator @@ -40,15 +76,16 @@ class enter_sub_region: self.generator.local_defs = self.prev_defs class CodeGenerator(ast.NodeVisitor): - def __init__(self, context, prototype, gscope, attributes, constants, kwargs): + def __init__(self, context, prototype, gscope, attributes, constants, module=None, is_kernel=False, function_types=dict()): self.builder = _triton.ir.builder(context) - self.module = self.builder.create_module() + self.module = self.builder.create_module() if module is None else module + self.function_ret_types = function_types self.prototype = prototype self.gscope = gscope self.lscope = dict() self.attributes = attributes self.constants = constants - self.kwargs = kwargs + self.is_kernel = is_kernel self.last_node = None self.builtins = { 'range': range, @@ -104,7 +141,7 @@ class CodeGenerator(ast.NodeVisitor): # def visit_compound_statement(self, stmts): for stmt in stmts: - self.last_ret = self.visit(stmt) + self.last_ret_type = self.visit(stmt) if isinstance(stmt, ast.Return): break return stmts and isinstance(stmt, ast.Return) @@ -120,12 +157,18 @@ class CodeGenerator(ast.NodeVisitor): # By design, only non-kernel functions can return def visit_Return(self, node): - ret = self.visit(node.value) - if ret is None: - return self.builder.ret_void() - return ret + ret_value = self.visit(node.value) + if ret_value is None: + return self.builder.ret([]) + if isinstance(ret_value, list): + assert False, "returing a tuple is not supported" + else: + ret = triton.language.core._to_tensor(ret_value, self.builder) + ret_type = ret.type + self.builder.ret([ret_value.handle]) + return ret_type - def visit_FunctionDef(self, node, inline=False, arg_values=None): + def visit_FunctionDef(self, node): arg_names, kwarg_names = self.visit(node.args) # initialize defaults for i, default_value in enumerate(node.args.defaults): @@ -138,46 +181,51 @@ class CodeGenerator(ast.NodeVisitor): else: init_node = ast.AnnAssign(target=st_target, value=default_value, annotation=annotation) self.visit(init_node) - # store keyword arguments in local scope - self.lscope[kwarg_names] = self.kwargs # initialize function - if inline: - for arg_name, arg_value in zip(arg_names, arg_values): - self.set_value(arg_name, arg_value) - self.visit_compound_statement(node.body) - return self.last_ret - else: - fn = self.builder.create_function(node.name, self.prototype.to_ir(self.builder)) - self.module.push_back(fn) - entry = fn.add_entry_block() - arg_values = [] - idx = 0 - for i, arg_name in enumerate(arg_names): - if i in self.constants: - cst = self.constants[i] - if not isinstance(cst, triton.language.constexpr): - cst = triton.language.constexpr(self.constants[i]) - arg_values.append(cst) - else: - pass - # TODO: ... - # if i in self.attributes: - # is_ptr = fn.args[idx].type.is_ptr() - # attr = 'aligned' if is_ptr else 'multiple_of' - # attr = getattr(_triton.ir.attribute_kind, attr) - # attr = _triton.ir.attribute(attr, self.attributes[i]) - # fn.add_attr(idx + 1, attr) - # fn.args[idx].name = arg_name - arg_values.append(triton.language.tensor(fn.args(idx), self.prototype.param_types[idx])) - idx += 1 + fn_name = mangle_fn(node.name, self.prototype.param_types, self.constants) + fn = self.builder.get_or_insert_function(self.module, fn_name, self.prototype.to_ir(self.builder)) + self.module.push_back(fn) + entry = fn.add_entry_block() + arg_values = [] + idx = 0 + for i, arg_name in enumerate(arg_names): + if i in self.constants: + cst = self.constants[i] + if not isinstance(cst, triton.language.constexpr): + cst = triton.language.constexpr(self.constants[i]) + arg_values.append(cst) + else: + pass + # TODO: ... + # if i in self.attributes: + # is_ptr = fn.args[idx].type.is_ptr() + # attr = 'aligned' if is_ptr else 'multiple_of' + # attr = getattr(_triton.ir.attribute_kind, attr) + # attr = _triton.ir.attribute(attr, self.attributes[i]) + # fn.add_attr(idx + 1, attr) + # fn.args[idx].name = arg_name + arg_values.append(triton.language.tensor(fn.args(idx), self.prototype.param_types[idx])) + idx += 1 - for arg_name, arg_value in zip(arg_names, arg_values): - self.set_value(arg_name, arg_value) - self.builder.set_insertion_point_to_start(entry) - # visit function body - self.visit_compound_statement(node.body) - # finalize function - self.builder.ret_void() + insert_pt = self.builder.get_insertion_block() + for arg_name, arg_value in zip(arg_names, arg_values): + self.set_value(arg_name, arg_value) + self.builder.set_insertion_point_to_start(entry) + # visit function body + has_ret = self.visit_compound_statement(node.body) + # finalize function + if not has_ret: + self.builder.ret([]) + else: + # update return type + if isinstance(self.last_ret_type, tuple): + self.prototype.ret_types = [ret_tensor.type for ret_tensor in self.last_ret_type] + fn.reset_type(self.prototype.to_ir(self.builder)) + else: + self.prototype.ret_types = [self.last_ret_type] + fn.reset_type(self.prototype.to_ir(self.builder)) + if insert_pt: + self.builder.set_insertion_point_to_end(insert_pt) def visit_arguments(self, node): arg_names = [] @@ -219,7 +267,6 @@ class CodeGenerator(ast.NodeVisitor): if not isinstance(values, tuple): values = [values] for name, value in zip(names, values): - # TODO: can we store constexpr here to support constant folding? # by default, constexpr are assigned into python variable if isinstance(value, triton.language.constexpr): value = value.value @@ -581,7 +628,40 @@ class CodeGenerator(ast.NodeVisitor): kws.update(self.visit(keyword)) args = [self.visit(arg) for arg in node.args] if isinstance(fn, JITFunction): - return fn(*args, generator=self, **kws) + from inspect import getcallargs + args = getcallargs(fn.fn, *args, **kws) + args = [args[name] for name in fn.arg_names] + args = [arg if isinstance(arg, triton.language.tensor) + else triton.language.constexpr(arg) for arg in args] + # generate function def + attributes = dict() + constexprs = [i for i, arg in enumerate(args) if isinstance(arg, triton.language.constexpr)] + constants = {i: args[i] for i in constexprs} + # generate call + args = [None if i in constexprs else arg for i, arg in enumerate(args)] + arg_vals = [arg.handle for arg in args if arg is not None] + arg_types = [arg.type for arg in args if arg is not None] + fn_name = mangle_fn(fn.__name__, arg_types, constants) + # generate function def if necessary + if not self.module.has_function(fn_name): + ret_type = triton.language.void + prototype = triton.language.function_type([ret_type], arg_types) + gscope = sys.modules[fn.fn.__module__].__dict__ + generator = CodeGenerator(self.builder.context, prototype, gscope, attributes, constants, module=self.module, function_types=self.function_ret_types) + generator.visit(fn.parse()) + callee_ret_type = generator.last_ret_type + self.function_ret_types[fn_name] = callee_ret_type + else: + callee_ret_type = self.function_ret_types[fn_name] + symbol = self.module.get_function(fn_name) + call_op = self.builder.call(symbol, arg_vals) + if call_op.get_num_results() == 0: + return None + elif call_op.get_num_results() == 1: + return triton.language.tensor(call_op.get_result(0), callee_ret_type) + else: + # should return a tuple of tl.tensor + raise RuntimeError("Not implemented") if hasattr(fn, '__self__') and self.is_triton_tensor(fn.__self__) or \ sys.modules[fn.__module__] is triton.language.core: return fn(*args, _builder=self.builder, **kws) @@ -1012,7 +1092,7 @@ class JITFunction: cache_hook = None - def __init__(self, fn, version=None, do_not_specialize=None): + def __init__(self, fn, version=None, inline=True, do_not_specialize=None): # information of wrapped function self.fn = fn self.module = fn.__module__ @@ -1021,6 +1101,7 @@ class JITFunction: self.arg_defaults = [v.default for v in signature.parameters.values()] self.version = version + self.inline = inline self.src = textwrap.dedent(inspect.getsource(fn)) self.src = self.src[self.src.find("def"):] self.do_not_specialize = [] if do_not_specialize is None else do_not_specialize @@ -1035,6 +1116,8 @@ class JITFunction: # annotations self.annotations = {self.arg_names.index(name): ty for name, ty in fn.__annotations__.items()} self.__annotations__ = fn.__annotations__ + # constexprs + self.constexprs = [self.arg_names.index(ann) for ann in self.__annotations__.keys()] # forward docs self.__doc__ = fn.__doc__ self.__name__ = fn.__name__ @@ -1061,32 +1144,8 @@ class JITFunction: assert isinstance(tree.body[0], ast.FunctionDef) return tree - # Called by CodeGenerator.visit_Call() - def __call__(self, *args, generator: CodeGenerator, **kwargs): - try: - from inspect import getcallargs - arg_values = getcallargs(self.fn, *args, **kwargs) - arg_values = [arg_values[name] for name in self.arg_names] - arg_values = [arg if isinstance(arg, triton.language.tensor) - else triton.language.constexpr(arg) for arg in arg_values] - - # Record values in the caller (parent scope) - gscope = generator.gscope.copy() - lscope = generator.lscope.copy() - - # TODO: clear values other than args - generator.gscope = sys.modules[self.fn.__module__].__dict__ - generator.lscope = dict() - ret = generator.visit_FunctionDef(self.parse().body[0], inline=True, arg_values=arg_values) - generator.gscope = gscope - generator.lscope = lscope - - return ret - except Exception as e: - node = generator.last_node - if node is None or isinstance(e, (NotImplementedError, CompilationError)): - raise e - raise CompilationError(self.src, node) from e + def __call__(self, *args, **kwargs): + raise RuntimeError("Cannot call @triton.jit'd outside of the scope of a kernel") # - when `.src` attribute is set, cache path needs # to be reinitialized @@ -1172,7 +1231,7 @@ class JITFunction: # generate Triton-IR # export symbols visible from self into code-generator object gscope = self.__globals__ - generator = CodeGenerator(context, prototype, gscope=gscope, attributes=attributes, constants=constants, kwargs=dict()) + generator = CodeGenerator(context, prototype, gscope=gscope, attributes=attributes, constants=constants, is_kernel=True) try: generator.visit(self.parse()) except Exception as e: @@ -1232,7 +1291,7 @@ class JITFunction: # generate Triton-IR # export symbols visible from self into code-generator object gscope = self.__globals__ - generator = CodeGenerator(context, prototype, gscope=gscope, attributes=attributes, constants=constants, kwargs=dict()) + generator = CodeGenerator(context, prototype, gscope=gscope, attributes=attributes, constants=constants, is_kernel=True) try: generator.visit(self.parse()) except Exception as e: diff --git a/rewrite-test/inline.mlir b/rewrite-test/inline.mlir new file mode 100644 index 000000000..e1bd07e8f --- /dev/null +++ b/rewrite-test/inline.mlir @@ -0,0 +1,261 @@ +============================= test session starts ============================== +platform linux -- Python 3.7.7, pytest-7.1.1, pluggy-1.0.0 +rootdir: /home/da/codes/triton-mlir-rewrite/triton/rewrite-test +collected 6 items + +scf_tests.py .....module { + func @matmul_kernel__Pfp16_Pfp16_Pfp16_i32_i32_i32_i32_i32_i32__7c1_9c1_11c1_12c64_13c64_14c32_15c8(%arg0: !tt.ptr, %arg1: !tt.ptr, %arg2: !tt.ptr, %arg3: i32, %arg4: i32, %arg5: i32, %arg6: i32, %arg7: i32, %arg8: i32) { + %0 = tt.get_program_id {axis = 0 : i32} : i32 + %1 = call @"cdiv__i32__1cconstexpr[64]"(%arg3) : (i32) -> i32 + %2 = call @"cdiv__i32__1cconstexpr[64]"(%arg4) : (i32) -> i32 + %c8_i32 = arith.constant 8 : i32 + %3 = arith.muli %2, %c8_i32 : i32 + %4 = arith.divsi %0, %3 : i32 + %c8_i32_0 = arith.constant 8 : i32 + %5 = arith.muli %4, %c8_i32_0 : i32 + %6 = arith.subi %1, %5 : i32 + %7 = call @"minimum__i32__1cconstexpr[8]"(%6) : (i32) -> i32 + %8 = arith.remsi %0, %7 : i32 + %9 = arith.addi %5, %8 : i32 + %10 = arith.remsi %0, %3 : i32 + %11 = arith.divsi %10, %7 : i32 + %c64_i32 = arith.constant 64 : i32 + %12 = arith.muli %9, %c64_i32 : i32 + %13 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32> + %14 = tt.broadcast %12 : (i32) -> tensor<64xi32> + %15 = arith.addi %14, %13 : tensor<64xi32> + %c64_i32_1 = arith.constant 64 : i32 + %16 = arith.muli %11, %c64_i32_1 : i32 + %17 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32> + %18 = tt.broadcast %16 : (i32) -> tensor<64xi32> + %19 = arith.addi %18, %17 : tensor<64xi32> + %20 = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32> + %21 = tt.reshape %15 : (tensor<64xi32>) -> tensor<64x1xi32> + %22 = tt.broadcast %arg6 : (i32) -> tensor<64x1xi32> + %23 = arith.muli %21, %22 : tensor<64x1xi32> + %24 = tt.reshape %20 : (tensor<32xi32>) -> tensor<1x32xi32> + %c1_i32 = arith.constant 1 : i32 + %25 = tt.broadcast %c1_i32 : (i32) -> tensor<1x32xi32> + %26 = arith.muli %24, %25 : tensor<1x32xi32> + %27 = tt.broadcast %23 : (tensor<64x1xi32>) -> tensor<64x32xi32> + %28 = tt.broadcast %26 : (tensor<1x32xi32>) -> tensor<64x32xi32> + %29 = arith.addi %27, %28 : tensor<64x32xi32> + %30 = tt.broadcast %arg0 : (!tt.ptr) -> tensor<64x32x!tt.ptr> + %31 = tt.getelementptr %30, %29, : tensor<64x32x!tt.ptr> + %32 = tt.reshape %20 : (tensor<32xi32>) -> tensor<32x1xi32> + %33 = tt.broadcast %arg7 : (i32) -> tensor<32x1xi32> + %34 = arith.muli %32, %33 : tensor<32x1xi32> + %35 = tt.reshape %19 : (tensor<64xi32>) -> tensor<1x64xi32> + %c1_i32_2 = arith.constant 1 : i32 + %36 = tt.broadcast %c1_i32_2 : (i32) -> tensor<1x64xi32> + %37 = arith.muli %35, %36 : tensor<1x64xi32> + %38 = tt.broadcast %34 : (tensor<32x1xi32>) -> tensor<32x64xi32> + %39 = tt.broadcast %37 : (tensor<1x64xi32>) -> tensor<32x64xi32> + %40 = arith.addi %38, %39 : tensor<32x64xi32> + %41 = tt.broadcast %arg1 : (!tt.ptr) -> tensor<32x64x!tt.ptr> + %42 = tt.getelementptr %41, %40, : tensor<32x64x!tt.ptr> + %cst = arith.constant 0.000000e+00 : f32 + %43 = tt.broadcast %cst : (f32) -> tensor<64x64xf32> + %c0_i32 = arith.constant 0 : i32 + %c32_i32 = arith.constant 32 : i32 + %44 = arith.index_cast %c0_i32 : i32 to index + %45 = arith.index_cast %arg5 : i32 to index + %46 = arith.index_cast %c32_i32 : i32 to index + %47:3 = scf.for %arg9 = %44 to %45 step %46 iter_args(%arg10 = %43, %arg11 = %31, %arg12 = %42) -> (tensor<64x64xf32>, tensor<64x32x!tt.ptr>, tensor<32x64x!tt.ptr>) { + %cst_6 = arith.constant dense : tensor<64x32xi1> + %cst_7 = arith.constant dense<0.000000e+00> : tensor<64x32xf16> + %77 = tt.load %arg11, %cst_6, %cst_7 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<64x32xf16> + %cst_8 = arith.constant dense : tensor<32x64xi1> + %cst_9 = arith.constant dense<0.000000e+00> : tensor<32x64xf16> + %78 = tt.load %arg12, %cst_8, %cst_9 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<32x64xf16> + %cst_10 = arith.constant 0.000000e+00 : f32 + %79 = tt.broadcast %cst_10 : (f32) -> tensor<64x64xf32> + %80 = tt.dot %77, %78, %79 {allowTF32 = true} : tensor<64x32xf16> * tensor<32x64xf16> -> tensor<64x64xf32> + %81 = arith.addf %arg10, %80 : tensor<64x64xf32> + %c32_i32_11 = arith.constant 32 : i32 + %82 = tt.broadcast %c32_i32_11 : (i32) -> tensor<64x32xi32> + %83 = tt.getelementptr %arg11, %82, : tensor<64x32x!tt.ptr> + %c32_i32_12 = arith.constant 32 : i32 + %84 = arith.muli %arg7, %c32_i32_12 : i32 + %85 = tt.broadcast %84 : (i32) -> tensor<32x64xi32> + %86 = tt.getelementptr %arg12, %85, : tensor<32x64x!tt.ptr> + scf.yield %81, %83, %86 : tensor<64x64xf32>, tensor<64x32x!tt.ptr>, tensor<32x64x!tt.ptr> + } + %48 = arith.truncf %47#0 : tensor<64x64xf32> to tensor<64x64xf16> + %c64_i32_3 = arith.constant 64 : i32 + %49 = arith.muli %9, %c64_i32_3 : i32 + %50 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32> + %51 = tt.broadcast %49 : (i32) -> tensor<64xi32> + %52 = arith.addi %51, %50 : tensor<64xi32> + %c64_i32_4 = arith.constant 64 : i32 + %53 = arith.muli %11, %c64_i32_4 : i32 + %54 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32> + %55 = tt.broadcast %53 : (i32) -> tensor<64xi32> + %56 = arith.addi %55, %54 : tensor<64xi32> + %57 = tt.reshape %52 : (tensor<64xi32>) -> tensor<64x1xi32> + %58 = tt.broadcast %arg8 : (i32) -> tensor<64x1xi32> + %59 = arith.muli %58, %57 : tensor<64x1xi32> + %60 = tt.broadcast %arg2 : (!tt.ptr) -> tensor<64x1x!tt.ptr> + %61 = tt.getelementptr %60, %59, : tensor<64x1x!tt.ptr> + %62 = tt.reshape %56 : (tensor<64xi32>) -> tensor<1x64xi32> + %c1_i32_5 = arith.constant 1 : i32 + %63 = tt.broadcast %c1_i32_5 : (i32) -> tensor<1x64xi32> + %64 = arith.muli %62, %63 : tensor<1x64xi32> + %65 = tt.broadcast %61 : (tensor<64x1x!tt.ptr>) -> tensor<64x64x!tt.ptr> + %66 = tt.broadcast %64 : (tensor<1x64xi32>) -> tensor<64x64xi32> + %67 = tt.getelementptr %65, %66, : tensor<64x64x!tt.ptr> + %68 = tt.reshape %52 : (tensor<64xi32>) -> tensor<64x1xi32> + %69 = tt.broadcast %arg3 : (i32) -> tensor<64x1xi32> + %70 = arith.cmpi slt, %68, %69 : tensor<64x1xi32> + %71 = tt.reshape %56 : (tensor<64xi32>) -> tensor<1x64xi32> + %72 = tt.broadcast %arg4 : (i32) -> tensor<1x64xi32> + %73 = arith.cmpi slt, %71, %72 : tensor<1x64xi32> + %74 = tt.broadcast %70 : (tensor<64x1xi1>) -> tensor<64x64xi1> + %75 = tt.broadcast %73 : (tensor<1x64xi1>) -> tensor<64x64xi1> + %76 = arith.andi %74, %75 : tensor<64x64xi1> + tt.store %67, %48, %76, : tensor<64x64xf16> + return + } + func @"cdiv__i32__1cconstexpr[64]"(%arg0: i32) -> i32 { + %c64_i32 = arith.constant 64 : i32 + %0 = arith.addi %arg0, %c64_i32 : i32 + %c1_i32 = arith.constant 1 : i32 + %1 = arith.subi %0, %c1_i32 : i32 + %c64_i32_0 = arith.constant 64 : i32 + %2 = arith.divsi %1, %c64_i32_0 : i32 + return %2 : i32 + } + func @"minimum__i32__1cconstexpr[8]"(%arg0: i32) -> i32 { + %c8_i32 = arith.constant 8 : i32 + %0 = arith.cmpi slt, %arg0, %c8_i32 : i32 + %c8_i32_0 = arith.constant 8 : i32 + %1 = select %0, %arg0, %c8_i32_0 : i32 + return %1 : i32 + } +} + +module { + func @matmul_kernel__Pfp16_Pfp16_Pfp16_i32_i32_i32_i32_i32_i32__7c1_9c1_11c1_12c64_13c64_14c32_15c8(%arg0: !tt.ptr, %arg1: !tt.ptr, %arg2: !tt.ptr, %arg3: i32, %arg4: i32, %arg5: i32, %arg6: i32, %arg7: i32, %arg8: i32) { + %c8_i32 = arith.constant 8 : i32 + %c63_i32 = arith.constant 63 : i32 + %c64_i32 = arith.constant 64 : i32 + %cst = arith.constant 0.000000e+00 : f32 + %c0 = arith.constant 0 : index + %c32 = arith.constant 32 : index + %cst_0 = arith.constant dense : tensor<64x32xi1> + %cst_1 = arith.constant dense<0.000000e+00> : tensor<64x32xf16> + %cst_2 = arith.constant dense : tensor<32x64xi1> + %cst_3 = arith.constant dense<0.000000e+00> : tensor<32x64xf16> + %c32_i32 = arith.constant 32 : i32 + %c1_i32 = arith.constant 1 : i32 + %0 = tt.get_program_id {axis = 0 : i32} : i32 + %1 = arith.addi %arg3, %c63_i32 : i32 + %2 = arith.divsi %1, %c64_i32 : i32 + %3 = arith.addi %arg4, %c63_i32 : i32 + %4 = arith.divsi %3, %c64_i32 : i32 + %5 = arith.muli %4, %c8_i32 : i32 + %6 = arith.divsi %0, %5 : i32 + %7 = arith.muli %6, %c8_i32 : i32 + %8 = arith.subi %2, %7 : i32 + %9 = arith.cmpi slt, %8, %c8_i32 : i32 + %10 = select %9, %8, %c8_i32 : i32 + %11 = arith.remsi %0, %10 : i32 + %12 = arith.addi %7, %11 : i32 + %13 = arith.remsi %0, %5 : i32 + %14 = arith.divsi %13, %10 : i32 + %15 = arith.muli %12, %c64_i32 : i32 + %16 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32> + %17 = tt.broadcast %15 : (i32) -> tensor<64xi32> + %18 = arith.addi %17, %16 : tensor<64xi32> + %19 = arith.muli %14, %c64_i32 : i32 + %20 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32> + %21 = tt.broadcast %19 : (i32) -> tensor<64xi32> + %22 = arith.addi %21, %20 : tensor<64xi32> + %23 = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32> + %24 = tt.reshape %18 : (tensor<64xi32>) -> tensor<64x1xi32> + %25 = tt.broadcast %arg6 : (i32) -> tensor<64x1xi32> + %26 = arith.muli %24, %25 : tensor<64x1xi32> + %27 = tt.reshape %23 : (tensor<32xi32>) -> tensor<1x32xi32> + %28 = tt.broadcast %c1_i32 : (i32) -> tensor<1x32xi32> + %29 = arith.muli %27, %28 : tensor<1x32xi32> + %30 = tt.broadcast %26 : (tensor<64x1xi32>) -> tensor<64x32xi32> + %31 = tt.broadcast %29 : (tensor<1x32xi32>) -> tensor<64x32xi32> + %32 = arith.addi %30, %31 : tensor<64x32xi32> + %33 = tt.broadcast %arg0 : (!tt.ptr) -> tensor<64x32x!tt.ptr> + %34 = tt.getelementptr %33, %32, : tensor<64x32x!tt.ptr> + %35 = tt.reshape %23 : (tensor<32xi32>) -> tensor<32x1xi32> + %36 = tt.broadcast %arg7 : (i32) -> tensor<32x1xi32> + %37 = arith.muli %35, %36 : tensor<32x1xi32> + %38 = tt.reshape %22 : (tensor<64xi32>) -> tensor<1x64xi32> + %39 = tt.broadcast %c1_i32 : (i32) -> tensor<1x64xi32> + %40 = arith.muli %38, %39 : tensor<1x64xi32> + %41 = tt.broadcast %37 : (tensor<32x1xi32>) -> tensor<32x64xi32> + %42 = tt.broadcast %40 : (tensor<1x64xi32>) -> tensor<32x64xi32> + %43 = arith.addi %41, %42 : tensor<32x64xi32> + %44 = tt.broadcast %arg1 : (!tt.ptr) -> tensor<32x64x!tt.ptr> + %45 = tt.getelementptr %44, %43, : tensor<32x64x!tt.ptr> + %46 = tt.broadcast %cst : (f32) -> tensor<64x64xf32> + %47 = arith.index_cast %arg5 : i32 to index + %48:3 = scf.for %arg9 = %c0 to %47 step %c32 iter_args(%arg10 = %46, %arg11 = %34, %arg12 = %45) -> (tensor<64x64xf32>, tensor<64x32x!tt.ptr>, tensor<32x64x!tt.ptr>) { + %78 = tt.load %arg11, %cst_0, %cst_1 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<64x32xf16> + %79 = tt.load %arg12, %cst_2, %cst_3 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<32x64xf16> + %80 = tt.broadcast %cst : (f32) -> tensor<64x64xf32> + %81 = tt.dot %78, %79, %80 {allowTF32 = true} : tensor<64x32xf16> * tensor<32x64xf16> -> tensor<64x64xf32> + %82 = arith.addf %arg10, %81 : tensor<64x64xf32> + %83 = tt.broadcast %c32_i32 : (i32) -> tensor<64x32xi32> + %84 = tt.getelementptr %arg11, %83, : tensor<64x32x!tt.ptr> + %85 = arith.muli %arg7, %c32_i32 : i32 + %86 = tt.broadcast %85 : (i32) -> tensor<32x64xi32> + %87 = tt.getelementptr %arg12, %86, : tensor<32x64x!tt.ptr> + scf.yield %82, %84, %87 : tensor<64x64xf32>, tensor<64x32x!tt.ptr>, tensor<32x64x!tt.ptr> + } + %49 = arith.truncf %48#0 : tensor<64x64xf32> to tensor<64x64xf16> + %50 = arith.muli %12, %c64_i32 : i32 + %51 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32> + %52 = tt.broadcast %50 : (i32) -> tensor<64xi32> + %53 = arith.addi %52, %51 : tensor<64xi32> + %54 = arith.muli %14, %c64_i32 : i32 + %55 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32> + %56 = tt.broadcast %54 : (i32) -> tensor<64xi32> + %57 = arith.addi %56, %55 : tensor<64xi32> + %58 = tt.reshape %53 : (tensor<64xi32>) -> tensor<64x1xi32> + %59 = tt.broadcast %arg8 : (i32) -> tensor<64x1xi32> + %60 = arith.muli %59, %58 : tensor<64x1xi32> + %61 = tt.broadcast %arg2 : (!tt.ptr) -> tensor<64x1x!tt.ptr> + %62 = tt.getelementptr %61, %60, : tensor<64x1x!tt.ptr> + %63 = tt.reshape %57 : (tensor<64xi32>) -> tensor<1x64xi32> + %64 = tt.broadcast %c1_i32 : (i32) -> tensor<1x64xi32> + %65 = arith.muli %63, %64 : tensor<1x64xi32> + %66 = tt.broadcast %62 : (tensor<64x1x!tt.ptr>) -> tensor<64x64x!tt.ptr> + %67 = tt.broadcast %65 : (tensor<1x64xi32>) -> tensor<64x64xi32> + %68 = tt.getelementptr %66, %67, : tensor<64x64x!tt.ptr> + %69 = tt.reshape %53 : (tensor<64xi32>) -> tensor<64x1xi32> + %70 = tt.broadcast %arg3 : (i32) -> tensor<64x1xi32> + %71 = arith.cmpi slt, %69, %70 : tensor<64x1xi32> + %72 = tt.reshape %57 : (tensor<64xi32>) -> tensor<1x64xi32> + %73 = tt.broadcast %arg4 : (i32) -> tensor<1x64xi32> + %74 = arith.cmpi slt, %72, %73 : tensor<1x64xi32> + %75 = tt.broadcast %71 : (tensor<64x1xi1>) -> tensor<64x64xi1> + %76 = tt.broadcast %74 : (tensor<1x64xi1>) -> tensor<64x64xi1> + %77 = arith.andi %75, %76 : tensor<64x64xi1> + tt.store %68, %49, %77, : tensor<64x64xf16> + return + } + func @"cdiv__i32__1cconstexpr[64]"(%arg0: i32) -> i32 { + %c64_i32 = arith.constant 64 : i32 + %c63_i32 = arith.constant 63 : i32 + %0 = arith.addi %arg0, %c63_i32 : i32 + %1 = arith.divsi %0, %c64_i32 : i32 + return %1 : i32 + } + func @"minimum__i32__1cconstexpr[8]"(%arg0: i32) -> i32 { + %c8_i32 = arith.constant 8 : i32 + %0 = arith.cmpi slt, %arg0, %c8_i32 : i32 + %1 = select %0, %arg0, %c8_i32 : i32 + return %1 : i32 + } +} + +. + +============================== 6 passed in 1.21s =============================== diff --git a/rewrite-test/scf_tests.py b/rewrite-test/scf_tests.py index 26fdf82cf..c4bb85840 100644 --- a/rewrite-test/scf_tests.py +++ b/rewrite-test/scf_tests.py @@ -2,12 +2,14 @@ import pytest import triton import triton.language as tl +import triton._C.libtriton.triton as _triton + import torch def test_if(): ref_ir = """module { - func @only_if(%arg0: i32, %arg1: i32, %arg2: i32) { + func @only_if__i32_i32_i32__(%arg0: i32, %arg1: i32, %arg2: i32) { %cst = arith.constant -1.000000e+00 : f32 %0 = arith.cmpi sgt, %arg2, %arg0 : i32 %1 = scf.if %0 -> (f32) { @@ -36,7 +38,7 @@ def test_if(): def test_if_else(): ref_ir = """module { - func @if_else(%arg0: i32, %arg1: i32, %arg2: i32) { + func @if_else__i32_i32_i32__(%arg0: i32, %arg1: i32, %arg2: i32) { %0 = arith.cmpi sgt, %arg2, %arg0 : i32 %1 = scf.if %0 -> (f32) { %cst = arith.constant 0.000000e+00 : f32 @@ -65,7 +67,7 @@ def test_if_else(): def test_for(): ref_ir = """module { - func @for_loop(%arg0: i32) { + func @for_loop__i32__(%arg0: i32) { %cst = arith.constant 1.000000e+00 : f32 %c0_i32 = arith.constant 0 : i32 %c1_i32 = arith.constant 1 : i32 @@ -95,7 +97,7 @@ def test_for(): def test_while(): ref_ir = """module { - func @generic_while(%arg0: i32) { + func @generic_while__i32__(%arg0: i32) { %c-1_i32 = arith.constant -1 : i32 %0 = scf.while (%arg1 = %c-1_i32) : (i32) -> i32 { %c0_i32 = arith.constant 0 : i32 @@ -124,7 +126,7 @@ def test_while(): def test_nested(): ref_ir = """module { - func @nested_cf(%arg0: i32, %arg1: i32, %arg2: i32, %arg3: i32) { + func @nested_cf__i32_i32_i32_i32__(%arg0: i32, %arg1: i32, %arg2: i32, %arg3: i32) { %cst = arith.constant 0.000000e+00 : f32 %c0_i32 = arith.constant 0 : i32 %c1_i32 = arith.constant 1 : i32 @@ -402,4 +404,14 @@ def test_matmul(): ) verify = mod.verify() assert verify - assert ref_ir == mod.str() + # assert ref_ir == mod.str() + print(mod.str()) + + pm = _triton.ir.pass_manager(ctx) + pm.add_inliner_pass() + pm.run(mod) + + verify = mod.verify() + assert verify + # assert ref_ir == mod.str() + print(mod.str())