11 Commits

Author SHA1 Message Date
Jokeren
65896aef9d Debugging 2022-12-13 11:17:40 -08:00
Jokeren
d8d6b9f3f1 Fix v100 fp32 2022-12-12 15:52:16 -08:00
Jokeren
3a1c140385 Add script 2022-12-12 12:10:40 -08:00
Yan Chunwei
0cfe909df8 [Triton-MLIR][BACKEND] some code clean on the backend (#978) 2022-12-12 09:46:16 +00:00
Philippe Tillet
e5cfa0f633 [FRONTEND] Added a few assertions in semantic.dot (#977) 2022-12-12 00:07:14 -08:00
Philippe Tillet
e552219104 [FRONTEND] Add possibility for user to force a GPU threadsync barrier (#976)
compiler still has pitfalls even in master branch
2022-12-11 23:03:52 -08:00
Philippe Tillet
52accd4c2b [BACKEND] Add isRow attribute for DotOp tensors whose parent is mmav1 (#970)
Co-authored-by: Yan Chunwei <yanchunwei@outlook.com>
2022-12-11 19:01:57 -08:00
Yan Chunwei
4fb048873a [Triton-MLIR][CI] Fix v100 tests to avoid skiping tests mistakely (#975) 2022-12-11 04:57:51 +00:00
Keren Zhou
be2f70699c [BACKEND][FRONTEND] Fix problems with test_matmul (#973)
1. Handle induction variable when step is negative
2. Restore async_wait that accidentally deleted
3. Add missing induction variable in prefetch
4. Add device property functions

Co-authored-by: Philippe Tillet <Phil.Tillet@gmail.com>
2022-12-10 20:34:58 -08:00
Yan Chunwei
24fd953f9a [BACKEND] Refine v100 tests and fix mmav1 numwarps>1 hang issue (#971)
This PR

- Fix numWarps>1 hang issue
- add existing test cases in test_gemm.py to CI, and add a common flag
`valid_on_Volta` to determine whether the test case should be activated
on Volta or just skip.
  - Currently, the column-major cases are disabled.
 - Add test_core.py and other tests to Volta CI
   - the `test_printf.py` failed.
2022-12-09 07:41:22 -08:00
goostavz
793012b4c4 [Triton-MLIR][Backend] Fix mmav1 in case of numWarps > 1 (#972) 2022-12-09 18:36:05 +08:00
26 changed files with 890 additions and 257 deletions

View File

@@ -89,7 +89,14 @@ jobs:
if: ${{matrix.runner[0] == 'self-hosted' && matrix.runner[1] == 'V100'}}
run: |
cd python/tests
pytest test_gemm.py::test_gemm_for_mmav1
pytest -k "not test_where_broadcast and not test_dot" test_core.py
pytest test_gemm.py
pytest test_backend.py
pytest test_reduce.py
pytest test_vecadd.py
pytest test_elementwise.py
pytest test_ext_elemwise.py
pytest test_transpose.py
- name: Run CXX unittests
run: |

View File

@@ -43,6 +43,8 @@ bool maybeSharedAllocationOp(Operation *op);
bool maybeAliasOp(Operation *op);
bool supportMMA(triton::DotOp op, int version);
std::string getValueOperandName(Value value, AsmState &state);
template <typename T_OUT, typename T_IN>

View File

@@ -416,15 +416,35 @@ In TritonGPU dialect, considering `d = tt.dot a, b, c`
tt.dot's operands a and b must be of DotOperandEncodingAttr layout.
a's opIdx is 0, b's opIdx is 1.
The parend field in DotOperandEncodingAttr is the layout of d.
For MMA v1, an additional attribute `isMMAv1Row` determines whether e.g. the a operand is used
in the context of an mma.884.row.col or an mma.884.col.col operation. See the PTX ISA documentation
section 9.7.13.4.1 for more details.
}];
let parameters = (
ins
"unsigned":$opIdx,
"Attribute":$parent
"Attribute":$parent,
"Attribute":$isMMAv1Row
);
let builders = [
AttrBuilder<(ins "unsigned":$opIdx,
"Attribute":$parent), [{
Attribute isMMAv1Row;
if(parent.isa<MmaEncodingAttr>() &&
parent.cast<MmaEncodingAttr>().getVersion() == 1){
isMMAv1Row = BoolAttr::get(context, true);
}
return $_get(context, opIdx, parent, isMMAv1Row);
}]>
];
let extraClassDeclaration = extraBaseClassDeclaration;
}
#endif

View File

@@ -110,6 +110,19 @@ bool maybeAliasOp(Operation *op) {
isa<tensor::InsertSliceOp>(op);
}
bool supportMMA(triton::DotOp op, int version) {
// Refer to mma section for the data type supported by Volta and Hopper
// Tensor Core in
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#warp-level-matrix-fragment-mma-884-f16
auto aElemTy = op.a().getType().cast<RankedTensorType>().getElementType();
auto bElemTy = op.b().getType().cast<RankedTensorType>().getElementType();
return (aElemTy.isF16() && bElemTy.isF16()) ||
(aElemTy.isBF16() && bElemTy.isBF16()) ||
(aElemTy.isF32() && bElemTy.isF32() && op.allowTF32() &&
version >= 2) ||
(aElemTy.isInteger(8) && bElemTy.isInteger(8) && version >= 2);
}
std::string getValueOperandName(Value value, AsmState &state) {
std::string opName;
llvm::raw_string_ostream ss(opName);

View File

@@ -105,11 +105,9 @@ struct DotOpMmaV1ConversionHelper {
}
// Get the number of fp16x2 elements for $a.
// \param shapeTransed: the shape or reordered shape if transpose needed.
// \param shapeTransed: A's shape or reordered shape if transpose needed.
// \param orderTransed: the order or reordered order if transpose needed.
unsigned getNumM(ArrayRef<int64_t> shapeTransed,
ArrayRef<unsigned> orderTransed) const {
bool isARow = orderTransed[0] != 0;
unsigned getNumM(ArrayRef<int64_t> shapeTransed, bool isARow) const {
AParam param(isARow);
unsigned numM = param.rep[0] * shapeTransed[0] / (param.spw[0] * wpt[0]);
@@ -117,11 +115,9 @@ struct DotOpMmaV1ConversionHelper {
}
// Get the number of fp16x2 elements for $b.
// \param shapeTransed: the shape or reordered shape if transpose needed.
// \param shapeTransed: B' shape or reordered shape if transpose needed.
// \param orderTransed: the order or reordered order if transpose needed.
unsigned getNumN(ArrayRef<int64_t> shapeTransed,
ArrayRef<unsigned> orderTransed) const {
bool isBRow = orderTransed[0] != 0;
unsigned getNumN(ArrayRef<int64_t> shapeTransed, bool isBRow) const {
BParam param(isBRow);
unsigned numN = param.rep[1] * shapeTransed[1] / (param.spw[1] * wpt[1]);
@@ -130,7 +126,7 @@ struct DotOpMmaV1ConversionHelper {
int numElemsPerThreadA(ArrayRef<int64_t> shapeTransed,
ArrayRef<unsigned> orderTransed) const {
int numM = getNumM(shapeTransed, orderTransed);
int numM = getNumM(shapeTransed, orderTransed[0] == 1);
int NK = shapeTransed[1];
// NOTE: We couldn't get the vec from the shared layout.
@@ -143,7 +139,7 @@ struct DotOpMmaV1ConversionHelper {
int numElemsPerThreadB(ArrayRef<int64_t> shapeTransed,
ArrayRef<unsigned> orderTransed) const {
unsigned numN = getNumN(shapeTransed, orderTransed);
unsigned numN = getNumN(shapeTransed, orderTransed[0] == 1);
int NK = shapeTransed[0];
// NOTE: We couldn't get the vec from the shared layout.
// int vecB = sharedLayout.getVec();
@@ -1451,7 +1447,7 @@ Value DotOpMmaV1ConversionHelper::loadA(
}
};
unsigned numM = getNumM(shape, order);
unsigned numM = getNumM(shape, order[0] == 1);
for (unsigned k = 0; k < NK; k += 4)
for (unsigned m = 0; m < numM / 2; ++m)
loadA(m, k);
@@ -1563,7 +1559,7 @@ Value DotOpMmaV1ConversionHelper::loadB(
}
};
unsigned numN = getNumN(shape, order);
unsigned numN = getNumN(shape, order[0] == 1);
for (unsigned k = 0; k < NK; k += 4)
for (unsigned n = 0; n < numN / 2; ++n) {
if (!hbs.count({n, k}))

View File

@@ -739,7 +739,6 @@ Value convertSplatLikeOp(Type elemType, Type resType, Value constVal,
auto tensorTy = resType.cast<RankedTensorType>();
if (tensorTy.getEncoding().isa<BlockedEncodingAttr>() ||
tensorTy.getEncoding().isa<SliceEncodingAttr>()) {
auto tensorTy = resType.cast<RankedTensorType>();
auto srcType = typeConverter->convertType(elemType);
auto llSrc = bitcast(constVal, srcType);
size_t elemsPerThread = getElemsPerThread(tensorTy);
@@ -981,7 +980,7 @@ struct LoadOpConversion
size_t size = width / valueElemNbits;
auto vecTy = LLVM::getFixedVectorType(valueElemTy, size);
Value v = rewriter.create<LLVM::UndefOp>(loc, vecTy);
Value v = undef(vecTy);
for (size_t s = 0; s < size; ++s) {
Value falseVal = otherElems[vecStart + ii * size + s];
Value sVal = createIndexAttrConstant(
@@ -1118,7 +1117,7 @@ struct StoreOpConversion
SmallVector<std::pair<Value, std::string>> asmArgs;
for (size_t wordIdx = 0; wordIdx < nWords; ++wordIdx) {
// llWord is a width-len composition
Value llWord = rewriter.create<LLVM::UndefOp>(loc, wordTy);
Value llWord = undef(wordTy);
// Insert each value element to the composition
for (size_t elemIdx = 0; elemIdx < wordNElems; ++elemIdx) {
const size_t elemOffset = vecStart + wordIdx * wordNElems + elemIdx;
@@ -1129,10 +1128,7 @@ struct StoreOpConversion
elem = bitcast(elem, valueElemTy);
Type u32Ty = typeConverter->convertType(type::u32Ty(ctx));
llWord =
insert_element(wordTy, llWord, elem,
rewriter.create<LLVM::ConstantOp>(
loc, u32Ty, IntegerAttr::get(u32Ty, elemIdx)));
llWord = insert_element(wordTy, llWord, elem, i32_val(elemIdx));
}
llWord = bitcast(llWord, valArgTy);
std::string constraint =
@@ -2886,11 +2882,15 @@ private:
Value laneIdRem16Div4 = udiv(laneIdRem16, _4);
Value laneIdRem16Div4Rem2 = urem(laneIdRem16Div4, _2);
Value laneIdRem4Div2 = udiv(urem(laneId, _4), _2);
Value rowWarpOffset = mul(multiDimWarpId[0], _16);
Value colWarpOffset = mul(multiDimWarpId[1], _16);
mmaRowIdx[0] =
add(add(mul(laneIdDiv16, _8), mul(laneIdRem16Div4Rem2, _4)),
laneIdRem2);
mmaRowIdx[0] = add(mmaRowIdx[0], rowWarpOffset);
mmaRowIdx[1] = add(mmaRowIdx[0], _2);
mmaColIdx[0] = add(mul(laneIdRem16Div8, _4), mul(laneIdRem4Div2, _2));
mmaColIdx[0] = add(mmaColIdx[0], colWarpOffset);
mmaColIdx[1] = add(mmaColIdx[0], _1);
mmaColIdx[2] = add(mmaColIdx[0], _8);
mmaColIdx[3] = add(mmaColIdx[0], idx_val(9));
@@ -3336,9 +3336,8 @@ struct DotOpConversion : public ConvertTritonGPUOpToLLVMPattern<triton::DotOp> {
"Unsupported MMA kind found when converting DotOp to LLVM.");
}
if (op.getType().cast<RankedTensorType>().getElementType().isF32() &&
A.getType().cast<RankedTensorType>().getElementType().isF32() &&
!op.allowTF32())
// XXX: fp64 has not been tested yet. In theory, it should work.
if (!isMMA)
return convertFMADot(op, adaptor, rewriter);
llvm::report_fatal_error(
@@ -3348,33 +3347,16 @@ struct DotOpConversion : public ConvertTritonGPUOpToLLVMPattern<triton::DotOp> {
// Tell whether a DotOp support HMMA.
// This is port from the master branch, the original logic is retained.
static bool isDotHMMA(DotOp op) {
auto a = op.a();
auto b = op.b();
auto c = op.c();
auto d = op.getResult();
auto aTensorTy = a.getType().cast<RankedTensorType>();
auto bTensorTy = b.getType().cast<RankedTensorType>();
auto cTensorTy = c.getType().cast<RankedTensorType>();
auto dTensorTy = d.getType().cast<RankedTensorType>();
if (!dTensorTy.getEncoding().isa<MmaEncodingAttr>())
return false;
auto mmaLayout = dTensorTy.getEncoding().cast<MmaEncodingAttr>();
auto aElemTy = aTensorTy.getElementType();
auto bElemTy = bTensorTy.getElementType();
assert((mmaLayout.getVersion() == 1 || mmaLayout.getVersion() == 2) &&
"Unexpected MMA layout version found");
// Refer to mma section for the data type supported by Volta and Hopper
// Tensor Core in
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#warp-level-matrix-fragment-mma-884-f16
return (aElemTy.isF16() && bElemTy.isF16()) ||
(aElemTy.isBF16() && bElemTy.isBF16()) ||
(aElemTy.isF32() && bElemTy.isF32() && op.allowTF32() &&
mmaLayout.getVersion() >= 2) ||
(aElemTy.isInteger(8) && bElemTy.isInteger(8) &&
mmaLayout.getVersion() >= 2);
return supportMMA(op, mmaLayout.getVersion());
}
// Tell whether a DotOp support HMMA by the operand type(either $a or $b).
@@ -3428,6 +3410,20 @@ Value ConvertLayoutOpConversion::lowerSharedToDotOperandMMA(
} else if (!isOuter && mmaLayout.getVersion() == 1 &&
isHMMA) { // tensor core v1
DotOpMmaV1ConversionHelper helper(mmaLayout);
bool isMMAv1Row =
dotOperandLayout.getIsMMAv1Row().cast<BoolAttr>().getValue();
auto srcSharedLayout = src.getType()
.cast<RankedTensorType>()
.getEncoding()
.cast<SharedEncodingAttr>();
// Can only convert [1, 0] to row or [0, 1] to col for now
if ((srcSharedLayout.getOrder()[0] == 1 && !isMMAv1Row) ||
(srcSharedLayout.getOrder()[0] == 0 && isMMAv1Row)) {
llvm::errs() << "Unsupported Shared -> DotOperand[MMAv1] conversion\n";
return Value();
}
if (dotOperandLayout.getOpIdx() == 0) { // operand $a
// TODO[Superjomn]: transA is not available here.
bool transA = false;
@@ -3540,51 +3536,39 @@ DotOpConversion::convertMMA884(triton::DotOp op, DotOpAdaptor adaptor,
.cast<RankedTensorType>()
.getEncoding()
.cast<MmaEncodingAttr>();
auto ALayout = A.getType()
.cast<RankedTensorType>()
.getEncoding()
.cast<DotOperandEncodingAttr>();
auto BLayout = B.getType()
.cast<RankedTensorType>()
.getEncoding()
.cast<DotOperandEncodingAttr>();
auto ATensorTy = A.getType().cast<RankedTensorType>();
auto BTensorTy = B.getType().cast<RankedTensorType>();
auto DTensorTy = D.getType().cast<RankedTensorType>();
SmallVector<int> AShape(ATensorTy.getShape().begin(),
ATensorTy.getShape().end());
SmallVector<int> BShape(BTensorTy.getShape().begin(),
BTensorTy.getShape().end());
auto AShape = ATensorTy.getShape();
auto BShape = BTensorTy.getShape();
auto DShape = DTensorTy.getShape();
auto wpt = mmaLayout.getWarpsPerCTA();
// TODO[Superjomn]: order cannot accessed in DotOp.
SmallVector<unsigned> AOrder({1, 0});
SmallVector<unsigned> BOrder({1, 0});
bool isARow = ALayout.getIsMMAv1Row().cast<BoolAttr>().getValue();
bool isBRow = BLayout.getIsMMAv1Row().cast<BoolAttr>().getValue();
bool isARow = AOrder[0] != 0;
bool isBRow = BOrder[0] != 0;
bool isAVec4 = !isARow && AShape[isARow] <= 16; // fp16*4 = 16bytes
bool isBVec4 = isBRow && BShape[isBRow] <= 16;
// TODO[Superjomn]: ld.v4 is not supported.
isAVec4 = true;
isBVec4 = true;
int packSize0 = (isARow || isAVec4) ? 1 : 2;
int packSize1 = (isBRow && !isBVec4) ? 2 : 1;
SmallVector<int> fpw({2, 2, 1});
SmallVector<int> rep({2 * packSize0, 2 * packSize1, 1});
SmallVector<int> spw({fpw[0] * 4 * rep[0], fpw[1] * 4 * rep[1], 1});
Value loadedA = adaptor.a();
Value loadedB = adaptor.b();
Value loadedC = adaptor.c();
DotOpMmaV1ConversionHelper helper(mmaLayout);
unsigned numM = rep[0] * DShape[0] / (spw[0] * wpt[0]);
unsigned numN = rep[1] * DShape[1] / (spw[1] * wpt[1]);
unsigned numM = helper.getNumM(AShape, isARow);
unsigned numN = helper.getNumN(BShape, isBRow);
unsigned NK = AShape[1];
auto has = helper.extractLoadedOperand(loadedA, NK, rewriter);
auto hbs = helper.extractLoadedOperand(loadedB, NK, rewriter);
auto has = helper.extractLoadedOperand(adaptor.a(), NK, rewriter);
auto hbs = helper.extractLoadedOperand(adaptor.b(), NK, rewriter);
// Initialize accumulators with external values, the acc holds the accumulator
// value that is shared between the MMA instructions inside a DotOp, we can
// call the order of the values the accumulator-internal order.
SmallVector<Value> acc = getElementsFromStruct(loc, loadedC, rewriter);
SmallVector<Value> acc = getElementsFromStruct(loc, adaptor.c(), rewriter);
size_t resSize = acc.size();
// The resVals holds the final result of the DotOp.
@@ -3697,38 +3681,19 @@ DotOpConversion::convertFMADot(triton::DotOp op, OpAdaptor adaptor,
auto bShape = bTensorTy.getShape();
auto cShape = cTensorTy.getShape();
ValueTable has, hbs;
int mShapePerCTA{-1}, nShapePerCTA{-1};
int mSizePerThread{-1}, nSizePerThread{-1};
ArrayRef<unsigned> aOrder, bOrder;
Value llA, llB;
BlockedEncodingAttr dLayout =
dTensorTy.getEncoding().cast<BlockedEncodingAttr>();
auto order = dLayout.getOrder();
auto cc = getElementsFromStruct(loc, adaptor.c(), rewriter);
DotOpFMAConversionHelper helper(dLayout);
if (auto aDotOpLayout =
aTensorTy.getEncoding()
.dyn_cast<DotOperandEncodingAttr>()) { // get input from
// convert_layout
auto bDotOpLayout =
bTensorTy.getEncoding().dyn_cast<DotOperandEncodingAttr>();
auto aLayout = aDotOpLayout.getParent().cast<BlockedEncodingAttr>();
auto bLayout = bDotOpLayout.getParent().cast<BlockedEncodingAttr>();
auto aDotOpLayout = aTensorTy.getEncoding().cast<DotOperandEncodingAttr>();
auto bDotOpLayout = bTensorTy.getEncoding().cast<DotOperandEncodingAttr>();
auto aLayout = aDotOpLayout.getParent().cast<BlockedEncodingAttr>();
auto bLayout = bDotOpLayout.getParent().cast<BlockedEncodingAttr>();
assert(bLayout);
llA = adaptor.a();
llB = adaptor.b();
} else if (auto aLayout =
aTensorTy.getEncoding()
.dyn_cast<SharedEncodingAttr>()) { // load input from smem
auto bLayout = bTensorTy.getEncoding().dyn_cast<SharedEncodingAttr>();
assert(bLayout);
Value thread = getThreadId(rewriter, loc);
llA = helper.loadA(A, adaptor.a(), dLayout, thread, loc, rewriter);
llB = helper.loadB(B, adaptor.b(), dLayout, thread, loc, rewriter);
}
Value llA = adaptor.a();
Value llB = adaptor.b();
auto sizePerThread = getSizePerThread(dLayout);
auto shapePerCTA = getShapePerCTA(dLayout);
@@ -3737,17 +3702,19 @@ DotOpConversion::convertFMADot(triton::DotOp op, OpAdaptor adaptor,
int M = aShape[0];
int N = bShape[1];
mShapePerCTA = order[0] == 1 ? shapePerCTA[order[1]] : shapePerCTA[order[0]];
mSizePerThread =
int mShapePerCTA =
order[0] == 1 ? shapePerCTA[order[1]] : shapePerCTA[order[0]];
int mSizePerThread =
order[0] == 1 ? sizePerThread[order[1]] : sizePerThread[order[0]];
nShapePerCTA = order[0] == 0 ? shapePerCTA[order[1]] : shapePerCTA[order[0]];
nSizePerThread =
int nShapePerCTA =
order[0] == 0 ? shapePerCTA[order[1]] : shapePerCTA[order[0]];
int nSizePerThread =
order[0] == 0 ? sizePerThread[order[1]] : sizePerThread[order[0]];
has = helper.getValueTableFromStruct(llA, K, M, mShapePerCTA, mSizePerThread,
rewriter, loc);
hbs = helper.getValueTableFromStruct(llB, K, N, nShapePerCTA, nSizePerThread,
rewriter, loc);
auto has = helper.getValueTableFromStruct(llA, K, M, mShapePerCTA,
mSizePerThread, rewriter, loc);
auto hbs = helper.getValueTableFromStruct(llB, K, N, nShapePerCTA,
nSizePerThread, rewriter, loc);
SmallVector<Value> ret = cc;
for (unsigned k = 0; k < K; k++) {
@@ -3758,7 +3725,6 @@ DotOpConversion::convertFMADot(triton::DotOp op, OpAdaptor adaptor,
for (unsigned nn = 0; nn < nSizePerThread; ++nn) {
ret[z] = rewriter.create<LLVM::FMulAddOp>(loc, has[{m + mm, k}],
hbs[{n + nn, k}], ret[z]);
++z;
}
}
@@ -4288,9 +4254,10 @@ struct ExpOpConversionApprox
// For FP64 input, call __nv_expf for higher-precision calculation
if (elemTy.getIntOrFloatBitWidth() == 64)
return {};
const double log2e = 1.4426950408889634;
Value prod =
rewriter.create<LLVM::FMulOp>(loc, f32_ty, operands[0], f32_val(log2e));
Value prod = fmul(f32_ty, operands[0], f32_val(log2e));
PTXBuilder ptxBuilder;
auto &exp2 = ptxBuilder.create<PTXInstr>("ex2")->o("approx").o("f32");
auto output = ptxBuilder.newOperand("=f");
@@ -4779,10 +4746,15 @@ private:
decomposed = true;
});
// async wait is supported in Ampere and later
mod.walk([&](triton::gpu::AsyncWaitOp asyncWaitOp) -> void {
if (!triton::gpu::AsyncWaitOp::isSupported(computeCapability) ||
decomposed) {
if (!triton::gpu::AsyncWaitOp::isSupported(computeCapability)) {
// async wait is supported in Ampere and later
asyncWaitOp.erase();
} else if (decomposed) {
// Wait for all previous async ops
OpBuilder builder(asyncWaitOp);
auto newAsyncWaitOp =
builder.create<triton::gpu::AsyncWaitOp>(asyncWaitOp.getLoc(), 0);
asyncWaitOp.erase();
}
});
@@ -4827,6 +4799,8 @@ public:
MembarAnalysis membarPass(&allocation);
membarPass.run();
mod.print(llvm::errs());
RewritePatternSet scf_patterns(context);
mlir::populateLoopToStdConversionPatterns(scf_patterns);
mlir::ConversionTarget scf_target(*context);

View File

@@ -31,6 +31,7 @@
#include <numeric>
// Shortcuts for some commonly used LLVM ops to keep code simple and intuitive
// Operators
#define inttoptr(...) rewriter.create<LLVM::IntToPtrOp>(loc, __VA_ARGS__)
#define ptrtoint(...) rewriter.create<LLVM::PtrToIntOp>(loc, __VA_ARGS__)
#define zext(...) rewriter.create<LLVM::ZExtOp>(loc, __VA_ARGS__)
@@ -40,6 +41,7 @@
#define sub(...) rewriter.create<LLVM::SubOp>(loc, __VA_ARGS__)
#define fadd(...) rewriter.create<LLVM::FAddOp>(loc, __VA_ARGS__)
#define mul(...) rewriter.create<LLVM::MulOp>(loc, __VA_ARGS__)
#define fmul(...) rewriter.create<LLVM::FMulOp>(loc, __VA_ARGS__)
#define smax(...) rewriter.create<LLVM::SMaxOp>(loc, __VA_ARGS__)
#define umax(...) rewriter.create<LLVM::UMaxOp>(loc, __VA_ARGS__)
#define fmax(...) rewriter.create<LLVM::MaxNumOp>(loc, __VA_ARGS__)
@@ -90,6 +92,8 @@
#define address_of(...) rewriter.create<LLVM::AddressOfOp>(loc, __VA_ARGS__)
#define barrier() rewriter.create<mlir::gpu::BarrierOp>(loc)
#define undef(...) rewriter.create<LLVM::UndefOp>(loc, __VA_ARGS__)
// Types
#define i32_ty rewriter.getIntegerType(32)
#define ui32_ty rewriter.getIntegerType(32, false)
#define f16_ty rewriter.getF16Type()
@@ -102,8 +106,9 @@
#define f64_val(...) LLVM::createConstantF64(loc, rewriter, __VA_ARGS__)
#define void_ty(ctx) LLVM::LLVMVoidType::get(ctx)
#define struct_ty(...) LLVM::LLVMStructType::getLiteral(ctx, __VA_ARGS__)
#define array_ty(elemTy, count) LLVM::LLVMArrayType::get(elemTy, count)
// Creator for constant
// Constants
#define i32_val(...) LLVM::createConstantI32(loc, rewriter, __VA_ARGS__)
#define int_val(width, val) \
LLVM::createLLVMIntegerConstant(rewriter, loc, width, val)

View File

@@ -262,10 +262,10 @@ struct TritonCatPattern : public OpConversionPattern<triton::CatOp> {
// For now, this behaves like generic, but this will evolve when
// we add support for `can_reorder=False`
Type retType = this->getTypeConverter()->convertType(op.getType());
rewriter.replaceOpWithNewOp<triton::CatOp>(op, retType, adaptor.getOperands());
rewriter.replaceOpWithNewOp<triton::CatOp>(op, retType,
adaptor.getOperands());
return success();
}
};
struct TritonTransPattern : public OpConversionPattern<triton::TransOp> {
@@ -450,13 +450,11 @@ void populateTritonPatterns(TritonGPUTypeConverter &typeConverter,
TritonGenericPattern<triton::IntToPtrOp>,
TritonGenericPattern<triton::PtrToIntOp>,
TritonGenericPattern<triton::SplatOp>, TritonBroadcastPattern,
TritonGenericPattern<triton::AddPtrOp>,
TritonCatPattern,
TritonReducePattern,
TritonTransPattern, TritonExpandDimsPattern, TritonMakeRangePattern,
TritonDotPattern, TritonLoadPattern, TritonStorePattern,
TritonExtElemwisePattern, TritonPrintfPattern, TritonAtomicRMWPattern>(
typeConverter, context);
TritonGenericPattern<triton::AddPtrOp>, TritonCatPattern,
TritonReducePattern, TritonTransPattern, TritonExpandDimsPattern,
TritonMakeRangePattern, TritonDotPattern, TritonLoadPattern,
TritonStorePattern, TritonExtElemwisePattern, TritonPrintfPattern,
TritonAtomicRMWPattern>(typeConverter, context);
}
//

View File

@@ -589,15 +589,24 @@ Attribute DotOperandEncodingAttr::parse(AsmParser &parser, Type type) {
return {};
unsigned opIdx = attrs.get("opIdx").cast<IntegerAttr>().getInt();
Attribute parent = attrs.get("parent");
Attribute isMMAv1Row;
if(parent.isa<MmaEncodingAttr>() &&
parent.cast<MmaEncodingAttr>().getVersion() == 1){
isMMAv1Row = attrs.get("isMMAv1Row");
if(!isMMAv1Row)
llvm::report_fatal_error("isMMAv1Row attribute is missing");
}
return parser.getChecked<DotOperandEncodingAttr>(parser.getContext(), opIdx,
parent);
parent, isMMAv1Row);
}
void DotOperandEncodingAttr::print(mlir::AsmPrinter &printer) const {
printer << "<{"
<< "opIdx = " << getOpIdx() << ", "
<< "parent = " << getParent() << "}>";
<< "parent = " << getParent();
if(getIsMMAv1Row())
printer << ", isMMAv1Row = " << getIsMMAv1Row();
printer << "}>";
}
//===----------------------------------------------------------------------===//

View File

@@ -36,7 +36,7 @@ namespace {
class DecomposeDotOperand : public mlir::RewritePattern {
public:
DecomposeDotOperand(mlir::MLIRContext *context)
explicit DecomposeDotOperand(mlir::MLIRContext *context)
: mlir::RewritePattern(triton::gpu::ConvertLayoutOp::getOperationName(),
1, context) {}
@@ -84,7 +84,7 @@ public:
// IIUC they are therefore not handled by DRR right now
class SimplifyConversion : public mlir::RewritePattern {
public:
SimplifyConversion(mlir::MLIRContext *context)
explicit SimplifyConversion(mlir::MLIRContext *context)
: mlir::RewritePattern(triton::gpu::ConvertLayoutOp::getOperationName(),
4, context) {}
@@ -219,8 +219,8 @@ public:
//
// -----------------------------------------------------------------------------
static LogicalResult invertEncoding(Attribute targetEncoding, Operation *op,
Attribute &ret) {
LogicalResult invertEncoding(Attribute targetEncoding, Operation *op,
Attribute &ret) {
ret = targetEncoding;
if (auto expand_dims = dyn_cast<triton::ExpandDimsOp>(op)) {
ret = triton::gpu::SliceEncodingAttr::get(
@@ -246,7 +246,7 @@ inline bool expensive_to_remat(Operation *op) {
if (isa<scf::YieldOp, scf::ForOp>(op))
return true;
return false;
};
}
Operation *cloneWithInferType(mlir::PatternRewriter &rewriter, Operation *op,
BlockAndValueMapping &mapping) {
@@ -276,7 +276,7 @@ Operation *cloneWithInferType(mlir::PatternRewriter &rewriter, Operation *op,
// are reachable from it without passing through any memory operation.
class RematerializeBackward : public mlir::RewritePattern {
public:
RematerializeBackward(mlir::MLIRContext *context)
explicit RematerializeBackward(mlir::MLIRContext *context)
: mlir::RewritePattern(triton::gpu::ConvertLayoutOp::getOperationName(),
2, context) {}
@@ -303,7 +303,7 @@ public:
SetVector<Attribute> layout;
llvm::MapVector<Value, Attribute> toConvert;
std::vector<std::pair<Operation *, Attribute>> queue;
queue.push_back({cvt, targetType.getEncoding()});
queue.emplace_back(cvt, targetType.getEncoding());
int numCvts = 1;
while (!queue.empty()) {
Operation *currOp;
@@ -341,7 +341,7 @@ public:
continue;
// we add one expensive conversion for the current operand
numCvts += 1;
queue.push_back({opArgI, newEncoding});
queue.emplace_back(opArgI, newEncoding);
}
}
// if rematerialization would add more conversions than it removes
@@ -351,8 +351,8 @@ public:
SmallVector<Value, 4> sortedValues;
SetVector<Operation *> tmp;
for (auto it = toConvert.begin(); it != toConvert.end(); ++it) {
Value v = it->first;
for (auto &item : toConvert) {
Value v = item.first;
if (v.getDefiningOp())
tmp.insert(v.getDefiningOp());
else
@@ -393,7 +393,7 @@ public:
class MoveConvertOutOfLoop : public mlir::RewritePattern {
public:
MoveConvertOutOfLoop(mlir::MLIRContext *context)
explicit MoveConvertOutOfLoop(mlir::MLIRContext *context)
: mlir::RewritePattern(scf::ForOp::getOperationName(), 1, context) {}
SmallVector<Value, 4>
@@ -406,7 +406,7 @@ public:
newInitArgs[i] = rewriter.create<triton::gpu::ConvertLayoutOp>(
newInitArgs[i].getLoc(), newType, newInitArgs[i]);
// Clone for loop
scf::ForOp newForOp = rewriter.create<scf::ForOp>(
auto newForOp = rewriter.create<scf::ForOp>(
forOp.getLoc(), forOp.getLowerBound(), forOp.getUpperBound(),
forOp.getStep(), newInitArgs);
newForOp->moveBefore(forOp);
@@ -455,7 +455,7 @@ public:
mlir::PatternRewriter &rewriter) const override {
auto forOp = cast<scf::ForOp>(op);
auto iterArgs = forOp.getRegionIterArgs();
for (auto iterArg : llvm::enumerate(iterArgs)) {
for (const auto &iterArg : llvm::enumerate(iterArgs)) {
// if (iterArg.index() != 1)
// continue;
// skip non-tensor types
@@ -517,7 +517,7 @@ public:
class RematerializeForward : public mlir::RewritePattern {
public:
RematerializeForward(mlir::MLIRContext *context)
explicit RematerializeForward(mlir::MLIRContext *context)
: mlir::RewritePattern(triton::gpu::ConvertLayoutOp::getOperationName(),
2, context) {}
@@ -584,7 +584,7 @@ public:
//
// -----------------------------------------------------------------------------
namespace {
static int computeCapabilityToMMAVersion(int computeCapability) {
int computeCapabilityToMMAVersion(int computeCapability) {
if (computeCapability < 80) {
return 1;
} else if (computeCapability < 90) {
@@ -595,9 +595,7 @@ static int computeCapabilityToMMAVersion(int computeCapability) {
}
}
static SmallVector<int64_t, 2>
mmaVersionToShapePerWarp(int version, const ArrayRef<int64_t> &shape,
int numWarps) {
SmallVector<int64_t, 2> mmaVersionToShapePerWarp(int version) {
if (version == 1)
return {16, 16};
else if (version == 2)
@@ -608,22 +606,23 @@ mmaVersionToShapePerWarp(int version, const ArrayRef<int64_t> &shape,
}
}
SmallVector<unsigned, 2> warpsPerTileV1(triton::DotOp dotOp,
const ArrayRef<int64_t> shape,
SmallVector<unsigned, 2> warpsPerTileV1(const ArrayRef<int64_t> shape,
int numWarps) {
SmallVector<unsigned, 2> ret = {1, 1};
SmallVector<int64_t, 2> shapePerWarp =
mmaVersionToShapePerWarp(1, shape, numWarps);
mmaVersionToShapePerWarp(1 /*version*/);
bool changed = false;
do {
changed = false;
int pre = ret[0];
if (ret[0] * ret[1] < numWarps) {
ret[0] = std::clamp<unsigned>(ret[0] * 2, 1, shape[0] / shapePerWarp[0]);
changed = true;
changed = pre != ret[0];
}
if (ret[0] * ret[1] < numWarps) {
pre = ret[1];
ret[1] = std::clamp<unsigned>(ret[1] * 2, 1, shape[1] / shapePerWarp[1]);
changed = true;
changed = pre != ret[1];
}
} while (changed);
return ret;
@@ -667,7 +666,7 @@ SmallVector<unsigned, 2> warpsPerTileV2(triton::DotOp dotOp,
class OptimizeBlockedToShared : public mlir::RewritePattern {
public:
OptimizeBlockedToShared(mlir::MLIRContext *context)
explicit OptimizeBlockedToShared(mlir::MLIRContext *context)
: RewritePattern(triton::gpu::ConvertLayoutOp::getOperationName(), 1,
context) {}
@@ -713,6 +712,53 @@ public:
}
};
class OptimizeConvertToDotOperand : public mlir::RewritePattern {
public:
explicit OptimizeConvertToDotOperand(mlir::MLIRContext *context)
: RewritePattern(triton::gpu::ConvertLayoutOp::getOperationName(), 1,
context) {}
mlir::LogicalResult
matchAndRewrite(mlir::Operation *op,
mlir::PatternRewriter &rewriter) const override {
auto cvt = cast<triton::gpu::ConvertLayoutOp>(op);
auto srcType = cvt.getOperand().getType().cast<RankedTensorType>();
auto dstType = cvt.getResult().getType().cast<RankedTensorType>();
// order
ArrayRef<unsigned> order;
if (auto srcBlockedLayout =
srcType.getEncoding().dyn_cast<triton::gpu::BlockedEncodingAttr>())
order = srcBlockedLayout.getOrder();
else if (auto srcSharedLayout =
srcType.getEncoding()
.dyn_cast<triton::gpu::SharedEncodingAttr>())
order = srcSharedLayout.getOrder();
else
return failure();
// dot operand output
auto dstDotOperandLayout =
dstType.getEncoding().dyn_cast<triton::gpu::DotOperandEncodingAttr>();
if (!dstDotOperandLayout)
return failure();
if (!dstDotOperandLayout.getIsMMAv1Row())
return failure();
bool isMMAv1Row =
dstDotOperandLayout.getIsMMAv1Row().cast<BoolAttr>().getValue();
if ((order[0] == 1 && isMMAv1Row) || (order[0] == 0 && !isMMAv1Row))
return failure();
auto newIsRow = BoolAttr::get(op->getContext(), !isMMAv1Row);
auto newDstEncoding = triton::gpu::DotOperandEncodingAttr::get(
op->getContext(), dstDotOperandLayout.getOpIdx(),
dstDotOperandLayout.getParent(), newIsRow);
auto newDstType = RankedTensorType::get(
dstType.getShape(), dstType.getElementType(), newDstEncoding);
auto newCvt = rewriter.create<triton::gpu::ConvertLayoutOp>(
op->getLoc(), newDstType, cvt.getOperand());
rewriter.replaceOp(op, newCvt.getResult());
return success();
}
};
class BlockedToMMA : public mlir::RewritePattern {
int computeCapability;
@@ -726,7 +772,7 @@ public:
int version, int numWarps) {
switch (version) {
case 1:
return warpsPerTileV1(dotOp, shape, numWarps);
return warpsPerTileV1(shape, numWarps);
case 2:
return warpsPerTileV2(dotOp, shape, numWarps);
default:
@@ -744,18 +790,16 @@ public:
if (oldRetType.getEncoding().isa<triton::gpu::MmaEncodingAttr>())
return failure();
auto A = dotOp.getOperand(0).getType().cast<RankedTensorType>();
auto B = dotOp.getOperand(1).getType().cast<RankedTensorType>();
int version = computeCapabilityToMMAVersion(computeCapability);
// for FMA, should retain the blocked layout.
if (A.getElementType().isF32() && B.getElementType().isF32() &&
!dotOp.allowTF32())
if (!supportMMA(dotOp, version))
return failure();
// get MMA encoding for the given number of warps
auto retShape = oldRetType.getShape();
auto mod = op->getParentOfType<mlir::ModuleOp>();
int numWarps = triton::gpu::TritonGPUDialect::getNumWarps(mod);
int version = computeCapabilityToMMAVersion(computeCapability);
auto newRetType = RankedTensorType::get(
retShape, oldRetType.getElementType(),
@@ -770,18 +814,36 @@ public:
Value b = dotOp.b();
auto oldAType = a.getType().cast<RankedTensorType>();
auto oldBType = b.getType().cast<RankedTensorType>();
auto oldAOrder = oldAType.getEncoding()
.cast<triton::gpu::DotOperandEncodingAttr>()
.getParent()
.cast<triton::gpu::BlockedEncodingAttr>()
.getOrder();
auto oldBOrder = oldBType.getEncoding()
.cast<triton::gpu::DotOperandEncodingAttr>()
.getParent()
.cast<triton::gpu::BlockedEncodingAttr>()
.getOrder();
Attribute isMMAv1RowA;
Attribute isMMAv1RowB;
if (version == 1) {
isMMAv1RowA = BoolAttr::get(getContext(), oldAOrder[0] == 1);
isMMAv1RowB = BoolAttr::get(getContext(), oldBOrder[0] == 1);
}
auto newAType = RankedTensorType::get(
oldAType.getShape(), oldAType.getElementType(),
triton::gpu::DotOperandEncodingAttr::get(oldAType.getContext(), 0,
newRetType.getEncoding()));
triton::gpu::DotOperandEncodingAttr::get(
oldAType.getContext(), 0, newRetType.getEncoding(), isMMAv1RowA));
auto newBType = RankedTensorType::get(
oldBType.getShape(), oldBType.getElementType(),
triton::gpu::DotOperandEncodingAttr::get(oldBType.getContext(), 1,
newRetType.getEncoding()));
triton::gpu::DotOperandEncodingAttr::get(
oldBType.getContext(), 1, newRetType.getEncoding(), isMMAv1RowB));
a = rewriter.create<triton::gpu::ConvertLayoutOp>(a.getLoc(), newAType, a);
b = rewriter.create<triton::gpu::ConvertLayoutOp>(b.getLoc(), newBType, b);
auto newDot = rewriter.create<triton::DotOp>(
dotOp.getLoc(), newRetType, a, b, newAcc, dotOp.allowTF32());
auto newDot = rewriter.create<triton::DotOp>(dotOp.getLoc(), newRetType, a,
b, newAcc, dotOp.allowTF32());
rewriter.replaceOpWithNewOp<triton::gpu::ConvertLayoutOp>(
op, oldRetType, newDot.getResult());
@@ -789,6 +851,48 @@ public:
}
};
class FixupLoop : public mlir::RewritePattern {
public:
explicit FixupLoop(mlir::MLIRContext *context)
: mlir::RewritePattern(scf::ForOp::getOperationName(), 2, context) {}
mlir::LogicalResult
matchAndRewrite(mlir::Operation *op,
mlir::PatternRewriter &rewriter) const override {
auto forOp = cast<scf::ForOp>(op);
// Rewrite init argument
SmallVector<Value, 4> newInitArgs = forOp.getInitArgs();
bool shouldRematerialize = false;
for (size_t i = 0; i < newInitArgs.size(); i++) {
auto initArg = newInitArgs[i];
auto regionArg = forOp.getRegionIterArgs()[i];
if (newInitArgs[i].getType() != forOp.getRegionIterArgs()[i].getType()) {
shouldRematerialize = true;
break;
}
}
if (!shouldRematerialize)
return failure();
scf::ForOp newForOp = rewriter.create<scf::ForOp>(
forOp.getLoc(), forOp.getLowerBound(), forOp.getUpperBound(),
forOp.getStep(), newInitArgs);
newForOp->moveBefore(forOp);
rewriter.setInsertionPointToStart(newForOp.getBody());
BlockAndValueMapping mapping;
for (const auto &arg : llvm::enumerate(forOp.getRegionIterArgs()))
mapping.map(arg.value(), newForOp.getRegionIterArgs()[arg.index()]);
for (Operation &op : forOp.getBody()->getOperations()) {
Operation *newOp = rewriter.clone(op, mapping);
}
rewriter.replaceOp(forOp, newForOp.getResults());
return success();
}
};
} // namespace
#define GEN_PASS_CLASSES
@@ -808,6 +912,7 @@ public:
mlir::RewritePatternSet patterns(context);
patterns.add<OptimizeBlockedToShared>(context);
patterns.add<OptimizeConvertToDotOperand>(context);
patterns.add<SimplifyConversion>(context);
patterns.add<DecomposeDotOperand>(context);
patterns.add<RematerializeBackward>(context);
@@ -818,6 +923,13 @@ public:
if (applyPatternsAndFoldGreedily(m, std::move(patterns)).failed()) {
signalPassFailure();
}
// llvm::outs() << m << "\n";
mlir::RewritePatternSet loopFixup(context);
loopFixup.add<FixupLoop>(context);
if (applyPatternsAndFoldGreedily(m, std::move(loopFixup)).failed()) {
signalPassFailure();
}
}
};

View File

@@ -225,6 +225,7 @@ scf::ForOp Prefetcher::createNewForOp() {
BlockAndValueMapping mapping;
for (const auto &arg : llvm::enumerate(forOp.getRegionIterArgs()))
mapping.map(arg.value(), newForOp.getRegionIterArgs()[arg.index()]);
mapping.map(forOp.getInductionVar(), newForOp.getInductionVar());
for (Operation &op : forOp.getBody()->without_terminator()) {
Operation *newOp = builder.clone(op, mapping);

View File

@@ -120,6 +120,7 @@ void init_triton_ir(py::module &&m) {
// some placeholders
self.getOrLoadDialect<mlir::triton::TritonDialect>();
self.getOrLoadDialect<mlir::LLVM::LLVMDialect>();
self.getOrLoadDialect<mlir::gpu::GPUDialect>();
});
// .def(py::init([](){
// mlir::MLIRContext context;
@@ -1265,7 +1266,13 @@ void init_triton_ir(py::module &&m) {
[](mlir::OpBuilder &self, mlir::Type &type) -> mlir::Value {
auto loc = self.getUnknownLoc();
return self.create<::mlir::LLVM::UndefOp>(loc, type);
});
})
// Force GPU barrier
.def("create_barrier",
[](mlir::OpBuilder &self) {
auto loc = self.getUnknownLoc();
self.create<mlir::gpu::BarrierOp>(loc);
});
py::class_<mlir::PassManager>(m, "pass_manager")
.def(py::init<mlir::MLIRContext *>())

156
python/tests/matmul.ttgir Normal file
View File

@@ -0,0 +1,156 @@
#blocked0 = #triton_gpu.blocked<{sizePerThread = [1, 8], threadsPerWarp = [8, 4], warpsPerCTA = [8, 1], order = [1, 0]}>
#blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 8], threadsPerWarp = [2, 16], warpsPerCTA = [8, 1], order = [1, 0]}>
#mma = #triton_gpu.mma<{version = 1, warpsPerCTA = [4, 2]}>
#shared0 = #triton_gpu.shared<{vec = 4, perPhase = 2, maxPhase = 4, order = [1, 0]}>
#shared1 = #triton_gpu.shared<{vec = 4, perPhase = 1, maxPhase = 4, order = [1, 0]}>
module attributes {"triton_gpu.num-warps" = 8 : i32} {
func public @_kernel_0d1d2d3d4d5d6d7d8d(%arg0: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %arg1: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %arg2: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %arg3: i32 {tt.divisibility = 16 : i32}, %arg4: i32 {tt.divisibility = 16 : i32}, %arg5: i32 {tt.divisibility = 16 : i32}, %arg6: i32 {tt.divisibility = 16 : i32}, %arg7: i32 {tt.divisibility = 16 : i32}, %arg8: i32 {tt.divisibility = 16 : i32}) {
%c2_i32 = arith.constant 2 : i32
%c1_i32 = arith.constant 1 : i32
%c0_i32 = arith.constant 0 : index
%cst = arith.constant dense<32> : tensor<256x32xi32, #blocked0>
%cst_0 = arith.constant dense<0.000000e+00> : tensor<256x128xf32, #mma>
%c8_i32 = arith.constant 8 : i32
%c255_i32 = arith.constant 255 : i32
%c127_i32 = arith.constant 127 : i32
%c32_i32 = arith.constant 32 : i32
%c0 = arith.constant 0 : index
%c32 = arith.constant 32 : index
%c256_i32 = arith.constant 256 : i32
%c128_i32 = arith.constant 128 : i32
%0 = tt.get_program_id {axis = 0 : i32} : i32
%1 = tt.get_program_id {axis = 1 : i32} : i32
%2 = arith.addi %arg3, %c255_i32 : i32
%3 = arith.divsi %2, %c256_i32 : i32
%4 = arith.addi %arg4, %c127_i32 : i32
%5 = arith.divsi %4, %c128_i32 : i32
%6 = arith.muli %5, %c8_i32 : i32
%7 = arith.divsi %0, %6 : i32
%8 = arith.muli %7, %c8_i32 : i32
%9 = arith.subi %3, %8 : i32
%10 = arith.cmpi slt, %9, %c8_i32 : i32
%11 = select %10, %9, %c8_i32 : i32
%12 = arith.remsi %0, %11 : i32
%13 = arith.addi %8, %12 : i32
%14 = arith.remsi %0, %6 : i32
%15 = arith.divsi %14, %11 : i32
%16 = arith.muli %13, %c256_i32 : i32
%17 = tt.make_range {end = 256 : i32, start = 0 : i32} : tensor<256xi32, #triton_gpu.slice<{dim = 1, parent = #blocked0}>>
%18 = tt.make_range {end = 256 : i32, start = 0 : i32} : tensor<256xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>>
%19 = tt.splat %16 : (i32) -> tensor<256xi32, #triton_gpu.slice<{dim = 1, parent = #blocked0}>>
%20 = tt.splat %16 : (i32) -> tensor<256xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>>
%21 = arith.muli %15, %c128_i32 : i32
%22 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>>
%23 = tt.splat %21 : (i32) -> tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>>
%24 = tt.splat %arg3 : (i32) -> tensor<256xi32, #triton_gpu.slice<{dim = 1, parent = #blocked0}>>
%25 = tt.splat %arg3 : (i32) -> tensor<256xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>>
%26 = tt.splat %arg4 : (i32) -> tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>>
%27 = arith.muli %1, %c32_i32 : i32
%28 = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32, #triton_gpu.slice<{dim = 0, parent = #blocked0}>>
%29 = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>>
%30 = tt.splat %27 : (i32) -> tensor<32xi32, #triton_gpu.slice<{dim = 0, parent = #blocked0}>>
%31 = tt.splat %27 : (i32) -> tensor<32xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>>
%32 = arith.addi %19, %17 : tensor<256xi32, #triton_gpu.slice<{dim = 1, parent = #blocked0}>>
%33 = arith.remsi %32, %24 : tensor<256xi32, #triton_gpu.slice<{dim = 1, parent = #blocked0}>>
%34 = tt.splat %arg6 : (i32) -> tensor<256x1xi32, #blocked0>
%35 = arith.addi %30, %28 : tensor<32xi32, #triton_gpu.slice<{dim = 0, parent = #blocked0}>>
%36 = tt.expand_dims %35 {axis = 0 : i32} : (tensor<32xi32, #triton_gpu.slice<{dim = 0, parent = #blocked0}>>) -> tensor<1x32xi32, #blocked0>
%37 = tt.broadcast %36 : (tensor<1x32xi32, #blocked0>) -> tensor<256x32xi32, #blocked0>
%38 = tt.splat %arg0 : (!tt.ptr<f16>) -> tensor<256x32x!tt.ptr<f16>, #blocked0>
%39 = arith.addi %31, %29 : tensor<32xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>>
%40 = tt.expand_dims %39 {axis = 1 : i32} : (tensor<32xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>>) -> tensor<32x1xi32, #blocked1>
%41 = tt.splat %arg7 : (i32) -> tensor<32x1xi32, #blocked1>
%42 = arith.addi %23, %22 : tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>>
%43 = arith.remsi %42, %26 : tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>>
%44 = tt.expand_dims %43 {axis = 0 : i32} : (tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>>) -> tensor<1x128xi32, #blocked1>
%45 = tt.broadcast %44 : (tensor<1x128xi32, #blocked1>) -> tensor<32x128xi32, #blocked1>
%46 = tt.splat %arg1 : (!tt.ptr<f16>) -> tensor<32x128x!tt.ptr<f16>, #blocked1>
%47 = arith.index_cast %arg5 : i32 to index
%48 = arith.muli %arg7, %c32_i32 : i32
%49 = tt.splat %48 : (i32) -> tensor<32x128xi32, #blocked1>
%50 = tt.expand_dims %33 {axis = 1 : i32} : (tensor<256xi32, #triton_gpu.slice<{dim = 1, parent = #blocked0}>>) -> tensor<256x1xi32, #blocked0>
%51 = arith.muli %50, %34 : tensor<256x1xi32, #blocked0>
%52 = tt.broadcast %51 : (tensor<256x1xi32, #blocked0>) -> tensor<256x32xi32, #blocked0>
%53 = arith.addi %52, %37 : tensor<256x32xi32, #blocked0>
%54 = tt.addptr %38, %53 : tensor<256x32x!tt.ptr<f16>, #blocked0>, tensor<256x32xi32, #blocked0>
%55 = arith.muli %40, %41 : tensor<32x1xi32, #blocked1>
%56 = tt.broadcast %55 : (tensor<32x1xi32, #blocked1>) -> tensor<32x128xi32, #blocked1>
%57 = arith.addi %56, %45 : tensor<32x128xi32, #blocked1>
%58 = tt.addptr %46, %57 : tensor<32x128x!tt.ptr<f16>, #blocked1>, tensor<32x128xi32, #blocked1>
%59 = arith.cmpi slt, %c0, %47 : index
%60 = triton_gpu.alloc_tensor : tensor<2x256x32xf16, #shared0>
%64 = triton_gpu.alloc_tensor : tensor<2x32x128xf16, #shared1>
%61 = tt.splat %59 : (i1) -> tensor<256x32xi1, #blocked0>
%65 = tt.splat %59 : (i1) -> tensor<32x128xi1, #blocked1>
%62 = tt.load %54, %61 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<256x32xf16, #blocked0>
%66 = tt.load %58, %65 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<32x128xf16, #blocked1>
%63 = tensor.insert_slice %62 into %60[%c0_i32, 0, 0] [1, 256, 32] [1, 1, 1] : tensor<256x32xf16, #blocked0> into tensor<2x256x32xf16, #shared0>
%67 = tensor.insert_slice %66 into %64[%c0_i32, 0, 0] [1, 32, 128] [1, 1, 1] : tensor<32x128xf16, #blocked1> into tensor<2x32x128xf16, #shared1>
%68 = tt.addptr %54, %cst : tensor<256x32x!tt.ptr<f16>, #blocked0>, tensor<256x32xi32, #blocked0>
%69 = tt.addptr %58, %49 : tensor<32x128x!tt.ptr<f16>, #blocked1>, tensor<32x128xi32, #blocked1>
%70 = tensor.extract_slice %63[0, 0, 0] [1, 256, 32] [1, 1, 1] : tensor<2x256x32xf16, #shared0> to tensor<256x32xf16, #shared0>
%71 = tensor.extract_slice %67[0, 0, 0] [1, 32, 128] [1, 1, 1] : tensor<2x32x128xf16, #shared1> to tensor<32x128xf16, #shared1>
%72 = tensor.extract_slice %70[0, 0] [256, 16] [1, 1] : tensor<256x32xf16, #shared0> to tensor<256x16xf16, #shared0>
gpu.barrier
%73 = triton_gpu.convert_layout %72 : (tensor<256x16xf16, #shared0>) -> tensor<256x16xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, isMMAv1Row = true}>>
%74 = tensor.extract_slice %71[0, 0] [16, 128] [1, 1] : tensor<32x128xf16, #shared1> to tensor<16x128xf16, #shared1>
%75 = triton_gpu.convert_layout %74 : (tensor<16x128xf16, #shared1>) -> tensor<16x128xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, isMMAv1Row = true}>>
%76:14 = scf.for %arg9 = %c0 to %47 step %c32 iter_args(%arg10 = %cst_0, %arg11 = %54, %arg12 = %58, %arg13 = %63, %arg14 = %67, %arg15 = %70, %arg16 = %71, %arg17 = %68, %arg18 = %69, %arg19 = %c0, %arg20 = %c1_i32, %arg21 = %c1_i32, %arg22 = %73, %arg23 = %75) -> (tensor<256x128xf32, #mma>, tensor<256x32x!tt.ptr<f16>, #blocked0>, tensor<32x128x!tt.ptr<f16>, #blocked1>, tensor<2x256x32xf16, #shared0>, tensor<2x32x128xf16, #shared1>, tensor<256x32xf16, #shared0>, tensor<32x128xf16, #shared1>, tensor<256x32x!tt.ptr<f16>, #blocked0>, tensor<32x128x!tt.ptr<f16>, #blocked1>, index, i32, i32, tensor<256x16xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, isMMAv1Row = true}>>, tensor<16x128xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, isMMAv1Row = true}>>) {
%104 = arith.addi %arg19, %c32 : index
%105 = arith.cmpi slt, %104, %47 : index
%106 = arith.remsi %arg20, %c2_i32 : i32
%107 = arith.remsi %arg21, %c2_i32 : i32
%108 = arith.index_cast %107 : i32 to index
%200 = arith.index_cast %106 : i32 to index
%109 = tt.splat %105 : (i1) -> tensor<256x32xi1, #blocked0>
%112 = tt.splat %105 : (i1) -> tensor<32x128xi1, #blocked1>
%110 = tt.load %arg17, %109 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<256x32xf16, #blocked0>
%113 = tt.load %arg18, %112 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<32x128xf16, #blocked1>
%96 = tt.dot %arg22, %arg23, %arg10 {allowTF32 = true} : tensor<256x16xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, isMMAv1Row = true}>> * tensor<16x128xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, isMMAv1Row = true}>> -> tensor<256x128xf32, #mma>
%97 = tensor.extract_slice %arg15[0, 16] [256, 16] [1, 1] : tensor<256x32xf16, #shared0> to tensor<256x16xf16, #shared0>
%98 = triton_gpu.convert_layout %97 : (tensor<256x16xf16, #shared0>) -> tensor<256x16xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, isMMAv1Row = true}>>
%99 = tensor.extract_slice %arg16[16, 0] [16, 128] [1, 1] : tensor<32x128xf16, #shared1> to tensor<16x128xf16, #shared1>
%100 = triton_gpu.convert_layout %99 : (tensor<16x128xf16, #shared1>) -> tensor<16x128xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, isMMAv1Row = true}>>
%101 = tt.dot %98, %100, %96 {allowTF32 = true} : tensor<256x16xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, isMMAv1Row = true}>> * tensor<16x128xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, isMMAv1Row = true}>> -> tensor<256x128xf32, #mma>
%102 = tt.addptr %arg11, %cst : tensor<256x32x!tt.ptr<f16>, #blocked0>, tensor<256x32xi32, #blocked0>
%103 = tt.addptr %arg12, %49 : tensor<32x128x!tt.ptr<f16>, #blocked1>, tensor<32x128xi32, #blocked1>
gpu.barrier
%111 = tensor.insert_slice %110 into %arg13[%200, 0, 0] [1, 256, 32] [1, 1, 1] : tensor<256x32xf16, #blocked0> into tensor<2x256x32xf16, #shared0>
%114 = tensor.insert_slice %113 into %arg14[%200, 0, 0] [1, 32, 128] [1, 1, 1] : tensor<32x128xf16, #blocked1> into tensor<2x32x128xf16, #shared1>
gpu.barrier
%115 = tt.addptr %arg17, %cst : tensor<256x32x!tt.ptr<f16>, #blocked0>, tensor<256x32xi32, #blocked0>
%116 = tt.addptr %arg18, %49 : tensor<32x128x!tt.ptr<f16>, #blocked1>, tensor<32x128xi32, #blocked1>
%117 = tensor.extract_slice %111[%108, 0, 0] [1, 256, 32] [1, 1, 1] : tensor<2x256x32xf16, #shared0> to tensor<256x32xf16, #shared0>
%118 = tensor.extract_slice %114[%108, 0, 0] [1, 32, 128] [1, 1, 1] : tensor<2x32x128xf16, #shared1> to tensor<32x128xf16, #shared1>
%119 = arith.addi %arg20, %c1_i32 : i32
%120 = arith.addi %arg21, %c1_i32 : i32
%121 = tensor.extract_slice %117[0, 0] [256, 16] [1, 1] : tensor<256x32xf16, #shared0> to tensor<256x16xf16, #shared0>
%122 = triton_gpu.convert_layout %121 : (tensor<256x16xf16, #shared0>) -> tensor<256x16xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, isMMAv1Row = true}>>
%123 = tensor.extract_slice %118[0, 0] [16, 128] [1, 1] : tensor<32x128xf16, #shared1> to tensor<16x128xf16, #shared1>
%124 = triton_gpu.convert_layout %123 : (tensor<16x128xf16, #shared1>) -> tensor<16x128xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, isMMAv1Row = true}>>
scf.yield %101, %102, %103, %111, %114, %117, %118, %115, %116, %104, %119, %120, %122, %124 : tensor<256x128xf32, #mma>, tensor<256x32x!tt.ptr<f16>, #blocked0>, tensor<32x128x!tt.ptr<f16>, #blocked1>, tensor<2x256x32xf16, #shared0>, tensor<2x32x128xf16, #shared1>, tensor<256x32xf16, #shared0>, tensor<32x128xf16, #shared1>, tensor<256x32x!tt.ptr<f16>, #blocked0>, tensor<32x128x!tt.ptr<f16>, #blocked1>, index, i32, i32, tensor<256x16xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, isMMAv1Row = true}>>, tensor<16x128xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, isMMAv1Row = true}>>
}
gpu.barrier
%77 = triton_gpu.convert_layout %76#0 : (tensor<256x128xf32, #mma>) -> tensor<256x128xf32, #blocked1>
%78 = arith.addi %20, %18 : tensor<256xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>>
%79 = tt.splat %arg8 : (i32) -> tensor<256x1xi32, #blocked1>
%80 = tt.expand_dims %42 {axis = 0 : i32} : (tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>>) -> tensor<1x128xi32, #blocked1>
%81 = tt.broadcast %80 : (tensor<1x128xi32, #blocked1>) -> tensor<256x128xi32, #blocked1>
%82 = tt.splat %arg2 : (!tt.ptr<f16>) -> tensor<256x128x!tt.ptr<f16>, #blocked1>
%83 = "triton_gpu.cmpi"(%78, %25) {predicate = 2 : i64} : (tensor<256xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>>, tensor<256xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>>) -> tensor<256xi1, #triton_gpu.slice<{dim = 1, parent = #blocked1}>>
%84 = "triton_gpu.cmpi"(%42, %26) {predicate = 2 : i64} : (tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>>, tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>>) -> tensor<128xi1, #triton_gpu.slice<{dim = 0, parent = #blocked1}>>
%85 = tt.expand_dims %84 {axis = 0 : i32} : (tensor<128xi1, #triton_gpu.slice<{dim = 0, parent = #blocked1}>>) -> tensor<1x128xi1, #blocked1>
%86 = tt.broadcast %85 : (tensor<1x128xi1, #blocked1>) -> tensor<256x128xi1, #blocked1>
%87 = tt.expand_dims %78 {axis = 1 : i32} : (tensor<256xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>>) -> tensor<256x1xi32, #blocked1>
%88 = arith.muli %87, %79 : tensor<256x1xi32, #blocked1>
%89 = tt.broadcast %88 : (tensor<256x1xi32, #blocked1>) -> tensor<256x128xi32, #blocked1>
%90 = arith.addi %89, %81 : tensor<256x128xi32, #blocked1>
%91 = tt.addptr %82, %90 : tensor<256x128x!tt.ptr<f16>, #blocked1>, tensor<256x128xi32, #blocked1>
%92 = arith.truncf %77 : tensor<256x128xf32, #blocked1> to tensor<256x128xf16, #blocked1>
%93 = tt.expand_dims %83 {axis = 1 : i32} : (tensor<256xi1, #triton_gpu.slice<{dim = 1, parent = #blocked1}>>) -> tensor<256x1xi1, #blocked1>
%94 = tt.broadcast %93 : (tensor<256x1xi1, #blocked1>) -> tensor<256x128xi1, #blocked1>
%95 = arith.andi %94, %86 : tensor<256x128xi1, #blocked1>
tt.store %91, %92, %95 : tensor<256x128xf16, #blocked1>
return
}
}

View File

@@ -32,7 +32,7 @@ def matmul_no_scf_kernel(
(shape, num_warps, trans_a, trans_b)
for shape in [
[128, 256, 32],
[256, 128, 16],
# [256, 128, 16],
[128, 16, 32],
[32, 128, 64],
[128, 128, 64],
@@ -72,7 +72,7 @@ def test_gemm_no_scf(SHAPE, NUM_WARPS, TRANS_A, TRANS_B):
for shape in [
[64, 128, 128],
[128, 128, 128],
[16, 8, 32],
[16, 16, 32],
[32, 16, 64],
[32, 16, 64],
]
@@ -81,6 +81,8 @@ def test_gemm_no_scf(SHAPE, NUM_WARPS, TRANS_A, TRANS_B):
for trans_b in [False, True]
])
def test_gemm_no_scf_int8(SHAPE, NUM_WARPS, TRANS_A, TRANS_B):
guard_for_volta(is_int8=True)
SIZE_M, SIZE_N, SIZE_K = SHAPE
if (TRANS_A):
@@ -195,6 +197,7 @@ def get_proper_err(a, b, golden):
[128, 64, 128, 4, 128, 64, 32, False, True],
])
def test_gemm(SIZE_M, SIZE_N, SIZE_K, NUM_WARPS, BLOCK_SIZE_M, BLOCK_SIZE_N, BLOCK_SIZE_K, TRANS_A, TRANS_B):
if (TRANS_A):
a = torch.randn((SIZE_K, SIZE_M), device='cuda', dtype=torch.float16).T
else:
@@ -270,6 +273,8 @@ def test_gemm_fp32(M, N, K, num_warps, block_M, block_N, block_K, allow_tf32):
c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N)
tl.store(c_ptrs, accumulator, c_mask)
guard_for_volta(is_tf32=allow_tf32)
# Configure the pytorch counterpart
torch.backends.cuda.matmul.allow_tf32 = allow_tf32
@@ -294,18 +299,16 @@ def test_gemm_fp32(M, N, K, num_warps, block_M, block_N, block_K, allow_tf32):
torch.testing.assert_close(c, golden, rtol=max(1e-4, 1.5 * golden_rel_err), atol=max(1e-4, 1.5 * golden_abs_err))
# NOTE this is useful only on Volta GPU.
@pytest.mark.parametrize('SIZE_M,SIZE_N,SIZE_K,NUM_WARPS,BLOCK_SIZE_M,BLOCK_SIZE_N,BLOCK_SIZE_K,TRANS_A,TRANS_B', [
# Non-forloop
[16, 16, 16, 1, 16, 16, 16, False, False],
[16, 16, 32, 1, 16, 16, 32, False, False],
[32, 16, 32, 1, 32, 16, 32, False, False],
[32, 32, 32, 1, 32, 32, 32, False, False],
[128, 32, 32, 1, 128, 32, 32, False, False],
def guard_for_volta(is_int8=False, is_tf32=False):
'''
Tell whether the test case is valid on Volta GPU.
Some features are WIP, so the corresponding support are missing.
'''
capability = torch.cuda.get_device_capability()
is_on_Volta = capability[0] < 8
# TODO[Superjomn]: Remove the constraints below when features are ready
is_feature_supported = not (is_int8 or is_tf32)
# split-K
[16, 16, 32, 1, 16, 16, 16, False, False],
[64, 64, 128, 1, 64, 64, 32, False, False],
])
def test_gemm_for_mmav1(SIZE_M, SIZE_N, SIZE_K, NUM_WARPS, BLOCK_SIZE_M, BLOCK_SIZE_N, BLOCK_SIZE_K, TRANS_A, TRANS_B):
test_gemm(SIZE_M, SIZE_N, SIZE_K, NUM_WARPS, BLOCK_SIZE_M, BLOCK_SIZE_N, BLOCK_SIZE_K, TRANS_A, TRANS_B)
if is_on_Volta:
if (not is_feature_supported):
pytest.skip("Not valid on Volta")

101
python/tests/test_matmul.py Normal file
View File

@@ -0,0 +1,101 @@
import itertools
import pytest
import torch
import triton
import triton._C.libtriton.triton as _triton
@pytest.mark.parametrize(
"BLOCK_M, BLOCK_N, BLOCK_K, SPLIT_K, NWARP, NSTAGE, M, N, K, AT, BT, DTYPE",
itertools.chain(
*[
[
# 1 warp
(16, 16, 16, 1, 1, 2, None, None, None, AT, BT, DTYPE),
(32, 16, 16, 1, 1, 2, None, None, None, AT, BT, DTYPE),
(16, 32, 16, 1, 1, 2, None, None, None, AT, BT, DTYPE),
(16, 16, 32, 1, 1, 2, None, None, None, AT, BT, DTYPE),
(32, 16, 32, 1, 1, 2, None, None, None, AT, BT, DTYPE),
(16, 32, 32, 1, 1, 2, None, None, None, AT, BT, DTYPE),
(16, 16, 64, 1, 1, 2, None, None, None, AT, BT, DTYPE),
(64, 16, 64, 1, 1, 2, None, None, None, AT, BT, DTYPE),
(16, 64, 64, 1, 1, 2, None, None, None, AT, BT, DTYPE),
# 2 warp
(64, 32, 64, 1, 2, 2, None, None, None, AT, BT, DTYPE),
(32, 64, 64, 1, 2, 2, None, None, None, AT, BT, DTYPE),
(64, 32, 16, 1, 2, 2, None, None, None, AT, BT, DTYPE),
(32, 64, 16, 1, 2, 2, None, None, None, AT, BT, DTYPE),
(128, 32, 32, 1, 2, 2, None, None, None, AT, BT, DTYPE),
(32, 128, 32, 1, 2, 2, None, None, None, AT, BT, DTYPE),
# 4 warp
(128, 64, 16, 1, 4, 2, None, None, None, AT, BT, DTYPE),
(64, 128, 16, 1, 4, 2, None, None, None, AT, BT, DTYPE),
(128, 32, 32, 1, 4, 2, None, None, None, AT, BT, DTYPE),
(32, 128, 32, 1, 4, 2, None, None, None, AT, BT, DTYPE),
(128, 32, 64, 1, 4, 2, None, None, None, AT, BT, DTYPE),
(32, 128, 64, 1, 4, 2, None, None, None, AT, BT, DTYPE),
# 8 warp
(128, 256, 16, 1, 8, 2, None, None, None, AT, BT, DTYPE),
(256, 128, 16, 1, 8, 2, None, None, None, AT, BT, DTYPE),
(256, 128, 32, 1, 8, 2, None, None, None, AT, BT, DTYPE),
# split-k
(64, 64, 16, 2, 4, 2, None, None, None, AT, BT, DTYPE),
(64, 64, 16, 4, 4, 2, None, None, None, AT, BT, DTYPE),
(64, 64, 16, 8, 4, 2, None, None, None, AT, BT, DTYPE),
# variable input
(128, 128, 32, 1, 4, 2, 1024, 1024, 1024, AT, BT, DTYPE),
(128, 128, 32, 1, 4, 2, 384, 128, 640, AT, BT, DTYPE),
(128, 128, 32, 1, 4, 2, 107, 233, 256, AT, BT, DTYPE),
(128, 128, 32, 1, 4, 2, 107, 233, 311, AT, BT, DTYPE),
] for DTYPE in ["float16", "bfloat16", "float32"] for AT in [False, True] for BT in [False, True]
],
# n-stage
*[
[
(16, 16, 16, 1, 1, STAGES, 1024, 1024, 1024, AT, BT, DTYPE),
(64, 32, 64, 1, 2, STAGES, 1024, 1024, 1024, AT, BT, DTYPE),
(128, 64, 16, 1, 4, STAGES, 1024, 1024, 1024, AT, BT, DTYPE),
(256, 128, 32, 1, 8, STAGES, 1024, 1024, 1024, AT, BT, DTYPE),
(128, 128, 32, 1, 4, STAGES, 384, 128, 640, AT, BT, DTYPE),
# split-k
(64, 64, 16, 8, 4, STAGES, 1024, 1024, 1024, AT, BT, DTYPE),
(64, 64, 16, 8, 4, STAGES, 1024, 1024, 32, AT, BT, DTYPE),
] for DTYPE in ["float16", "bfloat16", "float32"] for AT in [False, True] for BT in [False, True] for STAGES in [2, 3, 4]
]
),
)
def test_op(BLOCK_M, BLOCK_N, BLOCK_K, SPLIT_K, NWARP, NSTAGE, M, N, K, AT, BT, DTYPE):
capability = torch.cuda.get_device_capability()
if capability[0] < 7:
pytest.skip("Only test tl.dot() on devices with sm >= 70")
if capability[0] < 8 and DTYPE == "bfloat16":
pytest.skip("Only test bfloat16 on devices with sm >= 80")
#if DTYPE == "bfloat16" and SPLIT_K != 1:
# pytest.skip("bfloat16 matmuls don't allow split_k for now")
if DTYPE == "bfloat16":
pytest.skip("bfloat16 matmuls doesn't support for now")
torch.manual_seed(0)
# nuke kernel decorators -- will set meta-parameters manually
kwargs = {'BLOCK_M': BLOCK_M, 'BLOCK_N': BLOCK_N, 'BLOCK_K': BLOCK_K, 'SPLIT_K': SPLIT_K}
pre_hook = None if SPLIT_K == 1 else lambda nargs: nargs['C'].zero_()
configs = [triton.Config(kwargs=kwargs, num_warps=NWARP, num_stages=NSTAGE, pre_hook=pre_hook)]
kernel = triton.ops._matmul.kernel
kernel.configs = configs
# kernel.run = kernel.run.run.run
# get matrix shape
M = BLOCK_M if M is None else M
N = BLOCK_N if N is None else N
K = BLOCK_K * SPLIT_K if K is None else K
# allocate/transpose inputs
DTYPE = {"float16": torch.float16, "bfloat16": torch.bfloat16, "float32": torch.float32}[DTYPE]
a = .1 * torch.randn((K, M) if AT else (M, K), device="cuda", dtype=DTYPE)
b = .1 * torch.randn((N, K) if BT else (K, N), device="cuda", dtype=DTYPE)
a = a.t() if AT else a
b = b.t() if BT else b
# run test
th_c = torch.matmul(a, b)
tt_c = triton.testing.catch_oor(lambda: triton.ops.matmul(a, b), pytest)
triton.testing.assert_almost_equal(th_c, tt_c)

View File

@@ -0,0 +1,164 @@
import subprocess
import sys
import pytest
import torch
import triton
import triton.language as tl
from triton.testing import get_dram_gbps, get_max_tensorcore_tflops, set_gpu_clock
DEVICE_NAME = 'v100'
#######################
# Utilities
#######################
def nvsmi(attrs):
attrs = ','.join(attrs)
cmd = ['nvidia-smi', '-i', '0', '--query-gpu=' + attrs, '--format=csv,noheader,nounits']
out = subprocess.check_output(cmd)
ret = out.decode(sys.stdout.encoding).split(',')
ret = [int(x) for x in ret]
return ret
#######################
# Matrix Multiplication
#######################
sm_clocks = {'v100': 1350, 'a100': 1350}
mem_clocks = {'v100': 877, 'a100': 1215}
matmul_data = {
'v100': {
# square
(256, 256, 256): {'float16': 0.027},
(512, 512, 512): {'float16': 0.158},
(1024, 1024, 1024): {'float16': 0.466},
(2048, 2048, 2048): {'float16': 0.695},
(4096, 4096, 4096): {'float16': 0.831},
(8192, 8192, 8192): {'float16': 0.849},
# tall-skinny
(16, 1024, 1024): {'float16': 0.0128},
(16, 4096, 4096): {'float16': 0.0883},
(16, 8192, 8192): {'float16': 0.101},
(64, 1024, 1024): {'float16': 0.073},
(64, 4096, 4096): {'float16': 0.270},
(64, 8192, 8192): {'float16': 0.459},
(1024, 64, 1024): {'float16': 0.0692},
(4096, 64, 4096): {'float16': 0.264},
(8192, 64, 8192): {'float16': 0.452},
},
'a100': {
(256, 256, 256): {'float16': 0.010, 'float32': 0.0214, 'int8': 0.006},
(512, 512, 512): {'float16': 0.061, 'float32': 0.109, 'int8': 0.030},
(1024, 1024, 1024): {'float16': 0.287, 'float32': 0.331, 'int8': 0.169},
(2048, 2048, 2048): {'float16': 0.604, 'float32': 0.599, 'int8': 0.385},
(4096, 4096, 4096): {'float16': 0.842, 'float32': 0.862, 'int8': 0.711},
(8192, 8192, 8192): {'float16': 0.896, 'float32': 0.932, 'int8': 0.860},
# tall-skinny
(16, 1024, 1024): {'float16': 0.0077, 'float32': 0.0127, 'int8': 0.005},
(16, 4096, 4096): {'float16': 0.0363, 'float32': 0.0457, 'int8': 0.0259},
(16, 8192, 8192): {'float16': 0.0564, 'float32': 0.0648, 'int8': 0.0431},
(64, 1024, 1024): {'float16': 0.0271, 'float32': 0.0509, 'int8': 0.0169},
(64, 4096, 4096): {'float16': 0.141, 'float32': 0.162, 'int8': 0.097},
(64, 8192, 8192): {'float16': 0.244, 'float32': 0.257, 'int8': 0.174},
(1024, 64, 1024): {'float16': 0.0263, 'float32': 0.0458, 'int8': 0.017},
(4096, 64, 4096): {'float16': 0.135, 'float32': 0.177, 'int8': 0.102},
(8192, 64, 8192): {'float16': 0.216, 'float32': 0.230, 'int8': 0.177},
}
# # deep reductions
# (64 , 64 , 16384) : {'a100': 0.},
# (64 , 64 , 65536) : {'a100': 0.},
# (256 , 256 , 8192 ) : {'a100': 0.},
# (256 , 256 , 32768) : {'a100': 0.},
}
@pytest.mark.parametrize('M, N, K, dtype_str',
[(M, N, K, dtype_str)
for M, N, K in matmul_data[DEVICE_NAME].keys()
for dtype_str in ['float16']])
def test_matmul(M, N, K, dtype_str):
if dtype_str in ['float32', 'int8'] and DEVICE_NAME != 'a100':
pytest.skip('Only test float32 & int8 on a100')
dtype = {'float16': torch.float16, 'float32': torch.float32, 'int8': torch.int8}[dtype_str]
torch.manual_seed(0)
ref_gpu_util = matmul_data[DEVICE_NAME][(M, N, K)][dtype_str]
ref_sm_clock = sm_clocks[DEVICE_NAME]
cur_sm_clock = nvsmi(['clocks.current.sm'])[0]
max_gpu_perf = get_max_tensorcore_tflops(dtype, clock_rate=cur_sm_clock * 1e3)
assert abs(cur_sm_clock - ref_sm_clock) < 10, f'GPU SMs must run at {ref_sm_clock} MHz'
if dtype == torch.int8:
a = torch.randint(-128, 127, (M, K), dtype=dtype, device='cuda')
b = torch.randint(-128, 127, (N, K), dtype=dtype, device='cuda')
b = b.t() # only test row-col layout
else:
a = torch.randn((M, K), dtype=dtype, device='cuda')
b = torch.randn((K, N), dtype=dtype, device='cuda')
fn = lambda: triton.ops.matmul(a, b)
ms = triton.testing.do_bench(fn, percentiles=None, warmup=25, rep=1000)
cur_gpu_perf = 2. * M * N * K / ms * 1e-9
cur_gpu_util = cur_gpu_perf / max_gpu_perf
triton.testing.assert_almost_equal(cur_gpu_util, ref_gpu_util, decimal=2)
#######################
# Element-Wise
#######################
@triton.jit
def _add(x_ptr, y_ptr, output_ptr, n_elements,
BLOCK_SIZE: tl.constexpr):
pid = tl.program_id(axis=0)
block_start = pid * BLOCK_SIZE
offsets = block_start + tl.arange(0, BLOCK_SIZE)
mask = offsets < n_elements
x = tl.load(x_ptr + offsets, mask=mask)
y = tl.load(y_ptr + offsets, mask=mask)
output = x + y
tl.store(output_ptr + offsets, output, mask=mask)
elementwise_data = {
'v100': {
1024 * 16: 0.0219,
1024 * 64: 0.0791,
1024 * 256: 0.243,
1024 * 1024: 0.530,
1024 * 4096: 0.796,
1024 * 16384: 0.905,
1024 * 65536: 0.939,
},
'a100': {
1024 * 16: 0.008,
1024 * 64: 0.034,
1024 * 256: 0.114,
1024 * 1024: 0.315,
1024 * 4096: 0.580,
1024 * 16384: 0.782,
1024 * 65536: 0.850,
}
}
@pytest.mark.parametrize('N', elementwise_data[DEVICE_NAME].keys())
def test_elementwise(N):
torch.manual_seed(0)
ref_gpu_util = elementwise_data[DEVICE_NAME][N]
cur_mem_clock = nvsmi(['clocks.current.memory'])[0]
ref_mem_clock = mem_clocks[DEVICE_NAME]
max_gpu_perf = get_dram_gbps()
assert abs(cur_mem_clock - ref_mem_clock) < 10, f'GPU memory must run at {ref_mem_clock} MHz'
z = torch.empty((N, ), dtype=torch.float16, device='cuda')
x = torch.randn_like(z)
y = torch.randn_like(z)
grid = lambda args: (triton.cdiv(N, args['BLOCK_SIZE']), )
fn = lambda: _add[grid](x, y, z, N, BLOCK_SIZE=1024)
ms = triton.testing.do_bench(fn, percentiles=None, warmup=25, rep=250)
cur_gpu_perf = 3. * N * z.element_size() / ms * 1e-6
cur_gpu_util = cur_gpu_perf / max_gpu_perf
triton.testing.assert_almost_equal(cur_gpu_util, ref_gpu_util, decimal=2)

View File

@@ -68,7 +68,7 @@ def test_vecadd_scf_mask(shape, num_warps, block_size, iter_size):
@num_elements: number of elements
'''
pid = tl.program_id(axis=0)
for i in range(math.ceil(block_size / iter_size)):
for i in range(tl.cdiv(block_size, iter_size)):
# TODO: a bug here, if put the offset outside the forloop, there will be a GPU mis-aligned error.
offset = pid * block_size + tl.arange(0, iter_size)
x_ptrs = x_ptr + offset

View File

@@ -329,10 +329,6 @@ class CodeGenerator(ast.NodeVisitor):
def visit_BinOp(self, node):
lhs = self.visit(node.left)
rhs = self.visit(node.right)
if isinstance(lhs, triton.language.constexpr):
lhs = lhs.value
if isinstance(rhs, triton.language.constexpr):
rhs = rhs.value
fn = {
ast.Add: '__add__',
ast.Sub: '__sub__',
@@ -591,8 +587,10 @@ class CodeGenerator(ast.NodeVisitor):
ast.NodeVisitor.generic_visit(self, stmt)
return
# handle negative constant step (not supported by scf.for in MLIR)
negative_step = False
if isinstance(step, triton.language.constexpr) and step.value < 0:
step = triton.language.constexpr(-step.value)
negative_step = True
lb, ub = ub, lb
# lb/ub/step might be constexpr, we need to cast them to tensor
lb = triton.language.core._to_tensor(lb, self.builder).handle
@@ -640,6 +638,9 @@ class CodeGenerator(ast.NodeVisitor):
# update induction variable with actual value, and replace all uses
self.builder.set_insertion_point_to_start(for_op.get_body(0))
iv = self.builder.create_index_to_si(for_op.get_induction_var())
if negative_step:
ub_si = self.builder.create_index_to_si(ub)
iv = self.builder.create_sub(ub_si, iv)
self.lscope[node.target.id].handle.replace_all_uses_with(iv)
self.set_value(name, triton.language.core.tensor(iv, triton.language.core.int32))
@@ -890,9 +891,9 @@ def ttir_to_ttgir(mod, num_warps, num_stages, compute_capability):
pm = _triton.ir.pass_manager(mod.context)
pm.add_convert_triton_to_tritongpu_pass(num_warps)
pm.enable_debug()
# Convert blocked layout to mma layout for dot ops so that pipeline
# can get shared memory swizzled correctly.
pm.add_coalesce_pass()
# The combine pass converts blocked layout to mma layout
# for dot ops so that pipeline can get shared memory swizzled correctly.
pm.add_triton_gpu_combine_pass(compute_capability)
pm.add_tritongpu_pipeline_pass(num_stages)
# Prefetch must be done after pipeline pass because pipeline pass
@@ -1358,12 +1359,12 @@ def make_hash(fn, **kwargs):
return hashlib.md5((Path(fn).read_text() + triton.runtime.jit.version_key()).encode("utf-8")).hexdigest()
# - ^\s*func\s+ : match the start of the string, any leading whitespace, the keyword func,
# - ^\s*func\s+ : match the start of the string, any leading whitespace, the keyword func,
# and any following whitespace
# - (public\s+)? : optionally match the keyword public and any following whitespace
# - (@\w+) : match an @ symbol followed by one or more word characters
# - (@\w+) : match an @ symbol followed by one or more word characters
# (letters, digits, or underscores), and capture it as group 1 (the function name)
# - (\((?:%\w+: \S+(?: \{\S+ = \S+ : \S+\})?(?:, )?)*\)) : match a pair of parentheses enclosing
# - (\((?:%\w+: \S+(?: \{\S+ = \S+ : \S+\})?(?:, )?)*\)) : match a pair of parentheses enclosing
# zero or more arguments separated by commas, and capture it as group 2 (the argument list)
mlir_prototype_pattern = r'^\s*func\s+(?:public\s+)?(@\w+)(\((?:%\w+: \S+(?: \{\S+ = \S+ : \S+\})?(?:, )?)*\))\s*\{\s*$'
ptx_prototype_pattern = r"\.(?:visible|extern)\s+\.(?:entry|func)\s+(\w+)\s*\(([^)]*)\)"
@@ -1384,6 +1385,8 @@ arg_type_pattern = {
# def compile(fn, signature: str, device: int = -1, constants=dict(), num_warps: int = 4, num_stages: int = 3, extern_libs=None, configs=None):
def compile(fn, **kwargs):
capability = torch.cuda.get_device_capability()
capability = capability[0] * 10 + capability[1]
# we get the kernel, i.e. the first function generated in the module
# if fn is not a JITFunction, then it
# has to be a path to a file
@@ -1391,24 +1394,22 @@ def compile(fn, **kwargs):
asm = dict()
constants = kwargs.get("constants", dict())
num_warps = kwargs.get("num_warps", 4)
num_stages = kwargs.get("num_stages", 3)
num_stages = kwargs.get("num_stages", 3 if capability >= 75 else 2)
extern_libs = kwargs.get("extern_libs", dict())
device = kwargs.get("device", torch.cuda.current_device())
capability = torch.cuda.get_device_capability()
capability = capability[0]*10 + capability[1]
# build compilation stages
stages = {
"ast" : (lambda path: fn, None),
"ttir": (lambda path: _triton.ir.parse_mlir_module(path, context),
lambda src: ast_to_ttir(src, signature, configs[0], constants)),
"ttgir": (lambda path: _triton.ir.parse_mlir_module(path, context),
lambda src: ttir_to_ttgir(src, num_warps, num_stages, capability)),
"llir": (lambda path: Path(path).read_bytes(),
lambda src: ttgir_to_llir(src, extern_libs, capability)),
"ptx": (lambda path: Path(path).read_text(),
lambda src: llir_to_ptx(src, capability)),
"cubin": (lambda path: Path(path).read_bytes(),
lambda src: ptx_to_cubin(src, capability))
"ast": (lambda path: fn, None),
"ttir": (lambda path: _triton.ir.parse_mlir_module(path, context),
lambda src: ast_to_ttir(src, signature, configs[0], constants)),
"ttgir": (lambda path: _triton.ir.parse_mlir_module(path, context),
lambda src: ttir_to_ttgir(src, num_warps, num_stages, capability)),
"llir": (lambda path: Path(path).read_bytes(),
lambda src: ttgir_to_llir(src, extern_libs, capability)),
"ptx": (lambda path: Path(path).read_text(),
lambda src: llir_to_ptx(src, capability)),
"cubin": (lambda path: Path(path).read_bytes(),
lambda src: ptx_to_cubin(src, capability))
}
# find out the signature of the function
if isinstance(fn, triton.runtime.JITFunction):
@@ -1430,9 +1431,7 @@ def compile(fn, **kwargs):
import re
match = re.search(prototype_pattern[ir], src, re.MULTILINE)
name, signature = match.group(1), match.group(2)
print(name, signature)
types = re.findall(arg_type_pattern[ir], signature)
print(types)
param_tys = [convert_type_repr(ty) for ty in types]
signature = {k: v for k, v in enumerate(param_tys)}
first_stage = list(stages.keys()).index(ir)
@@ -1467,8 +1466,8 @@ def compile(fn, **kwargs):
if ir == ext:
next_module = parse(fn)
elif os.path.exists(path) and\
ir in metadata["ctime"] and\
os.path.getctime(path) == metadata["ctime"][ir]:
ir in metadata["ctime"] and\
os.path.getctime(path) == metadata["ctime"][ir]:
next_module = parse(path)
else:
next_module = compile(module)
@@ -1504,8 +1503,7 @@ class CompiledKernel:
self.asm = asm
device = torch.cuda.current_device()
global cuda_utils
if cuda_utils is None:
cuda_utils = CudaUtils()
init_cuda_utils()
mod, func, n_regs, n_spills = cuda_utils.load_binary(metadata["name"], self.asm["cubin"], self.shared, device)
self.cu_module = mod
self.cu_function = func
@@ -1562,6 +1560,34 @@ class CudaUtils(object):
#define CUDA_CHECK(ans) { gpuAssert((ans), __FILE__, __LINE__); if(PyErr_Occurred()) return NULL; }
static PyObject* getDeviceProperties(PyObject* self, PyObject* args){
int device_id;
if(!PyArg_ParseTuple(args, "i", &device_id))
return NULL;
// Get device handle
CUdevice device;
cuDeviceGet(&device, device_id);
// create a struct to hold device properties
int max_shared_mem;
int multiprocessor_count;
int sm_clock_rate;
int mem_clock_rate;
int mem_bus_width;
CUDA_CHECK(cuDeviceGetAttribute(&max_shared_mem, CU_DEVICE_ATTRIBUTE_MAX_SHARED_MEMORY_PER_BLOCK, device));
CUDA_CHECK(cuDeviceGetAttribute(&multiprocessor_count, CU_DEVICE_ATTRIBUTE_MULTIPROCESSOR_COUNT, device));
CUDA_CHECK(cuDeviceGetAttribute(&sm_clock_rate, CU_DEVICE_ATTRIBUTE_CLOCK_RATE, device));
CUDA_CHECK(cuDeviceGetAttribute(&mem_clock_rate, CU_DEVICE_ATTRIBUTE_MEMORY_CLOCK_RATE, device));
CUDA_CHECK(cuDeviceGetAttribute(&mem_bus_width, CU_DEVICE_ATTRIBUTE_GLOBAL_MEMORY_BUS_WIDTH, device));
return Py_BuildValue("{s:i, s:i, s:i, s:i, s:i}", "max_shared_mem", max_shared_mem,
"multiprocessor_count", multiprocessor_count,
"sm_clock_rate", sm_clock_rate,
"mem_clock_rate", mem_clock_rate,
"mem_bus_width", mem_bus_width);
}
static PyObject* loadBinary(PyObject* self, PyObject* args) {
const char* name;
const char* data;
@@ -1601,6 +1627,7 @@ class CudaUtils(object):
static PyMethodDef ModuleMethods[] = {
{"load_binary", loadBinary, METH_VARARGS, "Load provided cubin into CUDA driver"},
{"get_device_properties", getDeviceProperties, METH_VARARGS, "Get the properties for a given device"},
{NULL, NULL, 0, NULL} // sentinel
};
@@ -1640,6 +1667,13 @@ class CudaUtils(object):
mod = importlib.util.module_from_spec(spec)
spec.loader.exec_module(mod)
self.load_binary = mod.load_binary
self.get_device_properties = mod.get_device_properties
def init_cuda_utils():
global cuda_utils
if cuda_utils is None:
cuda_utils = CudaUtils()
cuda_utils = None

View File

@@ -9,6 +9,7 @@ from triton._C.libtriton.triton import ir
T = TypeVar('T')
def _to_tensor(x, builder):
if isinstance(x, bool):
return tensor(builder.get_int1(x), int1)
@@ -348,6 +349,9 @@ class constexpr:
def __mul__(self, other):
return constexpr(self.value * other.value)
def __mod__(self, other):
return constexpr(self.value % other.value)
def __rmul__(self, other):
return constexpr(other.value * self.value)
@@ -726,10 +730,12 @@ def broadcast_to(input, shape, _builder=None):
"""
return semantic.broadcast_impl_shape(input, shape, _builder)
@builtin
def trans(input, _builder=None):
return semantic.trans(input, _builder)
@builtin
def cat(input, other, can_reorder=False, _builder=None):
"""
@@ -762,6 +768,7 @@ def view(input, shape, _builder=None):
shape = [x.value for x in shape]
return semantic.view(input, shape, _builder)
@builtin
def reshape(input, shape, _builder=None):
# TODO: should be more than just a view

View File

@@ -481,7 +481,8 @@ def view(input: tl.tensor,
dst_shape: List[int],
builder: ir.builder) -> tl.tensor:
# TODO: disable when TritonToTritonGPU handles views properly
assert len(input.shape) == len(dst_shape)
# assert len(input.shape) == len(dst_shape)
numel = 1
for s in dst_shape:
numel *= s
@@ -982,6 +983,11 @@ def dot(lhs: tl.tensor,
allow_tf32: bool,
builder: ir.builder) -> tl.tensor:
assert lhs.type.is_block() and rhs.type.is_block()
assert len(lhs.shape) == 2 and len(rhs.shape) == 2
assert lhs.shape[1].value == rhs.shape[0].value
assert lhs.shape[0].value >= 16 and lhs.shape[1].value >= 16 \
and rhs.shape[1].value >= 16,\
"small blocks not supported!"
if lhs.type.scalar.is_int():
_0 = builder.get_int32(0)
ret_scalar_ty = tl.int32
@@ -1138,7 +1144,7 @@ def max_contiguous(x: tl.tensor, values: List[int]) -> tl.tensor:
def debug_barrier(builder: ir.builder) -> tl.tensor:
return tl.tensor(builder.create_barrier(''), tl.void)
return tl.tensor(builder.create_barrier(), tl.void)
def printf(prefix: str, args: List[tl.tensor], builder: ir.builder) -> tl.tensor:

View File

@@ -26,32 +26,30 @@ def get_configs_io_bound():
return configs
@triton.heuristics({
'EVEN_K': lambda args: args['K'] % (args['BLOCK_K'] * args['SPLIT_K']) == 0,
})
@triton.autotune(
configs=[
# basic configs for compute-bound matmuls
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 256, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=3, num_warps=8),
triton.Config({'BLOCK_M': 256, 'BLOCK_N': 128, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=3, num_warps=8),
triton.Config({'BLOCK_M': 256, 'BLOCK_N': 64, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_M': 64, 'BLOCK_N': 256, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_M': 64, 'BLOCK_N': 128, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 32, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_M': 64, 'BLOCK_N': 32, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=5, num_warps=2),
# good for int8
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 256, 'BLOCK_K': 128, 'SPLIT_K': 1}, num_stages=3, num_warps=8),
triton.Config({'BLOCK_M': 256, 'BLOCK_N': 128, 'BLOCK_K': 128, 'SPLIT_K': 1}, num_stages=3, num_warps=8),
triton.Config({'BLOCK_M': 256, 'BLOCK_N': 64, 'BLOCK_K': 128, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_M': 64, 'BLOCK_N': 256, 'BLOCK_K': 128, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'BLOCK_K': 128, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'BLOCK_K': 64, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_M': 64, 'BLOCK_N': 128, 'BLOCK_K': 64, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 32, 'BLOCK_K': 64, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_M': 64, 'BLOCK_N': 32, 'BLOCK_K': 64, 'SPLIT_K': 1}, num_stages=5, num_warps=2),
] + get_configs_io_bound(),
#configs=[
# # basic configs for compute-bound matmuls
# triton.Config({'BLOCK_M': 128, 'BLOCK_N': 256, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=3, num_warps=8),
# triton.Config({'BLOCK_M': 256, 'BLOCK_N': 128, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=3, num_warps=8),
# triton.Config({'BLOCK_M': 256, 'BLOCK_N': 64, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
# triton.Config({'BLOCK_M': 64, 'BLOCK_N': 256, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
# triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
# triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
# triton.Config({'BLOCK_M': 64, 'BLOCK_N': 128, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
# triton.Config({'BLOCK_M': 128, 'BLOCK_N': 32, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
# triton.Config({'BLOCK_M': 64, 'BLOCK_N': 32, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=5, num_warps=2),
# # good for int8
# triton.Config({'BLOCK_M': 128, 'BLOCK_N': 256, 'BLOCK_K': 128, 'SPLIT_K': 1}, num_stages=3, num_warps=8),
# triton.Config({'BLOCK_M': 256, 'BLOCK_N': 128, 'BLOCK_K': 128, 'SPLIT_K': 1}, num_stages=3, num_warps=8),
# triton.Config({'BLOCK_M': 256, 'BLOCK_N': 64, 'BLOCK_K': 128, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
# triton.Config({'BLOCK_M': 64, 'BLOCK_N': 256, 'BLOCK_K': 128, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
# triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'BLOCK_K': 128, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
# triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'BLOCK_K': 64, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
# triton.Config({'BLOCK_M': 64, 'BLOCK_N': 128, 'BLOCK_K': 64, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
# triton.Config({'BLOCK_M': 128, 'BLOCK_N': 32, 'BLOCK_K': 64, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
# triton.Config({'BLOCK_M': 64, 'BLOCK_N': 32, 'BLOCK_K': 64, 'SPLIT_K': 1}, num_stages=5, num_warps=2),
#] + get_configs_io_bound(),
configs=[triton.Config({'BLOCK_M': 256, 'BLOCK_N': 128, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=2, num_warps=8)],
key=['M', 'N', 'K'],
prune_configs_by={
'early_config_prune': early_config_prune,
@@ -59,6 +57,9 @@ def get_configs_io_bound():
'top_k': 10
},
)
@triton.heuristics({
'EVEN_K': lambda args: args['K'] % (args['BLOCK_K'] * args['SPLIT_K']) == 0,
})
@triton.jit
def _kernel(A, B, C, M, N, K,
stride_am, stride_ak,
@@ -113,7 +114,7 @@ def _kernel(A, B, C, M, N, K,
class _matmul(torch.autograd.Function):
kernel = _kernel
kernel = None
_locks = dict()
@@ -134,12 +135,17 @@ class _matmul(torch.autograd.Function):
# accumulator types
ACC_TYPE = tl.float32 if a.dtype in [torch.float16, torch.bfloat16, torch.float32] else tl.int32
# launch kernel
grid = lambda META: (triton.cdiv(M, META['BLOCK_M']) * triton.cdiv(N, META['BLOCK_N']), META['SPLIT_K'])
_kernel[grid](a, b, c, M, N, K,
a.stride(0), a.stride(1),
b.stride(0), b.stride(1),
c.stride(0), c.stride(1),
GROUP_M=8, ACC_TYPE=ACC_TYPE)
#grid = lambda META: (triton.cdiv(M, META['BLOCK_M']) * triton.cdiv(N, META['BLOCK_N']), META['SPLIT_K'])
if _matmul.kernel is None:
_matmul.kernel = triton.compile("/root/code/triton-mlir/python/tests/matmul.ttgir", num_stages=2, num_warps=8)
#_matmul.kernel = _kernel
_matmul.kernel[(8192//256 * 8192//128, 1, 1,)](a.data_ptr(), b.data_ptr(), c.data_ptr(),
M, N, K,
a.stride(0), b.stride(0), c.stride(0))
#_matmul.kernel[grid](a, b, c,
# M, N, K,
# a.stride(0), a.stride(1), b.stride(0), b.stride(1), c.stride(0), c.stride(1),
# GROUP_M=8, ACC_TYPE=ACC_TYPE)
return c
@staticmethod

View File

@@ -10,7 +10,9 @@ from triton.testing import get_dram_gbps, get_max_simd_tflops, get_max_tensorcor
def get_tensorcore_tflops(backend, device, num_ctas, num_warps, dtype):
''' return compute throughput in TOPS '''
total_warps = num_ctas * min(num_warps, 4)
num_subcores = _triton.runtime.num_sm(backend, device) * 4 # on recent GPUs
triton.compiler.init_cuda_utils()
num_subcores = triton.compiler.cuda_utils.get_device_properties(device)["multiprocessor_count"] * 4 # on recent GPUs
tflops = min(num_subcores, total_warps) / num_subcores * get_max_tensorcore_tflops(dtype, backend, device)
return tflops
@@ -18,14 +20,14 @@ def get_tensorcore_tflops(backend, device, num_ctas, num_warps, dtype):
def get_simd_tflops(backend, device, num_ctas, num_warps, dtype):
''' return compute throughput in TOPS '''
total_warps = num_ctas * min(num_warps, 4)
num_subcores = _triton.runtime.num_sm(backend, device) * 4 # on recent GPUs
num_subcores = triton.compiler.cuda_utils.get_device_properties(device)["multiprocessor_count"] * 4 # on recent GPUs
tflops = min(num_subcores, total_warps) / num_subcores * get_max_simd_tflops(dtype, backend, device)
return tflops
def get_tflops(backend, device, num_ctas, num_warps, dtype):
cc = _triton.runtime.cc(backend, device)
if cc < 80 and dtype == torch.float32:
capability = torch.cuda.get_device_capability(device)
if capability[0] < 8 and dtype == torch.float32:
return get_simd_tflops(backend, device, num_ctas, num_warps, dtype)
return get_tensorcore_tflops(backend, device, num_ctas, num_warps, dtype)
@@ -59,7 +61,7 @@ def estimate_matmul_time(
compute_ms = total_ops / tput
# time to load data
num_sm = _triton.runtime.num_sm(backend, device)
num_sm = triton.compiler.cuda_utils.get_device_properties(device)["multiprocessor_count"]
active_cta_ratio = min(1, num_ctas / num_sm)
active_cta_ratio_bw1 = min(1, num_ctas / 32) # 32 active ctas are enough to saturate
active_cta_ratio_bw2 = max(min(1, (num_ctas - 32) / (108 - 32)), 0) # 32-108, remaining 5%
@@ -99,7 +101,7 @@ def estimate_matmul_time(
def early_config_prune(configs, named_args):
backend = _triton.runtime.backend.CUDA
device = torch.cuda.current_device()
cc = _triton.runtime.cc(backend, device)
capability = torch.cuda.get_device_capability()
# BLOCK_M, BLOCK_N, BLOCK_K, SPLIT_K, num_warps, num_stages
dtsize = named_args['A'].element_size()
dtype = named_args['A'].dtype
@@ -110,7 +112,10 @@ def early_config_prune(configs, named_args):
kw = config.kwargs
BLOCK_M, BLOCK_N, BLOCK_K, num_stages = \
kw['BLOCK_M'], kw['BLOCK_N'], kw['BLOCK_K'], config.num_stages
max_shared_memory = _triton.runtime.max_shared_memory(backend, device)
# TODO: move to `cuda_utils` submodule
triton.compiler.init_cuda_utils()
max_shared_memory = triton.compiler.cuda_utils.get_device_properties(device)["max_shared_mem"]
required_shared_memory = (BLOCK_M + BLOCK_N) * BLOCK_K * num_stages * dtsize
if required_shared_memory <= max_shared_memory:
pruned_configs.append(config)
@@ -136,7 +141,7 @@ def early_config_prune(configs, named_args):
pruned_configs = []
for k, v in configs_map.items():
BLOCK_M, BLOCK_N, BLOCK_K, SPLIT_K, num_warps = k
if cc >= 80:
if capability[0] >= 8:
# compute cycles (only works for ampere GPUs)
mmas = BLOCK_M * BLOCK_N * BLOCK_K / (16 * 8 * 16)
mma_cycles = mmas / min(4, num_warps) * 8

View File

@@ -74,6 +74,8 @@ class Autotuner(KernelInterface):
for config in pruned_configs}
bench_end = time.time()
self.bench_time = bench_end - bench_start
for config, ttime in timings.items():
print(f"config: {config}, time: {ttime}")
self.cache[key] = builtins.min(timings, key=timings.get)
self.hook(args)
self.configs_timings = timings

View File

@@ -16,6 +16,9 @@ except ImportError:
_cutlass = None
has_cutlass = False
# TODO: move to separate module
import triton
def catch_oor(kernel, pytest_handle=None):
try:
@@ -330,8 +333,8 @@ def get_dram_gbps(backend=None, device=None):
backend = _triton.runtime.backend.CUDA
if not device:
device = torch.cuda.current_device()
mem_clock_khz = _triton.runtime.memory_clock_rate(backend, device)
bus_width = _triton.runtime.global_memory_bus_width(backend, device)
mem_clock_khz = triton.compiler.cuda_utils.get_device_properties(device)["mem_clock_rate"] # in kHz
bus_width = triton.compiler.cuda_utils.get_device_properties(device)["mem_bus_width"]
bw_gbps = mem_clock_khz * bus_width * 2 / 1e6 / 8 # In GB/s
return bw_gbps
@@ -341,11 +344,13 @@ def get_max_tensorcore_tflops(dtype: torch.dtype, backend=None, device=None, clo
backend = _triton.runtime.backend.CUDA
if not device:
device = torch.cuda.current_device()
num_subcores = _triton.runtime.num_sm(backend, device) * 4 # on recent GPUs
triton.compiler.init_cuda_utils()
num_subcores = triton.compiler.cuda_utils.get_device_properties(device)["multiprocessor_count"] * 4
if not clock_rate:
clock_rate = _triton.runtime.clock_rate(backend, device) # in kHz
cc = _triton.runtime.cc(backend, device)
if cc < 80:
clock_rate = triton.compiler.cuda_utils.get_device_properties(device)["sm_clock_rate"] # in kHz
capability = torch.cuda.get_device_capability(device)
if capability[0] < 8:
assert dtype == torch.float16
ops_per_sub_core = 256 # 2 4x4x4 Tensor Cores
else:

View File

@@ -156,7 +156,7 @@ import triton.language as tl
@triton.autotune(
configs=[
triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8}, num_stages=3, num_warps=8),
triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=3, num_warps=8),
],
key=['M', 'N', 'K'],
)

View File

@@ -879,8 +879,8 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} {
#blocked = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [2, 16], warpsPerCTA = [1, 4], order = [1, 0]}>
#shared = #triton_gpu.shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [1, 0]}>
#mma = #triton_gpu.mma<{version = 1, warpsPerCTA = [2, 2]}>
#dot_operand_a = #triton_gpu.dot_op<{opIdx=0, parent=#mma}>
#dot_operand_b = #triton_gpu.dot_op<{opIdx=1, parent=#mma}>
#dot_operand_a = #triton_gpu.dot_op<{opIdx=0, parent=#mma, isMMAv1Row=true}>
#dot_operand_b = #triton_gpu.dot_op<{opIdx=1, parent=#mma, isMMAv1Row=true}>
module attributes {"triton_gpu.num-warps" = 4 : i32} {
func @matmul884_kernel_dot_operand_layout(%ptr:!tt.ptr<f32> {tt.divisibility = 16 : i32},
%a:tensor<128x32xf16, #shared>, %b:tensor<32x256xf16, #shared>) {