Device function & PassManager

This commit is contained in:
Yan Da
2022-04-15 14:41:57 +08:00
parent 44d75cf9bb
commit 1c52bd587d
5 changed files with 464 additions and 92 deletions

View File

@@ -181,7 +181,10 @@ target_link_libraries(triton
TritonDriver TritonDriver
# TritonCodeGen # TritonCodeGen
MLIRCAPIIR # optimizations
MLIRPass
MLIRTransforms
${PYTHON_LIBRARIES} ${PYTHON_LIBRARIES}
) )

View File

@@ -8,6 +8,10 @@
#include "mlir/IR/MLIRContext.h" #include "mlir/IR/MLIRContext.h"
#include "mlir/IR/Verifier.h" #include "mlir/IR/Verifier.h"
#include "mlir/Pass/PassManager.h"
#include "mlir/Transforms/Passes.h"
#include "triton/ir/Dialect.h" #include "triton/ir/Dialect.h"
#include "triton/ir/Types.h" #include "triton/ir/Types.h"
@@ -717,6 +721,9 @@ void init_triton_ir(py::module &&m) {
.def("set_attr", [](mlir::OpState &self, std::string &name, mlir::Attribute &attr) -> void { .def("set_attr", [](mlir::OpState &self, std::string &name, mlir::Attribute &attr) -> void {
self->setAttr(name, attr); self->setAttr(name, attr);
}) })
.def("get_num_results", [](mlir::OpState &self) -> unsigned {
return self->getNumResults();
})
.def("get_result", [](mlir::OpState &self, unsigned idx) -> mlir::Value { .def("get_result", [](mlir::OpState &self, unsigned idx) -> mlir::Value {
return self->getResult(idx); return self->getResult(idx);
}) })
@@ -755,12 +762,18 @@ void init_triton_ir(py::module &&m) {
py::class_<mlir::scf::ConditionOp, mlir::OpState>(m, "CondtionOp"); py::class_<mlir::scf::ConditionOp, mlir::OpState>(m, "CondtionOp");
py::class_<mlir::ModuleOp, mlir::OpState>(m, "module") py::class_<mlir::ModuleOp, mlir::OpState>(m, "module")
.def("dump", [](mlir::ModuleOp &self) -> void { .def("dump", &mlir::ModuleOp::dump)
self.dump();
})
.def("push_back", [](mlir::ModuleOp &self, mlir::FuncOp &funcOp) -> void { .def("push_back", [](mlir::ModuleOp &self, mlir::FuncOp &funcOp) -> void {
self.push_back(funcOp); self.push_back(funcOp);
}) })
.def("has_function", [](mlir::ModuleOp &self, std::string &funcName) -> bool {
if (self.lookupSymbol(funcName))
return true;
return false;
})
.def("get_function", [](mlir::ModuleOp &self, std::string &funcName) -> mlir::FuncOp {
return self.lookupSymbol<mlir::FuncOp>(funcName);
})
; ;
py::class_<mlir::FuncOp, mlir::OpState>(m, "function") py::class_<mlir::FuncOp, mlir::OpState>(m, "function")
@@ -772,6 +785,7 @@ void init_triton_ir(py::module &&m) {
.def("add_entry_block", [](mlir::FuncOp &self) -> mlir::Block* { .def("add_entry_block", [](mlir::FuncOp &self) -> mlir::Block* {
return self.addEntryBlock(); return self.addEntryBlock();
}, ret::reference) }, ret::reference)
.def("reset_type", &mlir::FuncOp::setType)
; ;
py::class_<mlir::OpBuilder::InsertPoint>(m, "InsertPoint"); py::class_<mlir::OpBuilder::InsertPoint>(m, "InsertPoint");
@@ -784,11 +798,14 @@ void init_triton_ir(py::module &&m) {
auto loc = self.getUnknownLoc(); auto loc = self.getUnknownLoc();
return self.create<mlir::ModuleOp>(loc); return self.create<mlir::ModuleOp>(loc);
}) })
// control flow .def("ret", [](mlir::OpBuilder &self, std::vector<mlir::Value> &vals) -> void {
.def("ret_void", [](mlir::OpBuilder &self) {
auto loc = self.getUnknownLoc(); auto loc = self.getUnknownLoc();
self.create<mlir::ReturnOp>(loc); self.create<mlir::ReturnOp>(loc, vals);
}, ret::reference) })
.def("call", [](mlir::OpBuilder &self, mlir::FuncOp &func, std::vector<mlir::Value> &args) -> mlir::OpState {
auto loc = self.getUnknownLoc();
return self.create<mlir::CallOp>(loc, func, args);
})
// insertion block/point // insertion block/point
.def("set_insertion_point_to_start", [](mlir::OpBuilder &self, mlir::Block &block) -> void { .def("set_insertion_point_to_start", [](mlir::OpBuilder &self, mlir::Block &block) -> void {
self.setInsertionPointToStart(&block); self.setInsertionPointToStart(&block);
@@ -900,6 +917,16 @@ void init_triton_ir(py::module &&m) {
} }
throw std::runtime_error("invalid function type"); throw std::runtime_error("invalid function type");
}) })
.def("get_or_insert_function", [](mlir::OpBuilder &self, mlir::ModuleOp &module,
std::string &funcName, mlir::Type &funcType) -> mlir::FuncOp {
if (mlir::Operation *funcOperation = module.lookupSymbol(funcName))
return llvm::dyn_cast<mlir::FuncOp>(funcOperation);
auto loc = self.getUnknownLoc();
if (auto funcTy = funcType.dyn_cast<mlir::FunctionType>()) {
return self.create<mlir::FuncOp>(loc, funcName, funcTy);
}
throw std::runtime_error("invalid function type");
})
.def("create_block", [](mlir::OpBuilder &self) -> mlir::Block* { .def("create_block", [](mlir::OpBuilder &self) -> mlir::Block* {
mlir::Region *parent = self.getBlock()->getParent(); mlir::Region *parent = self.getBlock()->getParent();
return self.createBlock(parent); return self.createBlock(parent);
@@ -1293,6 +1320,16 @@ void init_triton_ir(py::module &&m) {
// .def("create_umulhi", &ir::builder::create_umulhi, ret::reference) // .def("create_umulhi", &ir::builder::create_umulhi, ret::reference)
// .def("create_barrier", &ir::builder::create_barrier, ret::reference); // .def("create_barrier", &ir::builder::create_barrier, ret::reference);
; ;
py::class_<mlir::PassManager>(m, "pass_manager")
.def(py::init<mlir::MLIRContext *>())
.def("run", [](mlir::PassManager &self, mlir::ModuleOp &mod) {
self.run(mod.getOperation());
})
.def("add_inliner_pass", [](mlir::PassManager &self) {
self.addPass(mlir::createInlinerPass());
})
;
} }
void init_triton(py::module &m) { void init_triton(py::module &m) {

View File

@@ -14,6 +14,7 @@ import textwrap
import time import time
import warnings import warnings
from typing import Dict, Optional, Set, Tuple, Union from typing import Dict, Optional, Set, Tuple, Union
from numpy import isin
import torch import torch
from filelock import FileLock from filelock import FileLock
@@ -22,6 +23,41 @@ import triton
import triton._C.libtriton.triton as _triton import triton._C.libtriton.triton as _triton
from .tools.disasm import extract from .tools.disasm import extract
def mangle_ty(ty):
if ty.is_ptr():
return 'P' + mangle_ty(ty.element_ty)
if ty.is_int():
return 'i' + str(ty.int_bitwidth)
if ty.is_fp8():
return 'fp8'
if ty.is_fp16():
return 'fp16'
if ty.is_bf16():
return 'bf16'
if ty.is_fp32():
return 'fp32'
if ty.is_fp64():
return 'fp64'
if ty.is_void():
return 'V'
if ty.is_block():
elt = mangle_ty(ty.scalar)
shape = '_'.join(map(str, ty.shape))
return f'{elt}S{shape}S'
assert False, "Unsupport type"
def mangle_fn(name, arg_tys, constants):
# doesn't mangle ret type, which must be a function of arg tys
mangled_arg_names = '_'.join([mangle_ty(ty) for ty in arg_tys])
key = lambda x: x.__name__ if isinstance(x, JITFunction) else repr(x)
mangled_constants = '_'.join([f'{i}c{key(constants[i])}' for i in sorted(constants)])
mangled_constants = mangled_constants.replace('.', '_d_')
mangled_constants = mangled_constants.replace("'", '_sq_')
ret = f'{name}__{mangled_arg_names}__{mangled_constants}'
return ret
class enter_sub_region: class enter_sub_region:
def __init__(self, generator: CodeGenerator): def __init__(self, generator: CodeGenerator):
self.generator = generator self.generator = generator
@@ -40,15 +76,16 @@ class enter_sub_region:
self.generator.local_defs = self.prev_defs self.generator.local_defs = self.prev_defs
class CodeGenerator(ast.NodeVisitor): class CodeGenerator(ast.NodeVisitor):
def __init__(self, context, prototype, gscope, attributes, constants, kwargs): def __init__(self, context, prototype, gscope, attributes, constants, module=None, is_kernel=False, function_types=dict()):
self.builder = _triton.ir.builder(context) self.builder = _triton.ir.builder(context)
self.module = self.builder.create_module() self.module = self.builder.create_module() if module is None else module
self.function_ret_types = function_types
self.prototype = prototype self.prototype = prototype
self.gscope = gscope self.gscope = gscope
self.lscope = dict() self.lscope = dict()
self.attributes = attributes self.attributes = attributes
self.constants = constants self.constants = constants
self.kwargs = kwargs self.is_kernel = is_kernel
self.last_node = None self.last_node = None
self.builtins = { self.builtins = {
'range': range, 'range': range,
@@ -104,7 +141,7 @@ class CodeGenerator(ast.NodeVisitor):
# #
def visit_compound_statement(self, stmts): def visit_compound_statement(self, stmts):
for stmt in stmts: for stmt in stmts:
self.last_ret = self.visit(stmt) self.last_ret_type = self.visit(stmt)
if isinstance(stmt, ast.Return): if isinstance(stmt, ast.Return):
break break
return stmts and isinstance(stmt, ast.Return) return stmts and isinstance(stmt, ast.Return)
@@ -120,12 +157,18 @@ class CodeGenerator(ast.NodeVisitor):
# By design, only non-kernel functions can return # By design, only non-kernel functions can return
def visit_Return(self, node): def visit_Return(self, node):
ret = self.visit(node.value) ret_value = self.visit(node.value)
if ret is None: if ret_value is None:
return self.builder.ret_void() return self.builder.ret([])
return ret if isinstance(ret_value, list):
assert False, "returing a tuple is not supported"
else:
ret = triton.language.core._to_tensor(ret_value, self.builder)
ret_type = ret.type
self.builder.ret([ret_value.handle])
return ret_type
def visit_FunctionDef(self, node, inline=False, arg_values=None): def visit_FunctionDef(self, node):
arg_names, kwarg_names = self.visit(node.args) arg_names, kwarg_names = self.visit(node.args)
# initialize defaults # initialize defaults
for i, default_value in enumerate(node.args.defaults): for i, default_value in enumerate(node.args.defaults):
@@ -138,46 +181,51 @@ class CodeGenerator(ast.NodeVisitor):
else: else:
init_node = ast.AnnAssign(target=st_target, value=default_value, annotation=annotation) init_node = ast.AnnAssign(target=st_target, value=default_value, annotation=annotation)
self.visit(init_node) self.visit(init_node)
# store keyword arguments in local scope
self.lscope[kwarg_names] = self.kwargs
# initialize function # initialize function
if inline: fn_name = mangle_fn(node.name, self.prototype.param_types, self.constants)
for arg_name, arg_value in zip(arg_names, arg_values): fn = self.builder.get_or_insert_function(self.module, fn_name, self.prototype.to_ir(self.builder))
self.set_value(arg_name, arg_value) self.module.push_back(fn)
self.visit_compound_statement(node.body) entry = fn.add_entry_block()
return self.last_ret arg_values = []
else: idx = 0
fn = self.builder.create_function(node.name, self.prototype.to_ir(self.builder)) for i, arg_name in enumerate(arg_names):
self.module.push_back(fn) if i in self.constants:
entry = fn.add_entry_block() cst = self.constants[i]
arg_values = [] if not isinstance(cst, triton.language.constexpr):
idx = 0 cst = triton.language.constexpr(self.constants[i])
for i, arg_name in enumerate(arg_names): arg_values.append(cst)
if i in self.constants: else:
cst = self.constants[i] pass
if not isinstance(cst, triton.language.constexpr): # TODO: ...
cst = triton.language.constexpr(self.constants[i]) # if i in self.attributes:
arg_values.append(cst) # is_ptr = fn.args[idx].type.is_ptr()
else: # attr = 'aligned' if is_ptr else 'multiple_of'
pass # attr = getattr(_triton.ir.attribute_kind, attr)
# TODO: ... # attr = _triton.ir.attribute(attr, self.attributes[i])
# if i in self.attributes: # fn.add_attr(idx + 1, attr)
# is_ptr = fn.args[idx].type.is_ptr() # fn.args[idx].name = arg_name
# attr = 'aligned' if is_ptr else 'multiple_of' arg_values.append(triton.language.tensor(fn.args(idx), self.prototype.param_types[idx]))
# attr = getattr(_triton.ir.attribute_kind, attr) idx += 1
# attr = _triton.ir.attribute(attr, self.attributes[i])
# fn.add_attr(idx + 1, attr)
# fn.args[idx].name = arg_name
arg_values.append(triton.language.tensor(fn.args(idx), self.prototype.param_types[idx]))
idx += 1
for arg_name, arg_value in zip(arg_names, arg_values): insert_pt = self.builder.get_insertion_block()
self.set_value(arg_name, arg_value) for arg_name, arg_value in zip(arg_names, arg_values):
self.builder.set_insertion_point_to_start(entry) self.set_value(arg_name, arg_value)
# visit function body self.builder.set_insertion_point_to_start(entry)
self.visit_compound_statement(node.body) # visit function body
# finalize function has_ret = self.visit_compound_statement(node.body)
self.builder.ret_void() # finalize function
if not has_ret:
self.builder.ret([])
else:
# update return type
if isinstance(self.last_ret_type, tuple):
self.prototype.ret_types = [ret_tensor.type for ret_tensor in self.last_ret_type]
fn.reset_type(self.prototype.to_ir(self.builder))
else:
self.prototype.ret_types = [self.last_ret_type]
fn.reset_type(self.prototype.to_ir(self.builder))
if insert_pt:
self.builder.set_insertion_point_to_end(insert_pt)
def visit_arguments(self, node): def visit_arguments(self, node):
arg_names = [] arg_names = []
@@ -219,7 +267,6 @@ class CodeGenerator(ast.NodeVisitor):
if not isinstance(values, tuple): if not isinstance(values, tuple):
values = [values] values = [values]
for name, value in zip(names, values): for name, value in zip(names, values):
# TODO: can we store constexpr here to support constant folding?
# by default, constexpr are assigned into python variable # by default, constexpr are assigned into python variable
if isinstance(value, triton.language.constexpr): if isinstance(value, triton.language.constexpr):
value = value.value value = value.value
@@ -581,7 +628,40 @@ class CodeGenerator(ast.NodeVisitor):
kws.update(self.visit(keyword)) kws.update(self.visit(keyword))
args = [self.visit(arg) for arg in node.args] args = [self.visit(arg) for arg in node.args]
if isinstance(fn, JITFunction): if isinstance(fn, JITFunction):
return fn(*args, generator=self, **kws) from inspect import getcallargs
args = getcallargs(fn.fn, *args, **kws)
args = [args[name] for name in fn.arg_names]
args = [arg if isinstance(arg, triton.language.tensor)
else triton.language.constexpr(arg) for arg in args]
# generate function def
attributes = dict()
constexprs = [i for i, arg in enumerate(args) if isinstance(arg, triton.language.constexpr)]
constants = {i: args[i] for i in constexprs}
# generate call
args = [None if i in constexprs else arg for i, arg in enumerate(args)]
arg_vals = [arg.handle for arg in args if arg is not None]
arg_types = [arg.type for arg in args if arg is not None]
fn_name = mangle_fn(fn.__name__, arg_types, constants)
# generate function def if necessary
if not self.module.has_function(fn_name):
ret_type = triton.language.void
prototype = triton.language.function_type([ret_type], arg_types)
gscope = sys.modules[fn.fn.__module__].__dict__
generator = CodeGenerator(self.builder.context, prototype, gscope, attributes, constants, module=self.module, function_types=self.function_ret_types)
generator.visit(fn.parse())
callee_ret_type = generator.last_ret_type
self.function_ret_types[fn_name] = callee_ret_type
else:
callee_ret_type = self.function_ret_types[fn_name]
symbol = self.module.get_function(fn_name)
call_op = self.builder.call(symbol, arg_vals)
if call_op.get_num_results() == 0:
return None
elif call_op.get_num_results() == 1:
return triton.language.tensor(call_op.get_result(0), callee_ret_type)
else:
# should return a tuple of tl.tensor
raise RuntimeError("Not implemented")
if hasattr(fn, '__self__') and self.is_triton_tensor(fn.__self__) or \ if hasattr(fn, '__self__') and self.is_triton_tensor(fn.__self__) or \
sys.modules[fn.__module__] is triton.language.core: sys.modules[fn.__module__] is triton.language.core:
return fn(*args, _builder=self.builder, **kws) return fn(*args, _builder=self.builder, **kws)
@@ -1012,7 +1092,7 @@ class JITFunction:
cache_hook = None cache_hook = None
def __init__(self, fn, version=None, do_not_specialize=None): def __init__(self, fn, version=None, inline=True, do_not_specialize=None):
# information of wrapped function # information of wrapped function
self.fn = fn self.fn = fn
self.module = fn.__module__ self.module = fn.__module__
@@ -1021,6 +1101,7 @@ class JITFunction:
self.arg_defaults = [v.default for v in signature.parameters.values()] self.arg_defaults = [v.default for v in signature.parameters.values()]
self.version = version self.version = version
self.inline = inline
self.src = textwrap.dedent(inspect.getsource(fn)) self.src = textwrap.dedent(inspect.getsource(fn))
self.src = self.src[self.src.find("def"):] self.src = self.src[self.src.find("def"):]
self.do_not_specialize = [] if do_not_specialize is None else do_not_specialize self.do_not_specialize = [] if do_not_specialize is None else do_not_specialize
@@ -1035,6 +1116,8 @@ class JITFunction:
# annotations # annotations
self.annotations = {self.arg_names.index(name): ty for name, ty in fn.__annotations__.items()} self.annotations = {self.arg_names.index(name): ty for name, ty in fn.__annotations__.items()}
self.__annotations__ = fn.__annotations__ self.__annotations__ = fn.__annotations__
# constexprs
self.constexprs = [self.arg_names.index(ann) for ann in self.__annotations__.keys()]
# forward docs # forward docs
self.__doc__ = fn.__doc__ self.__doc__ = fn.__doc__
self.__name__ = fn.__name__ self.__name__ = fn.__name__
@@ -1061,32 +1144,8 @@ class JITFunction:
assert isinstance(tree.body[0], ast.FunctionDef) assert isinstance(tree.body[0], ast.FunctionDef)
return tree return tree
# Called by CodeGenerator.visit_Call() def __call__(self, *args, **kwargs):
def __call__(self, *args, generator: CodeGenerator, **kwargs): raise RuntimeError("Cannot call @triton.jit'd outside of the scope of a kernel")
try:
from inspect import getcallargs
arg_values = getcallargs(self.fn, *args, **kwargs)
arg_values = [arg_values[name] for name in self.arg_names]
arg_values = [arg if isinstance(arg, triton.language.tensor)
else triton.language.constexpr(arg) for arg in arg_values]
# Record values in the caller (parent scope)
gscope = generator.gscope.copy()
lscope = generator.lscope.copy()
# TODO: clear values other than args
generator.gscope = sys.modules[self.fn.__module__].__dict__
generator.lscope = dict()
ret = generator.visit_FunctionDef(self.parse().body[0], inline=True, arg_values=arg_values)
generator.gscope = gscope
generator.lscope = lscope
return ret
except Exception as e:
node = generator.last_node
if node is None or isinstance(e, (NotImplementedError, CompilationError)):
raise e
raise CompilationError(self.src, node) from e
# - when `.src` attribute is set, cache path needs # - when `.src` attribute is set, cache path needs
# to be reinitialized # to be reinitialized
@@ -1172,7 +1231,7 @@ class JITFunction:
# generate Triton-IR # generate Triton-IR
# export symbols visible from self into code-generator object # export symbols visible from self into code-generator object
gscope = self.__globals__ gscope = self.__globals__
generator = CodeGenerator(context, prototype, gscope=gscope, attributes=attributes, constants=constants, kwargs=dict()) generator = CodeGenerator(context, prototype, gscope=gscope, attributes=attributes, constants=constants, is_kernel=True)
try: try:
generator.visit(self.parse()) generator.visit(self.parse())
except Exception as e: except Exception as e:
@@ -1232,7 +1291,7 @@ class JITFunction:
# generate Triton-IR # generate Triton-IR
# export symbols visible from self into code-generator object # export symbols visible from self into code-generator object
gscope = self.__globals__ gscope = self.__globals__
generator = CodeGenerator(context, prototype, gscope=gscope, attributes=attributes, constants=constants, kwargs=dict()) generator = CodeGenerator(context, prototype, gscope=gscope, attributes=attributes, constants=constants, is_kernel=True)
try: try:
generator.visit(self.parse()) generator.visit(self.parse())
except Exception as e: except Exception as e:

261
rewrite-test/inline.mlir Normal file
View File

@@ -0,0 +1,261 @@
============================= test session starts ==============================
platform linux -- Python 3.7.7, pytest-7.1.1, pluggy-1.0.0
rootdir: /home/da/codes/triton-mlir-rewrite/triton/rewrite-test
collected 6 items
scf_tests.py .....module {
func @matmul_kernel__Pfp16_Pfp16_Pfp16_i32_i32_i32_i32_i32_i32__7c1_9c1_11c1_12c64_13c64_14c32_15c8(%arg0: !tt.ptr<f16>, %arg1: !tt.ptr<f16>, %arg2: !tt.ptr<f16>, %arg3: i32, %arg4: i32, %arg5: i32, %arg6: i32, %arg7: i32, %arg8: i32) {
%0 = tt.get_program_id {axis = 0 : i32} : i32
%1 = call @"cdiv__i32__1cconstexpr[64]"(%arg3) : (i32) -> i32
%2 = call @"cdiv__i32__1cconstexpr[64]"(%arg4) : (i32) -> i32
%c8_i32 = arith.constant 8 : i32
%3 = arith.muli %2, %c8_i32 : i32
%4 = arith.divsi %0, %3 : i32
%c8_i32_0 = arith.constant 8 : i32
%5 = arith.muli %4, %c8_i32_0 : i32
%6 = arith.subi %1, %5 : i32
%7 = call @"minimum__i32__1cconstexpr[8]"(%6) : (i32) -> i32
%8 = arith.remsi %0, %7 : i32
%9 = arith.addi %5, %8 : i32
%10 = arith.remsi %0, %3 : i32
%11 = arith.divsi %10, %7 : i32
%c64_i32 = arith.constant 64 : i32
%12 = arith.muli %9, %c64_i32 : i32
%13 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32>
%14 = tt.broadcast %12 : (i32) -> tensor<64xi32>
%15 = arith.addi %14, %13 : tensor<64xi32>
%c64_i32_1 = arith.constant 64 : i32
%16 = arith.muli %11, %c64_i32_1 : i32
%17 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32>
%18 = tt.broadcast %16 : (i32) -> tensor<64xi32>
%19 = arith.addi %18, %17 : tensor<64xi32>
%20 = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32>
%21 = tt.reshape %15 : (tensor<64xi32>) -> tensor<64x1xi32>
%22 = tt.broadcast %arg6 : (i32) -> tensor<64x1xi32>
%23 = arith.muli %21, %22 : tensor<64x1xi32>
%24 = tt.reshape %20 : (tensor<32xi32>) -> tensor<1x32xi32>
%c1_i32 = arith.constant 1 : i32
%25 = tt.broadcast %c1_i32 : (i32) -> tensor<1x32xi32>
%26 = arith.muli %24, %25 : tensor<1x32xi32>
%27 = tt.broadcast %23 : (tensor<64x1xi32>) -> tensor<64x32xi32>
%28 = tt.broadcast %26 : (tensor<1x32xi32>) -> tensor<64x32xi32>
%29 = arith.addi %27, %28 : tensor<64x32xi32>
%30 = tt.broadcast %arg0 : (!tt.ptr<f16>) -> tensor<64x32x!tt.ptr<f16>>
%31 = tt.getelementptr %30, %29, : tensor<64x32x!tt.ptr<f16>>
%32 = tt.reshape %20 : (tensor<32xi32>) -> tensor<32x1xi32>
%33 = tt.broadcast %arg7 : (i32) -> tensor<32x1xi32>
%34 = arith.muli %32, %33 : tensor<32x1xi32>
%35 = tt.reshape %19 : (tensor<64xi32>) -> tensor<1x64xi32>
%c1_i32_2 = arith.constant 1 : i32
%36 = tt.broadcast %c1_i32_2 : (i32) -> tensor<1x64xi32>
%37 = arith.muli %35, %36 : tensor<1x64xi32>
%38 = tt.broadcast %34 : (tensor<32x1xi32>) -> tensor<32x64xi32>
%39 = tt.broadcast %37 : (tensor<1x64xi32>) -> tensor<32x64xi32>
%40 = arith.addi %38, %39 : tensor<32x64xi32>
%41 = tt.broadcast %arg1 : (!tt.ptr<f16>) -> tensor<32x64x!tt.ptr<f16>>
%42 = tt.getelementptr %41, %40, : tensor<32x64x!tt.ptr<f16>>
%cst = arith.constant 0.000000e+00 : f32
%43 = tt.broadcast %cst : (f32) -> tensor<64x64xf32>
%c0_i32 = arith.constant 0 : i32
%c32_i32 = arith.constant 32 : i32
%44 = arith.index_cast %c0_i32 : i32 to index
%45 = arith.index_cast %arg5 : i32 to index
%46 = arith.index_cast %c32_i32 : i32 to index
%47:3 = scf.for %arg9 = %44 to %45 step %46 iter_args(%arg10 = %43, %arg11 = %31, %arg12 = %42) -> (tensor<64x64xf32>, tensor<64x32x!tt.ptr<f16>>, tensor<32x64x!tt.ptr<f16>>) {
%cst_6 = arith.constant dense<true> : tensor<64x32xi1>
%cst_7 = arith.constant dense<0.000000e+00> : tensor<64x32xf16>
%77 = tt.load %arg11, %cst_6, %cst_7 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<64x32xf16>
%cst_8 = arith.constant dense<true> : tensor<32x64xi1>
%cst_9 = arith.constant dense<0.000000e+00> : tensor<32x64xf16>
%78 = tt.load %arg12, %cst_8, %cst_9 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<32x64xf16>
%cst_10 = arith.constant 0.000000e+00 : f32
%79 = tt.broadcast %cst_10 : (f32) -> tensor<64x64xf32>
%80 = tt.dot %77, %78, %79 {allowTF32 = true} : tensor<64x32xf16> * tensor<32x64xf16> -> tensor<64x64xf32>
%81 = arith.addf %arg10, %80 : tensor<64x64xf32>
%c32_i32_11 = arith.constant 32 : i32
%82 = tt.broadcast %c32_i32_11 : (i32) -> tensor<64x32xi32>
%83 = tt.getelementptr %arg11, %82, : tensor<64x32x!tt.ptr<f16>>
%c32_i32_12 = arith.constant 32 : i32
%84 = arith.muli %arg7, %c32_i32_12 : i32
%85 = tt.broadcast %84 : (i32) -> tensor<32x64xi32>
%86 = tt.getelementptr %arg12, %85, : tensor<32x64x!tt.ptr<f16>>
scf.yield %81, %83, %86 : tensor<64x64xf32>, tensor<64x32x!tt.ptr<f16>>, tensor<32x64x!tt.ptr<f16>>
}
%48 = arith.truncf %47#0 : tensor<64x64xf32> to tensor<64x64xf16>
%c64_i32_3 = arith.constant 64 : i32
%49 = arith.muli %9, %c64_i32_3 : i32
%50 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32>
%51 = tt.broadcast %49 : (i32) -> tensor<64xi32>
%52 = arith.addi %51, %50 : tensor<64xi32>
%c64_i32_4 = arith.constant 64 : i32
%53 = arith.muli %11, %c64_i32_4 : i32
%54 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32>
%55 = tt.broadcast %53 : (i32) -> tensor<64xi32>
%56 = arith.addi %55, %54 : tensor<64xi32>
%57 = tt.reshape %52 : (tensor<64xi32>) -> tensor<64x1xi32>
%58 = tt.broadcast %arg8 : (i32) -> tensor<64x1xi32>
%59 = arith.muli %58, %57 : tensor<64x1xi32>
%60 = tt.broadcast %arg2 : (!tt.ptr<f16>) -> tensor<64x1x!tt.ptr<f16>>
%61 = tt.getelementptr %60, %59, : tensor<64x1x!tt.ptr<f16>>
%62 = tt.reshape %56 : (tensor<64xi32>) -> tensor<1x64xi32>
%c1_i32_5 = arith.constant 1 : i32
%63 = tt.broadcast %c1_i32_5 : (i32) -> tensor<1x64xi32>
%64 = arith.muli %62, %63 : tensor<1x64xi32>
%65 = tt.broadcast %61 : (tensor<64x1x!tt.ptr<f16>>) -> tensor<64x64x!tt.ptr<f16>>
%66 = tt.broadcast %64 : (tensor<1x64xi32>) -> tensor<64x64xi32>
%67 = tt.getelementptr %65, %66, : tensor<64x64x!tt.ptr<f16>>
%68 = tt.reshape %52 : (tensor<64xi32>) -> tensor<64x1xi32>
%69 = tt.broadcast %arg3 : (i32) -> tensor<64x1xi32>
%70 = arith.cmpi slt, %68, %69 : tensor<64x1xi32>
%71 = tt.reshape %56 : (tensor<64xi32>) -> tensor<1x64xi32>
%72 = tt.broadcast %arg4 : (i32) -> tensor<1x64xi32>
%73 = arith.cmpi slt, %71, %72 : tensor<1x64xi32>
%74 = tt.broadcast %70 : (tensor<64x1xi1>) -> tensor<64x64xi1>
%75 = tt.broadcast %73 : (tensor<1x64xi1>) -> tensor<64x64xi1>
%76 = arith.andi %74, %75 : tensor<64x64xi1>
tt.store %67, %48, %76, : tensor<64x64xf16>
return
}
func @"cdiv__i32__1cconstexpr[64]"(%arg0: i32) -> i32 {
%c64_i32 = arith.constant 64 : i32
%0 = arith.addi %arg0, %c64_i32 : i32
%c1_i32 = arith.constant 1 : i32
%1 = arith.subi %0, %c1_i32 : i32
%c64_i32_0 = arith.constant 64 : i32
%2 = arith.divsi %1, %c64_i32_0 : i32
return %2 : i32
}
func @"minimum__i32__1cconstexpr[8]"(%arg0: i32) -> i32 {
%c8_i32 = arith.constant 8 : i32
%0 = arith.cmpi slt, %arg0, %c8_i32 : i32
%c8_i32_0 = arith.constant 8 : i32
%1 = select %0, %arg0, %c8_i32_0 : i32
return %1 : i32
}
}
module {
func @matmul_kernel__Pfp16_Pfp16_Pfp16_i32_i32_i32_i32_i32_i32__7c1_9c1_11c1_12c64_13c64_14c32_15c8(%arg0: !tt.ptr<f16>, %arg1: !tt.ptr<f16>, %arg2: !tt.ptr<f16>, %arg3: i32, %arg4: i32, %arg5: i32, %arg6: i32, %arg7: i32, %arg8: i32) {
%c8_i32 = arith.constant 8 : i32
%c63_i32 = arith.constant 63 : i32
%c64_i32 = arith.constant 64 : i32
%cst = arith.constant 0.000000e+00 : f32
%c0 = arith.constant 0 : index
%c32 = arith.constant 32 : index
%cst_0 = arith.constant dense<true> : tensor<64x32xi1>
%cst_1 = arith.constant dense<0.000000e+00> : tensor<64x32xf16>
%cst_2 = arith.constant dense<true> : tensor<32x64xi1>
%cst_3 = arith.constant dense<0.000000e+00> : tensor<32x64xf16>
%c32_i32 = arith.constant 32 : i32
%c1_i32 = arith.constant 1 : i32
%0 = tt.get_program_id {axis = 0 : i32} : i32
%1 = arith.addi %arg3, %c63_i32 : i32
%2 = arith.divsi %1, %c64_i32 : i32
%3 = arith.addi %arg4, %c63_i32 : i32
%4 = arith.divsi %3, %c64_i32 : i32
%5 = arith.muli %4, %c8_i32 : i32
%6 = arith.divsi %0, %5 : i32
%7 = arith.muli %6, %c8_i32 : i32
%8 = arith.subi %2, %7 : i32
%9 = arith.cmpi slt, %8, %c8_i32 : i32
%10 = select %9, %8, %c8_i32 : i32
%11 = arith.remsi %0, %10 : i32
%12 = arith.addi %7, %11 : i32
%13 = arith.remsi %0, %5 : i32
%14 = arith.divsi %13, %10 : i32
%15 = arith.muli %12, %c64_i32 : i32
%16 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32>
%17 = tt.broadcast %15 : (i32) -> tensor<64xi32>
%18 = arith.addi %17, %16 : tensor<64xi32>
%19 = arith.muli %14, %c64_i32 : i32
%20 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32>
%21 = tt.broadcast %19 : (i32) -> tensor<64xi32>
%22 = arith.addi %21, %20 : tensor<64xi32>
%23 = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32>
%24 = tt.reshape %18 : (tensor<64xi32>) -> tensor<64x1xi32>
%25 = tt.broadcast %arg6 : (i32) -> tensor<64x1xi32>
%26 = arith.muli %24, %25 : tensor<64x1xi32>
%27 = tt.reshape %23 : (tensor<32xi32>) -> tensor<1x32xi32>
%28 = tt.broadcast %c1_i32 : (i32) -> tensor<1x32xi32>
%29 = arith.muli %27, %28 : tensor<1x32xi32>
%30 = tt.broadcast %26 : (tensor<64x1xi32>) -> tensor<64x32xi32>
%31 = tt.broadcast %29 : (tensor<1x32xi32>) -> tensor<64x32xi32>
%32 = arith.addi %30, %31 : tensor<64x32xi32>
%33 = tt.broadcast %arg0 : (!tt.ptr<f16>) -> tensor<64x32x!tt.ptr<f16>>
%34 = tt.getelementptr %33, %32, : tensor<64x32x!tt.ptr<f16>>
%35 = tt.reshape %23 : (tensor<32xi32>) -> tensor<32x1xi32>
%36 = tt.broadcast %arg7 : (i32) -> tensor<32x1xi32>
%37 = arith.muli %35, %36 : tensor<32x1xi32>
%38 = tt.reshape %22 : (tensor<64xi32>) -> tensor<1x64xi32>
%39 = tt.broadcast %c1_i32 : (i32) -> tensor<1x64xi32>
%40 = arith.muli %38, %39 : tensor<1x64xi32>
%41 = tt.broadcast %37 : (tensor<32x1xi32>) -> tensor<32x64xi32>
%42 = tt.broadcast %40 : (tensor<1x64xi32>) -> tensor<32x64xi32>
%43 = arith.addi %41, %42 : tensor<32x64xi32>
%44 = tt.broadcast %arg1 : (!tt.ptr<f16>) -> tensor<32x64x!tt.ptr<f16>>
%45 = tt.getelementptr %44, %43, : tensor<32x64x!tt.ptr<f16>>
%46 = tt.broadcast %cst : (f32) -> tensor<64x64xf32>
%47 = arith.index_cast %arg5 : i32 to index
%48:3 = scf.for %arg9 = %c0 to %47 step %c32 iter_args(%arg10 = %46, %arg11 = %34, %arg12 = %45) -> (tensor<64x64xf32>, tensor<64x32x!tt.ptr<f16>>, tensor<32x64x!tt.ptr<f16>>) {
%78 = tt.load %arg11, %cst_0, %cst_1 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<64x32xf16>
%79 = tt.load %arg12, %cst_2, %cst_3 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<32x64xf16>
%80 = tt.broadcast %cst : (f32) -> tensor<64x64xf32>
%81 = tt.dot %78, %79, %80 {allowTF32 = true} : tensor<64x32xf16> * tensor<32x64xf16> -> tensor<64x64xf32>
%82 = arith.addf %arg10, %81 : tensor<64x64xf32>
%83 = tt.broadcast %c32_i32 : (i32) -> tensor<64x32xi32>
%84 = tt.getelementptr %arg11, %83, : tensor<64x32x!tt.ptr<f16>>
%85 = arith.muli %arg7, %c32_i32 : i32
%86 = tt.broadcast %85 : (i32) -> tensor<32x64xi32>
%87 = tt.getelementptr %arg12, %86, : tensor<32x64x!tt.ptr<f16>>
scf.yield %82, %84, %87 : tensor<64x64xf32>, tensor<64x32x!tt.ptr<f16>>, tensor<32x64x!tt.ptr<f16>>
}
%49 = arith.truncf %48#0 : tensor<64x64xf32> to tensor<64x64xf16>
%50 = arith.muli %12, %c64_i32 : i32
%51 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32>
%52 = tt.broadcast %50 : (i32) -> tensor<64xi32>
%53 = arith.addi %52, %51 : tensor<64xi32>
%54 = arith.muli %14, %c64_i32 : i32
%55 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32>
%56 = tt.broadcast %54 : (i32) -> tensor<64xi32>
%57 = arith.addi %56, %55 : tensor<64xi32>
%58 = tt.reshape %53 : (tensor<64xi32>) -> tensor<64x1xi32>
%59 = tt.broadcast %arg8 : (i32) -> tensor<64x1xi32>
%60 = arith.muli %59, %58 : tensor<64x1xi32>
%61 = tt.broadcast %arg2 : (!tt.ptr<f16>) -> tensor<64x1x!tt.ptr<f16>>
%62 = tt.getelementptr %61, %60, : tensor<64x1x!tt.ptr<f16>>
%63 = tt.reshape %57 : (tensor<64xi32>) -> tensor<1x64xi32>
%64 = tt.broadcast %c1_i32 : (i32) -> tensor<1x64xi32>
%65 = arith.muli %63, %64 : tensor<1x64xi32>
%66 = tt.broadcast %62 : (tensor<64x1x!tt.ptr<f16>>) -> tensor<64x64x!tt.ptr<f16>>
%67 = tt.broadcast %65 : (tensor<1x64xi32>) -> tensor<64x64xi32>
%68 = tt.getelementptr %66, %67, : tensor<64x64x!tt.ptr<f16>>
%69 = tt.reshape %53 : (tensor<64xi32>) -> tensor<64x1xi32>
%70 = tt.broadcast %arg3 : (i32) -> tensor<64x1xi32>
%71 = arith.cmpi slt, %69, %70 : tensor<64x1xi32>
%72 = tt.reshape %57 : (tensor<64xi32>) -> tensor<1x64xi32>
%73 = tt.broadcast %arg4 : (i32) -> tensor<1x64xi32>
%74 = arith.cmpi slt, %72, %73 : tensor<1x64xi32>
%75 = tt.broadcast %71 : (tensor<64x1xi1>) -> tensor<64x64xi1>
%76 = tt.broadcast %74 : (tensor<1x64xi1>) -> tensor<64x64xi1>
%77 = arith.andi %75, %76 : tensor<64x64xi1>
tt.store %68, %49, %77, : tensor<64x64xf16>
return
}
func @"cdiv__i32__1cconstexpr[64]"(%arg0: i32) -> i32 {
%c64_i32 = arith.constant 64 : i32
%c63_i32 = arith.constant 63 : i32
%0 = arith.addi %arg0, %c63_i32 : i32
%1 = arith.divsi %0, %c64_i32 : i32
return %1 : i32
}
func @"minimum__i32__1cconstexpr[8]"(%arg0: i32) -> i32 {
%c8_i32 = arith.constant 8 : i32
%0 = arith.cmpi slt, %arg0, %c8_i32 : i32
%1 = select %0, %arg0, %c8_i32 : i32
return %1 : i32
}
}
.
============================== 6 passed in 1.21s ===============================

View File

@@ -2,12 +2,14 @@ import pytest
import triton import triton
import triton.language as tl import triton.language as tl
import triton._C.libtriton.triton as _triton
import torch import torch
def test_if(): def test_if():
ref_ir = """module { ref_ir = """module {
func @only_if(%arg0: i32, %arg1: i32, %arg2: i32) { func @only_if__i32_i32_i32__(%arg0: i32, %arg1: i32, %arg2: i32) {
%cst = arith.constant -1.000000e+00 : f32 %cst = arith.constant -1.000000e+00 : f32
%0 = arith.cmpi sgt, %arg2, %arg0 : i32 %0 = arith.cmpi sgt, %arg2, %arg0 : i32
%1 = scf.if %0 -> (f32) { %1 = scf.if %0 -> (f32) {
@@ -36,7 +38,7 @@ def test_if():
def test_if_else(): def test_if_else():
ref_ir = """module { ref_ir = """module {
func @if_else(%arg0: i32, %arg1: i32, %arg2: i32) { func @if_else__i32_i32_i32__(%arg0: i32, %arg1: i32, %arg2: i32) {
%0 = arith.cmpi sgt, %arg2, %arg0 : i32 %0 = arith.cmpi sgt, %arg2, %arg0 : i32
%1 = scf.if %0 -> (f32) { %1 = scf.if %0 -> (f32) {
%cst = arith.constant 0.000000e+00 : f32 %cst = arith.constant 0.000000e+00 : f32
@@ -65,7 +67,7 @@ def test_if_else():
def test_for(): def test_for():
ref_ir = """module { ref_ir = """module {
func @for_loop(%arg0: i32) { func @for_loop__i32__(%arg0: i32) {
%cst = arith.constant 1.000000e+00 : f32 %cst = arith.constant 1.000000e+00 : f32
%c0_i32 = arith.constant 0 : i32 %c0_i32 = arith.constant 0 : i32
%c1_i32 = arith.constant 1 : i32 %c1_i32 = arith.constant 1 : i32
@@ -95,7 +97,7 @@ def test_for():
def test_while(): def test_while():
ref_ir = """module { ref_ir = """module {
func @generic_while(%arg0: i32) { func @generic_while__i32__(%arg0: i32) {
%c-1_i32 = arith.constant -1 : i32 %c-1_i32 = arith.constant -1 : i32
%0 = scf.while (%arg1 = %c-1_i32) : (i32) -> i32 { %0 = scf.while (%arg1 = %c-1_i32) : (i32) -> i32 {
%c0_i32 = arith.constant 0 : i32 %c0_i32 = arith.constant 0 : i32
@@ -124,7 +126,7 @@ def test_while():
def test_nested(): def test_nested():
ref_ir = """module { ref_ir = """module {
func @nested_cf(%arg0: i32, %arg1: i32, %arg2: i32, %arg3: i32) { func @nested_cf__i32_i32_i32_i32__(%arg0: i32, %arg1: i32, %arg2: i32, %arg3: i32) {
%cst = arith.constant 0.000000e+00 : f32 %cst = arith.constant 0.000000e+00 : f32
%c0_i32 = arith.constant 0 : i32 %c0_i32 = arith.constant 0 : i32
%c1_i32 = arith.constant 1 : i32 %c1_i32 = arith.constant 1 : i32
@@ -402,4 +404,14 @@ def test_matmul():
) )
verify = mod.verify() verify = mod.verify()
assert verify assert verify
assert ref_ir == mod.str() # assert ref_ir == mod.str()
print(mod.str())
pm = _triton.ir.pass_manager(ctx)
pm.add_inliner_pass()
pm.run(mod)
verify = mod.verify()
assert verify
# assert ref_ir == mod.str()
print(mod.str())