13 Commits

Author SHA1 Message Date
Phil Tillet
58d2867fe6 testing things... 2022-12-09 19:31:34 -08:00
Phil Tillet
fa6dbbff60 . 2022-12-08 21:36:47 -08:00
Yan Chunwei
13644e7ac4 adapt isMMAv1Row in backend (#969) 2022-12-08 20:52:43 -08:00
Phil Tillet
0d27912554 [OPTIMIZER] Added isRow attribute for dotOperandEncoding of MMAv1 layout 2022-12-08 18:31:16 -08:00
Keren Zhou
3ed36dcb4d [BACKEND] MMA->DotOperand conversion for chain dot of float32 tensors (#962)
Co-authored-by: Philippe Tillet <phil@openai.com>
2022-12-08 20:11:51 +00:00
Keren Zhou
83f3b9165b [FRONTEND][BACKEND] Fix bool and int8 load when the other operand is given (#968) 2022-12-08 11:52:18 -08:00
Keren Zhou
71c35bcf9c [Triton-MLIR][BACKEND] Mark float to integer in Arithmetic Dialect as legal (#963) 2022-12-08 09:07:01 -08:00
Dongdong Li
c7cf9c6a32 [TRITON-MLIR][BACKEND]fix atomic_rmw for vector (#966)
Co-authored-by: dongdongl <dongdongl@nvidia.com>
2022-12-08 20:03:40 +08:00
Yan Chunwei
f0885e9caf [Triton-MLIR][BACKEND] Tiny patch for MMAv1 and code clean (#964)
This PR:

- Several fix on MMAV1 code
- Remove the env `TRITON_STATIC_LOOP_UNROLLING` in v100 CI since the
pipeline pass works now
- some code clean
2022-12-08 16:39:32 +08:00
Keren Zhou
18e683d9bb [Triton-MLIR][BACKEND] Pass compute capability from the frontend and code cleanup (#961) 2022-12-07 15:03:46 -08:00
Yan Chunwei
4eab9dcedf [Triton-MLIR][BACKEND] make MMAv1 splitk works (#960) 2022-12-07 08:58:38 +00:00
Philippe Tillet
b2b793dfb5 [FRONTEND][BACKEND] Fixes for cat / reshape / addptr (#959)
Most notably, this PR:
- changes the traits (and assembly format) of addptr so it can handle offsets that have arbitrary integer width.
- adds support for `cat`
2022-12-06 23:29:50 -08:00
Philippe Tillet
981aee7f1e [FRONTEND] Frontend fixes for uint / for loops / random (#958) 2022-12-06 20:25:47 -08:00
36 changed files with 958 additions and 355 deletions

View File

@@ -89,7 +89,7 @@ jobs:
if: ${{matrix.runner[0] == 'self-hosted' && matrix.runner[1] == 'V100'}}
run: |
cd python/tests
pytest test_gemm.py::test_gemm_no_scf_for_mmav1
pytest test_gemm.py::test_gemm_for_mmav1
- name: Run CXX unittests
run: |

View File

@@ -103,15 +103,12 @@ def TT_AddPtrOp : TT_Op<"addptr",
SameOperandsAndResultShape,
SameOperandsAndResultEncoding,
TypesMatchWith<"result type matches ptr type",
"result", "ptr", "$_self">,
TypesMatchWith<"result shape matches offset shape",
"result", "offset",
"getI32SameShape($_self)">]> {
let arguments = (ins TT_PtrLike:$ptr, TT_I32Like:$offset);
"result", "ptr", "$_self">]> {
let arguments = (ins TT_PtrLike:$ptr, TT_IntLike:$offset);
let results = (outs TT_PtrLike:$result);
let assemblyFormat = "$ptr `,` $offset attr-dict `:` type($result)";
let assemblyFormat = "$ptr `,` $offset attr-dict `:` type($result) `,` type($offset)";
}

View File

@@ -416,15 +416,35 @@ In TritonGPU dialect, considering `d = tt.dot a, b, c`
tt.dot's operands a and b must be of DotOperandEncodingAttr layout.
a's opIdx is 0, b's opIdx is 1.
The parend field in DotOperandEncodingAttr is the layout of d.
For MMA v1, an additional attribute `isMMAv1Row` determines whether e.g. the a operand is used
in the context of an mma.884.row.col or an mma.884.col.col operation. See the PTX ISA documentation
section 9.7.13.4.1 for more details.
}];
let parameters = (
ins
"unsigned":$opIdx,
"Attribute":$parent
"Attribute":$parent,
"Attribute":$isMMAv1Row
);
let builders = [
AttrBuilder<(ins "unsigned":$opIdx,
"Attribute":$parent), [{
Attribute isMMAv1Row;
if(parent.isa<MmaEncodingAttr>() &&
parent.cast<MmaEncodingAttr>().getVersion() == 1){
isMMAv1Row = BoolAttr::get(context, true);
}
return $_get(context, opIdx, parent, isMMAv1Row);
}]>
];
let extraClassDeclaration = extraBaseClassDeclaration;
}
#endif

View File

@@ -32,6 +32,12 @@ def TTG_AsyncWaitOp : TTG_Op<"async_wait"> {
let arguments = (ins I32Attr:$num);
let assemblyFormat = "attr-dict";
let extraClassDeclaration = [{
static bool isSupported(int computeCapability) {
return computeCapability >= 80;
}
}];
}
// Port Arith_CmpIOp & Arith_CmpFOp & Std_SelectOp to TritonGPU.
@@ -152,7 +158,13 @@ def TTG_InsertSliceAsyncOp : TTG_Op<"insert_slice_async",
//}];
let extraClassDeclaration = [{
static DenseSet<unsigned> getEligibleLoadByteWidth(int computeCapability);
static DenseSet<unsigned> getEligibleLoadByteWidth(int computeCapability) {
DenseSet<unsigned> validLoadBytes;
if (computeCapability >= 80) {
validLoadBytes = {4, 8, 16};
}
return validLoadBytes;
}
}];
// The custom parser could be replaced with oilist in LLVM-16

View File

@@ -43,12 +43,51 @@ using ::mlir::triton::gpu::SharedEncodingAttr;
struct DotOpMmaV1ConversionHelper {
MmaEncodingAttr mmaLayout;
ArrayRef<unsigned> wpt;
static constexpr std::array<int, 3> fpw{{2, 2, 1}};
using ValueTable = std::map<std::pair<int, int>, std::pair<Value, Value>>;
explicit DotOpMmaV1ConversionHelper(MmaEncodingAttr mmaLayout)
: mmaLayout(mmaLayout), wpt(mmaLayout.getWarpsPerCTA()) {}
// Help to share some variables across multiple functions for A.
struct AParam {
SmallVector<int> rep;
SmallVector<int> spw;
// TODO[Superjomn]: Support the case when isAVec4=false later
// Currently, we only support ld.v2, for the mma layout varies with
// different ld vector width.
// bool isAVec4 = !isARow && shapeTransed[orderTransed[0]] <= 16;
const bool isAVec4{true};
explicit AParam(bool isARow) {
int packSize0 = (isARow || isAVec4) ? 1 : 2;
int repM = 2 * packSize0;
int repK = 1;
int spwM = fpw[0] * 4 * repM;
rep.assign({repM, 0, repK});
spw.assign({spwM, 0, 1});
}
};
// Help to share some variables across multiple functions for A.
struct BParam {
SmallVector<int> rep;
SmallVector<int> spw;
// TODO[Superjomn]: Support the case when isBVec4=false later
// Currently, we only support ld.v2, for the mma layout varies with
// different ld vector width.
// bool isBVec4 = isBRow && shapeTransed[orderTransed[0]] <= 16;
const bool isBVec4{true};
explicit BParam(bool isBRow) {
int packSize1 = (isBRow && !isBVec4) ? 2 : 1;
rep.assign({0, 2 * packSize1, 1});
spw.assign({0, fpw[1] * 4 * rep[1], 1});
}
};
int getRepM(int M) const {
return std::max<int>(M / (wpt[0] * instrShape[0]), 1);
}
@@ -65,24 +104,34 @@ struct DotOpMmaV1ConversionHelper {
return struct_ty(SmallVector<Type>{8, fp32Ty});
}
// number of fp16x2 elements for $a.
int numElemsPerThreadA(RankedTensorType tensorTy) const {
auto shape = tensorTy.getShape();
auto order = getOrder();
// Get the number of fp16x2 elements for $a.
// \param shapeTransed: the shape or reordered shape if transpose needed.
// \param orderTransed: the order or reordered order if transpose needed.
unsigned getNumM(ArrayRef<int64_t> shapeTransed,
ArrayRef<unsigned> orderTransed) const {
bool isARow = orderTransed[0] != 0;
AParam param(isARow);
bool isARow = order[0] != 0;
bool isAVec4 = !isARow && shape[order[0]] <= 16; // fp16*4 = 16bytes
int packSize0 = (isARow || isAVec4) ? 1 : 2;
unsigned numM = param.rep[0] * shapeTransed[0] / (param.spw[0] * wpt[0]);
return numM;
}
SmallVector<int> fpw({2, 2, 1});
int repM = 2 * packSize0;
int repK = 1;
int spwM = fpw[0] * 4 * repM;
SmallVector<int> rep({repM, 0, repK}); // pad N with 0
SmallVector<int> spw({spwM, 0, 1}); // pad N with 0
// Get the number of fp16x2 elements for $b.
// \param shapeTransed: the shape or reordered shape if transpose needed.
// \param orderTransed: the order or reordered order if transpose needed.
unsigned getNumN(ArrayRef<int64_t> shapeTransed,
ArrayRef<unsigned> orderTransed) const {
bool isBRow = orderTransed[0] != 0;
BParam param(isBRow);
int NK = shape[1];
unsigned numM = rep[0] * shape[0] / (spw[0] * wpt[0]);
unsigned numN = param.rep[1] * shapeTransed[1] / (param.spw[1] * wpt[1]);
return numN;
}
int numElemsPerThreadA(ArrayRef<int64_t> shapeTransed,
ArrayRef<unsigned> orderTransed) const {
int numM = getNumM(shapeTransed, orderTransed);
int NK = shapeTransed[1];
// NOTE: We couldn't get the vec from the shared layout.
// int vecA = sharedLayout.getVec();
@@ -92,34 +141,27 @@ struct DotOpMmaV1ConversionHelper {
return (numM / 2) * (NK / 4) * elemsPerLd;
}
// number of fp16x2 elements for $b.
int numElemsPerThreadB(RankedTensorType tensorTy) const {
auto shape = tensorTy.getShape();
auto order = getOrder();
bool isBRow = order[0] != 0;
bool isBVec4 = isBRow && shape[order[0]] <= 16;
int packSize1 = (isBRow && !isBVec4) ? 2 : 1;
SmallVector<int> fpw({2, 2, 1});
SmallVector<int> rep({0, 2 * packSize1, 1}); // pad M with 0
SmallVector<int> spw({0, fpw[1] * 4 * rep[1], 1}); // pad M with 0
int numElemsPerThreadB(ArrayRef<int64_t> shapeTransed,
ArrayRef<unsigned> orderTransed) const {
unsigned numN = getNumN(shapeTransed, orderTransed);
int NK = shapeTransed[0];
// NOTE: We couldn't get the vec from the shared layout.
// int vecB = sharedLayout.getVec();
// TODO[Superjomn]: Consider the case when vecA > 4
bool vecGt4 = false;
int elemsPerLd = vecGt4 ? 4 : 2;
int NK = shape[0];
unsigned numN = rep[1] * shape[1] / (spw[1] * wpt[0]);
return (numN / 2) * (NK / 4) * elemsPerLd;
}
// Loading $a from smem to registers, returns a LLVM::Struct.
Value loadA(Value A, const SharedMemoryObject &smemObj, Value thread,
Location loc, ConversionPatternRewriter &rewriter) const;
Value loadA(Value A, bool transA, const SharedMemoryObject &smemObj,
Value thread, Location loc,
ConversionPatternRewriter &rewriter) const;
// Loading $b from smem to registers, returns a LLVM::Struct.
Value loadB(Value B, const SharedMemoryObject &smemObj, Value thread,
Location loc, ConversionPatternRewriter &rewriter) const;
Value loadB(Value B, bool transB, const SharedMemoryObject &smemObj,
Value thread, Location loc,
ConversionPatternRewriter &rewriter) const;
static ArrayRef<unsigned> getOrder() { return mmaOrder; }
@@ -1311,8 +1353,22 @@ struct DotOpFMAConversionHelper {
};
Value DotOpMmaV1ConversionHelper::loadA(
Value tensor, const SharedMemoryObject &smemObj, Value thread, Location loc,
ConversionPatternRewriter &rewriter) const {
Value tensor, bool transA, const SharedMemoryObject &smemObj, Value thread,
Location loc, ConversionPatternRewriter &rewriter) const {
// [1, 0] (isRow = True)
// x x x x || x x x x
// x x x x || x x x x
// stride = [8, 1]
// strideA0 = strideAk = 1
// strideA1 = strideAm = 8
// [0, 1] (isRow = False)
// x x x x || x x x x
// x x x x || x x x x
// stride = [1, 2]
// strideA0 = strideAm = 1
// strideA1 = strideAk = 2
auto *ctx = rewriter.getContext();
auto tensorTy = tensor.getType().cast<RankedTensorType>();
@@ -1322,28 +1378,15 @@ Value DotOpMmaV1ConversionHelper::loadA(
SmallVector<unsigned> order(sharedLayout.getOrder().begin(),
sharedLayout.getOrder().end());
Value cSwizzleOffset = smemObj.getCSwizzleOffset(order[0]);
Value smemBase = smemObj.getBaseBeforeSwizzle(order[0], loc, rewriter);
// Value smemBase = smemObj.getBaseBeforeSwizzle(order[0], loc, rewriter);
Value smemBase = smemObj.base;
bool isARow = order[0] != 0;
bool isAVec4 = !isARow && shape[order[0]] <= 16; // fp16*4 = 16bytes
// TODO[Superjomn]: Support the case when isAVec4=false later
// Currently, we only support ld.v2, for the mma layout varies with different
// ld vector width.
isAVec4 = true;
int packSize0 = (isARow || isAVec4) ? 1 : 2;
AParam param(isARow);
SmallVector<int> fpw({2, 2, 1});
int repM = 2 * packSize0;
int repK = 1;
int spwM = fpw[0] * 4 * repM;
SmallVector<int> rep({repM, 0, repK}); // pad N with 0
SmallVector<int> spw({spwM, 0, 1}); // pad N with 0
auto [offsetAM, offsetAK, _0, _1] = computeOffsets(
thread, isARow, false, fpw, param.spw, param.rep, rewriter, loc);
auto [offsetAM, offsetAK, _0, _1] =
computeOffsets(thread, isARow, false, fpw, spw, rep, rewriter, loc);
// TODO [Superjomn]: transA cannot be accessed in ConvertLayoutOp.
bool transA = false;
if (transA) {
std::swap(shape[0], shape[1]);
std::swap(offsetAM, offsetAK);
@@ -1358,6 +1401,7 @@ Value DotOpMmaV1ConversionHelper::loadA(
Value strideA0 = isARow ? strideAK : strideAM;
Value strideA1 = isARow ? strideAM : strideAK;
smemBase = gep(ptr_ty(f16_ty), smemBase, Value(smemObj.offsets[1]));
int strideRepM = wpt[0] * fpw[0] * 8;
int strideRepK = 1;
@@ -1372,7 +1416,9 @@ Value DotOpMmaV1ConversionHelper::loadA(
Value offA0 = isARow ? offsetAK : offsetAM;
Value offA1 = isARow ? offsetAM : offsetAK;
Value phaseA = urem(udiv(offA1, i32_val(perPhaseA)), i32_val(maxPhaseA));
offA0 = add(offA0, cSwizzleOffset);
// offA0 = add(offA0, smemObj.offsets[order[0]]);
// offA1 = add(offA1, smemObj.offsets[order[1]]);
SmallVector<Value> offA(numPtrA);
for (int i = 0; i < numPtrA; i++) {
Value offA0I = add(offA0, i32_val(i * (isARow ? 4 : strideRepM)));
@@ -1391,10 +1437,9 @@ Value DotOpMmaV1ConversionHelper::loadA(
for (int i = 0; i < numPtrA; i++)
ptrA[i] = gep(ptr_ty(f16_ty), smemBase, offA[i]);
unsigned numM = std::max<int>(rep[0] * shape[0] / (spw[0] * wpt[0]), 1);
Type f16PtrTy = ptr_ty(f16_ty);
auto ld = [&](decltype(has) &vals, int m, int k, Value val0, Value val1) {
vals[{m, k}] = {val0, val1};
};
@@ -1424,6 +1469,10 @@ Value DotOpMmaV1ConversionHelper::loadA(
}
};
unsigned numM = getNumM(shape, order);
llvm::outs() << "LOAD A " << numM << " " << NK << "\n";
for (unsigned k = 0; k < NK; k += 4)
for (unsigned m = 0; m < numM / 2; ++m)
loadA(m, k);
@@ -1441,8 +1490,8 @@ Value DotOpMmaV1ConversionHelper::loadA(
}
Value DotOpMmaV1ConversionHelper::loadB(
Value tensor, const SharedMemoryObject &smemObj, Value thread, Location loc,
ConversionPatternRewriter &rewriter) const {
Value tensor, bool transB, const SharedMemoryObject &smemObj, Value thread,
Location loc, ConversionPatternRewriter &rewriter) const {
// smem
auto strides = smemObj.strides;
@@ -1455,20 +1504,11 @@ Value DotOpMmaV1ConversionHelper::loadB(
SmallVector<unsigned> order(sharedLayout.getOrder().begin(),
sharedLayout.getOrder().end());
Value smem = smemObj.getBaseBeforeSwizzle(order[0], loc, rewriter);
bool isBRow = order[0] != 0;
bool isBVec4 = isBRow && shape[order[0]] <= 16;
// TODO[Superjomn]: Support the case when isBVec4=false later
// Currently, we only support ld.v2, for the mma layout varies with different
// ld vector width.
isBVec4 = true;
int packSize1 = (isBRow && !isBVec4) ? 2 : 1;
SmallVector<int> fpw({2, 2, 1});
SmallVector<int> rep({0, 2 * packSize1, 1}); // pad M with 0
SmallVector<int> spw({0, fpw[1] * 4 * rep[1], 1}); // pad M with 0
int vecB = sharedLayout.getVec();
BParam param(isBRow);
int vecB = sharedLayout.getVec();
Value strideBN = isBRow ? i32_val(1) : strides[1];
Value strideBK = isBRow ? strides[0] : i32_val(1);
Value strideB0 = isBRow ? strideBN : strideBK;
@@ -1476,11 +1516,8 @@ Value DotOpMmaV1ConversionHelper::loadB(
int strideRepN = wpt[1] * fpw[1] * 8;
int strideRepK = 1;
// TODO [Superjomn]: transB cannot be accessed in ConvertLayoutOp.
bool transB = false;
auto [_0, _1, offsetBN, offsetBK] =
computeOffsets(thread, false, isBRow, fpw, spw, rep, rewriter, loc);
auto [_0, _1, offsetBN, offsetBK] = computeOffsets(
thread, false, isBRow, fpw, param.spw, param.rep, rewriter, loc);
if (transB) {
std::swap(order[0], order[1]);
std::swap(shape[0], shape[1]);
@@ -1547,7 +1584,7 @@ Value DotOpMmaV1ConversionHelper::loadB(
}
};
unsigned numN = rep[1] * shape[1] / (spw[1] * wpt[0]);
unsigned numN = getNumN(shape, order);
for (unsigned k = 0; k < NK; k += 4)
for (unsigned n = 0; n < numN / 2; ++n) {
if (!hbs.count({n, k}))

View File

@@ -973,8 +973,10 @@ struct LoadOpConversion
if (other) {
for (size_t ii = 0; ii < nWords; ++ii) {
// PTX doesn't support mov.u8, so we need to use mov.u16
auto movWidth = width < 16 ? 16 : width;
PTXInstr &mov =
ptxBuilder.create<>("mov")->o("u" + std::to_string(width));
ptxBuilder.create<>("mov")->o("u" + std::to_string(movWidth));
size_t size = width / valueElemNbits;
@@ -1707,6 +1709,43 @@ LogicalResult ReduceOpConversion::matchAndRewriteFast(
/// ====================== reduce codegen end ==========================
/// ====================== cat codegen begin ==========================
struct CatOpConversion : public ConvertTritonGPUOpToLLVMPattern<CatOp> {
using OpAdaptor = typename CatOp::Adaptor;
explicit CatOpConversion(LLVMTypeConverter &typeConverter,
PatternBenefit benefit = 1)
: ConvertTritonGPUOpToLLVMPattern<CatOp>(typeConverter, benefit) {}
LogicalResult
matchAndRewrite(CatOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
Location loc = op->getLoc();
auto resultTy = op.getType().template cast<RankedTensorType>();
unsigned elems = getElemsPerThread(resultTy);
Type elemTy =
this->getTypeConverter()->convertType(resultTy.getElementType());
SmallVector<Type> types(elems, elemTy);
// unpack input values
auto lhsVals = getElementsFromStruct(loc, adaptor.lhs(), rewriter);
auto rhsVals = getElementsFromStruct(loc, adaptor.rhs(), rewriter);
// concatenate (and potentially reorder) values
SmallVector<Value> retVals;
for (Value v : lhsVals)
retVals.push_back(v);
for (Value v : rhsVals)
retVals.push_back(v);
// pack and replace
Type structTy = LLVM::LLVMStructType::getLiteral(this->getContext(), types);
Value ret = getStructFromElements(loc, retVals, rewriter, structTy);
rewriter.replaceOp(op, ret);
return success();
}
};
/// ====================== cat codegen end ==========================
template <typename SourceOp>
struct ViewLikeOpConversion : public ConvertTritonGPUOpToLLVMPattern<SourceOp> {
using OpAdaptor = typename SourceOp::Adaptor;
@@ -2643,63 +2682,24 @@ public:
dstLayout.isa<SliceEncodingAttr>())) {
return lowerDistributedToDistributed(op, adaptor, rewriter);
}
// dot_op<opIdx=0, parent=#mma> = #mma
// when #mma = MmaEncoding<version=2, warpsPerCTA=[..., 1]>
if (srcLayout.isa<MmaEncodingAttr>() &&
dstLayout.isa<DotOperandEncodingAttr>()) {
auto srcMmaLayout = srcLayout.cast<MmaEncodingAttr>();
auto dstDotLayout = dstLayout.cast<DotOperandEncodingAttr>();
if (srcMmaLayout.getWarpsPerCTA()[1] == 1 &&
dstDotLayout.getOpIdx() == 0 &&
dstDotLayout.getParent() == srcMmaLayout) {
// get source values
Location loc = op->getLoc();
auto vals = getElementsFromStruct(loc, adaptor.src(), rewriter);
unsigned elems = getElemsPerThread(srcTy);
Type elemTy =
this->getTypeConverter()->convertType(srcTy.getElementType());
// for the destination type, we need to pack values together
// so they can be consumed by tensor core operations
unsigned vecSize =
std::max<unsigned>(32 / elemTy.getIntOrFloatBitWidth(), 1);
Type vecTy = vec_ty(elemTy, vecSize);
SmallVector<Type> types(elems / vecSize, vecTy);
SmallVector<Value> vecVals;
for (unsigned i = 0; i < elems; i += vecSize) {
Value packed = rewriter.create<LLVM::UndefOp>(loc, vecTy);
for (unsigned j = 0; j < vecSize; j++)
packed = insert_element(vecTy, packed, vals[i + j], i32_val(j));
vecVals.push_back(packed);
}
// This needs to be ordered the same way that
// ldmatrix.x4 would order it
// TODO: this needs to be refactor so we don't
// implicitly depends on how emitOffsetsForMMAV2
// is implemented
SmallVector<Value> reorderedVals;
for (unsigned i = 0; i < vecVals.size(); i += 4) {
reorderedVals.push_back(vecVals[i]);
reorderedVals.push_back(vecVals[i + 2]);
reorderedVals.push_back(vecVals[i + 1]);
reorderedVals.push_back(vecVals[i + 3]);
}
// return composeValuesToDotOperandLayoutStruct(ha, numRepM, numRepK);
Type structTy =
LLVM::LLVMStructType::getLiteral(this->getContext(), types);
Value view =
getStructFromElements(loc, reorderedVals, rewriter, structTy);
rewriter.replaceOp(op, view);
return success();
}
return lowerMmaToDotOperand(op, adaptor, rewriter);
}
// TODO: to be implemented
llvm_unreachable("unsupported layout conversion");
return failure();
}
static bool isMmaToDotShortcut(MmaEncodingAttr &mmaLayout,
DotOperandEncodingAttr &dotOperandLayout) {
// dot_op<opIdx=0, parent=#mma> = #mma
// when #mma = MmaEncoding<version=2, warpsPerCTA=[..., 1]>
return mmaLayout.getWarpsPerCTA()[1] == 1 &&
dotOperandLayout.getOpIdx() == 0 &&
dotOperandLayout.getParent() == mmaLayout;
}
static void storeBlockedToShared(Value src, Value llSrc,
ArrayRef<Value> srcStrides,
ArrayRef<Value> srcIndices, Value dst,
@@ -2964,6 +2964,11 @@ private:
lowerSharedToDotOperand(triton::gpu::ConvertLayoutOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const;
// mma -> dot_operand
LogicalResult lowerMmaToDotOperand(triton::gpu::ConvertLayoutOp op,
OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const;
// shared -> dot_operand if the result layout is mma
Value lowerSharedToDotOperandMMA(
triton::gpu::ConvertLayoutOp op, OpAdaptor adaptor,
@@ -3170,6 +3175,58 @@ LogicalResult ConvertLayoutOpConversion::lowerBlockedToShared(
return success();
}
LogicalResult ConvertLayoutOpConversion::lowerMmaToDotOperand(
triton::gpu::ConvertLayoutOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
auto loc = op.getLoc();
auto srcTy = op.src().getType().cast<RankedTensorType>();
auto dstTy = op.result().getType().cast<RankedTensorType>();
auto srcLayout = srcTy.getEncoding();
auto dstLayout = dstTy.getEncoding();
auto srcMmaLayout = srcLayout.cast<MmaEncodingAttr>();
auto dstDotLayout = dstLayout.cast<DotOperandEncodingAttr>();
if (isMmaToDotShortcut(srcMmaLayout, dstDotLayout)) {
// get source values
auto vals = getElementsFromStruct(loc, adaptor.src(), rewriter);
unsigned elems = getElemsPerThread(srcTy);
Type elemTy = this->getTypeConverter()->convertType(srcTy.getElementType());
// for the destination type, we need to pack values together
// so they can be consumed by tensor core operations
unsigned vecSize =
std::max<unsigned>(32 / elemTy.getIntOrFloatBitWidth(), 1);
Type vecTy = vec_ty(elemTy, vecSize);
SmallVector<Type> types(elems / vecSize, vecTy);
SmallVector<Value> vecVals;
for (unsigned i = 0; i < elems; i += vecSize) {
Value packed = rewriter.create<LLVM::UndefOp>(loc, vecTy);
for (unsigned j = 0; j < vecSize; j++)
packed = insert_element(vecTy, packed, vals[i + j], i32_val(j));
vecVals.push_back(packed);
}
// This needs to be ordered the same way that
// ldmatrix.x4 would order it
// TODO: this needs to be refactor so we don't
// implicitly depends on how emitOffsetsForMMAV2
// is implemented
SmallVector<Value> reorderedVals;
for (unsigned i = 0; i < vecVals.size(); i += 4) {
reorderedVals.push_back(vecVals[i]);
reorderedVals.push_back(vecVals[i + 2]);
reorderedVals.push_back(vecVals[i + 1]);
reorderedVals.push_back(vecVals[i + 3]);
}
// return composeValuesToDotOperandLayoutStruct(ha, numRepM, numRepK);
Type structTy = LLVM::LLVMStructType::getLiteral(this->getContext(), types);
Value view = getStructFromElements(loc, reorderedVals, rewriter, structTy);
rewriter.replaceOp(op, view);
return success();
}
return failure();
}
struct InsertSliceOpConversion
: public ConvertTritonGPUOpToLLVMPattern<tensor::InsertSliceOp> {
using ConvertTritonGPUOpToLLVMPattern<
@@ -3370,15 +3427,33 @@ Value ConvertLayoutOpConversion::lowerSharedToDotOperandMMA(
}
} else if (!isOuter && mmaLayout.getVersion() == 1 &&
isHMMA) { // tensor core v1
// vprintf("offset 0", smemObj.offsets[0]}, rewriter);
DotOpMmaV1ConversionHelper helper(mmaLayout);
if (dotOperandLayout.getOpIdx() == 0) {
// operand $a
res =
helper.loadA(src, smemObj, getThreadId(rewriter, loc), loc, rewriter);
} else if (dotOperandLayout.getOpIdx() == 1) {
// operand $b
res =
helper.loadB(src, smemObj, getThreadId(rewriter, loc), loc, rewriter);
bool isMMAv1Row =
dotOperandLayout.getIsMMAv1Row().cast<BoolAttr>().getValue();
auto srcSharedLayout = src.getType()
.cast<RankedTensorType>()
.getEncoding()
.cast<SharedEncodingAttr>();
// Can only convert [1, 0] to row or [0, 1] to col for now
if ((srcSharedLayout.getOrder()[0] == 1 && !isMMAv1Row) ||
(srcSharedLayout.getOrder()[0] == 0 && isMMAv1Row)) {
llvm::errs() << "Unsupported Shared -> DotOperand[MMAv1] conversion\n";
return Value();
}
if (dotOperandLayout.getOpIdx() == 0) { // operand $a
// LLVM::vprintf_array(i32_val(0), smemObj.offsets, "offsets ", "%d", rewriter);
// TODO[Superjomn]: transA is not available here.
bool transA = false;
res = helper.loadA(src, transA, smemObj, getThreadId(rewriter, loc), loc,
rewriter);
} else if (dotOperandLayout.getOpIdx() == 1) { // operand $b
// TODO[Superjomn]: transB is not available here.
bool transB = false;
res = helper.loadB(src, transB, smemObj, getThreadId(rewriter, loc), loc,
rewriter);
}
} else {
assert(false && "Unsupported mma layout found");
@@ -3481,6 +3556,14 @@ DotOpConversion::convertMMA884(triton::DotOp op, DotOpAdaptor adaptor,
.cast<RankedTensorType>()
.getEncoding()
.cast<MmaEncodingAttr>();
auto ALayout = A.getType()
.cast<RankedTensorType>()
.getEncoding()
.cast<DotOperandEncodingAttr>();
auto BLayout = B.getType()
.cast<RankedTensorType>()
.getEncoding()
.cast<DotOperandEncodingAttr>();
auto ATensorTy = A.getType().cast<RankedTensorType>();
auto BTensorTy = B.getType().cast<RankedTensorType>();
@@ -3492,14 +3575,14 @@ DotOpConversion::convertMMA884(triton::DotOp op, DotOpAdaptor adaptor,
auto DShape = DTensorTy.getShape();
auto wpt = mmaLayout.getWarpsPerCTA();
// TODO[Superjomn]: order cannot accessed in DotOp.
SmallVector<unsigned> AOrder({1, 0});
SmallVector<unsigned> BOrder({1, 0});
bool isARow = AOrder[0] != 0;
bool isBRow = BOrder[0] != 0;
bool isARow = ALayout.getIsMMAv1Row().cast<BoolAttr>().getValue();
bool isBRow = BLayout.getIsMMAv1Row().cast<BoolAttr>().getValue();
bool isAVec4 = !isARow && AShape[isARow] <= 16; // fp16*4 = 16bytes
bool isBVec4 = isBRow && BShape[isBRow] <= 16;
// TODO[Superjomn]: ld.v4 is not supported.
isAVec4 = true;
isBVec4 = true;
int packSize0 = (isARow || isAVec4) ? 1 : 2;
int packSize1 = (isBRow && !isBVec4) ? 2 : 1;
SmallVector<int> fpw({2, 2, 1});
@@ -3512,20 +3595,24 @@ DotOpConversion::convertMMA884(triton::DotOp op, DotOpAdaptor adaptor,
DotOpMmaV1ConversionHelper helper(mmaLayout);
unsigned numM = rep[0] * DShape[0] / (spw[0] * wpt[0]);
unsigned numN = rep[1] * DShape[1] / (spw[1] * wpt[0]);
unsigned numN = rep[1] * DShape[1] / (spw[1] * wpt[1]);
unsigned NK = AShape[1];
auto has = helper.extractLoadedOperand(loadedA, NK, rewriter);
auto hbs = helper.extractLoadedOperand(loadedB, NK, rewriter);
// initialize accumulators
// Initialize accumulators with external values, the acc holds the accumulator
// value that is shared between the MMA instructions inside a DotOp, we can
// call the order of the values the accumulator-internal order.
SmallVector<Value> acc = getElementsFromStruct(loc, loadedC, rewriter);
size_t resSize = acc.size();
// The resVals holds the final result of the DotOp.
// NOTE The current order of resVals is different from acc, we call it the
// accumulator-external order. and
SmallVector<Value> resVals(resSize);
auto callMMA = [&](unsigned m, unsigned n, unsigned k) {
auto ha = has.at({m, k});
auto hb = hbs.at({n, k});
auto getIdx = [&](int m, int n) {
std::vector<size_t> idx{{
(m * 2 + 0) + (n * 4 + 0) * numM, // row0
(m * 2 + 0) + (n * 4 + 1) * numM,
@@ -3536,8 +3623,29 @@ DotOpConversion::convertMMA884(triton::DotOp op, DotOpAdaptor adaptor,
(m * 2 + 1) + (n * 4 + 2) * numM, // row3
(m * 2 + 1) + (n * 4 + 3) * numM,
}};
return idx;
};
{ // convert the acc's value from accumuator-external order to
// accumulator-internal order.
SmallVector<Value> accInit(acc.size());
for (unsigned m = 0; m < numM / 2; ++m)
for (unsigned n = 0; n < numN / 2; ++n) {
auto idx = getIdx(m, n);
for (unsigned i = 0; i < 8; ++i)
accInit[idx[i]] = acc[(m * numN / 2 + n) * 8 + i];
}
acc = accInit;
}
auto callMMA = [&](unsigned m, unsigned n, unsigned k) {
auto ha = has.at({m, k});
auto hb = hbs.at({n, k});
PTXBuilder builder;
auto idx = getIdx(m, n);
auto *resOprs = builder.newListOperand(8, "=f");
auto *AOprs = builder.newListOperand({
@@ -3569,8 +3677,6 @@ DotOpConversion::convertMMA884(triton::DotOp op, DotOpAdaptor adaptor,
for (unsigned i = 0; i < 8; i++) {
Value elem = extract_val(f32_ty, res, getIntAttr(i));
acc[idx[i]] = elem;
// TODO[goostavz]: double confirm this when m/n/k = [32, 32, x] has been
// verified before MMA
resVals[(m * numN / 2 + n) * 8 + i] = elem;
}
};
@@ -3776,7 +3882,8 @@ public:
llvm::Optional<Type> convertTritonTensorType(RankedTensorType type) {
auto ctx = type.getContext();
Attribute layout = type.getEncoding();
auto shape = type.getShape();
SmallVector<int64_t> shape(type.getShape().begin(), type.getShape().end());
if (layout &&
(layout.isa<BlockedEncodingAttr>() || layout.isa<SliceEncodingAttr>() ||
layout.isa<MmaEncodingAttr>())) {
@@ -3839,13 +3946,22 @@ public:
if (mmaLayout.getVersion() == 1) {
DotOpMmaV1ConversionHelper helper(mmaLayout);
// TODO[Superjomn]: Both transA and transB are not available here.
bool trans = false;
// TODO[Superjomn]: The order of A and B are not available here.
SmallVector<unsigned> order({1, 0});
if (trans) {
std::swap(shape[0], shape[1]);
std::swap(order[0], order[1]);
}
if (dotOpLayout.getOpIdx() == 0) { // $a
int elems = helper.numElemsPerThreadA(type);
int elems = helper.numElemsPerThreadA(shape, order);
Type x2Ty = vec_ty(elemTy, 2);
return struct_ty(SmallVector<Type>(elems, x2Ty));
}
if (dotOpLayout.getOpIdx() == 1) { // $b
int elems = helper.numElemsPerThreadB(type);
int elems = helper.numElemsPerThreadB(shape, order);
Type x2Ty = vec_ty(elemTy, 2);
return struct_ty(SmallVector<Type>(elems, x2Ty));
}
@@ -4344,6 +4460,7 @@ struct AtomicRMWOpConversion
rewriter, loc, getTypeConverter()->getIndexType(), ii);
rmwVal = insert_element(vecTy, rmwVal, valElements[i + ii], iiVal);
}
Value rmwPtr = ptrElements[i];
Value rmwMask = maskElements[i];
rmwMask = and_(rmwMask, mask);
@@ -4401,9 +4518,10 @@ struct AtomicRMWOpConversion
atom.o(rmwOp).o(sTy);
if (valueTy) {
atom(dstOpr, ptrOpr, valOpr).predicate(rmwMask);
auto ret = ptxBuilderAtomicRMW.launch(rewriter, loc, valueElemTy);
auto retType = vec == 1 ? valueElemTy : vecTy;
auto ret = ptxBuilderAtomicRMW.launch(rewriter, loc, retType);
for (int ii = 0; ii < vec; ++ii) {
resultVals[i * vec + ii] =
resultVals[i + ii] =
vec == 1 ? ret : extract_element(valueElemTy, ret, idx_val(ii));
}
} else {
@@ -4537,6 +4655,7 @@ void populateTritonToLLVMPatterns(mlir::LLVMTypeConverter &typeConverter,
benefit);
patterns.add<DotOpConversion>(typeConverter, allocation, smem, benefit);
patterns.add<TransOpConversion>(typeConverter, benefit);
patterns.add<CatOpConversion>(typeConverter, benefit);
patterns.add<PrintfOpConversion>(typeConverter, benefit);
}
@@ -4544,6 +4663,34 @@ class ConvertTritonGPUToLLVM
: public ConvertTritonGPUToLLVMBase<ConvertTritonGPUToLLVM> {
private:
void decomposeMmaToDotOperand(ModuleOp mod, int numWarps) {
// replace `mma -> dot_op` with `mma -> blocked -> dot_op`
// unless certain conditions are met
mod.walk([&](triton::gpu::ConvertLayoutOp cvtOp) -> void {
OpBuilder builder(cvtOp);
auto srcType = cvtOp.getOperand().getType().cast<RankedTensorType>();
auto dstType = cvtOp.getType().cast<RankedTensorType>();
auto srcMma =
srcType.getEncoding().dyn_cast<triton::gpu::MmaEncodingAttr>();
auto dstDotOp =
dstType.getEncoding().dyn_cast<triton::gpu::DotOperandEncodingAttr>();
if (srcMma && dstDotOp &&
!ConvertLayoutOpConversion::isMmaToDotShortcut(srcMma, dstDotOp)) {
auto tmpType = RankedTensorType::get(
dstType.getShape(), dstType.getElementType(),
triton::gpu::BlockedEncodingAttr::get(
mod.getContext(), srcType.getShape(), getSizePerThread(srcMma),
getOrder(srcMma), numWarps));
auto tmp = builder.create<triton::gpu::ConvertLayoutOp>(
cvtOp.getLoc(), tmpType, cvtOp.getOperand());
auto newConvert = builder.create<triton::gpu::ConvertLayoutOp>(
cvtOp.getLoc(), dstType, tmp);
cvtOp.replaceAllUsesWith(newConvert.getResult());
cvtOp.erase();
}
});
}
void decomposeBlockedToDotOperand(ModuleOp mod) {
// replace `blocked -> dot_op` with `blocked -> shared -> dot_op`
// because the codegen doesn't handle `blocked -> dot_op` directly
@@ -4571,6 +4718,47 @@ private:
});
}
void rewriteConvertToDotOperand(ModuleOp mod) {
mod.walk([&](triton::gpu::ConvertLayoutOp cvt){
OpBuilder builder(cvt);
auto srcType = cvt.getOperand().getType().cast<RankedTensorType>();
auto dstType = cvt.getResult().getType().cast<RankedTensorType>();
// order
ArrayRef<unsigned> order;
if(auto srcBlockedLayout =
srcType.getEncoding().dyn_cast<triton::gpu::BlockedEncodingAttr>())
order = srcBlockedLayout.getOrder();
else if(auto srcSharedLayout =
srcType.getEncoding().dyn_cast<triton::gpu::SharedEncodingAttr>())
order = srcSharedLayout.getOrder();
else
return;
// dot operand output
auto dstDotOperandLayout =
dstType.getEncoding().dyn_cast<triton::gpu::DotOperandEncodingAttr>();
if (!dstDotOperandLayout)
return;
unsigned opIdx = dstDotOperandLayout.getOpIdx();
if(!dstDotOperandLayout.getIsMMAv1Row())
return;
bool isMMAv1Row = dstDotOperandLayout.getIsMMAv1Row().cast<BoolAttr>().getValue();
if((order[0] == 1 && isMMAv1Row) ||
(order[0] == 0 && !isMMAv1Row))
return;
auto newIsRow = BoolAttr::get(cvt.getContext(), !isMMAv1Row);
auto newDstEncoding = triton::gpu::DotOperandEncodingAttr::get(
cvt.getContext(), dstDotOperandLayout.getOpIdx(), dstDotOperandLayout.getParent(),
newIsRow);
auto newDstType = RankedTensorType::get(
dstType.getShape(),
dstType.getElementType(), newDstEncoding);
auto newCvt = builder.create<triton::gpu::ConvertLayoutOp>(
cvt.getLoc(), newDstType, cvt.getOperand());
cvt.replaceAllUsesWith(newCvt.getResult());
cvt.erase();
});
}
void decomposeInsertSliceAsyncOp(ModuleOp mod) {
AxisInfoAnalysis axisInfoAnalysis(mod.getContext());
axisInfoAnalysis.run(mod);
@@ -4620,8 +4808,7 @@ private:
// capability does not support async copy, then we do decompose
if (triton::gpu::InsertSliceAsyncOp::getEligibleLoadByteWidth(
computeCapability)
.contains(byteWidth) &&
computeCapability >= 80)
.contains(byteWidth))
return;
// load
@@ -4655,13 +4842,8 @@ private:
// async wait is supported in Ampere and later
mod.walk([&](triton::gpu::AsyncWaitOp asyncWaitOp) -> void {
if (computeCapability < 80) {
asyncWaitOp.erase();
} else if (decomposed) {
OpBuilder builder(asyncWaitOp);
// Wait for all previous async ops
auto newAsyncWaitOp = builder.create<triton::gpu::AsyncWaitOp>(
asyncWaitOp.getLoc(), builder.getI64IntegerAttr(0));
if (!triton::gpu::AsyncWaitOp::isSupported(computeCapability) ||
decomposed) {
asyncWaitOp.erase();
}
});
@@ -4696,6 +4878,9 @@ public:
// separation between 1/4 is that, step 3 is out of the scope of Dialect
// Conversion, thus we need to make sure the smem is not revised during the
// conversion of step 4.
rewriteConvertToDotOperand(mod);
decomposeMmaToDotOperand(mod, numWarps);
decomposeBlockedToDotOperand(mod);
decomposeInsertSliceAsyncOp(mod);
@@ -4704,6 +4889,7 @@ public:
MembarAnalysis membarPass(&allocation);
membarPass.run();
llvm::outs() << mod << "\n";
RewritePatternSet scf_patterns(context);
mlir::populateLoopToStdConversionPatterns(scf_patterns);
mlir::ConversionTarget scf_target(*context);

View File

@@ -114,6 +114,7 @@ void populateArithmeticPatternsAndLegality(
GenericOpPattern<arith::TruncIOp>, GenericOpPattern<arith::TruncFOp>,
GenericOpPattern<arith::ExtUIOp>, GenericOpPattern<arith::ExtSIOp>,
GenericOpPattern<arith::ExtFOp>, GenericOpPattern<arith::SIToFPOp>,
GenericOpPattern<arith::FPToSIOp>, GenericOpPattern<arith::FPToUIOp>,
GenericOpPattern<arith::UIToFPOp>>(typeConverter, context);
}
@@ -251,6 +252,22 @@ struct TritonDotPattern : public OpConversionPattern<triton::DotOp> {
}
};
struct TritonCatPattern : public OpConversionPattern<triton::CatOp> {
using OpConversionPattern<triton::CatOp>::OpConversionPattern;
LogicalResult
matchAndRewrite(triton::CatOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
// For now, this behaves like generic, but this will evolve when
// we add support for `can_reorder=False`
Type retType = this->getTypeConverter()->convertType(op.getType());
rewriter.replaceOpWithNewOp<triton::CatOp>(op, retType, adaptor.getOperands());
return success();
}
};
struct TritonTransPattern : public OpConversionPattern<triton::TransOp> {
using OpConversionPattern<triton::TransOp>::OpConversionPattern;
@@ -433,7 +450,9 @@ void populateTritonPatterns(TritonGPUTypeConverter &typeConverter,
TritonGenericPattern<triton::IntToPtrOp>,
TritonGenericPattern<triton::PtrToIntOp>,
TritonGenericPattern<triton::SplatOp>, TritonBroadcastPattern,
TritonGenericPattern<triton::AddPtrOp>, TritonReducePattern,
TritonGenericPattern<triton::AddPtrOp>,
TritonCatPattern,
TritonReducePattern,
TritonTransPattern, TritonExpandDimsPattern, TritonMakeRangePattern,
TritonDotPattern, TritonLoadPattern, TritonStorePattern,
TritonExtElemwisePattern, TritonPrintfPattern, TritonAtomicRMWPattern>(

View File

@@ -19,7 +19,7 @@ mlir::OpTrait::impl::verifySameOperandsAndResultEncoding(Operation *op) {
for (auto resultType : op->getResultTypes())
if (failed(verifySameEncoding(resultType, type)))
return op->emitOpError()
<< "requires the same shape for all operands and results";
<< "requires the same encoding for all operands and results";
return verifySameOperandsEncoding(op);
}

View File

@@ -196,7 +196,7 @@ public:
patterns.add<CombineDotAddFRevPattern>(context);
// %}
patterns.add<CombineSelectMaskedLoadPattern>(context);
patterns.add<CombineAddPtrPattern>(context);
// patterns.add<CombineAddPtrPattern>(context);
patterns.add<CombineBroadcastConstantPattern>(context);
if (applyPatternsAndFoldGreedily(m, std::move(patterns)).failed())

View File

@@ -29,13 +29,14 @@ def CombineDotAddFRevPattern : Pat<
(TT_DotOp $a, $b, $d, $allowTF32),
[(Constraint<CPred<"isZero($0)">> $c)]>;
// TODO: this fails for addptr(addptr(ptr, i32), i64)
// Commented out until fixed
// addptr(addptr(%ptr, %idx0), %idx1) => addptr(%ptr, AddI(%idx0, %idx1))
// Note: leave (sub %c0, %c0) canceling to ArithmeticDialect
// (ref: ArithmeticCanonicalization.td)
def CombineAddPtrPattern : Pat<
(TT_AddPtrOp (TT_AddPtrOp $ptr, $idx0), $idx1),
(TT_AddPtrOp $ptr, (Arith_AddIOp $idx0, $idx1))>;
// def CombineAddPtrPattern : Pat<
// (TT_AddPtrOp (TT_AddPtrOp $ptr, $idx0), $idx1),
// (TT_AddPtrOp $ptr, (Arith_AddIOp $idx0, $idx1))>;
// broadcast(cst) => cst
def getConstantValue : NativeCodeCall<"getConstantValue($_builder, $0, $1)">;

View File

@@ -589,15 +589,24 @@ Attribute DotOperandEncodingAttr::parse(AsmParser &parser, Type type) {
return {};
unsigned opIdx = attrs.get("opIdx").cast<IntegerAttr>().getInt();
Attribute parent = attrs.get("parent");
Attribute isMMAv1Row;
if(parent.isa<MmaEncodingAttr>() &&
parent.cast<MmaEncodingAttr>().getVersion() == 1){
isMMAv1Row = attrs.get("isMMAv1Row");
if(!isMMAv1Row)
llvm::report_fatal_error("isMMAv1Row attribute is missing");
}
return parser.getChecked<DotOperandEncodingAttr>(parser.getContext(), opIdx,
parent);
parent, isMMAv1Row);
}
void DotOperandEncodingAttr::print(mlir::AsmPrinter &printer) const {
printer << "<{"
<< "opIdx = " << getOpIdx() << ", "
<< "parent = " << getParent() << "}>";
<< "parent = " << getParent();
if(getIsMMAv1Row())
printer << ", isMMAv1Row = " << getIsMMAv1Row();
printer << "}>";
}
//===----------------------------------------------------------------------===//
@@ -659,15 +668,6 @@ void printInsertSliceAsyncOp(OpAsmPrinter &printer,
printer.printStrippedAttrOrType(insertSliceAsyncOp.result().getType());
}
DenseSet<unsigned>
InsertSliceAsyncOp::getEligibleLoadByteWidth(int computeCapability) {
DenseSet<unsigned> validLoadBytes;
if (computeCapability >= 80) {
validLoadBytes = {4, 8, 16};
}
return validLoadBytes;
}
//===----------------------------------------------------------------------===//
// ASM Interface (i.e.: alias)
//===----------------------------------------------------------------------===//

View File

@@ -713,6 +713,55 @@ public:
}
};
class OptimizeConvertToDotOperand : public mlir::RewritePattern {
public:
OptimizeConvertToDotOperand(mlir::MLIRContext *context)
: RewritePattern(triton::gpu::ConvertLayoutOp::getOperationName(), 1,
context) {}
mlir::LogicalResult
matchAndRewrite(mlir::Operation *op,
mlir::PatternRewriter &rewriter) const override {
auto cvt = cast<triton::gpu::ConvertLayoutOp>(op);
auto srcType = cvt.getOperand().getType().cast<RankedTensorType>();
auto dstType = cvt.getResult().getType().cast<RankedTensorType>();
// order
ArrayRef<unsigned> order;
if(auto srcBlockedLayout =
srcType.getEncoding().dyn_cast<triton::gpu::BlockedEncodingAttr>())
order = srcBlockedLayout.getOrder();
else if(auto srcSharedLayout =
srcType.getEncoding().dyn_cast<triton::gpu::SharedEncodingAttr>())
order = srcSharedLayout.getOrder();
else
return failure();
// dot operand output
auto dstDotOperandLayout =
dstType.getEncoding().dyn_cast<triton::gpu::DotOperandEncodingAttr>();
if (!dstDotOperandLayout)
return failure();
unsigned opIdx = dstDotOperandLayout.getOpIdx();
if(!dstDotOperandLayout.getIsMMAv1Row())
return failure();
bool isMMAv1Row = dstDotOperandLayout.getIsMMAv1Row().cast<BoolAttr>().getValue();
if((order[0] == 1 && isMMAv1Row) ||
(order[0] == 0 && !isMMAv1Row))
return failure();
auto newIsRow = BoolAttr::get(op->getContext(), !isMMAv1Row);
auto newDstEncoding = triton::gpu::DotOperandEncodingAttr::get(
op->getContext(), dstDotOperandLayout.getOpIdx(), dstDotOperandLayout.getParent(),
newIsRow);
auto newDstType = RankedTensorType::get(
dstType.getShape(),
dstType.getElementType(), newDstEncoding);
auto newCvt = rewriter.create<triton::gpu::ConvertLayoutOp>(
op->getLoc(), newDstType, cvt.getOperand());
rewriter.replaceOp(op, newCvt.getResult());
return success();
}
};
class BlockedToMMA : public mlir::RewritePattern {
int computeCapability;
@@ -770,14 +819,28 @@ public:
Value b = dotOp.b();
auto oldAType = a.getType().cast<RankedTensorType>();
auto oldBType = b.getType().cast<RankedTensorType>();
auto oldAOrder = oldAType.getEncoding().cast<triton::gpu::DotOperandEncodingAttr>()
.getParent().cast<triton::gpu::BlockedEncodingAttr>().getOrder();
auto oldBOrder = oldBType.getEncoding().cast<triton::gpu::DotOperandEncodingAttr>()
.getParent().cast<triton::gpu::BlockedEncodingAttr>().getOrder();
Attribute isMMAv1RowA;
Attribute isMMAv1RowB;
if(version == 1){
isMMAv1RowA = BoolAttr::get(getContext(), oldAOrder[0] == 1);
isMMAv1RowB = BoolAttr::get(getContext(), oldBOrder[0] == 1);
}
auto newAType = RankedTensorType::get(
oldAType.getShape(), oldAType.getElementType(),
triton::gpu::DotOperandEncodingAttr::get(oldAType.getContext(), 0,
newRetType.getEncoding()));
newRetType.getEncoding(),
isMMAv1RowA));
auto newBType = RankedTensorType::get(
oldBType.getShape(), oldBType.getElementType(),
triton::gpu::DotOperandEncodingAttr::get(oldBType.getContext(), 1,
newRetType.getEncoding()));
newRetType.getEncoding(),
isMMAv1RowB));
a = rewriter.create<triton::gpu::ConvertLayoutOp>(a.getLoc(), newAType, a);
b = rewriter.create<triton::gpu::ConvertLayoutOp>(b.getLoc(), newBType, b);
auto newDot = rewriter.create<triton::DotOp>(
@@ -808,6 +871,7 @@ public:
mlir::RewritePatternSet patterns(context);
patterns.add<OptimizeBlockedToShared>(context);
// patterns.add<OptimizeConvertToDotOperand>(context);
patterns.add<SimplifyConversion>(context);
patterns.add<DecomposeDotOperand>(context);
patterns.add<RematerializeBackward>(context);
@@ -818,6 +882,7 @@ public:
if (applyPatternsAndFoldGreedily(m, std::move(patterns)).failed()) {
signalPassFailure();
}
}
};

View File

@@ -134,7 +134,7 @@ translateTritonGPUToLLVMIR(llvm::LLVMContext *llvmContext,
/*printAfterOnlyOnChange=*/true,
/*printAfterOnlyOnFailure*/ false, llvm::dbgs(), printingFlags);
pm.addPass(createConvertTritonGPUToLLVMPass());
pm.addPass(createConvertTritonGPUToLLVMPass(computeCapability));
// Canonicalize to eliminate the remaining UnrealizedConversionCastOp
pm.addPass(mlir::createCanonicalizerPass());
pm.addPass(mlir::createCSEPass()); // Simplify the IR to improve readability.

View File

@@ -1382,10 +1382,11 @@ void init_triton_translation(py::module &m) {
llvm::SMDiagnostic error;
std::unique_ptr<llvm::Module> module =
llvm::parseIR(buffer->getMemBufferRef(), error, context);
if (!module)
if (!module) {
llvm::report_fatal_error(
"failed to parse IR: " + error.getMessage() +
"lineno: " + std::to_string(error.getLineNo()));
}
// translate module to PTX
auto ptxCode =

View File

@@ -64,12 +64,12 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} {
%7 = tt.broadcast %6 : (tensor<1x128xi32, #src>) -> tensor<128x128xi32, #src>
%8 = tt.broadcast %5 : (tensor<128x1xi32, #src>) -> tensor<128x128xi32, #src>
%9 = arith.addi %8, %7 : tensor<128x128xi32, #src>
%10 = tt.addptr %2, %9 : tensor<128x128x!tt.ptr<f16>, #src>
%10 = tt.addptr %2, %9 : tensor<128x128x!tt.ptr<f16>, #src>, tensor<128x128xi32, #src>
%11 = tt.load %10 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<128x128xf16, #src>
%3 = tt.splat %arg1 : (!tt.ptr<f16>) -> tensor<128x128x!tt.ptr<f16>, #dst>
%12 = triton_gpu.convert_layout %9 : (tensor<128x128xi32, #src>) -> tensor<128x128xi32, #dst>
%13 = triton_gpu.convert_layout %11 : (tensor<128x128xf16, #src>) -> tensor<128x128xf16, #dst>
%14 = tt.addptr %3, %12 : tensor<128x128x!tt.ptr<f16>, #dst>
%14 = tt.addptr %3, %12 : tensor<128x128x!tt.ptr<f16>, #dst>, tensor<128x128xi32, #dst>
tt.store %14, %13 : tensor<128x128xf16, #dst>
return
}

View File

@@ -20,6 +20,7 @@ float_dtypes = ['float16', 'float32', 'float64']
dtypes = int_dtypes + uint_dtypes + float_dtypes
# TODO: handle bfloat16
dtypes_with_bfloat16 = dtypes # + ['bfloat16']
torch_dtypes = ['bool'] + int_dtypes + ['uint8'] + float_dtypes # + ['bfloat16']
def _bitwidth(dtype: str) -> int:
@@ -677,6 +678,7 @@ def test_tensor_atomic_rmw(shape, axis, device="cuda"):
kernel[(1,)](z_tri, x_tri, axis, shape0, shape1)
np.testing.assert_allclose(z_ref, to_numpy(z_tri), rtol=1e-4)
def test_atomic_cas():
# 1. make sure that atomic_cas changes the original value (Lock)
@triton.jit
@@ -742,9 +744,11 @@ def test_cast(dtype_x, dtype_z, bitcast, device='cuda'):
# triton kernel
@triton.jit
def kernel(X, Z, BITCAST: tl.constexpr):
x = tl.load(X)
x_ptr = X + tl.arange(0, 1)
z_ptr = Z + tl.arange(0, 1)
x = tl.load(x_ptr)
z = x.to(Z.dtype.element_ty, bitcast=BITCAST)
tl.store(Z, z)
tl.store(z_ptr, z)
dtype_z_np = dtype_z if dtype_z != 'int1' else 'bool_'
# triton result
@@ -1067,21 +1071,23 @@ def test_permute(dtype_str, shape, perm, device='cuda'):
# # ---------------
@pytest.mark.parametrize("epilogue, allow_tf32, dtype",
[(epilogue, allow_tf32, dtype)
@pytest.mark.parametrize("M, N, K, epilogue, allow_tf32, dtype",
[(*shape, epilogue, allow_tf32, dtype)
for shape in [(64, 64, 64)]
for epilogue in ['none', 'trans', 'add-matrix', 'add-rows', 'add-cols', 'softmax', 'chain-dot']
for allow_tf32 in [True, False]
for dtype in ['float16']
for dtype in ['float16', 'float32']
if not (allow_tf32 and (dtype in ['float16']))])
def test_dot(epilogue, allow_tf32, dtype, device='cuda'):
def test_dot(M, N, K, epilogue, allow_tf32, dtype, device='cuda'):
capability = torch.cuda.get_device_capability()
if capability[0] < 80:
if capability[0] < 8:
if dtype == 'int8':
pytest.skip("Only test int8 on devices with sm >= 80")
elif dtype == 'float32' and allow_tf32:
pytest.skip("Only test tf32 on devices with sm >= 80")
M, N, K = 64, 64, 64
torch.backends.cuda.matmul.allow_tf32 = allow_tf32
num_warps = 4
trans_a, trans_b = False, False
@@ -1126,7 +1132,8 @@ def test_dot(epilogue, allow_tf32, dtype, device='cuda'):
if CHAIN_DOT:
# tl.store(Zs, z)
# tl.debug_barrier()
z = tl.dot(z.to(tl.float16), tl.load(Ws))
w = tl.load(Ws)
z = tl.dot(z.to(w.dtype), w)
tl.store(Zs, z)
# input
rs = RandomState(17)
@@ -1176,14 +1183,18 @@ def test_dot(epilogue, allow_tf32, dtype, device='cuda'):
z_ref = np.matmul(z_ref, w)
# compare
# print(z_ref[:,0], z_tri[:,0])
np.testing.assert_allclose(z_ref, to_numpy(z_tri), rtol=0.01)
if dtype == 'float32':
# XXX: Somehow there's a larger difference when we use float32
np.testing.assert_allclose(z_ref, to_numpy(z_tri), rtol=0.01, atol=1e-3)
else:
np.testing.assert_allclose(z_ref, to_numpy(z_tri), rtol=0.01)
# make sure ld/st are vectorized
ptx = pgm.asm['ptx']
assert 'ld.global.v4' in ptx
assert 'st.global.v4' in ptx
if allow_tf32:
assert 'mma.sync.aligned.m16n8k8.row.col.f32.tf32.tf32.f32' in ptx
elif dtype == 'float32':
elif dtype == 'float32' and allow_tf32:
assert 'mma.sync.aligned.m16n8k8.row.col.f32.tf32.tf32.f32' not in ptx
elif dtype == 'int8':
assert 'mma.sync.aligned.m16n8k32.row.col.satfinite.s32.s8.s8.s32' in ptx
@@ -1223,9 +1234,41 @@ def test_arange(start, device='cuda'):
z_ref = torch.arange(start, BLOCK + start, dtype=torch.int32, device=device)
triton.testing.assert_almost_equal(z_tri, z_ref)
# # ---------------
# # test load
# # ---------------
# ---------------
# test load
# ---------------
@pytest.mark.parametrize("dtype_str, size, size_diff", [(dtype_str, size, size_diff) for dtype_str in torch_dtypes for size in [128, 512] for size_diff in [1, 2, 3, 4]])
def test_masked_load(dtype_str, size, size_diff, device='cuda'):
dtype = getattr(torch, dtype_str)
check_type_supported(dtype) # bfloat16 on cc < 80 will not be tested
input_size = size - size_diff
output_size = size
if dtype_str == 'bool':
input = torch.randint(0, 2, (input_size,), dtype=dtype, device=device)
elif dtype_str in int_dtypes or dtype_str in uint_dtypes:
input = torch.randint(0, 127, (input_size,), dtype=dtype, device=device)
else:
input = torch.rand(input_size, dtype=dtype, device=device)
output = torch.zeros((output_size,), dtype=dtype, device=device)
@triton.jit
def _kernel(in_ptr, out_ptr, in_size: tl.constexpr, out_size: tl.constexpr):
in_offsets = tl.arange(0, out_size)
# Load inputs.
x = tl.load(in_ptr + in_offsets, mask=in_offsets < in_size, other=1)
# Store output
output_offsets = tl.arange(0, out_size)
tl.store(out_ptr + output_offsets, x)
_kernel[(1,)](input, output, input_size, output_size)
reference_out = input
reference_out = torch.cat((reference_out, torch.ones((size_diff,), dtype=dtype, device=device)))
triton.testing.allclose(output, reference_out)
# # 'bfloat16': torch.bfloat16,
# # Testing masked loads with an intermate copy to shared memory run.

View File

@@ -295,18 +295,25 @@ def test_gemm_fp32(M, N, K, num_warps, block_M, block_N, block_K, allow_tf32):
# NOTE this is useful only on Volta GPU.
@pytest.mark.parametrize('SHAPE,NUM_WARPS,TRANS_A,TRANS_B', [
(shape, num_warps, trans_a, trans_b)
for shape in [
[16, 16, 16],
[16, 16, 32],
[32, 16, 16],
[32, 32, 32],
[128, 16, 16],
]
for num_warps in [1]
for trans_a in [False]
for trans_b in [False]
@pytest.mark.parametrize('SIZE_M,SIZE_N,SIZE_K,NUM_WARPS,BLOCK_SIZE_M,BLOCK_SIZE_N,BLOCK_SIZE_K,TRANS_A,TRANS_B', [
# Non-forloop
# [16, 16, 16, 1, 16, 16, 16, False, False],
# [16, 16, 32, 1, 16, 16, 32, False, False],
# [32, 16, 32, 1, 32, 16, 32, False, False],
# [32, 32, 32, 1, 32, 32, 32, False, False],
# [128, 32, 32, 1, 128, 32, 32, False, False],
# [128, 32, 32, 1, 128, 32, 32, True, False],
# [128, 32, 32, 1, 128, 32, 32, True, True],
# # split-K
# [16, 16, 32, 1, 16, 16, 16, False, False],
# [64, 64, 128, 1, 64, 64, 32, False, False],
# [16, 16, 32, 1, 16, 16, 16, True, False],
# [16, 16, 32, 1, 16, 16, 16, True, True],
[64, 64, 64, 1, 64, 64, 32, True, False],
])
def test_gemm_no_scf_for_mmav1(SHAPE, NUM_WARPS, TRANS_A, TRANS_B):
test_gemm_no_scf(SHAPE, NUM_WARPS, TRANS_A, TRANS_B)
def test_gemm_for_mmav1(SIZE_M, SIZE_N, SIZE_K, NUM_WARPS, BLOCK_SIZE_M, BLOCK_SIZE_N, BLOCK_SIZE_K, TRANS_A, TRANS_B):
test_gemm(SIZE_M, SIZE_N, SIZE_K, NUM_WARPS, BLOCK_SIZE_M, BLOCK_SIZE_N, BLOCK_SIZE_K, TRANS_A, TRANS_B)

198
python/tests/test_random.py Normal file
View File

@@ -0,0 +1,198 @@
import numpy as np
import pytest
import scipy.stats
import torch
import triton
import triton.language as tl
#####################################
# Reference Philox Implementation
#####################################
class PhiloxConfig:
def __init__(self, PHILOX_ROUND_A, PHILOX_ROUND_B, PHILOX_KEY_A, PHILOX_KEY_B, DTYPE):
self.PHILOX_ROUND_A = np.array(PHILOX_ROUND_A, dtype=DTYPE)
self.PHILOX_ROUND_B = np.array(PHILOX_ROUND_B, dtype=DTYPE)
self.PHILOX_KEY_A = np.array(PHILOX_KEY_A, dtype=DTYPE)
self.PHILOX_KEY_B = np.array(PHILOX_KEY_B, dtype=DTYPE)
self.DTYPE = DTYPE
# This is better for GPU
PHILOX_32 = PhiloxConfig(
PHILOX_KEY_A=0x9E3779B9,
PHILOX_KEY_B=0xBB67AE85,
PHILOX_ROUND_A=0xD2511F53,
PHILOX_ROUND_B=0xCD9E8D57,
DTYPE=np.uint32,
)
# This is what numpy implements
PHILOX_64 = PhiloxConfig(
PHILOX_KEY_A=0x9E3779B97F4A7C15,
PHILOX_KEY_B=0xBB67AE8584CAA73B,
PHILOX_ROUND_A=0xD2E7470EE14C6C93,
PHILOX_ROUND_B=0xCA5A826395121157,
DTYPE=np.uint64,
)
class CustomPhilox4x:
def __init__(self, seed, config):
self._config = config
seed = self._into_pieces(seed)
self._key = np.array(seed[:2], dtype=self._dtype)
self._counter = np.array((0, 0) + seed[2:], dtype=self._dtype)
@property
def _dtype(self):
return self._config.DTYPE
def _into_pieces(self, n, pad=4):
res = []
while len(res) < pad:
res.append(np.array(n, dtype=self._dtype))
n >>= (np.dtype(self._dtype).itemsize * 8)
assert n == 0
return tuple(res)
def _multiply_low_high(self, a, b):
low = a * b
high = int(a) * int(b)
high = np.array(high >> (np.dtype(self._dtype).itemsize * 8), dtype=self._dtype)
return low, high
def _single_round(self, counter, key):
lo0, hi0 = self._multiply_low_high(self._config.PHILOX_ROUND_A, counter[0])
lo1, hi1 = self._multiply_low_high(self._config.PHILOX_ROUND_B, counter[2])
ret0 = hi1 ^ counter[1] ^ key[0]
ret1 = lo1
ret2 = hi0 ^ counter[3] ^ key[1]
ret3 = lo0
return np.array([ret0, ret1, ret2, ret3], dtype=self._dtype)
def _raise_key(self, key):
pk = [self._config.PHILOX_KEY_A, self._config.PHILOX_KEY_B]
return key + np.array(pk, dtype=self._dtype)
def random_raw(self):
counter = self._counter
key = self._key
for _ in range(10):
counter = self._single_round(counter, key)
key = self._raise_key(key)
self.advance(1)
return counter
def advance(self, n_steps):
self._counter[0] += n_steps
assert self._counter[0] < 2**32, "FIXME: doesn't work for large offsets"
class CustomPhilox(CustomPhilox4x):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.buffer = []
def random_raw(self):
if len(self.buffer) == 0:
self.buffer = list(super().random_raw())[::-1]
return int(self.buffer.pop())
#####################################
# Unit Tests
#####################################
BLOCK = 1024
# test generation of random uint32
@pytest.mark.parametrize('size, seed',
[(size, seed) for size in ['10', '4,53', '10000']
for seed in [0, 42, 124, 54, 0xffffffff, 0xdeadbeefcafeb0ba]]
)
def test_randint(size, seed, device='cuda'):
size = list(map(int, size.split(',')))
@triton.jit
def kernel(X, N, seed):
offset = tl.program_id(0) * BLOCK + tl.arange(0, BLOCK)
rand = tl.randint(seed, offset)
tl.store(X + offset, rand, mask=offset < N)
# triton result
x = torch.empty(size, dtype=torch.int32, device=device)
N = x.numel()
grid = (triton.cdiv(N, BLOCK),)
kernel[grid](x, N, seed)
out_tri = x.cpu().numpy().astype(np.uint32).flatten().tolist()
# reference result
gen = CustomPhilox4x(seed, config=PHILOX_32)
out_ref = [gen.random_raw()[0] for _ in out_tri]
assert out_tri == out_ref
# test uniform PRNG
@pytest.mark.parametrize('size, seed',
[(size, seed) for size in [1000000]
for seed in [0, 42, 124, 54]]
)
def test_rand(size, seed, device='cuda'):
@triton.jit
def kernel(X, N, seed):
offset = tl.program_id(0) * BLOCK + tl.arange(0, BLOCK)
rand = tl.rand(seed, offset)
tl.store(X + offset, rand, mask=offset < N)
# triton result
x = torch.empty(size, dtype=torch.float32, device=device)
N = x.numel()
grid = (triton.cdiv(N, BLOCK),)
kernel[grid](x, N, seed)
assert all((x >= 0) & (x <= 1))
assert scipy.stats.kstest(x.tolist(), 'uniform', args=(0, 1)).statistic < 0.01
# test normal PRNG
@pytest.mark.parametrize('size, seed',
[(size, seed) for size in [1000000]
for seed in [0, 42, 124, 54]]
)
def test_randn(size, seed, device='cuda'):
@triton.jit
def kernel(X, N, seed):
offset = tl.program_id(0) * BLOCK + tl.arange(0, BLOCK)
rand = tl.randn(seed, offset)
tl.store(X + offset, rand, mask=offset < N)
# triton result
x = torch.empty(size, dtype=torch.float32, device=device)
N = x.numel()
grid = (triton.cdiv(N, BLOCK),)
kernel[grid](x, N, seed)
assert abs(x.mean()) < 1e-2
assert abs(x.std() - 1) < 1e-2
# tl.rand() should never produce >=1.0
def test_rand_limits():
@triton.jit
def kernel(input, output, n: tl.constexpr):
idx = tl.arange(0, n)
x = tl.load(input + idx)
y = tl.random.uint32_to_uniform_float(x)
tl.store(output + idx, y)
min_max_int32 = torch.tensor([
torch.iinfo(torch.int32).min,
torch.iinfo(torch.int32).max,
], dtype=torch.int32, device='cuda')
output = torch.empty(2, dtype=torch.float32, device='cuda')
kernel[(1,)](min_max_int32, output, 2)
assert output[0] == output[1]
assert 1.0 - torch.finfo(torch.float32).eps <= output[0].item() < 1.0

View File

@@ -371,6 +371,7 @@ class CodeGenerator(ast.NodeVisitor):
# 1. we have an orelse node
# or
# 2. the then block defines new variable
else_defs = {}
if then_defs or node.orelse:
if node.orelse:
self.lscope = liveins
@@ -381,7 +382,6 @@ class CodeGenerator(ast.NodeVisitor):
else_defs = self.local_defs.copy()
else:
# collect else_defs
else_defs = {}
for name in then_defs:
if name in liveins:
assert self.is_triton_tensor(then_defs[name])
@@ -583,7 +583,7 @@ class CodeGenerator(ast.NodeVisitor):
isinstance(step, triton.language.constexpr):
sta_range = iterator(lb.value, ub.value, step.value)
static_unrolling = os.environ.get('TRITON_STATIC_LOOP_UNROLLING', False)
if static_unrolling and len(range) <= 10:
if static_unrolling and len(sta_range) <= 10:
for i in sta_range:
self.lscope[node.target.id] = triton.language.constexpr(i)
self.visit_compound_statement(node.body)
@@ -625,10 +625,12 @@ class CodeGenerator(ast.NodeVisitor):
if name in liveins:
assert self.is_triton_tensor(self.local_defs[name]), f'{name} is not tensor'
assert self.is_triton_tensor(liveins[name])
if self.local_defs[name].type == liveins[name].type:
names.append(name)
init_args.append(triton.language.core._to_tensor(liveins[name], self.builder))
yields.append(triton.language.core._to_tensor(self.local_defs[name], self.builder))
if self.local_defs[name].type != liveins[name].type:
local_value = self.local_defs[name]
self.local_defs[name] = local_value.to(liveins[name].dtype, _builder=self.builder)
names.append(name)
init_args.append(triton.language.core._to_tensor(liveins[name], self.builder))
yields.append(triton.language.core._to_tensor(self.local_defs[name], self.builder))
# create ForOp
self.builder.set_insertion_point_to_end(insert_block)
@@ -1400,9 +1402,9 @@ def compile(fn, **kwargs):
"ttir": (lambda path: _triton.ir.parse_mlir_module(path, context),
lambda src: ast_to_ttir(src, signature, configs[0], constants)),
"ttgir": (lambda path: _triton.ir.parse_mlir_module(path, context),
lambda src: ttir_to_ttgir(src, num_warps, num_stages, capability)),
lambda src: ttir_to_ttgir(src, num_warps, num_stages, 70)),
"llir": (lambda path: Path(path).read_bytes(),
lambda src: ttgir_to_llir(src, extern_libs, capability)),
lambda src: ttgir_to_llir(src, extern_libs, 70)),
"ptx": (lambda path: Path(path).read_text(),
lambda src: llir_to_ptx(src, capability)),
"cubin": (lambda path: Path(path).read_bytes(),

View File

@@ -55,6 +55,7 @@ from .core import (
printf,
program_id,
ravel,
reshape,
sigmoid,
sin,
softmax,
@@ -70,6 +71,7 @@ from .core import (
uint64,
uint8,
umulhi,
view,
void,
where,
xor_sum,
@@ -149,6 +151,7 @@ __all__ = [
"randn",
"randn4x",
"ravel",
"reshape",
"sigmoid",
"sin",
"softmax",
@@ -165,6 +168,7 @@ __all__ = [
"uint64",
"uint8",
"umulhi",
"view",
"void",
"where",
"xor_sum",

View File

@@ -17,11 +17,11 @@ def _to_tensor(x, builder):
if -2**31 <= x < 2**31:
return tensor(builder.get_int32(x), int32)
elif 2**31 <= x < 2**32:
return tensor(builder.get_uint32(x), uint32)
return tensor(builder.get_int32(x), uint32)
elif -2**63 <= x < 2**63:
return tensor(builder.get_int64(x), int64)
elif 2**63 <= x < 2**64:
return tensor(builder.get_uint64(x), uint64)
return tensor(builder.get_int64(x), uint64)
else:
raise RuntimeError(f'Nonrepresentable integer {x}.')
elif isinstance(x, float):
@@ -731,7 +731,7 @@ def trans(input, _builder=None):
return semantic.trans(input, _builder)
@builtin
def cat(input, other, _builder=None):
def cat(input, other, can_reorder=False, _builder=None):
"""
Concatenate the given blocks
@@ -739,8 +739,12 @@ def cat(input, other, _builder=None):
:type input:
:param other: The second input tensor.
:type other:
:param reorder: Compiler hint. If true, the compiler is
allowed to reorder elements while concatenating inputs.
Only use if the order does not matter (e.g., result is
only used in reduction ops)
"""
return semantic.cat(input, other, _builder)
return semantic.cat(input, other, can_reorder, _builder)
@builtin
@@ -761,7 +765,8 @@ def view(input, shape, _builder=None):
@builtin
def reshape(input, shape, _builder=None):
# TODO: should be more than just a view
return view(input, shape, _builder)
shape = [x.value for x in shape]
return semantic.view(input, shape, _builder)
# -----------------------
# Linear Algebra

View File

@@ -1,10 +1,10 @@
import triton
from . import core as tl
PHILOX_KEY_A: tl.constexpr = -1640531527 # 0x9E3779B9
PHILOX_KEY_B: tl.constexpr = -1150833019 # 0xBB67AE85
PHILOX_ROUND_A: tl.constexpr = -766435501 # 0xD2511F53
PHILOX_ROUND_B: tl.constexpr = -845247145 # 0xCD9E8D57
PHILOX_KEY_A: tl.constexpr = 0x9E3779B9
PHILOX_KEY_B: tl.constexpr = 0xBB67AE85
PHILOX_ROUND_A: tl.constexpr = 0xD2511F53
PHILOX_ROUND_B: tl.constexpr = 0xCD9E8D57
N_ROUNDS_DEFAULT = 10 # Default number of rounds for philox
# -------------------

View File

@@ -498,9 +498,11 @@ def expand_dims(input: tl.tensor, axis: int, builder: ir.builder) -> tl.tensor:
return tl.tensor(builder.create_expand_dims(input.handle, axis), ret_ty)
def cat(lhs: tl.tensor, rhs: tl.tensor, builder: ir.builder) -> tl.tensor:
# TODO: check types
return tl.tensor(builder.create_cat(lhs.handle, rhs.handle), lhs.type)
def cat(lhs: tl.tensor, rhs: tl.tensor, can_reorder: bool, builder: ir.builder) -> tl.tensor:
assert can_reorder, "current implementation of `cat` always may reorder elements"
assert len(lhs.shape) == 1
ret_type = tl.block_type(lhs.type.scalar, [lhs.shape[0] + rhs.shape[0]])
return tl.tensor(builder.create_cat(lhs.handle, rhs.handle), ret_type)
def trans(input: tl.tensor, builder: ir.builder) -> tl.tensor:
if len(input.shape) != 2:
@@ -736,16 +738,18 @@ def load(ptr: tl.tensor,
if other:
other = broadcast_impl_shape(other, ptr.type.get_block_shapes(), builder)
if other:
other = cast(other, ptr.type.scalar.element_ty, builder)
ptr_ty = ptr.type.scalar
elt_ty = ptr_ty.element_ty
# treat bool* as tl.int8*
if elt_ty == tl.int1:
elt_ty = tl.int8
ptr_ty = tl.pointer_type(elt_ty, ptr_ty.address_space)
ptr = cast(ptr, ptr_ty, builder)
if other:
other = cast(other, elt_ty, builder)
# cache modifier
cache = ir.CACHE_MODIFIER.NONE # default
if cache_modifier:

View File

@@ -27,8 +27,8 @@ func @matmul_loop(%lb : index, %ub : index, %step : index, %A : !tt.ptr<f16>, %B
%b = triton_gpu.convert_layout %b_ : (tensor<32x128xf16, #BL>) -> tensor<32x128xf16, #B_DOT>
%c = tt.dot %a, %b, %prev_c {transA = false, transB = false, allowTF32 = true} : tensor<128x32xf16, #A_DOT> * tensor<32x128xf16, #B_DOT> -> tensor<128x128xf32, #C>
%next_a_ptr = tt.addptr %a_ptr, %a_off : tensor<128x32x!tt.ptr<f16>, #AL>
%next_b_ptr = tt.addptr %b_ptr, %b_off : tensor<32x128x!tt.ptr<f16>, #BL>
%next_a_ptr = tt.addptr %a_ptr, %a_off : tensor<128x32x!tt.ptr<f16>, #AL>, tensor<128x32xi32, #AL>
%next_b_ptr = tt.addptr %b_ptr, %b_off : tensor<32x128x!tt.ptr<f16>, #BL>, tensor<32x128xi32, #BL>
scf.yield %next_a_ptr, %next_b_ptr, %c : tensor<128x32x!tt.ptr<f16>, #AL>, tensor<32x128x!tt.ptr<f16>, #BL>, tensor<128x128xf32, #C>
}
return

View File

@@ -18,7 +18,7 @@ func @permute_2d(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg1: i32 {t
// CHECK-NEXT: Contiguity: [1, 1] ; Divisibility: [16, 16] ; Constancy: [128, 1]
%5 = tt.splat %arg0 : (!tt.ptr<f32>) -> tensor<128x1x!tt.ptr<f32>>
// CHECK-NEXT: Contiguity: [1, 1] ; Divisibility: [16, 16] ; Constancy: [1, 1]
%6 = tt.addptr %5, %4 : tensor<128x1x!tt.ptr<f32>>
%6 = tt.addptr %5, %4 : tensor<128x1x!tt.ptr<f32>>, tensor<128x1xi32>
// CHECK-NEXT: Contiguity: [1, 128] ; Divisibility: [1, 65536] ; Constancy: [1, 1]
%7 = tt.expand_dims %1 {axis = 0 : i32}: (tensor<128xi32>) -> tensor<1x128xi32>
// CHECK-NEXT: Contiguity: [1, 1] ; Divisibility: [16, 16] ; Constancy: [1, 128]
@@ -26,13 +26,13 @@ func @permute_2d(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg1: i32 {t
// CHECK-NEXT: Contiguity: [1, 128] ; Divisibility: [1, 65536] ; Constancy: [128, 1]
%9 = tt.broadcast %7 : (tensor<1x128xi32>) -> tensor<128x128xi32>
// CHECK-NEXT: Contiguity: [1, 128] ; Divisibility: [1, 16] ; Constancy: [1, 1]
%10 = tt.addptr %8, %9 : tensor<128x128x!tt.ptr<f32>>
%10 = tt.addptr %8, %9 : tensor<128x128x!tt.ptr<f32>>, tensor<128x128xi32>
// CHECK-NEXT: Contiguity: [128, 1] ; Divisibility: [65536, 1] ; Constancy: [1, 1]
%11 = tt.expand_dims %0 {axis = 1 : i32}: (tensor<128xi32>) -> tensor<128x1xi32>
// CHECK-NEXT: Contiguity: [1, 1] ; Divisibility: [16, 16] ; Constancy: [128, 1]
%12 = tt.splat %arg2 : (!tt.ptr<f32>) -> tensor<128x1x!tt.ptr<f32>>
// CHECK-NEXT: Contiguity: [128, 1] ; Divisibility: [16, 1] ; Constancy: [1, 1]
%13 = tt.addptr %12, %11 : tensor<128x1x!tt.ptr<f32>>
%13 = tt.addptr %12, %11 : tensor<128x1x!tt.ptr<f32>>, tensor<128x1xi32>
// CHECK-NEXT: Contiguity: [1, 128] ; Divisibility: [1, 65536] ; Constancy: [1, 1]
%14 = tt.expand_dims %1 {axis = 0 : i32} : (tensor<128xi32>) -> tensor<1x128xi32>
// CHECK-NEXT: Contiguity: [1, 1] ; Divisibility: [16, 16] ; Constancy: [1, 128]
@@ -44,7 +44,7 @@ func @permute_2d(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg1: i32 {t
// CHECK-NEXT: Contiguity: [1, 1] ; Divisibility: [16, 1048576] ; Constancy: [128, 1]
%18 = tt.broadcast %16 : (tensor<1x128xi32>) -> tensor<128x128xi32>
// CHECK-NEXT: Contiguity: [128, 1] ; Divisibility: [16, 1] ; Constancy: [1, 1]
%19 = tt.addptr %17, %18 : tensor<128x128x!tt.ptr<f32>>
%19 = tt.addptr %17, %18 : tensor<128x128x!tt.ptr<f32>>, tensor<128x128xi32>
// CHECK-NEXT: Contiguity: [1, 1] ; Divisibility: [1, 1] ; Constancy: [1, 1]
%20 = tt.load %10, %cst, %cst_0 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<128x128xf32>
tt.store %19, %20, %cst : tensor<128x128xf32>
@@ -72,7 +72,7 @@ func @store_constant_align(%addr: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %n:
// CHECK-NEXT: Contiguity: [1] ; Divisibility: [16] ; Constancy: [128]
%5 = tt.splat %addr : (!tt.ptr<f32>) -> tensor<128x!tt.ptr<f32>>
// CHECK-NEXT: Contiguity: [128] ; Divisibility: [16] ; Constancy: [1]
%6 = tt.addptr %5, %4 : tensor<128x!tt.ptr<f32>>
%6 = tt.addptr %5, %4 : tensor<128x!tt.ptr<f32>>, tensor<128xi32>
// CHECK-NEXT: Contiguity: [1] ; Divisibility: [16] ; Constancy: [128]
%9 = tt.splat %n : (i32) -> tensor<128xi32>
// CHECK-NEXT: Contiguity: [1] ; Divisibility: [128] ; Constancy: [16]
@@ -97,9 +97,9 @@ func @vecadd_mask_align_16(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %ar
%3 = tt.splat %1 : (i32) -> tensor<64xi32>
%4 = arith.addi %3, %2 : tensor<64xi32>
%5 = tt.splat %arg0 : (!tt.ptr<f32>) -> tensor<64x!tt.ptr<f32>>
%6 = tt.addptr %5, %4 : tensor<64x!tt.ptr<f32>>
%6 = tt.addptr %5, %4 : tensor<64x!tt.ptr<f32>>, tensor<64xi32>
%7 = tt.splat %arg1 : (!tt.ptr<f32>) -> tensor<64x!tt.ptr<f32>>
%8 = tt.addptr %7, %4 : tensor<64x!tt.ptr<f32>>
%8 = tt.addptr %7, %4 : tensor<64x!tt.ptr<f32>>, tensor<64xi32>
%9 = tt.splat %n_elements : (i32) -> tensor<64xi32>
// CHECK: Contiguity: [1] ; Divisibility: [64] ; Constancy: [16] ( %{{.*}} = arith.cmpi slt, %{{.*}}, %{{.*}} : tensor<64xi32> )
%mask = arith.cmpi slt, %4, %9 : tensor<64xi32>
@@ -107,8 +107,8 @@ func @vecadd_mask_align_16(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %ar
%12 = tt.load %8, %mask {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<64xf32>
%13 = arith.addf %11, %12 : tensor<64xf32>
%14 = tt.splat %arg2 : (!tt.ptr<f32>) -> tensor<64x!tt.ptr<f32>>
// CHECK: Contiguity: [64] ; Divisibility: [16] ; Constancy: [1] ( %{{.*}} = tt.addptr %{{.*}}, %{{.*}} : tensor<64x!tt.ptr<f32>> )
%15 = tt.addptr %14, %4 : tensor<64x!tt.ptr<f32>>
// CHECK: Contiguity: [64] ; Divisibility: [16] ; Constancy: [1] ( %{{.*}} = tt.addptr %{{.*}}, %{{.*}} : tensor<64x!tt.ptr<f32>>, tensor<64xi32> )
%15 = tt.addptr %14, %4 : tensor<64x!tt.ptr<f32>>, tensor<64xi32>
tt.store %15, %13, %mask : tensor<64xf32>
return
}
@@ -125,9 +125,9 @@ func @vecadd_mask_align_1(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg
%3 = tt.splat %1 : (i32) -> tensor<64xi32>
%4 = arith.addi %3, %2 : tensor<64xi32>
%5 = tt.splat %arg0 : (!tt.ptr<f32>) -> tensor<64x!tt.ptr<f32>>
%6 = tt.addptr %5, %4 : tensor<64x!tt.ptr<f32>>
%6 = tt.addptr %5, %4 : tensor<64x!tt.ptr<f32>>, tensor<64xi32>
%7 = tt.splat %arg1 : (!tt.ptr<f32>) -> tensor<64x!tt.ptr<f32>>
%8 = tt.addptr %7, %4 : tensor<64x!tt.ptr<f32>>
%8 = tt.addptr %7, %4 : tensor<64x!tt.ptr<f32>>, tensor<64xi32>
%9 = tt.splat %n_elements : (i32) -> tensor<64xi32>
// CHECK: Contiguity: [1] ; Divisibility: [64] ; Constancy: [1] ( %{{.*}} = arith.cmpi slt, %{{.*}}, %{{.*}} : tensor<64xi32> )
%10 = arith.cmpi slt, %4, %9 : tensor<64xi32>
@@ -135,7 +135,7 @@ func @vecadd_mask_align_1(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg
%12 = tt.load %8, %10 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<64xf32>
%13 = arith.addf %11, %12 : tensor<64xf32>
%14 = tt.splat %arg2 : (!tt.ptr<f32>) -> tensor<64x!tt.ptr<f32>>
%15 = tt.addptr %14, %4 : tensor<64x!tt.ptr<f32>>
%15 = tt.addptr %14, %4 : tensor<64x!tt.ptr<f32>>, tensor<64xi32>
tt.store %15, %13, %10 : tensor<64xf32>
return
}

View File

@@ -35,8 +35,8 @@ func @matmul_loop(%lb : index, %ub : index, %step : index, %A : !tt.ptr<f16>, %B
%c = tt.dot %a, %b, %prev_c {allowTF32 = true, transA = false, transB = false} : tensor<128x32xf16, #A_DOT> * tensor<32x128xf16, #B_DOT> -> tensor<128x128xf32, #C>
%next_a_ptr = tt.addptr %a_ptr, %a_off : tensor<128x32x!tt.ptr<f16>, #AL>
%next_b_ptr = tt.addptr %b_ptr, %b_off : tensor<32x128x!tt.ptr<f16>, #BL>
%next_a_ptr = tt.addptr %a_ptr, %a_off : tensor<128x32x!tt.ptr<f16>, #AL>, tensor<128x32xi32, #AL>
%next_b_ptr = tt.addptr %b_ptr, %b_off : tensor<32x128x!tt.ptr<f16>, #BL>, tensor<32x128xi32, #BL>
scf.yield %next_a_ptr, %next_b_ptr, %c : tensor<128x32x!tt.ptr<f16>, #AL>, tensor<32x128x!tt.ptr<f16>, #BL>, tensor<128x128xf32, #C>
}
return

View File

@@ -33,8 +33,8 @@ func @matmul_loop(%lb : index, %ub : index, %step : index, %A : !tt.ptr<f16>, %B
%b = triton_gpu.convert_layout %b_ : (tensor<32x128xf16, #BL>) -> tensor<32x128xf16, #B_DOT>
%c = tt.dot %a, %b, %prev_c {allowTF32 = true, transA = false, transB = false} : tensor<128x32xf16, #A_DOT> * tensor<32x128xf16, #B_DOT> -> tensor<128x128xf32, #C>
%next_a_ptr = tt.addptr %a_ptr, %a_off : tensor<128x32x!tt.ptr<f16>, #AL>
%next_b_ptr = tt.addptr %b_ptr, %b_off : tensor<32x128x!tt.ptr<f16>, #BL>
%next_a_ptr = tt.addptr %a_ptr, %a_off : tensor<128x32x!tt.ptr<f16>, #AL>, tensor<128x32xi32, #AL>
%next_b_ptr = tt.addptr %b_ptr, %b_off : tensor<32x128x!tt.ptr<f16>, #BL>, tensor<32x128xi32, #BL>
scf.yield %next_a_ptr, %next_b_ptr, %c : tensor<128x32x!tt.ptr<f16>, #AL>, tensor<32x128x!tt.ptr<f16>, #BL>, tensor<128x128xf32, #C>
}
return

View File

@@ -38,19 +38,19 @@ func @cast_ops(%scalar_ptr: !tt.ptr<f32>, %scalar_f32: f32, %scalar_i64: i64) {
func @addptr_ops(%scalar_ptr: !tt.ptr<f32>, %scalar_i32: i32) {
// scalar -> scalar
// CHECK: !tt.ptr<f32>
%0 = tt.addptr %scalar_ptr, %scalar_i32 : !tt.ptr<f32>
%0 = tt.addptr %scalar_ptr, %scalar_i32 : !tt.ptr<f32>, i32
// 0D tensor -> 0D tensor
%tensor_ptr_0d = tt.splat %scalar_ptr : (!tt.ptr<f32>) -> tensor<!tt.ptr<f32>>
%tensor_i32_0d = tt.splat %scalar_i32 : (i32) -> tensor<i32>
// CHECK: tensor<!tt.ptr<f32>>
%1 = tt.addptr %tensor_ptr_0d, %tensor_i32_0d : tensor<!tt.ptr<f32>>
%1 = tt.addptr %tensor_ptr_0d, %tensor_i32_0d : tensor<!tt.ptr<f32>>, tensor<i32>
// 1D tensor -> 1D tensor
%tensor_ptr_1d = tt.splat %scalar_ptr : (!tt.ptr<f32>) -> tensor<16x!tt.ptr<f32>>
%tensor_i32_1d = tt.splat %scalar_i32 : (i32) -> tensor<16xi32>
// CHECK: tensor<16x!tt.ptr<f32>>
%2 = tt.addptr %tensor_ptr_1d, %tensor_i32_1d : tensor<16x!tt.ptr<f32>>
%2 = tt.addptr %tensor_ptr_1d, %tensor_i32_1d : tensor<16x!tt.ptr<f32>>, tensor<16xi32>
return
}

View File

@@ -92,9 +92,9 @@ module attributes {"triton_gpu.num-warps" = 2 : i32} {
%3 = tt.splat %1 : (i32) -> tensor<256xi32, #blocked0>
%4 = arith.addi %3, %2 : tensor<256xi32, #blocked0>
%5 = tt.splat %arg0 : (!tt.ptr<f32>) -> tensor<256x!tt.ptr<f32>, #blocked0>
%6 = tt.addptr %5, %4 : tensor<256x!tt.ptr<f32>, #blocked0>
%6 = tt.addptr %5, %4 : tensor<256x!tt.ptr<f32>, #blocked0>, tensor<256xi32, #blocked0>
%7 = tt.splat %arg1 : (!tt.ptr<f32>) -> tensor<256x!tt.ptr<f32>, #blocked0>
%8 = tt.addptr %7, %4 : tensor<256x!tt.ptr<f32>, #blocked0>
%8 = tt.addptr %7, %4 : tensor<256x!tt.ptr<f32>, #blocked0>, tensor<256xi32, #blocked0>
// Load 4 elements from vector0
// CHECK: "@${{.*}} ld.global.b32 { ${{.*}} }, [ ${{.*}} + 0 ];
@@ -111,7 +111,7 @@ module attributes {"triton_gpu.num-warps" = 2 : i32} {
%10 = tt.load %8 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<256xf32, #blocked0>
%11 = arith.addf %9, %10 : tensor<256xf32, #blocked0>
%12 = tt.splat %arg2 : (!tt.ptr<f32>) -> tensor<256x!tt.ptr<f32>, #blocked0>
%13 = tt.addptr %12, %4 : tensor<256x!tt.ptr<f32>, #blocked0>
%13 = tt.addptr %12, %4 : tensor<256x!tt.ptr<f32>, #blocked0>, tensor<256xi32, #blocked0>
// Store 4 elements to global
// CHECK: @${{.*}} st.global.b32 [ ${{.*}} + 0 ], { ${{.*}} };
@@ -136,9 +136,9 @@ module attributes {"triton_gpu.num-warps" = 2 : i32} {
%3 = tt.splat %1 : (i32) -> tensor<256xi32, #blocked0>
%4 = arith.addi %3, %2 : tensor<256xi32, #blocked0>
%5 = tt.splat %arg0 : (!tt.ptr<f32>) -> tensor<256x!tt.ptr<f32>, #blocked0>
%6 = tt.addptr %5, %4 : tensor<256x!tt.ptr<f32>, #blocked0>
%6 = tt.addptr %5, %4 : tensor<256x!tt.ptr<f32>, #blocked0>, tensor<256xi32, #blocked0>
%7 = tt.splat %arg1 : (!tt.ptr<f32>) -> tensor<256x!tt.ptr<f32>, #blocked0>
%8 = tt.addptr %7, %4 : tensor<256x!tt.ptr<f32>, #blocked0>
%8 = tt.addptr %7, %4 : tensor<256x!tt.ptr<f32>, #blocked0>, tensor<256xi32, #blocked0>
// Load 4 elements from A with single one vectorized load instruction
// CHECK: @${{.*}} ld.global.v4.b32 { ${{.*}}, ${{.*}}, ${{.*}}, ${{.*}} }, [ ${{.*}} + 0 ];
@@ -150,7 +150,7 @@ module attributes {"triton_gpu.num-warps" = 2 : i32} {
%10 = tt.load %8 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<256xf32, #blocked0>
%11 = arith.addf %9, %10 : tensor<256xf32, #blocked0>
%12 = tt.splat %arg2 : (!tt.ptr<f32>) -> tensor<256x!tt.ptr<f32>, #blocked0>
%13 = tt.addptr %12, %4 : tensor<256x!tt.ptr<f32>, #blocked0>
%13 = tt.addptr %12, %4 : tensor<256x!tt.ptr<f32>, #blocked0>, tensor<256xi32, #blocked0>
// Store 4 elements to global with single one vectorized store instruction
// CHECK: @$5 st.global.v4.b32 [ ${{.*}} + 0 ], { ${{.*}}, ${{.*}}, ${{.*}}, ${{.*}} };
@@ -173,9 +173,9 @@ module attributes {"triton_gpu.num-warps" = 2 : i32} {
%3 = tt.splat %1 : (i32) -> tensor<64xi32, #blocked>
%4 = arith.addi %3, %2 : tensor<64xi32, #blocked>
%5 = tt.splat %arg0 : (!tt.ptr<f32>) -> tensor<64x!tt.ptr<f32>, #blocked>
%6 = tt.addptr %5, %4 : tensor<64x!tt.ptr<f32>, #blocked>
%6 = tt.addptr %5, %4 : tensor<64x!tt.ptr<f32>, #blocked>, tensor<64xi32, #blocked>
%7 = tt.splat %arg1 : (!tt.ptr<f32>) -> tensor<64x!tt.ptr<f32>, #blocked>
%8 = tt.addptr %7, %4 : tensor<64x!tt.ptr<f32>, #blocked>
%8 = tt.addptr %7, %4 : tensor<64x!tt.ptr<f32>, #blocked>, tensor<64xi32, #blocked>
%9 = tt.splat %n_elements : (i32) -> tensor<64xi32, #blocked>
%10 = "triton_gpu.cmpi"(%4, %9) {predicate = 2 : i64} : (tensor<64xi32, #blocked>, tensor<64xi32, #blocked>) -> tensor<64xi1, #blocked>
// load op has a vector width = 1 due to the %mask's alignment
@@ -184,7 +184,7 @@ module attributes {"triton_gpu.num-warps" = 2 : i32} {
%12 = tt.load %8, %10 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<64xf32, #blocked>
%13 = arith.addf %11, %12 : tensor<64xf32, #blocked>
%14 = tt.splat %arg2 : (!tt.ptr<f32>) -> tensor<64x!tt.ptr<f32>, #blocked>
%15 = tt.addptr %14, %4 : tensor<64x!tt.ptr<f32>, #blocked>
%15 = tt.addptr %14, %4 : tensor<64x!tt.ptr<f32>, #blocked>, tensor<64xi32, #blocked>
tt.store %15, %13, %10 : tensor<64xf32, #blocked>
return
}
@@ -203,9 +203,9 @@ module attributes {"triton_gpu.num-warps" = 1 : i32} {
%3 = tt.splat %1 : (i32) -> tensor<256xi32, #blocked0>
%4 = arith.addi %3, %2 : tensor<256xi32, #blocked0>
%5 = tt.splat %arg0 : (!tt.ptr<f32>) -> tensor<256x!tt.ptr<f32>, #blocked0>
%6 = tt.addptr %5, %4 : tensor<256x!tt.ptr<f32>, #blocked0>
%6 = tt.addptr %5, %4 : tensor<256x!tt.ptr<f32>, #blocked0>, tensor<256xi32, #blocked0>
%7 = tt.splat %arg1 : (!tt.ptr<f32>) -> tensor<256x!tt.ptr<f32>, #blocked0>
%8 = tt.addptr %7, %4 : tensor<256x!tt.ptr<f32>, #blocked0>
%8 = tt.addptr %7, %4 : tensor<256x!tt.ptr<f32>, #blocked0>, tensor<256xi32, #blocked0>
// Load 8 elements from A with two vectorized load instruction
// CHECK: @${{.*}} ld.global.v4.b32 { ${{.*}}, ${{.*}}, ${{.*}}, ${{.*}} }, [ ${{.*}} + 0 ];
@@ -219,7 +219,7 @@ module attributes {"triton_gpu.num-warps" = 1 : i32} {
%10 = tt.load %8 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<256xf32, #blocked0>
%11 = arith.addf %9, %10 : tensor<256xf32, #blocked0>
%12 = tt.splat %arg2 : (!tt.ptr<f32>) -> tensor<256x!tt.ptr<f32>, #blocked0>
%13 = tt.addptr %12, %4 : tensor<256x!tt.ptr<f32>, #blocked0>
%13 = tt.addptr %12, %4 : tensor<256x!tt.ptr<f32>, #blocked0>, tensor<256xi32, #blocked0>
// Store 8 elements to global with two vectorized store instruction
// CHECK: @$5 st.global.v4.b32 [ ${{.*}} + 0 ], { ${{.*}}, ${{.*}}, ${{.*}}, ${{.*}} };
@@ -317,7 +317,7 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} {
func @basic_addptr(%arg0 : tensor<256x!tt.ptr<f32>,#blocked0>, %arg1 : tensor<256xi32,#blocked0>) {
// CHECK: llvm.getelementptr
// CHECK: llvm.getelementptr
%0 = tt.addptr %arg0, %arg1 : tensor<256x!tt.ptr<f32>, #blocked0>
%0 = tt.addptr %arg0, %arg1 : tensor<256x!tt.ptr<f32>, #blocked0>, tensor<256xi32, #blocked0>
return
}
}
@@ -411,7 +411,7 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} {
%broadcast_off1 = triton_gpu.convert_layout %broadcast_off1_ : (tensor<16x64xi32, #block3>) -> tensor<16x64xi32, #AL>
%off = arith.addi %broadcast_off0, %broadcast_off1 : tensor<16x64xi32, #AL>
%a_init = tt.splat %arg0 : (!tt.ptr<f16>) -> tensor<16x64x!tt.ptr<f16>, #AL>
%a_ptr = tt.addptr %a_init, %off : tensor<16x64x!tt.ptr<f16>, #AL>
%a_ptr = tt.addptr %a_init, %off : tensor<16x64x!tt.ptr<f16>, #AL>, tensor<16x64xi32, #AL>
%tensor = triton_gpu.alloc_tensor : tensor<2x16x64xf16, #A>
%index = arith.constant 1 : i32
@@ -450,7 +450,7 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} {
%broadcast_off1 = triton_gpu.convert_layout %broadcast_off1_ : (tensor<16x64xi32, #block3>) -> tensor<16x64xi32, #AL>
%off = arith.addi %broadcast_off0, %broadcast_off1 : tensor<16x64xi32, #AL>
%a_init = tt.splat %arg0 : (!tt.ptr<f32>) -> tensor<16x64x!tt.ptr<f32>, #AL>
%a_ptr = tt.addptr %a_init, %off : tensor<16x64x!tt.ptr<f32>, #AL>
%a_ptr = tt.addptr %a_init, %off : tensor<16x64x!tt.ptr<f32>, #AL>, tensor<16x64xi32, #AL>
%tensor = triton_gpu.alloc_tensor : tensor<2x16x64xf32, #A>
%index = arith.constant 1 : i32
@@ -491,7 +491,7 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} {
%broadcast_off1 = triton_gpu.convert_layout %broadcast_off1_ : (tensor<16x32xi32, #block3>) -> tensor<16x32xi32, #AL>
%off = arith.addi %broadcast_off0, %broadcast_off1 : tensor<16x32xi32, #AL>
%a_init = tt.splat %arg0 : (!tt.ptr<f32>) -> tensor<16x32x!tt.ptr<f32>, #AL>
%a_ptr = tt.addptr %a_init, %off : tensor<16x32x!tt.ptr<f32>, #AL>
%a_ptr = tt.addptr %a_init, %off : tensor<16x32x!tt.ptr<f32>, #AL>, tensor<16x32xi32, #AL>
%tensor = triton_gpu.alloc_tensor : tensor<2x16x32xf32, #A>
%index = arith.constant 1 : i32
@@ -535,7 +535,7 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} {
%broadcast_off1 = triton_gpu.convert_layout %broadcast_off1_ : (tensor<32x32xi32, #block3>) -> tensor<32x32xi32, #AL>
%off = arith.addi %broadcast_off0, %broadcast_off1 : tensor<32x32xi32, #AL>
%a_init = tt.splat %arg0 : (!tt.ptr<f32>) -> tensor<32x32x!tt.ptr<f32>, #AL>
%a_ptr = tt.addptr %a_init, %off : tensor<32x32x!tt.ptr<f32>, #AL>
%a_ptr = tt.addptr %a_init, %off : tensor<32x32x!tt.ptr<f32>, #AL>, tensor<32x32xi32, #AL>
%tensor = triton_gpu.alloc_tensor : tensor<2x32x32xf32, #A>
%index = arith.constant 1 : i32
@@ -879,8 +879,8 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} {
#blocked = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [2, 16], warpsPerCTA = [1, 4], order = [1, 0]}>
#shared = #triton_gpu.shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [1, 0]}>
#mma = #triton_gpu.mma<{version = 1, warpsPerCTA = [2, 2]}>
#dot_operand_a = #triton_gpu.dot_op<{opIdx=0, parent=#mma}>
#dot_operand_b = #triton_gpu.dot_op<{opIdx=1, parent=#mma}>
#dot_operand_a = #triton_gpu.dot_op<{opIdx=0, parent=#mma, isMMAv1Row=true}>
#dot_operand_b = #triton_gpu.dot_op<{opIdx=1, parent=#mma, isMMAv1Row=true}>
module attributes {"triton_gpu.num-warps" = 4 : i32} {
func @matmul884_kernel_dot_operand_layout(%ptr:!tt.ptr<f32> {tt.divisibility = 16 : i32},
%a:tensor<128x32xf16, #shared>, %b:tensor<32x256xf16, #shared>) {

View File

@@ -22,28 +22,30 @@ func @test_combine_dot_add_pattern() -> (tensor<128x128xf32>, tensor<128x128xf32
return %res0, %res1 : tensor<128x128xf32>, tensor<128x128xf32>
}
// CHECK-LABEL: @test_combine_addptr_pattern
// COM: CHECK-LABEL: @test_combine_addptr_pattern
func @test_combine_addptr_pattern(%base: !tt.ptr<f32>) -> tensor<8x!tt.ptr<f32>> {
%off0 = arith.constant 10 : i32
%off1 = arith.constant 15 : i32
// 10 + 15 = 25
// CHECK-NEXT: %[[cst:.*]] = arith.constant dense<25> : tensor<8xi32>
// COM: CHECK-NEXT: %[[cst:.*]] = arith.constant dense<25> : tensor<8xi32>
%base_ = tt.broadcast %base : (!tt.ptr<f32>) -> tensor<8x!tt.ptr<f32>>
// CHECK-NEXT: %[[tmp0:.*]] = tt.broadcast %{{.*}} : (!tt.ptr<f32>) -> tensor<8x!tt.ptr<f32>>
// COM: CHECK-NEXT: %[[tmp0:.*]] = tt.broadcast %{{.*}} : (!tt.ptr<f32>) -> tensor<8x!tt.ptr<f32>>
%idx0 = tt.broadcast %off0 : (i32) -> tensor<8xi32>
%idx1 = tt.broadcast %off1 : (i32) -> tensor<8xi32>
// CHECK-NEXT: %1 = tt.addptr %[[tmp0]], %[[cst]] : tensor<8x!tt.ptr<f32>>
%ptr0 = tt.addptr %base_, %idx0 : tensor<8x!tt.ptr<f32>>
%ptr1 = tt.addptr %ptr0, %idx1 : tensor<8x!tt.ptr<f32>>
// COM: CHECK-NEXT: %1 = tt.addptr %[[tmp0]], %[[cst]] : tensor<8x!tt.ptr<f32>>, tensor<8xi32>
%ptr0 = tt.addptr %base_, %idx0 : tensor<8x!tt.ptr<f32>>, tensor<8xi32>
%ptr1 = tt.addptr %ptr0, %idx1 : tensor<8x!tt.ptr<f32>>, tensor<8xi32>
return %ptr1 : tensor<8x!tt.ptr<f32>>
}
// CHECK-LABEL: @test_combine_select_masked_load_pattern
func @test_combine_select_masked_load_pattern(%ptr: tensor<8x!tt.ptr<f32>>, %cond: i1) -> (tensor<8xf32>, tensor<8xf32>) {
%mask = tt.broadcast %cond : (i1) -> tensor<8xi1>

View File

@@ -11,9 +11,9 @@ module {
%5 = tt.broadcast %arg3 : (i32) -> tensor<256xi32>
%6 = arith.cmpi slt, %4, %5 : tensor<256xi32>
%7 = tt.broadcast %arg0 : (!tt.ptr<f32>) -> tensor<256x!tt.ptr<f32>>
%8 = tt.addptr %7, %4 : tensor<256x!tt.ptr<f32>>
%8 = tt.addptr %7, %4 : tensor<256x!tt.ptr<f32>>, tensor<256xi32>
%9 = tt.broadcast %arg1 : (!tt.ptr<f32>) -> tensor<256x!tt.ptr<f32>>
%10 = tt.addptr %9, %4 : tensor<256x!tt.ptr<f32>>
%10 = tt.addptr %9, %4 : tensor<256x!tt.ptr<f32>>, tensor<256xi32>
%cst = arith.constant 0.000000e+00 : f32
%11 = tt.broadcast %cst : (f32) -> tensor<256xf32>
%c0_i32 = arith.constant 0 : i32
@@ -31,13 +31,13 @@ module {
%22 = arith.addf %19, %21 : tensor<256xf32>
%23 = arith.addf %arg7, %22 : tensor<256xf32>
%24 = tt.broadcast %arg5 : (i32) -> tensor<256xi32>
%25 = tt.addptr %arg8, %24 : tensor<256x!tt.ptr<f32>>
%25 = tt.addptr %arg8, %24 : tensor<256x!tt.ptr<f32>>, tensor<256xi32>
%26 = tt.broadcast %arg5 : (i32) -> tensor<256xi32>
%27 = tt.addptr %arg9, %26 : tensor<256x!tt.ptr<f32>>
%27 = tt.addptr %arg9, %26 : tensor<256x!tt.ptr<f32>>, tensor<256xi32>
scf.yield %23, %25, %27 : tensor<256xf32>, tensor<256x!tt.ptr<f32>>, tensor<256x!tt.ptr<f32>>
}
%16 = tt.broadcast %arg2 : (!tt.ptr<f32>) -> tensor<256x!tt.ptr<f32>>
%17 = tt.addptr %16, %4 : tensor<256x!tt.ptr<f32>>
%17 = tt.addptr %16, %4 : tensor<256x!tt.ptr<f32>>, tensor<256xi32>
tt.store %17, %15#0, %6 : tensor<256xf32>
return
}
@@ -57,9 +57,9 @@ module {
// %5 = tt.broadcast %arg3 : (i32) -> tensor<256xi32, #triton_gpu<"coalesced encoding<threadTileSize = 1, blockTileSize = 32, order = 0>">>
// %6 = "triton_gpu.cmpi"(%4, %5) {predicate = 2 : i64} : (tensor<256xi32, #triton_gpu<"coalesced encoding<threadTileSize = 1, blockTileSize = 32, order = 0>">>, tensor<256xi32, #triton_gpu<"coalesced encoding<threadTileSize = 1, blockTileSize = 32, order = 0>">>) -> tensor<256xi1, #triton_gpu<"coalesced encoding<threadTileSize = 1, blockTileSize = 32, order = 0>">>
// %7 = tt.broadcast %arg0 : (!tt.ptr<f32>) -> tensor<256x!tt.ptr<f32>, #triton_gpu<"coalesced encoding<threadTileSize = 1, blockTileSize = 32, order = 0>">>
// %8 = tt.addptr %7, %4, : tensor<256x!tt.ptr<f32>, #triton_gpu<"coalesced encoding<threadTileSize = 1, blockTileSize = 32, order = 0>">>
// %8 = tt.addptr %7, %4, : tensor<256x!tt.ptr<f32>, #triton_gpu<"coalesced encoding<threadTileSize = 1, blockTileSize = 32, order = 0>">>, tensor<256xi32>
// %9 = tt.broadcast %arg1 : (!tt.ptr<f32>) -> tensor<256x!tt.ptr<f32>, #triton_gpu<"coalesced encoding<threadTileSize = 1, blockTileSize = 32, order = 0>">>
// %10 = tt.addptr %9, %4, : tensor<256x!tt.ptr<f32>, #triton_gpu<"coalesced encoding<threadTileSize = 1, blockTileSize = 32, order = 0>">>
// %10 = tt.addptr %9, %4, : tensor<256x!tt.ptr<f32>, #triton_gpu<"coalesced encoding<threadTileSize = 1, blockTileSize = 32, order = 0>">>, tensor<256xi32>
// %11 = tt.broadcast %cst : (f32) -> tensor<256xf32, #triton_gpu<"coalesced encoding<threadTileSize = 1, blockTileSize = 32, order = 0>">>
// %12 = arith.index_cast %arg4 : i32 to index
// %13 = arith.cmpi slt, %c0, %12 : index
@@ -72,9 +72,9 @@ module {
// %20 = arith.andi %6, %19 : tensor<256xi1, #triton_gpu<"coalesced encoding<threadTileSize = 1, blockTileSize = 32, order = 0>">>
// %21 = triton_gpu.copy_async %10, %20, %18 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<256x!tt.ptr<f32>, #triton_gpu<"coalesced encoding<threadTileSize = 1, blockTileSize = 32, order = 0>">> -> tensor<256xf32, #triton_gpu<"coalesced encoding<threadTileSize = 1, blockTileSize = 32, order = 0>">>
// %22 = tt.broadcast %arg5 : (i32) -> tensor<256xi32, #triton_gpu<"coalesced encoding<threadTileSize = 1, blockTileSize = 32, order = 0>">>
// %23 = tt.addptr %8, %22, : tensor<256x!tt.ptr<f32>, #triton_gpu<"coalesced encoding<threadTileSize = 1, blockTileSize = 32, order = 0>">>
// %23 = tt.addptr %8, %22, : tensor<256x!tt.ptr<f32>, #triton_gpu<"coalesced encoding<threadTileSize = 1, blockTileSize = 32, order = 0>">>, tensor<256xi32>
// %24 = tt.broadcast %arg5 : (i32) -> tensor<256xi32, #triton_gpu<"coalesced encoding<threadTileSize = 1, blockTileSize = 32, order = 0>">>
// %25 = tt.addptr %10, %24, : tensor<256x!tt.ptr<f32>, #triton_gpu<"coalesced encoding<threadTileSize = 1, blockTileSize = 32, order = 0>">>
// %25 = tt.addptr %10, %24, : tensor<256x!tt.ptr<f32>, #triton_gpu<"coalesced encoding<threadTileSize = 1, blockTileSize = 32, order = 0>">>, tensor<256xi32>
// %26 = arith.cmpi slt, %c32, %12 : index
// %27 = tt.broadcast %cst : (f32) -> tensor<256xf32, #triton_gpu<"coalesced encoding<threadTileSize = 1, blockTileSize = 32, order = 0>">>
// %28 = tt.broadcast %26 : (i1) -> tensor<256xi1, #triton_gpu<"coalesced encoding<threadTileSize = 1, blockTileSize = 32, order = 0>">>
@@ -85,9 +85,9 @@ module {
// %33 = arith.andi %6, %32 : tensor<256xi1, #triton_gpu<"coalesced encoding<threadTileSize = 1, blockTileSize = 32, order = 0>">>
// %34 = triton_gpu.copy_async %25, %33, %31 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<256x!tt.ptr<f32>, #triton_gpu<"coalesced encoding<threadTileSize = 1, blockTileSize = 32, order = 0>">> -> tensor<256xf32, #triton_gpu<"coalesced encoding<threadTileSize = 1, blockTileSize = 32, order = 0>">>
// %35 = tt.broadcast %arg5 : (i32) -> tensor<256xi32, #triton_gpu<"coalesced encoding<threadTileSize = 1, blockTileSize = 32, order = 0>">>
// %36 = tt.addptr %23, %35, : tensor<256x!tt.ptr<f32>, #triton_gpu<"coalesced encoding<threadTileSize = 1, blockTileSize = 32, order = 0>">>
// %36 = tt.addptr %23, %35, : tensor<256x!tt.ptr<f32>, #triton_gpu<"coalesced encoding<threadTileSize = 1, blockTileSize = 32, order = 0>">>, tensor<256xi32>
// %37 = tt.broadcast %arg5 : (i32) -> tensor<256xi32, #triton_gpu<"coalesced encoding<threadTileSize = 1, blockTileSize = 32, order = 0>">>
// %38 = tt.addptr %25, %37, : tensor<256x!tt.ptr<f32>, #triton_gpu<"coalesced encoding<threadTileSize = 1, blockTileSize = 32, order = 0>">>
// %38 = tt.addptr %25, %37, : tensor<256x!tt.ptr<f32>, #triton_gpu<"coalesced encoding<threadTileSize = 1, blockTileSize = 32, order = 0>">>, tensor<256xi32>
// %39 = arith.cmpi slt, %c64, %12 : index
// %40 = tt.broadcast %cst : (f32) -> tensor<256xf32, #triton_gpu<"coalesced encoding<threadTileSize = 1, blockTileSize = 32, order = 0>">>
// %41 = tt.broadcast %39 : (i1) -> tensor<256xi1, #triton_gpu<"coalesced encoding<threadTileSize = 1, blockTileSize = 32, order = 0>">>
@@ -98,16 +98,16 @@ module {
// %46 = arith.andi %6, %45 : tensor<256xi1, #triton_gpu<"coalesced encoding<threadTileSize = 1, blockTileSize = 32, order = 0>">>
// %47 = triton_gpu.copy_async %38, %46, %44 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<256x!tt.ptr<f32>, #triton_gpu<"coalesced encoding<threadTileSize = 1, blockTileSize = 32, order = 0>">> -> tensor<256xf32, #triton_gpu<"coalesced encoding<threadTileSize = 1, blockTileSize = 32, order = 0>">>
// %48 = tt.broadcast %arg5 : (i32) -> tensor<256xi32, #triton_gpu<"coalesced encoding<threadTileSize = 1, blockTileSize = 32, order = 0>">>
// %49 = tt.addptr %36, %48, : tensor<256x!tt.ptr<f32>, #triton_gpu<"coalesced encoding<threadTileSize = 1, blockTileSize = 32, order = 0>">>
// %49 = tt.addptr %36, %48, : tensor<256x!tt.ptr<f32>, #triton_gpu<"coalesced encoding<threadTileSize = 1, blockTileSize = 32, order = 0>">>, tensor<256xi32>
// %50 = tt.broadcast %arg5 : (i32) -> tensor<256xi32, #triton_gpu<"coalesced encoding<threadTileSize = 1, blockTileSize = 32, order = 0>">>
// %51 = tt.addptr %38, %50, : tensor<256x!tt.ptr<f32>, #triton_gpu<"coalesced encoding<threadTileSize = 1, blockTileSize = 32, order = 0>">>
// %51 = tt.addptr %38, %50, : tensor<256x!tt.ptr<f32>, #triton_gpu<"coalesced encoding<threadTileSize = 1, blockTileSize = 32, order = 0>">>, tensor<256xi32>
// %52:12 = scf.for %arg6 = %c0 to %12 step %c32 iter_args(%arg7 = %11, %arg8 = %8, %arg9 = %10, %arg10 = %17, %arg11 = %30, %arg12 = %43, %arg13 = %21, %arg14 = %34, %arg15 = %47, %arg16 = %51, %arg17 = %49, %arg18 = %c64) -> (tensor<256xf32, #triton_gpu<"coalesced encoding<threadTileSize = 1, blockTileSize = 32, order = 0>">>, tensor<256x!tt.ptr<f32>, #triton_gpu<"coalesced encoding<threadTileSize = 1, blockTileSize = 32, order = 0>">>, tensor<256x!tt.ptr<f32>, #triton_gpu<"coalesced encoding<threadTileSize = 1, blockTileSize = 32, order = 0>">>, tensor<256xf32, #triton_gpu<"coalesced encoding<threadTileSize = 1, blockTileSize = 32, order = 0>">>, tensor<256xf32, #triton_gpu<"coalesced encoding<threadTileSize = 1, blockTileSize = 32, order = 0>">>, tensor<256xf32, #triton_gpu<"coalesced encoding<threadTileSize = 1, blockTileSize = 32, order = 0>">>, tensor<256xf32, #triton_gpu<"coalesced encoding<threadTileSize = 1, blockTileSize = 32, order = 0>">>, tensor<256xf32, #triton_gpu<"coalesced encoding<threadTileSize = 1, blockTileSize = 32, order = 0>">>, tensor<256xf32, #triton_gpu<"coalesced encoding<threadTileSize = 1, blockTileSize = 32, order = 0>">>, tensor<256x!tt.ptr<f32>, #triton_gpu<"coalesced encoding<threadTileSize = 1, blockTileSize = 32, order = 0>">>, tensor<256x!tt.ptr<f32>, #triton_gpu<"coalesced encoding<threadTileSize = 1, blockTileSize = 32, order = 0>">>, index) {
// %55 = arith.addf %arg10, %arg13 : tensor<256xf32, #triton_gpu<"coalesced encoding<threadTileSize = 1, blockTileSize = 32, order = 0>">>
// %56 = arith.addf %arg7, %55 : tensor<256xf32, #triton_gpu<"coalesced encoding<threadTileSize = 1, blockTileSize = 32, order = 0>">>
// %57 = tt.broadcast %arg5 : (i32) -> tensor<256xi32, #triton_gpu<"coalesced encoding<threadTileSize = 1, blockTileSize = 32, order = 0>">>
// %58 = tt.addptr %arg8, %57, : tensor<256x!tt.ptr<f32>, #triton_gpu<"coalesced encoding<threadTileSize = 1, blockTileSize = 32, order = 0>">>
// %58 = tt.addptr %arg8, %57, : tensor<256x!tt.ptr<f32>, #triton_gpu<"coalesced encoding<threadTileSize = 1, blockTileSize = 32, order = 0>">>, tensor<256xi32>
// %59 = tt.broadcast %arg5 : (i32) -> tensor<256xi32, #triton_gpu<"coalesced encoding<threadTileSize = 1, blockTileSize = 32, order = 0>">>
// %60 = tt.addptr %arg9, %59, : tensor<256x!tt.ptr<f32>, #triton_gpu<"coalesced encoding<threadTileSize = 1, blockTileSize = 32, order = 0>">>
// %60 = tt.addptr %arg9, %59, : tensor<256x!tt.ptr<f32>, #triton_gpu<"coalesced encoding<threadTileSize = 1, blockTileSize = 32, order = 0>">>, tensor<256xi32>
// %61 = arith.addi %arg18, %c32 : index
// %62 = arith.cmpi slt, %61, %12 : index
// %63 = tt.broadcast %cst : (f32) -> tensor<256xf32, #triton_gpu<"coalesced encoding<threadTileSize = 1, blockTileSize = 32, order = 0>">>
@@ -117,13 +117,13 @@ module {
// %67 = tt.broadcast %cst : (f32) -> tensor<256xf32, #triton_gpu<"coalesced encoding<threadTileSize = 1, blockTileSize = 32, order = 0>">>
// %68 = triton_gpu.copy_async %arg16, %65, %67 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<256x!tt.ptr<f32>, #triton_gpu<"coalesced encoding<threadTileSize = 1, blockTileSize = 32, order = 0>">> -> tensor<256xf32, #triton_gpu<"coalesced encoding<threadTileSize = 1, blockTileSize = 32, order = 0>">>
// %69 = tt.broadcast %arg5 : (i32) -> tensor<256xi32, #triton_gpu<"coalesced encoding<threadTileSize = 1, blockTileSize = 32, order = 0>">>
// %70 = tt.addptr %arg17, %69, : tensor<256x!tt.ptr<f32>, #triton_gpu<"coalesced encoding<threadTileSize = 1, blockTileSize = 32, order = 0>">>
// %70 = tt.addptr %arg17, %69, : tensor<256x!tt.ptr<f32>, #triton_gpu<"coalesced encoding<threadTileSize = 1, blockTileSize = 32, order = 0>">>, tensor<256xi32>
// %71 = tt.broadcast %arg5 : (i32) -> tensor<256xi32, #triton_gpu<"coalesced encoding<threadTileSize = 1, blockTileSize = 32, order = 0>">>
// %72 = tt.addptr %arg16, %71, : tensor<256x!tt.ptr<f32>, #triton_gpu<"coalesced encoding<threadTileSize = 1, blockTileSize = 32, order = 0>">>
// %72 = tt.addptr %arg16, %71, : tensor<256x!tt.ptr<f32>, #triton_gpu<"coalesced encoding<threadTileSize = 1, blockTileSize = 32, order = 0>">>, tensor<256xi32>
// scf.yield %56, %58, %60, %arg11, %arg12, %66, %arg14, %arg15, %68, %72, %70, %61 : tensor<256xf32, #triton_gpu<"coalesced encoding<threadTileSize = 1, blockTileSize = 32, order = 0>">>, tensor<256x!tt.ptr<f32>, #triton_gpu<"coalesced encoding<threadTileSize = 1, blockTileSize = 32, order = 0>">>, tensor<256x!tt.ptr<f32>, #triton_gpu<"coalesced encoding<threadTileSize = 1, blockTileSize = 32, order = 0>">>, tensor<256xf32, #triton_gpu<"coalesced encoding<threadTileSize = 1, blockTileSize = 32, order = 0>">>, tensor<256xf32, #triton_gpu<"coalesced encoding<threadTileSize = 1, blockTileSize = 32, order = 0>">>, tensor<256xf32, #triton_gpu<"coalesced encoding<threadTileSize = 1, blockTileSize = 32, order = 0>">>, tensor<256xf32, #triton_gpu<"coalesced encoding<threadTileSize = 1, blockTileSize = 32, order = 0>">>, tensor<256xf32, #triton_gpu<"coalesced encoding<threadTileSize = 1, blockTileSize = 32, order = 0>">>, tensor<256xf32, #triton_gpu<"coalesced encoding<threadTileSize = 1, blockTileSize = 32, order = 0>">>, tensor<256x!tt.ptr<f32>, #triton_gpu<"coalesced encoding<threadTileSize = 1, blockTileSize = 32, order = 0>">>, tensor<256x!tt.ptr<f32>, #triton_gpu<"coalesced encoding<threadTileSize = 1, blockTileSize = 32, order = 0>">>, index
// }
// %53 = tt.broadcast %arg2 : (!tt.ptr<f32>) -> tensor<256x!tt.ptr<f32>, #triton_gpu<"coalesced encoding<threadTileSize = 1, blockTileSize = 32, order = 0>">>
// %54 = tt.addptr %53, %4, : tensor<256x!tt.ptr<f32>, #triton_gpu<"coalesced encoding<threadTileSize = 1, blockTileSize = 32, order = 0>">>
// %54 = tt.addptr %53, %4, : tensor<256x!tt.ptr<f32>, #triton_gpu<"coalesced encoding<threadTileSize = 1, blockTileSize = 32, order = 0>">>, tensor<256xi32>
// tt.store %54, %52#0, %6 : tensor<256xf32, #triton_gpu<"coalesced encoding<threadTileSize = 1, blockTileSize = 32, order = 0>">>
// return
// }

View File

@@ -31,20 +31,20 @@ func @transpose(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32},
%2 = tt.splat %arg1 : (i32) -> tensor<64x1xi32, #blocked1>
%3 = arith.muli %1, %2 : tensor<64x1xi32, #blocked1>
%4 = tt.splat %arg0 : (!tt.ptr<f32>) -> tensor<64x1x!tt.ptr<f32>, #blocked1>
%5 = tt.addptr %4, %3 : tensor<64x1x!tt.ptr<f32>, #blocked1>
%5 = tt.addptr %4, %3 : tensor<64x1x!tt.ptr<f32>, #blocked1>, tensor<64x1xi32, #blocked1>
%6 = tt.expand_dims %01 {axis = 0 : i32} : (tensor<64xi32, #slice2dim0>) -> tensor<1x64xi32, #blocked2>
%7 = tt.broadcast %5 : (tensor<64x1x!tt.ptr<f32>, #blocked1>) -> tensor<64x64x!tt.ptr<f32>, #blocked1>
%8 = tt.broadcast %6 : (tensor<1x64xi32, #blocked2>) -> tensor<64x64xi32, #blocked2>
%9 = triton_gpu.convert_layout %8 : (tensor<64x64xi32, #blocked2>) -> tensor<64x64xi32, #blocked1>
%10 = tt.addptr %7, %9 : tensor<64x64x!tt.ptr<f32>, #blocked1>
%10 = tt.addptr %7, %9 : tensor<64x64x!tt.ptr<f32>, #blocked1>, tensor<64x64xi32, #blocked1>
%11 = tt.splat %arg2 : (!tt.ptr<f32>) -> tensor<64x1x!tt.ptr<f32>, #blocked1>
%12 = tt.addptr %11, %1 : tensor<64x1x!tt.ptr<f32>, #blocked1>
%12 = tt.addptr %11, %1 : tensor<64x1x!tt.ptr<f32>, #blocked1>, tensor<64x1xi32, #blocked1>
%13 = tt.splat %arg3 : (i32) -> tensor<1x64xi32, #blocked2>
%14 = arith.muli %6, %13 : tensor<1x64xi32, #blocked2>
%15 = tt.broadcast %12 : (tensor<64x1x!tt.ptr<f32>, #blocked1>) -> tensor<64x64x!tt.ptr<f32>, #blocked1>
%16 = tt.broadcast %14 : (tensor<1x64xi32, #blocked2>) -> tensor<64x64xi32, #blocked2>
%17 = triton_gpu.convert_layout %16 : (tensor<64x64xi32, #blocked2>) -> tensor<64x64xi32, #blocked1>
%18 = tt.addptr %15, %17 : tensor<64x64x!tt.ptr<f32>, #blocked1>
%18 = tt.addptr %15, %17 : tensor<64x64x!tt.ptr<f32>, #blocked1>, tensor<64x64xi32, #blocked1>
%19 = tt.load %10, %cst, %cst_0 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<64x64xf32, #blocked1>
tt.store %18, %19, %cst : tensor<64x64xf32, #blocked1>
return

View File

@@ -74,20 +74,20 @@ func @transpose(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg1: i32 {tt
%2 = tt.splat %arg1 : (i32) -> tensor<64x1xi32, #blocked1>
%3 = arith.muli %1, %2 : tensor<64x1xi32, #blocked1>
%4 = tt.splat %arg0 : (!tt.ptr<f32>) -> tensor<64x1x!tt.ptr<f32>, #blocked1>
%5 = tt.addptr %4, %3 : tensor<64x1x!tt.ptr<f32>, #blocked1>
%5 = tt.addptr %4, %3 : tensor<64x1x!tt.ptr<f32>, #blocked1>, tensor<64x1xi32, #blocked1>
%6 = tt.expand_dims %01 {axis = 0 : i32} : (tensor<64xi32, #slice2dim0>) -> tensor<1x64xi32, #blocked2>
%7 = tt.broadcast %5 : (tensor<64x1x!tt.ptr<f32>, #blocked1>) -> tensor<64x64x!tt.ptr<f32>, #blocked1>
%8 = tt.broadcast %6 : (tensor<1x64xi32, #blocked2>) -> tensor<64x64xi32, #blocked2>
%9 = triton_gpu.convert_layout %8 : (tensor<64x64xi32, #blocked2>) -> tensor<64x64xi32, #blocked1>
%10 = tt.addptr %7, %9 : tensor<64x64x!tt.ptr<f32>, #blocked1>
%10 = tt.addptr %7, %9 : tensor<64x64x!tt.ptr<f32>, #blocked1>, tensor<64x64xi32, #blocked1>
%11 = tt.splat %arg2 : (!tt.ptr<f32>) -> tensor<64x1x!tt.ptr<f32>, #blocked1>
%12 = tt.addptr %11, %1 : tensor<64x1x!tt.ptr<f32>, #blocked1>
%12 = tt.addptr %11, %1 : tensor<64x1x!tt.ptr<f32>, #blocked1>, tensor<64x1xi32, #blocked1>
%13 = tt.splat %arg3 : (i32) -> tensor<1x64xi32, #blocked2>
%14 = arith.muli %6, %13 : tensor<1x64xi32, #blocked2>
%15 = tt.broadcast %12 : (tensor<64x1x!tt.ptr<f32>, #blocked1>) -> tensor<64x64x!tt.ptr<f32>, #blocked1>
%16 = tt.broadcast %14 : (tensor<1x64xi32, #blocked2>) -> tensor<64x64xi32, #blocked2>
%17 = triton_gpu.convert_layout %16 : (tensor<64x64xi32, #blocked2>) -> tensor<64x64xi32, #blocked1>
%18 = tt.addptr %15, %17 : tensor<64x64x!tt.ptr<f32>, #blocked1>
%18 = tt.addptr %15, %17 : tensor<64x64x!tt.ptr<f32>, #blocked1>, tensor<64x64xi32, #blocked1>
%19 = triton_gpu.convert_layout %10 : (tensor<64x64x!tt.ptr<f32>, #blocked1>) -> tensor<64x64x!tt.ptr<f32>, #blocked3>
%20 = triton_gpu.convert_layout %cst_0 : (tensor<64x64xi1, #blocked1>) -> tensor<64x64xi1, #blocked3>
%21 = triton_gpu.convert_layout %cst : (tensor<64x64xf32, #blocked1>) -> tensor<64x64xf32, #blocked3>
@@ -106,7 +106,7 @@ func @loop(%arg0: !tt.ptr<f32>, %arg1: i32, %arg2: !tt.ptr<f32>, %arg3: i32, %ar
// CHECK: [[loop_ret:%.*]]:2 = scf.for {{.*}} -> (tensor<64x64xf32, [[row_layout]]>, tensor<64x64x!tt.ptr<f32>, [[row_layout]]>)
// CHECK-NEXT: {{.*}} = tt.load {{.*}} : tensor<64x64xf32, [[row_layout]]>
// CHECK-NEXT: {{.*}} = arith.addf {{.*}} : tensor<64x64xf32, [[row_layout]]>
// CHECK-NEXT: {{.*}} = tt.addptr {{.*}} : tensor<64x64x!tt.ptr<f32>, [[row_layout]]>
// CHECK-NEXT: {{.*}} = tt.addptr {{.*}} : tensor<64x64x!tt.ptr<f32>, [[row_layout]]>, tensor<64x64xi32, [[row_layout]]>
// CHECK-NEXT: scf.yield {{.*}} : tensor<64x64xf32, [[row_layout]]>, tensor<64x64x!tt.ptr<f32>, [[row_layout]]>
// CHECK-NEXT: }
// CHECK-NEXT: {{.*}} = triton_gpu.convert_layout [[loop_ret]]#0 : (tensor<64x64xf32, [[row_layout]]>) -> tensor<64x64xf32, [[col_layout_novec]]>
@@ -123,12 +123,12 @@ func @loop(%arg0: !tt.ptr<f32>, %arg1: i32, %arg2: !tt.ptr<f32>, %arg3: i32, %ar
%2 = tt.splat %arg1 : (i32) -> tensor<64x1xi32, #blocked1>
%3 = arith.muli %1, %2 : tensor<64x1xi32, #blocked1>
%4 = tt.splat %arg0 : (!tt.ptr<f32>) -> tensor<64x1x!tt.ptr<f32>, #blocked1>
%5 = tt.addptr %4, %3 : tensor<64x1x!tt.ptr<f32>, #blocked1>
%5 = tt.addptr %4, %3 : tensor<64x1x!tt.ptr<f32>, #blocked1>, tensor<64x1xi32, #blocked1>
%6 = tt.expand_dims %01 {axis = 0 : i32} : (tensor<64xi32, #slice2dim0>) -> tensor<1x64xi32, #blocked2>
%7 = tt.broadcast %5 : (tensor<64x1x!tt.ptr<f32>, #blocked1>) -> tensor<64x64x!tt.ptr<f32>, #blocked1>
%8 = tt.broadcast %6 : (tensor<1x64xi32, #blocked2>) -> tensor<64x64xi32, #blocked2>
%9 = triton_gpu.convert_layout %8 : (tensor<64x64xi32, #blocked2>) -> tensor<64x64xi32, #blocked1>
%10 = tt.addptr %7, %9 : tensor<64x64x!tt.ptr<f32>, #blocked1>
%10 = tt.addptr %7, %9 : tensor<64x64x!tt.ptr<f32>, #blocked1>, tensor<64x64xi32, #blocked1>
%11:2 = scf.for %arg5 = %c0 to %c32 step %c1 iter_args(%arg6 = %cst_1, %arg7 = %10) -> (tensor<64x64xf32, #blocked1>, tensor<64x64x!tt.ptr<f32>, #blocked1>) {
%23 = triton_gpu.convert_layout %arg7 : (tensor<64x64x!tt.ptr<f32>, #blocked1>) -> tensor<64x64x!tt.ptr<f32>, #blocked3>
%24 = triton_gpu.convert_layout %cst : (tensor<64x64xi1, #blocked1>) -> tensor<64x64xi1, #blocked3>
@@ -136,17 +136,17 @@ func @loop(%arg0: !tt.ptr<f32>, %arg1: i32, %arg2: !tt.ptr<f32>, %arg3: i32, %ar
%26 = tt.load %23, %24, %25 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<64x64xf32, #blocked3>
%27 = triton_gpu.convert_layout %26 : (tensor<64x64xf32, #blocked3>) -> tensor<64x64xf32, #blocked1>
%28 = arith.addf %arg6, %27 : tensor<64x64xf32, #blocked1>
%29 = tt.addptr %arg7, %cst_0 : tensor<64x64x!tt.ptr<f32>, #blocked1>
%29 = tt.addptr %arg7, %cst_0 : tensor<64x64x!tt.ptr<f32>, #blocked1>, tensor<64x64xi32, #blocked1>
scf.yield %28, %29 : tensor<64x64xf32, #blocked1>, tensor<64x64x!tt.ptr<f32>, #blocked1>
}
%12 = tt.splat %arg2 : (!tt.ptr<f32>) -> tensor<64x1x!tt.ptr<f32>, #blocked1>
%13 = tt.addptr %12, %1 : tensor<64x1x!tt.ptr<f32>, #blocked1>
%13 = tt.addptr %12, %1 : tensor<64x1x!tt.ptr<f32>, #blocked1>, tensor<64x1xi32, #blocked1>
%14 = tt.splat %arg3 : (i32) -> tensor<1x64xi32, #blocked2>
%15 = arith.muli %6, %14 : tensor<1x64xi32, #blocked2>
%16 = tt.broadcast %13 : (tensor<64x1x!tt.ptr<f32>, #blocked1>) -> tensor<64x64x!tt.ptr<f32>, #blocked1>
%17 = tt.broadcast %15 : (tensor<1x64xi32, #blocked2>) -> tensor<64x64xi32, #blocked2>
%18 = triton_gpu.convert_layout %17 : (tensor<64x64xi32, #blocked2>) -> tensor<64x64xi32, #blocked1>
%19 = tt.addptr %16, %18 : tensor<64x64x!tt.ptr<f32>, #blocked1>
%19 = tt.addptr %16, %18 : tensor<64x64x!tt.ptr<f32>, #blocked1>, tensor<64x64xi32, #blocked1>
%20 = triton_gpu.convert_layout %19 : (tensor<64x64x!tt.ptr<f32>, #blocked1>) -> tensor<64x64x!tt.ptr<f32>, #blocked1>
%21 = triton_gpu.convert_layout %11#0 : (tensor<64x64xf32, #blocked1>) -> tensor<64x64xf32, #blocked1>
%22 = triton_gpu.convert_layout %cst : (tensor<64x64xi1, #blocked1>) -> tensor<64x64xi1, #blocked1>
@@ -160,27 +160,27 @@ func @vecadd(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg1: !tt.ptr<f3
%c256_i32 = arith.constant 256 : i32
%0 = tt.get_program_id {axis = 0 : i32} : i32
%1 = arith.muli %0, %c256_i32 : i32
%2 = tt.splat %1 : (i32) -> tensor<256xi32, #triton_gpu.blocked<{sizePerThread = [4], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>>
%3 = tt.make_range {end = 256 : i32, start = 0 : i32} : tensor<256xi32, #triton_gpu.blocked<{sizePerThread = [4], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>>
%4 = tt.splat %1 : (i32) -> tensor<256xi32, #triton_gpu.blocked<{sizePerThread = [4], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>>
%5 = tt.make_range {end = 256 : i32, start = 0 : i32} : tensor<256xi32, #triton_gpu.blocked<{sizePerThread = [4], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>>
%6 = tt.splat %1 : (i32) -> tensor<256xi32, #triton_gpu.blocked<{sizePerThread = [4], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>>
%7 = tt.make_range {end = 256 : i32, start = 0 : i32} : tensor<256xi32, #triton_gpu.blocked<{sizePerThread = [4], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>>
%8 = tt.splat %arg0 : (!tt.ptr<f32>) -> tensor<256x!tt.ptr<f32>, #triton_gpu.blocked<{sizePerThread = [4], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>>
%9 = arith.addi %6, %7 : tensor<256xi32, #triton_gpu.blocked<{sizePerThread = [4], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>>
%10 = tt.splat %arg1 : (!tt.ptr<f32>) -> tensor<256x!tt.ptr<f32>, #triton_gpu.blocked<{sizePerThread = [4], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>>
%11 = arith.addi %4, %5 : tensor<256xi32, #triton_gpu.blocked<{sizePerThread = [4], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>>
%12 = tt.addptr %8, %9 : tensor<256x!tt.ptr<f32>, #triton_gpu.blocked<{sizePerThread = [4], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>>
%13 = tt.load %12 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<256xf32, #triton_gpu.blocked<{sizePerThread = [4], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>>
%14 = triton_gpu.convert_layout %13 : (tensor<256xf32, #triton_gpu.blocked<{sizePerThread = [4], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>>) -> tensor<256xf32, #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [2], order = [0]}>>
%15 = tt.addptr %10, %11 : tensor<256x!tt.ptr<f32>, #triton_gpu.blocked<{sizePerThread = [4], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>>
%16 = tt.load %15 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<256xf32, #triton_gpu.blocked<{sizePerThread = [4], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>>
%17 = triton_gpu.convert_layout %16 : (tensor<256xf32, #triton_gpu.blocked<{sizePerThread = [4], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>>) -> tensor<256xf32, #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [2], order = [0]}>>
%2 = tt.splat %1 : (i32) -> tensor<256xi32, #layout1>
%3 = tt.make_range {end = 256 : i32, start = 0 : i32} : tensor<256xi32, #layout1>
%4 = tt.splat %1 : (i32) -> tensor<256xi32, #layout1>
%5 = tt.make_range {end = 256 : i32, start = 0 : i32} : tensor<256xi32, #layout1>
%6 = tt.splat %1 : (i32) -> tensor<256xi32, #layout1>
%7 = tt.make_range {end = 256 : i32, start = 0 : i32} : tensor<256xi32, #layout1>
%8 = tt.splat %arg0 : (!tt.ptr<f32>) -> tensor<256x!tt.ptr<f32>, #layout1>
%9 = arith.addi %6, %7 : tensor<256xi32, #layout1>
%10 = tt.splat %arg1 : (!tt.ptr<f32>) -> tensor<256x!tt.ptr<f32>, #layout1>
%11 = arith.addi %4, %5 : tensor<256xi32, #layout1>
%12 = tt.addptr %8, %9 : tensor<256x!tt.ptr<f32>, #layout1>, tensor<256xi32, #layout1>
%13 = tt.load %12 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<256xf32, #layout1>
%14 = triton_gpu.convert_layout %13 : (tensor<256xf32, #layout1>) -> tensor<256xf32, #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [2], order = [0]}>>
%15 = tt.addptr %10, %11 : tensor<256x!tt.ptr<f32>, #layout1>, tensor<256xi32, #layout1>
%16 = tt.load %15 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<256xf32, #layout1>
%17 = triton_gpu.convert_layout %16 : (tensor<256xf32, #layout1>) -> tensor<256xf32, #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [2], order = [0]}>>
%18 = arith.addf %14, %17 : tensor<256xf32, #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [2], order = [0]}>>
%19 = tt.splat %arg2 : (!tt.ptr<f32>) -> tensor<256x!tt.ptr<f32>, #triton_gpu.blocked<{sizePerThread = [4], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>>
%20 = arith.addi %2, %3 : tensor<256xi32, #triton_gpu.blocked<{sizePerThread = [4], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>>
%21 = tt.addptr %19, %20 : tensor<256x!tt.ptr<f32>, #triton_gpu.blocked<{sizePerThread = [4], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>>
%22 = triton_gpu.convert_layout %18 : (tensor<256xf32, #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [2], order = [0]}>>) -> tensor<256xf32, #triton_gpu.blocked<{sizePerThread = [4], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>>
tt.store %21, %22 : tensor<256xf32, #triton_gpu.blocked<{sizePerThread = [4], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>>
%19 = tt.splat %arg2 : (!tt.ptr<f32>) -> tensor<256x!tt.ptr<f32>, #layout1>
%20 = arith.addi %2, %3 : tensor<256xi32, #layout1>
%21 = tt.addptr %19, %20 : tensor<256x!tt.ptr<f32>, #layout1>, tensor<256xi32, #layout1>
%22 = triton_gpu.convert_layout %18 : (tensor<256xf32, #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [2], order = [0]}>>) -> tensor<256xf32, #layout1>
tt.store %21, %22 : tensor<256xf32, #layout1>
return
}

View File

@@ -65,8 +65,8 @@ func @matmul_loop(%lb : index, %ub : index, %step : index, %A : !tt.ptr<f16>, %B
%c = tt.dot %a, %b, %prev_c {allowTF32 = true, transA = false, transB = false} : tensor<128x32xf16, #A> * tensor<32x128xf16, #B> -> tensor<128x128xf32, #C>
%next_a_ptr = tt.addptr %a_ptr, %a_off : tensor<128x32x!tt.ptr<f16>, #AL>
%next_b_ptr = tt.addptr %b_ptr, %b_off : tensor<32x128x!tt.ptr<f16>, #BL>
%next_a_ptr = tt.addptr %a_ptr, %a_off : tensor<128x32x!tt.ptr<f16>, #AL>, tensor<128x32xi32, #AL>
%next_b_ptr = tt.addptr %b_ptr, %b_off : tensor<32x128x!tt.ptr<f16>, #BL>, tensor<32x128xi32, #BL>
scf.yield %next_a_ptr, %next_b_ptr, %c : tensor<128x32x!tt.ptr<f16>, #AL>, tensor<32x128x!tt.ptr<f16>, #BL>, tensor<128x128xf32, #C>
}
return
@@ -125,8 +125,8 @@ func @matmul_loop_nested(%lb : index, %ub : index, %step : index, %A : !tt.ptr<f
%c = tt.dot %a, %b, %prev_c {allowTF32 = true, transA = false, transB = false} : tensor<128x32xf16, #A> * tensor<32x128xf16, #B> -> tensor<128x128xf32, #C>
%next_a_ptr = tt.addptr %a_ptr, %a_off : tensor<128x32x!tt.ptr<f16>, #AL>
%next_b_ptr = tt.addptr %b_ptr, %b_off : tensor<32x128x!tt.ptr<f16>, #BL>
%next_a_ptr = tt.addptr %a_ptr, %a_off : tensor<128x32x!tt.ptr<f16>, #AL>, tensor<128x32xi32, #AL>
%next_b_ptr = tt.addptr %b_ptr, %b_off : tensor<32x128x!tt.ptr<f16>, #BL>, tensor<32x128xi32, #BL>
scf.yield %next_a_ptr, %next_b_ptr, %c : tensor<128x32x!tt.ptr<f16>, #AL>, tensor<32x128x!tt.ptr<f16>, #BL>, tensor<128x128xf32, #C>
}
}
@@ -176,7 +176,7 @@ func @matmul_loop_single_pipeline(%lb : index, %ub : index, %step : index, %A :
%b_ = tt.load %b_ptr, %b_mask, %b_other {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<32x128xf16, #BL>
%b = triton_gpu.convert_layout %b_ : (tensor<32x128xf16, #BL>) -> tensor<32x128xf16, #B>
%c = tt.dot %a, %b, %prev_c {allowTF32 = true, transA = false, transB = false} : tensor<128x32xf16, #A> * tensor<32x128xf16, #B> -> tensor<128x128xf32, #C>
%next_b_ptr = tt.addptr %b_ptr, %b_off : tensor<32x128x!tt.ptr<f16>, #BL>
%next_b_ptr = tt.addptr %b_ptr, %b_off : tensor<32x128x!tt.ptr<f16>, #BL>, tensor<32x128xi32, #BL>
scf.yield %next_b_ptr, %c : tensor<32x128x!tt.ptr<f16>, #BL>, tensor<128x128xf32, #C>
}
return

View File

@@ -46,7 +46,7 @@ func @matmul_kernel__Pfp32_Pfp32_Pfp32_i32_i32_i32_i32_i32_i32_i32_i32_i32__12c6
%31 = tt.broadcast %29 : (tensor<1x64xi32>) -> tensor<64x64xi32>
%32 = arith.addi %30, %31 : tensor<64x64xi32>
%33 = tt.splat %arg0 : (!tt.ptr<f32>) -> tensor<64x64x!tt.ptr<f32>>
%34 = tt.addptr %33, %32 : tensor<64x64x!tt.ptr<f32>>
%34 = tt.addptr %33, %32 : tensor<64x64x!tt.ptr<f32>>, tensor<64x64xi32>
%35 = tt.expand_dims %23 {axis = 1 : i32} : (tensor<64xi32>) -> tensor<64x1xi32>
%36 = tt.splat %arg8 : (i32) -> tensor<64x1xi32>
%37 = arith.muli %35, %36 : tensor<64x1xi32>
@@ -57,7 +57,7 @@ func @matmul_kernel__Pfp32_Pfp32_Pfp32_i32_i32_i32_i32_i32_i32_i32_i32_i32__12c6
%42 = tt.broadcast %40 : (tensor<1x64xi32>) -> tensor<64x64xi32>
%43 = arith.addi %41, %42 : tensor<64x64xi32>
%44 = tt.splat %arg1 : (!tt.ptr<f32>) -> tensor<64x64x!tt.ptr<f32>>
%45 = tt.addptr %44, %43 : tensor<64x64x!tt.ptr<f32>>
%45 = tt.addptr %44, %43 : tensor<64x64x!tt.ptr<f32>>, tensor<64x64xi32>
%46 = arith.index_cast %arg5 : i32 to index
%47:3 = scf.for %arg12 = %c0 to %46 step %c64 iter_args(%arg13 = %cst_0, %arg14 = %34, %arg15 = %45) -> (tensor<64x64xf32>, tensor<64x64x!tt.ptr<f32>>, tensor<64x64x!tt.ptr<f32>>) {
%76 = tt.load %arg14, %cst, %cst_0 {cache = 1 : i32, evict = 1 : i32, isVolatile = false, transA=false, transB=false} : tensor<64x64xf32>
@@ -66,10 +66,10 @@ func @matmul_kernel__Pfp32_Pfp32_Pfp32_i32_i32_i32_i32_i32_i32_i32_i32_i32__12c6
%79 = arith.addf %arg13, %78 : tensor<64x64xf32>
%80 = arith.muli %arg7, %c64_i32 : i32
%81 = tt.splat %80 : (i32) -> tensor<64x64xi32>
%82 = tt.addptr %arg14, %81 : tensor<64x64x!tt.ptr<f32>>
%82 = tt.addptr %arg14, %81 : tensor<64x64x!tt.ptr<f32>>, tensor<64x64xi32>
%83 = arith.muli %arg8, %c64_i32 : i32
%84 = tt.splat %83 : (i32) -> tensor<64x64xi32>
%85 = tt.addptr %arg15, %84 : tensor<64x64x!tt.ptr<f32>>
%85 = tt.addptr %arg15, %84 : tensor<64x64x!tt.ptr<f32>>, tensor<64x64xi32>
scf.yield %79, %82, %85 : tensor<64x64xf32>, tensor<64x64x!tt.ptr<f32>>, tensor<64x64x!tt.ptr<f32>>
}
%48 = arith.muli %12, %c64_i32 : i32
@@ -90,7 +90,7 @@ func @matmul_kernel__Pfp32_Pfp32_Pfp32_i32_i32_i32_i32_i32_i32_i32_i32_i32__12c6
%63 = tt.broadcast %61 : (tensor<1x64xi32>) -> tensor<64x64xi32>
%64 = arith.addi %62, %63 : tensor<64x64xi32>
%65 = tt.splat %arg2 : (!tt.ptr<f32>) -> tensor<64x64x!tt.ptr<f32>>
%66 = tt.addptr %65, %64 : tensor<64x64x!tt.ptr<f32>>
%66 = tt.addptr %65, %64 : tensor<64x64x!tt.ptr<f32>>, tensor<64x64xi32>
%67 = tt.expand_dims %51 {axis = 1 : i32} : (tensor<64xi32>) -> tensor<64x1xi32>
%68 = tt.splat %arg3 : (i32) -> tensor<64x1xi32>
%69 = arith.cmpi slt, %67, %68 : tensor<64x1xi32>

View File

@@ -51,8 +51,8 @@ func @matmul_loop(%lb : index, %ub : index, %step : index, %A : !tt.ptr<f16>, %B
%b_op = triton_gpu.convert_layout %b : (tensor<32x128xf16, #B>) -> tensor<32x128xf16, #B_OP>
%c = tt.dot %a_op, %b_op, %prev_c {allowTF32 = true, transA = false, transB = false} : tensor<128x32xf16, #A_OP> * tensor<32x128xf16, #B_OP> -> tensor<128x128xf32, #C>
%next_a_ptr = tt.addptr %a_ptr, %a_off : tensor<128x32x!tt.ptr<f16>, #AL>
%next_b_ptr = tt.addptr %b_ptr, %b_off : tensor<32x128x!tt.ptr<f16>, #BL>
%next_a_ptr = tt.addptr %a_ptr, %a_off : tensor<128x32x!tt.ptr<f16>, #AL>, tensor<128x32xi32, #AL>
%next_b_ptr = tt.addptr %b_ptr, %b_off : tensor<32x128x!tt.ptr<f16>, #BL>, tensor<32x128xi32, #BL>
%next_a_ = tt.load %next_a_ptr, %a_mask, %a_other {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<128x32xf16, #AL>
%next_a = triton_gpu.convert_layout %next_a_ : (tensor<128x32xf16, #AL>) -> tensor<128x32xf16, #A>
%next_b_ = tt.load %next_b_ptr, %b_mask, %b_other {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<32x128xf16, #BL>