[OPTIMIZER] Updated TritonGPU-combine pass (#784)
WIP but should work int t…he cases we need so far
This commit is contained in:
@@ -5,9 +5,13 @@
|
||||
#include "mlir/IR/Matchers.h"
|
||||
#include "mlir/IR/PatternMatch.h"
|
||||
#include "mlir/IR/Verifier.h"
|
||||
#include "mlir/Interfaces/InferTypeOpInterface.h"
|
||||
#include "mlir/Pass/Pass.h"
|
||||
#include "mlir/Pass/PassManager.h"
|
||||
#include "mlir/Support/LogicalResult.h"
|
||||
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
|
||||
#include "mlir/Transforms/Passes.h"
|
||||
#include "mlir/Transforms/RegionUtils.h"
|
||||
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
|
||||
#include "triton/Dialect/TritonGPU/Transforms/Passes.h"
|
||||
|
||||
@@ -36,7 +40,7 @@ class SimplifyConversion : public mlir::RewritePattern {
|
||||
public:
|
||||
SimplifyConversion(mlir::MLIRContext *context)
|
||||
: mlir::RewritePattern(triton::gpu::ConvertLayoutOp::getOperationName(),
|
||||
2, context) {}
|
||||
4, context) {}
|
||||
|
||||
mlir::LogicalResult
|
||||
matchAndRewrite(mlir::Operation *op,
|
||||
@@ -86,106 +90,165 @@ public:
|
||||
//
|
||||
// -----------------------------------------------------------------------------
|
||||
|
||||
static LogicalResult invertEncoding(Attribute targetEncoding, Operation *op,
|
||||
Attribute &ret) {
|
||||
ret = targetEncoding;
|
||||
if (auto expand_dims = dyn_cast<triton::ExpandDimsOp>(op)) {
|
||||
ret = triton::gpu::SliceEncodingAttr::get(
|
||||
op->getContext(), expand_dims.axis(), targetEncoding);
|
||||
}
|
||||
if (auto reduce = dyn_cast<triton::ReduceOp>(op)) {
|
||||
auto sliceEncoding =
|
||||
targetEncoding.dyn_cast<triton::gpu::SliceEncodingAttr>();
|
||||
if (!sliceEncoding)
|
||||
return failure();
|
||||
ret = sliceEncoding.getParent();
|
||||
}
|
||||
return success();
|
||||
}
|
||||
|
||||
inline bool expensive_to_remat(Operation *op) {
|
||||
if (!op)
|
||||
return true;
|
||||
if (isa<triton::gpu::ExtractSliceOp, triton::gpu::AllocTensorOp,
|
||||
triton::gpu::InsertSliceAsyncOp, triton::LoadOp, triton::StoreOp,
|
||||
triton::DotOp>(op))
|
||||
return true;
|
||||
if (isa<scf::YieldOp, scf::ForOp>(op))
|
||||
return true;
|
||||
return false;
|
||||
};
|
||||
|
||||
Operation *cloneWithInferType(mlir::PatternRewriter &rewriter, Operation *op,
|
||||
BlockAndValueMapping &mapping) {
|
||||
Operation *newOp = rewriter.clone(*op, mapping);
|
||||
auto origType = op->getResult(0).getType().cast<RankedTensorType>();
|
||||
auto newType = RankedTensorType::get(
|
||||
origType.getShape(), origType.getElementType(),
|
||||
newOp->getOperand(0).getType().cast<RankedTensorType>().getEncoding());
|
||||
newOp->getResult(0).setType(newType);
|
||||
auto typeInfer = dyn_cast<InferTypeOpInterface>(newOp);
|
||||
if (typeInfer) {
|
||||
SmallVector<Type, 1> newType;
|
||||
auto sucess = typeInfer.inferReturnTypes(
|
||||
newOp->getContext(), newOp->getLoc(), newOp->getOperands(),
|
||||
newOp->getAttrDictionary(), newOp->getRegions(), newType);
|
||||
if (success)
|
||||
newOp->getResult(0).setType(newType.front());
|
||||
}
|
||||
return newOp;
|
||||
}
|
||||
|
||||
// Layout conversions are expensive. They require going through
|
||||
// shared memory, which is orders of magnitude slower than
|
||||
// other non-i/o operations in the dialect.
|
||||
// It therefore makes sense to remove them whenever possible,
|
||||
// even if it means rematerializing all values whose definitions
|
||||
// are reachable from it without passing through any memory operation.
|
||||
class PullConversionToSource : public mlir::RewritePattern {
|
||||
class RematerializeBackward : public mlir::RewritePattern {
|
||||
public:
|
||||
PullConversionToSource(mlir::MLIRContext *context)
|
||||
RematerializeBackward(mlir::MLIRContext *context)
|
||||
: mlir::RewritePattern(triton::gpu::ConvertLayoutOp::getOperationName(),
|
||||
2, context) {}
|
||||
|
||||
Attribute invertEncoding(Type targetType, Operation *op) const {
|
||||
RankedTensorType targetTensorType = targetType.cast<RankedTensorType>();
|
||||
if (auto expand_dims = dyn_cast<triton::ExpandDimsOp>(op)) {
|
||||
return targetTensorType.getEncoding()
|
||||
.cast<triton::gpu::BlockedEncodingAttr>()
|
||||
.squeeze(expand_dims.axis());
|
||||
}
|
||||
return targetTensorType.getEncoding();
|
||||
}
|
||||
|
||||
mlir::LogicalResult
|
||||
matchAndRewrite(mlir::Operation *cvt,
|
||||
mlir::PatternRewriter &rewriter) const override {
|
||||
if (!llvm::isa<triton::gpu::ConvertLayoutOp>(cvt))
|
||||
return mlir::failure();
|
||||
// constants/splat are handled separately
|
||||
// we don't touch block arguments
|
||||
Operation *op = cvt->getOperand(0).getDefiningOp();
|
||||
if (!op)
|
||||
return mlir::failure();
|
||||
|
||||
auto blacklist = [](Operation *op) {
|
||||
if (isa<triton::gpu::ExtractSliceOp, triton::gpu::AllocTensorOp,
|
||||
triton::gpu::InsertSliceAsyncOp, triton::LoadOp, triton::StoreOp,
|
||||
triton::DotOp>(op))
|
||||
return true;
|
||||
if (isa<scf::YieldOp, scf::ForOp>(op))
|
||||
return true;
|
||||
return false;
|
||||
};
|
||||
|
||||
// find all the ops that the conversion depends on
|
||||
SetVector<Operation *> depIntoOps;
|
||||
mlir::getBackwardSlice(cvt, &depIntoOps, [&](Operation *op) {
|
||||
return !blacklist(op) && op->getResult(0) &&
|
||||
(op->getResult(0).getParentBlock() ==
|
||||
cvt->getResult(0).getParentBlock());
|
||||
});
|
||||
// find all the ops that depend on something expensive
|
||||
// to rematerialize and break the dependency
|
||||
// chain there
|
||||
SetVector<Operation *> blacklistedOps;
|
||||
mlir::getBackwardSlice(cvt, &blacklistedOps, blacklist);
|
||||
for (Operation *op : blacklistedOps) {
|
||||
SetVector<Operation *> toRemove;
|
||||
mlir::getBackwardSlice(op, &toRemove);
|
||||
depIntoOps.set_subtract(toRemove);
|
||||
}
|
||||
// if there is nothing that can benefit from moving conversions
|
||||
// in the remaining op, we don't do anything
|
||||
auto it = llvm::find_if(depIntoOps, [&](Operation *op) {
|
||||
if (isa<triton::gpu::ConvertLayoutOp>(op)) {
|
||||
// conversions in for loops interfere with the
|
||||
// push-to-sink pass. Need a better cost model if how many conversions
|
||||
// we can actually remove by moving them to the beginning of the block
|
||||
auto forOp = dyn_cast<scf::ForOp>(cvt->getParentOp());
|
||||
if (!forOp &&
|
||||
(cvt->getResult(0).getType() == op->getOperand(0).getType()))
|
||||
return true;
|
||||
}
|
||||
if (isa<arith::ConstantOp, triton::MakeRangeOp, triton::SplatOp>(op))
|
||||
return true;
|
||||
return false;
|
||||
});
|
||||
if (it == depIntoOps.end()) {
|
||||
// we don't want to rematerialize any conversion to/from shared
|
||||
if (isSharedLayout(cvt->getResults()[0]) ||
|
||||
isSharedLayout(cvt->getOperand(0)))
|
||||
return mlir::failure();
|
||||
auto targetType = cvt->getResultTypes()[0].cast<RankedTensorType>();
|
||||
// DFS
|
||||
SetVector<Operation *> processed;
|
||||
SetVector<Attribute> layout;
|
||||
std::vector<std::pair<Operation *, Attribute>> queue;
|
||||
std::vector<std::pair<Value, Attribute>> toConvert;
|
||||
queue.push_back({cvt, targetType.getEncoding()});
|
||||
int numCvts = 1;
|
||||
while (!queue.empty()) {
|
||||
Operation *currOp;
|
||||
Attribute currLayout;
|
||||
std::tie(currOp, currLayout) = queue.back();
|
||||
queue.pop_back();
|
||||
// If the current operation is expensive to rematerialize,
|
||||
// we stop everything
|
||||
if (expensive_to_remat(currOp))
|
||||
break;
|
||||
// a conversion will be removed here (i.e. transfered to operands)
|
||||
numCvts -= 1;
|
||||
// done processing
|
||||
processed.insert(currOp);
|
||||
layout.insert(currLayout);
|
||||
// add all operands to the queue
|
||||
for (Value argI : currOp->getOperands()) {
|
||||
Attribute newEncoding;
|
||||
if (failed(invertEncoding(currLayout, currOp, newEncoding)))
|
||||
return mlir::failure();
|
||||
toConvert.push_back({argI, newEncoding});
|
||||
Operation *opArgI = argI.getDefiningOp();
|
||||
if (!opArgI)
|
||||
continue;
|
||||
if (!opArgI || processed.contains(opArgI) ||
|
||||
(opArgI->getBlock() != cvt->getBlock()))
|
||||
continue;
|
||||
// if the conversion can be folded into opArgI then
|
||||
// we actually haven't added anny conversion
|
||||
if (isa<triton::gpu::ConvertLayoutOp, arith::ConstantOp,
|
||||
triton::MakeRangeOp, triton::SplatOp>(*opArgI))
|
||||
continue;
|
||||
// we add one expensive conversion for the current operand
|
||||
numCvts += 1;
|
||||
queue.push_back({opArgI, newEncoding});
|
||||
}
|
||||
}
|
||||
// if rematerialization would add more conversions than it removes
|
||||
// then we don't do it
|
||||
if (numCvts > 0)
|
||||
return mlir::failure();
|
||||
|
||||
FuncOp parentFunc = cvt->getParentOfType<FuncOp>();
|
||||
bool test = cvt->getResult(0)
|
||||
.getType()
|
||||
.cast<RankedTensorType>()
|
||||
.getEncoding()
|
||||
.isa<triton::gpu::MmaEncodingAttr>();
|
||||
// if (test)
|
||||
// llvm::outs() << "--------\nConverting " << *cvt << "\n---------\n";
|
||||
|
||||
// We convert cvt(op(arg_0, arg_1, ..., arg_n))
|
||||
// into op(cvt_0(arg_0), cvt_1(arg_1), ..., cvt_n(arg_n))
|
||||
BlockAndValueMapping mapping;
|
||||
for (Value argI : op->getOperands()) {
|
||||
// Compute new argument types
|
||||
auto oldArgType = argI.getType().dyn_cast<RankedTensorType>();
|
||||
if (!oldArgType)
|
||||
continue;
|
||||
auto newEncoding = invertEncoding(cvt->getResultTypes()[0], op);
|
||||
auto newArgType = RankedTensorType::get(
|
||||
oldArgType.getShape(), oldArgType.getElementType(), newEncoding);
|
||||
// Create new argument
|
||||
auto cvtI = rewriter.create<triton::gpu::ConvertLayoutOp>(
|
||||
op->getLoc(), newArgType, argI);
|
||||
cvtI->moveBefore(op);
|
||||
mapping.map(argI, cvtI);
|
||||
for (int i = toConvert.size() - 1; i >= 0; i--) {
|
||||
// unpack information
|
||||
Value currOperand;
|
||||
Attribute targetLayout;
|
||||
std::tie(currOperand, targetLayout) = toConvert[i];
|
||||
// if (test)
|
||||
// llvm::outs() << "current " << currOperand << "\n";
|
||||
// rematerialize the operand if necessary
|
||||
Operation *currOperation = currOperand.getDefiningOp();
|
||||
if (processed.contains(currOperation)) {
|
||||
currOperation = cloneWithInferType(rewriter, currOperation, mapping);
|
||||
currOperand = currOperation->getResult(0);
|
||||
}
|
||||
if (i == 0)
|
||||
break;
|
||||
// compute target type for the layout cast
|
||||
auto currType = currOperand.getType().cast<RankedTensorType>();
|
||||
auto newType = RankedTensorType::get(
|
||||
currType.getShape(), currType.getElementType(), targetLayout);
|
||||
auto newOperand = rewriter.create<triton::gpu::ConvertLayoutOp>(
|
||||
currOperand.getLoc(), newType, currOperand);
|
||||
if (currOperation)
|
||||
newOperand->moveAfter(currOperation);
|
||||
mapping.map(currOperand, newOperand);
|
||||
}
|
||||
Operation *newOp = rewriter.clone(*op, mapping);
|
||||
newOp->getResult(0).setType(cvt->getResult(0).getType());
|
||||
rewriter.replaceOp(cvt, newOp->getResults());
|
||||
|
||||
rewriter.replaceOp(cvt, mapping.lookup(cvt->getOperand(0)));
|
||||
return mlir::success();
|
||||
}
|
||||
};
|
||||
@@ -226,6 +289,7 @@ bool tryLegalizeOp(Operation *op, DenseSet<Value> toPreserve,
|
||||
std::pair<SmallVector<Value, 4>, scf::ForOp>
|
||||
tryConvertIterArg(scf::ForOp &forOp, mlir::PatternRewriter &rewriter, size_t i,
|
||||
Type newType) {
|
||||
forOp.getInductionVar();
|
||||
auto newEncoding = newType.cast<RankedTensorType>().getEncoding();
|
||||
auto ctx = forOp.getContext();
|
||||
auto isInLoop = [&](Operation *op) { return op->getParentOp() == forOp; };
|
||||
@@ -243,6 +307,7 @@ tryConvertIterArg(scf::ForOp &forOp, mlir::PatternRewriter &rewriter, size_t i,
|
||||
BlockAndValueMapping mapping;
|
||||
for (const auto &arg : llvm::enumerate(forOp.getRegionIterArgs()))
|
||||
mapping.map(arg.value(), newForOp.getRegionIterArgs()[arg.index()]);
|
||||
mapping.map(forOp.getInductionVar(), newForOp.getInductionVar());
|
||||
// traverse all ops in the loop
|
||||
for (Operation &op : forOp.getBody()->without_terminator()) {
|
||||
// we clone the op
|
||||
@@ -278,9 +343,9 @@ tryConvertIterArg(scf::ForOp &forOp, mlir::PatternRewriter &rewriter, size_t i,
|
||||
return {newResults, newForOp};
|
||||
}
|
||||
|
||||
class MoveArgConvertOutOfLoop : public mlir::RewritePattern {
|
||||
class MoveConvertOutOfLoop : public mlir::RewritePattern {
|
||||
public:
|
||||
MoveArgConvertOutOfLoop(mlir::MLIRContext *context)
|
||||
MoveConvertOutOfLoop(mlir::MLIRContext *context)
|
||||
: mlir::RewritePattern(scf::ForOp::getOperationName(), 1, context) {}
|
||||
|
||||
mlir::LogicalResult matchAndRewrite(mlir::Operation *op,
|
||||
@@ -290,27 +355,14 @@ public:
|
||||
auto isInLoop = [&](Operation *op) { return op->getParentOp() == forOp; };
|
||||
auto iterArgs = forOp.getRegionIterArgs();
|
||||
for (auto iterArg : llvm::enumerate(iterArgs)) {
|
||||
// skip non-tensor types
|
||||
if (!iterArg.value().getType().isa<RankedTensorType>())
|
||||
continue;
|
||||
// check
|
||||
for (auto op : iterArg.value().getUsers()) {
|
||||
auto currOps = mlir::getSlice(op, isInLoop);
|
||||
auto pred = [&](Operation *op) {
|
||||
return isa<triton::LoadOp, triton::StoreOp>(op);
|
||||
};
|
||||
auto isCvt = [&](Operation *op) {
|
||||
return isa<triton::gpu::ConvertLayoutOp>(op);
|
||||
};
|
||||
auto isYield = [&](Operation *op) { return isa<scf::YieldOp>(op); };
|
||||
auto opIt = std::find(currOps.begin(), currOps.end(), op);
|
||||
auto yieldIt = std::find_if(currOps.begin(), currOps.end(), isYield);
|
||||
auto fwdEndIt = std::find_if(opIt, currOps.end(), pred);
|
||||
auto bwdBeginIt = std::find_if(currOps.begin(), opIt, pred);
|
||||
auto fwdCvtIt = std::find_if(opIt, fwdEndIt, isCvt);
|
||||
auto bwdCvtIt = std::find_if(bwdBeginIt, opIt, isCvt);
|
||||
|
||||
if (!iterArg.value().getType().isa<RankedTensorType>())
|
||||
continue;
|
||||
if (fwdCvtIt != fwdEndIt) {
|
||||
if (isa<triton::gpu::ConvertLayoutOp>(op)) {
|
||||
auto newFor = tryConvertIterArg(forOp, rewriter, iterArg.index(),
|
||||
(*fwdCvtIt)->getResult(0).getType());
|
||||
op->getResult(0).getType());
|
||||
rewriter.replaceOp(forOp, newFor.first);
|
||||
return success();
|
||||
}
|
||||
@@ -324,9 +376,9 @@ public:
|
||||
//
|
||||
// -----------------------------------------------------------------------------
|
||||
|
||||
class PushConversionToSink : public mlir::RewritePattern {
|
||||
class RematerializeForward : public mlir::RewritePattern {
|
||||
public:
|
||||
PushConversionToSink(mlir::MLIRContext *context)
|
||||
RematerializeForward(mlir::MLIRContext *context)
|
||||
: mlir::RewritePattern(triton::gpu::ConvertLayoutOp::getOperationName(),
|
||||
2, context) {}
|
||||
|
||||
@@ -430,16 +482,17 @@ public:
|
||||
mlir::RewritePatternSet patterns(context);
|
||||
|
||||
patterns.add<SimplifyConversion>(context);
|
||||
patterns.add<PullConversionToSource>(context);
|
||||
patterns.add<PushConversionToSink>(context);
|
||||
patterns.add<MoveArgConvertOutOfLoop>(context);
|
||||
patterns.add<RematerializeBackward>(context);
|
||||
patterns.add<RematerializeForward>(context);
|
||||
patterns.add<MoveConvertOutOfLoop>(context);
|
||||
patterns.add<BlockedToMMA>(context);
|
||||
|
||||
if (applyPatternsAndFoldGreedily(m, std::move(patterns)).failed())
|
||||
if (applyPatternsAndFoldGreedily(m, std::move(patterns)).failed()) {
|
||||
signalPassFailure();
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
std::unique_ptr<Pass> mlir::createTritonGPUCombineOpsPass() {
|
||||
return std::make_unique<TritonGPUCombineOpsPass>();
|
||||
}
|
||||
}
|
@@ -45,8 +45,8 @@ func @remat(%arg0: i32) -> tensor<1024xi32, #layout1> {
|
||||
// CHECK: %1 = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32, [[target_layout]]>
|
||||
// CHECK: %2 = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32, [[target_layout]]>
|
||||
// CHECK: %3 = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32, [[target_layout]]>
|
||||
// CHECK: %4 = arith.muli %2, %3 : tensor<1024xi32, [[target_layout]]>
|
||||
// CHECK: %5 = arith.muli %0, %1 : tensor<1024xi32, [[target_layout]]>
|
||||
// CHECK: %4 = arith.muli %0, %2 : tensor<1024xi32, [[target_layout]]>
|
||||
// CHECK: %5 = arith.muli %1, %3 : tensor<1024xi32, [[target_layout]]>
|
||||
// CHECK: %6 = arith.addi %4, %5 : tensor<1024xi32, [[target_layout]]>
|
||||
// CHECK: return %6 : tensor<1024xi32, [[target_layout]]>
|
||||
}
|
||||
@@ -61,34 +61,10 @@ func @remat(%arg0: i32) -> tensor<1024xi32, #layout1> {
|
||||
|
||||
// CHECK-LABEL: transpose
|
||||
func @transpose(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg1: i32 {tt.divisibility = 16 : i32}, %arg2: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg3: i32 {tt.divisibility = 16 : i32}) {
|
||||
// CHECK: %cst = arith.constant dense<true> : tensor<64x64xi1, [[row_layout]]>
|
||||
// CHECK: %cst_0 = arith.constant dense<0.000000e+00> : tensor<64x64xf32, [[row_layout]]>
|
||||
// CHECK: %cst_1 = arith.constant dense<true> : tensor<64x64xi1, [[col_layout]]>
|
||||
// CHECK: %0 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 1, parent = [[col_layout]]}>>
|
||||
// CHECK: %1 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 1, parent = [[row_layout]]}>>
|
||||
// CHECK: %2 = tt.expand_dims %1 {axis = 1 : i32} : (tensor<64xi32, #triton_gpu.slice<{dim = 1, parent = [[row_layout]]}>>) -> tensor<64x1xi32, [[row_layout]]>
|
||||
// CHECK: %3 = tt.splat %arg1 : (i32) -> tensor<64x1xi32, [[row_layout]]>
|
||||
// CHECK: %4 = tt.splat %arg0 : (!tt.ptr<f32>) -> tensor<64x1x!tt.ptr<f32>, [[row_layout]]>
|
||||
// CHECK: %5 = arith.muli %2, %3 : tensor<64x1xi32, [[row_layout]]>
|
||||
// CHECK: %6 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = [[col_layout]]}>>
|
||||
// CHECK: %7 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = [[row_layout]]}>>
|
||||
// CHECK: %8 = tt.addptr %4, %5 : tensor<64x1x!tt.ptr<f32>, [[row_layout]]>
|
||||
// CHECK: %9 = tt.expand_dims %7 {axis = 0 : i32} : (tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = [[row_layout]]}>>) -> tensor<1x64xi32, [[row_layout]]>
|
||||
// CHECK: %10 = tt.broadcast %8 : (tensor<64x1x!tt.ptr<f32>, [[row_layout]]>) -> tensor<64x64x!tt.ptr<f32>, [[row_layout]]>
|
||||
// CHECK: %11 = tt.broadcast %9 : (tensor<1x64xi32, [[row_layout]]>) -> tensor<64x64xi32, [[row_layout]]>
|
||||
// CHECK: %12 = tt.splat %arg2 : (!tt.ptr<f32>) -> tensor<64x1x!tt.ptr<f32>, [[col_layout]]>
|
||||
// CHECK: %13 = tt.expand_dims %0 {axis = 1 : i32} : (tensor<64xi32, #triton_gpu.slice<{dim = 1, parent = [[col_layout]]}>>) -> tensor<64x1xi32, [[col_layout]]>
|
||||
// CHECK: %14 = tt.expand_dims %6 {axis = 0 : i32} : (tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = [[col_layout]]}>>) -> tensor<1x64xi32, [[col_layout]]>
|
||||
// CHECK: %15 = tt.splat %arg3 : (i32) -> tensor<1x64xi32, [[col_layout]]>
|
||||
// CHECK: %16 = tt.addptr %12, %13 : tensor<64x1x!tt.ptr<f32>, [[col_layout]]>
|
||||
// CHECK: %17 = arith.muli %14, %15 : tensor<1x64xi32, [[col_layout]]>
|
||||
// CHECK: %18 = tt.broadcast %16 : (tensor<64x1x!tt.ptr<f32>, [[col_layout]]>) -> tensor<64x64x!tt.ptr<f32>, [[col_layout]]>
|
||||
// CHECK: %19 = tt.broadcast %17 : (tensor<1x64xi32, [[col_layout]]>) -> tensor<64x64xi32, [[col_layout]]>
|
||||
// CHECK: %20 = tt.addptr %10, %11 : tensor<64x64x!tt.ptr<f32>, [[row_layout]]>
|
||||
// CHECK: %21 = tt.load %20, %cst, %cst_0 {cache = 1 : i32, evict = 1 : i32, isOtherUnspecified = false, isVolatile = false} : tensor<64x64xf32, [[row_layout]]>
|
||||
// CHECK: %22 = tt.addptr %18, %19 : tensor<64x64x!tt.ptr<f32>, [[col_layout]]>
|
||||
// CHECK: %23 = triton_gpu.convert_layout %21 : (tensor<64x64xf32, [[row_layout]]>) -> tensor<64x64xf32, [[col_layout]]>
|
||||
// CHECK: tt.store %22, %23, %cst_1 : tensor<64x64xf32, [[col_layout]]>
|
||||
// CHECK-NOT: triton_gpu.convert_layout
|
||||
// CHECK: [[loaded_val:%.*]] = tt.load {{.*}}, %cst, %cst_0 {cache = 1 : i32, evict = 1 : i32, isOtherUnspecified = false, isVolatile = false} : tensor<64x64xf32, [[row_layout]]>
|
||||
// CHECK: [[cvt_val:%.*]] = triton_gpu.convert_layout [[loaded_val]] : (tensor<64x64xf32, [[row_layout]]>) -> tensor<64x64xf32, [[col_layout]]>
|
||||
// CHECK: tt.store {{.*}}, [[cvt_val]], %cst_1 : tensor<64x64xf32, [[col_layout]]>
|
||||
// CHECK: return
|
||||
%cst = arith.constant dense<0.000000e+00> : tensor<64x64xf32, #blocked1>
|
||||
%cst_0 = arith.constant dense<true> : tensor<64x64xi1, #blocked1>
|
||||
|
Reference in New Issue
Block a user