[TritonGPU] Improved documentation and semantics of layout encodings (#30)
This commit is contained in:
@@ -10,10 +10,10 @@ def kernel(X, stride_xm, stride_xn,
|
||||
BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr):
|
||||
off_m = tl.arange(0, BLOCK_M)
|
||||
off_n = tl.arange(0, BLOCK_N)
|
||||
Xs = X + off_m[:, None] * stride_xm + off_n[None, :] * stride_xn
|
||||
Zs = Z + off_m[:, None] * stride_zm + off_n[None, :] * stride_zn
|
||||
Xs = X + off_m[:, None] * stride_xm + off_n[None, :] * 1
|
||||
Zs = Z + off_m[:, None] * 1 + off_n[None, :] * stride_zn
|
||||
tl.store(Zs, tl.load(Xs))
|
||||
|
||||
|
||||
ret = triton.compile(kernel, "*fp32,i32,i32,*fp32,i32,i32", constants={"BLOCK_M": 128, "BLOCK_N": 128}, output="ttgir")
|
||||
ret = triton.compile(kernel, "*fp32,i32,i32,*fp32,i32,i32", constants={"BLOCK_M": 64, "BLOCK_N": 64}, output="ttgir")
|
||||
print(ret)
|
||||
|
@@ -1471,14 +1471,14 @@ void init_triton_ir(py::module &&m) {
|
||||
self.create<mlir::triton::StoreOp>(loc, ptrs, val, mask);
|
||||
})
|
||||
// Block instruction
|
||||
.def("create_reshape",
|
||||
.def("create_view",
|
||||
[](mlir::OpBuilder &self, mlir::Value &arg,
|
||||
std::vector<int64_t> &shape) -> mlir::Value {
|
||||
auto loc = self.getUnknownLoc();
|
||||
auto argType = arg.getType()
|
||||
.dyn_cast<mlir::RankedTensorType>()
|
||||
.getElementType();
|
||||
return self.create<mlir::triton::ReshapeOp>(
|
||||
return self.create<mlir::triton::ViewOp>(
|
||||
loc, mlir::RankedTensorType::get(shape, argType), arg);
|
||||
})
|
||||
.def("create_cat",
|
||||
|
@@ -565,7 +565,7 @@ class tensor:
|
||||
elif sl == slice(None, None, None):
|
||||
dst_shape.append(src_shape[curr].value)
|
||||
curr += 1
|
||||
ret = semantic.reshape(self, dst_shape, _builder)
|
||||
ret = semantic.view(self, dst_shape, _builder)
|
||||
return ret
|
||||
|
||||
@builtin
|
||||
|
@@ -451,16 +451,16 @@ def zeros(shape: List[int], dtype: tl.dtype, builder: ir.builder) -> tl.tensor:
|
||||
# ===----------------------------------------------------------------------===//
|
||||
|
||||
|
||||
def reshape(input: tl.tensor,
|
||||
dst_shape: List[int],
|
||||
builder: ir.builder) -> tl.tensor:
|
||||
def view(input: tl.tensor,
|
||||
dst_shape: List[int],
|
||||
builder: ir.builder) -> tl.tensor:
|
||||
numel = 1
|
||||
for s in dst_shape:
|
||||
numel *= s
|
||||
if input.type.numel != numel:
|
||||
raise ValueError("cannot reshape block of different shape")
|
||||
ret_ty = tl.block_type(input.type.scalar, dst_shape)
|
||||
return tl.tensor(builder.create_reshape(input.handle, dst_shape), ret_ty)
|
||||
return tl.tensor(builder.create_view(input.handle, dst_shape), ret_ty)
|
||||
|
||||
|
||||
def cat(lhs: tl.tensor, rhs: tl.tensor, builder: ir.builder) -> tl.tensor:
|
||||
|
Reference in New Issue
Block a user