[TritonGPU] Improved documentation and semantics of layout encodings (#30)

This commit is contained in:
Philippe Tillet
2022-07-31 13:59:44 -07:00
committed by GitHub
parent e02c82c765
commit d1593e6ca8
17 changed files with 399 additions and 566 deletions

View File

@@ -10,10 +10,10 @@ def kernel(X, stride_xm, stride_xn,
BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr):
off_m = tl.arange(0, BLOCK_M)
off_n = tl.arange(0, BLOCK_N)
Xs = X + off_m[:, None] * stride_xm + off_n[None, :] * stride_xn
Zs = Z + off_m[:, None] * stride_zm + off_n[None, :] * stride_zn
Xs = X + off_m[:, None] * stride_xm + off_n[None, :] * 1
Zs = Z + off_m[:, None] * 1 + off_n[None, :] * stride_zn
tl.store(Zs, tl.load(Xs))
ret = triton.compile(kernel, "*fp32,i32,i32,*fp32,i32,i32", constants={"BLOCK_M": 128, "BLOCK_N": 128}, output="ttgir")
ret = triton.compile(kernel, "*fp32,i32,i32,*fp32,i32,i32", constants={"BLOCK_M": 64, "BLOCK_N": 64}, output="ttgir")
print(ret)

View File

@@ -1471,14 +1471,14 @@ void init_triton_ir(py::module &&m) {
self.create<mlir::triton::StoreOp>(loc, ptrs, val, mask);
})
// Block instruction
.def("create_reshape",
.def("create_view",
[](mlir::OpBuilder &self, mlir::Value &arg,
std::vector<int64_t> &shape) -> mlir::Value {
auto loc = self.getUnknownLoc();
auto argType = arg.getType()
.dyn_cast<mlir::RankedTensorType>()
.getElementType();
return self.create<mlir::triton::ReshapeOp>(
return self.create<mlir::triton::ViewOp>(
loc, mlir::RankedTensorType::get(shape, argType), arg);
})
.def("create_cat",

View File

@@ -565,7 +565,7 @@ class tensor:
elif sl == slice(None, None, None):
dst_shape.append(src_shape[curr].value)
curr += 1
ret = semantic.reshape(self, dst_shape, _builder)
ret = semantic.view(self, dst_shape, _builder)
return ret
@builtin

View File

@@ -451,16 +451,16 @@ def zeros(shape: List[int], dtype: tl.dtype, builder: ir.builder) -> tl.tensor:
# ===----------------------------------------------------------------------===//
def reshape(input: tl.tensor,
dst_shape: List[int],
builder: ir.builder) -> tl.tensor:
def view(input: tl.tensor,
dst_shape: List[int],
builder: ir.builder) -> tl.tensor:
numel = 1
for s in dst_shape:
numel *= s
if input.type.numel != numel:
raise ValueError("cannot reshape block of different shape")
ret_ty = tl.block_type(input.type.scalar, dst_shape)
return tl.tensor(builder.create_reshape(input.handle, dst_shape), ret_ty)
return tl.tensor(builder.create_view(input.handle, dst_shape), ret_ty)
def cat(lhs: tl.tensor, rhs: tl.tensor, builder: ir.builder) -> tl.tensor: