Device function & PassManager

This commit is contained in:
Yan Da
2022-04-15 14:41:57 +08:00
parent 44d75cf9bb
commit 1c52bd587d
5 changed files with 464 additions and 92 deletions

View File

@@ -181,7 +181,10 @@ target_link_libraries(triton
TritonDriver
# TritonCodeGen
MLIRCAPIIR
# optimizations
MLIRPass
MLIRTransforms
${PYTHON_LIBRARIES}
)

View File

@@ -8,6 +8,10 @@
#include "mlir/IR/MLIRContext.h"
#include "mlir/IR/Verifier.h"
#include "mlir/Pass/PassManager.h"
#include "mlir/Transforms/Passes.h"
#include "triton/ir/Dialect.h"
#include "triton/ir/Types.h"
@@ -717,6 +721,9 @@ void init_triton_ir(py::module &&m) {
.def("set_attr", [](mlir::OpState &self, std::string &name, mlir::Attribute &attr) -> void {
self->setAttr(name, attr);
})
.def("get_num_results", [](mlir::OpState &self) -> unsigned {
return self->getNumResults();
})
.def("get_result", [](mlir::OpState &self, unsigned idx) -> mlir::Value {
return self->getResult(idx);
})
@@ -755,12 +762,18 @@ void init_triton_ir(py::module &&m) {
py::class_<mlir::scf::ConditionOp, mlir::OpState>(m, "CondtionOp");
py::class_<mlir::ModuleOp, mlir::OpState>(m, "module")
.def("dump", [](mlir::ModuleOp &self) -> void {
self.dump();
})
.def("dump", &mlir::ModuleOp::dump)
.def("push_back", [](mlir::ModuleOp &self, mlir::FuncOp &funcOp) -> void {
self.push_back(funcOp);
})
.def("has_function", [](mlir::ModuleOp &self, std::string &funcName) -> bool {
if (self.lookupSymbol(funcName))
return true;
return false;
})
.def("get_function", [](mlir::ModuleOp &self, std::string &funcName) -> mlir::FuncOp {
return self.lookupSymbol<mlir::FuncOp>(funcName);
})
;
py::class_<mlir::FuncOp, mlir::OpState>(m, "function")
@@ -772,6 +785,7 @@ void init_triton_ir(py::module &&m) {
.def("add_entry_block", [](mlir::FuncOp &self) -> mlir::Block* {
return self.addEntryBlock();
}, ret::reference)
.def("reset_type", &mlir::FuncOp::setType)
;
py::class_<mlir::OpBuilder::InsertPoint>(m, "InsertPoint");
@@ -784,11 +798,14 @@ void init_triton_ir(py::module &&m) {
auto loc = self.getUnknownLoc();
return self.create<mlir::ModuleOp>(loc);
})
// control flow
.def("ret_void", [](mlir::OpBuilder &self) {
.def("ret", [](mlir::OpBuilder &self, std::vector<mlir::Value> &vals) -> void {
auto loc = self.getUnknownLoc();
self.create<mlir::ReturnOp>(loc);
}, ret::reference)
self.create<mlir::ReturnOp>(loc, vals);
})
.def("call", [](mlir::OpBuilder &self, mlir::FuncOp &func, std::vector<mlir::Value> &args) -> mlir::OpState {
auto loc = self.getUnknownLoc();
return self.create<mlir::CallOp>(loc, func, args);
})
// insertion block/point
.def("set_insertion_point_to_start", [](mlir::OpBuilder &self, mlir::Block &block) -> void {
self.setInsertionPointToStart(&block);
@@ -900,6 +917,16 @@ void init_triton_ir(py::module &&m) {
}
throw std::runtime_error("invalid function type");
})
.def("get_or_insert_function", [](mlir::OpBuilder &self, mlir::ModuleOp &module,
std::string &funcName, mlir::Type &funcType) -> mlir::FuncOp {
if (mlir::Operation *funcOperation = module.lookupSymbol(funcName))
return llvm::dyn_cast<mlir::FuncOp>(funcOperation);
auto loc = self.getUnknownLoc();
if (auto funcTy = funcType.dyn_cast<mlir::FunctionType>()) {
return self.create<mlir::FuncOp>(loc, funcName, funcTy);
}
throw std::runtime_error("invalid function type");
})
.def("create_block", [](mlir::OpBuilder &self) -> mlir::Block* {
mlir::Region *parent = self.getBlock()->getParent();
return self.createBlock(parent);
@@ -1293,6 +1320,16 @@ void init_triton_ir(py::module &&m) {
// .def("create_umulhi", &ir::builder::create_umulhi, ret::reference)
// .def("create_barrier", &ir::builder::create_barrier, ret::reference);
;
py::class_<mlir::PassManager>(m, "pass_manager")
.def(py::init<mlir::MLIRContext *>())
.def("run", [](mlir::PassManager &self, mlir::ModuleOp &mod) {
self.run(mod.getOperation());
})
.def("add_inliner_pass", [](mlir::PassManager &self) {
self.addPass(mlir::createInlinerPass());
})
;
}
void init_triton(py::module &m) {

View File

@@ -14,6 +14,7 @@ import textwrap
import time
import warnings
from typing import Dict, Optional, Set, Tuple, Union
from numpy import isin
import torch
from filelock import FileLock
@@ -22,6 +23,41 @@ import triton
import triton._C.libtriton.triton as _triton
from .tools.disasm import extract
def mangle_ty(ty):
if ty.is_ptr():
return 'P' + mangle_ty(ty.element_ty)
if ty.is_int():
return 'i' + str(ty.int_bitwidth)
if ty.is_fp8():
return 'fp8'
if ty.is_fp16():
return 'fp16'
if ty.is_bf16():
return 'bf16'
if ty.is_fp32():
return 'fp32'
if ty.is_fp64():
return 'fp64'
if ty.is_void():
return 'V'
if ty.is_block():
elt = mangle_ty(ty.scalar)
shape = '_'.join(map(str, ty.shape))
return f'{elt}S{shape}S'
assert False, "Unsupport type"
def mangle_fn(name, arg_tys, constants):
# doesn't mangle ret type, which must be a function of arg tys
mangled_arg_names = '_'.join([mangle_ty(ty) for ty in arg_tys])
key = lambda x: x.__name__ if isinstance(x, JITFunction) else repr(x)
mangled_constants = '_'.join([f'{i}c{key(constants[i])}' for i in sorted(constants)])
mangled_constants = mangled_constants.replace('.', '_d_')
mangled_constants = mangled_constants.replace("'", '_sq_')
ret = f'{name}__{mangled_arg_names}__{mangled_constants}'
return ret
class enter_sub_region:
def __init__(self, generator: CodeGenerator):
self.generator = generator
@@ -40,15 +76,16 @@ class enter_sub_region:
self.generator.local_defs = self.prev_defs
class CodeGenerator(ast.NodeVisitor):
def __init__(self, context, prototype, gscope, attributes, constants, kwargs):
def __init__(self, context, prototype, gscope, attributes, constants, module=None, is_kernel=False, function_types=dict()):
self.builder = _triton.ir.builder(context)
self.module = self.builder.create_module()
self.module = self.builder.create_module() if module is None else module
self.function_ret_types = function_types
self.prototype = prototype
self.gscope = gscope
self.lscope = dict()
self.attributes = attributes
self.constants = constants
self.kwargs = kwargs
self.is_kernel = is_kernel
self.last_node = None
self.builtins = {
'range': range,
@@ -104,7 +141,7 @@ class CodeGenerator(ast.NodeVisitor):
#
def visit_compound_statement(self, stmts):
for stmt in stmts:
self.last_ret = self.visit(stmt)
self.last_ret_type = self.visit(stmt)
if isinstance(stmt, ast.Return):
break
return stmts and isinstance(stmt, ast.Return)
@@ -120,12 +157,18 @@ class CodeGenerator(ast.NodeVisitor):
# By design, only non-kernel functions can return
def visit_Return(self, node):
ret = self.visit(node.value)
if ret is None:
return self.builder.ret_void()
return ret
ret_value = self.visit(node.value)
if ret_value is None:
return self.builder.ret([])
if isinstance(ret_value, list):
assert False, "returing a tuple is not supported"
else:
ret = triton.language.core._to_tensor(ret_value, self.builder)
ret_type = ret.type
self.builder.ret([ret_value.handle])
return ret_type
def visit_FunctionDef(self, node, inline=False, arg_values=None):
def visit_FunctionDef(self, node):
arg_names, kwarg_names = self.visit(node.args)
# initialize defaults
for i, default_value in enumerate(node.args.defaults):
@@ -138,46 +181,51 @@ class CodeGenerator(ast.NodeVisitor):
else:
init_node = ast.AnnAssign(target=st_target, value=default_value, annotation=annotation)
self.visit(init_node)
# store keyword arguments in local scope
self.lscope[kwarg_names] = self.kwargs
# initialize function
if inline:
for arg_name, arg_value in zip(arg_names, arg_values):
self.set_value(arg_name, arg_value)
self.visit_compound_statement(node.body)
return self.last_ret
else:
fn = self.builder.create_function(node.name, self.prototype.to_ir(self.builder))
self.module.push_back(fn)
entry = fn.add_entry_block()
arg_values = []
idx = 0
for i, arg_name in enumerate(arg_names):
if i in self.constants:
cst = self.constants[i]
if not isinstance(cst, triton.language.constexpr):
cst = triton.language.constexpr(self.constants[i])
arg_values.append(cst)
else:
pass
# TODO: ...
# if i in self.attributes:
# is_ptr = fn.args[idx].type.is_ptr()
# attr = 'aligned' if is_ptr else 'multiple_of'
# attr = getattr(_triton.ir.attribute_kind, attr)
# attr = _triton.ir.attribute(attr, self.attributes[i])
# fn.add_attr(idx + 1, attr)
# fn.args[idx].name = arg_name
arg_values.append(triton.language.tensor(fn.args(idx), self.prototype.param_types[idx]))
idx += 1
fn_name = mangle_fn(node.name, self.prototype.param_types, self.constants)
fn = self.builder.get_or_insert_function(self.module, fn_name, self.prototype.to_ir(self.builder))
self.module.push_back(fn)
entry = fn.add_entry_block()
arg_values = []
idx = 0
for i, arg_name in enumerate(arg_names):
if i in self.constants:
cst = self.constants[i]
if not isinstance(cst, triton.language.constexpr):
cst = triton.language.constexpr(self.constants[i])
arg_values.append(cst)
else:
pass
# TODO: ...
# if i in self.attributes:
# is_ptr = fn.args[idx].type.is_ptr()
# attr = 'aligned' if is_ptr else 'multiple_of'
# attr = getattr(_triton.ir.attribute_kind, attr)
# attr = _triton.ir.attribute(attr, self.attributes[i])
# fn.add_attr(idx + 1, attr)
# fn.args[idx].name = arg_name
arg_values.append(triton.language.tensor(fn.args(idx), self.prototype.param_types[idx]))
idx += 1
for arg_name, arg_value in zip(arg_names, arg_values):
self.set_value(arg_name, arg_value)
self.builder.set_insertion_point_to_start(entry)
# visit function body
self.visit_compound_statement(node.body)
# finalize function
self.builder.ret_void()
insert_pt = self.builder.get_insertion_block()
for arg_name, arg_value in zip(arg_names, arg_values):
self.set_value(arg_name, arg_value)
self.builder.set_insertion_point_to_start(entry)
# visit function body
has_ret = self.visit_compound_statement(node.body)
# finalize function
if not has_ret:
self.builder.ret([])
else:
# update return type
if isinstance(self.last_ret_type, tuple):
self.prototype.ret_types = [ret_tensor.type for ret_tensor in self.last_ret_type]
fn.reset_type(self.prototype.to_ir(self.builder))
else:
self.prototype.ret_types = [self.last_ret_type]
fn.reset_type(self.prototype.to_ir(self.builder))
if insert_pt:
self.builder.set_insertion_point_to_end(insert_pt)
def visit_arguments(self, node):
arg_names = []
@@ -219,7 +267,6 @@ class CodeGenerator(ast.NodeVisitor):
if not isinstance(values, tuple):
values = [values]
for name, value in zip(names, values):
# TODO: can we store constexpr here to support constant folding?
# by default, constexpr are assigned into python variable
if isinstance(value, triton.language.constexpr):
value = value.value
@@ -581,7 +628,40 @@ class CodeGenerator(ast.NodeVisitor):
kws.update(self.visit(keyword))
args = [self.visit(arg) for arg in node.args]
if isinstance(fn, JITFunction):
return fn(*args, generator=self, **kws)
from inspect import getcallargs
args = getcallargs(fn.fn, *args, **kws)
args = [args[name] for name in fn.arg_names]
args = [arg if isinstance(arg, triton.language.tensor)
else triton.language.constexpr(arg) for arg in args]
# generate function def
attributes = dict()
constexprs = [i for i, arg in enumerate(args) if isinstance(arg, triton.language.constexpr)]
constants = {i: args[i] for i in constexprs}
# generate call
args = [None if i in constexprs else arg for i, arg in enumerate(args)]
arg_vals = [arg.handle for arg in args if arg is not None]
arg_types = [arg.type for arg in args if arg is not None]
fn_name = mangle_fn(fn.__name__, arg_types, constants)
# generate function def if necessary
if not self.module.has_function(fn_name):
ret_type = triton.language.void
prototype = triton.language.function_type([ret_type], arg_types)
gscope = sys.modules[fn.fn.__module__].__dict__
generator = CodeGenerator(self.builder.context, prototype, gscope, attributes, constants, module=self.module, function_types=self.function_ret_types)
generator.visit(fn.parse())
callee_ret_type = generator.last_ret_type
self.function_ret_types[fn_name] = callee_ret_type
else:
callee_ret_type = self.function_ret_types[fn_name]
symbol = self.module.get_function(fn_name)
call_op = self.builder.call(symbol, arg_vals)
if call_op.get_num_results() == 0:
return None
elif call_op.get_num_results() == 1:
return triton.language.tensor(call_op.get_result(0), callee_ret_type)
else:
# should return a tuple of tl.tensor
raise RuntimeError("Not implemented")
if hasattr(fn, '__self__') and self.is_triton_tensor(fn.__self__) or \
sys.modules[fn.__module__] is triton.language.core:
return fn(*args, _builder=self.builder, **kws)
@@ -1012,7 +1092,7 @@ class JITFunction:
cache_hook = None
def __init__(self, fn, version=None, do_not_specialize=None):
def __init__(self, fn, version=None, inline=True, do_not_specialize=None):
# information of wrapped function
self.fn = fn
self.module = fn.__module__
@@ -1021,6 +1101,7 @@ class JITFunction:
self.arg_defaults = [v.default for v in signature.parameters.values()]
self.version = version
self.inline = inline
self.src = textwrap.dedent(inspect.getsource(fn))
self.src = self.src[self.src.find("def"):]
self.do_not_specialize = [] if do_not_specialize is None else do_not_specialize
@@ -1035,6 +1116,8 @@ class JITFunction:
# annotations
self.annotations = {self.arg_names.index(name): ty for name, ty in fn.__annotations__.items()}
self.__annotations__ = fn.__annotations__
# constexprs
self.constexprs = [self.arg_names.index(ann) for ann in self.__annotations__.keys()]
# forward docs
self.__doc__ = fn.__doc__
self.__name__ = fn.__name__
@@ -1061,32 +1144,8 @@ class JITFunction:
assert isinstance(tree.body[0], ast.FunctionDef)
return tree
# Called by CodeGenerator.visit_Call()
def __call__(self, *args, generator: CodeGenerator, **kwargs):
try:
from inspect import getcallargs
arg_values = getcallargs(self.fn, *args, **kwargs)
arg_values = [arg_values[name] for name in self.arg_names]
arg_values = [arg if isinstance(arg, triton.language.tensor)
else triton.language.constexpr(arg) for arg in arg_values]
# Record values in the caller (parent scope)
gscope = generator.gscope.copy()
lscope = generator.lscope.copy()
# TODO: clear values other than args
generator.gscope = sys.modules[self.fn.__module__].__dict__
generator.lscope = dict()
ret = generator.visit_FunctionDef(self.parse().body[0], inline=True, arg_values=arg_values)
generator.gscope = gscope
generator.lscope = lscope
return ret
except Exception as e:
node = generator.last_node
if node is None or isinstance(e, (NotImplementedError, CompilationError)):
raise e
raise CompilationError(self.src, node) from e
def __call__(self, *args, **kwargs):
raise RuntimeError("Cannot call @triton.jit'd outside of the scope of a kernel")
# - when `.src` attribute is set, cache path needs
# to be reinitialized
@@ -1172,7 +1231,7 @@ class JITFunction:
# generate Triton-IR
# export symbols visible from self into code-generator object
gscope = self.__globals__
generator = CodeGenerator(context, prototype, gscope=gscope, attributes=attributes, constants=constants, kwargs=dict())
generator = CodeGenerator(context, prototype, gscope=gscope, attributes=attributes, constants=constants, is_kernel=True)
try:
generator.visit(self.parse())
except Exception as e:
@@ -1232,7 +1291,7 @@ class JITFunction:
# generate Triton-IR
# export symbols visible from self into code-generator object
gscope = self.__globals__
generator = CodeGenerator(context, prototype, gscope=gscope, attributes=attributes, constants=constants, kwargs=dict())
generator = CodeGenerator(context, prototype, gscope=gscope, attributes=attributes, constants=constants, is_kernel=True)
try:
generator.visit(self.parse())
except Exception as e:

261
rewrite-test/inline.mlir Normal file
View File

@@ -0,0 +1,261 @@
============================= test session starts ==============================
platform linux -- Python 3.7.7, pytest-7.1.1, pluggy-1.0.0
rootdir: /home/da/codes/triton-mlir-rewrite/triton/rewrite-test
collected 6 items
scf_tests.py .....module {
func @matmul_kernel__Pfp16_Pfp16_Pfp16_i32_i32_i32_i32_i32_i32__7c1_9c1_11c1_12c64_13c64_14c32_15c8(%arg0: !tt.ptr<f16>, %arg1: !tt.ptr<f16>, %arg2: !tt.ptr<f16>, %arg3: i32, %arg4: i32, %arg5: i32, %arg6: i32, %arg7: i32, %arg8: i32) {
%0 = tt.get_program_id {axis = 0 : i32} : i32
%1 = call @"cdiv__i32__1cconstexpr[64]"(%arg3) : (i32) -> i32
%2 = call @"cdiv__i32__1cconstexpr[64]"(%arg4) : (i32) -> i32
%c8_i32 = arith.constant 8 : i32
%3 = arith.muli %2, %c8_i32 : i32
%4 = arith.divsi %0, %3 : i32
%c8_i32_0 = arith.constant 8 : i32
%5 = arith.muli %4, %c8_i32_0 : i32
%6 = arith.subi %1, %5 : i32
%7 = call @"minimum__i32__1cconstexpr[8]"(%6) : (i32) -> i32
%8 = arith.remsi %0, %7 : i32
%9 = arith.addi %5, %8 : i32
%10 = arith.remsi %0, %3 : i32
%11 = arith.divsi %10, %7 : i32
%c64_i32 = arith.constant 64 : i32
%12 = arith.muli %9, %c64_i32 : i32
%13 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32>
%14 = tt.broadcast %12 : (i32) -> tensor<64xi32>
%15 = arith.addi %14, %13 : tensor<64xi32>
%c64_i32_1 = arith.constant 64 : i32
%16 = arith.muli %11, %c64_i32_1 : i32
%17 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32>
%18 = tt.broadcast %16 : (i32) -> tensor<64xi32>
%19 = arith.addi %18, %17 : tensor<64xi32>
%20 = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32>
%21 = tt.reshape %15 : (tensor<64xi32>) -> tensor<64x1xi32>
%22 = tt.broadcast %arg6 : (i32) -> tensor<64x1xi32>
%23 = arith.muli %21, %22 : tensor<64x1xi32>
%24 = tt.reshape %20 : (tensor<32xi32>) -> tensor<1x32xi32>
%c1_i32 = arith.constant 1 : i32
%25 = tt.broadcast %c1_i32 : (i32) -> tensor<1x32xi32>
%26 = arith.muli %24, %25 : tensor<1x32xi32>
%27 = tt.broadcast %23 : (tensor<64x1xi32>) -> tensor<64x32xi32>
%28 = tt.broadcast %26 : (tensor<1x32xi32>) -> tensor<64x32xi32>
%29 = arith.addi %27, %28 : tensor<64x32xi32>
%30 = tt.broadcast %arg0 : (!tt.ptr<f16>) -> tensor<64x32x!tt.ptr<f16>>
%31 = tt.getelementptr %30, %29, : tensor<64x32x!tt.ptr<f16>>
%32 = tt.reshape %20 : (tensor<32xi32>) -> tensor<32x1xi32>
%33 = tt.broadcast %arg7 : (i32) -> tensor<32x1xi32>
%34 = arith.muli %32, %33 : tensor<32x1xi32>
%35 = tt.reshape %19 : (tensor<64xi32>) -> tensor<1x64xi32>
%c1_i32_2 = arith.constant 1 : i32
%36 = tt.broadcast %c1_i32_2 : (i32) -> tensor<1x64xi32>
%37 = arith.muli %35, %36 : tensor<1x64xi32>
%38 = tt.broadcast %34 : (tensor<32x1xi32>) -> tensor<32x64xi32>
%39 = tt.broadcast %37 : (tensor<1x64xi32>) -> tensor<32x64xi32>
%40 = arith.addi %38, %39 : tensor<32x64xi32>
%41 = tt.broadcast %arg1 : (!tt.ptr<f16>) -> tensor<32x64x!tt.ptr<f16>>
%42 = tt.getelementptr %41, %40, : tensor<32x64x!tt.ptr<f16>>
%cst = arith.constant 0.000000e+00 : f32
%43 = tt.broadcast %cst : (f32) -> tensor<64x64xf32>
%c0_i32 = arith.constant 0 : i32
%c32_i32 = arith.constant 32 : i32
%44 = arith.index_cast %c0_i32 : i32 to index
%45 = arith.index_cast %arg5 : i32 to index
%46 = arith.index_cast %c32_i32 : i32 to index
%47:3 = scf.for %arg9 = %44 to %45 step %46 iter_args(%arg10 = %43, %arg11 = %31, %arg12 = %42) -> (tensor<64x64xf32>, tensor<64x32x!tt.ptr<f16>>, tensor<32x64x!tt.ptr<f16>>) {
%cst_6 = arith.constant dense<true> : tensor<64x32xi1>
%cst_7 = arith.constant dense<0.000000e+00> : tensor<64x32xf16>
%77 = tt.load %arg11, %cst_6, %cst_7 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<64x32xf16>
%cst_8 = arith.constant dense<true> : tensor<32x64xi1>
%cst_9 = arith.constant dense<0.000000e+00> : tensor<32x64xf16>
%78 = tt.load %arg12, %cst_8, %cst_9 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<32x64xf16>
%cst_10 = arith.constant 0.000000e+00 : f32
%79 = tt.broadcast %cst_10 : (f32) -> tensor<64x64xf32>
%80 = tt.dot %77, %78, %79 {allowTF32 = true} : tensor<64x32xf16> * tensor<32x64xf16> -> tensor<64x64xf32>
%81 = arith.addf %arg10, %80 : tensor<64x64xf32>
%c32_i32_11 = arith.constant 32 : i32
%82 = tt.broadcast %c32_i32_11 : (i32) -> tensor<64x32xi32>
%83 = tt.getelementptr %arg11, %82, : tensor<64x32x!tt.ptr<f16>>
%c32_i32_12 = arith.constant 32 : i32
%84 = arith.muli %arg7, %c32_i32_12 : i32
%85 = tt.broadcast %84 : (i32) -> tensor<32x64xi32>
%86 = tt.getelementptr %arg12, %85, : tensor<32x64x!tt.ptr<f16>>
scf.yield %81, %83, %86 : tensor<64x64xf32>, tensor<64x32x!tt.ptr<f16>>, tensor<32x64x!tt.ptr<f16>>
}
%48 = arith.truncf %47#0 : tensor<64x64xf32> to tensor<64x64xf16>
%c64_i32_3 = arith.constant 64 : i32
%49 = arith.muli %9, %c64_i32_3 : i32
%50 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32>
%51 = tt.broadcast %49 : (i32) -> tensor<64xi32>
%52 = arith.addi %51, %50 : tensor<64xi32>
%c64_i32_4 = arith.constant 64 : i32
%53 = arith.muli %11, %c64_i32_4 : i32
%54 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32>
%55 = tt.broadcast %53 : (i32) -> tensor<64xi32>
%56 = arith.addi %55, %54 : tensor<64xi32>
%57 = tt.reshape %52 : (tensor<64xi32>) -> tensor<64x1xi32>
%58 = tt.broadcast %arg8 : (i32) -> tensor<64x1xi32>
%59 = arith.muli %58, %57 : tensor<64x1xi32>
%60 = tt.broadcast %arg2 : (!tt.ptr<f16>) -> tensor<64x1x!tt.ptr<f16>>
%61 = tt.getelementptr %60, %59, : tensor<64x1x!tt.ptr<f16>>
%62 = tt.reshape %56 : (tensor<64xi32>) -> tensor<1x64xi32>
%c1_i32_5 = arith.constant 1 : i32
%63 = tt.broadcast %c1_i32_5 : (i32) -> tensor<1x64xi32>
%64 = arith.muli %62, %63 : tensor<1x64xi32>
%65 = tt.broadcast %61 : (tensor<64x1x!tt.ptr<f16>>) -> tensor<64x64x!tt.ptr<f16>>
%66 = tt.broadcast %64 : (tensor<1x64xi32>) -> tensor<64x64xi32>
%67 = tt.getelementptr %65, %66, : tensor<64x64x!tt.ptr<f16>>
%68 = tt.reshape %52 : (tensor<64xi32>) -> tensor<64x1xi32>
%69 = tt.broadcast %arg3 : (i32) -> tensor<64x1xi32>
%70 = arith.cmpi slt, %68, %69 : tensor<64x1xi32>
%71 = tt.reshape %56 : (tensor<64xi32>) -> tensor<1x64xi32>
%72 = tt.broadcast %arg4 : (i32) -> tensor<1x64xi32>
%73 = arith.cmpi slt, %71, %72 : tensor<1x64xi32>
%74 = tt.broadcast %70 : (tensor<64x1xi1>) -> tensor<64x64xi1>
%75 = tt.broadcast %73 : (tensor<1x64xi1>) -> tensor<64x64xi1>
%76 = arith.andi %74, %75 : tensor<64x64xi1>
tt.store %67, %48, %76, : tensor<64x64xf16>
return
}
func @"cdiv__i32__1cconstexpr[64]"(%arg0: i32) -> i32 {
%c64_i32 = arith.constant 64 : i32
%0 = arith.addi %arg0, %c64_i32 : i32
%c1_i32 = arith.constant 1 : i32
%1 = arith.subi %0, %c1_i32 : i32
%c64_i32_0 = arith.constant 64 : i32
%2 = arith.divsi %1, %c64_i32_0 : i32
return %2 : i32
}
func @"minimum__i32__1cconstexpr[8]"(%arg0: i32) -> i32 {
%c8_i32 = arith.constant 8 : i32
%0 = arith.cmpi slt, %arg0, %c8_i32 : i32
%c8_i32_0 = arith.constant 8 : i32
%1 = select %0, %arg0, %c8_i32_0 : i32
return %1 : i32
}
}
module {
func @matmul_kernel__Pfp16_Pfp16_Pfp16_i32_i32_i32_i32_i32_i32__7c1_9c1_11c1_12c64_13c64_14c32_15c8(%arg0: !tt.ptr<f16>, %arg1: !tt.ptr<f16>, %arg2: !tt.ptr<f16>, %arg3: i32, %arg4: i32, %arg5: i32, %arg6: i32, %arg7: i32, %arg8: i32) {
%c8_i32 = arith.constant 8 : i32
%c63_i32 = arith.constant 63 : i32
%c64_i32 = arith.constant 64 : i32
%cst = arith.constant 0.000000e+00 : f32
%c0 = arith.constant 0 : index
%c32 = arith.constant 32 : index
%cst_0 = arith.constant dense<true> : tensor<64x32xi1>
%cst_1 = arith.constant dense<0.000000e+00> : tensor<64x32xf16>
%cst_2 = arith.constant dense<true> : tensor<32x64xi1>
%cst_3 = arith.constant dense<0.000000e+00> : tensor<32x64xf16>
%c32_i32 = arith.constant 32 : i32
%c1_i32 = arith.constant 1 : i32
%0 = tt.get_program_id {axis = 0 : i32} : i32
%1 = arith.addi %arg3, %c63_i32 : i32
%2 = arith.divsi %1, %c64_i32 : i32
%3 = arith.addi %arg4, %c63_i32 : i32
%4 = arith.divsi %3, %c64_i32 : i32
%5 = arith.muli %4, %c8_i32 : i32
%6 = arith.divsi %0, %5 : i32
%7 = arith.muli %6, %c8_i32 : i32
%8 = arith.subi %2, %7 : i32
%9 = arith.cmpi slt, %8, %c8_i32 : i32
%10 = select %9, %8, %c8_i32 : i32
%11 = arith.remsi %0, %10 : i32
%12 = arith.addi %7, %11 : i32
%13 = arith.remsi %0, %5 : i32
%14 = arith.divsi %13, %10 : i32
%15 = arith.muli %12, %c64_i32 : i32
%16 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32>
%17 = tt.broadcast %15 : (i32) -> tensor<64xi32>
%18 = arith.addi %17, %16 : tensor<64xi32>
%19 = arith.muli %14, %c64_i32 : i32
%20 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32>
%21 = tt.broadcast %19 : (i32) -> tensor<64xi32>
%22 = arith.addi %21, %20 : tensor<64xi32>
%23 = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32>
%24 = tt.reshape %18 : (tensor<64xi32>) -> tensor<64x1xi32>
%25 = tt.broadcast %arg6 : (i32) -> tensor<64x1xi32>
%26 = arith.muli %24, %25 : tensor<64x1xi32>
%27 = tt.reshape %23 : (tensor<32xi32>) -> tensor<1x32xi32>
%28 = tt.broadcast %c1_i32 : (i32) -> tensor<1x32xi32>
%29 = arith.muli %27, %28 : tensor<1x32xi32>
%30 = tt.broadcast %26 : (tensor<64x1xi32>) -> tensor<64x32xi32>
%31 = tt.broadcast %29 : (tensor<1x32xi32>) -> tensor<64x32xi32>
%32 = arith.addi %30, %31 : tensor<64x32xi32>
%33 = tt.broadcast %arg0 : (!tt.ptr<f16>) -> tensor<64x32x!tt.ptr<f16>>
%34 = tt.getelementptr %33, %32, : tensor<64x32x!tt.ptr<f16>>
%35 = tt.reshape %23 : (tensor<32xi32>) -> tensor<32x1xi32>
%36 = tt.broadcast %arg7 : (i32) -> tensor<32x1xi32>
%37 = arith.muli %35, %36 : tensor<32x1xi32>
%38 = tt.reshape %22 : (tensor<64xi32>) -> tensor<1x64xi32>
%39 = tt.broadcast %c1_i32 : (i32) -> tensor<1x64xi32>
%40 = arith.muli %38, %39 : tensor<1x64xi32>
%41 = tt.broadcast %37 : (tensor<32x1xi32>) -> tensor<32x64xi32>
%42 = tt.broadcast %40 : (tensor<1x64xi32>) -> tensor<32x64xi32>
%43 = arith.addi %41, %42 : tensor<32x64xi32>
%44 = tt.broadcast %arg1 : (!tt.ptr<f16>) -> tensor<32x64x!tt.ptr<f16>>
%45 = tt.getelementptr %44, %43, : tensor<32x64x!tt.ptr<f16>>
%46 = tt.broadcast %cst : (f32) -> tensor<64x64xf32>
%47 = arith.index_cast %arg5 : i32 to index
%48:3 = scf.for %arg9 = %c0 to %47 step %c32 iter_args(%arg10 = %46, %arg11 = %34, %arg12 = %45) -> (tensor<64x64xf32>, tensor<64x32x!tt.ptr<f16>>, tensor<32x64x!tt.ptr<f16>>) {
%78 = tt.load %arg11, %cst_0, %cst_1 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<64x32xf16>
%79 = tt.load %arg12, %cst_2, %cst_3 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<32x64xf16>
%80 = tt.broadcast %cst : (f32) -> tensor<64x64xf32>
%81 = tt.dot %78, %79, %80 {allowTF32 = true} : tensor<64x32xf16> * tensor<32x64xf16> -> tensor<64x64xf32>
%82 = arith.addf %arg10, %81 : tensor<64x64xf32>
%83 = tt.broadcast %c32_i32 : (i32) -> tensor<64x32xi32>
%84 = tt.getelementptr %arg11, %83, : tensor<64x32x!tt.ptr<f16>>
%85 = arith.muli %arg7, %c32_i32 : i32
%86 = tt.broadcast %85 : (i32) -> tensor<32x64xi32>
%87 = tt.getelementptr %arg12, %86, : tensor<32x64x!tt.ptr<f16>>
scf.yield %82, %84, %87 : tensor<64x64xf32>, tensor<64x32x!tt.ptr<f16>>, tensor<32x64x!tt.ptr<f16>>
}
%49 = arith.truncf %48#0 : tensor<64x64xf32> to tensor<64x64xf16>
%50 = arith.muli %12, %c64_i32 : i32
%51 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32>
%52 = tt.broadcast %50 : (i32) -> tensor<64xi32>
%53 = arith.addi %52, %51 : tensor<64xi32>
%54 = arith.muli %14, %c64_i32 : i32
%55 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32>
%56 = tt.broadcast %54 : (i32) -> tensor<64xi32>
%57 = arith.addi %56, %55 : tensor<64xi32>
%58 = tt.reshape %53 : (tensor<64xi32>) -> tensor<64x1xi32>
%59 = tt.broadcast %arg8 : (i32) -> tensor<64x1xi32>
%60 = arith.muli %59, %58 : tensor<64x1xi32>
%61 = tt.broadcast %arg2 : (!tt.ptr<f16>) -> tensor<64x1x!tt.ptr<f16>>
%62 = tt.getelementptr %61, %60, : tensor<64x1x!tt.ptr<f16>>
%63 = tt.reshape %57 : (tensor<64xi32>) -> tensor<1x64xi32>
%64 = tt.broadcast %c1_i32 : (i32) -> tensor<1x64xi32>
%65 = arith.muli %63, %64 : tensor<1x64xi32>
%66 = tt.broadcast %62 : (tensor<64x1x!tt.ptr<f16>>) -> tensor<64x64x!tt.ptr<f16>>
%67 = tt.broadcast %65 : (tensor<1x64xi32>) -> tensor<64x64xi32>
%68 = tt.getelementptr %66, %67, : tensor<64x64x!tt.ptr<f16>>
%69 = tt.reshape %53 : (tensor<64xi32>) -> tensor<64x1xi32>
%70 = tt.broadcast %arg3 : (i32) -> tensor<64x1xi32>
%71 = arith.cmpi slt, %69, %70 : tensor<64x1xi32>
%72 = tt.reshape %57 : (tensor<64xi32>) -> tensor<1x64xi32>
%73 = tt.broadcast %arg4 : (i32) -> tensor<1x64xi32>
%74 = arith.cmpi slt, %72, %73 : tensor<1x64xi32>
%75 = tt.broadcast %71 : (tensor<64x1xi1>) -> tensor<64x64xi1>
%76 = tt.broadcast %74 : (tensor<1x64xi1>) -> tensor<64x64xi1>
%77 = arith.andi %75, %76 : tensor<64x64xi1>
tt.store %68, %49, %77, : tensor<64x64xf16>
return
}
func @"cdiv__i32__1cconstexpr[64]"(%arg0: i32) -> i32 {
%c64_i32 = arith.constant 64 : i32
%c63_i32 = arith.constant 63 : i32
%0 = arith.addi %arg0, %c63_i32 : i32
%1 = arith.divsi %0, %c64_i32 : i32
return %1 : i32
}
func @"minimum__i32__1cconstexpr[8]"(%arg0: i32) -> i32 {
%c8_i32 = arith.constant 8 : i32
%0 = arith.cmpi slt, %arg0, %c8_i32 : i32
%1 = select %0, %arg0, %c8_i32 : i32
return %1 : i32
}
}
.
============================== 6 passed in 1.21s ===============================

View File

@@ -2,12 +2,14 @@ import pytest
import triton
import triton.language as tl
import triton._C.libtriton.triton as _triton
import torch
def test_if():
ref_ir = """module {
func @only_if(%arg0: i32, %arg1: i32, %arg2: i32) {
func @only_if__i32_i32_i32__(%arg0: i32, %arg1: i32, %arg2: i32) {
%cst = arith.constant -1.000000e+00 : f32
%0 = arith.cmpi sgt, %arg2, %arg0 : i32
%1 = scf.if %0 -> (f32) {
@@ -36,7 +38,7 @@ def test_if():
def test_if_else():
ref_ir = """module {
func @if_else(%arg0: i32, %arg1: i32, %arg2: i32) {
func @if_else__i32_i32_i32__(%arg0: i32, %arg1: i32, %arg2: i32) {
%0 = arith.cmpi sgt, %arg2, %arg0 : i32
%1 = scf.if %0 -> (f32) {
%cst = arith.constant 0.000000e+00 : f32
@@ -65,7 +67,7 @@ def test_if_else():
def test_for():
ref_ir = """module {
func @for_loop(%arg0: i32) {
func @for_loop__i32__(%arg0: i32) {
%cst = arith.constant 1.000000e+00 : f32
%c0_i32 = arith.constant 0 : i32
%c1_i32 = arith.constant 1 : i32
@@ -95,7 +97,7 @@ def test_for():
def test_while():
ref_ir = """module {
func @generic_while(%arg0: i32) {
func @generic_while__i32__(%arg0: i32) {
%c-1_i32 = arith.constant -1 : i32
%0 = scf.while (%arg1 = %c-1_i32) : (i32) -> i32 {
%c0_i32 = arith.constant 0 : i32
@@ -124,7 +126,7 @@ def test_while():
def test_nested():
ref_ir = """module {
func @nested_cf(%arg0: i32, %arg1: i32, %arg2: i32, %arg3: i32) {
func @nested_cf__i32_i32_i32_i32__(%arg0: i32, %arg1: i32, %arg2: i32, %arg3: i32) {
%cst = arith.constant 0.000000e+00 : f32
%c0_i32 = arith.constant 0 : i32
%c1_i32 = arith.constant 1 : i32
@@ -402,4 +404,14 @@ def test_matmul():
)
verify = mod.verify()
assert verify
assert ref_ir == mod.str()
# assert ref_ir == mod.str()
print(mod.str())
pm = _triton.ir.pass_manager(ctx)
pm.add_inliner_pass()
pm.run(mod)
verify = mod.verify()
assert verify
# assert ref_ir == mod.str()
print(mod.str())