[BACKEND] Added support for 1D conversion blocked -> slice (#831)
This commit is contained in:
@@ -11,7 +11,9 @@
|
|||||||
#include <numeric>
|
#include <numeric>
|
||||||
|
|
||||||
using ::mlir::triton::gpu::BlockedEncodingAttr;
|
using ::mlir::triton::gpu::BlockedEncodingAttr;
|
||||||
|
using ::mlir::triton::gpu::getOrder;
|
||||||
using ::mlir::triton::gpu::getShapePerCTA;
|
using ::mlir::triton::gpu::getShapePerCTA;
|
||||||
|
using ::mlir::triton::gpu::getSizePerThread;
|
||||||
using ::mlir::triton::gpu::MmaEncodingAttr;
|
using ::mlir::triton::gpu::MmaEncodingAttr;
|
||||||
using ::mlir::triton::gpu::SharedEncodingAttr;
|
using ::mlir::triton::gpu::SharedEncodingAttr;
|
||||||
using ::mlir::triton::gpu::SliceEncodingAttr;
|
using ::mlir::triton::gpu::SliceEncodingAttr;
|
||||||
@@ -34,28 +36,16 @@ getScratchConfigForCvtLayout(triton::gpu::ConvertLayoutOp op, unsigned &inVec,
|
|||||||
"Unexpect layout in getScratchConfigForCvtLayout()");
|
"Unexpect layout in getScratchConfigForCvtLayout()");
|
||||||
unsigned rank = dstTy.getRank();
|
unsigned rank = dstTy.getRank();
|
||||||
SmallVector<unsigned> paddedRepShape(rank);
|
SmallVector<unsigned> paddedRepShape(rank);
|
||||||
if (auto srcSliceLayout = srcLayout.dyn_cast<SliceEncodingAttr>())
|
|
||||||
srcLayout = srcSliceLayout.getParent();
|
|
||||||
if (auto dstSliceLayout = dstLayout.dyn_cast<SliceEncodingAttr>())
|
|
||||||
dstLayout = dstSliceLayout.getParent();
|
|
||||||
auto srcBlockedLayout = srcLayout.dyn_cast<BlockedEncodingAttr>();
|
auto srcBlockedLayout = srcLayout.dyn_cast<BlockedEncodingAttr>();
|
||||||
auto srcMmaLayout = srcLayout.dyn_cast<MmaEncodingAttr>();
|
auto srcMmaLayout = srcLayout.dyn_cast<MmaEncodingAttr>();
|
||||||
auto dstBlockedLayout = dstLayout.dyn_cast<BlockedEncodingAttr>();
|
auto dstBlockedLayout = dstLayout.dyn_cast<BlockedEncodingAttr>();
|
||||||
auto dstMmaLayout = dstLayout.dyn_cast<MmaEncodingAttr>();
|
auto dstMmaLayout = dstLayout.dyn_cast<MmaEncodingAttr>();
|
||||||
assert((srcBlockedLayout || srcMmaLayout) &&
|
|
||||||
"Unexpected srcLayout in getScratchConfigForCvtLayout");
|
|
||||||
assert((dstBlockedLayout || dstMmaLayout) &&
|
|
||||||
"Unexpected dstLayout in getScratchConfigForCvtLayout");
|
|
||||||
assert(!(srcMmaLayout && dstMmaLayout) &&
|
assert(!(srcMmaLayout && dstMmaLayout) &&
|
||||||
"Unexpected mma -> mma layout conversion");
|
"Unexpected mma -> mma layout conversion");
|
||||||
auto inOrd =
|
auto inOrd = srcMmaLayout ? getOrder(dstLayout) : getOrder(srcLayout);
|
||||||
srcMmaLayout ? dstBlockedLayout.getOrder() : srcBlockedLayout.getOrder();
|
auto outOrd = dstMmaLayout ? getOrder(srcLayout) : getOrder(dstLayout);
|
||||||
auto outOrd =
|
unsigned srcContigPerThread = getSizePerThread(srcLayout)[inOrd[0]];
|
||||||
dstMmaLayout ? srcBlockedLayout.getOrder() : dstBlockedLayout.getOrder();
|
unsigned dstContigPerThread = getSizePerThread(dstLayout)[outOrd[0]];
|
||||||
unsigned srcContigPerThread =
|
|
||||||
srcBlockedLayout ? srcBlockedLayout.getSizePerThread()[inOrd[0]] : 2;
|
|
||||||
unsigned dstContigPerThread =
|
|
||||||
dstBlockedLayout ? dstBlockedLayout.getSizePerThread()[outOrd[0]] : 2;
|
|
||||||
// TODO: Fix the legacy issue that ourOrd[0] == 0 always means
|
// TODO: Fix the legacy issue that ourOrd[0] == 0 always means
|
||||||
// that we cannot do vectorization.
|
// that we cannot do vectorization.
|
||||||
inVec = outOrd[0] == 0 ? 1 : inOrd[0] == 0 ? 1 : srcContigPerThread;
|
inVec = outOrd[0] == 0 ? 1 : inOrd[0] == 0 ? 1 : srcContigPerThread;
|
||||||
@@ -70,6 +60,8 @@ getScratchConfigForCvtLayout(triton::gpu::ConvertLayoutOp op, unsigned &inVec,
|
|||||||
std::max(std::min<unsigned>(srcTy.getShape()[d], srcShapePerCTA[d]),
|
std::max(std::min<unsigned>(srcTy.getShape()[d], srcShapePerCTA[d]),
|
||||||
std::min<unsigned>(dstTy.getShape()[d], dstShapePerCTA[d]));
|
std::min<unsigned>(dstTy.getShape()[d], dstShapePerCTA[d]));
|
||||||
}
|
}
|
||||||
|
if (rank == 1)
|
||||||
|
return paddedRepShape;
|
||||||
unsigned paddedDim = 1;
|
unsigned paddedDim = 1;
|
||||||
if (auto dstBlockedLayout = dstLayout.dyn_cast<BlockedEncodingAttr>()) {
|
if (auto dstBlockedLayout = dstLayout.dyn_cast<BlockedEncodingAttr>()) {
|
||||||
paddedDim = dstBlockedLayout.getOrder()[0];
|
paddedDim = dstBlockedLayout.getOrder()[0];
|
||||||
|
@@ -1197,7 +1197,7 @@ struct BroadcastOpConversion
|
|||||||
broadcastDims.push_back(d);
|
broadcastDims.push_back(d);
|
||||||
srcLogicalShape[d] = 1;
|
srcLogicalShape[d] = 1;
|
||||||
srcLogicalShape[d + rank] =
|
srcLogicalShape[d + rank] =
|
||||||
std::max(unsigned(1), srcLayout.getSizePerThread()[d]);
|
std::max<unsigned>(1, srcLayout.getSizePerThread()[d]);
|
||||||
} else {
|
} else {
|
||||||
srcLogicalShape[d] = numCtas;
|
srcLogicalShape[d] = numCtas;
|
||||||
srcLogicalShape[d + rank] = resultLayout.getSizePerThread()[d];
|
srcLogicalShape[d + rank] = resultLayout.getSizePerThread()[d];
|
||||||
@@ -2231,6 +2231,7 @@ LogicalResult ConvertLayoutOpConversion::lowerDistributedToDistributed(
|
|||||||
}
|
}
|
||||||
barrier();
|
barrier();
|
||||||
if (dstLayout.isa<BlockedEncodingAttr>() ||
|
if (dstLayout.isa<BlockedEncodingAttr>() ||
|
||||||
|
dstLayout.isa<SliceEncodingAttr>() ||
|
||||||
dstLayout.isa<MmaEncodingAttr>()) {
|
dstLayout.isa<MmaEncodingAttr>()) {
|
||||||
processReplica(loc, rewriter, /*stNotRd*/ false, dstTy, outNumCTAsEachRep,
|
processReplica(loc, rewriter, /*stNotRd*/ false, dstTy, outNumCTAsEachRep,
|
||||||
multiDimRepId, outVec, paddedRepShape, outOrd, outVals,
|
multiDimRepId, outVec, paddedRepShape, outOrd, outVals,
|
||||||
|
@@ -68,18 +68,7 @@ SmallVector<unsigned> getSizePerThread(Attribute layout) {
|
|||||||
return SmallVector<unsigned>(blockedLayout.getSizePerThread().begin(),
|
return SmallVector<unsigned>(blockedLayout.getSizePerThread().begin(),
|
||||||
blockedLayout.getSizePerThread().end());
|
blockedLayout.getSizePerThread().end());
|
||||||
} else if (auto sliceLayout = layout.dyn_cast<SliceEncodingAttr>()) {
|
} else if (auto sliceLayout = layout.dyn_cast<SliceEncodingAttr>()) {
|
||||||
unsigned dim = sliceLayout.getDim();
|
return getSizePerThread(sliceLayout.getParent());
|
||||||
auto parent = sliceLayout.getParent();
|
|
||||||
if (auto blockedParent = parent.dyn_cast<BlockedEncodingAttr>()) {
|
|
||||||
SmallVector<unsigned> sizePerThread(
|
|
||||||
blockedParent.getSizePerThread().begin(),
|
|
||||||
blockedParent.getSizePerThread().end());
|
|
||||||
sizePerThread.erase(sizePerThread.begin() + dim);
|
|
||||||
return sizePerThread;
|
|
||||||
} else {
|
|
||||||
assert(0 && "SliceEncodingAttr with parent other than "
|
|
||||||
"BlockedEncodingAttr not implemented");
|
|
||||||
}
|
|
||||||
} else if (auto mmaLayout = layout.dyn_cast<MmaEncodingAttr>()) {
|
} else if (auto mmaLayout = layout.dyn_cast<MmaEncodingAttr>()) {
|
||||||
assert(mmaLayout.getVersion() == 2 &&
|
assert(mmaLayout.getVersion() == 2 &&
|
||||||
"mmaLayout version = 1 is not implemented yet");
|
"mmaLayout version = 1 is not implemented yet");
|
||||||
@@ -144,6 +133,19 @@ SmallVector<unsigned> getOrder(const Attribute &layout) {
|
|||||||
blockedLayout.getOrder().end());
|
blockedLayout.getOrder().end());
|
||||||
} else if (auto mmaLayout = layout.dyn_cast<MmaEncodingAttr>()) {
|
} else if (auto mmaLayout = layout.dyn_cast<MmaEncodingAttr>()) {
|
||||||
return SmallVector<unsigned>{1, 0};
|
return SmallVector<unsigned>{1, 0};
|
||||||
|
} else if (auto sliceLayout = layout.dyn_cast<SliceEncodingAttr>()) {
|
||||||
|
SmallVector<unsigned> parentOrder = getOrder(sliceLayout.getParent());
|
||||||
|
unsigned dim = sliceLayout.getDim();
|
||||||
|
SmallVector<unsigned> order;
|
||||||
|
for (unsigned d : parentOrder) {
|
||||||
|
if (d == dim)
|
||||||
|
continue;
|
||||||
|
else if (d > dim)
|
||||||
|
order.push_back(d - 1);
|
||||||
|
else
|
||||||
|
order.push_back(d);
|
||||||
|
}
|
||||||
|
return order;
|
||||||
} else if (auto sharedLayout = layout.dyn_cast<SharedEncodingAttr>()) {
|
} else if (auto sharedLayout = layout.dyn_cast<SharedEncodingAttr>()) {
|
||||||
return SmallVector<unsigned>(sharedLayout.getOrder().begin(),
|
return SmallVector<unsigned>(sharedLayout.getOrder().begin(),
|
||||||
sharedLayout.getOrder().end());
|
sharedLayout.getOrder().end());
|
||||||
|
@@ -493,61 +493,65 @@ def make_ptr_str(name, shape):
|
|||||||
|
|
||||||
|
|
||||||
# TODO: handle `%4 = triton_gpu.convert_layout %3 : (tensor<32xi32, #blocked0>) -> tensor<32xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>>``
|
# TODO: handle `%4 = triton_gpu.convert_layout %3 : (tensor<32xi32, #blocked0>) -> tensor<32xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>>``
|
||||||
# @pytest.mark.parametrize("expr, dtype_str", [
|
@pytest.mark.parametrize("expr, dtype_str", [
|
||||||
# (f'x[{s}]', d)
|
(f'x[{s}]', d)
|
||||||
# for s in ['None, :', ':, None', 'None, :, :', ':, :, None']
|
for s in ['None, :', ':, None',
|
||||||
# for d in ['int32', 'uint32', 'uint16']
|
# TODO: 3D
|
||||||
# ])
|
# 'None, :, :',
|
||||||
# def test_index1d(expr, dtype_str, device='cuda'):
|
# ':, :, None'
|
||||||
# rank_x = expr.count(':')
|
]
|
||||||
# rank_y = expr.count(',') + 1
|
for d in ['int32', 'uint32', 'uint16']
|
||||||
# shape_x = [32 for _ in range(rank_x)]
|
])
|
||||||
# shape_z = [32 for _ in range(rank_y)]
|
def test_index1d(expr, dtype_str, device='cuda'):
|
||||||
# shape_z_rank_mismatch = [32 for _ in range(rank_y + 1)]
|
rank_x = expr.count(':')
|
||||||
# shape_z_dim_mismatch = [64 for _ in range(rank_y)]
|
rank_y = expr.count(',') + 1
|
||||||
|
shape_x = [32 for _ in range(rank_x)]
|
||||||
|
shape_z = [32 for _ in range(rank_y)]
|
||||||
|
shape_z_rank_mismatch = [32 for _ in range(rank_y + 1)]
|
||||||
|
shape_z_dim_mismatch = [64 for _ in range(rank_y)]
|
||||||
|
|
||||||
# # Triton kernel
|
# Triton kernel
|
||||||
# @triton.jit
|
@triton.jit
|
||||||
# def kernel(Z, X, SIZE: tl.constexpr):
|
def kernel(Z, X, SIZE: tl.constexpr):
|
||||||
# m = tl.arange(0, SIZE)
|
m = tl.arange(0, SIZE)
|
||||||
# n = tl.arange(0, SIZE)
|
n = tl.arange(0, SIZE)
|
||||||
# x = tl.load(X_PTR_EXPR)
|
x = tl.load(X_PTR_EXPR)
|
||||||
# z = GENERATE_TEST_HERE
|
z = GENERATE_TEST_HERE
|
||||||
# tl.store(Z_PTR_EXPR, z)
|
tl.store(Z_PTR_EXPR, z)
|
||||||
|
|
||||||
# def generate_kernel(shape_x, shape_z):
|
def generate_kernel(shape_x, shape_z):
|
||||||
# to_replace = {
|
to_replace = {
|
||||||
# 'X_PTR_EXPR': make_ptr_str('X', shape_x),
|
'X_PTR_EXPR': make_ptr_str('X', shape_x),
|
||||||
# 'Z_PTR_EXPR': make_ptr_str('Z', shape_z),
|
'Z_PTR_EXPR': make_ptr_str('Z', shape_z),
|
||||||
# 'GENERATE_TEST_HERE': expr,
|
'GENERATE_TEST_HERE': expr,
|
||||||
# }
|
}
|
||||||
# return patch_kernel(kernel, to_replace)
|
return patch_kernel(kernel, to_replace)
|
||||||
|
|
||||||
# kernel_match = generate_kernel(shape_x, shape_z)
|
kernel_match = generate_kernel(shape_x, shape_z)
|
||||||
# kernel_dim_mismatch = generate_kernel(shape_x, shape_z_dim_mismatch)
|
kernel_dim_mismatch = generate_kernel(shape_x, shape_z_dim_mismatch)
|
||||||
# kernel_rank_mismatch = generate_kernel(shape_x, shape_z_rank_mismatch)
|
kernel_rank_mismatch = generate_kernel(shape_x, shape_z_rank_mismatch)
|
||||||
|
|
||||||
# # torch result
|
# torch result
|
||||||
# x = numpy_random(shape_x, dtype_str=dtype_str)
|
x = numpy_random(shape_x, dtype_str=dtype_str)
|
||||||
# y = np.zeros(shape_z, dtype=getattr(np, dtype_str))
|
y = np.zeros(shape_z, dtype=getattr(np, dtype_str))
|
||||||
# z_ref = eval(expr) + y
|
z_ref = eval(expr) + y
|
||||||
# # triton result
|
# triton result
|
||||||
# z_tri = to_triton(np.empty_like(z_ref), device=device)
|
z_tri = to_triton(np.empty_like(z_ref), device=device)
|
||||||
# x_tri = to_triton(x)
|
x_tri = to_triton(x)
|
||||||
# kernel_match[(1, )](z_tri, x_tri, num_warps=1, SIZE=shape_x[0])
|
kernel_match[(1, )](z_tri, x_tri, num_warps=1, SIZE=shape_x[0])
|
||||||
# # compare
|
# compare
|
||||||
# assert (z_ref == to_numpy(z_tri)).all()
|
assert (z_ref == to_numpy(z_tri)).all()
|
||||||
|
|
||||||
# def catch_compilation_error(kernel):
|
def catch_compilation_error(kernel):
|
||||||
# try:
|
try:
|
||||||
# kernel[(1, )](z_tri, x_tri, num_warps=1, SIZE=shape_x[0])
|
kernel[(1, )](z_tri, x_tri, num_warps=1, SIZE=shape_x[0])
|
||||||
# except triton.CompilationError as e:
|
except triton.CompilationError as e:
|
||||||
# np.testing.assert_(True)
|
np.testing.assert_(True)
|
||||||
# except BaseException:
|
except BaseException:
|
||||||
# np.testing.assert_(False)
|
np.testing.assert_(False)
|
||||||
|
|
||||||
# catch_compilation_error(kernel_dim_mismatch)
|
catch_compilation_error(kernel_dim_mismatch)
|
||||||
# catch_compilation_error(kernel_rank_mismatch)
|
catch_compilation_error(kernel_rank_mismatch)
|
||||||
|
|
||||||
|
|
||||||
# # ---------------
|
# # ---------------
|
||||||
|
@@ -715,6 +715,32 @@ module attributes {"triton_gpu.num-warps" = 1 : i32} {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
#blocked0 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [1], order = [0]}>
|
||||||
|
#blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [4, 8], warpsPerCTA = [1, 1], order = [1, 0]}>
|
||||||
|
module attributes {"triton_gpu.num-warps" = 1 : i32} {
|
||||||
|
// CHECK-LABEL: convert_blocked1d_to_slice0
|
||||||
|
func @convert_blocked1d_to_slice0(%src:tensor<32xi32, #blocked0>) {
|
||||||
|
// CHECK-COUNT-4: llvm.load {{.*}} : !llvm.ptr<vector<1xi32>, 3>
|
||||||
|
%cvt = triton_gpu.convert_layout %src : (tensor<32xi32, #blocked0>) -> tensor<32xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>>
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
// -----
|
||||||
|
#blocked0 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [1], order = [0]}>
|
||||||
|
#blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [4, 8], warpsPerCTA = [1, 1], order = [1, 0]}>
|
||||||
|
module attributes {"triton_gpu.num-warps" = 1 : i32} {
|
||||||
|
// CHECK-LABEL: convert_blocked1d_to_slice1
|
||||||
|
func @convert_blocked1d_to_slice1(%src:tensor<32xi32, #blocked0>) {
|
||||||
|
// CHECK-COUNT-32: llvm.load {{.*}} : !llvm.ptr<vector<1xi32>, 3>
|
||||||
|
%cvt = triton_gpu.convert_layout %src : (tensor<32xi32, #blocked0>) -> tensor<32xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>>
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
// -----
|
// -----
|
||||||
#blocked = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [2, 16], warpsPerCTA = [1, 4], order = [1, 0]}>
|
#blocked = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [2, 16], warpsPerCTA = [1, 4], order = [1, 0]}>
|
||||||
#shared = #triton_gpu.shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [1, 0]}>
|
#shared = #triton_gpu.shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [1, 0]}>
|
||||||
|
Reference in New Issue
Block a user