[OPTIMIZER] Fixed up order of shared layouts (#881)

This commit is contained in:
Philippe Tillet
2022-11-21 06:25:02 +01:00
committed by GitHub
parent 4d64ffb5fe
commit 23f71daa27
6 changed files with 27 additions and 27 deletions

View File

@@ -74,11 +74,9 @@ A_{3, 2} A_{3, 3} A_{3, 0} A_{3, 1} ... [phase 1] /
let builders = [
AttrBuilder<(ins "DotOperandEncodingAttr":$dotOpEnc,
"ArrayRef<int64_t>":$shape,
"ArrayRef<unsigned>":$order,
"Type":$eltTy), [{
auto mmaEnc = dotOpEnc.getParent().dyn_cast<MmaEncodingAttr>();
// Only support row major for now
// TODO(Keren): check why column major code crashes
SmallVector<unsigned> order = {1, 0};
if(!mmaEnc)
return $_get(context, 1, 1, 1, order);

View File

@@ -464,21 +464,6 @@ struct SharedMemoryObject {
}
}
// XXX(Keren): a special allocator for 3d tensors. It's a workaround for
// now since we don't have a correct way to encoding 3d tensors in the
// pipeline pass.
SharedMemoryObject(Value base, ArrayRef<int64_t> shape, Location loc,
ConversionPatternRewriter &rewriter)
: base(base) {
auto stride = 1;
for (auto dim : llvm::reverse(shape)) {
strides.emplace_back(i32_val(stride));
offsets.emplace_back(i32_val(0));
stride *= dim;
}
strides = llvm::to_vector<4>(llvm::reverse(strides));
}
SmallVector<Value> getElems() const {
SmallVector<Value> elems;
elems.push_back(base);
@@ -2251,8 +2236,18 @@ struct AllocTensorOpConversion
getTypeConverter()->convertType(resultTy.getElementType());
auto elemPtrTy = ptr_ty(llvmElemTy, 3);
smemBase = bitcast(smemBase, elemPtrTy);
auto order = resultTy.getEncoding().cast<SharedEncodingAttr>().getOrder();
// workaround for 3D tensors
// TODO: We need to modify the pipeline pass to give a proper shared encoding to 3D tensors
SmallVector<unsigned> newOrder;
if (resultTy.getShape().size() == 3)
newOrder = {1 + order[0], 1 + order[1], 0};
else
newOrder = SmallVector<unsigned>(order.begin(), order.end());
auto smemObj =
SharedMemoryObject(smemBase, resultTy.getShape(), loc, rewriter);
SharedMemoryObject(smemBase, resultTy.getShape(), newOrder, loc, rewriter);
auto retVal = getStructFromSharedMemoryObject(loc, smemObj, rewriter);
rewriter.replaceOp(op, retVal);
return success();
@@ -2302,6 +2297,10 @@ struct ExtractSliceOpConversion
strideVals.emplace_back(smemObj.strides[i]);
}
}
// llvm::outs() << "extract slice\n";
// llvm::outs() << strideVals[0] << " " << smemObj.strides[1] << "\n";
// llvm::outs() << strideVals[1] << " " << smemObj.strides[2] << "\n";
auto llvmElemTy = getTypeConverter()->convertType(srcTy.getElementType());
auto elemPtrTy = ptr_ty(llvmElemTy, 3);
auto resTy = op.getType().dyn_cast<RankedTensorType>();
@@ -3262,8 +3261,8 @@ public:
cMatShape = matShape[order[0]];
sMatShape = matShape[order[1]];
cStride = smemStrides[1];
sStride = smemStrides[0];
cStride = smemStrides[order[0]];
sStride = smemStrides[order[1]];
// rule: k must be the fast-changing axis.
needTrans = kOrder != order[0];
@@ -6202,6 +6201,7 @@ private:
dstType.getShape(), dstType.getElementType(),
triton::gpu::SharedEncodingAttr::get(mod.getContext(), dstDotOp,
srcType.getShape(),
getOrder(srcBlocked),
srcType.getElementType()));
auto tmp = builder.create<triton::gpu::ConvertLayoutOp>(
cvtOp.getLoc(), tmpType, cvtOp.getOperand());

View File

@@ -201,7 +201,9 @@ LogicalResult LoopPipeliner::initialize() {
ty.getShape().end());
bufferShape.insert(bufferShape.begin(), numStages);
auto sharedEnc = ttg::SharedEncodingAttr::get(
ty.getContext(), dotOpEnc, ty.getShape(), ty.getElementType());
ty.getContext(), dotOpEnc, ty.getShape(),
triton::gpu::getOrder(ty.getEncoding()),
ty.getElementType());
loadsBufferType[loadOp] = RankedTensorType::get(
bufferShape, ty.getElementType(), sharedEnc);
}

View File

@@ -186,9 +186,9 @@ def get_proper_err(a, b, golden):
[128, 256, 128, 4, 128, 256, 32, False, False],
[256, 128, 64, 4, 256, 128, 16, False, False],
[128, 64, 128, 4, 128, 64, 32, False, False],
# TODO[goostavz]: fix these cases
#[128, 64, 128, 4, 128, 64, 32, True, False],
#[128, 64, 128, 4, 128, 64, 32, False, True],
# trans
[128, 64, 128, 4, 128, 64, 32, True, False],
[128, 64, 128, 4, 128, 64, 32, False, True],
])
def test_gemm(SIZE_M, SIZE_N, SIZE_K, NUM_WARPS, BLOCK_SIZE_M, BLOCK_SIZE_N, BLOCK_SIZE_K, TRANS_A, TRANS_B):
if (TRANS_A):

View File

@@ -882,6 +882,7 @@ def ttir_to_ttgir(mod, num_warps, num_stages):
pm.enable_debug()
# Convert blocked layout to mma layout for dot ops so that pipeline
# can get shared memory swizzled correctly.
pm.add_coalesce_pass()
pm.add_triton_gpu_combine_pass()
pm.add_tritongpu_pipeline_pass(num_stages)
# Prefetch must be done after pipeline pass because pipeline pass
@@ -889,7 +890,6 @@ def ttir_to_ttgir(mod, num_warps, num_stages):
pm.add_tritongpu_prefetch_pass()
pm.add_canonicalizer_pass()
pm.add_cse_pass()
pm.add_coalesce_pass()
pm.add_triton_gpu_combine_pass()
pm.add_licm_pass()
pm.add_triton_gpu_combine_pass()

View File

@@ -34,7 +34,7 @@ TEST_P(SwizzleDotOperandTestFixture, DotOperands) {
// create element type
Type eltType = IntegerType::get(&ctx, params.typeWidth);
auto layout = SharedEncodingAttr::get(&ctx, encoding, params.shape, eltType);
auto layout = SharedEncodingAttr::get(&ctx, encoding, params.shape, {1, 0}, eltType);
ASSERT_EQ(layout.getVec(), params.refSwizzle.vec);
ASSERT_EQ(layout.getPerPhase(), params.refSwizzle.perPhase);