[BACKEND] MMA->DotOperand conversion for chain dot of float32 tensors (#962)
Co-authored-by: Philippe Tillet <phil@openai.com>
This commit is contained in:
@@ -2682,63 +2682,24 @@ public:
|
||||
dstLayout.isa<SliceEncodingAttr>())) {
|
||||
return lowerDistributedToDistributed(op, adaptor, rewriter);
|
||||
}
|
||||
// dot_op<opIdx=0, parent=#mma> = #mma
|
||||
// when #mma = MmaEncoding<version=2, warpsPerCTA=[..., 1]>
|
||||
if (srcLayout.isa<MmaEncodingAttr>() &&
|
||||
dstLayout.isa<DotOperandEncodingAttr>()) {
|
||||
auto srcMmaLayout = srcLayout.cast<MmaEncodingAttr>();
|
||||
auto dstDotLayout = dstLayout.cast<DotOperandEncodingAttr>();
|
||||
if (srcMmaLayout.getWarpsPerCTA()[1] == 1 &&
|
||||
dstDotLayout.getOpIdx() == 0 &&
|
||||
dstDotLayout.getParent() == srcMmaLayout) {
|
||||
// get source values
|
||||
Location loc = op->getLoc();
|
||||
auto vals = getElementsFromStruct(loc, adaptor.src(), rewriter);
|
||||
unsigned elems = getElemsPerThread(srcTy);
|
||||
Type elemTy =
|
||||
this->getTypeConverter()->convertType(srcTy.getElementType());
|
||||
// for the destination type, we need to pack values together
|
||||
// so they can be consumed by tensor core operations
|
||||
unsigned vecSize =
|
||||
std::max<unsigned>(32 / elemTy.getIntOrFloatBitWidth(), 1);
|
||||
Type vecTy = vec_ty(elemTy, vecSize);
|
||||
SmallVector<Type> types(elems / vecSize, vecTy);
|
||||
SmallVector<Value> vecVals;
|
||||
for (unsigned i = 0; i < elems; i += vecSize) {
|
||||
Value packed = rewriter.create<LLVM::UndefOp>(loc, vecTy);
|
||||
for (unsigned j = 0; j < vecSize; j++)
|
||||
packed = insert_element(vecTy, packed, vals[i + j], i32_val(j));
|
||||
vecVals.push_back(packed);
|
||||
}
|
||||
|
||||
// This needs to be ordered the same way that
|
||||
// ldmatrix.x4 would order it
|
||||
// TODO: this needs to be refactor so we don't
|
||||
// implicitly depends on how emitOffsetsForMMAV2
|
||||
// is implemented
|
||||
SmallVector<Value> reorderedVals;
|
||||
for (unsigned i = 0; i < vecVals.size(); i += 4) {
|
||||
reorderedVals.push_back(vecVals[i]);
|
||||
reorderedVals.push_back(vecVals[i + 2]);
|
||||
reorderedVals.push_back(vecVals[i + 1]);
|
||||
reorderedVals.push_back(vecVals[i + 3]);
|
||||
}
|
||||
|
||||
// return composeValuesToDotOperandLayoutStruct(ha, numRepM, numRepK);
|
||||
|
||||
Type structTy =
|
||||
LLVM::LLVMStructType::getLiteral(this->getContext(), types);
|
||||
Value view =
|
||||
getStructFromElements(loc, reorderedVals, rewriter, structTy);
|
||||
rewriter.replaceOp(op, view);
|
||||
return success();
|
||||
}
|
||||
return lowerMmaToDotOperand(op, adaptor, rewriter);
|
||||
}
|
||||
// TODO: to be implemented
|
||||
llvm_unreachable("unsupported layout conversion");
|
||||
return failure();
|
||||
}
|
||||
|
||||
static bool isMmaToDotShortcut(MmaEncodingAttr &mmaLayout,
|
||||
DotOperandEncodingAttr &dotOperandLayout) {
|
||||
// dot_op<opIdx=0, parent=#mma> = #mma
|
||||
// when #mma = MmaEncoding<version=2, warpsPerCTA=[..., 1]>
|
||||
return mmaLayout.getWarpsPerCTA()[1] == 1 &&
|
||||
dotOperandLayout.getOpIdx() == 0 &&
|
||||
dotOperandLayout.getParent() == mmaLayout;
|
||||
}
|
||||
|
||||
static void storeBlockedToShared(Value src, Value llSrc,
|
||||
ArrayRef<Value> srcStrides,
|
||||
ArrayRef<Value> srcIndices, Value dst,
|
||||
@@ -3003,6 +2964,11 @@ private:
|
||||
lowerSharedToDotOperand(triton::gpu::ConvertLayoutOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const;
|
||||
|
||||
// mma -> dot_operand
|
||||
LogicalResult lowerMmaToDotOperand(triton::gpu::ConvertLayoutOp op,
|
||||
OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const;
|
||||
|
||||
// shared -> dot_operand if the result layout is mma
|
||||
Value lowerSharedToDotOperandMMA(
|
||||
triton::gpu::ConvertLayoutOp op, OpAdaptor adaptor,
|
||||
@@ -3209,6 +3175,58 @@ LogicalResult ConvertLayoutOpConversion::lowerBlockedToShared(
|
||||
return success();
|
||||
}
|
||||
|
||||
LogicalResult ConvertLayoutOpConversion::lowerMmaToDotOperand(
|
||||
triton::gpu::ConvertLayoutOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const {
|
||||
auto loc = op.getLoc();
|
||||
auto srcTy = op.src().getType().cast<RankedTensorType>();
|
||||
auto dstTy = op.result().getType().cast<RankedTensorType>();
|
||||
auto srcLayout = srcTy.getEncoding();
|
||||
auto dstLayout = dstTy.getEncoding();
|
||||
auto srcMmaLayout = srcLayout.cast<MmaEncodingAttr>();
|
||||
auto dstDotLayout = dstLayout.cast<DotOperandEncodingAttr>();
|
||||
if (isMmaToDotShortcut(srcMmaLayout, dstDotLayout)) {
|
||||
// get source values
|
||||
auto vals = getElementsFromStruct(loc, adaptor.src(), rewriter);
|
||||
unsigned elems = getElemsPerThread(srcTy);
|
||||
Type elemTy = this->getTypeConverter()->convertType(srcTy.getElementType());
|
||||
// for the destination type, we need to pack values together
|
||||
// so they can be consumed by tensor core operations
|
||||
unsigned vecSize =
|
||||
std::max<unsigned>(32 / elemTy.getIntOrFloatBitWidth(), 1);
|
||||
Type vecTy = vec_ty(elemTy, vecSize);
|
||||
SmallVector<Type> types(elems / vecSize, vecTy);
|
||||
SmallVector<Value> vecVals;
|
||||
for (unsigned i = 0; i < elems; i += vecSize) {
|
||||
Value packed = rewriter.create<LLVM::UndefOp>(loc, vecTy);
|
||||
for (unsigned j = 0; j < vecSize; j++)
|
||||
packed = insert_element(vecTy, packed, vals[i + j], i32_val(j));
|
||||
vecVals.push_back(packed);
|
||||
}
|
||||
|
||||
// This needs to be ordered the same way that
|
||||
// ldmatrix.x4 would order it
|
||||
// TODO: this needs to be refactor so we don't
|
||||
// implicitly depends on how emitOffsetsForMMAV2
|
||||
// is implemented
|
||||
SmallVector<Value> reorderedVals;
|
||||
for (unsigned i = 0; i < vecVals.size(); i += 4) {
|
||||
reorderedVals.push_back(vecVals[i]);
|
||||
reorderedVals.push_back(vecVals[i + 2]);
|
||||
reorderedVals.push_back(vecVals[i + 1]);
|
||||
reorderedVals.push_back(vecVals[i + 3]);
|
||||
}
|
||||
|
||||
// return composeValuesToDotOperandLayoutStruct(ha, numRepM, numRepK);
|
||||
|
||||
Type structTy = LLVM::LLVMStructType::getLiteral(this->getContext(), types);
|
||||
Value view = getStructFromElements(loc, reorderedVals, rewriter, structTy);
|
||||
rewriter.replaceOp(op, view);
|
||||
return success();
|
||||
}
|
||||
return failure();
|
||||
}
|
||||
|
||||
struct InsertSliceOpConversion
|
||||
: public ConvertTritonGPUOpToLLVMPattern<tensor::InsertSliceOp> {
|
||||
using ConvertTritonGPUOpToLLVMPattern<
|
||||
@@ -4625,6 +4643,34 @@ class ConvertTritonGPUToLLVM
|
||||
: public ConvertTritonGPUToLLVMBase<ConvertTritonGPUToLLVM> {
|
||||
|
||||
private:
|
||||
void decomposeMmaToDotOperand(ModuleOp mod, int numWarps) {
|
||||
// replace `mma -> dot_op` with `mma -> blocked -> dot_op`
|
||||
// unless certain conditions are met
|
||||
mod.walk([&](triton::gpu::ConvertLayoutOp cvtOp) -> void {
|
||||
OpBuilder builder(cvtOp);
|
||||
auto srcType = cvtOp.getOperand().getType().cast<RankedTensorType>();
|
||||
auto dstType = cvtOp.getType().cast<RankedTensorType>();
|
||||
auto srcMma =
|
||||
srcType.getEncoding().dyn_cast<triton::gpu::MmaEncodingAttr>();
|
||||
auto dstDotOp =
|
||||
dstType.getEncoding().dyn_cast<triton::gpu::DotOperandEncodingAttr>();
|
||||
if (srcMma && dstDotOp &&
|
||||
!ConvertLayoutOpConversion::isMmaToDotShortcut(srcMma, dstDotOp)) {
|
||||
auto tmpType = RankedTensorType::get(
|
||||
dstType.getShape(), dstType.getElementType(),
|
||||
triton::gpu::BlockedEncodingAttr::get(
|
||||
mod.getContext(), srcType.getShape(), getSizePerThread(srcMma),
|
||||
getOrder(srcMma), numWarps));
|
||||
auto tmp = builder.create<triton::gpu::ConvertLayoutOp>(
|
||||
cvtOp.getLoc(), tmpType, cvtOp.getOperand());
|
||||
auto newConvert = builder.create<triton::gpu::ConvertLayoutOp>(
|
||||
cvtOp.getLoc(), dstType, tmp);
|
||||
cvtOp.replaceAllUsesWith(newConvert.getResult());
|
||||
cvtOp.erase();
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
void decomposeBlockedToDotOperand(ModuleOp mod) {
|
||||
// replace `blocked -> dot_op` with `blocked -> shared -> dot_op`
|
||||
// because the codegen doesn't handle `blocked -> dot_op` directly
|
||||
@@ -4771,6 +4817,8 @@ public:
|
||||
// separation between 1/4 is that, step 3 is out of the scope of Dialect
|
||||
// Conversion, thus we need to make sure the smem is not revised during the
|
||||
// conversion of step 4.
|
||||
decomposeMmaToDotOperand(mod, numWarps);
|
||||
|
||||
decomposeBlockedToDotOperand(mod);
|
||||
|
||||
decomposeInsertSliceAsyncOp(mod);
|
||||
|
@@ -1071,21 +1071,23 @@ def test_permute(dtype_str, shape, perm, device='cuda'):
|
||||
# # ---------------
|
||||
|
||||
|
||||
@pytest.mark.parametrize("epilogue, allow_tf32, dtype",
|
||||
[(epilogue, allow_tf32, dtype)
|
||||
@pytest.mark.parametrize("M, N, K, epilogue, allow_tf32, dtype",
|
||||
[(*shape, epilogue, allow_tf32, dtype)
|
||||
for shape in [(64, 64, 64)]
|
||||
for epilogue in ['none', 'trans', 'add-matrix', 'add-rows', 'add-cols', 'softmax', 'chain-dot']
|
||||
for allow_tf32 in [True, False]
|
||||
for dtype in ['float16']
|
||||
for dtype in ['float16', 'float32']
|
||||
if not (allow_tf32 and (dtype in ['float16']))])
|
||||
def test_dot(epilogue, allow_tf32, dtype, device='cuda'):
|
||||
def test_dot(M, N, K, epilogue, allow_tf32, dtype, device='cuda'):
|
||||
capability = torch.cuda.get_device_capability()
|
||||
if capability[0] < 80:
|
||||
if capability[0] < 8:
|
||||
if dtype == 'int8':
|
||||
pytest.skip("Only test int8 on devices with sm >= 80")
|
||||
elif dtype == 'float32' and allow_tf32:
|
||||
pytest.skip("Only test tf32 on devices with sm >= 80")
|
||||
|
||||
M, N, K = 64, 64, 64
|
||||
torch.backends.cuda.matmul.allow_tf32 = allow_tf32
|
||||
|
||||
num_warps = 4
|
||||
trans_a, trans_b = False, False
|
||||
|
||||
@@ -1130,7 +1132,8 @@ def test_dot(epilogue, allow_tf32, dtype, device='cuda'):
|
||||
if CHAIN_DOT:
|
||||
# tl.store(Zs, z)
|
||||
# tl.debug_barrier()
|
||||
z = tl.dot(z.to(tl.float16), tl.load(Ws))
|
||||
w = tl.load(Ws)
|
||||
z = tl.dot(z.to(w.dtype), w)
|
||||
tl.store(Zs, z)
|
||||
# input
|
||||
rs = RandomState(17)
|
||||
@@ -1180,14 +1183,18 @@ def test_dot(epilogue, allow_tf32, dtype, device='cuda'):
|
||||
z_ref = np.matmul(z_ref, w)
|
||||
# compare
|
||||
# print(z_ref[:,0], z_tri[:,0])
|
||||
np.testing.assert_allclose(z_ref, to_numpy(z_tri), rtol=0.01)
|
||||
if dtype == 'float32':
|
||||
# XXX: Somehow there's a larger difference when we use float32
|
||||
np.testing.assert_allclose(z_ref, to_numpy(z_tri), rtol=0.01, atol=1e-3)
|
||||
else:
|
||||
np.testing.assert_allclose(z_ref, to_numpy(z_tri), rtol=0.01)
|
||||
# make sure ld/st are vectorized
|
||||
ptx = pgm.asm['ptx']
|
||||
assert 'ld.global.v4' in ptx
|
||||
assert 'st.global.v4' in ptx
|
||||
if allow_tf32:
|
||||
assert 'mma.sync.aligned.m16n8k8.row.col.f32.tf32.tf32.f32' in ptx
|
||||
elif dtype == 'float32':
|
||||
elif dtype == 'float32' and allow_tf32:
|
||||
assert 'mma.sync.aligned.m16n8k8.row.col.f32.tf32.tf32.f32' not in ptx
|
||||
elif dtype == 'int8':
|
||||
assert 'mma.sync.aligned.m16n8k32.row.col.satfinite.s32.s8.s8.s32' in ptx
|
||||
|
Reference in New Issue
Block a user