[OPTIMIZER] Updated TritonGPU-combine pass (#784)
WIP but should work int t…he cases we need so far
This commit is contained in:
@@ -5,9 +5,13 @@
|
|||||||
#include "mlir/IR/Matchers.h"
|
#include "mlir/IR/Matchers.h"
|
||||||
#include "mlir/IR/PatternMatch.h"
|
#include "mlir/IR/PatternMatch.h"
|
||||||
#include "mlir/IR/Verifier.h"
|
#include "mlir/IR/Verifier.h"
|
||||||
|
#include "mlir/Interfaces/InferTypeOpInterface.h"
|
||||||
#include "mlir/Pass/Pass.h"
|
#include "mlir/Pass/Pass.h"
|
||||||
|
#include "mlir/Pass/PassManager.h"
|
||||||
#include "mlir/Support/LogicalResult.h"
|
#include "mlir/Support/LogicalResult.h"
|
||||||
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
|
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
|
||||||
|
#include "mlir/Transforms/Passes.h"
|
||||||
|
#include "mlir/Transforms/RegionUtils.h"
|
||||||
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
|
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
|
||||||
#include "triton/Dialect/TritonGPU/Transforms/Passes.h"
|
#include "triton/Dialect/TritonGPU/Transforms/Passes.h"
|
||||||
|
|
||||||
@@ -36,7 +40,7 @@ class SimplifyConversion : public mlir::RewritePattern {
|
|||||||
public:
|
public:
|
||||||
SimplifyConversion(mlir::MLIRContext *context)
|
SimplifyConversion(mlir::MLIRContext *context)
|
||||||
: mlir::RewritePattern(triton::gpu::ConvertLayoutOp::getOperationName(),
|
: mlir::RewritePattern(triton::gpu::ConvertLayoutOp::getOperationName(),
|
||||||
2, context) {}
|
4, context) {}
|
||||||
|
|
||||||
mlir::LogicalResult
|
mlir::LogicalResult
|
||||||
matchAndRewrite(mlir::Operation *op,
|
matchAndRewrite(mlir::Operation *op,
|
||||||
@@ -86,106 +90,165 @@ public:
|
|||||||
//
|
//
|
||||||
// -----------------------------------------------------------------------------
|
// -----------------------------------------------------------------------------
|
||||||
|
|
||||||
|
static LogicalResult invertEncoding(Attribute targetEncoding, Operation *op,
|
||||||
|
Attribute &ret) {
|
||||||
|
ret = targetEncoding;
|
||||||
|
if (auto expand_dims = dyn_cast<triton::ExpandDimsOp>(op)) {
|
||||||
|
ret = triton::gpu::SliceEncodingAttr::get(
|
||||||
|
op->getContext(), expand_dims.axis(), targetEncoding);
|
||||||
|
}
|
||||||
|
if (auto reduce = dyn_cast<triton::ReduceOp>(op)) {
|
||||||
|
auto sliceEncoding =
|
||||||
|
targetEncoding.dyn_cast<triton::gpu::SliceEncodingAttr>();
|
||||||
|
if (!sliceEncoding)
|
||||||
|
return failure();
|
||||||
|
ret = sliceEncoding.getParent();
|
||||||
|
}
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
|
||||||
|
inline bool expensive_to_remat(Operation *op) {
|
||||||
|
if (!op)
|
||||||
|
return true;
|
||||||
|
if (isa<triton::gpu::ExtractSliceOp, triton::gpu::AllocTensorOp,
|
||||||
|
triton::gpu::InsertSliceAsyncOp, triton::LoadOp, triton::StoreOp,
|
||||||
|
triton::DotOp>(op))
|
||||||
|
return true;
|
||||||
|
if (isa<scf::YieldOp, scf::ForOp>(op))
|
||||||
|
return true;
|
||||||
|
return false;
|
||||||
|
};
|
||||||
|
|
||||||
|
Operation *cloneWithInferType(mlir::PatternRewriter &rewriter, Operation *op,
|
||||||
|
BlockAndValueMapping &mapping) {
|
||||||
|
Operation *newOp = rewriter.clone(*op, mapping);
|
||||||
|
auto origType = op->getResult(0).getType().cast<RankedTensorType>();
|
||||||
|
auto newType = RankedTensorType::get(
|
||||||
|
origType.getShape(), origType.getElementType(),
|
||||||
|
newOp->getOperand(0).getType().cast<RankedTensorType>().getEncoding());
|
||||||
|
newOp->getResult(0).setType(newType);
|
||||||
|
auto typeInfer = dyn_cast<InferTypeOpInterface>(newOp);
|
||||||
|
if (typeInfer) {
|
||||||
|
SmallVector<Type, 1> newType;
|
||||||
|
auto sucess = typeInfer.inferReturnTypes(
|
||||||
|
newOp->getContext(), newOp->getLoc(), newOp->getOperands(),
|
||||||
|
newOp->getAttrDictionary(), newOp->getRegions(), newType);
|
||||||
|
if (success)
|
||||||
|
newOp->getResult(0).setType(newType.front());
|
||||||
|
}
|
||||||
|
return newOp;
|
||||||
|
}
|
||||||
|
|
||||||
// Layout conversions are expensive. They require going through
|
// Layout conversions are expensive. They require going through
|
||||||
// shared memory, which is orders of magnitude slower than
|
// shared memory, which is orders of magnitude slower than
|
||||||
// other non-i/o operations in the dialect.
|
// other non-i/o operations in the dialect.
|
||||||
// It therefore makes sense to remove them whenever possible,
|
// It therefore makes sense to remove them whenever possible,
|
||||||
// even if it means rematerializing all values whose definitions
|
// even if it means rematerializing all values whose definitions
|
||||||
// are reachable from it without passing through any memory operation.
|
// are reachable from it without passing through any memory operation.
|
||||||
class PullConversionToSource : public mlir::RewritePattern {
|
class RematerializeBackward : public mlir::RewritePattern {
|
||||||
public:
|
public:
|
||||||
PullConversionToSource(mlir::MLIRContext *context)
|
RematerializeBackward(mlir::MLIRContext *context)
|
||||||
: mlir::RewritePattern(triton::gpu::ConvertLayoutOp::getOperationName(),
|
: mlir::RewritePattern(triton::gpu::ConvertLayoutOp::getOperationName(),
|
||||||
2, context) {}
|
2, context) {}
|
||||||
|
|
||||||
Attribute invertEncoding(Type targetType, Operation *op) const {
|
|
||||||
RankedTensorType targetTensorType = targetType.cast<RankedTensorType>();
|
|
||||||
if (auto expand_dims = dyn_cast<triton::ExpandDimsOp>(op)) {
|
|
||||||
return targetTensorType.getEncoding()
|
|
||||||
.cast<triton::gpu::BlockedEncodingAttr>()
|
|
||||||
.squeeze(expand_dims.axis());
|
|
||||||
}
|
|
||||||
return targetTensorType.getEncoding();
|
|
||||||
}
|
|
||||||
|
|
||||||
mlir::LogicalResult
|
mlir::LogicalResult
|
||||||
matchAndRewrite(mlir::Operation *cvt,
|
matchAndRewrite(mlir::Operation *cvt,
|
||||||
mlir::PatternRewriter &rewriter) const override {
|
mlir::PatternRewriter &rewriter) const override {
|
||||||
if (!llvm::isa<triton::gpu::ConvertLayoutOp>(cvt))
|
if (!llvm::isa<triton::gpu::ConvertLayoutOp>(cvt))
|
||||||
return mlir::failure();
|
return mlir::failure();
|
||||||
// constants/splat are handled separately
|
// we don't touch block arguments
|
||||||
Operation *op = cvt->getOperand(0).getDefiningOp();
|
Operation *op = cvt->getOperand(0).getDefiningOp();
|
||||||
if (!op)
|
if (!op)
|
||||||
return mlir::failure();
|
return mlir::failure();
|
||||||
|
// we don't want to rematerialize any conversion to/from shared
|
||||||
auto blacklist = [](Operation *op) {
|
if (isSharedLayout(cvt->getResults()[0]) ||
|
||||||
if (isa<triton::gpu::ExtractSliceOp, triton::gpu::AllocTensorOp,
|
isSharedLayout(cvt->getOperand(0)))
|
||||||
triton::gpu::InsertSliceAsyncOp, triton::LoadOp, triton::StoreOp,
|
|
||||||
triton::DotOp>(op))
|
|
||||||
return true;
|
|
||||||
if (isa<scf::YieldOp, scf::ForOp>(op))
|
|
||||||
return true;
|
|
||||||
return false;
|
|
||||||
};
|
|
||||||
|
|
||||||
// find all the ops that the conversion depends on
|
|
||||||
SetVector<Operation *> depIntoOps;
|
|
||||||
mlir::getBackwardSlice(cvt, &depIntoOps, [&](Operation *op) {
|
|
||||||
return !blacklist(op) && op->getResult(0) &&
|
|
||||||
(op->getResult(0).getParentBlock() ==
|
|
||||||
cvt->getResult(0).getParentBlock());
|
|
||||||
});
|
|
||||||
// find all the ops that depend on something expensive
|
|
||||||
// to rematerialize and break the dependency
|
|
||||||
// chain there
|
|
||||||
SetVector<Operation *> blacklistedOps;
|
|
||||||
mlir::getBackwardSlice(cvt, &blacklistedOps, blacklist);
|
|
||||||
for (Operation *op : blacklistedOps) {
|
|
||||||
SetVector<Operation *> toRemove;
|
|
||||||
mlir::getBackwardSlice(op, &toRemove);
|
|
||||||
depIntoOps.set_subtract(toRemove);
|
|
||||||
}
|
|
||||||
// if there is nothing that can benefit from moving conversions
|
|
||||||
// in the remaining op, we don't do anything
|
|
||||||
auto it = llvm::find_if(depIntoOps, [&](Operation *op) {
|
|
||||||
if (isa<triton::gpu::ConvertLayoutOp>(op)) {
|
|
||||||
// conversions in for loops interfere with the
|
|
||||||
// push-to-sink pass. Need a better cost model if how many conversions
|
|
||||||
// we can actually remove by moving them to the beginning of the block
|
|
||||||
auto forOp = dyn_cast<scf::ForOp>(cvt->getParentOp());
|
|
||||||
if (!forOp &&
|
|
||||||
(cvt->getResult(0).getType() == op->getOperand(0).getType()))
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
if (isa<arith::ConstantOp, triton::MakeRangeOp, triton::SplatOp>(op))
|
|
||||||
return true;
|
|
||||||
return false;
|
|
||||||
});
|
|
||||||
if (it == depIntoOps.end()) {
|
|
||||||
return mlir::failure();
|
return mlir::failure();
|
||||||
|
auto targetType = cvt->getResultTypes()[0].cast<RankedTensorType>();
|
||||||
|
// DFS
|
||||||
|
SetVector<Operation *> processed;
|
||||||
|
SetVector<Attribute> layout;
|
||||||
|
std::vector<std::pair<Operation *, Attribute>> queue;
|
||||||
|
std::vector<std::pair<Value, Attribute>> toConvert;
|
||||||
|
queue.push_back({cvt, targetType.getEncoding()});
|
||||||
|
int numCvts = 1;
|
||||||
|
while (!queue.empty()) {
|
||||||
|
Operation *currOp;
|
||||||
|
Attribute currLayout;
|
||||||
|
std::tie(currOp, currLayout) = queue.back();
|
||||||
|
queue.pop_back();
|
||||||
|
// If the current operation is expensive to rematerialize,
|
||||||
|
// we stop everything
|
||||||
|
if (expensive_to_remat(currOp))
|
||||||
|
break;
|
||||||
|
// a conversion will be removed here (i.e. transfered to operands)
|
||||||
|
numCvts -= 1;
|
||||||
|
// done processing
|
||||||
|
processed.insert(currOp);
|
||||||
|
layout.insert(currLayout);
|
||||||
|
// add all operands to the queue
|
||||||
|
for (Value argI : currOp->getOperands()) {
|
||||||
|
Attribute newEncoding;
|
||||||
|
if (failed(invertEncoding(currLayout, currOp, newEncoding)))
|
||||||
|
return mlir::failure();
|
||||||
|
toConvert.push_back({argI, newEncoding});
|
||||||
|
Operation *opArgI = argI.getDefiningOp();
|
||||||
|
if (!opArgI)
|
||||||
|
continue;
|
||||||
|
if (!opArgI || processed.contains(opArgI) ||
|
||||||
|
(opArgI->getBlock() != cvt->getBlock()))
|
||||||
|
continue;
|
||||||
|
// if the conversion can be folded into opArgI then
|
||||||
|
// we actually haven't added anny conversion
|
||||||
|
if (isa<triton::gpu::ConvertLayoutOp, arith::ConstantOp,
|
||||||
|
triton::MakeRangeOp, triton::SplatOp>(*opArgI))
|
||||||
|
continue;
|
||||||
|
// we add one expensive conversion for the current operand
|
||||||
|
numCvts += 1;
|
||||||
|
queue.push_back({opArgI, newEncoding});
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
// if rematerialization would add more conversions than it removes
|
||||||
|
// then we don't do it
|
||||||
|
if (numCvts > 0)
|
||||||
|
return mlir::failure();
|
||||||
|
|
||||||
|
FuncOp parentFunc = cvt->getParentOfType<FuncOp>();
|
||||||
|
bool test = cvt->getResult(0)
|
||||||
|
.getType()
|
||||||
|
.cast<RankedTensorType>()
|
||||||
|
.getEncoding()
|
||||||
|
.isa<triton::gpu::MmaEncodingAttr>();
|
||||||
|
// if (test)
|
||||||
|
// llvm::outs() << "--------\nConverting " << *cvt << "\n---------\n";
|
||||||
|
|
||||||
// We convert cvt(op(arg_0, arg_1, ..., arg_n))
|
|
||||||
// into op(cvt_0(arg_0), cvt_1(arg_1), ..., cvt_n(arg_n))
|
|
||||||
BlockAndValueMapping mapping;
|
BlockAndValueMapping mapping;
|
||||||
for (Value argI : op->getOperands()) {
|
for (int i = toConvert.size() - 1; i >= 0; i--) {
|
||||||
// Compute new argument types
|
// unpack information
|
||||||
auto oldArgType = argI.getType().dyn_cast<RankedTensorType>();
|
Value currOperand;
|
||||||
if (!oldArgType)
|
Attribute targetLayout;
|
||||||
continue;
|
std::tie(currOperand, targetLayout) = toConvert[i];
|
||||||
auto newEncoding = invertEncoding(cvt->getResultTypes()[0], op);
|
// if (test)
|
||||||
auto newArgType = RankedTensorType::get(
|
// llvm::outs() << "current " << currOperand << "\n";
|
||||||
oldArgType.getShape(), oldArgType.getElementType(), newEncoding);
|
// rematerialize the operand if necessary
|
||||||
// Create new argument
|
Operation *currOperation = currOperand.getDefiningOp();
|
||||||
auto cvtI = rewriter.create<triton::gpu::ConvertLayoutOp>(
|
if (processed.contains(currOperation)) {
|
||||||
op->getLoc(), newArgType, argI);
|
currOperation = cloneWithInferType(rewriter, currOperation, mapping);
|
||||||
cvtI->moveBefore(op);
|
currOperand = currOperation->getResult(0);
|
||||||
mapping.map(argI, cvtI);
|
}
|
||||||
|
if (i == 0)
|
||||||
|
break;
|
||||||
|
// compute target type for the layout cast
|
||||||
|
auto currType = currOperand.getType().cast<RankedTensorType>();
|
||||||
|
auto newType = RankedTensorType::get(
|
||||||
|
currType.getShape(), currType.getElementType(), targetLayout);
|
||||||
|
auto newOperand = rewriter.create<triton::gpu::ConvertLayoutOp>(
|
||||||
|
currOperand.getLoc(), newType, currOperand);
|
||||||
|
if (currOperation)
|
||||||
|
newOperand->moveAfter(currOperation);
|
||||||
|
mapping.map(currOperand, newOperand);
|
||||||
}
|
}
|
||||||
Operation *newOp = rewriter.clone(*op, mapping);
|
rewriter.replaceOp(cvt, mapping.lookup(cvt->getOperand(0)));
|
||||||
newOp->getResult(0).setType(cvt->getResult(0).getType());
|
|
||||||
rewriter.replaceOp(cvt, newOp->getResults());
|
|
||||||
|
|
||||||
return mlir::success();
|
return mlir::success();
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
@@ -226,6 +289,7 @@ bool tryLegalizeOp(Operation *op, DenseSet<Value> toPreserve,
|
|||||||
std::pair<SmallVector<Value, 4>, scf::ForOp>
|
std::pair<SmallVector<Value, 4>, scf::ForOp>
|
||||||
tryConvertIterArg(scf::ForOp &forOp, mlir::PatternRewriter &rewriter, size_t i,
|
tryConvertIterArg(scf::ForOp &forOp, mlir::PatternRewriter &rewriter, size_t i,
|
||||||
Type newType) {
|
Type newType) {
|
||||||
|
forOp.getInductionVar();
|
||||||
auto newEncoding = newType.cast<RankedTensorType>().getEncoding();
|
auto newEncoding = newType.cast<RankedTensorType>().getEncoding();
|
||||||
auto ctx = forOp.getContext();
|
auto ctx = forOp.getContext();
|
||||||
auto isInLoop = [&](Operation *op) { return op->getParentOp() == forOp; };
|
auto isInLoop = [&](Operation *op) { return op->getParentOp() == forOp; };
|
||||||
@@ -243,6 +307,7 @@ tryConvertIterArg(scf::ForOp &forOp, mlir::PatternRewriter &rewriter, size_t i,
|
|||||||
BlockAndValueMapping mapping;
|
BlockAndValueMapping mapping;
|
||||||
for (const auto &arg : llvm::enumerate(forOp.getRegionIterArgs()))
|
for (const auto &arg : llvm::enumerate(forOp.getRegionIterArgs()))
|
||||||
mapping.map(arg.value(), newForOp.getRegionIterArgs()[arg.index()]);
|
mapping.map(arg.value(), newForOp.getRegionIterArgs()[arg.index()]);
|
||||||
|
mapping.map(forOp.getInductionVar(), newForOp.getInductionVar());
|
||||||
// traverse all ops in the loop
|
// traverse all ops in the loop
|
||||||
for (Operation &op : forOp.getBody()->without_terminator()) {
|
for (Operation &op : forOp.getBody()->without_terminator()) {
|
||||||
// we clone the op
|
// we clone the op
|
||||||
@@ -278,9 +343,9 @@ tryConvertIterArg(scf::ForOp &forOp, mlir::PatternRewriter &rewriter, size_t i,
|
|||||||
return {newResults, newForOp};
|
return {newResults, newForOp};
|
||||||
}
|
}
|
||||||
|
|
||||||
class MoveArgConvertOutOfLoop : public mlir::RewritePattern {
|
class MoveConvertOutOfLoop : public mlir::RewritePattern {
|
||||||
public:
|
public:
|
||||||
MoveArgConvertOutOfLoop(mlir::MLIRContext *context)
|
MoveConvertOutOfLoop(mlir::MLIRContext *context)
|
||||||
: mlir::RewritePattern(scf::ForOp::getOperationName(), 1, context) {}
|
: mlir::RewritePattern(scf::ForOp::getOperationName(), 1, context) {}
|
||||||
|
|
||||||
mlir::LogicalResult matchAndRewrite(mlir::Operation *op,
|
mlir::LogicalResult matchAndRewrite(mlir::Operation *op,
|
||||||
@@ -290,27 +355,14 @@ public:
|
|||||||
auto isInLoop = [&](Operation *op) { return op->getParentOp() == forOp; };
|
auto isInLoop = [&](Operation *op) { return op->getParentOp() == forOp; };
|
||||||
auto iterArgs = forOp.getRegionIterArgs();
|
auto iterArgs = forOp.getRegionIterArgs();
|
||||||
for (auto iterArg : llvm::enumerate(iterArgs)) {
|
for (auto iterArg : llvm::enumerate(iterArgs)) {
|
||||||
|
// skip non-tensor types
|
||||||
|
if (!iterArg.value().getType().isa<RankedTensorType>())
|
||||||
|
continue;
|
||||||
|
// check
|
||||||
for (auto op : iterArg.value().getUsers()) {
|
for (auto op : iterArg.value().getUsers()) {
|
||||||
auto currOps = mlir::getSlice(op, isInLoop);
|
if (isa<triton::gpu::ConvertLayoutOp>(op)) {
|
||||||
auto pred = [&](Operation *op) {
|
|
||||||
return isa<triton::LoadOp, triton::StoreOp>(op);
|
|
||||||
};
|
|
||||||
auto isCvt = [&](Operation *op) {
|
|
||||||
return isa<triton::gpu::ConvertLayoutOp>(op);
|
|
||||||
};
|
|
||||||
auto isYield = [&](Operation *op) { return isa<scf::YieldOp>(op); };
|
|
||||||
auto opIt = std::find(currOps.begin(), currOps.end(), op);
|
|
||||||
auto yieldIt = std::find_if(currOps.begin(), currOps.end(), isYield);
|
|
||||||
auto fwdEndIt = std::find_if(opIt, currOps.end(), pred);
|
|
||||||
auto bwdBeginIt = std::find_if(currOps.begin(), opIt, pred);
|
|
||||||
auto fwdCvtIt = std::find_if(opIt, fwdEndIt, isCvt);
|
|
||||||
auto bwdCvtIt = std::find_if(bwdBeginIt, opIt, isCvt);
|
|
||||||
|
|
||||||
if (!iterArg.value().getType().isa<RankedTensorType>())
|
|
||||||
continue;
|
|
||||||
if (fwdCvtIt != fwdEndIt) {
|
|
||||||
auto newFor = tryConvertIterArg(forOp, rewriter, iterArg.index(),
|
auto newFor = tryConvertIterArg(forOp, rewriter, iterArg.index(),
|
||||||
(*fwdCvtIt)->getResult(0).getType());
|
op->getResult(0).getType());
|
||||||
rewriter.replaceOp(forOp, newFor.first);
|
rewriter.replaceOp(forOp, newFor.first);
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
@@ -324,9 +376,9 @@ public:
|
|||||||
//
|
//
|
||||||
// -----------------------------------------------------------------------------
|
// -----------------------------------------------------------------------------
|
||||||
|
|
||||||
class PushConversionToSink : public mlir::RewritePattern {
|
class RematerializeForward : public mlir::RewritePattern {
|
||||||
public:
|
public:
|
||||||
PushConversionToSink(mlir::MLIRContext *context)
|
RematerializeForward(mlir::MLIRContext *context)
|
||||||
: mlir::RewritePattern(triton::gpu::ConvertLayoutOp::getOperationName(),
|
: mlir::RewritePattern(triton::gpu::ConvertLayoutOp::getOperationName(),
|
||||||
2, context) {}
|
2, context) {}
|
||||||
|
|
||||||
@@ -430,13 +482,14 @@ public:
|
|||||||
mlir::RewritePatternSet patterns(context);
|
mlir::RewritePatternSet patterns(context);
|
||||||
|
|
||||||
patterns.add<SimplifyConversion>(context);
|
patterns.add<SimplifyConversion>(context);
|
||||||
patterns.add<PullConversionToSource>(context);
|
patterns.add<RematerializeBackward>(context);
|
||||||
patterns.add<PushConversionToSink>(context);
|
patterns.add<RematerializeForward>(context);
|
||||||
patterns.add<MoveArgConvertOutOfLoop>(context);
|
patterns.add<MoveConvertOutOfLoop>(context);
|
||||||
patterns.add<BlockedToMMA>(context);
|
patterns.add<BlockedToMMA>(context);
|
||||||
|
|
||||||
if (applyPatternsAndFoldGreedily(m, std::move(patterns)).failed())
|
if (applyPatternsAndFoldGreedily(m, std::move(patterns)).failed()) {
|
||||||
signalPassFailure();
|
signalPassFailure();
|
||||||
|
}
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@@ -45,8 +45,8 @@ func @remat(%arg0: i32) -> tensor<1024xi32, #layout1> {
|
|||||||
// CHECK: %1 = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32, [[target_layout]]>
|
// CHECK: %1 = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32, [[target_layout]]>
|
||||||
// CHECK: %2 = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32, [[target_layout]]>
|
// CHECK: %2 = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32, [[target_layout]]>
|
||||||
// CHECK: %3 = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32, [[target_layout]]>
|
// CHECK: %3 = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32, [[target_layout]]>
|
||||||
// CHECK: %4 = arith.muli %2, %3 : tensor<1024xi32, [[target_layout]]>
|
// CHECK: %4 = arith.muli %0, %2 : tensor<1024xi32, [[target_layout]]>
|
||||||
// CHECK: %5 = arith.muli %0, %1 : tensor<1024xi32, [[target_layout]]>
|
// CHECK: %5 = arith.muli %1, %3 : tensor<1024xi32, [[target_layout]]>
|
||||||
// CHECK: %6 = arith.addi %4, %5 : tensor<1024xi32, [[target_layout]]>
|
// CHECK: %6 = arith.addi %4, %5 : tensor<1024xi32, [[target_layout]]>
|
||||||
// CHECK: return %6 : tensor<1024xi32, [[target_layout]]>
|
// CHECK: return %6 : tensor<1024xi32, [[target_layout]]>
|
||||||
}
|
}
|
||||||
@@ -61,34 +61,10 @@ func @remat(%arg0: i32) -> tensor<1024xi32, #layout1> {
|
|||||||
|
|
||||||
// CHECK-LABEL: transpose
|
// CHECK-LABEL: transpose
|
||||||
func @transpose(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg1: i32 {tt.divisibility = 16 : i32}, %arg2: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg3: i32 {tt.divisibility = 16 : i32}) {
|
func @transpose(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg1: i32 {tt.divisibility = 16 : i32}, %arg2: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg3: i32 {tt.divisibility = 16 : i32}) {
|
||||||
// CHECK: %cst = arith.constant dense<true> : tensor<64x64xi1, [[row_layout]]>
|
// CHECK-NOT: triton_gpu.convert_layout
|
||||||
// CHECK: %cst_0 = arith.constant dense<0.000000e+00> : tensor<64x64xf32, [[row_layout]]>
|
// CHECK: [[loaded_val:%.*]] = tt.load {{.*}}, %cst, %cst_0 {cache = 1 : i32, evict = 1 : i32, isOtherUnspecified = false, isVolatile = false} : tensor<64x64xf32, [[row_layout]]>
|
||||||
// CHECK: %cst_1 = arith.constant dense<true> : tensor<64x64xi1, [[col_layout]]>
|
// CHECK: [[cvt_val:%.*]] = triton_gpu.convert_layout [[loaded_val]] : (tensor<64x64xf32, [[row_layout]]>) -> tensor<64x64xf32, [[col_layout]]>
|
||||||
// CHECK: %0 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 1, parent = [[col_layout]]}>>
|
// CHECK: tt.store {{.*}}, [[cvt_val]], %cst_1 : tensor<64x64xf32, [[col_layout]]>
|
||||||
// CHECK: %1 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 1, parent = [[row_layout]]}>>
|
|
||||||
// CHECK: %2 = tt.expand_dims %1 {axis = 1 : i32} : (tensor<64xi32, #triton_gpu.slice<{dim = 1, parent = [[row_layout]]}>>) -> tensor<64x1xi32, [[row_layout]]>
|
|
||||||
// CHECK: %3 = tt.splat %arg1 : (i32) -> tensor<64x1xi32, [[row_layout]]>
|
|
||||||
// CHECK: %4 = tt.splat %arg0 : (!tt.ptr<f32>) -> tensor<64x1x!tt.ptr<f32>, [[row_layout]]>
|
|
||||||
// CHECK: %5 = arith.muli %2, %3 : tensor<64x1xi32, [[row_layout]]>
|
|
||||||
// CHECK: %6 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = [[col_layout]]}>>
|
|
||||||
// CHECK: %7 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = [[row_layout]]}>>
|
|
||||||
// CHECK: %8 = tt.addptr %4, %5 : tensor<64x1x!tt.ptr<f32>, [[row_layout]]>
|
|
||||||
// CHECK: %9 = tt.expand_dims %7 {axis = 0 : i32} : (tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = [[row_layout]]}>>) -> tensor<1x64xi32, [[row_layout]]>
|
|
||||||
// CHECK: %10 = tt.broadcast %8 : (tensor<64x1x!tt.ptr<f32>, [[row_layout]]>) -> tensor<64x64x!tt.ptr<f32>, [[row_layout]]>
|
|
||||||
// CHECK: %11 = tt.broadcast %9 : (tensor<1x64xi32, [[row_layout]]>) -> tensor<64x64xi32, [[row_layout]]>
|
|
||||||
// CHECK: %12 = tt.splat %arg2 : (!tt.ptr<f32>) -> tensor<64x1x!tt.ptr<f32>, [[col_layout]]>
|
|
||||||
// CHECK: %13 = tt.expand_dims %0 {axis = 1 : i32} : (tensor<64xi32, #triton_gpu.slice<{dim = 1, parent = [[col_layout]]}>>) -> tensor<64x1xi32, [[col_layout]]>
|
|
||||||
// CHECK: %14 = tt.expand_dims %6 {axis = 0 : i32} : (tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = [[col_layout]]}>>) -> tensor<1x64xi32, [[col_layout]]>
|
|
||||||
// CHECK: %15 = tt.splat %arg3 : (i32) -> tensor<1x64xi32, [[col_layout]]>
|
|
||||||
// CHECK: %16 = tt.addptr %12, %13 : tensor<64x1x!tt.ptr<f32>, [[col_layout]]>
|
|
||||||
// CHECK: %17 = arith.muli %14, %15 : tensor<1x64xi32, [[col_layout]]>
|
|
||||||
// CHECK: %18 = tt.broadcast %16 : (tensor<64x1x!tt.ptr<f32>, [[col_layout]]>) -> tensor<64x64x!tt.ptr<f32>, [[col_layout]]>
|
|
||||||
// CHECK: %19 = tt.broadcast %17 : (tensor<1x64xi32, [[col_layout]]>) -> tensor<64x64xi32, [[col_layout]]>
|
|
||||||
// CHECK: %20 = tt.addptr %10, %11 : tensor<64x64x!tt.ptr<f32>, [[row_layout]]>
|
|
||||||
// CHECK: %21 = tt.load %20, %cst, %cst_0 {cache = 1 : i32, evict = 1 : i32, isOtherUnspecified = false, isVolatile = false} : tensor<64x64xf32, [[row_layout]]>
|
|
||||||
// CHECK: %22 = tt.addptr %18, %19 : tensor<64x64x!tt.ptr<f32>, [[col_layout]]>
|
|
||||||
// CHECK: %23 = triton_gpu.convert_layout %21 : (tensor<64x64xf32, [[row_layout]]>) -> tensor<64x64xf32, [[col_layout]]>
|
|
||||||
// CHECK: tt.store %22, %23, %cst_1 : tensor<64x64xf32, [[col_layout]]>
|
|
||||||
// CHECK: return
|
// CHECK: return
|
||||||
%cst = arith.constant dense<0.000000e+00> : tensor<64x64xf32, #blocked1>
|
%cst = arith.constant dense<0.000000e+00> : tensor<64x64xf32, #blocked1>
|
||||||
%cst_0 = arith.constant dense<true> : tensor<64x64xi1, #blocked1>
|
%cst_0 = arith.constant dense<true> : tensor<64x64xi1, #blocked1>
|
||||||
|
Reference in New Issue
Block a user