Compare commits
11 Commits
phil/mma-v
...
keren/v100
Author | SHA1 | Date | |
---|---|---|---|
|
65896aef9d | ||
|
d8d6b9f3f1 | ||
|
3a1c140385 | ||
|
0cfe909df8 | ||
|
e5cfa0f633 | ||
|
e552219104 | ||
|
52accd4c2b | ||
|
4fb048873a | ||
|
be2f70699c | ||
|
24fd953f9a | ||
|
793012b4c4 |
9
.github/workflows/integration-tests.yml
vendored
9
.github/workflows/integration-tests.yml
vendored
@@ -89,7 +89,14 @@ jobs:
|
||||
if: ${{matrix.runner[0] == 'self-hosted' && matrix.runner[1] == 'V100'}}
|
||||
run: |
|
||||
cd python/tests
|
||||
pytest test_gemm.py::test_gemm_for_mmav1
|
||||
pytest -k "not test_where_broadcast and not test_dot" test_core.py
|
||||
pytest test_gemm.py
|
||||
pytest test_backend.py
|
||||
pytest test_reduce.py
|
||||
pytest test_vecadd.py
|
||||
pytest test_elementwise.py
|
||||
pytest test_ext_elemwise.py
|
||||
pytest test_transpose.py
|
||||
|
||||
- name: Run CXX unittests
|
||||
run: |
|
||||
|
@@ -43,6 +43,8 @@ bool maybeSharedAllocationOp(Operation *op);
|
||||
|
||||
bool maybeAliasOp(Operation *op);
|
||||
|
||||
bool supportMMA(triton::DotOp op, int version);
|
||||
|
||||
std::string getValueOperandName(Value value, AsmState &state);
|
||||
|
||||
template <typename T_OUT, typename T_IN>
|
||||
|
@@ -416,15 +416,35 @@ In TritonGPU dialect, considering `d = tt.dot a, b, c`
|
||||
tt.dot's operands a and b must be of DotOperandEncodingAttr layout.
|
||||
a's opIdx is 0, b's opIdx is 1.
|
||||
The parend field in DotOperandEncodingAttr is the layout of d.
|
||||
|
||||
For MMA v1, an additional attribute `isMMAv1Row` determines whether e.g. the a operand is used
|
||||
in the context of an mma.884.row.col or an mma.884.col.col operation. See the PTX ISA documentation
|
||||
section 9.7.13.4.1 for more details.
|
||||
}];
|
||||
|
||||
let parameters = (
|
||||
ins
|
||||
"unsigned":$opIdx,
|
||||
"Attribute":$parent
|
||||
"Attribute":$parent,
|
||||
"Attribute":$isMMAv1Row
|
||||
);
|
||||
|
||||
let builders = [
|
||||
AttrBuilder<(ins "unsigned":$opIdx,
|
||||
"Attribute":$parent), [{
|
||||
Attribute isMMAv1Row;
|
||||
if(parent.isa<MmaEncodingAttr>() &&
|
||||
parent.cast<MmaEncodingAttr>().getVersion() == 1){
|
||||
isMMAv1Row = BoolAttr::get(context, true);
|
||||
}
|
||||
return $_get(context, opIdx, parent, isMMAv1Row);
|
||||
}]>
|
||||
|
||||
];
|
||||
|
||||
let extraClassDeclaration = extraBaseClassDeclaration;
|
||||
}
|
||||
|
||||
|
||||
|
||||
#endif
|
||||
|
@@ -110,6 +110,19 @@ bool maybeAliasOp(Operation *op) {
|
||||
isa<tensor::InsertSliceOp>(op);
|
||||
}
|
||||
|
||||
bool supportMMA(triton::DotOp op, int version) {
|
||||
// Refer to mma section for the data type supported by Volta and Hopper
|
||||
// Tensor Core in
|
||||
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#warp-level-matrix-fragment-mma-884-f16
|
||||
auto aElemTy = op.a().getType().cast<RankedTensorType>().getElementType();
|
||||
auto bElemTy = op.b().getType().cast<RankedTensorType>().getElementType();
|
||||
return (aElemTy.isF16() && bElemTy.isF16()) ||
|
||||
(aElemTy.isBF16() && bElemTy.isBF16()) ||
|
||||
(aElemTy.isF32() && bElemTy.isF32() && op.allowTF32() &&
|
||||
version >= 2) ||
|
||||
(aElemTy.isInteger(8) && bElemTy.isInteger(8) && version >= 2);
|
||||
}
|
||||
|
||||
std::string getValueOperandName(Value value, AsmState &state) {
|
||||
std::string opName;
|
||||
llvm::raw_string_ostream ss(opName);
|
||||
|
@@ -105,11 +105,9 @@ struct DotOpMmaV1ConversionHelper {
|
||||
}
|
||||
|
||||
// Get the number of fp16x2 elements for $a.
|
||||
// \param shapeTransed: the shape or reordered shape if transpose needed.
|
||||
// \param shapeTransed: A's shape or reordered shape if transpose needed.
|
||||
// \param orderTransed: the order or reordered order if transpose needed.
|
||||
unsigned getNumM(ArrayRef<int64_t> shapeTransed,
|
||||
ArrayRef<unsigned> orderTransed) const {
|
||||
bool isARow = orderTransed[0] != 0;
|
||||
unsigned getNumM(ArrayRef<int64_t> shapeTransed, bool isARow) const {
|
||||
AParam param(isARow);
|
||||
|
||||
unsigned numM = param.rep[0] * shapeTransed[0] / (param.spw[0] * wpt[0]);
|
||||
@@ -117,11 +115,9 @@ struct DotOpMmaV1ConversionHelper {
|
||||
}
|
||||
|
||||
// Get the number of fp16x2 elements for $b.
|
||||
// \param shapeTransed: the shape or reordered shape if transpose needed.
|
||||
// \param shapeTransed: B' shape or reordered shape if transpose needed.
|
||||
// \param orderTransed: the order or reordered order if transpose needed.
|
||||
unsigned getNumN(ArrayRef<int64_t> shapeTransed,
|
||||
ArrayRef<unsigned> orderTransed) const {
|
||||
bool isBRow = orderTransed[0] != 0;
|
||||
unsigned getNumN(ArrayRef<int64_t> shapeTransed, bool isBRow) const {
|
||||
BParam param(isBRow);
|
||||
|
||||
unsigned numN = param.rep[1] * shapeTransed[1] / (param.spw[1] * wpt[1]);
|
||||
@@ -130,7 +126,7 @@ struct DotOpMmaV1ConversionHelper {
|
||||
|
||||
int numElemsPerThreadA(ArrayRef<int64_t> shapeTransed,
|
||||
ArrayRef<unsigned> orderTransed) const {
|
||||
int numM = getNumM(shapeTransed, orderTransed);
|
||||
int numM = getNumM(shapeTransed, orderTransed[0] == 1);
|
||||
int NK = shapeTransed[1];
|
||||
|
||||
// NOTE: We couldn't get the vec from the shared layout.
|
||||
@@ -143,7 +139,7 @@ struct DotOpMmaV1ConversionHelper {
|
||||
|
||||
int numElemsPerThreadB(ArrayRef<int64_t> shapeTransed,
|
||||
ArrayRef<unsigned> orderTransed) const {
|
||||
unsigned numN = getNumN(shapeTransed, orderTransed);
|
||||
unsigned numN = getNumN(shapeTransed, orderTransed[0] == 1);
|
||||
int NK = shapeTransed[0];
|
||||
// NOTE: We couldn't get the vec from the shared layout.
|
||||
// int vecB = sharedLayout.getVec();
|
||||
@@ -1451,7 +1447,7 @@ Value DotOpMmaV1ConversionHelper::loadA(
|
||||
}
|
||||
};
|
||||
|
||||
unsigned numM = getNumM(shape, order);
|
||||
unsigned numM = getNumM(shape, order[0] == 1);
|
||||
for (unsigned k = 0; k < NK; k += 4)
|
||||
for (unsigned m = 0; m < numM / 2; ++m)
|
||||
loadA(m, k);
|
||||
@@ -1563,7 +1559,7 @@ Value DotOpMmaV1ConversionHelper::loadB(
|
||||
}
|
||||
};
|
||||
|
||||
unsigned numN = getNumN(shape, order);
|
||||
unsigned numN = getNumN(shape, order[0] == 1);
|
||||
for (unsigned k = 0; k < NK; k += 4)
|
||||
for (unsigned n = 0; n < numN / 2; ++n) {
|
||||
if (!hbs.count({n, k}))
|
||||
|
@@ -739,7 +739,6 @@ Value convertSplatLikeOp(Type elemType, Type resType, Value constVal,
|
||||
auto tensorTy = resType.cast<RankedTensorType>();
|
||||
if (tensorTy.getEncoding().isa<BlockedEncodingAttr>() ||
|
||||
tensorTy.getEncoding().isa<SliceEncodingAttr>()) {
|
||||
auto tensorTy = resType.cast<RankedTensorType>();
|
||||
auto srcType = typeConverter->convertType(elemType);
|
||||
auto llSrc = bitcast(constVal, srcType);
|
||||
size_t elemsPerThread = getElemsPerThread(tensorTy);
|
||||
@@ -981,7 +980,7 @@ struct LoadOpConversion
|
||||
size_t size = width / valueElemNbits;
|
||||
|
||||
auto vecTy = LLVM::getFixedVectorType(valueElemTy, size);
|
||||
Value v = rewriter.create<LLVM::UndefOp>(loc, vecTy);
|
||||
Value v = undef(vecTy);
|
||||
for (size_t s = 0; s < size; ++s) {
|
||||
Value falseVal = otherElems[vecStart + ii * size + s];
|
||||
Value sVal = createIndexAttrConstant(
|
||||
@@ -1118,7 +1117,7 @@ struct StoreOpConversion
|
||||
SmallVector<std::pair<Value, std::string>> asmArgs;
|
||||
for (size_t wordIdx = 0; wordIdx < nWords; ++wordIdx) {
|
||||
// llWord is a width-len composition
|
||||
Value llWord = rewriter.create<LLVM::UndefOp>(loc, wordTy);
|
||||
Value llWord = undef(wordTy);
|
||||
// Insert each value element to the composition
|
||||
for (size_t elemIdx = 0; elemIdx < wordNElems; ++elemIdx) {
|
||||
const size_t elemOffset = vecStart + wordIdx * wordNElems + elemIdx;
|
||||
@@ -1129,10 +1128,7 @@ struct StoreOpConversion
|
||||
elem = bitcast(elem, valueElemTy);
|
||||
|
||||
Type u32Ty = typeConverter->convertType(type::u32Ty(ctx));
|
||||
llWord =
|
||||
insert_element(wordTy, llWord, elem,
|
||||
rewriter.create<LLVM::ConstantOp>(
|
||||
loc, u32Ty, IntegerAttr::get(u32Ty, elemIdx)));
|
||||
llWord = insert_element(wordTy, llWord, elem, i32_val(elemIdx));
|
||||
}
|
||||
llWord = bitcast(llWord, valArgTy);
|
||||
std::string constraint =
|
||||
@@ -2886,11 +2882,15 @@ private:
|
||||
Value laneIdRem16Div4 = udiv(laneIdRem16, _4);
|
||||
Value laneIdRem16Div4Rem2 = urem(laneIdRem16Div4, _2);
|
||||
Value laneIdRem4Div2 = udiv(urem(laneId, _4), _2);
|
||||
Value rowWarpOffset = mul(multiDimWarpId[0], _16);
|
||||
Value colWarpOffset = mul(multiDimWarpId[1], _16);
|
||||
mmaRowIdx[0] =
|
||||
add(add(mul(laneIdDiv16, _8), mul(laneIdRem16Div4Rem2, _4)),
|
||||
laneIdRem2);
|
||||
mmaRowIdx[0] = add(mmaRowIdx[0], rowWarpOffset);
|
||||
mmaRowIdx[1] = add(mmaRowIdx[0], _2);
|
||||
mmaColIdx[0] = add(mul(laneIdRem16Div8, _4), mul(laneIdRem4Div2, _2));
|
||||
mmaColIdx[0] = add(mmaColIdx[0], colWarpOffset);
|
||||
mmaColIdx[1] = add(mmaColIdx[0], _1);
|
||||
mmaColIdx[2] = add(mmaColIdx[0], _8);
|
||||
mmaColIdx[3] = add(mmaColIdx[0], idx_val(9));
|
||||
@@ -3336,9 +3336,8 @@ struct DotOpConversion : public ConvertTritonGPUOpToLLVMPattern<triton::DotOp> {
|
||||
"Unsupported MMA kind found when converting DotOp to LLVM.");
|
||||
}
|
||||
|
||||
if (op.getType().cast<RankedTensorType>().getElementType().isF32() &&
|
||||
A.getType().cast<RankedTensorType>().getElementType().isF32() &&
|
||||
!op.allowTF32())
|
||||
// XXX: fp64 has not been tested yet. In theory, it should work.
|
||||
if (!isMMA)
|
||||
return convertFMADot(op, adaptor, rewriter);
|
||||
|
||||
llvm::report_fatal_error(
|
||||
@@ -3348,33 +3347,16 @@ struct DotOpConversion : public ConvertTritonGPUOpToLLVMPattern<triton::DotOp> {
|
||||
// Tell whether a DotOp support HMMA.
|
||||
// This is port from the master branch, the original logic is retained.
|
||||
static bool isDotHMMA(DotOp op) {
|
||||
auto a = op.a();
|
||||
auto b = op.b();
|
||||
auto c = op.c();
|
||||
auto d = op.getResult();
|
||||
auto aTensorTy = a.getType().cast<RankedTensorType>();
|
||||
auto bTensorTy = b.getType().cast<RankedTensorType>();
|
||||
auto cTensorTy = c.getType().cast<RankedTensorType>();
|
||||
auto dTensorTy = d.getType().cast<RankedTensorType>();
|
||||
|
||||
if (!dTensorTy.getEncoding().isa<MmaEncodingAttr>())
|
||||
return false;
|
||||
|
||||
auto mmaLayout = dTensorTy.getEncoding().cast<MmaEncodingAttr>();
|
||||
auto aElemTy = aTensorTy.getElementType();
|
||||
auto bElemTy = bTensorTy.getElementType();
|
||||
|
||||
assert((mmaLayout.getVersion() == 1 || mmaLayout.getVersion() == 2) &&
|
||||
"Unexpected MMA layout version found");
|
||||
// Refer to mma section for the data type supported by Volta and Hopper
|
||||
// Tensor Core in
|
||||
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#warp-level-matrix-fragment-mma-884-f16
|
||||
return (aElemTy.isF16() && bElemTy.isF16()) ||
|
||||
(aElemTy.isBF16() && bElemTy.isBF16()) ||
|
||||
(aElemTy.isF32() && bElemTy.isF32() && op.allowTF32() &&
|
||||
mmaLayout.getVersion() >= 2) ||
|
||||
(aElemTy.isInteger(8) && bElemTy.isInteger(8) &&
|
||||
mmaLayout.getVersion() >= 2);
|
||||
return supportMMA(op, mmaLayout.getVersion());
|
||||
}
|
||||
|
||||
// Tell whether a DotOp support HMMA by the operand type(either $a or $b).
|
||||
@@ -3428,6 +3410,20 @@ Value ConvertLayoutOpConversion::lowerSharedToDotOperandMMA(
|
||||
} else if (!isOuter && mmaLayout.getVersion() == 1 &&
|
||||
isHMMA) { // tensor core v1
|
||||
DotOpMmaV1ConversionHelper helper(mmaLayout);
|
||||
bool isMMAv1Row =
|
||||
dotOperandLayout.getIsMMAv1Row().cast<BoolAttr>().getValue();
|
||||
auto srcSharedLayout = src.getType()
|
||||
.cast<RankedTensorType>()
|
||||
.getEncoding()
|
||||
.cast<SharedEncodingAttr>();
|
||||
|
||||
// Can only convert [1, 0] to row or [0, 1] to col for now
|
||||
if ((srcSharedLayout.getOrder()[0] == 1 && !isMMAv1Row) ||
|
||||
(srcSharedLayout.getOrder()[0] == 0 && isMMAv1Row)) {
|
||||
llvm::errs() << "Unsupported Shared -> DotOperand[MMAv1] conversion\n";
|
||||
return Value();
|
||||
}
|
||||
|
||||
if (dotOperandLayout.getOpIdx() == 0) { // operand $a
|
||||
// TODO[Superjomn]: transA is not available here.
|
||||
bool transA = false;
|
||||
@@ -3540,51 +3536,39 @@ DotOpConversion::convertMMA884(triton::DotOp op, DotOpAdaptor adaptor,
|
||||
.cast<RankedTensorType>()
|
||||
.getEncoding()
|
||||
.cast<MmaEncodingAttr>();
|
||||
auto ALayout = A.getType()
|
||||
.cast<RankedTensorType>()
|
||||
.getEncoding()
|
||||
.cast<DotOperandEncodingAttr>();
|
||||
auto BLayout = B.getType()
|
||||
.cast<RankedTensorType>()
|
||||
.getEncoding()
|
||||
.cast<DotOperandEncodingAttr>();
|
||||
|
||||
auto ATensorTy = A.getType().cast<RankedTensorType>();
|
||||
auto BTensorTy = B.getType().cast<RankedTensorType>();
|
||||
auto DTensorTy = D.getType().cast<RankedTensorType>();
|
||||
SmallVector<int> AShape(ATensorTy.getShape().begin(),
|
||||
ATensorTy.getShape().end());
|
||||
SmallVector<int> BShape(BTensorTy.getShape().begin(),
|
||||
BTensorTy.getShape().end());
|
||||
auto AShape = ATensorTy.getShape();
|
||||
auto BShape = BTensorTy.getShape();
|
||||
auto DShape = DTensorTy.getShape();
|
||||
auto wpt = mmaLayout.getWarpsPerCTA();
|
||||
|
||||
// TODO[Superjomn]: order cannot accessed in DotOp.
|
||||
SmallVector<unsigned> AOrder({1, 0});
|
||||
SmallVector<unsigned> BOrder({1, 0});
|
||||
bool isARow = ALayout.getIsMMAv1Row().cast<BoolAttr>().getValue();
|
||||
bool isBRow = BLayout.getIsMMAv1Row().cast<BoolAttr>().getValue();
|
||||
|
||||
bool isARow = AOrder[0] != 0;
|
||||
bool isBRow = BOrder[0] != 0;
|
||||
bool isAVec4 = !isARow && AShape[isARow] <= 16; // fp16*4 = 16bytes
|
||||
bool isBVec4 = isBRow && BShape[isBRow] <= 16;
|
||||
// TODO[Superjomn]: ld.v4 is not supported.
|
||||
isAVec4 = true;
|
||||
isBVec4 = true;
|
||||
|
||||
int packSize0 = (isARow || isAVec4) ? 1 : 2;
|
||||
int packSize1 = (isBRow && !isBVec4) ? 2 : 1;
|
||||
SmallVector<int> fpw({2, 2, 1});
|
||||
SmallVector<int> rep({2 * packSize0, 2 * packSize1, 1});
|
||||
SmallVector<int> spw({fpw[0] * 4 * rep[0], fpw[1] * 4 * rep[1], 1});
|
||||
|
||||
Value loadedA = adaptor.a();
|
||||
Value loadedB = adaptor.b();
|
||||
Value loadedC = adaptor.c();
|
||||
DotOpMmaV1ConversionHelper helper(mmaLayout);
|
||||
|
||||
unsigned numM = rep[0] * DShape[0] / (spw[0] * wpt[0]);
|
||||
unsigned numN = rep[1] * DShape[1] / (spw[1] * wpt[1]);
|
||||
unsigned numM = helper.getNumM(AShape, isARow);
|
||||
unsigned numN = helper.getNumN(BShape, isBRow);
|
||||
unsigned NK = AShape[1];
|
||||
|
||||
auto has = helper.extractLoadedOperand(loadedA, NK, rewriter);
|
||||
auto hbs = helper.extractLoadedOperand(loadedB, NK, rewriter);
|
||||
auto has = helper.extractLoadedOperand(adaptor.a(), NK, rewriter);
|
||||
auto hbs = helper.extractLoadedOperand(adaptor.b(), NK, rewriter);
|
||||
|
||||
// Initialize accumulators with external values, the acc holds the accumulator
|
||||
// value that is shared between the MMA instructions inside a DotOp, we can
|
||||
// call the order of the values the accumulator-internal order.
|
||||
SmallVector<Value> acc = getElementsFromStruct(loc, loadedC, rewriter);
|
||||
SmallVector<Value> acc = getElementsFromStruct(loc, adaptor.c(), rewriter);
|
||||
size_t resSize = acc.size();
|
||||
|
||||
// The resVals holds the final result of the DotOp.
|
||||
@@ -3697,38 +3681,19 @@ DotOpConversion::convertFMADot(triton::DotOp op, OpAdaptor adaptor,
|
||||
auto bShape = bTensorTy.getShape();
|
||||
auto cShape = cTensorTy.getShape();
|
||||
|
||||
ValueTable has, hbs;
|
||||
int mShapePerCTA{-1}, nShapePerCTA{-1};
|
||||
int mSizePerThread{-1}, nSizePerThread{-1};
|
||||
ArrayRef<unsigned> aOrder, bOrder;
|
||||
Value llA, llB;
|
||||
BlockedEncodingAttr dLayout =
|
||||
dTensorTy.getEncoding().cast<BlockedEncodingAttr>();
|
||||
auto order = dLayout.getOrder();
|
||||
auto cc = getElementsFromStruct(loc, adaptor.c(), rewriter);
|
||||
|
||||
DotOpFMAConversionHelper helper(dLayout);
|
||||
if (auto aDotOpLayout =
|
||||
aTensorTy.getEncoding()
|
||||
.dyn_cast<DotOperandEncodingAttr>()) { // get input from
|
||||
// convert_layout
|
||||
auto bDotOpLayout =
|
||||
bTensorTy.getEncoding().dyn_cast<DotOperandEncodingAttr>();
|
||||
auto aLayout = aDotOpLayout.getParent().cast<BlockedEncodingAttr>();
|
||||
auto bLayout = bDotOpLayout.getParent().cast<BlockedEncodingAttr>();
|
||||
auto aDotOpLayout = aTensorTy.getEncoding().cast<DotOperandEncodingAttr>();
|
||||
auto bDotOpLayout = bTensorTy.getEncoding().cast<DotOperandEncodingAttr>();
|
||||
auto aLayout = aDotOpLayout.getParent().cast<BlockedEncodingAttr>();
|
||||
auto bLayout = bDotOpLayout.getParent().cast<BlockedEncodingAttr>();
|
||||
|
||||
assert(bLayout);
|
||||
llA = adaptor.a();
|
||||
llB = adaptor.b();
|
||||
} else if (auto aLayout =
|
||||
aTensorTy.getEncoding()
|
||||
.dyn_cast<SharedEncodingAttr>()) { // load input from smem
|
||||
auto bLayout = bTensorTy.getEncoding().dyn_cast<SharedEncodingAttr>();
|
||||
assert(bLayout);
|
||||
Value thread = getThreadId(rewriter, loc);
|
||||
llA = helper.loadA(A, adaptor.a(), dLayout, thread, loc, rewriter);
|
||||
llB = helper.loadB(B, adaptor.b(), dLayout, thread, loc, rewriter);
|
||||
}
|
||||
Value llA = adaptor.a();
|
||||
Value llB = adaptor.b();
|
||||
|
||||
auto sizePerThread = getSizePerThread(dLayout);
|
||||
auto shapePerCTA = getShapePerCTA(dLayout);
|
||||
@@ -3737,17 +3702,19 @@ DotOpConversion::convertFMADot(triton::DotOp op, OpAdaptor adaptor,
|
||||
int M = aShape[0];
|
||||
int N = bShape[1];
|
||||
|
||||
mShapePerCTA = order[0] == 1 ? shapePerCTA[order[1]] : shapePerCTA[order[0]];
|
||||
mSizePerThread =
|
||||
int mShapePerCTA =
|
||||
order[0] == 1 ? shapePerCTA[order[1]] : shapePerCTA[order[0]];
|
||||
int mSizePerThread =
|
||||
order[0] == 1 ? sizePerThread[order[1]] : sizePerThread[order[0]];
|
||||
nShapePerCTA = order[0] == 0 ? shapePerCTA[order[1]] : shapePerCTA[order[0]];
|
||||
nSizePerThread =
|
||||
int nShapePerCTA =
|
||||
order[0] == 0 ? shapePerCTA[order[1]] : shapePerCTA[order[0]];
|
||||
int nSizePerThread =
|
||||
order[0] == 0 ? sizePerThread[order[1]] : sizePerThread[order[0]];
|
||||
|
||||
has = helper.getValueTableFromStruct(llA, K, M, mShapePerCTA, mSizePerThread,
|
||||
rewriter, loc);
|
||||
hbs = helper.getValueTableFromStruct(llB, K, N, nShapePerCTA, nSizePerThread,
|
||||
rewriter, loc);
|
||||
auto has = helper.getValueTableFromStruct(llA, K, M, mShapePerCTA,
|
||||
mSizePerThread, rewriter, loc);
|
||||
auto hbs = helper.getValueTableFromStruct(llB, K, N, nShapePerCTA,
|
||||
nSizePerThread, rewriter, loc);
|
||||
|
||||
SmallVector<Value> ret = cc;
|
||||
for (unsigned k = 0; k < K; k++) {
|
||||
@@ -3758,7 +3725,6 @@ DotOpConversion::convertFMADot(triton::DotOp op, OpAdaptor adaptor,
|
||||
for (unsigned nn = 0; nn < nSizePerThread; ++nn) {
|
||||
ret[z] = rewriter.create<LLVM::FMulAddOp>(loc, has[{m + mm, k}],
|
||||
hbs[{n + nn, k}], ret[z]);
|
||||
|
||||
++z;
|
||||
}
|
||||
}
|
||||
@@ -4288,9 +4254,10 @@ struct ExpOpConversionApprox
|
||||
// For FP64 input, call __nv_expf for higher-precision calculation
|
||||
if (elemTy.getIntOrFloatBitWidth() == 64)
|
||||
return {};
|
||||
|
||||
const double log2e = 1.4426950408889634;
|
||||
Value prod =
|
||||
rewriter.create<LLVM::FMulOp>(loc, f32_ty, operands[0], f32_val(log2e));
|
||||
Value prod = fmul(f32_ty, operands[0], f32_val(log2e));
|
||||
|
||||
PTXBuilder ptxBuilder;
|
||||
auto &exp2 = ptxBuilder.create<PTXInstr>("ex2")->o("approx").o("f32");
|
||||
auto output = ptxBuilder.newOperand("=f");
|
||||
@@ -4779,10 +4746,15 @@ private:
|
||||
decomposed = true;
|
||||
});
|
||||
|
||||
// async wait is supported in Ampere and later
|
||||
mod.walk([&](triton::gpu::AsyncWaitOp asyncWaitOp) -> void {
|
||||
if (!triton::gpu::AsyncWaitOp::isSupported(computeCapability) ||
|
||||
decomposed) {
|
||||
if (!triton::gpu::AsyncWaitOp::isSupported(computeCapability)) {
|
||||
// async wait is supported in Ampere and later
|
||||
asyncWaitOp.erase();
|
||||
} else if (decomposed) {
|
||||
// Wait for all previous async ops
|
||||
OpBuilder builder(asyncWaitOp);
|
||||
auto newAsyncWaitOp =
|
||||
builder.create<triton::gpu::AsyncWaitOp>(asyncWaitOp.getLoc(), 0);
|
||||
asyncWaitOp.erase();
|
||||
}
|
||||
});
|
||||
@@ -4827,6 +4799,8 @@ public:
|
||||
MembarAnalysis membarPass(&allocation);
|
||||
membarPass.run();
|
||||
|
||||
mod.print(llvm::errs());
|
||||
|
||||
RewritePatternSet scf_patterns(context);
|
||||
mlir::populateLoopToStdConversionPatterns(scf_patterns);
|
||||
mlir::ConversionTarget scf_target(*context);
|
||||
|
@@ -31,6 +31,7 @@
|
||||
#include <numeric>
|
||||
|
||||
// Shortcuts for some commonly used LLVM ops to keep code simple and intuitive
|
||||
// Operators
|
||||
#define inttoptr(...) rewriter.create<LLVM::IntToPtrOp>(loc, __VA_ARGS__)
|
||||
#define ptrtoint(...) rewriter.create<LLVM::PtrToIntOp>(loc, __VA_ARGS__)
|
||||
#define zext(...) rewriter.create<LLVM::ZExtOp>(loc, __VA_ARGS__)
|
||||
@@ -40,6 +41,7 @@
|
||||
#define sub(...) rewriter.create<LLVM::SubOp>(loc, __VA_ARGS__)
|
||||
#define fadd(...) rewriter.create<LLVM::FAddOp>(loc, __VA_ARGS__)
|
||||
#define mul(...) rewriter.create<LLVM::MulOp>(loc, __VA_ARGS__)
|
||||
#define fmul(...) rewriter.create<LLVM::FMulOp>(loc, __VA_ARGS__)
|
||||
#define smax(...) rewriter.create<LLVM::SMaxOp>(loc, __VA_ARGS__)
|
||||
#define umax(...) rewriter.create<LLVM::UMaxOp>(loc, __VA_ARGS__)
|
||||
#define fmax(...) rewriter.create<LLVM::MaxNumOp>(loc, __VA_ARGS__)
|
||||
@@ -90,6 +92,8 @@
|
||||
#define address_of(...) rewriter.create<LLVM::AddressOfOp>(loc, __VA_ARGS__)
|
||||
#define barrier() rewriter.create<mlir::gpu::BarrierOp>(loc)
|
||||
#define undef(...) rewriter.create<LLVM::UndefOp>(loc, __VA_ARGS__)
|
||||
|
||||
// Types
|
||||
#define i32_ty rewriter.getIntegerType(32)
|
||||
#define ui32_ty rewriter.getIntegerType(32, false)
|
||||
#define f16_ty rewriter.getF16Type()
|
||||
@@ -102,8 +106,9 @@
|
||||
#define f64_val(...) LLVM::createConstantF64(loc, rewriter, __VA_ARGS__)
|
||||
#define void_ty(ctx) LLVM::LLVMVoidType::get(ctx)
|
||||
#define struct_ty(...) LLVM::LLVMStructType::getLiteral(ctx, __VA_ARGS__)
|
||||
#define array_ty(elemTy, count) LLVM::LLVMArrayType::get(elemTy, count)
|
||||
|
||||
// Creator for constant
|
||||
// Constants
|
||||
#define i32_val(...) LLVM::createConstantI32(loc, rewriter, __VA_ARGS__)
|
||||
#define int_val(width, val) \
|
||||
LLVM::createLLVMIntegerConstant(rewriter, loc, width, val)
|
||||
|
@@ -262,10 +262,10 @@ struct TritonCatPattern : public OpConversionPattern<triton::CatOp> {
|
||||
// For now, this behaves like generic, but this will evolve when
|
||||
// we add support for `can_reorder=False`
|
||||
Type retType = this->getTypeConverter()->convertType(op.getType());
|
||||
rewriter.replaceOpWithNewOp<triton::CatOp>(op, retType, adaptor.getOperands());
|
||||
rewriter.replaceOpWithNewOp<triton::CatOp>(op, retType,
|
||||
adaptor.getOperands());
|
||||
return success();
|
||||
}
|
||||
|
||||
};
|
||||
|
||||
struct TritonTransPattern : public OpConversionPattern<triton::TransOp> {
|
||||
@@ -450,13 +450,11 @@ void populateTritonPatterns(TritonGPUTypeConverter &typeConverter,
|
||||
TritonGenericPattern<triton::IntToPtrOp>,
|
||||
TritonGenericPattern<triton::PtrToIntOp>,
|
||||
TritonGenericPattern<triton::SplatOp>, TritonBroadcastPattern,
|
||||
TritonGenericPattern<triton::AddPtrOp>,
|
||||
TritonCatPattern,
|
||||
TritonReducePattern,
|
||||
TritonTransPattern, TritonExpandDimsPattern, TritonMakeRangePattern,
|
||||
TritonDotPattern, TritonLoadPattern, TritonStorePattern,
|
||||
TritonExtElemwisePattern, TritonPrintfPattern, TritonAtomicRMWPattern>(
|
||||
typeConverter, context);
|
||||
TritonGenericPattern<triton::AddPtrOp>, TritonCatPattern,
|
||||
TritonReducePattern, TritonTransPattern, TritonExpandDimsPattern,
|
||||
TritonMakeRangePattern, TritonDotPattern, TritonLoadPattern,
|
||||
TritonStorePattern, TritonExtElemwisePattern, TritonPrintfPattern,
|
||||
TritonAtomicRMWPattern>(typeConverter, context);
|
||||
}
|
||||
|
||||
//
|
||||
|
@@ -589,15 +589,24 @@ Attribute DotOperandEncodingAttr::parse(AsmParser &parser, Type type) {
|
||||
return {};
|
||||
unsigned opIdx = attrs.get("opIdx").cast<IntegerAttr>().getInt();
|
||||
Attribute parent = attrs.get("parent");
|
||||
|
||||
Attribute isMMAv1Row;
|
||||
if(parent.isa<MmaEncodingAttr>() &&
|
||||
parent.cast<MmaEncodingAttr>().getVersion() == 1){
|
||||
isMMAv1Row = attrs.get("isMMAv1Row");
|
||||
if(!isMMAv1Row)
|
||||
llvm::report_fatal_error("isMMAv1Row attribute is missing");
|
||||
}
|
||||
return parser.getChecked<DotOperandEncodingAttr>(parser.getContext(), opIdx,
|
||||
parent);
|
||||
parent, isMMAv1Row);
|
||||
}
|
||||
|
||||
void DotOperandEncodingAttr::print(mlir::AsmPrinter &printer) const {
|
||||
printer << "<{"
|
||||
<< "opIdx = " << getOpIdx() << ", "
|
||||
<< "parent = " << getParent() << "}>";
|
||||
<< "parent = " << getParent();
|
||||
if(getIsMMAv1Row())
|
||||
printer << ", isMMAv1Row = " << getIsMMAv1Row();
|
||||
printer << "}>";
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@@ -36,7 +36,7 @@ namespace {
|
||||
class DecomposeDotOperand : public mlir::RewritePattern {
|
||||
|
||||
public:
|
||||
DecomposeDotOperand(mlir::MLIRContext *context)
|
||||
explicit DecomposeDotOperand(mlir::MLIRContext *context)
|
||||
: mlir::RewritePattern(triton::gpu::ConvertLayoutOp::getOperationName(),
|
||||
1, context) {}
|
||||
|
||||
@@ -84,7 +84,7 @@ public:
|
||||
// IIUC they are therefore not handled by DRR right now
|
||||
class SimplifyConversion : public mlir::RewritePattern {
|
||||
public:
|
||||
SimplifyConversion(mlir::MLIRContext *context)
|
||||
explicit SimplifyConversion(mlir::MLIRContext *context)
|
||||
: mlir::RewritePattern(triton::gpu::ConvertLayoutOp::getOperationName(),
|
||||
4, context) {}
|
||||
|
||||
@@ -219,8 +219,8 @@ public:
|
||||
//
|
||||
// -----------------------------------------------------------------------------
|
||||
|
||||
static LogicalResult invertEncoding(Attribute targetEncoding, Operation *op,
|
||||
Attribute &ret) {
|
||||
LogicalResult invertEncoding(Attribute targetEncoding, Operation *op,
|
||||
Attribute &ret) {
|
||||
ret = targetEncoding;
|
||||
if (auto expand_dims = dyn_cast<triton::ExpandDimsOp>(op)) {
|
||||
ret = triton::gpu::SliceEncodingAttr::get(
|
||||
@@ -246,7 +246,7 @@ inline bool expensive_to_remat(Operation *op) {
|
||||
if (isa<scf::YieldOp, scf::ForOp>(op))
|
||||
return true;
|
||||
return false;
|
||||
};
|
||||
}
|
||||
|
||||
Operation *cloneWithInferType(mlir::PatternRewriter &rewriter, Operation *op,
|
||||
BlockAndValueMapping &mapping) {
|
||||
@@ -276,7 +276,7 @@ Operation *cloneWithInferType(mlir::PatternRewriter &rewriter, Operation *op,
|
||||
// are reachable from it without passing through any memory operation.
|
||||
class RematerializeBackward : public mlir::RewritePattern {
|
||||
public:
|
||||
RematerializeBackward(mlir::MLIRContext *context)
|
||||
explicit RematerializeBackward(mlir::MLIRContext *context)
|
||||
: mlir::RewritePattern(triton::gpu::ConvertLayoutOp::getOperationName(),
|
||||
2, context) {}
|
||||
|
||||
@@ -303,7 +303,7 @@ public:
|
||||
SetVector<Attribute> layout;
|
||||
llvm::MapVector<Value, Attribute> toConvert;
|
||||
std::vector<std::pair<Operation *, Attribute>> queue;
|
||||
queue.push_back({cvt, targetType.getEncoding()});
|
||||
queue.emplace_back(cvt, targetType.getEncoding());
|
||||
int numCvts = 1;
|
||||
while (!queue.empty()) {
|
||||
Operation *currOp;
|
||||
@@ -341,7 +341,7 @@ public:
|
||||
continue;
|
||||
// we add one expensive conversion for the current operand
|
||||
numCvts += 1;
|
||||
queue.push_back({opArgI, newEncoding});
|
||||
queue.emplace_back(opArgI, newEncoding);
|
||||
}
|
||||
}
|
||||
// if rematerialization would add more conversions than it removes
|
||||
@@ -351,8 +351,8 @@ public:
|
||||
|
||||
SmallVector<Value, 4> sortedValues;
|
||||
SetVector<Operation *> tmp;
|
||||
for (auto it = toConvert.begin(); it != toConvert.end(); ++it) {
|
||||
Value v = it->first;
|
||||
for (auto &item : toConvert) {
|
||||
Value v = item.first;
|
||||
if (v.getDefiningOp())
|
||||
tmp.insert(v.getDefiningOp());
|
||||
else
|
||||
@@ -393,7 +393,7 @@ public:
|
||||
|
||||
class MoveConvertOutOfLoop : public mlir::RewritePattern {
|
||||
public:
|
||||
MoveConvertOutOfLoop(mlir::MLIRContext *context)
|
||||
explicit MoveConvertOutOfLoop(mlir::MLIRContext *context)
|
||||
: mlir::RewritePattern(scf::ForOp::getOperationName(), 1, context) {}
|
||||
|
||||
SmallVector<Value, 4>
|
||||
@@ -406,7 +406,7 @@ public:
|
||||
newInitArgs[i] = rewriter.create<triton::gpu::ConvertLayoutOp>(
|
||||
newInitArgs[i].getLoc(), newType, newInitArgs[i]);
|
||||
// Clone for loop
|
||||
scf::ForOp newForOp = rewriter.create<scf::ForOp>(
|
||||
auto newForOp = rewriter.create<scf::ForOp>(
|
||||
forOp.getLoc(), forOp.getLowerBound(), forOp.getUpperBound(),
|
||||
forOp.getStep(), newInitArgs);
|
||||
newForOp->moveBefore(forOp);
|
||||
@@ -455,7 +455,7 @@ public:
|
||||
mlir::PatternRewriter &rewriter) const override {
|
||||
auto forOp = cast<scf::ForOp>(op);
|
||||
auto iterArgs = forOp.getRegionIterArgs();
|
||||
for (auto iterArg : llvm::enumerate(iterArgs)) {
|
||||
for (const auto &iterArg : llvm::enumerate(iterArgs)) {
|
||||
// if (iterArg.index() != 1)
|
||||
// continue;
|
||||
// skip non-tensor types
|
||||
@@ -517,7 +517,7 @@ public:
|
||||
|
||||
class RematerializeForward : public mlir::RewritePattern {
|
||||
public:
|
||||
RematerializeForward(mlir::MLIRContext *context)
|
||||
explicit RematerializeForward(mlir::MLIRContext *context)
|
||||
: mlir::RewritePattern(triton::gpu::ConvertLayoutOp::getOperationName(),
|
||||
2, context) {}
|
||||
|
||||
@@ -584,7 +584,7 @@ public:
|
||||
//
|
||||
// -----------------------------------------------------------------------------
|
||||
namespace {
|
||||
static int computeCapabilityToMMAVersion(int computeCapability) {
|
||||
int computeCapabilityToMMAVersion(int computeCapability) {
|
||||
if (computeCapability < 80) {
|
||||
return 1;
|
||||
} else if (computeCapability < 90) {
|
||||
@@ -595,9 +595,7 @@ static int computeCapabilityToMMAVersion(int computeCapability) {
|
||||
}
|
||||
}
|
||||
|
||||
static SmallVector<int64_t, 2>
|
||||
mmaVersionToShapePerWarp(int version, const ArrayRef<int64_t> &shape,
|
||||
int numWarps) {
|
||||
SmallVector<int64_t, 2> mmaVersionToShapePerWarp(int version) {
|
||||
if (version == 1)
|
||||
return {16, 16};
|
||||
else if (version == 2)
|
||||
@@ -608,22 +606,23 @@ mmaVersionToShapePerWarp(int version, const ArrayRef<int64_t> &shape,
|
||||
}
|
||||
}
|
||||
|
||||
SmallVector<unsigned, 2> warpsPerTileV1(triton::DotOp dotOp,
|
||||
const ArrayRef<int64_t> shape,
|
||||
SmallVector<unsigned, 2> warpsPerTileV1(const ArrayRef<int64_t> shape,
|
||||
int numWarps) {
|
||||
SmallVector<unsigned, 2> ret = {1, 1};
|
||||
SmallVector<int64_t, 2> shapePerWarp =
|
||||
mmaVersionToShapePerWarp(1, shape, numWarps);
|
||||
mmaVersionToShapePerWarp(1 /*version*/);
|
||||
bool changed = false;
|
||||
do {
|
||||
changed = false;
|
||||
int pre = ret[0];
|
||||
if (ret[0] * ret[1] < numWarps) {
|
||||
ret[0] = std::clamp<unsigned>(ret[0] * 2, 1, shape[0] / shapePerWarp[0]);
|
||||
changed = true;
|
||||
changed = pre != ret[0];
|
||||
}
|
||||
if (ret[0] * ret[1] < numWarps) {
|
||||
pre = ret[1];
|
||||
ret[1] = std::clamp<unsigned>(ret[1] * 2, 1, shape[1] / shapePerWarp[1]);
|
||||
changed = true;
|
||||
changed = pre != ret[1];
|
||||
}
|
||||
} while (changed);
|
||||
return ret;
|
||||
@@ -667,7 +666,7 @@ SmallVector<unsigned, 2> warpsPerTileV2(triton::DotOp dotOp,
|
||||
|
||||
class OptimizeBlockedToShared : public mlir::RewritePattern {
|
||||
public:
|
||||
OptimizeBlockedToShared(mlir::MLIRContext *context)
|
||||
explicit OptimizeBlockedToShared(mlir::MLIRContext *context)
|
||||
: RewritePattern(triton::gpu::ConvertLayoutOp::getOperationName(), 1,
|
||||
context) {}
|
||||
|
||||
@@ -713,6 +712,53 @@ public:
|
||||
}
|
||||
};
|
||||
|
||||
class OptimizeConvertToDotOperand : public mlir::RewritePattern {
|
||||
public:
|
||||
explicit OptimizeConvertToDotOperand(mlir::MLIRContext *context)
|
||||
: RewritePattern(triton::gpu::ConvertLayoutOp::getOperationName(), 1,
|
||||
context) {}
|
||||
|
||||
mlir::LogicalResult
|
||||
matchAndRewrite(mlir::Operation *op,
|
||||
mlir::PatternRewriter &rewriter) const override {
|
||||
auto cvt = cast<triton::gpu::ConvertLayoutOp>(op);
|
||||
auto srcType = cvt.getOperand().getType().cast<RankedTensorType>();
|
||||
auto dstType = cvt.getResult().getType().cast<RankedTensorType>();
|
||||
// order
|
||||
ArrayRef<unsigned> order;
|
||||
if (auto srcBlockedLayout =
|
||||
srcType.getEncoding().dyn_cast<triton::gpu::BlockedEncodingAttr>())
|
||||
order = srcBlockedLayout.getOrder();
|
||||
else if (auto srcSharedLayout =
|
||||
srcType.getEncoding()
|
||||
.dyn_cast<triton::gpu::SharedEncodingAttr>())
|
||||
order = srcSharedLayout.getOrder();
|
||||
else
|
||||
return failure();
|
||||
// dot operand output
|
||||
auto dstDotOperandLayout =
|
||||
dstType.getEncoding().dyn_cast<triton::gpu::DotOperandEncodingAttr>();
|
||||
if (!dstDotOperandLayout)
|
||||
return failure();
|
||||
if (!dstDotOperandLayout.getIsMMAv1Row())
|
||||
return failure();
|
||||
bool isMMAv1Row =
|
||||
dstDotOperandLayout.getIsMMAv1Row().cast<BoolAttr>().getValue();
|
||||
if ((order[0] == 1 && isMMAv1Row) || (order[0] == 0 && !isMMAv1Row))
|
||||
return failure();
|
||||
auto newIsRow = BoolAttr::get(op->getContext(), !isMMAv1Row);
|
||||
auto newDstEncoding = triton::gpu::DotOperandEncodingAttr::get(
|
||||
op->getContext(), dstDotOperandLayout.getOpIdx(),
|
||||
dstDotOperandLayout.getParent(), newIsRow);
|
||||
auto newDstType = RankedTensorType::get(
|
||||
dstType.getShape(), dstType.getElementType(), newDstEncoding);
|
||||
auto newCvt = rewriter.create<triton::gpu::ConvertLayoutOp>(
|
||||
op->getLoc(), newDstType, cvt.getOperand());
|
||||
rewriter.replaceOp(op, newCvt.getResult());
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
class BlockedToMMA : public mlir::RewritePattern {
|
||||
int computeCapability;
|
||||
|
||||
@@ -726,7 +772,7 @@ public:
|
||||
int version, int numWarps) {
|
||||
switch (version) {
|
||||
case 1:
|
||||
return warpsPerTileV1(dotOp, shape, numWarps);
|
||||
return warpsPerTileV1(shape, numWarps);
|
||||
case 2:
|
||||
return warpsPerTileV2(dotOp, shape, numWarps);
|
||||
default:
|
||||
@@ -744,18 +790,16 @@ public:
|
||||
if (oldRetType.getEncoding().isa<triton::gpu::MmaEncodingAttr>())
|
||||
return failure();
|
||||
|
||||
auto A = dotOp.getOperand(0).getType().cast<RankedTensorType>();
|
||||
auto B = dotOp.getOperand(1).getType().cast<RankedTensorType>();
|
||||
int version = computeCapabilityToMMAVersion(computeCapability);
|
||||
|
||||
// for FMA, should retain the blocked layout.
|
||||
if (A.getElementType().isF32() && B.getElementType().isF32() &&
|
||||
!dotOp.allowTF32())
|
||||
if (!supportMMA(dotOp, version))
|
||||
return failure();
|
||||
|
||||
// get MMA encoding for the given number of warps
|
||||
auto retShape = oldRetType.getShape();
|
||||
auto mod = op->getParentOfType<mlir::ModuleOp>();
|
||||
int numWarps = triton::gpu::TritonGPUDialect::getNumWarps(mod);
|
||||
int version = computeCapabilityToMMAVersion(computeCapability);
|
||||
|
||||
auto newRetType = RankedTensorType::get(
|
||||
retShape, oldRetType.getElementType(),
|
||||
@@ -770,18 +814,36 @@ public:
|
||||
Value b = dotOp.b();
|
||||
auto oldAType = a.getType().cast<RankedTensorType>();
|
||||
auto oldBType = b.getType().cast<RankedTensorType>();
|
||||
auto oldAOrder = oldAType.getEncoding()
|
||||
.cast<triton::gpu::DotOperandEncodingAttr>()
|
||||
.getParent()
|
||||
.cast<triton::gpu::BlockedEncodingAttr>()
|
||||
.getOrder();
|
||||
auto oldBOrder = oldBType.getEncoding()
|
||||
.cast<triton::gpu::DotOperandEncodingAttr>()
|
||||
.getParent()
|
||||
.cast<triton::gpu::BlockedEncodingAttr>()
|
||||
.getOrder();
|
||||
Attribute isMMAv1RowA;
|
||||
Attribute isMMAv1RowB;
|
||||
if (version == 1) {
|
||||
isMMAv1RowA = BoolAttr::get(getContext(), oldAOrder[0] == 1);
|
||||
isMMAv1RowB = BoolAttr::get(getContext(), oldBOrder[0] == 1);
|
||||
}
|
||||
|
||||
auto newAType = RankedTensorType::get(
|
||||
oldAType.getShape(), oldAType.getElementType(),
|
||||
triton::gpu::DotOperandEncodingAttr::get(oldAType.getContext(), 0,
|
||||
newRetType.getEncoding()));
|
||||
triton::gpu::DotOperandEncodingAttr::get(
|
||||
oldAType.getContext(), 0, newRetType.getEncoding(), isMMAv1RowA));
|
||||
auto newBType = RankedTensorType::get(
|
||||
oldBType.getShape(), oldBType.getElementType(),
|
||||
triton::gpu::DotOperandEncodingAttr::get(oldBType.getContext(), 1,
|
||||
newRetType.getEncoding()));
|
||||
triton::gpu::DotOperandEncodingAttr::get(
|
||||
oldBType.getContext(), 1, newRetType.getEncoding(), isMMAv1RowB));
|
||||
|
||||
a = rewriter.create<triton::gpu::ConvertLayoutOp>(a.getLoc(), newAType, a);
|
||||
b = rewriter.create<triton::gpu::ConvertLayoutOp>(b.getLoc(), newBType, b);
|
||||
auto newDot = rewriter.create<triton::DotOp>(
|
||||
dotOp.getLoc(), newRetType, a, b, newAcc, dotOp.allowTF32());
|
||||
auto newDot = rewriter.create<triton::DotOp>(dotOp.getLoc(), newRetType, a,
|
||||
b, newAcc, dotOp.allowTF32());
|
||||
|
||||
rewriter.replaceOpWithNewOp<triton::gpu::ConvertLayoutOp>(
|
||||
op, oldRetType, newDot.getResult());
|
||||
@@ -789,6 +851,48 @@ public:
|
||||
}
|
||||
};
|
||||
|
||||
class FixupLoop : public mlir::RewritePattern {
|
||||
|
||||
public:
|
||||
explicit FixupLoop(mlir::MLIRContext *context)
|
||||
: mlir::RewritePattern(scf::ForOp::getOperationName(), 2, context) {}
|
||||
|
||||
mlir::LogicalResult
|
||||
matchAndRewrite(mlir::Operation *op,
|
||||
mlir::PatternRewriter &rewriter) const override {
|
||||
auto forOp = cast<scf::ForOp>(op);
|
||||
|
||||
// Rewrite init argument
|
||||
SmallVector<Value, 4> newInitArgs = forOp.getInitArgs();
|
||||
bool shouldRematerialize = false;
|
||||
for (size_t i = 0; i < newInitArgs.size(); i++) {
|
||||
auto initArg = newInitArgs[i];
|
||||
auto regionArg = forOp.getRegionIterArgs()[i];
|
||||
if (newInitArgs[i].getType() != forOp.getRegionIterArgs()[i].getType()) {
|
||||
shouldRematerialize = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (!shouldRematerialize)
|
||||
return failure();
|
||||
|
||||
scf::ForOp newForOp = rewriter.create<scf::ForOp>(
|
||||
forOp.getLoc(), forOp.getLowerBound(), forOp.getUpperBound(),
|
||||
forOp.getStep(), newInitArgs);
|
||||
newForOp->moveBefore(forOp);
|
||||
rewriter.setInsertionPointToStart(newForOp.getBody());
|
||||
BlockAndValueMapping mapping;
|
||||
for (const auto &arg : llvm::enumerate(forOp.getRegionIterArgs()))
|
||||
mapping.map(arg.value(), newForOp.getRegionIterArgs()[arg.index()]);
|
||||
|
||||
for (Operation &op : forOp.getBody()->getOperations()) {
|
||||
Operation *newOp = rewriter.clone(op, mapping);
|
||||
}
|
||||
rewriter.replaceOp(forOp, newForOp.getResults());
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace
|
||||
|
||||
#define GEN_PASS_CLASSES
|
||||
@@ -808,6 +912,7 @@ public:
|
||||
mlir::RewritePatternSet patterns(context);
|
||||
|
||||
patterns.add<OptimizeBlockedToShared>(context);
|
||||
patterns.add<OptimizeConvertToDotOperand>(context);
|
||||
patterns.add<SimplifyConversion>(context);
|
||||
patterns.add<DecomposeDotOperand>(context);
|
||||
patterns.add<RematerializeBackward>(context);
|
||||
@@ -818,6 +923,13 @@ public:
|
||||
if (applyPatternsAndFoldGreedily(m, std::move(patterns)).failed()) {
|
||||
signalPassFailure();
|
||||
}
|
||||
|
||||
// llvm::outs() << m << "\n";
|
||||
mlir::RewritePatternSet loopFixup(context);
|
||||
loopFixup.add<FixupLoop>(context);
|
||||
if (applyPatternsAndFoldGreedily(m, std::move(loopFixup)).failed()) {
|
||||
signalPassFailure();
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
|
@@ -225,6 +225,7 @@ scf::ForOp Prefetcher::createNewForOp() {
|
||||
BlockAndValueMapping mapping;
|
||||
for (const auto &arg : llvm::enumerate(forOp.getRegionIterArgs()))
|
||||
mapping.map(arg.value(), newForOp.getRegionIterArgs()[arg.index()]);
|
||||
mapping.map(forOp.getInductionVar(), newForOp.getInductionVar());
|
||||
|
||||
for (Operation &op : forOp.getBody()->without_terminator()) {
|
||||
Operation *newOp = builder.clone(op, mapping);
|
||||
|
@@ -120,6 +120,7 @@ void init_triton_ir(py::module &&m) {
|
||||
// some placeholders
|
||||
self.getOrLoadDialect<mlir::triton::TritonDialect>();
|
||||
self.getOrLoadDialect<mlir::LLVM::LLVMDialect>();
|
||||
self.getOrLoadDialect<mlir::gpu::GPUDialect>();
|
||||
});
|
||||
// .def(py::init([](){
|
||||
// mlir::MLIRContext context;
|
||||
@@ -1265,7 +1266,13 @@ void init_triton_ir(py::module &&m) {
|
||||
[](mlir::OpBuilder &self, mlir::Type &type) -> mlir::Value {
|
||||
auto loc = self.getUnknownLoc();
|
||||
return self.create<::mlir::LLVM::UndefOp>(loc, type);
|
||||
});
|
||||
})
|
||||
// Force GPU barrier
|
||||
.def("create_barrier",
|
||||
[](mlir::OpBuilder &self) {
|
||||
auto loc = self.getUnknownLoc();
|
||||
self.create<mlir::gpu::BarrierOp>(loc);
|
||||
});
|
||||
|
||||
py::class_<mlir::PassManager>(m, "pass_manager")
|
||||
.def(py::init<mlir::MLIRContext *>())
|
||||
|
156
python/tests/matmul.ttgir
Normal file
156
python/tests/matmul.ttgir
Normal file
@@ -0,0 +1,156 @@
|
||||
#blocked0 = #triton_gpu.blocked<{sizePerThread = [1, 8], threadsPerWarp = [8, 4], warpsPerCTA = [8, 1], order = [1, 0]}>
|
||||
#blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 8], threadsPerWarp = [2, 16], warpsPerCTA = [8, 1], order = [1, 0]}>
|
||||
#mma = #triton_gpu.mma<{version = 1, warpsPerCTA = [4, 2]}>
|
||||
#shared0 = #triton_gpu.shared<{vec = 4, perPhase = 2, maxPhase = 4, order = [1, 0]}>
|
||||
#shared1 = #triton_gpu.shared<{vec = 4, perPhase = 1, maxPhase = 4, order = [1, 0]}>
|
||||
module attributes {"triton_gpu.num-warps" = 8 : i32} {
|
||||
func public @_kernel_0d1d2d3d4d5d6d7d8d(%arg0: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %arg1: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %arg2: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %arg3: i32 {tt.divisibility = 16 : i32}, %arg4: i32 {tt.divisibility = 16 : i32}, %arg5: i32 {tt.divisibility = 16 : i32}, %arg6: i32 {tt.divisibility = 16 : i32}, %arg7: i32 {tt.divisibility = 16 : i32}, %arg8: i32 {tt.divisibility = 16 : i32}) {
|
||||
%c2_i32 = arith.constant 2 : i32
|
||||
%c1_i32 = arith.constant 1 : i32
|
||||
%c0_i32 = arith.constant 0 : index
|
||||
%cst = arith.constant dense<32> : tensor<256x32xi32, #blocked0>
|
||||
%cst_0 = arith.constant dense<0.000000e+00> : tensor<256x128xf32, #mma>
|
||||
%c8_i32 = arith.constant 8 : i32
|
||||
%c255_i32 = arith.constant 255 : i32
|
||||
%c127_i32 = arith.constant 127 : i32
|
||||
%c32_i32 = arith.constant 32 : i32
|
||||
%c0 = arith.constant 0 : index
|
||||
%c32 = arith.constant 32 : index
|
||||
%c256_i32 = arith.constant 256 : i32
|
||||
%c128_i32 = arith.constant 128 : i32
|
||||
%0 = tt.get_program_id {axis = 0 : i32} : i32
|
||||
%1 = tt.get_program_id {axis = 1 : i32} : i32
|
||||
%2 = arith.addi %arg3, %c255_i32 : i32
|
||||
%3 = arith.divsi %2, %c256_i32 : i32
|
||||
%4 = arith.addi %arg4, %c127_i32 : i32
|
||||
%5 = arith.divsi %4, %c128_i32 : i32
|
||||
%6 = arith.muli %5, %c8_i32 : i32
|
||||
%7 = arith.divsi %0, %6 : i32
|
||||
%8 = arith.muli %7, %c8_i32 : i32
|
||||
%9 = arith.subi %3, %8 : i32
|
||||
%10 = arith.cmpi slt, %9, %c8_i32 : i32
|
||||
%11 = select %10, %9, %c8_i32 : i32
|
||||
%12 = arith.remsi %0, %11 : i32
|
||||
%13 = arith.addi %8, %12 : i32
|
||||
%14 = arith.remsi %0, %6 : i32
|
||||
%15 = arith.divsi %14, %11 : i32
|
||||
%16 = arith.muli %13, %c256_i32 : i32
|
||||
%17 = tt.make_range {end = 256 : i32, start = 0 : i32} : tensor<256xi32, #triton_gpu.slice<{dim = 1, parent = #blocked0}>>
|
||||
%18 = tt.make_range {end = 256 : i32, start = 0 : i32} : tensor<256xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>>
|
||||
%19 = tt.splat %16 : (i32) -> tensor<256xi32, #triton_gpu.slice<{dim = 1, parent = #blocked0}>>
|
||||
%20 = tt.splat %16 : (i32) -> tensor<256xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>>
|
||||
%21 = arith.muli %15, %c128_i32 : i32
|
||||
%22 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>>
|
||||
%23 = tt.splat %21 : (i32) -> tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>>
|
||||
%24 = tt.splat %arg3 : (i32) -> tensor<256xi32, #triton_gpu.slice<{dim = 1, parent = #blocked0}>>
|
||||
%25 = tt.splat %arg3 : (i32) -> tensor<256xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>>
|
||||
%26 = tt.splat %arg4 : (i32) -> tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>>
|
||||
%27 = arith.muli %1, %c32_i32 : i32
|
||||
%28 = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32, #triton_gpu.slice<{dim = 0, parent = #blocked0}>>
|
||||
%29 = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>>
|
||||
%30 = tt.splat %27 : (i32) -> tensor<32xi32, #triton_gpu.slice<{dim = 0, parent = #blocked0}>>
|
||||
%31 = tt.splat %27 : (i32) -> tensor<32xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>>
|
||||
%32 = arith.addi %19, %17 : tensor<256xi32, #triton_gpu.slice<{dim = 1, parent = #blocked0}>>
|
||||
%33 = arith.remsi %32, %24 : tensor<256xi32, #triton_gpu.slice<{dim = 1, parent = #blocked0}>>
|
||||
%34 = tt.splat %arg6 : (i32) -> tensor<256x1xi32, #blocked0>
|
||||
%35 = arith.addi %30, %28 : tensor<32xi32, #triton_gpu.slice<{dim = 0, parent = #blocked0}>>
|
||||
%36 = tt.expand_dims %35 {axis = 0 : i32} : (tensor<32xi32, #triton_gpu.slice<{dim = 0, parent = #blocked0}>>) -> tensor<1x32xi32, #blocked0>
|
||||
%37 = tt.broadcast %36 : (tensor<1x32xi32, #blocked0>) -> tensor<256x32xi32, #blocked0>
|
||||
%38 = tt.splat %arg0 : (!tt.ptr<f16>) -> tensor<256x32x!tt.ptr<f16>, #blocked0>
|
||||
%39 = arith.addi %31, %29 : tensor<32xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>>
|
||||
%40 = tt.expand_dims %39 {axis = 1 : i32} : (tensor<32xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>>) -> tensor<32x1xi32, #blocked1>
|
||||
%41 = tt.splat %arg7 : (i32) -> tensor<32x1xi32, #blocked1>
|
||||
%42 = arith.addi %23, %22 : tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>>
|
||||
%43 = arith.remsi %42, %26 : tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>>
|
||||
%44 = tt.expand_dims %43 {axis = 0 : i32} : (tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>>) -> tensor<1x128xi32, #blocked1>
|
||||
%45 = tt.broadcast %44 : (tensor<1x128xi32, #blocked1>) -> tensor<32x128xi32, #blocked1>
|
||||
%46 = tt.splat %arg1 : (!tt.ptr<f16>) -> tensor<32x128x!tt.ptr<f16>, #blocked1>
|
||||
%47 = arith.index_cast %arg5 : i32 to index
|
||||
%48 = arith.muli %arg7, %c32_i32 : i32
|
||||
%49 = tt.splat %48 : (i32) -> tensor<32x128xi32, #blocked1>
|
||||
%50 = tt.expand_dims %33 {axis = 1 : i32} : (tensor<256xi32, #triton_gpu.slice<{dim = 1, parent = #blocked0}>>) -> tensor<256x1xi32, #blocked0>
|
||||
%51 = arith.muli %50, %34 : tensor<256x1xi32, #blocked0>
|
||||
%52 = tt.broadcast %51 : (tensor<256x1xi32, #blocked0>) -> tensor<256x32xi32, #blocked0>
|
||||
%53 = arith.addi %52, %37 : tensor<256x32xi32, #blocked0>
|
||||
%54 = tt.addptr %38, %53 : tensor<256x32x!tt.ptr<f16>, #blocked0>, tensor<256x32xi32, #blocked0>
|
||||
%55 = arith.muli %40, %41 : tensor<32x1xi32, #blocked1>
|
||||
%56 = tt.broadcast %55 : (tensor<32x1xi32, #blocked1>) -> tensor<32x128xi32, #blocked1>
|
||||
%57 = arith.addi %56, %45 : tensor<32x128xi32, #blocked1>
|
||||
%58 = tt.addptr %46, %57 : tensor<32x128x!tt.ptr<f16>, #blocked1>, tensor<32x128xi32, #blocked1>
|
||||
%59 = arith.cmpi slt, %c0, %47 : index
|
||||
%60 = triton_gpu.alloc_tensor : tensor<2x256x32xf16, #shared0>
|
||||
%64 = triton_gpu.alloc_tensor : tensor<2x32x128xf16, #shared1>
|
||||
%61 = tt.splat %59 : (i1) -> tensor<256x32xi1, #blocked0>
|
||||
%65 = tt.splat %59 : (i1) -> tensor<32x128xi1, #blocked1>
|
||||
%62 = tt.load %54, %61 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<256x32xf16, #blocked0>
|
||||
%66 = tt.load %58, %65 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<32x128xf16, #blocked1>
|
||||
%63 = tensor.insert_slice %62 into %60[%c0_i32, 0, 0] [1, 256, 32] [1, 1, 1] : tensor<256x32xf16, #blocked0> into tensor<2x256x32xf16, #shared0>
|
||||
%67 = tensor.insert_slice %66 into %64[%c0_i32, 0, 0] [1, 32, 128] [1, 1, 1] : tensor<32x128xf16, #blocked1> into tensor<2x32x128xf16, #shared1>
|
||||
%68 = tt.addptr %54, %cst : tensor<256x32x!tt.ptr<f16>, #blocked0>, tensor<256x32xi32, #blocked0>
|
||||
%69 = tt.addptr %58, %49 : tensor<32x128x!tt.ptr<f16>, #blocked1>, tensor<32x128xi32, #blocked1>
|
||||
%70 = tensor.extract_slice %63[0, 0, 0] [1, 256, 32] [1, 1, 1] : tensor<2x256x32xf16, #shared0> to tensor<256x32xf16, #shared0>
|
||||
%71 = tensor.extract_slice %67[0, 0, 0] [1, 32, 128] [1, 1, 1] : tensor<2x32x128xf16, #shared1> to tensor<32x128xf16, #shared1>
|
||||
%72 = tensor.extract_slice %70[0, 0] [256, 16] [1, 1] : tensor<256x32xf16, #shared0> to tensor<256x16xf16, #shared0>
|
||||
gpu.barrier
|
||||
%73 = triton_gpu.convert_layout %72 : (tensor<256x16xf16, #shared0>) -> tensor<256x16xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, isMMAv1Row = true}>>
|
||||
%74 = tensor.extract_slice %71[0, 0] [16, 128] [1, 1] : tensor<32x128xf16, #shared1> to tensor<16x128xf16, #shared1>
|
||||
%75 = triton_gpu.convert_layout %74 : (tensor<16x128xf16, #shared1>) -> tensor<16x128xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, isMMAv1Row = true}>>
|
||||
%76:14 = scf.for %arg9 = %c0 to %47 step %c32 iter_args(%arg10 = %cst_0, %arg11 = %54, %arg12 = %58, %arg13 = %63, %arg14 = %67, %arg15 = %70, %arg16 = %71, %arg17 = %68, %arg18 = %69, %arg19 = %c0, %arg20 = %c1_i32, %arg21 = %c1_i32, %arg22 = %73, %arg23 = %75) -> (tensor<256x128xf32, #mma>, tensor<256x32x!tt.ptr<f16>, #blocked0>, tensor<32x128x!tt.ptr<f16>, #blocked1>, tensor<2x256x32xf16, #shared0>, tensor<2x32x128xf16, #shared1>, tensor<256x32xf16, #shared0>, tensor<32x128xf16, #shared1>, tensor<256x32x!tt.ptr<f16>, #blocked0>, tensor<32x128x!tt.ptr<f16>, #blocked1>, index, i32, i32, tensor<256x16xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, isMMAv1Row = true}>>, tensor<16x128xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, isMMAv1Row = true}>>) {
|
||||
%104 = arith.addi %arg19, %c32 : index
|
||||
%105 = arith.cmpi slt, %104, %47 : index
|
||||
%106 = arith.remsi %arg20, %c2_i32 : i32
|
||||
%107 = arith.remsi %arg21, %c2_i32 : i32
|
||||
%108 = arith.index_cast %107 : i32 to index
|
||||
%200 = arith.index_cast %106 : i32 to index
|
||||
%109 = tt.splat %105 : (i1) -> tensor<256x32xi1, #blocked0>
|
||||
%112 = tt.splat %105 : (i1) -> tensor<32x128xi1, #blocked1>
|
||||
%110 = tt.load %arg17, %109 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<256x32xf16, #blocked0>
|
||||
%113 = tt.load %arg18, %112 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<32x128xf16, #blocked1>
|
||||
%96 = tt.dot %arg22, %arg23, %arg10 {allowTF32 = true} : tensor<256x16xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, isMMAv1Row = true}>> * tensor<16x128xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, isMMAv1Row = true}>> -> tensor<256x128xf32, #mma>
|
||||
%97 = tensor.extract_slice %arg15[0, 16] [256, 16] [1, 1] : tensor<256x32xf16, #shared0> to tensor<256x16xf16, #shared0>
|
||||
%98 = triton_gpu.convert_layout %97 : (tensor<256x16xf16, #shared0>) -> tensor<256x16xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, isMMAv1Row = true}>>
|
||||
%99 = tensor.extract_slice %arg16[16, 0] [16, 128] [1, 1] : tensor<32x128xf16, #shared1> to tensor<16x128xf16, #shared1>
|
||||
%100 = triton_gpu.convert_layout %99 : (tensor<16x128xf16, #shared1>) -> tensor<16x128xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, isMMAv1Row = true}>>
|
||||
%101 = tt.dot %98, %100, %96 {allowTF32 = true} : tensor<256x16xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, isMMAv1Row = true}>> * tensor<16x128xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, isMMAv1Row = true}>> -> tensor<256x128xf32, #mma>
|
||||
%102 = tt.addptr %arg11, %cst : tensor<256x32x!tt.ptr<f16>, #blocked0>, tensor<256x32xi32, #blocked0>
|
||||
%103 = tt.addptr %arg12, %49 : tensor<32x128x!tt.ptr<f16>, #blocked1>, tensor<32x128xi32, #blocked1>
|
||||
gpu.barrier
|
||||
%111 = tensor.insert_slice %110 into %arg13[%200, 0, 0] [1, 256, 32] [1, 1, 1] : tensor<256x32xf16, #blocked0> into tensor<2x256x32xf16, #shared0>
|
||||
%114 = tensor.insert_slice %113 into %arg14[%200, 0, 0] [1, 32, 128] [1, 1, 1] : tensor<32x128xf16, #blocked1> into tensor<2x32x128xf16, #shared1>
|
||||
gpu.barrier
|
||||
%115 = tt.addptr %arg17, %cst : tensor<256x32x!tt.ptr<f16>, #blocked0>, tensor<256x32xi32, #blocked0>
|
||||
%116 = tt.addptr %arg18, %49 : tensor<32x128x!tt.ptr<f16>, #blocked1>, tensor<32x128xi32, #blocked1>
|
||||
%117 = tensor.extract_slice %111[%108, 0, 0] [1, 256, 32] [1, 1, 1] : tensor<2x256x32xf16, #shared0> to tensor<256x32xf16, #shared0>
|
||||
%118 = tensor.extract_slice %114[%108, 0, 0] [1, 32, 128] [1, 1, 1] : tensor<2x32x128xf16, #shared1> to tensor<32x128xf16, #shared1>
|
||||
%119 = arith.addi %arg20, %c1_i32 : i32
|
||||
%120 = arith.addi %arg21, %c1_i32 : i32
|
||||
%121 = tensor.extract_slice %117[0, 0] [256, 16] [1, 1] : tensor<256x32xf16, #shared0> to tensor<256x16xf16, #shared0>
|
||||
%122 = triton_gpu.convert_layout %121 : (tensor<256x16xf16, #shared0>) -> tensor<256x16xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, isMMAv1Row = true}>>
|
||||
%123 = tensor.extract_slice %118[0, 0] [16, 128] [1, 1] : tensor<32x128xf16, #shared1> to tensor<16x128xf16, #shared1>
|
||||
%124 = triton_gpu.convert_layout %123 : (tensor<16x128xf16, #shared1>) -> tensor<16x128xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, isMMAv1Row = true}>>
|
||||
scf.yield %101, %102, %103, %111, %114, %117, %118, %115, %116, %104, %119, %120, %122, %124 : tensor<256x128xf32, #mma>, tensor<256x32x!tt.ptr<f16>, #blocked0>, tensor<32x128x!tt.ptr<f16>, #blocked1>, tensor<2x256x32xf16, #shared0>, tensor<2x32x128xf16, #shared1>, tensor<256x32xf16, #shared0>, tensor<32x128xf16, #shared1>, tensor<256x32x!tt.ptr<f16>, #blocked0>, tensor<32x128x!tt.ptr<f16>, #blocked1>, index, i32, i32, tensor<256x16xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, isMMAv1Row = true}>>, tensor<16x128xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, isMMAv1Row = true}>>
|
||||
}
|
||||
gpu.barrier
|
||||
%77 = triton_gpu.convert_layout %76#0 : (tensor<256x128xf32, #mma>) -> tensor<256x128xf32, #blocked1>
|
||||
%78 = arith.addi %20, %18 : tensor<256xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>>
|
||||
%79 = tt.splat %arg8 : (i32) -> tensor<256x1xi32, #blocked1>
|
||||
%80 = tt.expand_dims %42 {axis = 0 : i32} : (tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>>) -> tensor<1x128xi32, #blocked1>
|
||||
%81 = tt.broadcast %80 : (tensor<1x128xi32, #blocked1>) -> tensor<256x128xi32, #blocked1>
|
||||
%82 = tt.splat %arg2 : (!tt.ptr<f16>) -> tensor<256x128x!tt.ptr<f16>, #blocked1>
|
||||
%83 = "triton_gpu.cmpi"(%78, %25) {predicate = 2 : i64} : (tensor<256xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>>, tensor<256xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>>) -> tensor<256xi1, #triton_gpu.slice<{dim = 1, parent = #blocked1}>>
|
||||
%84 = "triton_gpu.cmpi"(%42, %26) {predicate = 2 : i64} : (tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>>, tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>>) -> tensor<128xi1, #triton_gpu.slice<{dim = 0, parent = #blocked1}>>
|
||||
%85 = tt.expand_dims %84 {axis = 0 : i32} : (tensor<128xi1, #triton_gpu.slice<{dim = 0, parent = #blocked1}>>) -> tensor<1x128xi1, #blocked1>
|
||||
%86 = tt.broadcast %85 : (tensor<1x128xi1, #blocked1>) -> tensor<256x128xi1, #blocked1>
|
||||
%87 = tt.expand_dims %78 {axis = 1 : i32} : (tensor<256xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>>) -> tensor<256x1xi32, #blocked1>
|
||||
%88 = arith.muli %87, %79 : tensor<256x1xi32, #blocked1>
|
||||
%89 = tt.broadcast %88 : (tensor<256x1xi32, #blocked1>) -> tensor<256x128xi32, #blocked1>
|
||||
%90 = arith.addi %89, %81 : tensor<256x128xi32, #blocked1>
|
||||
%91 = tt.addptr %82, %90 : tensor<256x128x!tt.ptr<f16>, #blocked1>, tensor<256x128xi32, #blocked1>
|
||||
%92 = arith.truncf %77 : tensor<256x128xf32, #blocked1> to tensor<256x128xf16, #blocked1>
|
||||
%93 = tt.expand_dims %83 {axis = 1 : i32} : (tensor<256xi1, #triton_gpu.slice<{dim = 1, parent = #blocked1}>>) -> tensor<256x1xi1, #blocked1>
|
||||
%94 = tt.broadcast %93 : (tensor<256x1xi1, #blocked1>) -> tensor<256x128xi1, #blocked1>
|
||||
%95 = arith.andi %94, %86 : tensor<256x128xi1, #blocked1>
|
||||
tt.store %91, %92, %95 : tensor<256x128xf16, #blocked1>
|
||||
return
|
||||
}
|
||||
}
|
@@ -32,7 +32,7 @@ def matmul_no_scf_kernel(
|
||||
(shape, num_warps, trans_a, trans_b)
|
||||
for shape in [
|
||||
[128, 256, 32],
|
||||
[256, 128, 16],
|
||||
# [256, 128, 16],
|
||||
[128, 16, 32],
|
||||
[32, 128, 64],
|
||||
[128, 128, 64],
|
||||
@@ -72,7 +72,7 @@ def test_gemm_no_scf(SHAPE, NUM_WARPS, TRANS_A, TRANS_B):
|
||||
for shape in [
|
||||
[64, 128, 128],
|
||||
[128, 128, 128],
|
||||
[16, 8, 32],
|
||||
[16, 16, 32],
|
||||
[32, 16, 64],
|
||||
[32, 16, 64],
|
||||
]
|
||||
@@ -81,6 +81,8 @@ def test_gemm_no_scf(SHAPE, NUM_WARPS, TRANS_A, TRANS_B):
|
||||
for trans_b in [False, True]
|
||||
])
|
||||
def test_gemm_no_scf_int8(SHAPE, NUM_WARPS, TRANS_A, TRANS_B):
|
||||
guard_for_volta(is_int8=True)
|
||||
|
||||
SIZE_M, SIZE_N, SIZE_K = SHAPE
|
||||
|
||||
if (TRANS_A):
|
||||
@@ -195,6 +197,7 @@ def get_proper_err(a, b, golden):
|
||||
[128, 64, 128, 4, 128, 64, 32, False, True],
|
||||
])
|
||||
def test_gemm(SIZE_M, SIZE_N, SIZE_K, NUM_WARPS, BLOCK_SIZE_M, BLOCK_SIZE_N, BLOCK_SIZE_K, TRANS_A, TRANS_B):
|
||||
|
||||
if (TRANS_A):
|
||||
a = torch.randn((SIZE_K, SIZE_M), device='cuda', dtype=torch.float16).T
|
||||
else:
|
||||
@@ -270,6 +273,8 @@ def test_gemm_fp32(M, N, K, num_warps, block_M, block_N, block_K, allow_tf32):
|
||||
c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N)
|
||||
tl.store(c_ptrs, accumulator, c_mask)
|
||||
|
||||
guard_for_volta(is_tf32=allow_tf32)
|
||||
|
||||
# Configure the pytorch counterpart
|
||||
torch.backends.cuda.matmul.allow_tf32 = allow_tf32
|
||||
|
||||
@@ -294,18 +299,16 @@ def test_gemm_fp32(M, N, K, num_warps, block_M, block_N, block_K, allow_tf32):
|
||||
torch.testing.assert_close(c, golden, rtol=max(1e-4, 1.5 * golden_rel_err), atol=max(1e-4, 1.5 * golden_abs_err))
|
||||
|
||||
|
||||
# NOTE this is useful only on Volta GPU.
|
||||
@pytest.mark.parametrize('SIZE_M,SIZE_N,SIZE_K,NUM_WARPS,BLOCK_SIZE_M,BLOCK_SIZE_N,BLOCK_SIZE_K,TRANS_A,TRANS_B', [
|
||||
# Non-forloop
|
||||
[16, 16, 16, 1, 16, 16, 16, False, False],
|
||||
[16, 16, 32, 1, 16, 16, 32, False, False],
|
||||
[32, 16, 32, 1, 32, 16, 32, False, False],
|
||||
[32, 32, 32, 1, 32, 32, 32, False, False],
|
||||
[128, 32, 32, 1, 128, 32, 32, False, False],
|
||||
def guard_for_volta(is_int8=False, is_tf32=False):
|
||||
'''
|
||||
Tell whether the test case is valid on Volta GPU.
|
||||
Some features are WIP, so the corresponding support are missing.
|
||||
'''
|
||||
capability = torch.cuda.get_device_capability()
|
||||
is_on_Volta = capability[0] < 8
|
||||
# TODO[Superjomn]: Remove the constraints below when features are ready
|
||||
is_feature_supported = not (is_int8 or is_tf32)
|
||||
|
||||
# split-K
|
||||
[16, 16, 32, 1, 16, 16, 16, False, False],
|
||||
[64, 64, 128, 1, 64, 64, 32, False, False],
|
||||
])
|
||||
def test_gemm_for_mmav1(SIZE_M, SIZE_N, SIZE_K, NUM_WARPS, BLOCK_SIZE_M, BLOCK_SIZE_N, BLOCK_SIZE_K, TRANS_A, TRANS_B):
|
||||
test_gemm(SIZE_M, SIZE_N, SIZE_K, NUM_WARPS, BLOCK_SIZE_M, BLOCK_SIZE_N, BLOCK_SIZE_K, TRANS_A, TRANS_B)
|
||||
if is_on_Volta:
|
||||
if (not is_feature_supported):
|
||||
pytest.skip("Not valid on Volta")
|
101
python/tests/test_matmul.py
Normal file
101
python/tests/test_matmul.py
Normal file
@@ -0,0 +1,101 @@
|
||||
import itertools
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
import triton
|
||||
import triton._C.libtriton.triton as _triton
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"BLOCK_M, BLOCK_N, BLOCK_K, SPLIT_K, NWARP, NSTAGE, M, N, K, AT, BT, DTYPE",
|
||||
itertools.chain(
|
||||
*[
|
||||
[
|
||||
# 1 warp
|
||||
(16, 16, 16, 1, 1, 2, None, None, None, AT, BT, DTYPE),
|
||||
(32, 16, 16, 1, 1, 2, None, None, None, AT, BT, DTYPE),
|
||||
(16, 32, 16, 1, 1, 2, None, None, None, AT, BT, DTYPE),
|
||||
(16, 16, 32, 1, 1, 2, None, None, None, AT, BT, DTYPE),
|
||||
(32, 16, 32, 1, 1, 2, None, None, None, AT, BT, DTYPE),
|
||||
(16, 32, 32, 1, 1, 2, None, None, None, AT, BT, DTYPE),
|
||||
(16, 16, 64, 1, 1, 2, None, None, None, AT, BT, DTYPE),
|
||||
(64, 16, 64, 1, 1, 2, None, None, None, AT, BT, DTYPE),
|
||||
(16, 64, 64, 1, 1, 2, None, None, None, AT, BT, DTYPE),
|
||||
# 2 warp
|
||||
(64, 32, 64, 1, 2, 2, None, None, None, AT, BT, DTYPE),
|
||||
(32, 64, 64, 1, 2, 2, None, None, None, AT, BT, DTYPE),
|
||||
(64, 32, 16, 1, 2, 2, None, None, None, AT, BT, DTYPE),
|
||||
(32, 64, 16, 1, 2, 2, None, None, None, AT, BT, DTYPE),
|
||||
(128, 32, 32, 1, 2, 2, None, None, None, AT, BT, DTYPE),
|
||||
(32, 128, 32, 1, 2, 2, None, None, None, AT, BT, DTYPE),
|
||||
# 4 warp
|
||||
(128, 64, 16, 1, 4, 2, None, None, None, AT, BT, DTYPE),
|
||||
(64, 128, 16, 1, 4, 2, None, None, None, AT, BT, DTYPE),
|
||||
(128, 32, 32, 1, 4, 2, None, None, None, AT, BT, DTYPE),
|
||||
(32, 128, 32, 1, 4, 2, None, None, None, AT, BT, DTYPE),
|
||||
(128, 32, 64, 1, 4, 2, None, None, None, AT, BT, DTYPE),
|
||||
(32, 128, 64, 1, 4, 2, None, None, None, AT, BT, DTYPE),
|
||||
# 8 warp
|
||||
(128, 256, 16, 1, 8, 2, None, None, None, AT, BT, DTYPE),
|
||||
(256, 128, 16, 1, 8, 2, None, None, None, AT, BT, DTYPE),
|
||||
(256, 128, 32, 1, 8, 2, None, None, None, AT, BT, DTYPE),
|
||||
# split-k
|
||||
(64, 64, 16, 2, 4, 2, None, None, None, AT, BT, DTYPE),
|
||||
(64, 64, 16, 4, 4, 2, None, None, None, AT, BT, DTYPE),
|
||||
(64, 64, 16, 8, 4, 2, None, None, None, AT, BT, DTYPE),
|
||||
# variable input
|
||||
(128, 128, 32, 1, 4, 2, 1024, 1024, 1024, AT, BT, DTYPE),
|
||||
(128, 128, 32, 1, 4, 2, 384, 128, 640, AT, BT, DTYPE),
|
||||
(128, 128, 32, 1, 4, 2, 107, 233, 256, AT, BT, DTYPE),
|
||||
(128, 128, 32, 1, 4, 2, 107, 233, 311, AT, BT, DTYPE),
|
||||
] for DTYPE in ["float16", "bfloat16", "float32"] for AT in [False, True] for BT in [False, True]
|
||||
],
|
||||
# n-stage
|
||||
*[
|
||||
[
|
||||
(16, 16, 16, 1, 1, STAGES, 1024, 1024, 1024, AT, BT, DTYPE),
|
||||
(64, 32, 64, 1, 2, STAGES, 1024, 1024, 1024, AT, BT, DTYPE),
|
||||
(128, 64, 16, 1, 4, STAGES, 1024, 1024, 1024, AT, BT, DTYPE),
|
||||
(256, 128, 32, 1, 8, STAGES, 1024, 1024, 1024, AT, BT, DTYPE),
|
||||
(128, 128, 32, 1, 4, STAGES, 384, 128, 640, AT, BT, DTYPE),
|
||||
# split-k
|
||||
(64, 64, 16, 8, 4, STAGES, 1024, 1024, 1024, AT, BT, DTYPE),
|
||||
(64, 64, 16, 8, 4, STAGES, 1024, 1024, 32, AT, BT, DTYPE),
|
||||
] for DTYPE in ["float16", "bfloat16", "float32"] for AT in [False, True] for BT in [False, True] for STAGES in [2, 3, 4]
|
||||
]
|
||||
),
|
||||
)
|
||||
def test_op(BLOCK_M, BLOCK_N, BLOCK_K, SPLIT_K, NWARP, NSTAGE, M, N, K, AT, BT, DTYPE):
|
||||
capability = torch.cuda.get_device_capability()
|
||||
if capability[0] < 7:
|
||||
pytest.skip("Only test tl.dot() on devices with sm >= 70")
|
||||
if capability[0] < 8 and DTYPE == "bfloat16":
|
||||
pytest.skip("Only test bfloat16 on devices with sm >= 80")
|
||||
#if DTYPE == "bfloat16" and SPLIT_K != 1:
|
||||
# pytest.skip("bfloat16 matmuls don't allow split_k for now")
|
||||
if DTYPE == "bfloat16":
|
||||
pytest.skip("bfloat16 matmuls doesn't support for now")
|
||||
torch.manual_seed(0)
|
||||
# nuke kernel decorators -- will set meta-parameters manually
|
||||
kwargs = {'BLOCK_M': BLOCK_M, 'BLOCK_N': BLOCK_N, 'BLOCK_K': BLOCK_K, 'SPLIT_K': SPLIT_K}
|
||||
pre_hook = None if SPLIT_K == 1 else lambda nargs: nargs['C'].zero_()
|
||||
configs = [triton.Config(kwargs=kwargs, num_warps=NWARP, num_stages=NSTAGE, pre_hook=pre_hook)]
|
||||
kernel = triton.ops._matmul.kernel
|
||||
kernel.configs = configs
|
||||
# kernel.run = kernel.run.run.run
|
||||
|
||||
# get matrix shape
|
||||
M = BLOCK_M if M is None else M
|
||||
N = BLOCK_N if N is None else N
|
||||
K = BLOCK_K * SPLIT_K if K is None else K
|
||||
# allocate/transpose inputs
|
||||
DTYPE = {"float16": torch.float16, "bfloat16": torch.bfloat16, "float32": torch.float32}[DTYPE]
|
||||
a = .1 * torch.randn((K, M) if AT else (M, K), device="cuda", dtype=DTYPE)
|
||||
b = .1 * torch.randn((N, K) if BT else (K, N), device="cuda", dtype=DTYPE)
|
||||
a = a.t() if AT else a
|
||||
b = b.t() if BT else b
|
||||
# run test
|
||||
th_c = torch.matmul(a, b)
|
||||
tt_c = triton.testing.catch_oor(lambda: triton.ops.matmul(a, b), pytest)
|
||||
triton.testing.assert_almost_equal(th_c, tt_c)
|
164
python/tests/test_performance.py
Normal file
164
python/tests/test_performance.py
Normal file
@@ -0,0 +1,164 @@
|
||||
import subprocess
|
||||
import sys
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
import triton
|
||||
import triton.language as tl
|
||||
from triton.testing import get_dram_gbps, get_max_tensorcore_tflops, set_gpu_clock
|
||||
|
||||
DEVICE_NAME = 'v100'
|
||||
|
||||
#######################
|
||||
# Utilities
|
||||
#######################
|
||||
|
||||
|
||||
def nvsmi(attrs):
|
||||
attrs = ','.join(attrs)
|
||||
cmd = ['nvidia-smi', '-i', '0', '--query-gpu=' + attrs, '--format=csv,noheader,nounits']
|
||||
out = subprocess.check_output(cmd)
|
||||
ret = out.decode(sys.stdout.encoding).split(',')
|
||||
ret = [int(x) for x in ret]
|
||||
return ret
|
||||
|
||||
|
||||
#######################
|
||||
# Matrix Multiplication
|
||||
#######################
|
||||
|
||||
sm_clocks = {'v100': 1350, 'a100': 1350}
|
||||
mem_clocks = {'v100': 877, 'a100': 1215}
|
||||
|
||||
matmul_data = {
|
||||
'v100': {
|
||||
# square
|
||||
(256, 256, 256): {'float16': 0.027},
|
||||
(512, 512, 512): {'float16': 0.158},
|
||||
(1024, 1024, 1024): {'float16': 0.466},
|
||||
(2048, 2048, 2048): {'float16': 0.695},
|
||||
(4096, 4096, 4096): {'float16': 0.831},
|
||||
(8192, 8192, 8192): {'float16': 0.849},
|
||||
# tall-skinny
|
||||
(16, 1024, 1024): {'float16': 0.0128},
|
||||
(16, 4096, 4096): {'float16': 0.0883},
|
||||
(16, 8192, 8192): {'float16': 0.101},
|
||||
(64, 1024, 1024): {'float16': 0.073},
|
||||
(64, 4096, 4096): {'float16': 0.270},
|
||||
(64, 8192, 8192): {'float16': 0.459},
|
||||
(1024, 64, 1024): {'float16': 0.0692},
|
||||
(4096, 64, 4096): {'float16': 0.264},
|
||||
(8192, 64, 8192): {'float16': 0.452},
|
||||
},
|
||||
'a100': {
|
||||
(256, 256, 256): {'float16': 0.010, 'float32': 0.0214, 'int8': 0.006},
|
||||
(512, 512, 512): {'float16': 0.061, 'float32': 0.109, 'int8': 0.030},
|
||||
(1024, 1024, 1024): {'float16': 0.287, 'float32': 0.331, 'int8': 0.169},
|
||||
(2048, 2048, 2048): {'float16': 0.604, 'float32': 0.599, 'int8': 0.385},
|
||||
(4096, 4096, 4096): {'float16': 0.842, 'float32': 0.862, 'int8': 0.711},
|
||||
(8192, 8192, 8192): {'float16': 0.896, 'float32': 0.932, 'int8': 0.860},
|
||||
# tall-skinny
|
||||
(16, 1024, 1024): {'float16': 0.0077, 'float32': 0.0127, 'int8': 0.005},
|
||||
(16, 4096, 4096): {'float16': 0.0363, 'float32': 0.0457, 'int8': 0.0259},
|
||||
(16, 8192, 8192): {'float16': 0.0564, 'float32': 0.0648, 'int8': 0.0431},
|
||||
(64, 1024, 1024): {'float16': 0.0271, 'float32': 0.0509, 'int8': 0.0169},
|
||||
(64, 4096, 4096): {'float16': 0.141, 'float32': 0.162, 'int8': 0.097},
|
||||
(64, 8192, 8192): {'float16': 0.244, 'float32': 0.257, 'int8': 0.174},
|
||||
(1024, 64, 1024): {'float16': 0.0263, 'float32': 0.0458, 'int8': 0.017},
|
||||
(4096, 64, 4096): {'float16': 0.135, 'float32': 0.177, 'int8': 0.102},
|
||||
(8192, 64, 8192): {'float16': 0.216, 'float32': 0.230, 'int8': 0.177},
|
||||
}
|
||||
# # deep reductions
|
||||
# (64 , 64 , 16384) : {'a100': 0.},
|
||||
# (64 , 64 , 65536) : {'a100': 0.},
|
||||
# (256 , 256 , 8192 ) : {'a100': 0.},
|
||||
# (256 , 256 , 32768) : {'a100': 0.},
|
||||
}
|
||||
|
||||
|
||||
@pytest.mark.parametrize('M, N, K, dtype_str',
|
||||
[(M, N, K, dtype_str)
|
||||
for M, N, K in matmul_data[DEVICE_NAME].keys()
|
||||
for dtype_str in ['float16']])
|
||||
def test_matmul(M, N, K, dtype_str):
|
||||
if dtype_str in ['float32', 'int8'] and DEVICE_NAME != 'a100':
|
||||
pytest.skip('Only test float32 & int8 on a100')
|
||||
dtype = {'float16': torch.float16, 'float32': torch.float32, 'int8': torch.int8}[dtype_str]
|
||||
torch.manual_seed(0)
|
||||
ref_gpu_util = matmul_data[DEVICE_NAME][(M, N, K)][dtype_str]
|
||||
ref_sm_clock = sm_clocks[DEVICE_NAME]
|
||||
cur_sm_clock = nvsmi(['clocks.current.sm'])[0]
|
||||
max_gpu_perf = get_max_tensorcore_tflops(dtype, clock_rate=cur_sm_clock * 1e3)
|
||||
assert abs(cur_sm_clock - ref_sm_clock) < 10, f'GPU SMs must run at {ref_sm_clock} MHz'
|
||||
if dtype == torch.int8:
|
||||
a = torch.randint(-128, 127, (M, K), dtype=dtype, device='cuda')
|
||||
b = torch.randint(-128, 127, (N, K), dtype=dtype, device='cuda')
|
||||
b = b.t() # only test row-col layout
|
||||
else:
|
||||
a = torch.randn((M, K), dtype=dtype, device='cuda')
|
||||
b = torch.randn((K, N), dtype=dtype, device='cuda')
|
||||
fn = lambda: triton.ops.matmul(a, b)
|
||||
ms = triton.testing.do_bench(fn, percentiles=None, warmup=25, rep=1000)
|
||||
cur_gpu_perf = 2. * M * N * K / ms * 1e-9
|
||||
cur_gpu_util = cur_gpu_perf / max_gpu_perf
|
||||
triton.testing.assert_almost_equal(cur_gpu_util, ref_gpu_util, decimal=2)
|
||||
|
||||
|
||||
#######################
|
||||
# Element-Wise
|
||||
#######################
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _add(x_ptr, y_ptr, output_ptr, n_elements,
|
||||
BLOCK_SIZE: tl.constexpr):
|
||||
pid = tl.program_id(axis=0)
|
||||
block_start = pid * BLOCK_SIZE
|
||||
offsets = block_start + tl.arange(0, BLOCK_SIZE)
|
||||
mask = offsets < n_elements
|
||||
x = tl.load(x_ptr + offsets, mask=mask)
|
||||
y = tl.load(y_ptr + offsets, mask=mask)
|
||||
output = x + y
|
||||
tl.store(output_ptr + offsets, output, mask=mask)
|
||||
|
||||
|
||||
elementwise_data = {
|
||||
'v100': {
|
||||
1024 * 16: 0.0219,
|
||||
1024 * 64: 0.0791,
|
||||
1024 * 256: 0.243,
|
||||
1024 * 1024: 0.530,
|
||||
1024 * 4096: 0.796,
|
||||
1024 * 16384: 0.905,
|
||||
1024 * 65536: 0.939,
|
||||
},
|
||||
'a100': {
|
||||
1024 * 16: 0.008,
|
||||
1024 * 64: 0.034,
|
||||
1024 * 256: 0.114,
|
||||
1024 * 1024: 0.315,
|
||||
1024 * 4096: 0.580,
|
||||
1024 * 16384: 0.782,
|
||||
1024 * 65536: 0.850,
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@pytest.mark.parametrize('N', elementwise_data[DEVICE_NAME].keys())
|
||||
def test_elementwise(N):
|
||||
torch.manual_seed(0)
|
||||
ref_gpu_util = elementwise_data[DEVICE_NAME][N]
|
||||
cur_mem_clock = nvsmi(['clocks.current.memory'])[0]
|
||||
ref_mem_clock = mem_clocks[DEVICE_NAME]
|
||||
max_gpu_perf = get_dram_gbps()
|
||||
assert abs(cur_mem_clock - ref_mem_clock) < 10, f'GPU memory must run at {ref_mem_clock} MHz'
|
||||
z = torch.empty((N, ), dtype=torch.float16, device='cuda')
|
||||
x = torch.randn_like(z)
|
||||
y = torch.randn_like(z)
|
||||
grid = lambda args: (triton.cdiv(N, args['BLOCK_SIZE']), )
|
||||
fn = lambda: _add[grid](x, y, z, N, BLOCK_SIZE=1024)
|
||||
ms = triton.testing.do_bench(fn, percentiles=None, warmup=25, rep=250)
|
||||
cur_gpu_perf = 3. * N * z.element_size() / ms * 1e-6
|
||||
cur_gpu_util = cur_gpu_perf / max_gpu_perf
|
||||
triton.testing.assert_almost_equal(cur_gpu_util, ref_gpu_util, decimal=2)
|
@@ -68,7 +68,7 @@ def test_vecadd_scf_mask(shape, num_warps, block_size, iter_size):
|
||||
@num_elements: number of elements
|
||||
'''
|
||||
pid = tl.program_id(axis=0)
|
||||
for i in range(math.ceil(block_size / iter_size)):
|
||||
for i in range(tl.cdiv(block_size, iter_size)):
|
||||
# TODO: a bug here, if put the offset outside the forloop, there will be a GPU mis-aligned error.
|
||||
offset = pid * block_size + tl.arange(0, iter_size)
|
||||
x_ptrs = x_ptr + offset
|
||||
|
@@ -329,10 +329,6 @@ class CodeGenerator(ast.NodeVisitor):
|
||||
def visit_BinOp(self, node):
|
||||
lhs = self.visit(node.left)
|
||||
rhs = self.visit(node.right)
|
||||
if isinstance(lhs, triton.language.constexpr):
|
||||
lhs = lhs.value
|
||||
if isinstance(rhs, triton.language.constexpr):
|
||||
rhs = rhs.value
|
||||
fn = {
|
||||
ast.Add: '__add__',
|
||||
ast.Sub: '__sub__',
|
||||
@@ -591,8 +587,10 @@ class CodeGenerator(ast.NodeVisitor):
|
||||
ast.NodeVisitor.generic_visit(self, stmt)
|
||||
return
|
||||
# handle negative constant step (not supported by scf.for in MLIR)
|
||||
negative_step = False
|
||||
if isinstance(step, triton.language.constexpr) and step.value < 0:
|
||||
step = triton.language.constexpr(-step.value)
|
||||
negative_step = True
|
||||
lb, ub = ub, lb
|
||||
# lb/ub/step might be constexpr, we need to cast them to tensor
|
||||
lb = triton.language.core._to_tensor(lb, self.builder).handle
|
||||
@@ -640,6 +638,9 @@ class CodeGenerator(ast.NodeVisitor):
|
||||
# update induction variable with actual value, and replace all uses
|
||||
self.builder.set_insertion_point_to_start(for_op.get_body(0))
|
||||
iv = self.builder.create_index_to_si(for_op.get_induction_var())
|
||||
if negative_step:
|
||||
ub_si = self.builder.create_index_to_si(ub)
|
||||
iv = self.builder.create_sub(ub_si, iv)
|
||||
self.lscope[node.target.id].handle.replace_all_uses_with(iv)
|
||||
self.set_value(name, triton.language.core.tensor(iv, triton.language.core.int32))
|
||||
|
||||
@@ -890,9 +891,9 @@ def ttir_to_ttgir(mod, num_warps, num_stages, compute_capability):
|
||||
pm = _triton.ir.pass_manager(mod.context)
|
||||
pm.add_convert_triton_to_tritongpu_pass(num_warps)
|
||||
pm.enable_debug()
|
||||
# Convert blocked layout to mma layout for dot ops so that pipeline
|
||||
# can get shared memory swizzled correctly.
|
||||
pm.add_coalesce_pass()
|
||||
# The combine pass converts blocked layout to mma layout
|
||||
# for dot ops so that pipeline can get shared memory swizzled correctly.
|
||||
pm.add_triton_gpu_combine_pass(compute_capability)
|
||||
pm.add_tritongpu_pipeline_pass(num_stages)
|
||||
# Prefetch must be done after pipeline pass because pipeline pass
|
||||
@@ -1358,12 +1359,12 @@ def make_hash(fn, **kwargs):
|
||||
return hashlib.md5((Path(fn).read_text() + triton.runtime.jit.version_key()).encode("utf-8")).hexdigest()
|
||||
|
||||
|
||||
# - ^\s*func\s+ : match the start of the string, any leading whitespace, the keyword func,
|
||||
# - ^\s*func\s+ : match the start of the string, any leading whitespace, the keyword func,
|
||||
# and any following whitespace
|
||||
# - (public\s+)? : optionally match the keyword public and any following whitespace
|
||||
# - (@\w+) : match an @ symbol followed by one or more word characters
|
||||
# - (@\w+) : match an @ symbol followed by one or more word characters
|
||||
# (letters, digits, or underscores), and capture it as group 1 (the function name)
|
||||
# - (\((?:%\w+: \S+(?: \{\S+ = \S+ : \S+\})?(?:, )?)*\)) : match a pair of parentheses enclosing
|
||||
# - (\((?:%\w+: \S+(?: \{\S+ = \S+ : \S+\})?(?:, )?)*\)) : match a pair of parentheses enclosing
|
||||
# zero or more arguments separated by commas, and capture it as group 2 (the argument list)
|
||||
mlir_prototype_pattern = r'^\s*func\s+(?:public\s+)?(@\w+)(\((?:%\w+: \S+(?: \{\S+ = \S+ : \S+\})?(?:, )?)*\))\s*\{\s*$'
|
||||
ptx_prototype_pattern = r"\.(?:visible|extern)\s+\.(?:entry|func)\s+(\w+)\s*\(([^)]*)\)"
|
||||
@@ -1384,6 +1385,8 @@ arg_type_pattern = {
|
||||
|
||||
# def compile(fn, signature: str, device: int = -1, constants=dict(), num_warps: int = 4, num_stages: int = 3, extern_libs=None, configs=None):
|
||||
def compile(fn, **kwargs):
|
||||
capability = torch.cuda.get_device_capability()
|
||||
capability = capability[0] * 10 + capability[1]
|
||||
# we get the kernel, i.e. the first function generated in the module
|
||||
# if fn is not a JITFunction, then it
|
||||
# has to be a path to a file
|
||||
@@ -1391,24 +1394,22 @@ def compile(fn, **kwargs):
|
||||
asm = dict()
|
||||
constants = kwargs.get("constants", dict())
|
||||
num_warps = kwargs.get("num_warps", 4)
|
||||
num_stages = kwargs.get("num_stages", 3)
|
||||
num_stages = kwargs.get("num_stages", 3 if capability >= 75 else 2)
|
||||
extern_libs = kwargs.get("extern_libs", dict())
|
||||
device = kwargs.get("device", torch.cuda.current_device())
|
||||
capability = torch.cuda.get_device_capability()
|
||||
capability = capability[0]*10 + capability[1]
|
||||
# build compilation stages
|
||||
stages = {
|
||||
"ast" : (lambda path: fn, None),
|
||||
"ttir": (lambda path: _triton.ir.parse_mlir_module(path, context),
|
||||
lambda src: ast_to_ttir(src, signature, configs[0], constants)),
|
||||
"ttgir": (lambda path: _triton.ir.parse_mlir_module(path, context),
|
||||
lambda src: ttir_to_ttgir(src, num_warps, num_stages, capability)),
|
||||
"llir": (lambda path: Path(path).read_bytes(),
|
||||
lambda src: ttgir_to_llir(src, extern_libs, capability)),
|
||||
"ptx": (lambda path: Path(path).read_text(),
|
||||
lambda src: llir_to_ptx(src, capability)),
|
||||
"cubin": (lambda path: Path(path).read_bytes(),
|
||||
lambda src: ptx_to_cubin(src, capability))
|
||||
"ast": (lambda path: fn, None),
|
||||
"ttir": (lambda path: _triton.ir.parse_mlir_module(path, context),
|
||||
lambda src: ast_to_ttir(src, signature, configs[0], constants)),
|
||||
"ttgir": (lambda path: _triton.ir.parse_mlir_module(path, context),
|
||||
lambda src: ttir_to_ttgir(src, num_warps, num_stages, capability)),
|
||||
"llir": (lambda path: Path(path).read_bytes(),
|
||||
lambda src: ttgir_to_llir(src, extern_libs, capability)),
|
||||
"ptx": (lambda path: Path(path).read_text(),
|
||||
lambda src: llir_to_ptx(src, capability)),
|
||||
"cubin": (lambda path: Path(path).read_bytes(),
|
||||
lambda src: ptx_to_cubin(src, capability))
|
||||
}
|
||||
# find out the signature of the function
|
||||
if isinstance(fn, triton.runtime.JITFunction):
|
||||
@@ -1430,9 +1431,7 @@ def compile(fn, **kwargs):
|
||||
import re
|
||||
match = re.search(prototype_pattern[ir], src, re.MULTILINE)
|
||||
name, signature = match.group(1), match.group(2)
|
||||
print(name, signature)
|
||||
types = re.findall(arg_type_pattern[ir], signature)
|
||||
print(types)
|
||||
param_tys = [convert_type_repr(ty) for ty in types]
|
||||
signature = {k: v for k, v in enumerate(param_tys)}
|
||||
first_stage = list(stages.keys()).index(ir)
|
||||
@@ -1467,8 +1466,8 @@ def compile(fn, **kwargs):
|
||||
if ir == ext:
|
||||
next_module = parse(fn)
|
||||
elif os.path.exists(path) and\
|
||||
ir in metadata["ctime"] and\
|
||||
os.path.getctime(path) == metadata["ctime"][ir]:
|
||||
ir in metadata["ctime"] and\
|
||||
os.path.getctime(path) == metadata["ctime"][ir]:
|
||||
next_module = parse(path)
|
||||
else:
|
||||
next_module = compile(module)
|
||||
@@ -1504,8 +1503,7 @@ class CompiledKernel:
|
||||
self.asm = asm
|
||||
device = torch.cuda.current_device()
|
||||
global cuda_utils
|
||||
if cuda_utils is None:
|
||||
cuda_utils = CudaUtils()
|
||||
init_cuda_utils()
|
||||
mod, func, n_regs, n_spills = cuda_utils.load_binary(metadata["name"], self.asm["cubin"], self.shared, device)
|
||||
self.cu_module = mod
|
||||
self.cu_function = func
|
||||
@@ -1562,6 +1560,34 @@ class CudaUtils(object):
|
||||
|
||||
#define CUDA_CHECK(ans) { gpuAssert((ans), __FILE__, __LINE__); if(PyErr_Occurred()) return NULL; }
|
||||
|
||||
static PyObject* getDeviceProperties(PyObject* self, PyObject* args){
|
||||
int device_id;
|
||||
if(!PyArg_ParseTuple(args, "i", &device_id))
|
||||
return NULL;
|
||||
// Get device handle
|
||||
CUdevice device;
|
||||
cuDeviceGet(&device, device_id);
|
||||
|
||||
// create a struct to hold device properties
|
||||
int max_shared_mem;
|
||||
int multiprocessor_count;
|
||||
int sm_clock_rate;
|
||||
int mem_clock_rate;
|
||||
int mem_bus_width;
|
||||
CUDA_CHECK(cuDeviceGetAttribute(&max_shared_mem, CU_DEVICE_ATTRIBUTE_MAX_SHARED_MEMORY_PER_BLOCK, device));
|
||||
CUDA_CHECK(cuDeviceGetAttribute(&multiprocessor_count, CU_DEVICE_ATTRIBUTE_MULTIPROCESSOR_COUNT, device));
|
||||
CUDA_CHECK(cuDeviceGetAttribute(&sm_clock_rate, CU_DEVICE_ATTRIBUTE_CLOCK_RATE, device));
|
||||
CUDA_CHECK(cuDeviceGetAttribute(&mem_clock_rate, CU_DEVICE_ATTRIBUTE_MEMORY_CLOCK_RATE, device));
|
||||
CUDA_CHECK(cuDeviceGetAttribute(&mem_bus_width, CU_DEVICE_ATTRIBUTE_GLOBAL_MEMORY_BUS_WIDTH, device));
|
||||
|
||||
|
||||
return Py_BuildValue("{s:i, s:i, s:i, s:i, s:i}", "max_shared_mem", max_shared_mem,
|
||||
"multiprocessor_count", multiprocessor_count,
|
||||
"sm_clock_rate", sm_clock_rate,
|
||||
"mem_clock_rate", mem_clock_rate,
|
||||
"mem_bus_width", mem_bus_width);
|
||||
}
|
||||
|
||||
static PyObject* loadBinary(PyObject* self, PyObject* args) {
|
||||
const char* name;
|
||||
const char* data;
|
||||
@@ -1601,6 +1627,7 @@ class CudaUtils(object):
|
||||
|
||||
static PyMethodDef ModuleMethods[] = {
|
||||
{"load_binary", loadBinary, METH_VARARGS, "Load provided cubin into CUDA driver"},
|
||||
{"get_device_properties", getDeviceProperties, METH_VARARGS, "Get the properties for a given device"},
|
||||
{NULL, NULL, 0, NULL} // sentinel
|
||||
};
|
||||
|
||||
@@ -1640,6 +1667,13 @@ class CudaUtils(object):
|
||||
mod = importlib.util.module_from_spec(spec)
|
||||
spec.loader.exec_module(mod)
|
||||
self.load_binary = mod.load_binary
|
||||
self.get_device_properties = mod.get_device_properties
|
||||
|
||||
|
||||
def init_cuda_utils():
|
||||
global cuda_utils
|
||||
if cuda_utils is None:
|
||||
cuda_utils = CudaUtils()
|
||||
|
||||
|
||||
cuda_utils = None
|
||||
|
@@ -9,6 +9,7 @@ from triton._C.libtriton.triton import ir
|
||||
|
||||
T = TypeVar('T')
|
||||
|
||||
|
||||
def _to_tensor(x, builder):
|
||||
if isinstance(x, bool):
|
||||
return tensor(builder.get_int1(x), int1)
|
||||
@@ -348,6 +349,9 @@ class constexpr:
|
||||
def __mul__(self, other):
|
||||
return constexpr(self.value * other.value)
|
||||
|
||||
def __mod__(self, other):
|
||||
return constexpr(self.value % other.value)
|
||||
|
||||
def __rmul__(self, other):
|
||||
return constexpr(other.value * self.value)
|
||||
|
||||
@@ -726,10 +730,12 @@ def broadcast_to(input, shape, _builder=None):
|
||||
"""
|
||||
return semantic.broadcast_impl_shape(input, shape, _builder)
|
||||
|
||||
|
||||
@builtin
|
||||
def trans(input, _builder=None):
|
||||
return semantic.trans(input, _builder)
|
||||
|
||||
|
||||
@builtin
|
||||
def cat(input, other, can_reorder=False, _builder=None):
|
||||
"""
|
||||
@@ -762,6 +768,7 @@ def view(input, shape, _builder=None):
|
||||
shape = [x.value for x in shape]
|
||||
return semantic.view(input, shape, _builder)
|
||||
|
||||
|
||||
@builtin
|
||||
def reshape(input, shape, _builder=None):
|
||||
# TODO: should be more than just a view
|
||||
|
@@ -481,7 +481,8 @@ def view(input: tl.tensor,
|
||||
dst_shape: List[int],
|
||||
builder: ir.builder) -> tl.tensor:
|
||||
# TODO: disable when TritonToTritonGPU handles views properly
|
||||
assert len(input.shape) == len(dst_shape)
|
||||
|
||||
# assert len(input.shape) == len(dst_shape)
|
||||
numel = 1
|
||||
for s in dst_shape:
|
||||
numel *= s
|
||||
@@ -982,6 +983,11 @@ def dot(lhs: tl.tensor,
|
||||
allow_tf32: bool,
|
||||
builder: ir.builder) -> tl.tensor:
|
||||
assert lhs.type.is_block() and rhs.type.is_block()
|
||||
assert len(lhs.shape) == 2 and len(rhs.shape) == 2
|
||||
assert lhs.shape[1].value == rhs.shape[0].value
|
||||
assert lhs.shape[0].value >= 16 and lhs.shape[1].value >= 16 \
|
||||
and rhs.shape[1].value >= 16,\
|
||||
"small blocks not supported!"
|
||||
if lhs.type.scalar.is_int():
|
||||
_0 = builder.get_int32(0)
|
||||
ret_scalar_ty = tl.int32
|
||||
@@ -1138,7 +1144,7 @@ def max_contiguous(x: tl.tensor, values: List[int]) -> tl.tensor:
|
||||
|
||||
|
||||
def debug_barrier(builder: ir.builder) -> tl.tensor:
|
||||
return tl.tensor(builder.create_barrier(''), tl.void)
|
||||
return tl.tensor(builder.create_barrier(), tl.void)
|
||||
|
||||
|
||||
def printf(prefix: str, args: List[tl.tensor], builder: ir.builder) -> tl.tensor:
|
||||
|
@@ -26,32 +26,30 @@ def get_configs_io_bound():
|
||||
return configs
|
||||
|
||||
|
||||
@triton.heuristics({
|
||||
'EVEN_K': lambda args: args['K'] % (args['BLOCK_K'] * args['SPLIT_K']) == 0,
|
||||
})
|
||||
@triton.autotune(
|
||||
configs=[
|
||||
# basic configs for compute-bound matmuls
|
||||
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 256, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=3, num_warps=8),
|
||||
triton.Config({'BLOCK_M': 256, 'BLOCK_N': 128, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=3, num_warps=8),
|
||||
triton.Config({'BLOCK_M': 256, 'BLOCK_N': 64, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
|
||||
triton.Config({'BLOCK_M': 64, 'BLOCK_N': 256, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
|
||||
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
|
||||
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
|
||||
triton.Config({'BLOCK_M': 64, 'BLOCK_N': 128, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
|
||||
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 32, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
|
||||
triton.Config({'BLOCK_M': 64, 'BLOCK_N': 32, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=5, num_warps=2),
|
||||
# good for int8
|
||||
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 256, 'BLOCK_K': 128, 'SPLIT_K': 1}, num_stages=3, num_warps=8),
|
||||
triton.Config({'BLOCK_M': 256, 'BLOCK_N': 128, 'BLOCK_K': 128, 'SPLIT_K': 1}, num_stages=3, num_warps=8),
|
||||
triton.Config({'BLOCK_M': 256, 'BLOCK_N': 64, 'BLOCK_K': 128, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
|
||||
triton.Config({'BLOCK_M': 64, 'BLOCK_N': 256, 'BLOCK_K': 128, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
|
||||
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'BLOCK_K': 128, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
|
||||
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'BLOCK_K': 64, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
|
||||
triton.Config({'BLOCK_M': 64, 'BLOCK_N': 128, 'BLOCK_K': 64, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
|
||||
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 32, 'BLOCK_K': 64, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
|
||||
triton.Config({'BLOCK_M': 64, 'BLOCK_N': 32, 'BLOCK_K': 64, 'SPLIT_K': 1}, num_stages=5, num_warps=2),
|
||||
] + get_configs_io_bound(),
|
||||
#configs=[
|
||||
# # basic configs for compute-bound matmuls
|
||||
# triton.Config({'BLOCK_M': 128, 'BLOCK_N': 256, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=3, num_warps=8),
|
||||
# triton.Config({'BLOCK_M': 256, 'BLOCK_N': 128, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=3, num_warps=8),
|
||||
# triton.Config({'BLOCK_M': 256, 'BLOCK_N': 64, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
|
||||
# triton.Config({'BLOCK_M': 64, 'BLOCK_N': 256, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
|
||||
# triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
|
||||
# triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
|
||||
# triton.Config({'BLOCK_M': 64, 'BLOCK_N': 128, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
|
||||
# triton.Config({'BLOCK_M': 128, 'BLOCK_N': 32, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
|
||||
# triton.Config({'BLOCK_M': 64, 'BLOCK_N': 32, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=5, num_warps=2),
|
||||
# # good for int8
|
||||
# triton.Config({'BLOCK_M': 128, 'BLOCK_N': 256, 'BLOCK_K': 128, 'SPLIT_K': 1}, num_stages=3, num_warps=8),
|
||||
# triton.Config({'BLOCK_M': 256, 'BLOCK_N': 128, 'BLOCK_K': 128, 'SPLIT_K': 1}, num_stages=3, num_warps=8),
|
||||
# triton.Config({'BLOCK_M': 256, 'BLOCK_N': 64, 'BLOCK_K': 128, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
|
||||
# triton.Config({'BLOCK_M': 64, 'BLOCK_N': 256, 'BLOCK_K': 128, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
|
||||
# triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'BLOCK_K': 128, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
|
||||
# triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'BLOCK_K': 64, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
|
||||
# triton.Config({'BLOCK_M': 64, 'BLOCK_N': 128, 'BLOCK_K': 64, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
|
||||
# triton.Config({'BLOCK_M': 128, 'BLOCK_N': 32, 'BLOCK_K': 64, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
|
||||
# triton.Config({'BLOCK_M': 64, 'BLOCK_N': 32, 'BLOCK_K': 64, 'SPLIT_K': 1}, num_stages=5, num_warps=2),
|
||||
#] + get_configs_io_bound(),
|
||||
configs=[triton.Config({'BLOCK_M': 256, 'BLOCK_N': 128, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=2, num_warps=8)],
|
||||
key=['M', 'N', 'K'],
|
||||
prune_configs_by={
|
||||
'early_config_prune': early_config_prune,
|
||||
@@ -59,6 +57,9 @@ def get_configs_io_bound():
|
||||
'top_k': 10
|
||||
},
|
||||
)
|
||||
@triton.heuristics({
|
||||
'EVEN_K': lambda args: args['K'] % (args['BLOCK_K'] * args['SPLIT_K']) == 0,
|
||||
})
|
||||
@triton.jit
|
||||
def _kernel(A, B, C, M, N, K,
|
||||
stride_am, stride_ak,
|
||||
@@ -113,7 +114,7 @@ def _kernel(A, B, C, M, N, K,
|
||||
|
||||
|
||||
class _matmul(torch.autograd.Function):
|
||||
kernel = _kernel
|
||||
kernel = None
|
||||
|
||||
_locks = dict()
|
||||
|
||||
@@ -134,12 +135,17 @@ class _matmul(torch.autograd.Function):
|
||||
# accumulator types
|
||||
ACC_TYPE = tl.float32 if a.dtype in [torch.float16, torch.bfloat16, torch.float32] else tl.int32
|
||||
# launch kernel
|
||||
grid = lambda META: (triton.cdiv(M, META['BLOCK_M']) * triton.cdiv(N, META['BLOCK_N']), META['SPLIT_K'])
|
||||
_kernel[grid](a, b, c, M, N, K,
|
||||
a.stride(0), a.stride(1),
|
||||
b.stride(0), b.stride(1),
|
||||
c.stride(0), c.stride(1),
|
||||
GROUP_M=8, ACC_TYPE=ACC_TYPE)
|
||||
#grid = lambda META: (triton.cdiv(M, META['BLOCK_M']) * triton.cdiv(N, META['BLOCK_N']), META['SPLIT_K'])
|
||||
if _matmul.kernel is None:
|
||||
_matmul.kernel = triton.compile("/root/code/triton-mlir/python/tests/matmul.ttgir", num_stages=2, num_warps=8)
|
||||
#_matmul.kernel = _kernel
|
||||
_matmul.kernel[(8192//256 * 8192//128, 1, 1,)](a.data_ptr(), b.data_ptr(), c.data_ptr(),
|
||||
M, N, K,
|
||||
a.stride(0), b.stride(0), c.stride(0))
|
||||
#_matmul.kernel[grid](a, b, c,
|
||||
# M, N, K,
|
||||
# a.stride(0), a.stride(1), b.stride(0), b.stride(1), c.stride(0), c.stride(1),
|
||||
# GROUP_M=8, ACC_TYPE=ACC_TYPE)
|
||||
return c
|
||||
|
||||
@staticmethod
|
||||
|
@@ -10,7 +10,9 @@ from triton.testing import get_dram_gbps, get_max_simd_tflops, get_max_tensorcor
|
||||
def get_tensorcore_tflops(backend, device, num_ctas, num_warps, dtype):
|
||||
''' return compute throughput in TOPS '''
|
||||
total_warps = num_ctas * min(num_warps, 4)
|
||||
num_subcores = _triton.runtime.num_sm(backend, device) * 4 # on recent GPUs
|
||||
triton.compiler.init_cuda_utils()
|
||||
|
||||
num_subcores = triton.compiler.cuda_utils.get_device_properties(device)["multiprocessor_count"] * 4 # on recent GPUs
|
||||
tflops = min(num_subcores, total_warps) / num_subcores * get_max_tensorcore_tflops(dtype, backend, device)
|
||||
return tflops
|
||||
|
||||
@@ -18,14 +20,14 @@ def get_tensorcore_tflops(backend, device, num_ctas, num_warps, dtype):
|
||||
def get_simd_tflops(backend, device, num_ctas, num_warps, dtype):
|
||||
''' return compute throughput in TOPS '''
|
||||
total_warps = num_ctas * min(num_warps, 4)
|
||||
num_subcores = _triton.runtime.num_sm(backend, device) * 4 # on recent GPUs
|
||||
num_subcores = triton.compiler.cuda_utils.get_device_properties(device)["multiprocessor_count"] * 4 # on recent GPUs
|
||||
tflops = min(num_subcores, total_warps) / num_subcores * get_max_simd_tflops(dtype, backend, device)
|
||||
return tflops
|
||||
|
||||
|
||||
def get_tflops(backend, device, num_ctas, num_warps, dtype):
|
||||
cc = _triton.runtime.cc(backend, device)
|
||||
if cc < 80 and dtype == torch.float32:
|
||||
capability = torch.cuda.get_device_capability(device)
|
||||
if capability[0] < 8 and dtype == torch.float32:
|
||||
return get_simd_tflops(backend, device, num_ctas, num_warps, dtype)
|
||||
return get_tensorcore_tflops(backend, device, num_ctas, num_warps, dtype)
|
||||
|
||||
@@ -59,7 +61,7 @@ def estimate_matmul_time(
|
||||
compute_ms = total_ops / tput
|
||||
|
||||
# time to load data
|
||||
num_sm = _triton.runtime.num_sm(backend, device)
|
||||
num_sm = triton.compiler.cuda_utils.get_device_properties(device)["multiprocessor_count"]
|
||||
active_cta_ratio = min(1, num_ctas / num_sm)
|
||||
active_cta_ratio_bw1 = min(1, num_ctas / 32) # 32 active ctas are enough to saturate
|
||||
active_cta_ratio_bw2 = max(min(1, (num_ctas - 32) / (108 - 32)), 0) # 32-108, remaining 5%
|
||||
@@ -99,7 +101,7 @@ def estimate_matmul_time(
|
||||
def early_config_prune(configs, named_args):
|
||||
backend = _triton.runtime.backend.CUDA
|
||||
device = torch.cuda.current_device()
|
||||
cc = _triton.runtime.cc(backend, device)
|
||||
capability = torch.cuda.get_device_capability()
|
||||
# BLOCK_M, BLOCK_N, BLOCK_K, SPLIT_K, num_warps, num_stages
|
||||
dtsize = named_args['A'].element_size()
|
||||
dtype = named_args['A'].dtype
|
||||
@@ -110,7 +112,10 @@ def early_config_prune(configs, named_args):
|
||||
kw = config.kwargs
|
||||
BLOCK_M, BLOCK_N, BLOCK_K, num_stages = \
|
||||
kw['BLOCK_M'], kw['BLOCK_N'], kw['BLOCK_K'], config.num_stages
|
||||
max_shared_memory = _triton.runtime.max_shared_memory(backend, device)
|
||||
|
||||
# TODO: move to `cuda_utils` submodule
|
||||
triton.compiler.init_cuda_utils()
|
||||
max_shared_memory = triton.compiler.cuda_utils.get_device_properties(device)["max_shared_mem"]
|
||||
required_shared_memory = (BLOCK_M + BLOCK_N) * BLOCK_K * num_stages * dtsize
|
||||
if required_shared_memory <= max_shared_memory:
|
||||
pruned_configs.append(config)
|
||||
@@ -136,7 +141,7 @@ def early_config_prune(configs, named_args):
|
||||
pruned_configs = []
|
||||
for k, v in configs_map.items():
|
||||
BLOCK_M, BLOCK_N, BLOCK_K, SPLIT_K, num_warps = k
|
||||
if cc >= 80:
|
||||
if capability[0] >= 8:
|
||||
# compute cycles (only works for ampere GPUs)
|
||||
mmas = BLOCK_M * BLOCK_N * BLOCK_K / (16 * 8 * 16)
|
||||
mma_cycles = mmas / min(4, num_warps) * 8
|
||||
|
@@ -74,6 +74,8 @@ class Autotuner(KernelInterface):
|
||||
for config in pruned_configs}
|
||||
bench_end = time.time()
|
||||
self.bench_time = bench_end - bench_start
|
||||
for config, ttime in timings.items():
|
||||
print(f"config: {config}, time: {ttime}")
|
||||
self.cache[key] = builtins.min(timings, key=timings.get)
|
||||
self.hook(args)
|
||||
self.configs_timings = timings
|
||||
|
@@ -16,6 +16,9 @@ except ImportError:
|
||||
_cutlass = None
|
||||
has_cutlass = False
|
||||
|
||||
# TODO: move to separate module
|
||||
import triton
|
||||
|
||||
|
||||
def catch_oor(kernel, pytest_handle=None):
|
||||
try:
|
||||
@@ -330,8 +333,8 @@ def get_dram_gbps(backend=None, device=None):
|
||||
backend = _triton.runtime.backend.CUDA
|
||||
if not device:
|
||||
device = torch.cuda.current_device()
|
||||
mem_clock_khz = _triton.runtime.memory_clock_rate(backend, device)
|
||||
bus_width = _triton.runtime.global_memory_bus_width(backend, device)
|
||||
mem_clock_khz = triton.compiler.cuda_utils.get_device_properties(device)["mem_clock_rate"] # in kHz
|
||||
bus_width = triton.compiler.cuda_utils.get_device_properties(device)["mem_bus_width"]
|
||||
bw_gbps = mem_clock_khz * bus_width * 2 / 1e6 / 8 # In GB/s
|
||||
return bw_gbps
|
||||
|
||||
@@ -341,11 +344,13 @@ def get_max_tensorcore_tflops(dtype: torch.dtype, backend=None, device=None, clo
|
||||
backend = _triton.runtime.backend.CUDA
|
||||
if not device:
|
||||
device = torch.cuda.current_device()
|
||||
num_subcores = _triton.runtime.num_sm(backend, device) * 4 # on recent GPUs
|
||||
|
||||
triton.compiler.init_cuda_utils()
|
||||
num_subcores = triton.compiler.cuda_utils.get_device_properties(device)["multiprocessor_count"] * 4
|
||||
if not clock_rate:
|
||||
clock_rate = _triton.runtime.clock_rate(backend, device) # in kHz
|
||||
cc = _triton.runtime.cc(backend, device)
|
||||
if cc < 80:
|
||||
clock_rate = triton.compiler.cuda_utils.get_device_properties(device)["sm_clock_rate"] # in kHz
|
||||
capability = torch.cuda.get_device_capability(device)
|
||||
if capability[0] < 8:
|
||||
assert dtype == torch.float16
|
||||
ops_per_sub_core = 256 # 2 4x4x4 Tensor Cores
|
||||
else:
|
||||
|
@@ -156,7 +156,7 @@ import triton.language as tl
|
||||
|
||||
@triton.autotune(
|
||||
configs=[
|
||||
triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8}, num_stages=3, num_warps=8),
|
||||
triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=3, num_warps=8),
|
||||
],
|
||||
key=['M', 'N', 'K'],
|
||||
)
|
||||
|
@@ -879,8 +879,8 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} {
|
||||
#blocked = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [2, 16], warpsPerCTA = [1, 4], order = [1, 0]}>
|
||||
#shared = #triton_gpu.shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [1, 0]}>
|
||||
#mma = #triton_gpu.mma<{version = 1, warpsPerCTA = [2, 2]}>
|
||||
#dot_operand_a = #triton_gpu.dot_op<{opIdx=0, parent=#mma}>
|
||||
#dot_operand_b = #triton_gpu.dot_op<{opIdx=1, parent=#mma}>
|
||||
#dot_operand_a = #triton_gpu.dot_op<{opIdx=0, parent=#mma, isMMAv1Row=true}>
|
||||
#dot_operand_b = #triton_gpu.dot_op<{opIdx=1, parent=#mma, isMMAv1Row=true}>
|
||||
module attributes {"triton_gpu.num-warps" = 4 : i32} {
|
||||
func @matmul884_kernel_dot_operand_layout(%ptr:!tt.ptr<f32> {tt.divisibility = 16 : i32},
|
||||
%a:tensor<128x32xf16, #shared>, %b:tensor<32x256xf16, #shared>) {
|
||||
|
Reference in New Issue
Block a user