[OPTIMIZER] Fixed up order of shared layouts (#881)
This commit is contained in:
@@ -74,11 +74,9 @@ A_{3, 2} A_{3, 3} A_{3, 0} A_{3, 1} ... [phase 1] /
|
||||
let builders = [
|
||||
AttrBuilder<(ins "DotOperandEncodingAttr":$dotOpEnc,
|
||||
"ArrayRef<int64_t>":$shape,
|
||||
"ArrayRef<unsigned>":$order,
|
||||
"Type":$eltTy), [{
|
||||
auto mmaEnc = dotOpEnc.getParent().dyn_cast<MmaEncodingAttr>();
|
||||
// Only support row major for now
|
||||
// TODO(Keren): check why column major code crashes
|
||||
SmallVector<unsigned> order = {1, 0};
|
||||
|
||||
if(!mmaEnc)
|
||||
return $_get(context, 1, 1, 1, order);
|
||||
|
@@ -464,21 +464,6 @@ struct SharedMemoryObject {
|
||||
}
|
||||
}
|
||||
|
||||
// XXX(Keren): a special allocator for 3d tensors. It's a workaround for
|
||||
// now since we don't have a correct way to encoding 3d tensors in the
|
||||
// pipeline pass.
|
||||
SharedMemoryObject(Value base, ArrayRef<int64_t> shape, Location loc,
|
||||
ConversionPatternRewriter &rewriter)
|
||||
: base(base) {
|
||||
auto stride = 1;
|
||||
for (auto dim : llvm::reverse(shape)) {
|
||||
strides.emplace_back(i32_val(stride));
|
||||
offsets.emplace_back(i32_val(0));
|
||||
stride *= dim;
|
||||
}
|
||||
strides = llvm::to_vector<4>(llvm::reverse(strides));
|
||||
}
|
||||
|
||||
SmallVector<Value> getElems() const {
|
||||
SmallVector<Value> elems;
|
||||
elems.push_back(base);
|
||||
@@ -2251,8 +2236,18 @@ struct AllocTensorOpConversion
|
||||
getTypeConverter()->convertType(resultTy.getElementType());
|
||||
auto elemPtrTy = ptr_ty(llvmElemTy, 3);
|
||||
smemBase = bitcast(smemBase, elemPtrTy);
|
||||
auto order = resultTy.getEncoding().cast<SharedEncodingAttr>().getOrder();
|
||||
// workaround for 3D tensors
|
||||
// TODO: We need to modify the pipeline pass to give a proper shared encoding to 3D tensors
|
||||
SmallVector<unsigned> newOrder;
|
||||
if (resultTy.getShape().size() == 3)
|
||||
newOrder = {1 + order[0], 1 + order[1], 0};
|
||||
else
|
||||
newOrder = SmallVector<unsigned>(order.begin(), order.end());
|
||||
|
||||
|
||||
auto smemObj =
|
||||
SharedMemoryObject(smemBase, resultTy.getShape(), loc, rewriter);
|
||||
SharedMemoryObject(smemBase, resultTy.getShape(), newOrder, loc, rewriter);
|
||||
auto retVal = getStructFromSharedMemoryObject(loc, smemObj, rewriter);
|
||||
rewriter.replaceOp(op, retVal);
|
||||
return success();
|
||||
@@ -2302,6 +2297,10 @@ struct ExtractSliceOpConversion
|
||||
strideVals.emplace_back(smemObj.strides[i]);
|
||||
}
|
||||
}
|
||||
|
||||
// llvm::outs() << "extract slice\n";
|
||||
// llvm::outs() << strideVals[0] << " " << smemObj.strides[1] << "\n";
|
||||
// llvm::outs() << strideVals[1] << " " << smemObj.strides[2] << "\n";
|
||||
auto llvmElemTy = getTypeConverter()->convertType(srcTy.getElementType());
|
||||
auto elemPtrTy = ptr_ty(llvmElemTy, 3);
|
||||
auto resTy = op.getType().dyn_cast<RankedTensorType>();
|
||||
@@ -3262,8 +3261,8 @@ public:
|
||||
cMatShape = matShape[order[0]];
|
||||
sMatShape = matShape[order[1]];
|
||||
|
||||
cStride = smemStrides[1];
|
||||
sStride = smemStrides[0];
|
||||
cStride = smemStrides[order[0]];
|
||||
sStride = smemStrides[order[1]];
|
||||
|
||||
// rule: k must be the fast-changing axis.
|
||||
needTrans = kOrder != order[0];
|
||||
@@ -6202,6 +6201,7 @@ private:
|
||||
dstType.getShape(), dstType.getElementType(),
|
||||
triton::gpu::SharedEncodingAttr::get(mod.getContext(), dstDotOp,
|
||||
srcType.getShape(),
|
||||
getOrder(srcBlocked),
|
||||
srcType.getElementType()));
|
||||
auto tmp = builder.create<triton::gpu::ConvertLayoutOp>(
|
||||
cvtOp.getLoc(), tmpType, cvtOp.getOperand());
|
||||
|
@@ -201,7 +201,9 @@ LogicalResult LoopPipeliner::initialize() {
|
||||
ty.getShape().end());
|
||||
bufferShape.insert(bufferShape.begin(), numStages);
|
||||
auto sharedEnc = ttg::SharedEncodingAttr::get(
|
||||
ty.getContext(), dotOpEnc, ty.getShape(), ty.getElementType());
|
||||
ty.getContext(), dotOpEnc, ty.getShape(),
|
||||
triton::gpu::getOrder(ty.getEncoding()),
|
||||
ty.getElementType());
|
||||
loadsBufferType[loadOp] = RankedTensorType::get(
|
||||
bufferShape, ty.getElementType(), sharedEnc);
|
||||
}
|
||||
|
@@ -186,9 +186,9 @@ def get_proper_err(a, b, golden):
|
||||
[128, 256, 128, 4, 128, 256, 32, False, False],
|
||||
[256, 128, 64, 4, 256, 128, 16, False, False],
|
||||
[128, 64, 128, 4, 128, 64, 32, False, False],
|
||||
# TODO[goostavz]: fix these cases
|
||||
#[128, 64, 128, 4, 128, 64, 32, True, False],
|
||||
#[128, 64, 128, 4, 128, 64, 32, False, True],
|
||||
# trans
|
||||
[128, 64, 128, 4, 128, 64, 32, True, False],
|
||||
[128, 64, 128, 4, 128, 64, 32, False, True],
|
||||
])
|
||||
def test_gemm(SIZE_M, SIZE_N, SIZE_K, NUM_WARPS, BLOCK_SIZE_M, BLOCK_SIZE_N, BLOCK_SIZE_K, TRANS_A, TRANS_B):
|
||||
if (TRANS_A):
|
||||
|
@@ -882,6 +882,7 @@ def ttir_to_ttgir(mod, num_warps, num_stages):
|
||||
pm.enable_debug()
|
||||
# Convert blocked layout to mma layout for dot ops so that pipeline
|
||||
# can get shared memory swizzled correctly.
|
||||
pm.add_coalesce_pass()
|
||||
pm.add_triton_gpu_combine_pass()
|
||||
pm.add_tritongpu_pipeline_pass(num_stages)
|
||||
# Prefetch must be done after pipeline pass because pipeline pass
|
||||
@@ -889,7 +890,6 @@ def ttir_to_ttgir(mod, num_warps, num_stages):
|
||||
pm.add_tritongpu_prefetch_pass()
|
||||
pm.add_canonicalizer_pass()
|
||||
pm.add_cse_pass()
|
||||
pm.add_coalesce_pass()
|
||||
pm.add_triton_gpu_combine_pass()
|
||||
pm.add_licm_pass()
|
||||
pm.add_triton_gpu_combine_pass()
|
||||
|
@@ -34,7 +34,7 @@ TEST_P(SwizzleDotOperandTestFixture, DotOperands) {
|
||||
|
||||
// create element type
|
||||
Type eltType = IntegerType::get(&ctx, params.typeWidth);
|
||||
auto layout = SharedEncodingAttr::get(&ctx, encoding, params.shape, eltType);
|
||||
auto layout = SharedEncodingAttr::get(&ctx, encoding, params.shape, {1, 0}, eltType);
|
||||
|
||||
ASSERT_EQ(layout.getVec(), params.refSwizzle.vec);
|
||||
ASSERT_EQ(layout.getPerPhase(), params.refSwizzle.perPhase);
|
||||
|
Reference in New Issue
Block a user