Fix OpBuilder
This commit is contained in:
@@ -123,8 +123,10 @@ def TT_LoadOp : TT_Op<"load", [SameOperandsAndResultShape]> {
|
|||||||
|
|
||||||
let builders = [
|
let builders = [
|
||||||
// for args with default values
|
// for args with default values
|
||||||
OpBuilder<(ins "Value":$ptr)>,
|
OpBuilder<(ins "Value":$ptr, "triton::CacheModifier":$cache,
|
||||||
OpBuilder<(ins "Value":$ptr, "Value":$mask)>
|
"triton::EvictionPolicy":$evict, "bool":$isVolatile)>,
|
||||||
|
OpBuilder<(ins "Value":$ptr, "Value":$mask, "triton::CacheModifier":$cache,
|
||||||
|
"triton::EvictionPolicy":$evict, "bool":$isVolatile)>
|
||||||
];
|
];
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -199,7 +201,7 @@ def TT_DotOp : TT_Op<"dot", [NoSideEffect, SameOperandsAndResultShape]> {
|
|||||||
$d = matrix_multiply($a, $b) + $c
|
$d = matrix_multiply($a, $b) + $c
|
||||||
}];
|
}];
|
||||||
|
|
||||||
let arguments = (ins TT_FpIntTensor:$a, TT_FpIntTensor:$b, TT_FpIntTensor:$c);
|
let arguments = (ins TT_FpIntTensor:$a, TT_FpIntTensor:$b, TT_FpIntTensor:$c, BoolAttr:$allowTF32);
|
||||||
|
|
||||||
let results = (outs TT_FpIntTensor:$d);
|
let results = (outs TT_FpIntTensor:$d);
|
||||||
}
|
}
|
||||||
|
@@ -33,7 +33,8 @@ void StoreOp::build(::mlir::OpBuilder &builder, ::mlir::OperationState &state, :
|
|||||||
}
|
}
|
||||||
|
|
||||||
//-- LoadOp --
|
//-- LoadOp --
|
||||||
void LoadOp::build(::mlir::OpBuilder &builder, ::mlir::OperationState &state, ::mlir::Value ptr) {
|
void LoadOp::build(::mlir::OpBuilder &builder, ::mlir::OperationState &state, ::mlir::Value ptr,
|
||||||
|
::mlir::triton::CacheModifier cache, ::mlir::triton::EvictionPolicy evict, bool isVolatile) {
|
||||||
TensorType ptrType = ptr.getType().dyn_cast<TensorType>();
|
TensorType ptrType = ptr.getType().dyn_cast<TensorType>();
|
||||||
Type elementType = ptrType.getElementType().dyn_cast<PointerType>().getPointeeType();
|
Type elementType = ptrType.getElementType().dyn_cast<PointerType>().getPointeeType();
|
||||||
auto shape = ptrType.getShape();
|
auto shape = ptrType.getShape();
|
||||||
@@ -57,6 +58,9 @@ void LoadOp::build(::mlir::OpBuilder &builder, ::mlir::OperationState &state, ::
|
|||||||
state.addOperands(ptr);
|
state.addOperands(ptr);
|
||||||
state.addOperands(mask);
|
state.addOperands(mask);
|
||||||
state.addOperands(other);
|
state.addOperands(other);
|
||||||
|
state.addAttribute(cacheAttrName(state.name), ::mlir::triton::CacheModifierAttr::get(builder.getContext(), cache));
|
||||||
|
state.addAttribute(evictAttrName(state.name), ::mlir::triton::EvictionPolicyAttr::get(builder.getContext(), evict));
|
||||||
|
state.addAttribute(isVolatileAttrName(state.name), builder.getBoolAttr(isVolatile));
|
||||||
state.addTypes({resultType});
|
state.addTypes({resultType});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@@ -1155,9 +1155,12 @@ void init_triton_ir(py::module &&m) {
|
|||||||
return self.create<mlir::arith::OrIOp>(loc, lhs, rhs);
|
return self.create<mlir::arith::OrIOp>(loc, lhs, rhs);
|
||||||
})
|
})
|
||||||
// // Input/Output
|
// // Input/Output
|
||||||
.def("create_load", [](mlir::OpBuilder &self, mlir::Value &ptrs) -> mlir::Value {
|
.def("create_load", [](mlir::OpBuilder &self, mlir::Value &ptrs,
|
||||||
|
mlir::triton::CacheModifier cacheModifer,
|
||||||
|
mlir::triton::EvictionPolicy evictionPolicy,
|
||||||
|
bool isVolatile) -> mlir::Value {
|
||||||
auto loc = self.getUnknownLoc();
|
auto loc = self.getUnknownLoc();
|
||||||
return self.create<mlir::triton::LoadOp>(loc, ptrs);
|
return self.create<mlir::triton::LoadOp>(loc, ptrs, cacheModifer, evictionPolicy, isVolatile);
|
||||||
})
|
})
|
||||||
.def("create_store", [](mlir::OpBuilder &self, mlir::Value &ptrs, mlir::Value &value) -> void {
|
.def("create_store", [](mlir::OpBuilder &self, mlir::Value &ptrs, mlir::Value &value) -> void {
|
||||||
auto loc = self.getUnknownLoc();
|
auto loc = self.getUnknownLoc();
|
||||||
@@ -1200,8 +1203,7 @@ void init_triton_ir(py::module &&m) {
|
|||||||
})
|
})
|
||||||
.def("create_broadcast", [](mlir::OpBuilder &self, mlir::Value &arg, std::vector<int64_t> &shape) -> mlir::Value {
|
.def("create_broadcast", [](mlir::OpBuilder &self, mlir::Value &arg, std::vector<int64_t> &shape) -> mlir::Value {
|
||||||
auto loc = self.getUnknownLoc();
|
auto loc = self.getUnknownLoc();
|
||||||
// TODO: should be scalar type here
|
auto argType = arg.getType().dyn_cast<mlir::RankedTensorType>().getElementType();
|
||||||
auto argType = arg.getType();
|
|
||||||
return self.create<mlir::triton::BroadcastOp>(
|
return self.create<mlir::triton::BroadcastOp>(
|
||||||
loc, mlir::RankedTensorType::get(shape, argType), arg
|
loc, mlir::RankedTensorType::get(shape, argType), arg
|
||||||
);
|
);
|
||||||
@@ -1246,9 +1248,9 @@ void init_triton_ir(py::module &&m) {
|
|||||||
loc, self.getI32Type(), self.getI32IntegerAttr(axis)
|
loc, self.getI32Type(), self.getI32IntegerAttr(axis)
|
||||||
);
|
);
|
||||||
})
|
})
|
||||||
.def("create_dot", [](mlir::OpBuilder &self, mlir::Value &a, mlir::Value &b, mlir::Value &c) -> mlir::Value {
|
.def("create_dot", [](mlir::OpBuilder &self, mlir::Value &a, mlir::Value &b, mlir::Value &c, bool allowTF32) -> mlir::Value {
|
||||||
auto loc = self.getUnknownLoc();
|
auto loc = self.getUnknownLoc();
|
||||||
return self.create<mlir::triton::DotOp>(loc, c.getType(), a, b, c);
|
return self.create<mlir::triton::DotOp>(loc, c.getType(), a, b, c, allowTF32);
|
||||||
})
|
})
|
||||||
// .def("create_exp", &ir::builder::create_exp, ret::reference)
|
// .def("create_exp", &ir::builder::create_exp, ret::reference)
|
||||||
// .def("create_cos", &ir::builder::create_cos, ret::reference)
|
// .def("create_cos", &ir::builder::create_cos, ret::reference)
|
||||||
@@ -1257,7 +1259,11 @@ void init_triton_ir(py::module &&m) {
|
|||||||
// .def("create_trans", &ir::builder::create_trans, ret::reference)
|
// .def("create_trans", &ir::builder::create_trans, ret::reference)
|
||||||
// .def("create_sqrt", &ir::builder::create_sqrt, ret::reference)
|
// .def("create_sqrt", &ir::builder::create_sqrt, ret::reference)
|
||||||
// .def("create_reduce", &ir::builder::create_reduce, ret::reference)
|
// .def("create_reduce", &ir::builder::create_reduce, ret::reference)
|
||||||
// .def("create_select", &ir::builder::create_select, ret::reference)
|
.def("create_select", [](mlir::OpBuilder &self, mlir::Value &condition,
|
||||||
|
mlir::Value &trueValue, mlir::Value &falseValue) -> mlir::Value {
|
||||||
|
auto loc = self.getUnknownLoc();
|
||||||
|
return self.create<mlir::SelectOp>(loc, condition, trueValue, falseValue);
|
||||||
|
})
|
||||||
// // Intrinsics
|
// // Intrinsics
|
||||||
// // These have no place in the IR, and hopefully they can be removed at some point
|
// // These have no place in the IR, and hopefully they can be removed at some point
|
||||||
// .def("create_umulhi", &ir::builder::create_umulhi, ret::reference)
|
// .def("create_umulhi", &ir::builder::create_umulhi, ret::reference)
|
||||||
|
@@ -1177,7 +1177,16 @@ class JITFunction:
|
|||||||
# Compile to ttir, for the propose of testing MLIR rewriting
|
# Compile to ttir, for the propose of testing MLIR rewriting
|
||||||
def compile_to_ttir(self, *wargs, grid, num_warps=4, num_stages=2, **kwargs):
|
def compile_to_ttir(self, *wargs, grid, num_warps=4, num_stages=2, **kwargs):
|
||||||
# TODO: share code with _compile & __call__
|
# TODO: share code with _compile & __call__
|
||||||
|
# handle arguments passed by name
|
||||||
|
kwargs = {self.arg_names.index(name): value for name, value in kwargs.items()}
|
||||||
|
wargs = list(wargs)
|
||||||
|
for i, pos in enumerate(sorted(kwargs)):
|
||||||
|
wargs.insert(pos + i, kwargs[pos])
|
||||||
|
if len(wargs) != len(self.arg_names):
|
||||||
|
raise TypeError(f"Function takes {len(self.arg_names)} positional arguments but {len(wargs)} were given")
|
||||||
|
# handle annotations
|
||||||
|
for pos, _type in self.annotations.items():
|
||||||
|
wargs[pos] = _type(wargs[pos])
|
||||||
# preparing args
|
# preparing args
|
||||||
tensor_idxs = [i for i, arg in enumerate(wargs) if hasattr(arg, 'data_ptr')]
|
tensor_idxs = [i for i, arg in enumerate(wargs) if hasattr(arg, 'data_ptr')]
|
||||||
# attributes
|
# attributes
|
||||||
@@ -1191,7 +1200,7 @@ class JITFunction:
|
|||||||
attributes[i] = min(Kernel.pow2_divisor(addr),
|
attributes[i] = min(Kernel.pow2_divisor(addr),
|
||||||
Kernel.pow2_divisor(range_size))
|
Kernel.pow2_divisor(range_size))
|
||||||
# transforms ints whose value is one into constants for just-in-time compilation
|
# transforms ints whose value is one into constants for just-in-time compilation
|
||||||
constants = {i: arg for i, arg in enumerate(wargs) if isinstance(arg, int) and arg == 1 and i not in self.fn.do_not_specialize}
|
constants = {i: arg for i, arg in enumerate(wargs) if isinstance(arg, int) and arg == 1 and i not in self.do_not_specialize}
|
||||||
constants.update({i: arg.value for i, arg in enumerate(wargs) if isinstance(arg, triton.language.constexpr)})
|
constants.update({i: arg.value for i, arg in enumerate(wargs) if isinstance(arg, triton.language.constexpr)})
|
||||||
constants.update({i: None for i, arg in enumerate(wargs) if arg is None})
|
constants.update({i: None for i, arg in enumerate(wargs) if arg is None})
|
||||||
arg_types = [Kernel._to_python_ir(arg) for i, arg in enumerate(wargs) if i not in constants]
|
arg_types = [Kernel._to_python_ir(arg) for i, arg in enumerate(wargs) if i not in constants]
|
||||||
|
@@ -1,3 +1,4 @@
|
|||||||
|
from tarfile import BLOCKSIZE
|
||||||
import torch
|
import torch
|
||||||
import triton
|
import triton
|
||||||
import triton.language as tl
|
import triton.language as tl
|
||||||
@@ -9,8 +10,8 @@ def add_kernel(
|
|||||||
y_ptr, # *Pointer* to second input vector
|
y_ptr, # *Pointer* to second input vector
|
||||||
output_ptr, # *Pointer* to output vector
|
output_ptr, # *Pointer* to output vector
|
||||||
n_elements, # Size of the vector
|
n_elements, # Size of the vector
|
||||||
# BLOCK_SIZE: tl.constexpr, # Number of elements each program should process
|
BLOCK_SIZE: tl.constexpr, # Number of elements each program should process
|
||||||
# # NOTE: `constexpr` so it can be used as a shape value
|
# NOTE: `constexpr` so it can be used as a shape value
|
||||||
):
|
):
|
||||||
# There are multiple 'program's processing different data. We identify which program
|
# There are multiple 'program's processing different data. We identify which program
|
||||||
# we are here
|
# we are here
|
||||||
@@ -19,8 +20,8 @@ def add_kernel(
|
|||||||
# for instance, if you had a vector of length 256 and block_size of 64, the programs
|
# for instance, if you had a vector of length 256 and block_size of 64, the programs
|
||||||
# would each access the elements [0:64, 64:128, 128:192, 192:256].
|
# would each access the elements [0:64, 64:128, 128:192, 192:256].
|
||||||
# Note that offsets is a list of pointers
|
# Note that offsets is a list of pointers
|
||||||
block_start = pid * 256
|
block_start = pid * BLOCK_SIZE
|
||||||
offsets = block_start + tl.arange(0, 256)
|
offsets = block_start + tl.arange(0, BLOCK_SIZE)
|
||||||
# Create a mask to guard memory operations against out-of-bounds accesses
|
# Create a mask to guard memory operations against out-of-bounds accesses
|
||||||
mask = offsets < n_elements
|
mask = offsets < n_elements
|
||||||
# Load x and y from DRAM, masking out any extra elements in case the input is not a
|
# Load x and y from DRAM, masking out any extra elements in case the input is not a
|
||||||
@@ -37,7 +38,6 @@ y = torch.rand(size, device='cuda')
|
|||||||
z = torch.empty_like(x)
|
z = torch.empty_like(x)
|
||||||
# add_kernel[(1,)](x, y, z, size, 256)
|
# add_kernel[(1,)](x, y, z, size, 256)
|
||||||
# print(add_kernel[(1,)].kernel.compile_to_ttir())
|
# print(add_kernel[(1,)].kernel.compile_to_ttir())
|
||||||
mod, ctx = add_kernel.compile_to_ttir(x, y, z, size, grid=(1,))
|
# print(add_kernel.annotations)
|
||||||
mod.get_context()
|
mod, ctx = add_kernel.compile_to_ttir(x, y, z, size, BLOCK_SIZE=256, grid=(1,))
|
||||||
mod.dump()
|
mod.dump()
|
||||||
# print(mod)
|
|
||||||
|
Reference in New Issue
Block a user