Compare commits
13 Commits
keren/inse
...
phil/mma-v
Author | SHA1 | Date | |
---|---|---|---|
|
58d2867fe6 | ||
|
fa6dbbff60 | ||
|
13644e7ac4 | ||
|
0d27912554 | ||
|
3ed36dcb4d | ||
|
83f3b9165b | ||
|
71c35bcf9c | ||
|
c7cf9c6a32 | ||
|
f0885e9caf | ||
|
18e683d9bb | ||
|
4eab9dcedf | ||
|
b2b793dfb5 | ||
|
981aee7f1e |
2
.github/workflows/integration-tests.yml
vendored
2
.github/workflows/integration-tests.yml
vendored
@@ -89,7 +89,7 @@ jobs:
|
||||
if: ${{matrix.runner[0] == 'self-hosted' && matrix.runner[1] == 'V100'}}
|
||||
run: |
|
||||
cd python/tests
|
||||
pytest test_gemm.py::test_gemm_no_scf_for_mmav1
|
||||
pytest test_gemm.py::test_gemm_for_mmav1
|
||||
|
||||
- name: Run CXX unittests
|
||||
run: |
|
||||
|
@@ -103,15 +103,12 @@ def TT_AddPtrOp : TT_Op<"addptr",
|
||||
SameOperandsAndResultShape,
|
||||
SameOperandsAndResultEncoding,
|
||||
TypesMatchWith<"result type matches ptr type",
|
||||
"result", "ptr", "$_self">,
|
||||
TypesMatchWith<"result shape matches offset shape",
|
||||
"result", "offset",
|
||||
"getI32SameShape($_self)">]> {
|
||||
let arguments = (ins TT_PtrLike:$ptr, TT_I32Like:$offset);
|
||||
"result", "ptr", "$_self">]> {
|
||||
let arguments = (ins TT_PtrLike:$ptr, TT_IntLike:$offset);
|
||||
|
||||
let results = (outs TT_PtrLike:$result);
|
||||
|
||||
let assemblyFormat = "$ptr `,` $offset attr-dict `:` type($result)";
|
||||
let assemblyFormat = "$ptr `,` $offset attr-dict `:` type($result) `,` type($offset)";
|
||||
}
|
||||
|
||||
|
||||
|
@@ -416,15 +416,35 @@ In TritonGPU dialect, considering `d = tt.dot a, b, c`
|
||||
tt.dot's operands a and b must be of DotOperandEncodingAttr layout.
|
||||
a's opIdx is 0, b's opIdx is 1.
|
||||
The parend field in DotOperandEncodingAttr is the layout of d.
|
||||
|
||||
For MMA v1, an additional attribute `isMMAv1Row` determines whether e.g. the a operand is used
|
||||
in the context of an mma.884.row.col or an mma.884.col.col operation. See the PTX ISA documentation
|
||||
section 9.7.13.4.1 for more details.
|
||||
}];
|
||||
|
||||
let parameters = (
|
||||
ins
|
||||
"unsigned":$opIdx,
|
||||
"Attribute":$parent
|
||||
"Attribute":$parent,
|
||||
"Attribute":$isMMAv1Row
|
||||
);
|
||||
|
||||
let builders = [
|
||||
AttrBuilder<(ins "unsigned":$opIdx,
|
||||
"Attribute":$parent), [{
|
||||
Attribute isMMAv1Row;
|
||||
if(parent.isa<MmaEncodingAttr>() &&
|
||||
parent.cast<MmaEncodingAttr>().getVersion() == 1){
|
||||
isMMAv1Row = BoolAttr::get(context, true);
|
||||
}
|
||||
return $_get(context, opIdx, parent, isMMAv1Row);
|
||||
}]>
|
||||
|
||||
];
|
||||
|
||||
let extraClassDeclaration = extraBaseClassDeclaration;
|
||||
}
|
||||
|
||||
|
||||
|
||||
#endif
|
||||
|
@@ -32,6 +32,12 @@ def TTG_AsyncWaitOp : TTG_Op<"async_wait"> {
|
||||
let arguments = (ins I32Attr:$num);
|
||||
|
||||
let assemblyFormat = "attr-dict";
|
||||
|
||||
let extraClassDeclaration = [{
|
||||
static bool isSupported(int computeCapability) {
|
||||
return computeCapability >= 80;
|
||||
}
|
||||
}];
|
||||
}
|
||||
|
||||
// Port Arith_CmpIOp & Arith_CmpFOp & Std_SelectOp to TritonGPU.
|
||||
@@ -152,7 +158,13 @@ def TTG_InsertSliceAsyncOp : TTG_Op<"insert_slice_async",
|
||||
//}];
|
||||
|
||||
let extraClassDeclaration = [{
|
||||
static DenseSet<unsigned> getEligibleLoadByteWidth(int computeCapability);
|
||||
static DenseSet<unsigned> getEligibleLoadByteWidth(int computeCapability) {
|
||||
DenseSet<unsigned> validLoadBytes;
|
||||
if (computeCapability >= 80) {
|
||||
validLoadBytes = {4, 8, 16};
|
||||
}
|
||||
return validLoadBytes;
|
||||
}
|
||||
}];
|
||||
|
||||
// The custom parser could be replaced with oilist in LLVM-16
|
||||
|
@@ -43,12 +43,51 @@ using ::mlir::triton::gpu::SharedEncodingAttr;
|
||||
struct DotOpMmaV1ConversionHelper {
|
||||
MmaEncodingAttr mmaLayout;
|
||||
ArrayRef<unsigned> wpt;
|
||||
static constexpr std::array<int, 3> fpw{{2, 2, 1}};
|
||||
|
||||
using ValueTable = std::map<std::pair<int, int>, std::pair<Value, Value>>;
|
||||
|
||||
explicit DotOpMmaV1ConversionHelper(MmaEncodingAttr mmaLayout)
|
||||
: mmaLayout(mmaLayout), wpt(mmaLayout.getWarpsPerCTA()) {}
|
||||
|
||||
// Help to share some variables across multiple functions for A.
|
||||
struct AParam {
|
||||
SmallVector<int> rep;
|
||||
SmallVector<int> spw;
|
||||
|
||||
// TODO[Superjomn]: Support the case when isAVec4=false later
|
||||
// Currently, we only support ld.v2, for the mma layout varies with
|
||||
// different ld vector width.
|
||||
// bool isAVec4 = !isARow && shapeTransed[orderTransed[0]] <= 16;
|
||||
const bool isAVec4{true};
|
||||
|
||||
explicit AParam(bool isARow) {
|
||||
int packSize0 = (isARow || isAVec4) ? 1 : 2;
|
||||
int repM = 2 * packSize0;
|
||||
int repK = 1;
|
||||
int spwM = fpw[0] * 4 * repM;
|
||||
rep.assign({repM, 0, repK});
|
||||
spw.assign({spwM, 0, 1});
|
||||
}
|
||||
};
|
||||
|
||||
// Help to share some variables across multiple functions for A.
|
||||
struct BParam {
|
||||
SmallVector<int> rep;
|
||||
SmallVector<int> spw;
|
||||
// TODO[Superjomn]: Support the case when isBVec4=false later
|
||||
// Currently, we only support ld.v2, for the mma layout varies with
|
||||
// different ld vector width.
|
||||
// bool isBVec4 = isBRow && shapeTransed[orderTransed[0]] <= 16;
|
||||
const bool isBVec4{true};
|
||||
|
||||
explicit BParam(bool isBRow) {
|
||||
int packSize1 = (isBRow && !isBVec4) ? 2 : 1;
|
||||
rep.assign({0, 2 * packSize1, 1});
|
||||
spw.assign({0, fpw[1] * 4 * rep[1], 1});
|
||||
}
|
||||
};
|
||||
|
||||
int getRepM(int M) const {
|
||||
return std::max<int>(M / (wpt[0] * instrShape[0]), 1);
|
||||
}
|
||||
@@ -65,24 +104,34 @@ struct DotOpMmaV1ConversionHelper {
|
||||
return struct_ty(SmallVector<Type>{8, fp32Ty});
|
||||
}
|
||||
|
||||
// number of fp16x2 elements for $a.
|
||||
int numElemsPerThreadA(RankedTensorType tensorTy) const {
|
||||
auto shape = tensorTy.getShape();
|
||||
auto order = getOrder();
|
||||
// Get the number of fp16x2 elements for $a.
|
||||
// \param shapeTransed: the shape or reordered shape if transpose needed.
|
||||
// \param orderTransed: the order or reordered order if transpose needed.
|
||||
unsigned getNumM(ArrayRef<int64_t> shapeTransed,
|
||||
ArrayRef<unsigned> orderTransed) const {
|
||||
bool isARow = orderTransed[0] != 0;
|
||||
AParam param(isARow);
|
||||
|
||||
bool isARow = order[0] != 0;
|
||||
bool isAVec4 = !isARow && shape[order[0]] <= 16; // fp16*4 = 16bytes
|
||||
int packSize0 = (isARow || isAVec4) ? 1 : 2;
|
||||
unsigned numM = param.rep[0] * shapeTransed[0] / (param.spw[0] * wpt[0]);
|
||||
return numM;
|
||||
}
|
||||
|
||||
SmallVector<int> fpw({2, 2, 1});
|
||||
int repM = 2 * packSize0;
|
||||
int repK = 1;
|
||||
int spwM = fpw[0] * 4 * repM;
|
||||
SmallVector<int> rep({repM, 0, repK}); // pad N with 0
|
||||
SmallVector<int> spw({spwM, 0, 1}); // pad N with 0
|
||||
// Get the number of fp16x2 elements for $b.
|
||||
// \param shapeTransed: the shape or reordered shape if transpose needed.
|
||||
// \param orderTransed: the order or reordered order if transpose needed.
|
||||
unsigned getNumN(ArrayRef<int64_t> shapeTransed,
|
||||
ArrayRef<unsigned> orderTransed) const {
|
||||
bool isBRow = orderTransed[0] != 0;
|
||||
BParam param(isBRow);
|
||||
|
||||
int NK = shape[1];
|
||||
unsigned numM = rep[0] * shape[0] / (spw[0] * wpt[0]);
|
||||
unsigned numN = param.rep[1] * shapeTransed[1] / (param.spw[1] * wpt[1]);
|
||||
return numN;
|
||||
}
|
||||
|
||||
int numElemsPerThreadA(ArrayRef<int64_t> shapeTransed,
|
||||
ArrayRef<unsigned> orderTransed) const {
|
||||
int numM = getNumM(shapeTransed, orderTransed);
|
||||
int NK = shapeTransed[1];
|
||||
|
||||
// NOTE: We couldn't get the vec from the shared layout.
|
||||
// int vecA = sharedLayout.getVec();
|
||||
@@ -92,34 +141,27 @@ struct DotOpMmaV1ConversionHelper {
|
||||
return (numM / 2) * (NK / 4) * elemsPerLd;
|
||||
}
|
||||
|
||||
// number of fp16x2 elements for $b.
|
||||
int numElemsPerThreadB(RankedTensorType tensorTy) const {
|
||||
auto shape = tensorTy.getShape();
|
||||
auto order = getOrder();
|
||||
bool isBRow = order[0] != 0;
|
||||
bool isBVec4 = isBRow && shape[order[0]] <= 16;
|
||||
int packSize1 = (isBRow && !isBVec4) ? 2 : 1;
|
||||
SmallVector<int> fpw({2, 2, 1});
|
||||
SmallVector<int> rep({0, 2 * packSize1, 1}); // pad M with 0
|
||||
SmallVector<int> spw({0, fpw[1] * 4 * rep[1], 1}); // pad M with 0
|
||||
int numElemsPerThreadB(ArrayRef<int64_t> shapeTransed,
|
||||
ArrayRef<unsigned> orderTransed) const {
|
||||
unsigned numN = getNumN(shapeTransed, orderTransed);
|
||||
int NK = shapeTransed[0];
|
||||
// NOTE: We couldn't get the vec from the shared layout.
|
||||
// int vecB = sharedLayout.getVec();
|
||||
// TODO[Superjomn]: Consider the case when vecA > 4
|
||||
bool vecGt4 = false;
|
||||
int elemsPerLd = vecGt4 ? 4 : 2;
|
||||
int NK = shape[0];
|
||||
|
||||
unsigned numN = rep[1] * shape[1] / (spw[1] * wpt[0]);
|
||||
return (numN / 2) * (NK / 4) * elemsPerLd;
|
||||
}
|
||||
|
||||
// Loading $a from smem to registers, returns a LLVM::Struct.
|
||||
Value loadA(Value A, const SharedMemoryObject &smemObj, Value thread,
|
||||
Location loc, ConversionPatternRewriter &rewriter) const;
|
||||
Value loadA(Value A, bool transA, const SharedMemoryObject &smemObj,
|
||||
Value thread, Location loc,
|
||||
ConversionPatternRewriter &rewriter) const;
|
||||
|
||||
// Loading $b from smem to registers, returns a LLVM::Struct.
|
||||
Value loadB(Value B, const SharedMemoryObject &smemObj, Value thread,
|
||||
Location loc, ConversionPatternRewriter &rewriter) const;
|
||||
Value loadB(Value B, bool transB, const SharedMemoryObject &smemObj,
|
||||
Value thread, Location loc,
|
||||
ConversionPatternRewriter &rewriter) const;
|
||||
|
||||
static ArrayRef<unsigned> getOrder() { return mmaOrder; }
|
||||
|
||||
@@ -1311,8 +1353,22 @@ struct DotOpFMAConversionHelper {
|
||||
};
|
||||
|
||||
Value DotOpMmaV1ConversionHelper::loadA(
|
||||
Value tensor, const SharedMemoryObject &smemObj, Value thread, Location loc,
|
||||
ConversionPatternRewriter &rewriter) const {
|
||||
Value tensor, bool transA, const SharedMemoryObject &smemObj, Value thread,
|
||||
Location loc, ConversionPatternRewriter &rewriter) const {
|
||||
|
||||
// [1, 0] (isRow = True)
|
||||
// x x x x || x x x x
|
||||
// x x x x || x x x x
|
||||
// stride = [8, 1]
|
||||
// strideA0 = strideAk = 1
|
||||
// strideA1 = strideAm = 8
|
||||
|
||||
// [0, 1] (isRow = False)
|
||||
// x x x x || x x x x
|
||||
// x x x x || x x x x
|
||||
// stride = [1, 2]
|
||||
// strideA0 = strideAm = 1
|
||||
// strideA1 = strideAk = 2
|
||||
|
||||
auto *ctx = rewriter.getContext();
|
||||
auto tensorTy = tensor.getType().cast<RankedTensorType>();
|
||||
@@ -1322,28 +1378,15 @@ Value DotOpMmaV1ConversionHelper::loadA(
|
||||
SmallVector<unsigned> order(sharedLayout.getOrder().begin(),
|
||||
sharedLayout.getOrder().end());
|
||||
|
||||
Value cSwizzleOffset = smemObj.getCSwizzleOffset(order[0]);
|
||||
Value smemBase = smemObj.getBaseBeforeSwizzle(order[0], loc, rewriter);
|
||||
// Value smemBase = smemObj.getBaseBeforeSwizzle(order[0], loc, rewriter);
|
||||
Value smemBase = smemObj.base;
|
||||
|
||||
bool isARow = order[0] != 0;
|
||||
bool isAVec4 = !isARow && shape[order[0]] <= 16; // fp16*4 = 16bytes
|
||||
// TODO[Superjomn]: Support the case when isAVec4=false later
|
||||
// Currently, we only support ld.v2, for the mma layout varies with different
|
||||
// ld vector width.
|
||||
isAVec4 = true;
|
||||
int packSize0 = (isARow || isAVec4) ? 1 : 2;
|
||||
AParam param(isARow);
|
||||
|
||||
SmallVector<int> fpw({2, 2, 1});
|
||||
int repM = 2 * packSize0;
|
||||
int repK = 1;
|
||||
int spwM = fpw[0] * 4 * repM;
|
||||
SmallVector<int> rep({repM, 0, repK}); // pad N with 0
|
||||
SmallVector<int> spw({spwM, 0, 1}); // pad N with 0
|
||||
auto [offsetAM, offsetAK, _0, _1] = computeOffsets(
|
||||
thread, isARow, false, fpw, param.spw, param.rep, rewriter, loc);
|
||||
|
||||
auto [offsetAM, offsetAK, _0, _1] =
|
||||
computeOffsets(thread, isARow, false, fpw, spw, rep, rewriter, loc);
|
||||
// TODO [Superjomn]: transA cannot be accessed in ConvertLayoutOp.
|
||||
bool transA = false;
|
||||
if (transA) {
|
||||
std::swap(shape[0], shape[1]);
|
||||
std::swap(offsetAM, offsetAK);
|
||||
@@ -1358,6 +1401,7 @@ Value DotOpMmaV1ConversionHelper::loadA(
|
||||
Value strideA0 = isARow ? strideAK : strideAM;
|
||||
Value strideA1 = isARow ? strideAM : strideAK;
|
||||
|
||||
smemBase = gep(ptr_ty(f16_ty), smemBase, Value(smemObj.offsets[1]));
|
||||
int strideRepM = wpt[0] * fpw[0] * 8;
|
||||
int strideRepK = 1;
|
||||
|
||||
@@ -1372,7 +1416,9 @@ Value DotOpMmaV1ConversionHelper::loadA(
|
||||
Value offA0 = isARow ? offsetAK : offsetAM;
|
||||
Value offA1 = isARow ? offsetAM : offsetAK;
|
||||
Value phaseA = urem(udiv(offA1, i32_val(perPhaseA)), i32_val(maxPhaseA));
|
||||
offA0 = add(offA0, cSwizzleOffset);
|
||||
// offA0 = add(offA0, smemObj.offsets[order[0]]);
|
||||
// offA1 = add(offA1, smemObj.offsets[order[1]]);
|
||||
|
||||
SmallVector<Value> offA(numPtrA);
|
||||
for (int i = 0; i < numPtrA; i++) {
|
||||
Value offA0I = add(offA0, i32_val(i * (isARow ? 4 : strideRepM)));
|
||||
@@ -1391,10 +1437,9 @@ Value DotOpMmaV1ConversionHelper::loadA(
|
||||
for (int i = 0; i < numPtrA; i++)
|
||||
ptrA[i] = gep(ptr_ty(f16_ty), smemBase, offA[i]);
|
||||
|
||||
unsigned numM = std::max<int>(rep[0] * shape[0] / (spw[0] * wpt[0]), 1);
|
||||
|
||||
Type f16PtrTy = ptr_ty(f16_ty);
|
||||
|
||||
|
||||
auto ld = [&](decltype(has) &vals, int m, int k, Value val0, Value val1) {
|
||||
vals[{m, k}] = {val0, val1};
|
||||
};
|
||||
@@ -1424,6 +1469,10 @@ Value DotOpMmaV1ConversionHelper::loadA(
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
unsigned numM = getNumM(shape, order);
|
||||
llvm::outs() << "LOAD A " << numM << " " << NK << "\n";
|
||||
|
||||
for (unsigned k = 0; k < NK; k += 4)
|
||||
for (unsigned m = 0; m < numM / 2; ++m)
|
||||
loadA(m, k);
|
||||
@@ -1441,8 +1490,8 @@ Value DotOpMmaV1ConversionHelper::loadA(
|
||||
}
|
||||
|
||||
Value DotOpMmaV1ConversionHelper::loadB(
|
||||
Value tensor, const SharedMemoryObject &smemObj, Value thread, Location loc,
|
||||
ConversionPatternRewriter &rewriter) const {
|
||||
Value tensor, bool transB, const SharedMemoryObject &smemObj, Value thread,
|
||||
Location loc, ConversionPatternRewriter &rewriter) const {
|
||||
// smem
|
||||
auto strides = smemObj.strides;
|
||||
|
||||
@@ -1455,20 +1504,11 @@ Value DotOpMmaV1ConversionHelper::loadB(
|
||||
SmallVector<unsigned> order(sharedLayout.getOrder().begin(),
|
||||
sharedLayout.getOrder().end());
|
||||
|
||||
|
||||
Value smem = smemObj.getBaseBeforeSwizzle(order[0], loc, rewriter);
|
||||
bool isBRow = order[0] != 0;
|
||||
bool isBVec4 = isBRow && shape[order[0]] <= 16;
|
||||
// TODO[Superjomn]: Support the case when isBVec4=false later
|
||||
// Currently, we only support ld.v2, for the mma layout varies with different
|
||||
// ld vector width.
|
||||
isBVec4 = true;
|
||||
int packSize1 = (isBRow && !isBVec4) ? 2 : 1;
|
||||
SmallVector<int> fpw({2, 2, 1});
|
||||
SmallVector<int> rep({0, 2 * packSize1, 1}); // pad M with 0
|
||||
SmallVector<int> spw({0, fpw[1] * 4 * rep[1], 1}); // pad M with 0
|
||||
int vecB = sharedLayout.getVec();
|
||||
BParam param(isBRow);
|
||||
|
||||
int vecB = sharedLayout.getVec();
|
||||
Value strideBN = isBRow ? i32_val(1) : strides[1];
|
||||
Value strideBK = isBRow ? strides[0] : i32_val(1);
|
||||
Value strideB0 = isBRow ? strideBN : strideBK;
|
||||
@@ -1476,11 +1516,8 @@ Value DotOpMmaV1ConversionHelper::loadB(
|
||||
int strideRepN = wpt[1] * fpw[1] * 8;
|
||||
int strideRepK = 1;
|
||||
|
||||
// TODO [Superjomn]: transB cannot be accessed in ConvertLayoutOp.
|
||||
bool transB = false;
|
||||
|
||||
auto [_0, _1, offsetBN, offsetBK] =
|
||||
computeOffsets(thread, false, isBRow, fpw, spw, rep, rewriter, loc);
|
||||
auto [_0, _1, offsetBN, offsetBK] = computeOffsets(
|
||||
thread, false, isBRow, fpw, param.spw, param.rep, rewriter, loc);
|
||||
if (transB) {
|
||||
std::swap(order[0], order[1]);
|
||||
std::swap(shape[0], shape[1]);
|
||||
@@ -1547,7 +1584,7 @@ Value DotOpMmaV1ConversionHelper::loadB(
|
||||
}
|
||||
};
|
||||
|
||||
unsigned numN = rep[1] * shape[1] / (spw[1] * wpt[0]);
|
||||
unsigned numN = getNumN(shape, order);
|
||||
for (unsigned k = 0; k < NK; k += 4)
|
||||
for (unsigned n = 0; n < numN / 2; ++n) {
|
||||
if (!hbs.count({n, k}))
|
||||
|
@@ -973,8 +973,10 @@ struct LoadOpConversion
|
||||
|
||||
if (other) {
|
||||
for (size_t ii = 0; ii < nWords; ++ii) {
|
||||
// PTX doesn't support mov.u8, so we need to use mov.u16
|
||||
auto movWidth = width < 16 ? 16 : width;
|
||||
PTXInstr &mov =
|
||||
ptxBuilder.create<>("mov")->o("u" + std::to_string(width));
|
||||
ptxBuilder.create<>("mov")->o("u" + std::to_string(movWidth));
|
||||
|
||||
size_t size = width / valueElemNbits;
|
||||
|
||||
@@ -1707,6 +1709,43 @@ LogicalResult ReduceOpConversion::matchAndRewriteFast(
|
||||
|
||||
/// ====================== reduce codegen end ==========================
|
||||
|
||||
/// ====================== cat codegen begin ==========================
|
||||
|
||||
struct CatOpConversion : public ConvertTritonGPUOpToLLVMPattern<CatOp> {
|
||||
using OpAdaptor = typename CatOp::Adaptor;
|
||||
|
||||
explicit CatOpConversion(LLVMTypeConverter &typeConverter,
|
||||
PatternBenefit benefit = 1)
|
||||
: ConvertTritonGPUOpToLLVMPattern<CatOp>(typeConverter, benefit) {}
|
||||
|
||||
LogicalResult
|
||||
matchAndRewrite(CatOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
Location loc = op->getLoc();
|
||||
auto resultTy = op.getType().template cast<RankedTensorType>();
|
||||
unsigned elems = getElemsPerThread(resultTy);
|
||||
Type elemTy =
|
||||
this->getTypeConverter()->convertType(resultTy.getElementType());
|
||||
SmallVector<Type> types(elems, elemTy);
|
||||
// unpack input values
|
||||
auto lhsVals = getElementsFromStruct(loc, adaptor.lhs(), rewriter);
|
||||
auto rhsVals = getElementsFromStruct(loc, adaptor.rhs(), rewriter);
|
||||
// concatenate (and potentially reorder) values
|
||||
SmallVector<Value> retVals;
|
||||
for (Value v : lhsVals)
|
||||
retVals.push_back(v);
|
||||
for (Value v : rhsVals)
|
||||
retVals.push_back(v);
|
||||
// pack and replace
|
||||
Type structTy = LLVM::LLVMStructType::getLiteral(this->getContext(), types);
|
||||
Value ret = getStructFromElements(loc, retVals, rewriter, structTy);
|
||||
rewriter.replaceOp(op, ret);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
/// ====================== cat codegen end ==========================
|
||||
|
||||
template <typename SourceOp>
|
||||
struct ViewLikeOpConversion : public ConvertTritonGPUOpToLLVMPattern<SourceOp> {
|
||||
using OpAdaptor = typename SourceOp::Adaptor;
|
||||
@@ -2643,63 +2682,24 @@ public:
|
||||
dstLayout.isa<SliceEncodingAttr>())) {
|
||||
return lowerDistributedToDistributed(op, adaptor, rewriter);
|
||||
}
|
||||
// dot_op<opIdx=0, parent=#mma> = #mma
|
||||
// when #mma = MmaEncoding<version=2, warpsPerCTA=[..., 1]>
|
||||
if (srcLayout.isa<MmaEncodingAttr>() &&
|
||||
dstLayout.isa<DotOperandEncodingAttr>()) {
|
||||
auto srcMmaLayout = srcLayout.cast<MmaEncodingAttr>();
|
||||
auto dstDotLayout = dstLayout.cast<DotOperandEncodingAttr>();
|
||||
if (srcMmaLayout.getWarpsPerCTA()[1] == 1 &&
|
||||
dstDotLayout.getOpIdx() == 0 &&
|
||||
dstDotLayout.getParent() == srcMmaLayout) {
|
||||
// get source values
|
||||
Location loc = op->getLoc();
|
||||
auto vals = getElementsFromStruct(loc, adaptor.src(), rewriter);
|
||||
unsigned elems = getElemsPerThread(srcTy);
|
||||
Type elemTy =
|
||||
this->getTypeConverter()->convertType(srcTy.getElementType());
|
||||
// for the destination type, we need to pack values together
|
||||
// so they can be consumed by tensor core operations
|
||||
unsigned vecSize =
|
||||
std::max<unsigned>(32 / elemTy.getIntOrFloatBitWidth(), 1);
|
||||
Type vecTy = vec_ty(elemTy, vecSize);
|
||||
SmallVector<Type> types(elems / vecSize, vecTy);
|
||||
SmallVector<Value> vecVals;
|
||||
for (unsigned i = 0; i < elems; i += vecSize) {
|
||||
Value packed = rewriter.create<LLVM::UndefOp>(loc, vecTy);
|
||||
for (unsigned j = 0; j < vecSize; j++)
|
||||
packed = insert_element(vecTy, packed, vals[i + j], i32_val(j));
|
||||
vecVals.push_back(packed);
|
||||
}
|
||||
|
||||
// This needs to be ordered the same way that
|
||||
// ldmatrix.x4 would order it
|
||||
// TODO: this needs to be refactor so we don't
|
||||
// implicitly depends on how emitOffsetsForMMAV2
|
||||
// is implemented
|
||||
SmallVector<Value> reorderedVals;
|
||||
for (unsigned i = 0; i < vecVals.size(); i += 4) {
|
||||
reorderedVals.push_back(vecVals[i]);
|
||||
reorderedVals.push_back(vecVals[i + 2]);
|
||||
reorderedVals.push_back(vecVals[i + 1]);
|
||||
reorderedVals.push_back(vecVals[i + 3]);
|
||||
}
|
||||
|
||||
// return composeValuesToDotOperandLayoutStruct(ha, numRepM, numRepK);
|
||||
|
||||
Type structTy =
|
||||
LLVM::LLVMStructType::getLiteral(this->getContext(), types);
|
||||
Value view =
|
||||
getStructFromElements(loc, reorderedVals, rewriter, structTy);
|
||||
rewriter.replaceOp(op, view);
|
||||
return success();
|
||||
}
|
||||
return lowerMmaToDotOperand(op, adaptor, rewriter);
|
||||
}
|
||||
// TODO: to be implemented
|
||||
llvm_unreachable("unsupported layout conversion");
|
||||
return failure();
|
||||
}
|
||||
|
||||
static bool isMmaToDotShortcut(MmaEncodingAttr &mmaLayout,
|
||||
DotOperandEncodingAttr &dotOperandLayout) {
|
||||
// dot_op<opIdx=0, parent=#mma> = #mma
|
||||
// when #mma = MmaEncoding<version=2, warpsPerCTA=[..., 1]>
|
||||
return mmaLayout.getWarpsPerCTA()[1] == 1 &&
|
||||
dotOperandLayout.getOpIdx() == 0 &&
|
||||
dotOperandLayout.getParent() == mmaLayout;
|
||||
}
|
||||
|
||||
static void storeBlockedToShared(Value src, Value llSrc,
|
||||
ArrayRef<Value> srcStrides,
|
||||
ArrayRef<Value> srcIndices, Value dst,
|
||||
@@ -2964,6 +2964,11 @@ private:
|
||||
lowerSharedToDotOperand(triton::gpu::ConvertLayoutOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const;
|
||||
|
||||
// mma -> dot_operand
|
||||
LogicalResult lowerMmaToDotOperand(triton::gpu::ConvertLayoutOp op,
|
||||
OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const;
|
||||
|
||||
// shared -> dot_operand if the result layout is mma
|
||||
Value lowerSharedToDotOperandMMA(
|
||||
triton::gpu::ConvertLayoutOp op, OpAdaptor adaptor,
|
||||
@@ -3170,6 +3175,58 @@ LogicalResult ConvertLayoutOpConversion::lowerBlockedToShared(
|
||||
return success();
|
||||
}
|
||||
|
||||
LogicalResult ConvertLayoutOpConversion::lowerMmaToDotOperand(
|
||||
triton::gpu::ConvertLayoutOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const {
|
||||
auto loc = op.getLoc();
|
||||
auto srcTy = op.src().getType().cast<RankedTensorType>();
|
||||
auto dstTy = op.result().getType().cast<RankedTensorType>();
|
||||
auto srcLayout = srcTy.getEncoding();
|
||||
auto dstLayout = dstTy.getEncoding();
|
||||
auto srcMmaLayout = srcLayout.cast<MmaEncodingAttr>();
|
||||
auto dstDotLayout = dstLayout.cast<DotOperandEncodingAttr>();
|
||||
if (isMmaToDotShortcut(srcMmaLayout, dstDotLayout)) {
|
||||
// get source values
|
||||
auto vals = getElementsFromStruct(loc, adaptor.src(), rewriter);
|
||||
unsigned elems = getElemsPerThread(srcTy);
|
||||
Type elemTy = this->getTypeConverter()->convertType(srcTy.getElementType());
|
||||
// for the destination type, we need to pack values together
|
||||
// so they can be consumed by tensor core operations
|
||||
unsigned vecSize =
|
||||
std::max<unsigned>(32 / elemTy.getIntOrFloatBitWidth(), 1);
|
||||
Type vecTy = vec_ty(elemTy, vecSize);
|
||||
SmallVector<Type> types(elems / vecSize, vecTy);
|
||||
SmallVector<Value> vecVals;
|
||||
for (unsigned i = 0; i < elems; i += vecSize) {
|
||||
Value packed = rewriter.create<LLVM::UndefOp>(loc, vecTy);
|
||||
for (unsigned j = 0; j < vecSize; j++)
|
||||
packed = insert_element(vecTy, packed, vals[i + j], i32_val(j));
|
||||
vecVals.push_back(packed);
|
||||
}
|
||||
|
||||
// This needs to be ordered the same way that
|
||||
// ldmatrix.x4 would order it
|
||||
// TODO: this needs to be refactor so we don't
|
||||
// implicitly depends on how emitOffsetsForMMAV2
|
||||
// is implemented
|
||||
SmallVector<Value> reorderedVals;
|
||||
for (unsigned i = 0; i < vecVals.size(); i += 4) {
|
||||
reorderedVals.push_back(vecVals[i]);
|
||||
reorderedVals.push_back(vecVals[i + 2]);
|
||||
reorderedVals.push_back(vecVals[i + 1]);
|
||||
reorderedVals.push_back(vecVals[i + 3]);
|
||||
}
|
||||
|
||||
// return composeValuesToDotOperandLayoutStruct(ha, numRepM, numRepK);
|
||||
|
||||
Type structTy = LLVM::LLVMStructType::getLiteral(this->getContext(), types);
|
||||
Value view = getStructFromElements(loc, reorderedVals, rewriter, structTy);
|
||||
rewriter.replaceOp(op, view);
|
||||
return success();
|
||||
}
|
||||
return failure();
|
||||
}
|
||||
|
||||
struct InsertSliceOpConversion
|
||||
: public ConvertTritonGPUOpToLLVMPattern<tensor::InsertSliceOp> {
|
||||
using ConvertTritonGPUOpToLLVMPattern<
|
||||
@@ -3370,15 +3427,33 @@ Value ConvertLayoutOpConversion::lowerSharedToDotOperandMMA(
|
||||
}
|
||||
} else if (!isOuter && mmaLayout.getVersion() == 1 &&
|
||||
isHMMA) { // tensor core v1
|
||||
// vprintf("offset 0", smemObj.offsets[0]}, rewriter);
|
||||
DotOpMmaV1ConversionHelper helper(mmaLayout);
|
||||
if (dotOperandLayout.getOpIdx() == 0) {
|
||||
// operand $a
|
||||
res =
|
||||
helper.loadA(src, smemObj, getThreadId(rewriter, loc), loc, rewriter);
|
||||
} else if (dotOperandLayout.getOpIdx() == 1) {
|
||||
// operand $b
|
||||
res =
|
||||
helper.loadB(src, smemObj, getThreadId(rewriter, loc), loc, rewriter);
|
||||
bool isMMAv1Row =
|
||||
dotOperandLayout.getIsMMAv1Row().cast<BoolAttr>().getValue();
|
||||
auto srcSharedLayout = src.getType()
|
||||
.cast<RankedTensorType>()
|
||||
.getEncoding()
|
||||
.cast<SharedEncodingAttr>();
|
||||
|
||||
// Can only convert [1, 0] to row or [0, 1] to col for now
|
||||
if ((srcSharedLayout.getOrder()[0] == 1 && !isMMAv1Row) ||
|
||||
(srcSharedLayout.getOrder()[0] == 0 && isMMAv1Row)) {
|
||||
llvm::errs() << "Unsupported Shared -> DotOperand[MMAv1] conversion\n";
|
||||
return Value();
|
||||
}
|
||||
|
||||
if (dotOperandLayout.getOpIdx() == 0) { // operand $a
|
||||
// LLVM::vprintf_array(i32_val(0), smemObj.offsets, "offsets ", "%d", rewriter);
|
||||
// TODO[Superjomn]: transA is not available here.
|
||||
bool transA = false;
|
||||
res = helper.loadA(src, transA, smemObj, getThreadId(rewriter, loc), loc,
|
||||
rewriter);
|
||||
} else if (dotOperandLayout.getOpIdx() == 1) { // operand $b
|
||||
// TODO[Superjomn]: transB is not available here.
|
||||
bool transB = false;
|
||||
res = helper.loadB(src, transB, smemObj, getThreadId(rewriter, loc), loc,
|
||||
rewriter);
|
||||
}
|
||||
} else {
|
||||
assert(false && "Unsupported mma layout found");
|
||||
@@ -3481,6 +3556,14 @@ DotOpConversion::convertMMA884(triton::DotOp op, DotOpAdaptor adaptor,
|
||||
.cast<RankedTensorType>()
|
||||
.getEncoding()
|
||||
.cast<MmaEncodingAttr>();
|
||||
auto ALayout = A.getType()
|
||||
.cast<RankedTensorType>()
|
||||
.getEncoding()
|
||||
.cast<DotOperandEncodingAttr>();
|
||||
auto BLayout = B.getType()
|
||||
.cast<RankedTensorType>()
|
||||
.getEncoding()
|
||||
.cast<DotOperandEncodingAttr>();
|
||||
|
||||
auto ATensorTy = A.getType().cast<RankedTensorType>();
|
||||
auto BTensorTy = B.getType().cast<RankedTensorType>();
|
||||
@@ -3492,14 +3575,14 @@ DotOpConversion::convertMMA884(triton::DotOp op, DotOpAdaptor adaptor,
|
||||
auto DShape = DTensorTy.getShape();
|
||||
auto wpt = mmaLayout.getWarpsPerCTA();
|
||||
|
||||
// TODO[Superjomn]: order cannot accessed in DotOp.
|
||||
SmallVector<unsigned> AOrder({1, 0});
|
||||
SmallVector<unsigned> BOrder({1, 0});
|
||||
|
||||
bool isARow = AOrder[0] != 0;
|
||||
bool isBRow = BOrder[0] != 0;
|
||||
bool isARow = ALayout.getIsMMAv1Row().cast<BoolAttr>().getValue();
|
||||
bool isBRow = BLayout.getIsMMAv1Row().cast<BoolAttr>().getValue();
|
||||
bool isAVec4 = !isARow && AShape[isARow] <= 16; // fp16*4 = 16bytes
|
||||
bool isBVec4 = isBRow && BShape[isBRow] <= 16;
|
||||
// TODO[Superjomn]: ld.v4 is not supported.
|
||||
isAVec4 = true;
|
||||
isBVec4 = true;
|
||||
|
||||
int packSize0 = (isARow || isAVec4) ? 1 : 2;
|
||||
int packSize1 = (isBRow && !isBVec4) ? 2 : 1;
|
||||
SmallVector<int> fpw({2, 2, 1});
|
||||
@@ -3512,20 +3595,24 @@ DotOpConversion::convertMMA884(triton::DotOp op, DotOpAdaptor adaptor,
|
||||
DotOpMmaV1ConversionHelper helper(mmaLayout);
|
||||
|
||||
unsigned numM = rep[0] * DShape[0] / (spw[0] * wpt[0]);
|
||||
unsigned numN = rep[1] * DShape[1] / (spw[1] * wpt[0]);
|
||||
unsigned numN = rep[1] * DShape[1] / (spw[1] * wpt[1]);
|
||||
unsigned NK = AShape[1];
|
||||
|
||||
auto has = helper.extractLoadedOperand(loadedA, NK, rewriter);
|
||||
auto hbs = helper.extractLoadedOperand(loadedB, NK, rewriter);
|
||||
|
||||
// initialize accumulators
|
||||
// Initialize accumulators with external values, the acc holds the accumulator
|
||||
// value that is shared between the MMA instructions inside a DotOp, we can
|
||||
// call the order of the values the accumulator-internal order.
|
||||
SmallVector<Value> acc = getElementsFromStruct(loc, loadedC, rewriter);
|
||||
size_t resSize = acc.size();
|
||||
|
||||
// The resVals holds the final result of the DotOp.
|
||||
// NOTE The current order of resVals is different from acc, we call it the
|
||||
// accumulator-external order. and
|
||||
SmallVector<Value> resVals(resSize);
|
||||
|
||||
auto callMMA = [&](unsigned m, unsigned n, unsigned k) {
|
||||
auto ha = has.at({m, k});
|
||||
auto hb = hbs.at({n, k});
|
||||
auto getIdx = [&](int m, int n) {
|
||||
std::vector<size_t> idx{{
|
||||
(m * 2 + 0) + (n * 4 + 0) * numM, // row0
|
||||
(m * 2 + 0) + (n * 4 + 1) * numM,
|
||||
@@ -3536,8 +3623,29 @@ DotOpConversion::convertMMA884(triton::DotOp op, DotOpAdaptor adaptor,
|
||||
(m * 2 + 1) + (n * 4 + 2) * numM, // row3
|
||||
(m * 2 + 1) + (n * 4 + 3) * numM,
|
||||
}};
|
||||
return idx;
|
||||
};
|
||||
|
||||
{ // convert the acc's value from accumuator-external order to
|
||||
// accumulator-internal order.
|
||||
SmallVector<Value> accInit(acc.size());
|
||||
|
||||
for (unsigned m = 0; m < numM / 2; ++m)
|
||||
for (unsigned n = 0; n < numN / 2; ++n) {
|
||||
auto idx = getIdx(m, n);
|
||||
for (unsigned i = 0; i < 8; ++i)
|
||||
accInit[idx[i]] = acc[(m * numN / 2 + n) * 8 + i];
|
||||
}
|
||||
|
||||
acc = accInit;
|
||||
}
|
||||
|
||||
auto callMMA = [&](unsigned m, unsigned n, unsigned k) {
|
||||
auto ha = has.at({m, k});
|
||||
auto hb = hbs.at({n, k});
|
||||
|
||||
PTXBuilder builder;
|
||||
auto idx = getIdx(m, n);
|
||||
|
||||
auto *resOprs = builder.newListOperand(8, "=f");
|
||||
auto *AOprs = builder.newListOperand({
|
||||
@@ -3569,8 +3677,6 @@ DotOpConversion::convertMMA884(triton::DotOp op, DotOpAdaptor adaptor,
|
||||
for (unsigned i = 0; i < 8; i++) {
|
||||
Value elem = extract_val(f32_ty, res, getIntAttr(i));
|
||||
acc[idx[i]] = elem;
|
||||
// TODO[goostavz]: double confirm this when m/n/k = [32, 32, x] has been
|
||||
// verified before MMA
|
||||
resVals[(m * numN / 2 + n) * 8 + i] = elem;
|
||||
}
|
||||
};
|
||||
@@ -3776,7 +3882,8 @@ public:
|
||||
llvm::Optional<Type> convertTritonTensorType(RankedTensorType type) {
|
||||
auto ctx = type.getContext();
|
||||
Attribute layout = type.getEncoding();
|
||||
auto shape = type.getShape();
|
||||
SmallVector<int64_t> shape(type.getShape().begin(), type.getShape().end());
|
||||
|
||||
if (layout &&
|
||||
(layout.isa<BlockedEncodingAttr>() || layout.isa<SliceEncodingAttr>() ||
|
||||
layout.isa<MmaEncodingAttr>())) {
|
||||
@@ -3839,13 +3946,22 @@ public:
|
||||
if (mmaLayout.getVersion() == 1) {
|
||||
DotOpMmaV1ConversionHelper helper(mmaLayout);
|
||||
|
||||
// TODO[Superjomn]: Both transA and transB are not available here.
|
||||
bool trans = false;
|
||||
// TODO[Superjomn]: The order of A and B are not available here.
|
||||
SmallVector<unsigned> order({1, 0});
|
||||
if (trans) {
|
||||
std::swap(shape[0], shape[1]);
|
||||
std::swap(order[0], order[1]);
|
||||
}
|
||||
|
||||
if (dotOpLayout.getOpIdx() == 0) { // $a
|
||||
int elems = helper.numElemsPerThreadA(type);
|
||||
int elems = helper.numElemsPerThreadA(shape, order);
|
||||
Type x2Ty = vec_ty(elemTy, 2);
|
||||
return struct_ty(SmallVector<Type>(elems, x2Ty));
|
||||
}
|
||||
if (dotOpLayout.getOpIdx() == 1) { // $b
|
||||
int elems = helper.numElemsPerThreadB(type);
|
||||
int elems = helper.numElemsPerThreadB(shape, order);
|
||||
Type x2Ty = vec_ty(elemTy, 2);
|
||||
return struct_ty(SmallVector<Type>(elems, x2Ty));
|
||||
}
|
||||
@@ -4344,6 +4460,7 @@ struct AtomicRMWOpConversion
|
||||
rewriter, loc, getTypeConverter()->getIndexType(), ii);
|
||||
rmwVal = insert_element(vecTy, rmwVal, valElements[i + ii], iiVal);
|
||||
}
|
||||
|
||||
Value rmwPtr = ptrElements[i];
|
||||
Value rmwMask = maskElements[i];
|
||||
rmwMask = and_(rmwMask, mask);
|
||||
@@ -4401,9 +4518,10 @@ struct AtomicRMWOpConversion
|
||||
atom.o(rmwOp).o(sTy);
|
||||
if (valueTy) {
|
||||
atom(dstOpr, ptrOpr, valOpr).predicate(rmwMask);
|
||||
auto ret = ptxBuilderAtomicRMW.launch(rewriter, loc, valueElemTy);
|
||||
auto retType = vec == 1 ? valueElemTy : vecTy;
|
||||
auto ret = ptxBuilderAtomicRMW.launch(rewriter, loc, retType);
|
||||
for (int ii = 0; ii < vec; ++ii) {
|
||||
resultVals[i * vec + ii] =
|
||||
resultVals[i + ii] =
|
||||
vec == 1 ? ret : extract_element(valueElemTy, ret, idx_val(ii));
|
||||
}
|
||||
} else {
|
||||
@@ -4537,6 +4655,7 @@ void populateTritonToLLVMPatterns(mlir::LLVMTypeConverter &typeConverter,
|
||||
benefit);
|
||||
patterns.add<DotOpConversion>(typeConverter, allocation, smem, benefit);
|
||||
patterns.add<TransOpConversion>(typeConverter, benefit);
|
||||
patterns.add<CatOpConversion>(typeConverter, benefit);
|
||||
patterns.add<PrintfOpConversion>(typeConverter, benefit);
|
||||
}
|
||||
|
||||
@@ -4544,6 +4663,34 @@ class ConvertTritonGPUToLLVM
|
||||
: public ConvertTritonGPUToLLVMBase<ConvertTritonGPUToLLVM> {
|
||||
|
||||
private:
|
||||
void decomposeMmaToDotOperand(ModuleOp mod, int numWarps) {
|
||||
// replace `mma -> dot_op` with `mma -> blocked -> dot_op`
|
||||
// unless certain conditions are met
|
||||
mod.walk([&](triton::gpu::ConvertLayoutOp cvtOp) -> void {
|
||||
OpBuilder builder(cvtOp);
|
||||
auto srcType = cvtOp.getOperand().getType().cast<RankedTensorType>();
|
||||
auto dstType = cvtOp.getType().cast<RankedTensorType>();
|
||||
auto srcMma =
|
||||
srcType.getEncoding().dyn_cast<triton::gpu::MmaEncodingAttr>();
|
||||
auto dstDotOp =
|
||||
dstType.getEncoding().dyn_cast<triton::gpu::DotOperandEncodingAttr>();
|
||||
if (srcMma && dstDotOp &&
|
||||
!ConvertLayoutOpConversion::isMmaToDotShortcut(srcMma, dstDotOp)) {
|
||||
auto tmpType = RankedTensorType::get(
|
||||
dstType.getShape(), dstType.getElementType(),
|
||||
triton::gpu::BlockedEncodingAttr::get(
|
||||
mod.getContext(), srcType.getShape(), getSizePerThread(srcMma),
|
||||
getOrder(srcMma), numWarps));
|
||||
auto tmp = builder.create<triton::gpu::ConvertLayoutOp>(
|
||||
cvtOp.getLoc(), tmpType, cvtOp.getOperand());
|
||||
auto newConvert = builder.create<triton::gpu::ConvertLayoutOp>(
|
||||
cvtOp.getLoc(), dstType, tmp);
|
||||
cvtOp.replaceAllUsesWith(newConvert.getResult());
|
||||
cvtOp.erase();
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
void decomposeBlockedToDotOperand(ModuleOp mod) {
|
||||
// replace `blocked -> dot_op` with `blocked -> shared -> dot_op`
|
||||
// because the codegen doesn't handle `blocked -> dot_op` directly
|
||||
@@ -4571,6 +4718,47 @@ private:
|
||||
});
|
||||
}
|
||||
|
||||
void rewriteConvertToDotOperand(ModuleOp mod) {
|
||||
mod.walk([&](triton::gpu::ConvertLayoutOp cvt){
|
||||
OpBuilder builder(cvt);
|
||||
auto srcType = cvt.getOperand().getType().cast<RankedTensorType>();
|
||||
auto dstType = cvt.getResult().getType().cast<RankedTensorType>();
|
||||
// order
|
||||
ArrayRef<unsigned> order;
|
||||
if(auto srcBlockedLayout =
|
||||
srcType.getEncoding().dyn_cast<triton::gpu::BlockedEncodingAttr>())
|
||||
order = srcBlockedLayout.getOrder();
|
||||
else if(auto srcSharedLayout =
|
||||
srcType.getEncoding().dyn_cast<triton::gpu::SharedEncodingAttr>())
|
||||
order = srcSharedLayout.getOrder();
|
||||
else
|
||||
return;
|
||||
// dot operand output
|
||||
auto dstDotOperandLayout =
|
||||
dstType.getEncoding().dyn_cast<triton::gpu::DotOperandEncodingAttr>();
|
||||
if (!dstDotOperandLayout)
|
||||
return;
|
||||
unsigned opIdx = dstDotOperandLayout.getOpIdx();
|
||||
if(!dstDotOperandLayout.getIsMMAv1Row())
|
||||
return;
|
||||
bool isMMAv1Row = dstDotOperandLayout.getIsMMAv1Row().cast<BoolAttr>().getValue();
|
||||
if((order[0] == 1 && isMMAv1Row) ||
|
||||
(order[0] == 0 && !isMMAv1Row))
|
||||
return;
|
||||
auto newIsRow = BoolAttr::get(cvt.getContext(), !isMMAv1Row);
|
||||
auto newDstEncoding = triton::gpu::DotOperandEncodingAttr::get(
|
||||
cvt.getContext(), dstDotOperandLayout.getOpIdx(), dstDotOperandLayout.getParent(),
|
||||
newIsRow);
|
||||
auto newDstType = RankedTensorType::get(
|
||||
dstType.getShape(),
|
||||
dstType.getElementType(), newDstEncoding);
|
||||
auto newCvt = builder.create<triton::gpu::ConvertLayoutOp>(
|
||||
cvt.getLoc(), newDstType, cvt.getOperand());
|
||||
cvt.replaceAllUsesWith(newCvt.getResult());
|
||||
cvt.erase();
|
||||
});
|
||||
}
|
||||
|
||||
void decomposeInsertSliceAsyncOp(ModuleOp mod) {
|
||||
AxisInfoAnalysis axisInfoAnalysis(mod.getContext());
|
||||
axisInfoAnalysis.run(mod);
|
||||
@@ -4620,8 +4808,7 @@ private:
|
||||
// capability does not support async copy, then we do decompose
|
||||
if (triton::gpu::InsertSliceAsyncOp::getEligibleLoadByteWidth(
|
||||
computeCapability)
|
||||
.contains(byteWidth) &&
|
||||
computeCapability >= 80)
|
||||
.contains(byteWidth))
|
||||
return;
|
||||
|
||||
// load
|
||||
@@ -4655,13 +4842,8 @@ private:
|
||||
|
||||
// async wait is supported in Ampere and later
|
||||
mod.walk([&](triton::gpu::AsyncWaitOp asyncWaitOp) -> void {
|
||||
if (computeCapability < 80) {
|
||||
asyncWaitOp.erase();
|
||||
} else if (decomposed) {
|
||||
OpBuilder builder(asyncWaitOp);
|
||||
// Wait for all previous async ops
|
||||
auto newAsyncWaitOp = builder.create<triton::gpu::AsyncWaitOp>(
|
||||
asyncWaitOp.getLoc(), builder.getI64IntegerAttr(0));
|
||||
if (!triton::gpu::AsyncWaitOp::isSupported(computeCapability) ||
|
||||
decomposed) {
|
||||
asyncWaitOp.erase();
|
||||
}
|
||||
});
|
||||
@@ -4696,6 +4878,9 @@ public:
|
||||
// separation between 1/4 is that, step 3 is out of the scope of Dialect
|
||||
// Conversion, thus we need to make sure the smem is not revised during the
|
||||
// conversion of step 4.
|
||||
rewriteConvertToDotOperand(mod);
|
||||
decomposeMmaToDotOperand(mod, numWarps);
|
||||
|
||||
decomposeBlockedToDotOperand(mod);
|
||||
|
||||
decomposeInsertSliceAsyncOp(mod);
|
||||
@@ -4704,6 +4889,7 @@ public:
|
||||
MembarAnalysis membarPass(&allocation);
|
||||
membarPass.run();
|
||||
|
||||
llvm::outs() << mod << "\n";
|
||||
RewritePatternSet scf_patterns(context);
|
||||
mlir::populateLoopToStdConversionPatterns(scf_patterns);
|
||||
mlir::ConversionTarget scf_target(*context);
|
||||
|
@@ -114,6 +114,7 @@ void populateArithmeticPatternsAndLegality(
|
||||
GenericOpPattern<arith::TruncIOp>, GenericOpPattern<arith::TruncFOp>,
|
||||
GenericOpPattern<arith::ExtUIOp>, GenericOpPattern<arith::ExtSIOp>,
|
||||
GenericOpPattern<arith::ExtFOp>, GenericOpPattern<arith::SIToFPOp>,
|
||||
GenericOpPattern<arith::FPToSIOp>, GenericOpPattern<arith::FPToUIOp>,
|
||||
GenericOpPattern<arith::UIToFPOp>>(typeConverter, context);
|
||||
}
|
||||
|
||||
@@ -251,6 +252,22 @@ struct TritonDotPattern : public OpConversionPattern<triton::DotOp> {
|
||||
}
|
||||
};
|
||||
|
||||
struct TritonCatPattern : public OpConversionPattern<triton::CatOp> {
|
||||
|
||||
using OpConversionPattern<triton::CatOp>::OpConversionPattern;
|
||||
|
||||
LogicalResult
|
||||
matchAndRewrite(triton::CatOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
// For now, this behaves like generic, but this will evolve when
|
||||
// we add support for `can_reorder=False`
|
||||
Type retType = this->getTypeConverter()->convertType(op.getType());
|
||||
rewriter.replaceOpWithNewOp<triton::CatOp>(op, retType, adaptor.getOperands());
|
||||
return success();
|
||||
}
|
||||
|
||||
};
|
||||
|
||||
struct TritonTransPattern : public OpConversionPattern<triton::TransOp> {
|
||||
|
||||
using OpConversionPattern<triton::TransOp>::OpConversionPattern;
|
||||
@@ -433,7 +450,9 @@ void populateTritonPatterns(TritonGPUTypeConverter &typeConverter,
|
||||
TritonGenericPattern<triton::IntToPtrOp>,
|
||||
TritonGenericPattern<triton::PtrToIntOp>,
|
||||
TritonGenericPattern<triton::SplatOp>, TritonBroadcastPattern,
|
||||
TritonGenericPattern<triton::AddPtrOp>, TritonReducePattern,
|
||||
TritonGenericPattern<triton::AddPtrOp>,
|
||||
TritonCatPattern,
|
||||
TritonReducePattern,
|
||||
TritonTransPattern, TritonExpandDimsPattern, TritonMakeRangePattern,
|
||||
TritonDotPattern, TritonLoadPattern, TritonStorePattern,
|
||||
TritonExtElemwisePattern, TritonPrintfPattern, TritonAtomicRMWPattern>(
|
||||
|
@@ -19,7 +19,7 @@ mlir::OpTrait::impl::verifySameOperandsAndResultEncoding(Operation *op) {
|
||||
for (auto resultType : op->getResultTypes())
|
||||
if (failed(verifySameEncoding(resultType, type)))
|
||||
return op->emitOpError()
|
||||
<< "requires the same shape for all operands and results";
|
||||
<< "requires the same encoding for all operands and results";
|
||||
return verifySameOperandsEncoding(op);
|
||||
}
|
||||
|
||||
|
@@ -196,7 +196,7 @@ public:
|
||||
patterns.add<CombineDotAddFRevPattern>(context);
|
||||
// %}
|
||||
patterns.add<CombineSelectMaskedLoadPattern>(context);
|
||||
patterns.add<CombineAddPtrPattern>(context);
|
||||
// patterns.add<CombineAddPtrPattern>(context);
|
||||
patterns.add<CombineBroadcastConstantPattern>(context);
|
||||
|
||||
if (applyPatternsAndFoldGreedily(m, std::move(patterns)).failed())
|
||||
|
@@ -29,13 +29,14 @@ def CombineDotAddFRevPattern : Pat<
|
||||
(TT_DotOp $a, $b, $d, $allowTF32),
|
||||
[(Constraint<CPred<"isZero($0)">> $c)]>;
|
||||
|
||||
|
||||
// TODO: this fails for addptr(addptr(ptr, i32), i64)
|
||||
// Commented out until fixed
|
||||
// addptr(addptr(%ptr, %idx0), %idx1) => addptr(%ptr, AddI(%idx0, %idx1))
|
||||
// Note: leave (sub %c0, %c0) canceling to ArithmeticDialect
|
||||
// (ref: ArithmeticCanonicalization.td)
|
||||
def CombineAddPtrPattern : Pat<
|
||||
(TT_AddPtrOp (TT_AddPtrOp $ptr, $idx0), $idx1),
|
||||
(TT_AddPtrOp $ptr, (Arith_AddIOp $idx0, $idx1))>;
|
||||
// def CombineAddPtrPattern : Pat<
|
||||
// (TT_AddPtrOp (TT_AddPtrOp $ptr, $idx0), $idx1),
|
||||
// (TT_AddPtrOp $ptr, (Arith_AddIOp $idx0, $idx1))>;
|
||||
|
||||
// broadcast(cst) => cst
|
||||
def getConstantValue : NativeCodeCall<"getConstantValue($_builder, $0, $1)">;
|
||||
|
@@ -589,15 +589,24 @@ Attribute DotOperandEncodingAttr::parse(AsmParser &parser, Type type) {
|
||||
return {};
|
||||
unsigned opIdx = attrs.get("opIdx").cast<IntegerAttr>().getInt();
|
||||
Attribute parent = attrs.get("parent");
|
||||
|
||||
Attribute isMMAv1Row;
|
||||
if(parent.isa<MmaEncodingAttr>() &&
|
||||
parent.cast<MmaEncodingAttr>().getVersion() == 1){
|
||||
isMMAv1Row = attrs.get("isMMAv1Row");
|
||||
if(!isMMAv1Row)
|
||||
llvm::report_fatal_error("isMMAv1Row attribute is missing");
|
||||
}
|
||||
return parser.getChecked<DotOperandEncodingAttr>(parser.getContext(), opIdx,
|
||||
parent);
|
||||
parent, isMMAv1Row);
|
||||
}
|
||||
|
||||
void DotOperandEncodingAttr::print(mlir::AsmPrinter &printer) const {
|
||||
printer << "<{"
|
||||
<< "opIdx = " << getOpIdx() << ", "
|
||||
<< "parent = " << getParent() << "}>";
|
||||
<< "parent = " << getParent();
|
||||
if(getIsMMAv1Row())
|
||||
printer << ", isMMAv1Row = " << getIsMMAv1Row();
|
||||
printer << "}>";
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
@@ -659,15 +668,6 @@ void printInsertSliceAsyncOp(OpAsmPrinter &printer,
|
||||
printer.printStrippedAttrOrType(insertSliceAsyncOp.result().getType());
|
||||
}
|
||||
|
||||
DenseSet<unsigned>
|
||||
InsertSliceAsyncOp::getEligibleLoadByteWidth(int computeCapability) {
|
||||
DenseSet<unsigned> validLoadBytes;
|
||||
if (computeCapability >= 80) {
|
||||
validLoadBytes = {4, 8, 16};
|
||||
}
|
||||
return validLoadBytes;
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// ASM Interface (i.e.: alias)
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@@ -713,6 +713,55 @@ public:
|
||||
}
|
||||
};
|
||||
|
||||
class OptimizeConvertToDotOperand : public mlir::RewritePattern {
|
||||
public:
|
||||
OptimizeConvertToDotOperand(mlir::MLIRContext *context)
|
||||
: RewritePattern(triton::gpu::ConvertLayoutOp::getOperationName(), 1,
|
||||
context) {}
|
||||
|
||||
mlir::LogicalResult
|
||||
matchAndRewrite(mlir::Operation *op,
|
||||
mlir::PatternRewriter &rewriter) const override {
|
||||
auto cvt = cast<triton::gpu::ConvertLayoutOp>(op);
|
||||
auto srcType = cvt.getOperand().getType().cast<RankedTensorType>();
|
||||
auto dstType = cvt.getResult().getType().cast<RankedTensorType>();
|
||||
// order
|
||||
ArrayRef<unsigned> order;
|
||||
if(auto srcBlockedLayout =
|
||||
srcType.getEncoding().dyn_cast<triton::gpu::BlockedEncodingAttr>())
|
||||
order = srcBlockedLayout.getOrder();
|
||||
else if(auto srcSharedLayout =
|
||||
srcType.getEncoding().dyn_cast<triton::gpu::SharedEncodingAttr>())
|
||||
order = srcSharedLayout.getOrder();
|
||||
else
|
||||
return failure();
|
||||
// dot operand output
|
||||
auto dstDotOperandLayout =
|
||||
dstType.getEncoding().dyn_cast<triton::gpu::DotOperandEncodingAttr>();
|
||||
if (!dstDotOperandLayout)
|
||||
return failure();
|
||||
unsigned opIdx = dstDotOperandLayout.getOpIdx();
|
||||
if(!dstDotOperandLayout.getIsMMAv1Row())
|
||||
return failure();
|
||||
bool isMMAv1Row = dstDotOperandLayout.getIsMMAv1Row().cast<BoolAttr>().getValue();
|
||||
if((order[0] == 1 && isMMAv1Row) ||
|
||||
(order[0] == 0 && !isMMAv1Row))
|
||||
return failure();
|
||||
auto newIsRow = BoolAttr::get(op->getContext(), !isMMAv1Row);
|
||||
auto newDstEncoding = triton::gpu::DotOperandEncodingAttr::get(
|
||||
op->getContext(), dstDotOperandLayout.getOpIdx(), dstDotOperandLayout.getParent(),
|
||||
newIsRow);
|
||||
auto newDstType = RankedTensorType::get(
|
||||
dstType.getShape(),
|
||||
dstType.getElementType(), newDstEncoding);
|
||||
auto newCvt = rewriter.create<triton::gpu::ConvertLayoutOp>(
|
||||
op->getLoc(), newDstType, cvt.getOperand());
|
||||
rewriter.replaceOp(op, newCvt.getResult());
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
class BlockedToMMA : public mlir::RewritePattern {
|
||||
int computeCapability;
|
||||
|
||||
@@ -770,14 +819,28 @@ public:
|
||||
Value b = dotOp.b();
|
||||
auto oldAType = a.getType().cast<RankedTensorType>();
|
||||
auto oldBType = b.getType().cast<RankedTensorType>();
|
||||
auto oldAOrder = oldAType.getEncoding().cast<triton::gpu::DotOperandEncodingAttr>()
|
||||
.getParent().cast<triton::gpu::BlockedEncodingAttr>().getOrder();
|
||||
auto oldBOrder = oldBType.getEncoding().cast<triton::gpu::DotOperandEncodingAttr>()
|
||||
.getParent().cast<triton::gpu::BlockedEncodingAttr>().getOrder();
|
||||
Attribute isMMAv1RowA;
|
||||
Attribute isMMAv1RowB;
|
||||
if(version == 1){
|
||||
isMMAv1RowA = BoolAttr::get(getContext(), oldAOrder[0] == 1);
|
||||
isMMAv1RowB = BoolAttr::get(getContext(), oldBOrder[0] == 1);
|
||||
}
|
||||
|
||||
auto newAType = RankedTensorType::get(
|
||||
oldAType.getShape(), oldAType.getElementType(),
|
||||
triton::gpu::DotOperandEncodingAttr::get(oldAType.getContext(), 0,
|
||||
newRetType.getEncoding()));
|
||||
newRetType.getEncoding(),
|
||||
isMMAv1RowA));
|
||||
auto newBType = RankedTensorType::get(
|
||||
oldBType.getShape(), oldBType.getElementType(),
|
||||
triton::gpu::DotOperandEncodingAttr::get(oldBType.getContext(), 1,
|
||||
newRetType.getEncoding()));
|
||||
newRetType.getEncoding(),
|
||||
isMMAv1RowB));
|
||||
|
||||
a = rewriter.create<triton::gpu::ConvertLayoutOp>(a.getLoc(), newAType, a);
|
||||
b = rewriter.create<triton::gpu::ConvertLayoutOp>(b.getLoc(), newBType, b);
|
||||
auto newDot = rewriter.create<triton::DotOp>(
|
||||
@@ -808,6 +871,7 @@ public:
|
||||
mlir::RewritePatternSet patterns(context);
|
||||
|
||||
patterns.add<OptimizeBlockedToShared>(context);
|
||||
// patterns.add<OptimizeConvertToDotOperand>(context);
|
||||
patterns.add<SimplifyConversion>(context);
|
||||
patterns.add<DecomposeDotOperand>(context);
|
||||
patterns.add<RematerializeBackward>(context);
|
||||
@@ -818,6 +882,7 @@ public:
|
||||
if (applyPatternsAndFoldGreedily(m, std::move(patterns)).failed()) {
|
||||
signalPassFailure();
|
||||
}
|
||||
|
||||
}
|
||||
};
|
||||
|
||||
|
@@ -134,7 +134,7 @@ translateTritonGPUToLLVMIR(llvm::LLVMContext *llvmContext,
|
||||
/*printAfterOnlyOnChange=*/true,
|
||||
/*printAfterOnlyOnFailure*/ false, llvm::dbgs(), printingFlags);
|
||||
|
||||
pm.addPass(createConvertTritonGPUToLLVMPass());
|
||||
pm.addPass(createConvertTritonGPUToLLVMPass(computeCapability));
|
||||
// Canonicalize to eliminate the remaining UnrealizedConversionCastOp
|
||||
pm.addPass(mlir::createCanonicalizerPass());
|
||||
pm.addPass(mlir::createCSEPass()); // Simplify the IR to improve readability.
|
||||
|
@@ -1382,10 +1382,11 @@ void init_triton_translation(py::module &m) {
|
||||
llvm::SMDiagnostic error;
|
||||
std::unique_ptr<llvm::Module> module =
|
||||
llvm::parseIR(buffer->getMemBufferRef(), error, context);
|
||||
if (!module)
|
||||
if (!module) {
|
||||
llvm::report_fatal_error(
|
||||
"failed to parse IR: " + error.getMessage() +
|
||||
"lineno: " + std::to_string(error.getLineNo()));
|
||||
}
|
||||
|
||||
// translate module to PTX
|
||||
auto ptxCode =
|
||||
|
@@ -64,12 +64,12 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} {
|
||||
%7 = tt.broadcast %6 : (tensor<1x128xi32, #src>) -> tensor<128x128xi32, #src>
|
||||
%8 = tt.broadcast %5 : (tensor<128x1xi32, #src>) -> tensor<128x128xi32, #src>
|
||||
%9 = arith.addi %8, %7 : tensor<128x128xi32, #src>
|
||||
%10 = tt.addptr %2, %9 : tensor<128x128x!tt.ptr<f16>, #src>
|
||||
%10 = tt.addptr %2, %9 : tensor<128x128x!tt.ptr<f16>, #src>, tensor<128x128xi32, #src>
|
||||
%11 = tt.load %10 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<128x128xf16, #src>
|
||||
%3 = tt.splat %arg1 : (!tt.ptr<f16>) -> tensor<128x128x!tt.ptr<f16>, #dst>
|
||||
%12 = triton_gpu.convert_layout %9 : (tensor<128x128xi32, #src>) -> tensor<128x128xi32, #dst>
|
||||
%13 = triton_gpu.convert_layout %11 : (tensor<128x128xf16, #src>) -> tensor<128x128xf16, #dst>
|
||||
%14 = tt.addptr %3, %12 : tensor<128x128x!tt.ptr<f16>, #dst>
|
||||
%14 = tt.addptr %3, %12 : tensor<128x128x!tt.ptr<f16>, #dst>, tensor<128x128xi32, #dst>
|
||||
tt.store %14, %13 : tensor<128x128xf16, #dst>
|
||||
return
|
||||
}
|
||||
|
@@ -20,6 +20,7 @@ float_dtypes = ['float16', 'float32', 'float64']
|
||||
dtypes = int_dtypes + uint_dtypes + float_dtypes
|
||||
# TODO: handle bfloat16
|
||||
dtypes_with_bfloat16 = dtypes # + ['bfloat16']
|
||||
torch_dtypes = ['bool'] + int_dtypes + ['uint8'] + float_dtypes # + ['bfloat16']
|
||||
|
||||
|
||||
def _bitwidth(dtype: str) -> int:
|
||||
@@ -677,6 +678,7 @@ def test_tensor_atomic_rmw(shape, axis, device="cuda"):
|
||||
kernel[(1,)](z_tri, x_tri, axis, shape0, shape1)
|
||||
np.testing.assert_allclose(z_ref, to_numpy(z_tri), rtol=1e-4)
|
||||
|
||||
|
||||
def test_atomic_cas():
|
||||
# 1. make sure that atomic_cas changes the original value (Lock)
|
||||
@triton.jit
|
||||
@@ -742,9 +744,11 @@ def test_cast(dtype_x, dtype_z, bitcast, device='cuda'):
|
||||
# triton kernel
|
||||
@triton.jit
|
||||
def kernel(X, Z, BITCAST: tl.constexpr):
|
||||
x = tl.load(X)
|
||||
x_ptr = X + tl.arange(0, 1)
|
||||
z_ptr = Z + tl.arange(0, 1)
|
||||
x = tl.load(x_ptr)
|
||||
z = x.to(Z.dtype.element_ty, bitcast=BITCAST)
|
||||
tl.store(Z, z)
|
||||
tl.store(z_ptr, z)
|
||||
|
||||
dtype_z_np = dtype_z if dtype_z != 'int1' else 'bool_'
|
||||
# triton result
|
||||
@@ -1067,21 +1071,23 @@ def test_permute(dtype_str, shape, perm, device='cuda'):
|
||||
# # ---------------
|
||||
|
||||
|
||||
@pytest.mark.parametrize("epilogue, allow_tf32, dtype",
|
||||
[(epilogue, allow_tf32, dtype)
|
||||
@pytest.mark.parametrize("M, N, K, epilogue, allow_tf32, dtype",
|
||||
[(*shape, epilogue, allow_tf32, dtype)
|
||||
for shape in [(64, 64, 64)]
|
||||
for epilogue in ['none', 'trans', 'add-matrix', 'add-rows', 'add-cols', 'softmax', 'chain-dot']
|
||||
for allow_tf32 in [True, False]
|
||||
for dtype in ['float16']
|
||||
for dtype in ['float16', 'float32']
|
||||
if not (allow_tf32 and (dtype in ['float16']))])
|
||||
def test_dot(epilogue, allow_tf32, dtype, device='cuda'):
|
||||
def test_dot(M, N, K, epilogue, allow_tf32, dtype, device='cuda'):
|
||||
capability = torch.cuda.get_device_capability()
|
||||
if capability[0] < 80:
|
||||
if capability[0] < 8:
|
||||
if dtype == 'int8':
|
||||
pytest.skip("Only test int8 on devices with sm >= 80")
|
||||
elif dtype == 'float32' and allow_tf32:
|
||||
pytest.skip("Only test tf32 on devices with sm >= 80")
|
||||
|
||||
M, N, K = 64, 64, 64
|
||||
torch.backends.cuda.matmul.allow_tf32 = allow_tf32
|
||||
|
||||
num_warps = 4
|
||||
trans_a, trans_b = False, False
|
||||
|
||||
@@ -1126,7 +1132,8 @@ def test_dot(epilogue, allow_tf32, dtype, device='cuda'):
|
||||
if CHAIN_DOT:
|
||||
# tl.store(Zs, z)
|
||||
# tl.debug_barrier()
|
||||
z = tl.dot(z.to(tl.float16), tl.load(Ws))
|
||||
w = tl.load(Ws)
|
||||
z = tl.dot(z.to(w.dtype), w)
|
||||
tl.store(Zs, z)
|
||||
# input
|
||||
rs = RandomState(17)
|
||||
@@ -1176,14 +1183,18 @@ def test_dot(epilogue, allow_tf32, dtype, device='cuda'):
|
||||
z_ref = np.matmul(z_ref, w)
|
||||
# compare
|
||||
# print(z_ref[:,0], z_tri[:,0])
|
||||
np.testing.assert_allclose(z_ref, to_numpy(z_tri), rtol=0.01)
|
||||
if dtype == 'float32':
|
||||
# XXX: Somehow there's a larger difference when we use float32
|
||||
np.testing.assert_allclose(z_ref, to_numpy(z_tri), rtol=0.01, atol=1e-3)
|
||||
else:
|
||||
np.testing.assert_allclose(z_ref, to_numpy(z_tri), rtol=0.01)
|
||||
# make sure ld/st are vectorized
|
||||
ptx = pgm.asm['ptx']
|
||||
assert 'ld.global.v4' in ptx
|
||||
assert 'st.global.v4' in ptx
|
||||
if allow_tf32:
|
||||
assert 'mma.sync.aligned.m16n8k8.row.col.f32.tf32.tf32.f32' in ptx
|
||||
elif dtype == 'float32':
|
||||
elif dtype == 'float32' and allow_tf32:
|
||||
assert 'mma.sync.aligned.m16n8k8.row.col.f32.tf32.tf32.f32' not in ptx
|
||||
elif dtype == 'int8':
|
||||
assert 'mma.sync.aligned.m16n8k32.row.col.satfinite.s32.s8.s8.s32' in ptx
|
||||
@@ -1223,9 +1234,41 @@ def test_arange(start, device='cuda'):
|
||||
z_ref = torch.arange(start, BLOCK + start, dtype=torch.int32, device=device)
|
||||
triton.testing.assert_almost_equal(z_tri, z_ref)
|
||||
|
||||
# # ---------------
|
||||
# # test load
|
||||
# # ---------------
|
||||
# ---------------
|
||||
# test load
|
||||
# ---------------
|
||||
|
||||
|
||||
@pytest.mark.parametrize("dtype_str, size, size_diff", [(dtype_str, size, size_diff) for dtype_str in torch_dtypes for size in [128, 512] for size_diff in [1, 2, 3, 4]])
|
||||
def test_masked_load(dtype_str, size, size_diff, device='cuda'):
|
||||
dtype = getattr(torch, dtype_str)
|
||||
check_type_supported(dtype) # bfloat16 on cc < 80 will not be tested
|
||||
|
||||
input_size = size - size_diff
|
||||
output_size = size
|
||||
if dtype_str == 'bool':
|
||||
input = torch.randint(0, 2, (input_size,), dtype=dtype, device=device)
|
||||
elif dtype_str in int_dtypes or dtype_str in uint_dtypes:
|
||||
input = torch.randint(0, 127, (input_size,), dtype=dtype, device=device)
|
||||
else:
|
||||
input = torch.rand(input_size, dtype=dtype, device=device)
|
||||
output = torch.zeros((output_size,), dtype=dtype, device=device)
|
||||
|
||||
@triton.jit
|
||||
def _kernel(in_ptr, out_ptr, in_size: tl.constexpr, out_size: tl.constexpr):
|
||||
in_offsets = tl.arange(0, out_size)
|
||||
# Load inputs.
|
||||
x = tl.load(in_ptr + in_offsets, mask=in_offsets < in_size, other=1)
|
||||
# Store output
|
||||
output_offsets = tl.arange(0, out_size)
|
||||
tl.store(out_ptr + output_offsets, x)
|
||||
|
||||
_kernel[(1,)](input, output, input_size, output_size)
|
||||
|
||||
reference_out = input
|
||||
reference_out = torch.cat((reference_out, torch.ones((size_diff,), dtype=dtype, device=device)))
|
||||
triton.testing.allclose(output, reference_out)
|
||||
|
||||
# # 'bfloat16': torch.bfloat16,
|
||||
# # Testing masked loads with an intermate copy to shared memory run.
|
||||
|
||||
|
@@ -295,18 +295,25 @@ def test_gemm_fp32(M, N, K, num_warps, block_M, block_N, block_K, allow_tf32):
|
||||
|
||||
|
||||
# NOTE this is useful only on Volta GPU.
|
||||
@pytest.mark.parametrize('SHAPE,NUM_WARPS,TRANS_A,TRANS_B', [
|
||||
(shape, num_warps, trans_a, trans_b)
|
||||
for shape in [
|
||||
[16, 16, 16],
|
||||
[16, 16, 32],
|
||||
[32, 16, 16],
|
||||
[32, 32, 32],
|
||||
[128, 16, 16],
|
||||
]
|
||||
for num_warps in [1]
|
||||
for trans_a in [False]
|
||||
for trans_b in [False]
|
||||
@pytest.mark.parametrize('SIZE_M,SIZE_N,SIZE_K,NUM_WARPS,BLOCK_SIZE_M,BLOCK_SIZE_N,BLOCK_SIZE_K,TRANS_A,TRANS_B', [
|
||||
# Non-forloop
|
||||
# [16, 16, 16, 1, 16, 16, 16, False, False],
|
||||
# [16, 16, 32, 1, 16, 16, 32, False, False],
|
||||
# [32, 16, 32, 1, 32, 16, 32, False, False],
|
||||
# [32, 32, 32, 1, 32, 32, 32, False, False],
|
||||
# [128, 32, 32, 1, 128, 32, 32, False, False],
|
||||
|
||||
# [128, 32, 32, 1, 128, 32, 32, True, False],
|
||||
# [128, 32, 32, 1, 128, 32, 32, True, True],
|
||||
|
||||
# # split-K
|
||||
# [16, 16, 32, 1, 16, 16, 16, False, False],
|
||||
# [64, 64, 128, 1, 64, 64, 32, False, False],
|
||||
|
||||
# [16, 16, 32, 1, 16, 16, 16, True, False],
|
||||
# [16, 16, 32, 1, 16, 16, 16, True, True],
|
||||
[64, 64, 64, 1, 64, 64, 32, True, False],
|
||||
])
|
||||
def test_gemm_no_scf_for_mmav1(SHAPE, NUM_WARPS, TRANS_A, TRANS_B):
|
||||
test_gemm_no_scf(SHAPE, NUM_WARPS, TRANS_A, TRANS_B)
|
||||
def test_gemm_for_mmav1(SIZE_M, SIZE_N, SIZE_K, NUM_WARPS, BLOCK_SIZE_M, BLOCK_SIZE_N, BLOCK_SIZE_K, TRANS_A, TRANS_B):
|
||||
test_gemm(SIZE_M, SIZE_N, SIZE_K, NUM_WARPS, BLOCK_SIZE_M, BLOCK_SIZE_N, BLOCK_SIZE_K, TRANS_A, TRANS_B)
|
||||
|
||||
|
198
python/tests/test_random.py
Normal file
198
python/tests/test_random.py
Normal file
@@ -0,0 +1,198 @@
|
||||
import numpy as np
|
||||
import pytest
|
||||
import scipy.stats
|
||||
import torch
|
||||
|
||||
import triton
|
||||
import triton.language as tl
|
||||
|
||||
#####################################
|
||||
# Reference Philox Implementation
|
||||
#####################################
|
||||
|
||||
|
||||
class PhiloxConfig:
|
||||
def __init__(self, PHILOX_ROUND_A, PHILOX_ROUND_B, PHILOX_KEY_A, PHILOX_KEY_B, DTYPE):
|
||||
self.PHILOX_ROUND_A = np.array(PHILOX_ROUND_A, dtype=DTYPE)
|
||||
self.PHILOX_ROUND_B = np.array(PHILOX_ROUND_B, dtype=DTYPE)
|
||||
self.PHILOX_KEY_A = np.array(PHILOX_KEY_A, dtype=DTYPE)
|
||||
self.PHILOX_KEY_B = np.array(PHILOX_KEY_B, dtype=DTYPE)
|
||||
self.DTYPE = DTYPE
|
||||
|
||||
|
||||
# This is better for GPU
|
||||
PHILOX_32 = PhiloxConfig(
|
||||
PHILOX_KEY_A=0x9E3779B9,
|
||||
PHILOX_KEY_B=0xBB67AE85,
|
||||
PHILOX_ROUND_A=0xD2511F53,
|
||||
PHILOX_ROUND_B=0xCD9E8D57,
|
||||
DTYPE=np.uint32,
|
||||
)
|
||||
|
||||
# This is what numpy implements
|
||||
PHILOX_64 = PhiloxConfig(
|
||||
PHILOX_KEY_A=0x9E3779B97F4A7C15,
|
||||
PHILOX_KEY_B=0xBB67AE8584CAA73B,
|
||||
PHILOX_ROUND_A=0xD2E7470EE14C6C93,
|
||||
PHILOX_ROUND_B=0xCA5A826395121157,
|
||||
DTYPE=np.uint64,
|
||||
)
|
||||
|
||||
|
||||
class CustomPhilox4x:
|
||||
def __init__(self, seed, config):
|
||||
self._config = config
|
||||
seed = self._into_pieces(seed)
|
||||
self._key = np.array(seed[:2], dtype=self._dtype)
|
||||
self._counter = np.array((0, 0) + seed[2:], dtype=self._dtype)
|
||||
|
||||
@property
|
||||
def _dtype(self):
|
||||
return self._config.DTYPE
|
||||
|
||||
def _into_pieces(self, n, pad=4):
|
||||
res = []
|
||||
while len(res) < pad:
|
||||
res.append(np.array(n, dtype=self._dtype))
|
||||
n >>= (np.dtype(self._dtype).itemsize * 8)
|
||||
assert n == 0
|
||||
return tuple(res)
|
||||
|
||||
def _multiply_low_high(self, a, b):
|
||||
low = a * b
|
||||
high = int(a) * int(b)
|
||||
high = np.array(high >> (np.dtype(self._dtype).itemsize * 8), dtype=self._dtype)
|
||||
return low, high
|
||||
|
||||
def _single_round(self, counter, key):
|
||||
lo0, hi0 = self._multiply_low_high(self._config.PHILOX_ROUND_A, counter[0])
|
||||
lo1, hi1 = self._multiply_low_high(self._config.PHILOX_ROUND_B, counter[2])
|
||||
ret0 = hi1 ^ counter[1] ^ key[0]
|
||||
ret1 = lo1
|
||||
ret2 = hi0 ^ counter[3] ^ key[1]
|
||||
ret3 = lo0
|
||||
return np.array([ret0, ret1, ret2, ret3], dtype=self._dtype)
|
||||
|
||||
def _raise_key(self, key):
|
||||
pk = [self._config.PHILOX_KEY_A, self._config.PHILOX_KEY_B]
|
||||
return key + np.array(pk, dtype=self._dtype)
|
||||
|
||||
def random_raw(self):
|
||||
counter = self._counter
|
||||
key = self._key
|
||||
for _ in range(10):
|
||||
counter = self._single_round(counter, key)
|
||||
key = self._raise_key(key)
|
||||
self.advance(1)
|
||||
return counter
|
||||
|
||||
def advance(self, n_steps):
|
||||
self._counter[0] += n_steps
|
||||
assert self._counter[0] < 2**32, "FIXME: doesn't work for large offsets"
|
||||
|
||||
|
||||
class CustomPhilox(CustomPhilox4x):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.buffer = []
|
||||
|
||||
def random_raw(self):
|
||||
if len(self.buffer) == 0:
|
||||
self.buffer = list(super().random_raw())[::-1]
|
||||
return int(self.buffer.pop())
|
||||
|
||||
|
||||
#####################################
|
||||
# Unit Tests
|
||||
#####################################
|
||||
|
||||
BLOCK = 1024
|
||||
|
||||
# test generation of random uint32
|
||||
|
||||
|
||||
@pytest.mark.parametrize('size, seed',
|
||||
[(size, seed) for size in ['10', '4,53', '10000']
|
||||
for seed in [0, 42, 124, 54, 0xffffffff, 0xdeadbeefcafeb0ba]]
|
||||
)
|
||||
def test_randint(size, seed, device='cuda'):
|
||||
size = list(map(int, size.split(',')))
|
||||
|
||||
@triton.jit
|
||||
def kernel(X, N, seed):
|
||||
offset = tl.program_id(0) * BLOCK + tl.arange(0, BLOCK)
|
||||
rand = tl.randint(seed, offset)
|
||||
tl.store(X + offset, rand, mask=offset < N)
|
||||
# triton result
|
||||
x = torch.empty(size, dtype=torch.int32, device=device)
|
||||
N = x.numel()
|
||||
grid = (triton.cdiv(N, BLOCK),)
|
||||
kernel[grid](x, N, seed)
|
||||
out_tri = x.cpu().numpy().astype(np.uint32).flatten().tolist()
|
||||
# reference result
|
||||
gen = CustomPhilox4x(seed, config=PHILOX_32)
|
||||
out_ref = [gen.random_raw()[0] for _ in out_tri]
|
||||
assert out_tri == out_ref
|
||||
|
||||
# test uniform PRNG
|
||||
|
||||
|
||||
@pytest.mark.parametrize('size, seed',
|
||||
[(size, seed) for size in [1000000]
|
||||
for seed in [0, 42, 124, 54]]
|
||||
)
|
||||
def test_rand(size, seed, device='cuda'):
|
||||
@triton.jit
|
||||
def kernel(X, N, seed):
|
||||
offset = tl.program_id(0) * BLOCK + tl.arange(0, BLOCK)
|
||||
rand = tl.rand(seed, offset)
|
||||
tl.store(X + offset, rand, mask=offset < N)
|
||||
# triton result
|
||||
x = torch.empty(size, dtype=torch.float32, device=device)
|
||||
N = x.numel()
|
||||
grid = (triton.cdiv(N, BLOCK),)
|
||||
kernel[grid](x, N, seed)
|
||||
assert all((x >= 0) & (x <= 1))
|
||||
assert scipy.stats.kstest(x.tolist(), 'uniform', args=(0, 1)).statistic < 0.01
|
||||
|
||||
# test normal PRNG
|
||||
|
||||
|
||||
@pytest.mark.parametrize('size, seed',
|
||||
[(size, seed) for size in [1000000]
|
||||
for seed in [0, 42, 124, 54]]
|
||||
)
|
||||
def test_randn(size, seed, device='cuda'):
|
||||
@triton.jit
|
||||
def kernel(X, N, seed):
|
||||
offset = tl.program_id(0) * BLOCK + tl.arange(0, BLOCK)
|
||||
rand = tl.randn(seed, offset)
|
||||
tl.store(X + offset, rand, mask=offset < N)
|
||||
# triton result
|
||||
x = torch.empty(size, dtype=torch.float32, device=device)
|
||||
N = x.numel()
|
||||
grid = (triton.cdiv(N, BLOCK),)
|
||||
kernel[grid](x, N, seed)
|
||||
assert abs(x.mean()) < 1e-2
|
||||
assert abs(x.std() - 1) < 1e-2
|
||||
|
||||
|
||||
# tl.rand() should never produce >=1.0
|
||||
|
||||
def test_rand_limits():
|
||||
@triton.jit
|
||||
def kernel(input, output, n: tl.constexpr):
|
||||
idx = tl.arange(0, n)
|
||||
x = tl.load(input + idx)
|
||||
y = tl.random.uint32_to_uniform_float(x)
|
||||
tl.store(output + idx, y)
|
||||
|
||||
min_max_int32 = torch.tensor([
|
||||
torch.iinfo(torch.int32).min,
|
||||
torch.iinfo(torch.int32).max,
|
||||
], dtype=torch.int32, device='cuda')
|
||||
output = torch.empty(2, dtype=torch.float32, device='cuda')
|
||||
kernel[(1,)](min_max_int32, output, 2)
|
||||
|
||||
assert output[0] == output[1]
|
||||
assert 1.0 - torch.finfo(torch.float32).eps <= output[0].item() < 1.0
|
@@ -371,6 +371,7 @@ class CodeGenerator(ast.NodeVisitor):
|
||||
# 1. we have an orelse node
|
||||
# or
|
||||
# 2. the then block defines new variable
|
||||
else_defs = {}
|
||||
if then_defs or node.orelse:
|
||||
if node.orelse:
|
||||
self.lscope = liveins
|
||||
@@ -381,7 +382,6 @@ class CodeGenerator(ast.NodeVisitor):
|
||||
else_defs = self.local_defs.copy()
|
||||
else:
|
||||
# collect else_defs
|
||||
else_defs = {}
|
||||
for name in then_defs:
|
||||
if name in liveins:
|
||||
assert self.is_triton_tensor(then_defs[name])
|
||||
@@ -583,7 +583,7 @@ class CodeGenerator(ast.NodeVisitor):
|
||||
isinstance(step, triton.language.constexpr):
|
||||
sta_range = iterator(lb.value, ub.value, step.value)
|
||||
static_unrolling = os.environ.get('TRITON_STATIC_LOOP_UNROLLING', False)
|
||||
if static_unrolling and len(range) <= 10:
|
||||
if static_unrolling and len(sta_range) <= 10:
|
||||
for i in sta_range:
|
||||
self.lscope[node.target.id] = triton.language.constexpr(i)
|
||||
self.visit_compound_statement(node.body)
|
||||
@@ -625,10 +625,12 @@ class CodeGenerator(ast.NodeVisitor):
|
||||
if name in liveins:
|
||||
assert self.is_triton_tensor(self.local_defs[name]), f'{name} is not tensor'
|
||||
assert self.is_triton_tensor(liveins[name])
|
||||
if self.local_defs[name].type == liveins[name].type:
|
||||
names.append(name)
|
||||
init_args.append(triton.language.core._to_tensor(liveins[name], self.builder))
|
||||
yields.append(triton.language.core._to_tensor(self.local_defs[name], self.builder))
|
||||
if self.local_defs[name].type != liveins[name].type:
|
||||
local_value = self.local_defs[name]
|
||||
self.local_defs[name] = local_value.to(liveins[name].dtype, _builder=self.builder)
|
||||
names.append(name)
|
||||
init_args.append(triton.language.core._to_tensor(liveins[name], self.builder))
|
||||
yields.append(triton.language.core._to_tensor(self.local_defs[name], self.builder))
|
||||
|
||||
# create ForOp
|
||||
self.builder.set_insertion_point_to_end(insert_block)
|
||||
@@ -1400,9 +1402,9 @@ def compile(fn, **kwargs):
|
||||
"ttir": (lambda path: _triton.ir.parse_mlir_module(path, context),
|
||||
lambda src: ast_to_ttir(src, signature, configs[0], constants)),
|
||||
"ttgir": (lambda path: _triton.ir.parse_mlir_module(path, context),
|
||||
lambda src: ttir_to_ttgir(src, num_warps, num_stages, capability)),
|
||||
lambda src: ttir_to_ttgir(src, num_warps, num_stages, 70)),
|
||||
"llir": (lambda path: Path(path).read_bytes(),
|
||||
lambda src: ttgir_to_llir(src, extern_libs, capability)),
|
||||
lambda src: ttgir_to_llir(src, extern_libs, 70)),
|
||||
"ptx": (lambda path: Path(path).read_text(),
|
||||
lambda src: llir_to_ptx(src, capability)),
|
||||
"cubin": (lambda path: Path(path).read_bytes(),
|
||||
|
@@ -55,6 +55,7 @@ from .core import (
|
||||
printf,
|
||||
program_id,
|
||||
ravel,
|
||||
reshape,
|
||||
sigmoid,
|
||||
sin,
|
||||
softmax,
|
||||
@@ -70,6 +71,7 @@ from .core import (
|
||||
uint64,
|
||||
uint8,
|
||||
umulhi,
|
||||
view,
|
||||
void,
|
||||
where,
|
||||
xor_sum,
|
||||
@@ -149,6 +151,7 @@ __all__ = [
|
||||
"randn",
|
||||
"randn4x",
|
||||
"ravel",
|
||||
"reshape",
|
||||
"sigmoid",
|
||||
"sin",
|
||||
"softmax",
|
||||
@@ -165,6 +168,7 @@ __all__ = [
|
||||
"uint64",
|
||||
"uint8",
|
||||
"umulhi",
|
||||
"view",
|
||||
"void",
|
||||
"where",
|
||||
"xor_sum",
|
||||
|
@@ -17,11 +17,11 @@ def _to_tensor(x, builder):
|
||||
if -2**31 <= x < 2**31:
|
||||
return tensor(builder.get_int32(x), int32)
|
||||
elif 2**31 <= x < 2**32:
|
||||
return tensor(builder.get_uint32(x), uint32)
|
||||
return tensor(builder.get_int32(x), uint32)
|
||||
elif -2**63 <= x < 2**63:
|
||||
return tensor(builder.get_int64(x), int64)
|
||||
elif 2**63 <= x < 2**64:
|
||||
return tensor(builder.get_uint64(x), uint64)
|
||||
return tensor(builder.get_int64(x), uint64)
|
||||
else:
|
||||
raise RuntimeError(f'Nonrepresentable integer {x}.')
|
||||
elif isinstance(x, float):
|
||||
@@ -731,7 +731,7 @@ def trans(input, _builder=None):
|
||||
return semantic.trans(input, _builder)
|
||||
|
||||
@builtin
|
||||
def cat(input, other, _builder=None):
|
||||
def cat(input, other, can_reorder=False, _builder=None):
|
||||
"""
|
||||
Concatenate the given blocks
|
||||
|
||||
@@ -739,8 +739,12 @@ def cat(input, other, _builder=None):
|
||||
:type input:
|
||||
:param other: The second input tensor.
|
||||
:type other:
|
||||
:param reorder: Compiler hint. If true, the compiler is
|
||||
allowed to reorder elements while concatenating inputs.
|
||||
Only use if the order does not matter (e.g., result is
|
||||
only used in reduction ops)
|
||||
"""
|
||||
return semantic.cat(input, other, _builder)
|
||||
return semantic.cat(input, other, can_reorder, _builder)
|
||||
|
||||
|
||||
@builtin
|
||||
@@ -761,7 +765,8 @@ def view(input, shape, _builder=None):
|
||||
@builtin
|
||||
def reshape(input, shape, _builder=None):
|
||||
# TODO: should be more than just a view
|
||||
return view(input, shape, _builder)
|
||||
shape = [x.value for x in shape]
|
||||
return semantic.view(input, shape, _builder)
|
||||
|
||||
# -----------------------
|
||||
# Linear Algebra
|
||||
|
@@ -1,10 +1,10 @@
|
||||
import triton
|
||||
from . import core as tl
|
||||
|
||||
PHILOX_KEY_A: tl.constexpr = -1640531527 # 0x9E3779B9
|
||||
PHILOX_KEY_B: tl.constexpr = -1150833019 # 0xBB67AE85
|
||||
PHILOX_ROUND_A: tl.constexpr = -766435501 # 0xD2511F53
|
||||
PHILOX_ROUND_B: tl.constexpr = -845247145 # 0xCD9E8D57
|
||||
PHILOX_KEY_A: tl.constexpr = 0x9E3779B9
|
||||
PHILOX_KEY_B: tl.constexpr = 0xBB67AE85
|
||||
PHILOX_ROUND_A: tl.constexpr = 0xD2511F53
|
||||
PHILOX_ROUND_B: tl.constexpr = 0xCD9E8D57
|
||||
N_ROUNDS_DEFAULT = 10 # Default number of rounds for philox
|
||||
|
||||
# -------------------
|
||||
|
@@ -498,9 +498,11 @@ def expand_dims(input: tl.tensor, axis: int, builder: ir.builder) -> tl.tensor:
|
||||
return tl.tensor(builder.create_expand_dims(input.handle, axis), ret_ty)
|
||||
|
||||
|
||||
def cat(lhs: tl.tensor, rhs: tl.tensor, builder: ir.builder) -> tl.tensor:
|
||||
# TODO: check types
|
||||
return tl.tensor(builder.create_cat(lhs.handle, rhs.handle), lhs.type)
|
||||
def cat(lhs: tl.tensor, rhs: tl.tensor, can_reorder: bool, builder: ir.builder) -> tl.tensor:
|
||||
assert can_reorder, "current implementation of `cat` always may reorder elements"
|
||||
assert len(lhs.shape) == 1
|
||||
ret_type = tl.block_type(lhs.type.scalar, [lhs.shape[0] + rhs.shape[0]])
|
||||
return tl.tensor(builder.create_cat(lhs.handle, rhs.handle), ret_type)
|
||||
|
||||
def trans(input: tl.tensor, builder: ir.builder) -> tl.tensor:
|
||||
if len(input.shape) != 2:
|
||||
@@ -736,16 +738,18 @@ def load(ptr: tl.tensor,
|
||||
if other:
|
||||
other = broadcast_impl_shape(other, ptr.type.get_block_shapes(), builder)
|
||||
|
||||
if other:
|
||||
other = cast(other, ptr.type.scalar.element_ty, builder)
|
||||
ptr_ty = ptr.type.scalar
|
||||
elt_ty = ptr_ty.element_ty
|
||||
|
||||
# treat bool* as tl.int8*
|
||||
if elt_ty == tl.int1:
|
||||
elt_ty = tl.int8
|
||||
ptr_ty = tl.pointer_type(elt_ty, ptr_ty.address_space)
|
||||
ptr = cast(ptr, ptr_ty, builder)
|
||||
|
||||
if other:
|
||||
other = cast(other, elt_ty, builder)
|
||||
|
||||
# cache modifier
|
||||
cache = ir.CACHE_MODIFIER.NONE # default
|
||||
if cache_modifier:
|
||||
|
@@ -27,8 +27,8 @@ func @matmul_loop(%lb : index, %ub : index, %step : index, %A : !tt.ptr<f16>, %B
|
||||
%b = triton_gpu.convert_layout %b_ : (tensor<32x128xf16, #BL>) -> tensor<32x128xf16, #B_DOT>
|
||||
%c = tt.dot %a, %b, %prev_c {transA = false, transB = false, allowTF32 = true} : tensor<128x32xf16, #A_DOT> * tensor<32x128xf16, #B_DOT> -> tensor<128x128xf32, #C>
|
||||
|
||||
%next_a_ptr = tt.addptr %a_ptr, %a_off : tensor<128x32x!tt.ptr<f16>, #AL>
|
||||
%next_b_ptr = tt.addptr %b_ptr, %b_off : tensor<32x128x!tt.ptr<f16>, #BL>
|
||||
%next_a_ptr = tt.addptr %a_ptr, %a_off : tensor<128x32x!tt.ptr<f16>, #AL>, tensor<128x32xi32, #AL>
|
||||
%next_b_ptr = tt.addptr %b_ptr, %b_off : tensor<32x128x!tt.ptr<f16>, #BL>, tensor<32x128xi32, #BL>
|
||||
scf.yield %next_a_ptr, %next_b_ptr, %c : tensor<128x32x!tt.ptr<f16>, #AL>, tensor<32x128x!tt.ptr<f16>, #BL>, tensor<128x128xf32, #C>
|
||||
}
|
||||
return
|
||||
|
@@ -18,7 +18,7 @@ func @permute_2d(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg1: i32 {t
|
||||
// CHECK-NEXT: Contiguity: [1, 1] ; Divisibility: [16, 16] ; Constancy: [128, 1]
|
||||
%5 = tt.splat %arg0 : (!tt.ptr<f32>) -> tensor<128x1x!tt.ptr<f32>>
|
||||
// CHECK-NEXT: Contiguity: [1, 1] ; Divisibility: [16, 16] ; Constancy: [1, 1]
|
||||
%6 = tt.addptr %5, %4 : tensor<128x1x!tt.ptr<f32>>
|
||||
%6 = tt.addptr %5, %4 : tensor<128x1x!tt.ptr<f32>>, tensor<128x1xi32>
|
||||
// CHECK-NEXT: Contiguity: [1, 128] ; Divisibility: [1, 65536] ; Constancy: [1, 1]
|
||||
%7 = tt.expand_dims %1 {axis = 0 : i32}: (tensor<128xi32>) -> tensor<1x128xi32>
|
||||
// CHECK-NEXT: Contiguity: [1, 1] ; Divisibility: [16, 16] ; Constancy: [1, 128]
|
||||
@@ -26,13 +26,13 @@ func @permute_2d(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg1: i32 {t
|
||||
// CHECK-NEXT: Contiguity: [1, 128] ; Divisibility: [1, 65536] ; Constancy: [128, 1]
|
||||
%9 = tt.broadcast %7 : (tensor<1x128xi32>) -> tensor<128x128xi32>
|
||||
// CHECK-NEXT: Contiguity: [1, 128] ; Divisibility: [1, 16] ; Constancy: [1, 1]
|
||||
%10 = tt.addptr %8, %9 : tensor<128x128x!tt.ptr<f32>>
|
||||
%10 = tt.addptr %8, %9 : tensor<128x128x!tt.ptr<f32>>, tensor<128x128xi32>
|
||||
// CHECK-NEXT: Contiguity: [128, 1] ; Divisibility: [65536, 1] ; Constancy: [1, 1]
|
||||
%11 = tt.expand_dims %0 {axis = 1 : i32}: (tensor<128xi32>) -> tensor<128x1xi32>
|
||||
// CHECK-NEXT: Contiguity: [1, 1] ; Divisibility: [16, 16] ; Constancy: [128, 1]
|
||||
%12 = tt.splat %arg2 : (!tt.ptr<f32>) -> tensor<128x1x!tt.ptr<f32>>
|
||||
// CHECK-NEXT: Contiguity: [128, 1] ; Divisibility: [16, 1] ; Constancy: [1, 1]
|
||||
%13 = tt.addptr %12, %11 : tensor<128x1x!tt.ptr<f32>>
|
||||
%13 = tt.addptr %12, %11 : tensor<128x1x!tt.ptr<f32>>, tensor<128x1xi32>
|
||||
// CHECK-NEXT: Contiguity: [1, 128] ; Divisibility: [1, 65536] ; Constancy: [1, 1]
|
||||
%14 = tt.expand_dims %1 {axis = 0 : i32} : (tensor<128xi32>) -> tensor<1x128xi32>
|
||||
// CHECK-NEXT: Contiguity: [1, 1] ; Divisibility: [16, 16] ; Constancy: [1, 128]
|
||||
@@ -44,7 +44,7 @@ func @permute_2d(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg1: i32 {t
|
||||
// CHECK-NEXT: Contiguity: [1, 1] ; Divisibility: [16, 1048576] ; Constancy: [128, 1]
|
||||
%18 = tt.broadcast %16 : (tensor<1x128xi32>) -> tensor<128x128xi32>
|
||||
// CHECK-NEXT: Contiguity: [128, 1] ; Divisibility: [16, 1] ; Constancy: [1, 1]
|
||||
%19 = tt.addptr %17, %18 : tensor<128x128x!tt.ptr<f32>>
|
||||
%19 = tt.addptr %17, %18 : tensor<128x128x!tt.ptr<f32>>, tensor<128x128xi32>
|
||||
// CHECK-NEXT: Contiguity: [1, 1] ; Divisibility: [1, 1] ; Constancy: [1, 1]
|
||||
%20 = tt.load %10, %cst, %cst_0 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<128x128xf32>
|
||||
tt.store %19, %20, %cst : tensor<128x128xf32>
|
||||
@@ -72,7 +72,7 @@ func @store_constant_align(%addr: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %n:
|
||||
// CHECK-NEXT: Contiguity: [1] ; Divisibility: [16] ; Constancy: [128]
|
||||
%5 = tt.splat %addr : (!tt.ptr<f32>) -> tensor<128x!tt.ptr<f32>>
|
||||
// CHECK-NEXT: Contiguity: [128] ; Divisibility: [16] ; Constancy: [1]
|
||||
%6 = tt.addptr %5, %4 : tensor<128x!tt.ptr<f32>>
|
||||
%6 = tt.addptr %5, %4 : tensor<128x!tt.ptr<f32>>, tensor<128xi32>
|
||||
// CHECK-NEXT: Contiguity: [1] ; Divisibility: [16] ; Constancy: [128]
|
||||
%9 = tt.splat %n : (i32) -> tensor<128xi32>
|
||||
// CHECK-NEXT: Contiguity: [1] ; Divisibility: [128] ; Constancy: [16]
|
||||
@@ -97,9 +97,9 @@ func @vecadd_mask_align_16(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %ar
|
||||
%3 = tt.splat %1 : (i32) -> tensor<64xi32>
|
||||
%4 = arith.addi %3, %2 : tensor<64xi32>
|
||||
%5 = tt.splat %arg0 : (!tt.ptr<f32>) -> tensor<64x!tt.ptr<f32>>
|
||||
%6 = tt.addptr %5, %4 : tensor<64x!tt.ptr<f32>>
|
||||
%6 = tt.addptr %5, %4 : tensor<64x!tt.ptr<f32>>, tensor<64xi32>
|
||||
%7 = tt.splat %arg1 : (!tt.ptr<f32>) -> tensor<64x!tt.ptr<f32>>
|
||||
%8 = tt.addptr %7, %4 : tensor<64x!tt.ptr<f32>>
|
||||
%8 = tt.addptr %7, %4 : tensor<64x!tt.ptr<f32>>, tensor<64xi32>
|
||||
%9 = tt.splat %n_elements : (i32) -> tensor<64xi32>
|
||||
// CHECK: Contiguity: [1] ; Divisibility: [64] ; Constancy: [16] ( %{{.*}} = arith.cmpi slt, %{{.*}}, %{{.*}} : tensor<64xi32> )
|
||||
%mask = arith.cmpi slt, %4, %9 : tensor<64xi32>
|
||||
@@ -107,8 +107,8 @@ func @vecadd_mask_align_16(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %ar
|
||||
%12 = tt.load %8, %mask {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<64xf32>
|
||||
%13 = arith.addf %11, %12 : tensor<64xf32>
|
||||
%14 = tt.splat %arg2 : (!tt.ptr<f32>) -> tensor<64x!tt.ptr<f32>>
|
||||
// CHECK: Contiguity: [64] ; Divisibility: [16] ; Constancy: [1] ( %{{.*}} = tt.addptr %{{.*}}, %{{.*}} : tensor<64x!tt.ptr<f32>> )
|
||||
%15 = tt.addptr %14, %4 : tensor<64x!tt.ptr<f32>>
|
||||
// CHECK: Contiguity: [64] ; Divisibility: [16] ; Constancy: [1] ( %{{.*}} = tt.addptr %{{.*}}, %{{.*}} : tensor<64x!tt.ptr<f32>>, tensor<64xi32> )
|
||||
%15 = tt.addptr %14, %4 : tensor<64x!tt.ptr<f32>>, tensor<64xi32>
|
||||
tt.store %15, %13, %mask : tensor<64xf32>
|
||||
return
|
||||
}
|
||||
@@ -125,9 +125,9 @@ func @vecadd_mask_align_1(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg
|
||||
%3 = tt.splat %1 : (i32) -> tensor<64xi32>
|
||||
%4 = arith.addi %3, %2 : tensor<64xi32>
|
||||
%5 = tt.splat %arg0 : (!tt.ptr<f32>) -> tensor<64x!tt.ptr<f32>>
|
||||
%6 = tt.addptr %5, %4 : tensor<64x!tt.ptr<f32>>
|
||||
%6 = tt.addptr %5, %4 : tensor<64x!tt.ptr<f32>>, tensor<64xi32>
|
||||
%7 = tt.splat %arg1 : (!tt.ptr<f32>) -> tensor<64x!tt.ptr<f32>>
|
||||
%8 = tt.addptr %7, %4 : tensor<64x!tt.ptr<f32>>
|
||||
%8 = tt.addptr %7, %4 : tensor<64x!tt.ptr<f32>>, tensor<64xi32>
|
||||
%9 = tt.splat %n_elements : (i32) -> tensor<64xi32>
|
||||
// CHECK: Contiguity: [1] ; Divisibility: [64] ; Constancy: [1] ( %{{.*}} = arith.cmpi slt, %{{.*}}, %{{.*}} : tensor<64xi32> )
|
||||
%10 = arith.cmpi slt, %4, %9 : tensor<64xi32>
|
||||
@@ -135,7 +135,7 @@ func @vecadd_mask_align_1(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg
|
||||
%12 = tt.load %8, %10 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<64xf32>
|
||||
%13 = arith.addf %11, %12 : tensor<64xf32>
|
||||
%14 = tt.splat %arg2 : (!tt.ptr<f32>) -> tensor<64x!tt.ptr<f32>>
|
||||
%15 = tt.addptr %14, %4 : tensor<64x!tt.ptr<f32>>
|
||||
%15 = tt.addptr %14, %4 : tensor<64x!tt.ptr<f32>>, tensor<64xi32>
|
||||
tt.store %15, %13, %10 : tensor<64xf32>
|
||||
return
|
||||
}
|
||||
|
@@ -35,8 +35,8 @@ func @matmul_loop(%lb : index, %ub : index, %step : index, %A : !tt.ptr<f16>, %B
|
||||
|
||||
%c = tt.dot %a, %b, %prev_c {allowTF32 = true, transA = false, transB = false} : tensor<128x32xf16, #A_DOT> * tensor<32x128xf16, #B_DOT> -> tensor<128x128xf32, #C>
|
||||
|
||||
%next_a_ptr = tt.addptr %a_ptr, %a_off : tensor<128x32x!tt.ptr<f16>, #AL>
|
||||
%next_b_ptr = tt.addptr %b_ptr, %b_off : tensor<32x128x!tt.ptr<f16>, #BL>
|
||||
%next_a_ptr = tt.addptr %a_ptr, %a_off : tensor<128x32x!tt.ptr<f16>, #AL>, tensor<128x32xi32, #AL>
|
||||
%next_b_ptr = tt.addptr %b_ptr, %b_off : tensor<32x128x!tt.ptr<f16>, #BL>, tensor<32x128xi32, #BL>
|
||||
scf.yield %next_a_ptr, %next_b_ptr, %c : tensor<128x32x!tt.ptr<f16>, #AL>, tensor<32x128x!tt.ptr<f16>, #BL>, tensor<128x128xf32, #C>
|
||||
}
|
||||
return
|
||||
|
@@ -33,8 +33,8 @@ func @matmul_loop(%lb : index, %ub : index, %step : index, %A : !tt.ptr<f16>, %B
|
||||
%b = triton_gpu.convert_layout %b_ : (tensor<32x128xf16, #BL>) -> tensor<32x128xf16, #B_DOT>
|
||||
%c = tt.dot %a, %b, %prev_c {allowTF32 = true, transA = false, transB = false} : tensor<128x32xf16, #A_DOT> * tensor<32x128xf16, #B_DOT> -> tensor<128x128xf32, #C>
|
||||
|
||||
%next_a_ptr = tt.addptr %a_ptr, %a_off : tensor<128x32x!tt.ptr<f16>, #AL>
|
||||
%next_b_ptr = tt.addptr %b_ptr, %b_off : tensor<32x128x!tt.ptr<f16>, #BL>
|
||||
%next_a_ptr = tt.addptr %a_ptr, %a_off : tensor<128x32x!tt.ptr<f16>, #AL>, tensor<128x32xi32, #AL>
|
||||
%next_b_ptr = tt.addptr %b_ptr, %b_off : tensor<32x128x!tt.ptr<f16>, #BL>, tensor<32x128xi32, #BL>
|
||||
scf.yield %next_a_ptr, %next_b_ptr, %c : tensor<128x32x!tt.ptr<f16>, #AL>, tensor<32x128x!tt.ptr<f16>, #BL>, tensor<128x128xf32, #C>
|
||||
}
|
||||
return
|
||||
|
@@ -38,19 +38,19 @@ func @cast_ops(%scalar_ptr: !tt.ptr<f32>, %scalar_f32: f32, %scalar_i64: i64) {
|
||||
func @addptr_ops(%scalar_ptr: !tt.ptr<f32>, %scalar_i32: i32) {
|
||||
// scalar -> scalar
|
||||
// CHECK: !tt.ptr<f32>
|
||||
%0 = tt.addptr %scalar_ptr, %scalar_i32 : !tt.ptr<f32>
|
||||
%0 = tt.addptr %scalar_ptr, %scalar_i32 : !tt.ptr<f32>, i32
|
||||
|
||||
// 0D tensor -> 0D tensor
|
||||
%tensor_ptr_0d = tt.splat %scalar_ptr : (!tt.ptr<f32>) -> tensor<!tt.ptr<f32>>
|
||||
%tensor_i32_0d = tt.splat %scalar_i32 : (i32) -> tensor<i32>
|
||||
// CHECK: tensor<!tt.ptr<f32>>
|
||||
%1 = tt.addptr %tensor_ptr_0d, %tensor_i32_0d : tensor<!tt.ptr<f32>>
|
||||
%1 = tt.addptr %tensor_ptr_0d, %tensor_i32_0d : tensor<!tt.ptr<f32>>, tensor<i32>
|
||||
|
||||
// 1D tensor -> 1D tensor
|
||||
%tensor_ptr_1d = tt.splat %scalar_ptr : (!tt.ptr<f32>) -> tensor<16x!tt.ptr<f32>>
|
||||
%tensor_i32_1d = tt.splat %scalar_i32 : (i32) -> tensor<16xi32>
|
||||
// CHECK: tensor<16x!tt.ptr<f32>>
|
||||
%2 = tt.addptr %tensor_ptr_1d, %tensor_i32_1d : tensor<16x!tt.ptr<f32>>
|
||||
%2 = tt.addptr %tensor_ptr_1d, %tensor_i32_1d : tensor<16x!tt.ptr<f32>>, tensor<16xi32>
|
||||
return
|
||||
}
|
||||
|
||||
|
@@ -92,9 +92,9 @@ module attributes {"triton_gpu.num-warps" = 2 : i32} {
|
||||
%3 = tt.splat %1 : (i32) -> tensor<256xi32, #blocked0>
|
||||
%4 = arith.addi %3, %2 : tensor<256xi32, #blocked0>
|
||||
%5 = tt.splat %arg0 : (!tt.ptr<f32>) -> tensor<256x!tt.ptr<f32>, #blocked0>
|
||||
%6 = tt.addptr %5, %4 : tensor<256x!tt.ptr<f32>, #blocked0>
|
||||
%6 = tt.addptr %5, %4 : tensor<256x!tt.ptr<f32>, #blocked0>, tensor<256xi32, #blocked0>
|
||||
%7 = tt.splat %arg1 : (!tt.ptr<f32>) -> tensor<256x!tt.ptr<f32>, #blocked0>
|
||||
%8 = tt.addptr %7, %4 : tensor<256x!tt.ptr<f32>, #blocked0>
|
||||
%8 = tt.addptr %7, %4 : tensor<256x!tt.ptr<f32>, #blocked0>, tensor<256xi32, #blocked0>
|
||||
|
||||
// Load 4 elements from vector0
|
||||
// CHECK: "@${{.*}} ld.global.b32 { ${{.*}} }, [ ${{.*}} + 0 ];
|
||||
@@ -111,7 +111,7 @@ module attributes {"triton_gpu.num-warps" = 2 : i32} {
|
||||
%10 = tt.load %8 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<256xf32, #blocked0>
|
||||
%11 = arith.addf %9, %10 : tensor<256xf32, #blocked0>
|
||||
%12 = tt.splat %arg2 : (!tt.ptr<f32>) -> tensor<256x!tt.ptr<f32>, #blocked0>
|
||||
%13 = tt.addptr %12, %4 : tensor<256x!tt.ptr<f32>, #blocked0>
|
||||
%13 = tt.addptr %12, %4 : tensor<256x!tt.ptr<f32>, #blocked0>, tensor<256xi32, #blocked0>
|
||||
|
||||
// Store 4 elements to global
|
||||
// CHECK: @${{.*}} st.global.b32 [ ${{.*}} + 0 ], { ${{.*}} };
|
||||
@@ -136,9 +136,9 @@ module attributes {"triton_gpu.num-warps" = 2 : i32} {
|
||||
%3 = tt.splat %1 : (i32) -> tensor<256xi32, #blocked0>
|
||||
%4 = arith.addi %3, %2 : tensor<256xi32, #blocked0>
|
||||
%5 = tt.splat %arg0 : (!tt.ptr<f32>) -> tensor<256x!tt.ptr<f32>, #blocked0>
|
||||
%6 = tt.addptr %5, %4 : tensor<256x!tt.ptr<f32>, #blocked0>
|
||||
%6 = tt.addptr %5, %4 : tensor<256x!tt.ptr<f32>, #blocked0>, tensor<256xi32, #blocked0>
|
||||
%7 = tt.splat %arg1 : (!tt.ptr<f32>) -> tensor<256x!tt.ptr<f32>, #blocked0>
|
||||
%8 = tt.addptr %7, %4 : tensor<256x!tt.ptr<f32>, #blocked0>
|
||||
%8 = tt.addptr %7, %4 : tensor<256x!tt.ptr<f32>, #blocked0>, tensor<256xi32, #blocked0>
|
||||
|
||||
// Load 4 elements from A with single one vectorized load instruction
|
||||
// CHECK: @${{.*}} ld.global.v4.b32 { ${{.*}}, ${{.*}}, ${{.*}}, ${{.*}} }, [ ${{.*}} + 0 ];
|
||||
@@ -150,7 +150,7 @@ module attributes {"triton_gpu.num-warps" = 2 : i32} {
|
||||
%10 = tt.load %8 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<256xf32, #blocked0>
|
||||
%11 = arith.addf %9, %10 : tensor<256xf32, #blocked0>
|
||||
%12 = tt.splat %arg2 : (!tt.ptr<f32>) -> tensor<256x!tt.ptr<f32>, #blocked0>
|
||||
%13 = tt.addptr %12, %4 : tensor<256x!tt.ptr<f32>, #blocked0>
|
||||
%13 = tt.addptr %12, %4 : tensor<256x!tt.ptr<f32>, #blocked0>, tensor<256xi32, #blocked0>
|
||||
|
||||
// Store 4 elements to global with single one vectorized store instruction
|
||||
// CHECK: @$5 st.global.v4.b32 [ ${{.*}} + 0 ], { ${{.*}}, ${{.*}}, ${{.*}}, ${{.*}} };
|
||||
@@ -173,9 +173,9 @@ module attributes {"triton_gpu.num-warps" = 2 : i32} {
|
||||
%3 = tt.splat %1 : (i32) -> tensor<64xi32, #blocked>
|
||||
%4 = arith.addi %3, %2 : tensor<64xi32, #blocked>
|
||||
%5 = tt.splat %arg0 : (!tt.ptr<f32>) -> tensor<64x!tt.ptr<f32>, #blocked>
|
||||
%6 = tt.addptr %5, %4 : tensor<64x!tt.ptr<f32>, #blocked>
|
||||
%6 = tt.addptr %5, %4 : tensor<64x!tt.ptr<f32>, #blocked>, tensor<64xi32, #blocked>
|
||||
%7 = tt.splat %arg1 : (!tt.ptr<f32>) -> tensor<64x!tt.ptr<f32>, #blocked>
|
||||
%8 = tt.addptr %7, %4 : tensor<64x!tt.ptr<f32>, #blocked>
|
||||
%8 = tt.addptr %7, %4 : tensor<64x!tt.ptr<f32>, #blocked>, tensor<64xi32, #blocked>
|
||||
%9 = tt.splat %n_elements : (i32) -> tensor<64xi32, #blocked>
|
||||
%10 = "triton_gpu.cmpi"(%4, %9) {predicate = 2 : i64} : (tensor<64xi32, #blocked>, tensor<64xi32, #blocked>) -> tensor<64xi1, #blocked>
|
||||
// load op has a vector width = 1 due to the %mask's alignment
|
||||
@@ -184,7 +184,7 @@ module attributes {"triton_gpu.num-warps" = 2 : i32} {
|
||||
%12 = tt.load %8, %10 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<64xf32, #blocked>
|
||||
%13 = arith.addf %11, %12 : tensor<64xf32, #blocked>
|
||||
%14 = tt.splat %arg2 : (!tt.ptr<f32>) -> tensor<64x!tt.ptr<f32>, #blocked>
|
||||
%15 = tt.addptr %14, %4 : tensor<64x!tt.ptr<f32>, #blocked>
|
||||
%15 = tt.addptr %14, %4 : tensor<64x!tt.ptr<f32>, #blocked>, tensor<64xi32, #blocked>
|
||||
tt.store %15, %13, %10 : tensor<64xf32, #blocked>
|
||||
return
|
||||
}
|
||||
@@ -203,9 +203,9 @@ module attributes {"triton_gpu.num-warps" = 1 : i32} {
|
||||
%3 = tt.splat %1 : (i32) -> tensor<256xi32, #blocked0>
|
||||
%4 = arith.addi %3, %2 : tensor<256xi32, #blocked0>
|
||||
%5 = tt.splat %arg0 : (!tt.ptr<f32>) -> tensor<256x!tt.ptr<f32>, #blocked0>
|
||||
%6 = tt.addptr %5, %4 : tensor<256x!tt.ptr<f32>, #blocked0>
|
||||
%6 = tt.addptr %5, %4 : tensor<256x!tt.ptr<f32>, #blocked0>, tensor<256xi32, #blocked0>
|
||||
%7 = tt.splat %arg1 : (!tt.ptr<f32>) -> tensor<256x!tt.ptr<f32>, #blocked0>
|
||||
%8 = tt.addptr %7, %4 : tensor<256x!tt.ptr<f32>, #blocked0>
|
||||
%8 = tt.addptr %7, %4 : tensor<256x!tt.ptr<f32>, #blocked0>, tensor<256xi32, #blocked0>
|
||||
|
||||
// Load 8 elements from A with two vectorized load instruction
|
||||
// CHECK: @${{.*}} ld.global.v4.b32 { ${{.*}}, ${{.*}}, ${{.*}}, ${{.*}} }, [ ${{.*}} + 0 ];
|
||||
@@ -219,7 +219,7 @@ module attributes {"triton_gpu.num-warps" = 1 : i32} {
|
||||
%10 = tt.load %8 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<256xf32, #blocked0>
|
||||
%11 = arith.addf %9, %10 : tensor<256xf32, #blocked0>
|
||||
%12 = tt.splat %arg2 : (!tt.ptr<f32>) -> tensor<256x!tt.ptr<f32>, #blocked0>
|
||||
%13 = tt.addptr %12, %4 : tensor<256x!tt.ptr<f32>, #blocked0>
|
||||
%13 = tt.addptr %12, %4 : tensor<256x!tt.ptr<f32>, #blocked0>, tensor<256xi32, #blocked0>
|
||||
|
||||
// Store 8 elements to global with two vectorized store instruction
|
||||
// CHECK: @$5 st.global.v4.b32 [ ${{.*}} + 0 ], { ${{.*}}, ${{.*}}, ${{.*}}, ${{.*}} };
|
||||
@@ -317,7 +317,7 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} {
|
||||
func @basic_addptr(%arg0 : tensor<256x!tt.ptr<f32>,#blocked0>, %arg1 : tensor<256xi32,#blocked0>) {
|
||||
// CHECK: llvm.getelementptr
|
||||
// CHECK: llvm.getelementptr
|
||||
%0 = tt.addptr %arg0, %arg1 : tensor<256x!tt.ptr<f32>, #blocked0>
|
||||
%0 = tt.addptr %arg0, %arg1 : tensor<256x!tt.ptr<f32>, #blocked0>, tensor<256xi32, #blocked0>
|
||||
return
|
||||
}
|
||||
}
|
||||
@@ -411,7 +411,7 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} {
|
||||
%broadcast_off1 = triton_gpu.convert_layout %broadcast_off1_ : (tensor<16x64xi32, #block3>) -> tensor<16x64xi32, #AL>
|
||||
%off = arith.addi %broadcast_off0, %broadcast_off1 : tensor<16x64xi32, #AL>
|
||||
%a_init = tt.splat %arg0 : (!tt.ptr<f16>) -> tensor<16x64x!tt.ptr<f16>, #AL>
|
||||
%a_ptr = tt.addptr %a_init, %off : tensor<16x64x!tt.ptr<f16>, #AL>
|
||||
%a_ptr = tt.addptr %a_init, %off : tensor<16x64x!tt.ptr<f16>, #AL>, tensor<16x64xi32, #AL>
|
||||
%tensor = triton_gpu.alloc_tensor : tensor<2x16x64xf16, #A>
|
||||
%index = arith.constant 1 : i32
|
||||
|
||||
@@ -450,7 +450,7 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} {
|
||||
%broadcast_off1 = triton_gpu.convert_layout %broadcast_off1_ : (tensor<16x64xi32, #block3>) -> tensor<16x64xi32, #AL>
|
||||
%off = arith.addi %broadcast_off0, %broadcast_off1 : tensor<16x64xi32, #AL>
|
||||
%a_init = tt.splat %arg0 : (!tt.ptr<f32>) -> tensor<16x64x!tt.ptr<f32>, #AL>
|
||||
%a_ptr = tt.addptr %a_init, %off : tensor<16x64x!tt.ptr<f32>, #AL>
|
||||
%a_ptr = tt.addptr %a_init, %off : tensor<16x64x!tt.ptr<f32>, #AL>, tensor<16x64xi32, #AL>
|
||||
%tensor = triton_gpu.alloc_tensor : tensor<2x16x64xf32, #A>
|
||||
%index = arith.constant 1 : i32
|
||||
|
||||
@@ -491,7 +491,7 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} {
|
||||
%broadcast_off1 = triton_gpu.convert_layout %broadcast_off1_ : (tensor<16x32xi32, #block3>) -> tensor<16x32xi32, #AL>
|
||||
%off = arith.addi %broadcast_off0, %broadcast_off1 : tensor<16x32xi32, #AL>
|
||||
%a_init = tt.splat %arg0 : (!tt.ptr<f32>) -> tensor<16x32x!tt.ptr<f32>, #AL>
|
||||
%a_ptr = tt.addptr %a_init, %off : tensor<16x32x!tt.ptr<f32>, #AL>
|
||||
%a_ptr = tt.addptr %a_init, %off : tensor<16x32x!tt.ptr<f32>, #AL>, tensor<16x32xi32, #AL>
|
||||
%tensor = triton_gpu.alloc_tensor : tensor<2x16x32xf32, #A>
|
||||
%index = arith.constant 1 : i32
|
||||
|
||||
@@ -535,7 +535,7 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} {
|
||||
%broadcast_off1 = triton_gpu.convert_layout %broadcast_off1_ : (tensor<32x32xi32, #block3>) -> tensor<32x32xi32, #AL>
|
||||
%off = arith.addi %broadcast_off0, %broadcast_off1 : tensor<32x32xi32, #AL>
|
||||
%a_init = tt.splat %arg0 : (!tt.ptr<f32>) -> tensor<32x32x!tt.ptr<f32>, #AL>
|
||||
%a_ptr = tt.addptr %a_init, %off : tensor<32x32x!tt.ptr<f32>, #AL>
|
||||
%a_ptr = tt.addptr %a_init, %off : tensor<32x32x!tt.ptr<f32>, #AL>, tensor<32x32xi32, #AL>
|
||||
%tensor = triton_gpu.alloc_tensor : tensor<2x32x32xf32, #A>
|
||||
%index = arith.constant 1 : i32
|
||||
|
||||
@@ -879,8 +879,8 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} {
|
||||
#blocked = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [2, 16], warpsPerCTA = [1, 4], order = [1, 0]}>
|
||||
#shared = #triton_gpu.shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [1, 0]}>
|
||||
#mma = #triton_gpu.mma<{version = 1, warpsPerCTA = [2, 2]}>
|
||||
#dot_operand_a = #triton_gpu.dot_op<{opIdx=0, parent=#mma}>
|
||||
#dot_operand_b = #triton_gpu.dot_op<{opIdx=1, parent=#mma}>
|
||||
#dot_operand_a = #triton_gpu.dot_op<{opIdx=0, parent=#mma, isMMAv1Row=true}>
|
||||
#dot_operand_b = #triton_gpu.dot_op<{opIdx=1, parent=#mma, isMMAv1Row=true}>
|
||||
module attributes {"triton_gpu.num-warps" = 4 : i32} {
|
||||
func @matmul884_kernel_dot_operand_layout(%ptr:!tt.ptr<f32> {tt.divisibility = 16 : i32},
|
||||
%a:tensor<128x32xf16, #shared>, %b:tensor<32x256xf16, #shared>) {
|
||||
|
@@ -22,28 +22,30 @@ func @test_combine_dot_add_pattern() -> (tensor<128x128xf32>, tensor<128x128xf32
|
||||
return %res0, %res1 : tensor<128x128xf32>, tensor<128x128xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @test_combine_addptr_pattern
|
||||
|
||||
// COM: CHECK-LABEL: @test_combine_addptr_pattern
|
||||
func @test_combine_addptr_pattern(%base: !tt.ptr<f32>) -> tensor<8x!tt.ptr<f32>> {
|
||||
%off0 = arith.constant 10 : i32
|
||||
%off1 = arith.constant 15 : i32
|
||||
|
||||
// 10 + 15 = 25
|
||||
// CHECK-NEXT: %[[cst:.*]] = arith.constant dense<25> : tensor<8xi32>
|
||||
// COM: CHECK-NEXT: %[[cst:.*]] = arith.constant dense<25> : tensor<8xi32>
|
||||
|
||||
%base_ = tt.broadcast %base : (!tt.ptr<f32>) -> tensor<8x!tt.ptr<f32>>
|
||||
|
||||
// CHECK-NEXT: %[[tmp0:.*]] = tt.broadcast %{{.*}} : (!tt.ptr<f32>) -> tensor<8x!tt.ptr<f32>>
|
||||
// COM: CHECK-NEXT: %[[tmp0:.*]] = tt.broadcast %{{.*}} : (!tt.ptr<f32>) -> tensor<8x!tt.ptr<f32>>
|
||||
|
||||
%idx0 = tt.broadcast %off0 : (i32) -> tensor<8xi32>
|
||||
%idx1 = tt.broadcast %off1 : (i32) -> tensor<8xi32>
|
||||
|
||||
// CHECK-NEXT: %1 = tt.addptr %[[tmp0]], %[[cst]] : tensor<8x!tt.ptr<f32>>
|
||||
%ptr0 = tt.addptr %base_, %idx0 : tensor<8x!tt.ptr<f32>>
|
||||
%ptr1 = tt.addptr %ptr0, %idx1 : tensor<8x!tt.ptr<f32>>
|
||||
|
||||
// COM: CHECK-NEXT: %1 = tt.addptr %[[tmp0]], %[[cst]] : tensor<8x!tt.ptr<f32>>, tensor<8xi32>
|
||||
%ptr0 = tt.addptr %base_, %idx0 : tensor<8x!tt.ptr<f32>>, tensor<8xi32>
|
||||
%ptr1 = tt.addptr %ptr0, %idx1 : tensor<8x!tt.ptr<f32>>, tensor<8xi32>
|
||||
|
||||
return %ptr1 : tensor<8x!tt.ptr<f32>>
|
||||
}
|
||||
|
||||
|
||||
// CHECK-LABEL: @test_combine_select_masked_load_pattern
|
||||
func @test_combine_select_masked_load_pattern(%ptr: tensor<8x!tt.ptr<f32>>, %cond: i1) -> (tensor<8xf32>, tensor<8xf32>) {
|
||||
%mask = tt.broadcast %cond : (i1) -> tensor<8xi1>
|
||||
|
@@ -11,9 +11,9 @@ module {
|
||||
%5 = tt.broadcast %arg3 : (i32) -> tensor<256xi32>
|
||||
%6 = arith.cmpi slt, %4, %5 : tensor<256xi32>
|
||||
%7 = tt.broadcast %arg0 : (!tt.ptr<f32>) -> tensor<256x!tt.ptr<f32>>
|
||||
%8 = tt.addptr %7, %4 : tensor<256x!tt.ptr<f32>>
|
||||
%8 = tt.addptr %7, %4 : tensor<256x!tt.ptr<f32>>, tensor<256xi32>
|
||||
%9 = tt.broadcast %arg1 : (!tt.ptr<f32>) -> tensor<256x!tt.ptr<f32>>
|
||||
%10 = tt.addptr %9, %4 : tensor<256x!tt.ptr<f32>>
|
||||
%10 = tt.addptr %9, %4 : tensor<256x!tt.ptr<f32>>, tensor<256xi32>
|
||||
%cst = arith.constant 0.000000e+00 : f32
|
||||
%11 = tt.broadcast %cst : (f32) -> tensor<256xf32>
|
||||
%c0_i32 = arith.constant 0 : i32
|
||||
@@ -31,13 +31,13 @@ module {
|
||||
%22 = arith.addf %19, %21 : tensor<256xf32>
|
||||
%23 = arith.addf %arg7, %22 : tensor<256xf32>
|
||||
%24 = tt.broadcast %arg5 : (i32) -> tensor<256xi32>
|
||||
%25 = tt.addptr %arg8, %24 : tensor<256x!tt.ptr<f32>>
|
||||
%25 = tt.addptr %arg8, %24 : tensor<256x!tt.ptr<f32>>, tensor<256xi32>
|
||||
%26 = tt.broadcast %arg5 : (i32) -> tensor<256xi32>
|
||||
%27 = tt.addptr %arg9, %26 : tensor<256x!tt.ptr<f32>>
|
||||
%27 = tt.addptr %arg9, %26 : tensor<256x!tt.ptr<f32>>, tensor<256xi32>
|
||||
scf.yield %23, %25, %27 : tensor<256xf32>, tensor<256x!tt.ptr<f32>>, tensor<256x!tt.ptr<f32>>
|
||||
}
|
||||
%16 = tt.broadcast %arg2 : (!tt.ptr<f32>) -> tensor<256x!tt.ptr<f32>>
|
||||
%17 = tt.addptr %16, %4 : tensor<256x!tt.ptr<f32>>
|
||||
%17 = tt.addptr %16, %4 : tensor<256x!tt.ptr<f32>>, tensor<256xi32>
|
||||
tt.store %17, %15#0, %6 : tensor<256xf32>
|
||||
return
|
||||
}
|
||||
@@ -57,9 +57,9 @@ module {
|
||||
// %5 = tt.broadcast %arg3 : (i32) -> tensor<256xi32, #triton_gpu<"coalesced encoding<threadTileSize = 1, blockTileSize = 32, order = 0>">>
|
||||
// %6 = "triton_gpu.cmpi"(%4, %5) {predicate = 2 : i64} : (tensor<256xi32, #triton_gpu<"coalesced encoding<threadTileSize = 1, blockTileSize = 32, order = 0>">>, tensor<256xi32, #triton_gpu<"coalesced encoding<threadTileSize = 1, blockTileSize = 32, order = 0>">>) -> tensor<256xi1, #triton_gpu<"coalesced encoding<threadTileSize = 1, blockTileSize = 32, order = 0>">>
|
||||
// %7 = tt.broadcast %arg0 : (!tt.ptr<f32>) -> tensor<256x!tt.ptr<f32>, #triton_gpu<"coalesced encoding<threadTileSize = 1, blockTileSize = 32, order = 0>">>
|
||||
// %8 = tt.addptr %7, %4, : tensor<256x!tt.ptr<f32>, #triton_gpu<"coalesced encoding<threadTileSize = 1, blockTileSize = 32, order = 0>">>
|
||||
// %8 = tt.addptr %7, %4, : tensor<256x!tt.ptr<f32>, #triton_gpu<"coalesced encoding<threadTileSize = 1, blockTileSize = 32, order = 0>">>, tensor<256xi32>
|
||||
// %9 = tt.broadcast %arg1 : (!tt.ptr<f32>) -> tensor<256x!tt.ptr<f32>, #triton_gpu<"coalesced encoding<threadTileSize = 1, blockTileSize = 32, order = 0>">>
|
||||
// %10 = tt.addptr %9, %4, : tensor<256x!tt.ptr<f32>, #triton_gpu<"coalesced encoding<threadTileSize = 1, blockTileSize = 32, order = 0>">>
|
||||
// %10 = tt.addptr %9, %4, : tensor<256x!tt.ptr<f32>, #triton_gpu<"coalesced encoding<threadTileSize = 1, blockTileSize = 32, order = 0>">>, tensor<256xi32>
|
||||
// %11 = tt.broadcast %cst : (f32) -> tensor<256xf32, #triton_gpu<"coalesced encoding<threadTileSize = 1, blockTileSize = 32, order = 0>">>
|
||||
// %12 = arith.index_cast %arg4 : i32 to index
|
||||
// %13 = arith.cmpi slt, %c0, %12 : index
|
||||
@@ -72,9 +72,9 @@ module {
|
||||
// %20 = arith.andi %6, %19 : tensor<256xi1, #triton_gpu<"coalesced encoding<threadTileSize = 1, blockTileSize = 32, order = 0>">>
|
||||
// %21 = triton_gpu.copy_async %10, %20, %18 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<256x!tt.ptr<f32>, #triton_gpu<"coalesced encoding<threadTileSize = 1, blockTileSize = 32, order = 0>">> -> tensor<256xf32, #triton_gpu<"coalesced encoding<threadTileSize = 1, blockTileSize = 32, order = 0>">>
|
||||
// %22 = tt.broadcast %arg5 : (i32) -> tensor<256xi32, #triton_gpu<"coalesced encoding<threadTileSize = 1, blockTileSize = 32, order = 0>">>
|
||||
// %23 = tt.addptr %8, %22, : tensor<256x!tt.ptr<f32>, #triton_gpu<"coalesced encoding<threadTileSize = 1, blockTileSize = 32, order = 0>">>
|
||||
// %23 = tt.addptr %8, %22, : tensor<256x!tt.ptr<f32>, #triton_gpu<"coalesced encoding<threadTileSize = 1, blockTileSize = 32, order = 0>">>, tensor<256xi32>
|
||||
// %24 = tt.broadcast %arg5 : (i32) -> tensor<256xi32, #triton_gpu<"coalesced encoding<threadTileSize = 1, blockTileSize = 32, order = 0>">>
|
||||
// %25 = tt.addptr %10, %24, : tensor<256x!tt.ptr<f32>, #triton_gpu<"coalesced encoding<threadTileSize = 1, blockTileSize = 32, order = 0>">>
|
||||
// %25 = tt.addptr %10, %24, : tensor<256x!tt.ptr<f32>, #triton_gpu<"coalesced encoding<threadTileSize = 1, blockTileSize = 32, order = 0>">>, tensor<256xi32>
|
||||
// %26 = arith.cmpi slt, %c32, %12 : index
|
||||
// %27 = tt.broadcast %cst : (f32) -> tensor<256xf32, #triton_gpu<"coalesced encoding<threadTileSize = 1, blockTileSize = 32, order = 0>">>
|
||||
// %28 = tt.broadcast %26 : (i1) -> tensor<256xi1, #triton_gpu<"coalesced encoding<threadTileSize = 1, blockTileSize = 32, order = 0>">>
|
||||
@@ -85,9 +85,9 @@ module {
|
||||
// %33 = arith.andi %6, %32 : tensor<256xi1, #triton_gpu<"coalesced encoding<threadTileSize = 1, blockTileSize = 32, order = 0>">>
|
||||
// %34 = triton_gpu.copy_async %25, %33, %31 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<256x!tt.ptr<f32>, #triton_gpu<"coalesced encoding<threadTileSize = 1, blockTileSize = 32, order = 0>">> -> tensor<256xf32, #triton_gpu<"coalesced encoding<threadTileSize = 1, blockTileSize = 32, order = 0>">>
|
||||
// %35 = tt.broadcast %arg5 : (i32) -> tensor<256xi32, #triton_gpu<"coalesced encoding<threadTileSize = 1, blockTileSize = 32, order = 0>">>
|
||||
// %36 = tt.addptr %23, %35, : tensor<256x!tt.ptr<f32>, #triton_gpu<"coalesced encoding<threadTileSize = 1, blockTileSize = 32, order = 0>">>
|
||||
// %36 = tt.addptr %23, %35, : tensor<256x!tt.ptr<f32>, #triton_gpu<"coalesced encoding<threadTileSize = 1, blockTileSize = 32, order = 0>">>, tensor<256xi32>
|
||||
// %37 = tt.broadcast %arg5 : (i32) -> tensor<256xi32, #triton_gpu<"coalesced encoding<threadTileSize = 1, blockTileSize = 32, order = 0>">>
|
||||
// %38 = tt.addptr %25, %37, : tensor<256x!tt.ptr<f32>, #triton_gpu<"coalesced encoding<threadTileSize = 1, blockTileSize = 32, order = 0>">>
|
||||
// %38 = tt.addptr %25, %37, : tensor<256x!tt.ptr<f32>, #triton_gpu<"coalesced encoding<threadTileSize = 1, blockTileSize = 32, order = 0>">>, tensor<256xi32>
|
||||
// %39 = arith.cmpi slt, %c64, %12 : index
|
||||
// %40 = tt.broadcast %cst : (f32) -> tensor<256xf32, #triton_gpu<"coalesced encoding<threadTileSize = 1, blockTileSize = 32, order = 0>">>
|
||||
// %41 = tt.broadcast %39 : (i1) -> tensor<256xi1, #triton_gpu<"coalesced encoding<threadTileSize = 1, blockTileSize = 32, order = 0>">>
|
||||
@@ -98,16 +98,16 @@ module {
|
||||
// %46 = arith.andi %6, %45 : tensor<256xi1, #triton_gpu<"coalesced encoding<threadTileSize = 1, blockTileSize = 32, order = 0>">>
|
||||
// %47 = triton_gpu.copy_async %38, %46, %44 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<256x!tt.ptr<f32>, #triton_gpu<"coalesced encoding<threadTileSize = 1, blockTileSize = 32, order = 0>">> -> tensor<256xf32, #triton_gpu<"coalesced encoding<threadTileSize = 1, blockTileSize = 32, order = 0>">>
|
||||
// %48 = tt.broadcast %arg5 : (i32) -> tensor<256xi32, #triton_gpu<"coalesced encoding<threadTileSize = 1, blockTileSize = 32, order = 0>">>
|
||||
// %49 = tt.addptr %36, %48, : tensor<256x!tt.ptr<f32>, #triton_gpu<"coalesced encoding<threadTileSize = 1, blockTileSize = 32, order = 0>">>
|
||||
// %49 = tt.addptr %36, %48, : tensor<256x!tt.ptr<f32>, #triton_gpu<"coalesced encoding<threadTileSize = 1, blockTileSize = 32, order = 0>">>, tensor<256xi32>
|
||||
// %50 = tt.broadcast %arg5 : (i32) -> tensor<256xi32, #triton_gpu<"coalesced encoding<threadTileSize = 1, blockTileSize = 32, order = 0>">>
|
||||
// %51 = tt.addptr %38, %50, : tensor<256x!tt.ptr<f32>, #triton_gpu<"coalesced encoding<threadTileSize = 1, blockTileSize = 32, order = 0>">>
|
||||
// %51 = tt.addptr %38, %50, : tensor<256x!tt.ptr<f32>, #triton_gpu<"coalesced encoding<threadTileSize = 1, blockTileSize = 32, order = 0>">>, tensor<256xi32>
|
||||
// %52:12 = scf.for %arg6 = %c0 to %12 step %c32 iter_args(%arg7 = %11, %arg8 = %8, %arg9 = %10, %arg10 = %17, %arg11 = %30, %arg12 = %43, %arg13 = %21, %arg14 = %34, %arg15 = %47, %arg16 = %51, %arg17 = %49, %arg18 = %c64) -> (tensor<256xf32, #triton_gpu<"coalesced encoding<threadTileSize = 1, blockTileSize = 32, order = 0>">>, tensor<256x!tt.ptr<f32>, #triton_gpu<"coalesced encoding<threadTileSize = 1, blockTileSize = 32, order = 0>">>, tensor<256x!tt.ptr<f32>, #triton_gpu<"coalesced encoding<threadTileSize = 1, blockTileSize = 32, order = 0>">>, tensor<256xf32, #triton_gpu<"coalesced encoding<threadTileSize = 1, blockTileSize = 32, order = 0>">>, tensor<256xf32, #triton_gpu<"coalesced encoding<threadTileSize = 1, blockTileSize = 32, order = 0>">>, tensor<256xf32, #triton_gpu<"coalesced encoding<threadTileSize = 1, blockTileSize = 32, order = 0>">>, tensor<256xf32, #triton_gpu<"coalesced encoding<threadTileSize = 1, blockTileSize = 32, order = 0>">>, tensor<256xf32, #triton_gpu<"coalesced encoding<threadTileSize = 1, blockTileSize = 32, order = 0>">>, tensor<256xf32, #triton_gpu<"coalesced encoding<threadTileSize = 1, blockTileSize = 32, order = 0>">>, tensor<256x!tt.ptr<f32>, #triton_gpu<"coalesced encoding<threadTileSize = 1, blockTileSize = 32, order = 0>">>, tensor<256x!tt.ptr<f32>, #triton_gpu<"coalesced encoding<threadTileSize = 1, blockTileSize = 32, order = 0>">>, index) {
|
||||
// %55 = arith.addf %arg10, %arg13 : tensor<256xf32, #triton_gpu<"coalesced encoding<threadTileSize = 1, blockTileSize = 32, order = 0>">>
|
||||
// %56 = arith.addf %arg7, %55 : tensor<256xf32, #triton_gpu<"coalesced encoding<threadTileSize = 1, blockTileSize = 32, order = 0>">>
|
||||
// %57 = tt.broadcast %arg5 : (i32) -> tensor<256xi32, #triton_gpu<"coalesced encoding<threadTileSize = 1, blockTileSize = 32, order = 0>">>
|
||||
// %58 = tt.addptr %arg8, %57, : tensor<256x!tt.ptr<f32>, #triton_gpu<"coalesced encoding<threadTileSize = 1, blockTileSize = 32, order = 0>">>
|
||||
// %58 = tt.addptr %arg8, %57, : tensor<256x!tt.ptr<f32>, #triton_gpu<"coalesced encoding<threadTileSize = 1, blockTileSize = 32, order = 0>">>, tensor<256xi32>
|
||||
// %59 = tt.broadcast %arg5 : (i32) -> tensor<256xi32, #triton_gpu<"coalesced encoding<threadTileSize = 1, blockTileSize = 32, order = 0>">>
|
||||
// %60 = tt.addptr %arg9, %59, : tensor<256x!tt.ptr<f32>, #triton_gpu<"coalesced encoding<threadTileSize = 1, blockTileSize = 32, order = 0>">>
|
||||
// %60 = tt.addptr %arg9, %59, : tensor<256x!tt.ptr<f32>, #triton_gpu<"coalesced encoding<threadTileSize = 1, blockTileSize = 32, order = 0>">>, tensor<256xi32>
|
||||
// %61 = arith.addi %arg18, %c32 : index
|
||||
// %62 = arith.cmpi slt, %61, %12 : index
|
||||
// %63 = tt.broadcast %cst : (f32) -> tensor<256xf32, #triton_gpu<"coalesced encoding<threadTileSize = 1, blockTileSize = 32, order = 0>">>
|
||||
@@ -117,13 +117,13 @@ module {
|
||||
// %67 = tt.broadcast %cst : (f32) -> tensor<256xf32, #triton_gpu<"coalesced encoding<threadTileSize = 1, blockTileSize = 32, order = 0>">>
|
||||
// %68 = triton_gpu.copy_async %arg16, %65, %67 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<256x!tt.ptr<f32>, #triton_gpu<"coalesced encoding<threadTileSize = 1, blockTileSize = 32, order = 0>">> -> tensor<256xf32, #triton_gpu<"coalesced encoding<threadTileSize = 1, blockTileSize = 32, order = 0>">>
|
||||
// %69 = tt.broadcast %arg5 : (i32) -> tensor<256xi32, #triton_gpu<"coalesced encoding<threadTileSize = 1, blockTileSize = 32, order = 0>">>
|
||||
// %70 = tt.addptr %arg17, %69, : tensor<256x!tt.ptr<f32>, #triton_gpu<"coalesced encoding<threadTileSize = 1, blockTileSize = 32, order = 0>">>
|
||||
// %70 = tt.addptr %arg17, %69, : tensor<256x!tt.ptr<f32>, #triton_gpu<"coalesced encoding<threadTileSize = 1, blockTileSize = 32, order = 0>">>, tensor<256xi32>
|
||||
// %71 = tt.broadcast %arg5 : (i32) -> tensor<256xi32, #triton_gpu<"coalesced encoding<threadTileSize = 1, blockTileSize = 32, order = 0>">>
|
||||
// %72 = tt.addptr %arg16, %71, : tensor<256x!tt.ptr<f32>, #triton_gpu<"coalesced encoding<threadTileSize = 1, blockTileSize = 32, order = 0>">>
|
||||
// %72 = tt.addptr %arg16, %71, : tensor<256x!tt.ptr<f32>, #triton_gpu<"coalesced encoding<threadTileSize = 1, blockTileSize = 32, order = 0>">>, tensor<256xi32>
|
||||
// scf.yield %56, %58, %60, %arg11, %arg12, %66, %arg14, %arg15, %68, %72, %70, %61 : tensor<256xf32, #triton_gpu<"coalesced encoding<threadTileSize = 1, blockTileSize = 32, order = 0>">>, tensor<256x!tt.ptr<f32>, #triton_gpu<"coalesced encoding<threadTileSize = 1, blockTileSize = 32, order = 0>">>, tensor<256x!tt.ptr<f32>, #triton_gpu<"coalesced encoding<threadTileSize = 1, blockTileSize = 32, order = 0>">>, tensor<256xf32, #triton_gpu<"coalesced encoding<threadTileSize = 1, blockTileSize = 32, order = 0>">>, tensor<256xf32, #triton_gpu<"coalesced encoding<threadTileSize = 1, blockTileSize = 32, order = 0>">>, tensor<256xf32, #triton_gpu<"coalesced encoding<threadTileSize = 1, blockTileSize = 32, order = 0>">>, tensor<256xf32, #triton_gpu<"coalesced encoding<threadTileSize = 1, blockTileSize = 32, order = 0>">>, tensor<256xf32, #triton_gpu<"coalesced encoding<threadTileSize = 1, blockTileSize = 32, order = 0>">>, tensor<256xf32, #triton_gpu<"coalesced encoding<threadTileSize = 1, blockTileSize = 32, order = 0>">>, tensor<256x!tt.ptr<f32>, #triton_gpu<"coalesced encoding<threadTileSize = 1, blockTileSize = 32, order = 0>">>, tensor<256x!tt.ptr<f32>, #triton_gpu<"coalesced encoding<threadTileSize = 1, blockTileSize = 32, order = 0>">>, index
|
||||
// }
|
||||
// %53 = tt.broadcast %arg2 : (!tt.ptr<f32>) -> tensor<256x!tt.ptr<f32>, #triton_gpu<"coalesced encoding<threadTileSize = 1, blockTileSize = 32, order = 0>">>
|
||||
// %54 = tt.addptr %53, %4, : tensor<256x!tt.ptr<f32>, #triton_gpu<"coalesced encoding<threadTileSize = 1, blockTileSize = 32, order = 0>">>
|
||||
// %54 = tt.addptr %53, %4, : tensor<256x!tt.ptr<f32>, #triton_gpu<"coalesced encoding<threadTileSize = 1, blockTileSize = 32, order = 0>">>, tensor<256xi32>
|
||||
// tt.store %54, %52#0, %6 : tensor<256xf32, #triton_gpu<"coalesced encoding<threadTileSize = 1, blockTileSize = 32, order = 0>">>
|
||||
// return
|
||||
// }
|
||||
|
@@ -31,20 +31,20 @@ func @transpose(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32},
|
||||
%2 = tt.splat %arg1 : (i32) -> tensor<64x1xi32, #blocked1>
|
||||
%3 = arith.muli %1, %2 : tensor<64x1xi32, #blocked1>
|
||||
%4 = tt.splat %arg0 : (!tt.ptr<f32>) -> tensor<64x1x!tt.ptr<f32>, #blocked1>
|
||||
%5 = tt.addptr %4, %3 : tensor<64x1x!tt.ptr<f32>, #blocked1>
|
||||
%5 = tt.addptr %4, %3 : tensor<64x1x!tt.ptr<f32>, #blocked1>, tensor<64x1xi32, #blocked1>
|
||||
%6 = tt.expand_dims %01 {axis = 0 : i32} : (tensor<64xi32, #slice2dim0>) -> tensor<1x64xi32, #blocked2>
|
||||
%7 = tt.broadcast %5 : (tensor<64x1x!tt.ptr<f32>, #blocked1>) -> tensor<64x64x!tt.ptr<f32>, #blocked1>
|
||||
%8 = tt.broadcast %6 : (tensor<1x64xi32, #blocked2>) -> tensor<64x64xi32, #blocked2>
|
||||
%9 = triton_gpu.convert_layout %8 : (tensor<64x64xi32, #blocked2>) -> tensor<64x64xi32, #blocked1>
|
||||
%10 = tt.addptr %7, %9 : tensor<64x64x!tt.ptr<f32>, #blocked1>
|
||||
%10 = tt.addptr %7, %9 : tensor<64x64x!tt.ptr<f32>, #blocked1>, tensor<64x64xi32, #blocked1>
|
||||
%11 = tt.splat %arg2 : (!tt.ptr<f32>) -> tensor<64x1x!tt.ptr<f32>, #blocked1>
|
||||
%12 = tt.addptr %11, %1 : tensor<64x1x!tt.ptr<f32>, #blocked1>
|
||||
%12 = tt.addptr %11, %1 : tensor<64x1x!tt.ptr<f32>, #blocked1>, tensor<64x1xi32, #blocked1>
|
||||
%13 = tt.splat %arg3 : (i32) -> tensor<1x64xi32, #blocked2>
|
||||
%14 = arith.muli %6, %13 : tensor<1x64xi32, #blocked2>
|
||||
%15 = tt.broadcast %12 : (tensor<64x1x!tt.ptr<f32>, #blocked1>) -> tensor<64x64x!tt.ptr<f32>, #blocked1>
|
||||
%16 = tt.broadcast %14 : (tensor<1x64xi32, #blocked2>) -> tensor<64x64xi32, #blocked2>
|
||||
%17 = triton_gpu.convert_layout %16 : (tensor<64x64xi32, #blocked2>) -> tensor<64x64xi32, #blocked1>
|
||||
%18 = tt.addptr %15, %17 : tensor<64x64x!tt.ptr<f32>, #blocked1>
|
||||
%18 = tt.addptr %15, %17 : tensor<64x64x!tt.ptr<f32>, #blocked1>, tensor<64x64xi32, #blocked1>
|
||||
%19 = tt.load %10, %cst, %cst_0 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<64x64xf32, #blocked1>
|
||||
tt.store %18, %19, %cst : tensor<64x64xf32, #blocked1>
|
||||
return
|
||||
|
@@ -74,20 +74,20 @@ func @transpose(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg1: i32 {tt
|
||||
%2 = tt.splat %arg1 : (i32) -> tensor<64x1xi32, #blocked1>
|
||||
%3 = arith.muli %1, %2 : tensor<64x1xi32, #blocked1>
|
||||
%4 = tt.splat %arg0 : (!tt.ptr<f32>) -> tensor<64x1x!tt.ptr<f32>, #blocked1>
|
||||
%5 = tt.addptr %4, %3 : tensor<64x1x!tt.ptr<f32>, #blocked1>
|
||||
%5 = tt.addptr %4, %3 : tensor<64x1x!tt.ptr<f32>, #blocked1>, tensor<64x1xi32, #blocked1>
|
||||
%6 = tt.expand_dims %01 {axis = 0 : i32} : (tensor<64xi32, #slice2dim0>) -> tensor<1x64xi32, #blocked2>
|
||||
%7 = tt.broadcast %5 : (tensor<64x1x!tt.ptr<f32>, #blocked1>) -> tensor<64x64x!tt.ptr<f32>, #blocked1>
|
||||
%8 = tt.broadcast %6 : (tensor<1x64xi32, #blocked2>) -> tensor<64x64xi32, #blocked2>
|
||||
%9 = triton_gpu.convert_layout %8 : (tensor<64x64xi32, #blocked2>) -> tensor<64x64xi32, #blocked1>
|
||||
%10 = tt.addptr %7, %9 : tensor<64x64x!tt.ptr<f32>, #blocked1>
|
||||
%10 = tt.addptr %7, %9 : tensor<64x64x!tt.ptr<f32>, #blocked1>, tensor<64x64xi32, #blocked1>
|
||||
%11 = tt.splat %arg2 : (!tt.ptr<f32>) -> tensor<64x1x!tt.ptr<f32>, #blocked1>
|
||||
%12 = tt.addptr %11, %1 : tensor<64x1x!tt.ptr<f32>, #blocked1>
|
||||
%12 = tt.addptr %11, %1 : tensor<64x1x!tt.ptr<f32>, #blocked1>, tensor<64x1xi32, #blocked1>
|
||||
%13 = tt.splat %arg3 : (i32) -> tensor<1x64xi32, #blocked2>
|
||||
%14 = arith.muli %6, %13 : tensor<1x64xi32, #blocked2>
|
||||
%15 = tt.broadcast %12 : (tensor<64x1x!tt.ptr<f32>, #blocked1>) -> tensor<64x64x!tt.ptr<f32>, #blocked1>
|
||||
%16 = tt.broadcast %14 : (tensor<1x64xi32, #blocked2>) -> tensor<64x64xi32, #blocked2>
|
||||
%17 = triton_gpu.convert_layout %16 : (tensor<64x64xi32, #blocked2>) -> tensor<64x64xi32, #blocked1>
|
||||
%18 = tt.addptr %15, %17 : tensor<64x64x!tt.ptr<f32>, #blocked1>
|
||||
%18 = tt.addptr %15, %17 : tensor<64x64x!tt.ptr<f32>, #blocked1>, tensor<64x64xi32, #blocked1>
|
||||
%19 = triton_gpu.convert_layout %10 : (tensor<64x64x!tt.ptr<f32>, #blocked1>) -> tensor<64x64x!tt.ptr<f32>, #blocked3>
|
||||
%20 = triton_gpu.convert_layout %cst_0 : (tensor<64x64xi1, #blocked1>) -> tensor<64x64xi1, #blocked3>
|
||||
%21 = triton_gpu.convert_layout %cst : (tensor<64x64xf32, #blocked1>) -> tensor<64x64xf32, #blocked3>
|
||||
@@ -106,7 +106,7 @@ func @loop(%arg0: !tt.ptr<f32>, %arg1: i32, %arg2: !tt.ptr<f32>, %arg3: i32, %ar
|
||||
// CHECK: [[loop_ret:%.*]]:2 = scf.for {{.*}} -> (tensor<64x64xf32, [[row_layout]]>, tensor<64x64x!tt.ptr<f32>, [[row_layout]]>)
|
||||
// CHECK-NEXT: {{.*}} = tt.load {{.*}} : tensor<64x64xf32, [[row_layout]]>
|
||||
// CHECK-NEXT: {{.*}} = arith.addf {{.*}} : tensor<64x64xf32, [[row_layout]]>
|
||||
// CHECK-NEXT: {{.*}} = tt.addptr {{.*}} : tensor<64x64x!tt.ptr<f32>, [[row_layout]]>
|
||||
// CHECK-NEXT: {{.*}} = tt.addptr {{.*}} : tensor<64x64x!tt.ptr<f32>, [[row_layout]]>, tensor<64x64xi32, [[row_layout]]>
|
||||
// CHECK-NEXT: scf.yield {{.*}} : tensor<64x64xf32, [[row_layout]]>, tensor<64x64x!tt.ptr<f32>, [[row_layout]]>
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: {{.*}} = triton_gpu.convert_layout [[loop_ret]]#0 : (tensor<64x64xf32, [[row_layout]]>) -> tensor<64x64xf32, [[col_layout_novec]]>
|
||||
@@ -123,12 +123,12 @@ func @loop(%arg0: !tt.ptr<f32>, %arg1: i32, %arg2: !tt.ptr<f32>, %arg3: i32, %ar
|
||||
%2 = tt.splat %arg1 : (i32) -> tensor<64x1xi32, #blocked1>
|
||||
%3 = arith.muli %1, %2 : tensor<64x1xi32, #blocked1>
|
||||
%4 = tt.splat %arg0 : (!tt.ptr<f32>) -> tensor<64x1x!tt.ptr<f32>, #blocked1>
|
||||
%5 = tt.addptr %4, %3 : tensor<64x1x!tt.ptr<f32>, #blocked1>
|
||||
%5 = tt.addptr %4, %3 : tensor<64x1x!tt.ptr<f32>, #blocked1>, tensor<64x1xi32, #blocked1>
|
||||
%6 = tt.expand_dims %01 {axis = 0 : i32} : (tensor<64xi32, #slice2dim0>) -> tensor<1x64xi32, #blocked2>
|
||||
%7 = tt.broadcast %5 : (tensor<64x1x!tt.ptr<f32>, #blocked1>) -> tensor<64x64x!tt.ptr<f32>, #blocked1>
|
||||
%8 = tt.broadcast %6 : (tensor<1x64xi32, #blocked2>) -> tensor<64x64xi32, #blocked2>
|
||||
%9 = triton_gpu.convert_layout %8 : (tensor<64x64xi32, #blocked2>) -> tensor<64x64xi32, #blocked1>
|
||||
%10 = tt.addptr %7, %9 : tensor<64x64x!tt.ptr<f32>, #blocked1>
|
||||
%10 = tt.addptr %7, %9 : tensor<64x64x!tt.ptr<f32>, #blocked1>, tensor<64x64xi32, #blocked1>
|
||||
%11:2 = scf.for %arg5 = %c0 to %c32 step %c1 iter_args(%arg6 = %cst_1, %arg7 = %10) -> (tensor<64x64xf32, #blocked1>, tensor<64x64x!tt.ptr<f32>, #blocked1>) {
|
||||
%23 = triton_gpu.convert_layout %arg7 : (tensor<64x64x!tt.ptr<f32>, #blocked1>) -> tensor<64x64x!tt.ptr<f32>, #blocked3>
|
||||
%24 = triton_gpu.convert_layout %cst : (tensor<64x64xi1, #blocked1>) -> tensor<64x64xi1, #blocked3>
|
||||
@@ -136,17 +136,17 @@ func @loop(%arg0: !tt.ptr<f32>, %arg1: i32, %arg2: !tt.ptr<f32>, %arg3: i32, %ar
|
||||
%26 = tt.load %23, %24, %25 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<64x64xf32, #blocked3>
|
||||
%27 = triton_gpu.convert_layout %26 : (tensor<64x64xf32, #blocked3>) -> tensor<64x64xf32, #blocked1>
|
||||
%28 = arith.addf %arg6, %27 : tensor<64x64xf32, #blocked1>
|
||||
%29 = tt.addptr %arg7, %cst_0 : tensor<64x64x!tt.ptr<f32>, #blocked1>
|
||||
%29 = tt.addptr %arg7, %cst_0 : tensor<64x64x!tt.ptr<f32>, #blocked1>, tensor<64x64xi32, #blocked1>
|
||||
scf.yield %28, %29 : tensor<64x64xf32, #blocked1>, tensor<64x64x!tt.ptr<f32>, #blocked1>
|
||||
}
|
||||
%12 = tt.splat %arg2 : (!tt.ptr<f32>) -> tensor<64x1x!tt.ptr<f32>, #blocked1>
|
||||
%13 = tt.addptr %12, %1 : tensor<64x1x!tt.ptr<f32>, #blocked1>
|
||||
%13 = tt.addptr %12, %1 : tensor<64x1x!tt.ptr<f32>, #blocked1>, tensor<64x1xi32, #blocked1>
|
||||
%14 = tt.splat %arg3 : (i32) -> tensor<1x64xi32, #blocked2>
|
||||
%15 = arith.muli %6, %14 : tensor<1x64xi32, #blocked2>
|
||||
%16 = tt.broadcast %13 : (tensor<64x1x!tt.ptr<f32>, #blocked1>) -> tensor<64x64x!tt.ptr<f32>, #blocked1>
|
||||
%17 = tt.broadcast %15 : (tensor<1x64xi32, #blocked2>) -> tensor<64x64xi32, #blocked2>
|
||||
%18 = triton_gpu.convert_layout %17 : (tensor<64x64xi32, #blocked2>) -> tensor<64x64xi32, #blocked1>
|
||||
%19 = tt.addptr %16, %18 : tensor<64x64x!tt.ptr<f32>, #blocked1>
|
||||
%19 = tt.addptr %16, %18 : tensor<64x64x!tt.ptr<f32>, #blocked1>, tensor<64x64xi32, #blocked1>
|
||||
%20 = triton_gpu.convert_layout %19 : (tensor<64x64x!tt.ptr<f32>, #blocked1>) -> tensor<64x64x!tt.ptr<f32>, #blocked1>
|
||||
%21 = triton_gpu.convert_layout %11#0 : (tensor<64x64xf32, #blocked1>) -> tensor<64x64xf32, #blocked1>
|
||||
%22 = triton_gpu.convert_layout %cst : (tensor<64x64xi1, #blocked1>) -> tensor<64x64xi1, #blocked1>
|
||||
@@ -160,27 +160,27 @@ func @vecadd(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg1: !tt.ptr<f3
|
||||
%c256_i32 = arith.constant 256 : i32
|
||||
%0 = tt.get_program_id {axis = 0 : i32} : i32
|
||||
%1 = arith.muli %0, %c256_i32 : i32
|
||||
%2 = tt.splat %1 : (i32) -> tensor<256xi32, #triton_gpu.blocked<{sizePerThread = [4], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>>
|
||||
%3 = tt.make_range {end = 256 : i32, start = 0 : i32} : tensor<256xi32, #triton_gpu.blocked<{sizePerThread = [4], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>>
|
||||
%4 = tt.splat %1 : (i32) -> tensor<256xi32, #triton_gpu.blocked<{sizePerThread = [4], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>>
|
||||
%5 = tt.make_range {end = 256 : i32, start = 0 : i32} : tensor<256xi32, #triton_gpu.blocked<{sizePerThread = [4], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>>
|
||||
%6 = tt.splat %1 : (i32) -> tensor<256xi32, #triton_gpu.blocked<{sizePerThread = [4], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>>
|
||||
%7 = tt.make_range {end = 256 : i32, start = 0 : i32} : tensor<256xi32, #triton_gpu.blocked<{sizePerThread = [4], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>>
|
||||
%8 = tt.splat %arg0 : (!tt.ptr<f32>) -> tensor<256x!tt.ptr<f32>, #triton_gpu.blocked<{sizePerThread = [4], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>>
|
||||
%9 = arith.addi %6, %7 : tensor<256xi32, #triton_gpu.blocked<{sizePerThread = [4], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>>
|
||||
%10 = tt.splat %arg1 : (!tt.ptr<f32>) -> tensor<256x!tt.ptr<f32>, #triton_gpu.blocked<{sizePerThread = [4], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>>
|
||||
%11 = arith.addi %4, %5 : tensor<256xi32, #triton_gpu.blocked<{sizePerThread = [4], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>>
|
||||
%12 = tt.addptr %8, %9 : tensor<256x!tt.ptr<f32>, #triton_gpu.blocked<{sizePerThread = [4], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>>
|
||||
%13 = tt.load %12 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<256xf32, #triton_gpu.blocked<{sizePerThread = [4], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>>
|
||||
%14 = triton_gpu.convert_layout %13 : (tensor<256xf32, #triton_gpu.blocked<{sizePerThread = [4], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>>) -> tensor<256xf32, #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [2], order = [0]}>>
|
||||
%15 = tt.addptr %10, %11 : tensor<256x!tt.ptr<f32>, #triton_gpu.blocked<{sizePerThread = [4], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>>
|
||||
%16 = tt.load %15 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<256xf32, #triton_gpu.blocked<{sizePerThread = [4], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>>
|
||||
%17 = triton_gpu.convert_layout %16 : (tensor<256xf32, #triton_gpu.blocked<{sizePerThread = [4], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>>) -> tensor<256xf32, #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [2], order = [0]}>>
|
||||
%2 = tt.splat %1 : (i32) -> tensor<256xi32, #layout1>
|
||||
%3 = tt.make_range {end = 256 : i32, start = 0 : i32} : tensor<256xi32, #layout1>
|
||||
%4 = tt.splat %1 : (i32) -> tensor<256xi32, #layout1>
|
||||
%5 = tt.make_range {end = 256 : i32, start = 0 : i32} : tensor<256xi32, #layout1>
|
||||
%6 = tt.splat %1 : (i32) -> tensor<256xi32, #layout1>
|
||||
%7 = tt.make_range {end = 256 : i32, start = 0 : i32} : tensor<256xi32, #layout1>
|
||||
%8 = tt.splat %arg0 : (!tt.ptr<f32>) -> tensor<256x!tt.ptr<f32>, #layout1>
|
||||
%9 = arith.addi %6, %7 : tensor<256xi32, #layout1>
|
||||
%10 = tt.splat %arg1 : (!tt.ptr<f32>) -> tensor<256x!tt.ptr<f32>, #layout1>
|
||||
%11 = arith.addi %4, %5 : tensor<256xi32, #layout1>
|
||||
%12 = tt.addptr %8, %9 : tensor<256x!tt.ptr<f32>, #layout1>, tensor<256xi32, #layout1>
|
||||
%13 = tt.load %12 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<256xf32, #layout1>
|
||||
%14 = triton_gpu.convert_layout %13 : (tensor<256xf32, #layout1>) -> tensor<256xf32, #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [2], order = [0]}>>
|
||||
%15 = tt.addptr %10, %11 : tensor<256x!tt.ptr<f32>, #layout1>, tensor<256xi32, #layout1>
|
||||
%16 = tt.load %15 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<256xf32, #layout1>
|
||||
%17 = triton_gpu.convert_layout %16 : (tensor<256xf32, #layout1>) -> tensor<256xf32, #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [2], order = [0]}>>
|
||||
%18 = arith.addf %14, %17 : tensor<256xf32, #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [2], order = [0]}>>
|
||||
%19 = tt.splat %arg2 : (!tt.ptr<f32>) -> tensor<256x!tt.ptr<f32>, #triton_gpu.blocked<{sizePerThread = [4], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>>
|
||||
%20 = arith.addi %2, %3 : tensor<256xi32, #triton_gpu.blocked<{sizePerThread = [4], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>>
|
||||
%21 = tt.addptr %19, %20 : tensor<256x!tt.ptr<f32>, #triton_gpu.blocked<{sizePerThread = [4], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>>
|
||||
%22 = triton_gpu.convert_layout %18 : (tensor<256xf32, #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [2], order = [0]}>>) -> tensor<256xf32, #triton_gpu.blocked<{sizePerThread = [4], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>>
|
||||
tt.store %21, %22 : tensor<256xf32, #triton_gpu.blocked<{sizePerThread = [4], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>>
|
||||
%19 = tt.splat %arg2 : (!tt.ptr<f32>) -> tensor<256x!tt.ptr<f32>, #layout1>
|
||||
%20 = arith.addi %2, %3 : tensor<256xi32, #layout1>
|
||||
%21 = tt.addptr %19, %20 : tensor<256x!tt.ptr<f32>, #layout1>, tensor<256xi32, #layout1>
|
||||
%22 = triton_gpu.convert_layout %18 : (tensor<256xf32, #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [2], order = [0]}>>) -> tensor<256xf32, #layout1>
|
||||
tt.store %21, %22 : tensor<256xf32, #layout1>
|
||||
return
|
||||
}
|
||||
|
@@ -65,8 +65,8 @@ func @matmul_loop(%lb : index, %ub : index, %step : index, %A : !tt.ptr<f16>, %B
|
||||
|
||||
%c = tt.dot %a, %b, %prev_c {allowTF32 = true, transA = false, transB = false} : tensor<128x32xf16, #A> * tensor<32x128xf16, #B> -> tensor<128x128xf32, #C>
|
||||
|
||||
%next_a_ptr = tt.addptr %a_ptr, %a_off : tensor<128x32x!tt.ptr<f16>, #AL>
|
||||
%next_b_ptr = tt.addptr %b_ptr, %b_off : tensor<32x128x!tt.ptr<f16>, #BL>
|
||||
%next_a_ptr = tt.addptr %a_ptr, %a_off : tensor<128x32x!tt.ptr<f16>, #AL>, tensor<128x32xi32, #AL>
|
||||
%next_b_ptr = tt.addptr %b_ptr, %b_off : tensor<32x128x!tt.ptr<f16>, #BL>, tensor<32x128xi32, #BL>
|
||||
scf.yield %next_a_ptr, %next_b_ptr, %c : tensor<128x32x!tt.ptr<f16>, #AL>, tensor<32x128x!tt.ptr<f16>, #BL>, tensor<128x128xf32, #C>
|
||||
}
|
||||
return
|
||||
@@ -125,8 +125,8 @@ func @matmul_loop_nested(%lb : index, %ub : index, %step : index, %A : !tt.ptr<f
|
||||
|
||||
%c = tt.dot %a, %b, %prev_c {allowTF32 = true, transA = false, transB = false} : tensor<128x32xf16, #A> * tensor<32x128xf16, #B> -> tensor<128x128xf32, #C>
|
||||
|
||||
%next_a_ptr = tt.addptr %a_ptr, %a_off : tensor<128x32x!tt.ptr<f16>, #AL>
|
||||
%next_b_ptr = tt.addptr %b_ptr, %b_off : tensor<32x128x!tt.ptr<f16>, #BL>
|
||||
%next_a_ptr = tt.addptr %a_ptr, %a_off : tensor<128x32x!tt.ptr<f16>, #AL>, tensor<128x32xi32, #AL>
|
||||
%next_b_ptr = tt.addptr %b_ptr, %b_off : tensor<32x128x!tt.ptr<f16>, #BL>, tensor<32x128xi32, #BL>
|
||||
scf.yield %next_a_ptr, %next_b_ptr, %c : tensor<128x32x!tt.ptr<f16>, #AL>, tensor<32x128x!tt.ptr<f16>, #BL>, tensor<128x128xf32, #C>
|
||||
}
|
||||
}
|
||||
@@ -176,7 +176,7 @@ func @matmul_loop_single_pipeline(%lb : index, %ub : index, %step : index, %A :
|
||||
%b_ = tt.load %b_ptr, %b_mask, %b_other {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<32x128xf16, #BL>
|
||||
%b = triton_gpu.convert_layout %b_ : (tensor<32x128xf16, #BL>) -> tensor<32x128xf16, #B>
|
||||
%c = tt.dot %a, %b, %prev_c {allowTF32 = true, transA = false, transB = false} : tensor<128x32xf16, #A> * tensor<32x128xf16, #B> -> tensor<128x128xf32, #C>
|
||||
%next_b_ptr = tt.addptr %b_ptr, %b_off : tensor<32x128x!tt.ptr<f16>, #BL>
|
||||
%next_b_ptr = tt.addptr %b_ptr, %b_off : tensor<32x128x!tt.ptr<f16>, #BL>, tensor<32x128xi32, #BL>
|
||||
scf.yield %next_b_ptr, %c : tensor<32x128x!tt.ptr<f16>, #BL>, tensor<128x128xf32, #C>
|
||||
}
|
||||
return
|
||||
|
@@ -46,7 +46,7 @@ func @matmul_kernel__Pfp32_Pfp32_Pfp32_i32_i32_i32_i32_i32_i32_i32_i32_i32__12c6
|
||||
%31 = tt.broadcast %29 : (tensor<1x64xi32>) -> tensor<64x64xi32>
|
||||
%32 = arith.addi %30, %31 : tensor<64x64xi32>
|
||||
%33 = tt.splat %arg0 : (!tt.ptr<f32>) -> tensor<64x64x!tt.ptr<f32>>
|
||||
%34 = tt.addptr %33, %32 : tensor<64x64x!tt.ptr<f32>>
|
||||
%34 = tt.addptr %33, %32 : tensor<64x64x!tt.ptr<f32>>, tensor<64x64xi32>
|
||||
%35 = tt.expand_dims %23 {axis = 1 : i32} : (tensor<64xi32>) -> tensor<64x1xi32>
|
||||
%36 = tt.splat %arg8 : (i32) -> tensor<64x1xi32>
|
||||
%37 = arith.muli %35, %36 : tensor<64x1xi32>
|
||||
@@ -57,7 +57,7 @@ func @matmul_kernel__Pfp32_Pfp32_Pfp32_i32_i32_i32_i32_i32_i32_i32_i32_i32__12c6
|
||||
%42 = tt.broadcast %40 : (tensor<1x64xi32>) -> tensor<64x64xi32>
|
||||
%43 = arith.addi %41, %42 : tensor<64x64xi32>
|
||||
%44 = tt.splat %arg1 : (!tt.ptr<f32>) -> tensor<64x64x!tt.ptr<f32>>
|
||||
%45 = tt.addptr %44, %43 : tensor<64x64x!tt.ptr<f32>>
|
||||
%45 = tt.addptr %44, %43 : tensor<64x64x!tt.ptr<f32>>, tensor<64x64xi32>
|
||||
%46 = arith.index_cast %arg5 : i32 to index
|
||||
%47:3 = scf.for %arg12 = %c0 to %46 step %c64 iter_args(%arg13 = %cst_0, %arg14 = %34, %arg15 = %45) -> (tensor<64x64xf32>, tensor<64x64x!tt.ptr<f32>>, tensor<64x64x!tt.ptr<f32>>) {
|
||||
%76 = tt.load %arg14, %cst, %cst_0 {cache = 1 : i32, evict = 1 : i32, isVolatile = false, transA=false, transB=false} : tensor<64x64xf32>
|
||||
@@ -66,10 +66,10 @@ func @matmul_kernel__Pfp32_Pfp32_Pfp32_i32_i32_i32_i32_i32_i32_i32_i32_i32__12c6
|
||||
%79 = arith.addf %arg13, %78 : tensor<64x64xf32>
|
||||
%80 = arith.muli %arg7, %c64_i32 : i32
|
||||
%81 = tt.splat %80 : (i32) -> tensor<64x64xi32>
|
||||
%82 = tt.addptr %arg14, %81 : tensor<64x64x!tt.ptr<f32>>
|
||||
%82 = tt.addptr %arg14, %81 : tensor<64x64x!tt.ptr<f32>>, tensor<64x64xi32>
|
||||
%83 = arith.muli %arg8, %c64_i32 : i32
|
||||
%84 = tt.splat %83 : (i32) -> tensor<64x64xi32>
|
||||
%85 = tt.addptr %arg15, %84 : tensor<64x64x!tt.ptr<f32>>
|
||||
%85 = tt.addptr %arg15, %84 : tensor<64x64x!tt.ptr<f32>>, tensor<64x64xi32>
|
||||
scf.yield %79, %82, %85 : tensor<64x64xf32>, tensor<64x64x!tt.ptr<f32>>, tensor<64x64x!tt.ptr<f32>>
|
||||
}
|
||||
%48 = arith.muli %12, %c64_i32 : i32
|
||||
@@ -90,7 +90,7 @@ func @matmul_kernel__Pfp32_Pfp32_Pfp32_i32_i32_i32_i32_i32_i32_i32_i32_i32__12c6
|
||||
%63 = tt.broadcast %61 : (tensor<1x64xi32>) -> tensor<64x64xi32>
|
||||
%64 = arith.addi %62, %63 : tensor<64x64xi32>
|
||||
%65 = tt.splat %arg2 : (!tt.ptr<f32>) -> tensor<64x64x!tt.ptr<f32>>
|
||||
%66 = tt.addptr %65, %64 : tensor<64x64x!tt.ptr<f32>>
|
||||
%66 = tt.addptr %65, %64 : tensor<64x64x!tt.ptr<f32>>, tensor<64x64xi32>
|
||||
%67 = tt.expand_dims %51 {axis = 1 : i32} : (tensor<64xi32>) -> tensor<64x1xi32>
|
||||
%68 = tt.splat %arg3 : (i32) -> tensor<64x1xi32>
|
||||
%69 = arith.cmpi slt, %67, %68 : tensor<64x1xi32>
|
||||
|
@@ -51,8 +51,8 @@ func @matmul_loop(%lb : index, %ub : index, %step : index, %A : !tt.ptr<f16>, %B
|
||||
%b_op = triton_gpu.convert_layout %b : (tensor<32x128xf16, #B>) -> tensor<32x128xf16, #B_OP>
|
||||
%c = tt.dot %a_op, %b_op, %prev_c {allowTF32 = true, transA = false, transB = false} : tensor<128x32xf16, #A_OP> * tensor<32x128xf16, #B_OP> -> tensor<128x128xf32, #C>
|
||||
|
||||
%next_a_ptr = tt.addptr %a_ptr, %a_off : tensor<128x32x!tt.ptr<f16>, #AL>
|
||||
%next_b_ptr = tt.addptr %b_ptr, %b_off : tensor<32x128x!tt.ptr<f16>, #BL>
|
||||
%next_a_ptr = tt.addptr %a_ptr, %a_off : tensor<128x32x!tt.ptr<f16>, #AL>, tensor<128x32xi32, #AL>
|
||||
%next_b_ptr = tt.addptr %b_ptr, %b_off : tensor<32x128x!tt.ptr<f16>, #BL>, tensor<32x128xi32, #BL>
|
||||
%next_a_ = tt.load %next_a_ptr, %a_mask, %a_other {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<128x32xf16, #AL>
|
||||
%next_a = triton_gpu.convert_layout %next_a_ : (tensor<128x32xf16, #AL>) -> tensor<128x32xf16, #A>
|
||||
%next_b_ = tt.load %next_b_ptr, %b_mask, %b_other {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<32x128xf16, #BL>
|
||||
|
Reference in New Issue
Block a user