[BACKEND] Added support for 1D conversion blocked -> slice (#831)

This commit is contained in:
Philippe Tillet
2022-11-01 13:19:58 -07:00
committed by GitHub
parent c9d84237e8
commit 12d60cb4a3
5 changed files with 103 additions and 78 deletions

View File

@@ -11,7 +11,9 @@
#include <numeric> #include <numeric>
using ::mlir::triton::gpu::BlockedEncodingAttr; using ::mlir::triton::gpu::BlockedEncodingAttr;
using ::mlir::triton::gpu::getOrder;
using ::mlir::triton::gpu::getShapePerCTA; using ::mlir::triton::gpu::getShapePerCTA;
using ::mlir::triton::gpu::getSizePerThread;
using ::mlir::triton::gpu::MmaEncodingAttr; using ::mlir::triton::gpu::MmaEncodingAttr;
using ::mlir::triton::gpu::SharedEncodingAttr; using ::mlir::triton::gpu::SharedEncodingAttr;
using ::mlir::triton::gpu::SliceEncodingAttr; using ::mlir::triton::gpu::SliceEncodingAttr;
@@ -34,28 +36,16 @@ getScratchConfigForCvtLayout(triton::gpu::ConvertLayoutOp op, unsigned &inVec,
"Unexpect layout in getScratchConfigForCvtLayout()"); "Unexpect layout in getScratchConfigForCvtLayout()");
unsigned rank = dstTy.getRank(); unsigned rank = dstTy.getRank();
SmallVector<unsigned> paddedRepShape(rank); SmallVector<unsigned> paddedRepShape(rank);
if (auto srcSliceLayout = srcLayout.dyn_cast<SliceEncodingAttr>())
srcLayout = srcSliceLayout.getParent();
if (auto dstSliceLayout = dstLayout.dyn_cast<SliceEncodingAttr>())
dstLayout = dstSliceLayout.getParent();
auto srcBlockedLayout = srcLayout.dyn_cast<BlockedEncodingAttr>(); auto srcBlockedLayout = srcLayout.dyn_cast<BlockedEncodingAttr>();
auto srcMmaLayout = srcLayout.dyn_cast<MmaEncodingAttr>(); auto srcMmaLayout = srcLayout.dyn_cast<MmaEncodingAttr>();
auto dstBlockedLayout = dstLayout.dyn_cast<BlockedEncodingAttr>(); auto dstBlockedLayout = dstLayout.dyn_cast<BlockedEncodingAttr>();
auto dstMmaLayout = dstLayout.dyn_cast<MmaEncodingAttr>(); auto dstMmaLayout = dstLayout.dyn_cast<MmaEncodingAttr>();
assert((srcBlockedLayout || srcMmaLayout) &&
"Unexpected srcLayout in getScratchConfigForCvtLayout");
assert((dstBlockedLayout || dstMmaLayout) &&
"Unexpected dstLayout in getScratchConfigForCvtLayout");
assert(!(srcMmaLayout && dstMmaLayout) && assert(!(srcMmaLayout && dstMmaLayout) &&
"Unexpected mma -> mma layout conversion"); "Unexpected mma -> mma layout conversion");
auto inOrd = auto inOrd = srcMmaLayout ? getOrder(dstLayout) : getOrder(srcLayout);
srcMmaLayout ? dstBlockedLayout.getOrder() : srcBlockedLayout.getOrder(); auto outOrd = dstMmaLayout ? getOrder(srcLayout) : getOrder(dstLayout);
auto outOrd = unsigned srcContigPerThread = getSizePerThread(srcLayout)[inOrd[0]];
dstMmaLayout ? srcBlockedLayout.getOrder() : dstBlockedLayout.getOrder(); unsigned dstContigPerThread = getSizePerThread(dstLayout)[outOrd[0]];
unsigned srcContigPerThread =
srcBlockedLayout ? srcBlockedLayout.getSizePerThread()[inOrd[0]] : 2;
unsigned dstContigPerThread =
dstBlockedLayout ? dstBlockedLayout.getSizePerThread()[outOrd[0]] : 2;
// TODO: Fix the legacy issue that ourOrd[0] == 0 always means // TODO: Fix the legacy issue that ourOrd[0] == 0 always means
// that we cannot do vectorization. // that we cannot do vectorization.
inVec = outOrd[0] == 0 ? 1 : inOrd[0] == 0 ? 1 : srcContigPerThread; inVec = outOrd[0] == 0 ? 1 : inOrd[0] == 0 ? 1 : srcContigPerThread;
@@ -70,6 +60,8 @@ getScratchConfigForCvtLayout(triton::gpu::ConvertLayoutOp op, unsigned &inVec,
std::max(std::min<unsigned>(srcTy.getShape()[d], srcShapePerCTA[d]), std::max(std::min<unsigned>(srcTy.getShape()[d], srcShapePerCTA[d]),
std::min<unsigned>(dstTy.getShape()[d], dstShapePerCTA[d])); std::min<unsigned>(dstTy.getShape()[d], dstShapePerCTA[d]));
} }
if (rank == 1)
return paddedRepShape;
unsigned paddedDim = 1; unsigned paddedDim = 1;
if (auto dstBlockedLayout = dstLayout.dyn_cast<BlockedEncodingAttr>()) { if (auto dstBlockedLayout = dstLayout.dyn_cast<BlockedEncodingAttr>()) {
paddedDim = dstBlockedLayout.getOrder()[0]; paddedDim = dstBlockedLayout.getOrder()[0];

View File

@@ -1197,7 +1197,7 @@ struct BroadcastOpConversion
broadcastDims.push_back(d); broadcastDims.push_back(d);
srcLogicalShape[d] = 1; srcLogicalShape[d] = 1;
srcLogicalShape[d + rank] = srcLogicalShape[d + rank] =
std::max(unsigned(1), srcLayout.getSizePerThread()[d]); std::max<unsigned>(1, srcLayout.getSizePerThread()[d]);
} else { } else {
srcLogicalShape[d] = numCtas; srcLogicalShape[d] = numCtas;
srcLogicalShape[d + rank] = resultLayout.getSizePerThread()[d]; srcLogicalShape[d + rank] = resultLayout.getSizePerThread()[d];
@@ -2231,6 +2231,7 @@ LogicalResult ConvertLayoutOpConversion::lowerDistributedToDistributed(
} }
barrier(); barrier();
if (dstLayout.isa<BlockedEncodingAttr>() || if (dstLayout.isa<BlockedEncodingAttr>() ||
dstLayout.isa<SliceEncodingAttr>() ||
dstLayout.isa<MmaEncodingAttr>()) { dstLayout.isa<MmaEncodingAttr>()) {
processReplica(loc, rewriter, /*stNotRd*/ false, dstTy, outNumCTAsEachRep, processReplica(loc, rewriter, /*stNotRd*/ false, dstTy, outNumCTAsEachRep,
multiDimRepId, outVec, paddedRepShape, outOrd, outVals, multiDimRepId, outVec, paddedRepShape, outOrd, outVals,

View File

@@ -68,18 +68,7 @@ SmallVector<unsigned> getSizePerThread(Attribute layout) {
return SmallVector<unsigned>(blockedLayout.getSizePerThread().begin(), return SmallVector<unsigned>(blockedLayout.getSizePerThread().begin(),
blockedLayout.getSizePerThread().end()); blockedLayout.getSizePerThread().end());
} else if (auto sliceLayout = layout.dyn_cast<SliceEncodingAttr>()) { } else if (auto sliceLayout = layout.dyn_cast<SliceEncodingAttr>()) {
unsigned dim = sliceLayout.getDim(); return getSizePerThread(sliceLayout.getParent());
auto parent = sliceLayout.getParent();
if (auto blockedParent = parent.dyn_cast<BlockedEncodingAttr>()) {
SmallVector<unsigned> sizePerThread(
blockedParent.getSizePerThread().begin(),
blockedParent.getSizePerThread().end());
sizePerThread.erase(sizePerThread.begin() + dim);
return sizePerThread;
} else {
assert(0 && "SliceEncodingAttr with parent other than "
"BlockedEncodingAttr not implemented");
}
} else if (auto mmaLayout = layout.dyn_cast<MmaEncodingAttr>()) { } else if (auto mmaLayout = layout.dyn_cast<MmaEncodingAttr>()) {
assert(mmaLayout.getVersion() == 2 && assert(mmaLayout.getVersion() == 2 &&
"mmaLayout version = 1 is not implemented yet"); "mmaLayout version = 1 is not implemented yet");
@@ -144,6 +133,19 @@ SmallVector<unsigned> getOrder(const Attribute &layout) {
blockedLayout.getOrder().end()); blockedLayout.getOrder().end());
} else if (auto mmaLayout = layout.dyn_cast<MmaEncodingAttr>()) { } else if (auto mmaLayout = layout.dyn_cast<MmaEncodingAttr>()) {
return SmallVector<unsigned>{1, 0}; return SmallVector<unsigned>{1, 0};
} else if (auto sliceLayout = layout.dyn_cast<SliceEncodingAttr>()) {
SmallVector<unsigned> parentOrder = getOrder(sliceLayout.getParent());
unsigned dim = sliceLayout.getDim();
SmallVector<unsigned> order;
for (unsigned d : parentOrder) {
if (d == dim)
continue;
else if (d > dim)
order.push_back(d - 1);
else
order.push_back(d);
}
return order;
} else if (auto sharedLayout = layout.dyn_cast<SharedEncodingAttr>()) { } else if (auto sharedLayout = layout.dyn_cast<SharedEncodingAttr>()) {
return SmallVector<unsigned>(sharedLayout.getOrder().begin(), return SmallVector<unsigned>(sharedLayout.getOrder().begin(),
sharedLayout.getOrder().end()); sharedLayout.getOrder().end());

View File

@@ -493,61 +493,65 @@ def make_ptr_str(name, shape):
# TODO: handle `%4 = triton_gpu.convert_layout %3 : (tensor<32xi32, #blocked0>) -> tensor<32xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>>`` # TODO: handle `%4 = triton_gpu.convert_layout %3 : (tensor<32xi32, #blocked0>) -> tensor<32xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>>``
# @pytest.mark.parametrize("expr, dtype_str", [ @pytest.mark.parametrize("expr, dtype_str", [
# (f'x[{s}]', d) (f'x[{s}]', d)
# for s in ['None, :', ':, None', 'None, :, :', ':, :, None'] for s in ['None, :', ':, None',
# for d in ['int32', 'uint32', 'uint16'] # TODO: 3D
# ]) # 'None, :, :',
# def test_index1d(expr, dtype_str, device='cuda'): # ':, :, None'
# rank_x = expr.count(':') ]
# rank_y = expr.count(',') + 1 for d in ['int32', 'uint32', 'uint16']
# shape_x = [32 for _ in range(rank_x)] ])
# shape_z = [32 for _ in range(rank_y)] def test_index1d(expr, dtype_str, device='cuda'):
# shape_z_rank_mismatch = [32 for _ in range(rank_y + 1)] rank_x = expr.count(':')
# shape_z_dim_mismatch = [64 for _ in range(rank_y)] rank_y = expr.count(',') + 1
shape_x = [32 for _ in range(rank_x)]
shape_z = [32 for _ in range(rank_y)]
shape_z_rank_mismatch = [32 for _ in range(rank_y + 1)]
shape_z_dim_mismatch = [64 for _ in range(rank_y)]
# # Triton kernel # Triton kernel
# @triton.jit @triton.jit
# def kernel(Z, X, SIZE: tl.constexpr): def kernel(Z, X, SIZE: tl.constexpr):
# m = tl.arange(0, SIZE) m = tl.arange(0, SIZE)
# n = tl.arange(0, SIZE) n = tl.arange(0, SIZE)
# x = tl.load(X_PTR_EXPR) x = tl.load(X_PTR_EXPR)
# z = GENERATE_TEST_HERE z = GENERATE_TEST_HERE
# tl.store(Z_PTR_EXPR, z) tl.store(Z_PTR_EXPR, z)
# def generate_kernel(shape_x, shape_z): def generate_kernel(shape_x, shape_z):
# to_replace = { to_replace = {
# 'X_PTR_EXPR': make_ptr_str('X', shape_x), 'X_PTR_EXPR': make_ptr_str('X', shape_x),
# 'Z_PTR_EXPR': make_ptr_str('Z', shape_z), 'Z_PTR_EXPR': make_ptr_str('Z', shape_z),
# 'GENERATE_TEST_HERE': expr, 'GENERATE_TEST_HERE': expr,
# } }
# return patch_kernel(kernel, to_replace) return patch_kernel(kernel, to_replace)
# kernel_match = generate_kernel(shape_x, shape_z) kernel_match = generate_kernel(shape_x, shape_z)
# kernel_dim_mismatch = generate_kernel(shape_x, shape_z_dim_mismatch) kernel_dim_mismatch = generate_kernel(shape_x, shape_z_dim_mismatch)
# kernel_rank_mismatch = generate_kernel(shape_x, shape_z_rank_mismatch) kernel_rank_mismatch = generate_kernel(shape_x, shape_z_rank_mismatch)
# # torch result # torch result
# x = numpy_random(shape_x, dtype_str=dtype_str) x = numpy_random(shape_x, dtype_str=dtype_str)
# y = np.zeros(shape_z, dtype=getattr(np, dtype_str)) y = np.zeros(shape_z, dtype=getattr(np, dtype_str))
# z_ref = eval(expr) + y z_ref = eval(expr) + y
# # triton result # triton result
# z_tri = to_triton(np.empty_like(z_ref), device=device) z_tri = to_triton(np.empty_like(z_ref), device=device)
# x_tri = to_triton(x) x_tri = to_triton(x)
# kernel_match[(1, )](z_tri, x_tri, num_warps=1, SIZE=shape_x[0]) kernel_match[(1, )](z_tri, x_tri, num_warps=1, SIZE=shape_x[0])
# # compare # compare
# assert (z_ref == to_numpy(z_tri)).all() assert (z_ref == to_numpy(z_tri)).all()
# def catch_compilation_error(kernel): def catch_compilation_error(kernel):
# try: try:
# kernel[(1, )](z_tri, x_tri, num_warps=1, SIZE=shape_x[0]) kernel[(1, )](z_tri, x_tri, num_warps=1, SIZE=shape_x[0])
# except triton.CompilationError as e: except triton.CompilationError as e:
# np.testing.assert_(True) np.testing.assert_(True)
# except BaseException: except BaseException:
# np.testing.assert_(False) np.testing.assert_(False)
# catch_compilation_error(kernel_dim_mismatch) catch_compilation_error(kernel_dim_mismatch)
# catch_compilation_error(kernel_rank_mismatch) catch_compilation_error(kernel_rank_mismatch)
# # --------------- # # ---------------

View File

@@ -715,6 +715,32 @@ module attributes {"triton_gpu.num-warps" = 1 : i32} {
} }
} }
// -----
#blocked0 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [1], order = [0]}>
#blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [4, 8], warpsPerCTA = [1, 1], order = [1, 0]}>
module attributes {"triton_gpu.num-warps" = 1 : i32} {
// CHECK-LABEL: convert_blocked1d_to_slice0
func @convert_blocked1d_to_slice0(%src:tensor<32xi32, #blocked0>) {
// CHECK-COUNT-4: llvm.load {{.*}} : !llvm.ptr<vector<1xi32>, 3>
%cvt = triton_gpu.convert_layout %src : (tensor<32xi32, #blocked0>) -> tensor<32xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>>
return
}
}
// -----
#blocked0 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [1], order = [0]}>
#blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [4, 8], warpsPerCTA = [1, 1], order = [1, 0]}>
module attributes {"triton_gpu.num-warps" = 1 : i32} {
// CHECK-LABEL: convert_blocked1d_to_slice1
func @convert_blocked1d_to_slice1(%src:tensor<32xi32, #blocked0>) {
// CHECK-COUNT-32: llvm.load {{.*}} : !llvm.ptr<vector<1xi32>, 3>
%cvt = triton_gpu.convert_layout %src : (tensor<32xi32, #blocked0>) -> tensor<32xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>>
return
}
}
// ----- // -----
#blocked = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [2, 16], warpsPerCTA = [1, 4], order = [1, 0]}> #blocked = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [2, 16], warpsPerCTA = [1, 4], order = [1, 0]}>
#shared = #triton_gpu.shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [1, 0]}> #shared = #triton_gpu.shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [1, 0]}>