finish coding

This commit is contained in:
Superjomn
2022-11-04 16:54:05 +08:00
parent 1f552308c4
commit 1ed6ee34ba
3 changed files with 135 additions and 44 deletions

View File

@@ -3654,18 +3654,6 @@ public:
return operand.getElementType().isF32();
}
SmallVector<unsigned> getOrder() const {
SmallVector<unsigned> order(2);
if (mmaLayout.getVersion() == 1)
order = {0, 1};
else if (mmaLayout.getVersion() == 0)
order = {1, 0};
else {
assert(false && "Unexpected MMA version found.");
}
return order;
}
Value loadA(Value tensor, Value llTensor, Value threadId, Location loc,
Value smem, ConversionPatternRewriter &rewriter) const {
@@ -3674,7 +3662,6 @@ public:
auto aShape = tensorTy.getShape();
auto aLayout = tensorTy.getEncoding().cast<SharedEncodingAttr>();
auto aOrder = aLayout.getOrder();
auto order = getOrder();
bool isARow = aOrder[0] == 1;
@@ -3688,14 +3675,115 @@ public:
int aNumPtr = 8;
int bNumPtr = 8;
int aVec = 2;
int NK = aShape[isARow ? 1 : 0];
return Value{};
Value _0 = i32_val(0);
Value _1 = i32_val(1);
Value mContig = _1;
Value nContig = _1;
Value offA0 = isARow ? _0 : mul(threadId, mContig);
Value offA1 = isARow ? mul(threadId, mContig) : _0;
SmallVector<Value> aOff(aNumPtr);
for (int i = 0; i < aNumPtr; ++i) {
aOff[i] =
add(mul(offA0, i32_val(strideA0)), mul(offA1, i32_val(strideA1)));
}
Type f32PtrTy = ptr_ty(f32_ty);
SmallVector<Value> aPtrs(aNumPtr);
for (int i = 0; i < aNumPtr; ++i)
aPtrs[i] = gep(f32PtrTy, llTensor, aOff[i]);
ValueTable has;
auto aShapePerCTA = getShapePerCTA(aLayout);
auto sizePerThread = getSizePerThread(aLayout);
int M = isARow ? aShape[0] : aShape[1];
int K = isARow ? aShape[1] : aShape[0];
for (unsigned k = 0; k < K; k++)
for (unsigned m = 0; m < M; m += aShapePerCTA[aOrder[1]])
for (unsigned mm = 0; mm < sizePerThread[aOrder[1]]; ++mm) {
Value pa = gep(f32PtrTy, aPtrs[0],
i32_val((m + mm) * strideAM + k * strideAK));
Value va = load(pa);
has[{m + mm, k}] = va;
}
SmallVector<Value> values;
for (auto &item : has)
values.push_back(item.second);
Type structTy =
struct_ty(SmallVector<Type>(values.size(), values[0].getType()));
return getStructFromElements(loc, values, rewriter, structTy);
}
Value loadB(Value tensor, Value llTensor, Value threadId, Location loc,
Value smem, ConversionPatternRewriter &rewriter) const {
return Value{};
auto *ctx = rewriter.getContext();
auto tensorTy = tensor.getType().cast<RankedTensorType>();
auto bShape = tensorTy.getShape();
auto bLayout = tensorTy.getEncoding().cast<SharedEncodingAttr>();
auto bOrder = bLayout.getOrder();
bool isBRow = bOrder[0] == 1;
int strideBN = isBRow ? 1 : bShape[0];
int strideBK = isBRow ? bShape[1] : 1;
int strideB0 = isBRow ? strideBN : strideBK;
int strideB1 = isBRow ? strideBK : strideBN;
int ldb = isBRow ? strideBK : strideBN;
int bPerPhase = bLayout.getPerPhase();
int bMaxPhase = bLayout.getMaxPhase();
int bNumPtr = 8;
int bVec = 4;
auto bShapePerCTA = getShapePerCTA(bLayout);
auto sizePerThread = getSizePerThread(bLayout);
Value _0 = i32_val(0);
Value _1 = i32_val(1);
Value mContig = _1;
Value nContig = _1;
Value offB0 = isBRow ? mul(threadId, nContig) : _0;
Value offB1 = isBRow ? _0 : mul(threadId, nContig);
SmallVector<Value> bOff(bNumPtr);
for (int i = 0; i < bNumPtr; ++i) {
bOff[i] =
add(mul(offB0, i32_val(strideB0)), mul(offB1, i32_val(strideB1)));
}
Type f32PtrTy = ptr_ty(f32_ty);
SmallVector<Value> bPtrs(bNumPtr);
for (int i = 0; i < bNumPtr; ++i)
bPtrs[i] = gep(f32PtrTy, llTensor, bOff[i]);
ValueTable hbs;
int K = isBRow ? bShape[0] : bShape[1];
int N = isBRow ? bShape[1] : bShape[0];
for (int k = 0; k < K; ++k)
for (unsigned n = 0; n < N; n += bShapePerCTA[bOrder[0]])
for (unsigned nn = 0; nn < sizePerThread[bOrder[0]]; ++nn) {
Value pb = gep(f32PtrTy, bPtrs[0],
i32_val((n + nn) * strideBN + k * strideBK));
Value vb = load(pb);
hbs[{n + nn, k}] = vb;
}
SmallVector<Value> values;
for (auto &item : hbs)
values.push_back(item.second);
Type structTy =
struct_ty(SmallVector<Type>(values.size(), values[0].getType()));
return getStructFromElements(loc, values, rewriter, structTy);
}
ValueTable extractLoadedOperand(Value llTensor) const { return ValueTable{}; }
@@ -3738,18 +3826,15 @@ LogicalResult ConvertLayoutOpConversion::lowerSharedToDotOperand(
rewriter, getTypeConverter(),
op.getLoc());
if (dotOperandLayout.getOpIdx() == 0) {
// operand $a
if (dotOperandLayout.getOpIdx() == 0) { // operand $a
res = mmaHelper.loadA(src, adaptor.src());
} else if (dotOperandLayout.getOpIdx() == 1) {
// operand $b
} else if (dotOperandLayout.getOpIdx() == 1) { // operand $b
res = mmaHelper.loadB(src, adaptor.src());
}
} else if (!isOuter && mmaLayout.getVersion() == 1 &&
isHMMA) { // tensor core v1
DotOpMmaV1ConversionHelper helper(mmaLayout);
if (dotOperandLayout.getOpIdx() == 0) {
// operand $a
if (dotOperandLayout.getOpIdx() == 0) { // operand $a
res = helper.loadA(src, adaptor.src(), getThreadId(rewriter, loc),
adaptor.src(), loc, rewriter);
} else if (dotOperandLayout.getOpIdx() == 1) {
@@ -3758,7 +3843,14 @@ LogicalResult ConvertLayoutOpConversion::lowerSharedToDotOperand(
adaptor.src(), loc, rewriter);
}
} else if (DotOpFMAConversionHelper::useFMA(dstTensorTy)) { // fmadot
DotOpMmaV1ConversionHelper helper(mmaLayout);
if (dotOperandLayout.getOpIdx() == 0) { // operand $a
res = helper.loadA(src, adaptor.src(), getThreadId(rewriter, loc),
adaptor.src(), loc, rewriter);
} else if (dotOperandLayout.getOpIdx() == 1) { // operand $b
res = helper.loadB(src, adaptor.src(), getThreadId(rewriter, loc),
adaptor.src(), loc, rewriter);
}
} else {
assert(false && "Unsupported mma layout found");
}
@@ -4245,26 +4337,20 @@ DotOpConversion::convertFMADot(triton::DotOp op, OpAdaptor adaptor,
auto aLayout = aTensorTy.getEncoding().cast<SharedEncodingAttr>();
auto bLayout = bTensorTy.getEncoding().cast<SharedEncodingAttr>();
auto cLayout = cTensorTy.getEncoding().cast<MmaEncodingAttr>();
auto dLayout = dTensorTy.getEncoding().cast<MmaEncodingAttr>();
auto cLayout = cTensorTy.getEncoding().cast<BlockedEncodingAttr>();
auto dLayout = dTensorTy.getEncoding().cast<BlockedEncodingAttr>();
auto aOrder = aLayout.getOrder();
auto bOrder = bLayout.getOrder();
// According to the original logic, if target.sm < 80, get a {0,1} or get a
// {1,0}
SmallVector<int> order(2);
if (dLayout.getVersion() == 1)
order = {0, 1};
else
order = {1, 0};
auto order = dLayout.getOrder();
bool isARow = aOrder[0] == 1;
bool isBRow = bOrder[0] == 1;
int strideAM = isARow ? aShape[1] : 1;
int strideAK = isARow ? 1 : aShape[0];
int strideBN = isBRow ? 1 : aShape[0];
int strideBN = isBRow ? 1 : bShape[0];
int strideBK = isBRow ? bShape[1] : 1;
int strideA0 = isARow ? strideAK : strideAM;
int strideA1 = isARow ? strideAM : strideAK;
@@ -4315,9 +4401,9 @@ DotOpConversion::convertFMADot(triton::DotOp op, OpAdaptor adaptor,
bPtrs[i] = gep(f32PtrTy, adaptor.b(), bOff[i]);
// TODO initialize ret with $c.
std::map<std::pair<int, int>, Value> has, hbs;
DotOpFMAConversionHelper::ValueTable has, hbs;
auto cc = getElementsFromStruct(loc, adaptor.c(), rewriter);
SmallVector<Value> ret(cShape[0] * cShape[1], cc[0]);
SmallVector<Value> ret = cc;
for (unsigned k = 0; k < NK; k++) {
int z = 0;
@@ -4982,8 +5068,8 @@ void ConvertTritonGPUToLLVM::initSharedMemory(
OpBuilder b(mod.getBodyRegion());
auto loc = mod.getLoc();
auto elemTy = typeConverter.convertType(b.getIntegerType(8));
// Set array size 0 and external linkage indicates that we use dynamic shared
// allocation to allow a larger shared memory size for each kernel.
// Set array size 0 and external linkage indicates that we use dynamic
// shared allocation to allow a larger shared memory size for each kernel.
auto arrayTy = LLVM::LLVMArrayType::get(elemTy, 0);
auto global = b.create<LLVM::GlobalOp>(
loc, arrayTy, /*isConstant=*/false, LLVM::Linkage::External,

View File

@@ -117,10 +117,12 @@ SmallVector<unsigned> getShapePerCTA(const Attribute &layout) {
"BlockedEncodingAttr not implemented");
}
} else if (auto mmaLayout = layout.dyn_cast<MmaEncodingAttr>()) {
assert(mmaLayout.getVersion() == 2 &&
"mmaLayout version = 1 is not implemented yet");
if (mmaLayout.getVersion() == 2)
return {16 * mmaLayout.getWarpsPerCTA()[0],
8 * mmaLayout.getWarpsPerCTA()[1]};
if (mmaLayout.getVersion() == 1)
return {16 * mmaLayout.getWarpsPerCTA()[0],
16 * mmaLayout.getWarpsPerCTA()[1]};
} else {
assert(0 && "Unimplemented usage of getShapePerCTA");
}

View File

@@ -797,11 +797,14 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} {
#dot_operand_a = #triton_gpu.dot_op<{opIdx=0, parent=#mma}>
#dot_operand_b = #triton_gpu.dot_op<{opIdx=1, parent=#mma}>
module attributes {"triton_gpu.num-warps" = 4 : i32} {
func @matmul884_kernel_dot_operand_layout(%ptr:!tt.ptr<f32> {tt.divisibility = 16 : i32},
func @matmul_fmadot(%ptr:!tt.ptr<f32> {tt.divisibility = 16 : i32},
%a:tensor<32x16xf32, #shared>, %b:tensor<16x32xf32, #shared>) {
%cst = arith.constant dense<0.000000e+00> : tensor<16x16xf32, #mma>
%cst = arith.constant dense<0.000000e+00> : tensor<32x32xf32, #blocked>
// CHECK: llvm.intr.fmuladd
%28 = tt.dot %a, %b, %cst {allowTF32 = false, transA = false, transB = false} : tensor<32x16xf32, #shared> * tensor<16x32xf32, #shared> -> tensor<16x16xf32, #mma>
%28 = tt.dot %a, %b, %cst {allowTF32 = false, transA = false, transB = false} : tensor<32x16xf32, #shared> * tensor<16x32xf32, #shared> -> tensor<32x32xf32, #blocked>
%30 = tt.splat %ptr : (!tt.ptr<f32>) -> tensor<32x1x!tt.ptr<f32>, #blocked>
%36 = tt.broadcast %30 : (tensor<32x1x!tt.ptr<f32>, #blocked>) -> tensor<32x32x!tt.ptr<f32>, #blocked>
tt.store %36, %28 : tensor<32x32xf32, #blocked>
return
}
}