This PR merges the `triton-mlir` branch, in which we have been quietly rewriting the Triton backend from scratch to increase maintainability, stability and ultimately performance. Changes to the runtime are minimal, and this new version aims to remain backward-compatible with the previous commit. The legacy backend is now officially deprecated, but can still be accessed via the `legacy-backend` tag. Co-authored-by: Keren Zhou <kerenzhou@openai.com> Co-authored-by: Yan Chunwei <yanchunwei@outlook.com> Co-authored-by: goostavz <109190422+goostavz@users.noreply.github.com> Co-authored-by: Shintaro Iwasaki <siwasaki@fb.com> Co-authored-by: Yan Da <dyanab@connect.ust.hk> Co-authored-by: Jun Yang <yangjunpro@gmail.com> Co-authored-by: Ian Bearman <ianb@microsoft.com> Co-authored-by: Jason Ansel <jansel@jansel.net> Co-authored-by: Qingyi Liu <qingyil@nvidia.com> Co-authored-by: ben-zhang-609 <110140741+ben-zhang-609@users.noreply.github.com> Co-authored-by: Chenggang Zhao <lyricz@yeah.net> Co-authored-by: ben-zhang-609 <benzh609@gmail.com> Co-authored-by: dongdongl <dongdongl@nvidia.com>
55 lines
2.0 KiB
C++
55 lines
2.0 KiB
C++
#include "mlir/Analysis/SliceAnalysis.h"
|
|
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
|
|
#include "triton/Dialect/TritonGPU/Transforms/Passes.h"
|
|
|
|
using namespace mlir;
|
|
using namespace mlir::triton;
|
|
|
|
#define GEN_PASS_CLASSES
|
|
#include "triton/Dialect/TritonGPU/Transforms/Passes.h.inc"
|
|
|
|
namespace {
|
|
|
|
struct CanonicalizePass
|
|
: public TritonGPUCanonicalizeLoopsBase<CanonicalizePass> {
|
|
CanonicalizePass() = default;
|
|
|
|
void runOnOperation() override {
|
|
|
|
// Canonicalize pass may have created dead code that
|
|
// standard scf.for canonicalization cannot handle
|
|
// as of LLVM 14. For example, the iteration arguments
|
|
// for the pointer of the synchronous loads that are
|
|
// discarded.
|
|
// The following piece of code is a workaround to
|
|
// very crudely remove dead code, by making an iteration
|
|
// argument yield itself if it is not used to create
|
|
// side effects anywhere.
|
|
getOperation()->walk([&](scf::ForOp forOp) -> void {
|
|
for (size_t i = 0; i < forOp.getNumResults(); ++i) {
|
|
// condition 1: no other iter arguments depend on it
|
|
SetVector<Operation *> fwdSlice;
|
|
mlir::getForwardSlice(forOp.getRegionIterArgs()[i], &fwdSlice);
|
|
Operation *yieldOp = forOp.getBody()->getTerminator();
|
|
bool noOtherDependency = std::all_of(
|
|
yieldOp->operand_begin(), yieldOp->operand_end(), [&](Value arg) {
|
|
return arg == yieldOp->getOperand(i) ||
|
|
!fwdSlice.contains(arg.getDefiningOp());
|
|
});
|
|
// condition 2: final value is not used after the loop
|
|
auto retVal = forOp.getResult(i);
|
|
bool noUserAfterLoop = retVal.getUsers().empty();
|
|
// yielding the region iter arg will cause loop canonicalization
|
|
// to clean up the dead code
|
|
if (noOtherDependency && noUserAfterLoop) {
|
|
yieldOp->setOperand(i, forOp.getRegionIterArgs()[i]);
|
|
}
|
|
}
|
|
});
|
|
}
|
|
};
|
|
} // anonymous namespace
|
|
|
|
std::unique_ptr<Pass> mlir::createTritonGPUCanonicalizeLoopsPass() {
|
|
return std::make_unique<CanonicalizePass>();
|
|
} |