[BACKEND] Extracting numWarps from tritonGPU module (#39)

This commit is contained in:
Yan Chunwei
2022-08-09 00:40:20 +08:00
committed by GitHub
parent 920723cf3d
commit 83ef74f248
8 changed files with 66 additions and 17 deletions

View File

@@ -5,6 +5,7 @@
#include "triton/Dialect/Triton/IR/Dialect.h"
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
#include "triton/Dialect/TritonGPU/Transforms/TritonGPUConversion.h"
#include "llvm/ADT/APSInt.h"
#include <numeric>
using namespace mlir;
@@ -375,6 +376,8 @@ void populateSCFPatterns(TritonGPUTypeConverter &typeConverter,
class ConvertTritonToTritonGPU
: public ConvertTritonToTritonGPUBase<ConvertTritonToTritonGPU> {
public:
ConvertTritonToTritonGPU() = default;
// constructor with some parameters set explicitly.
ConvertTritonToTritonGPU(int numWarps) { this->numWarps = numWarps; }
void runOnOperation() override {
@@ -395,6 +398,13 @@ public:
if (failed(applyPartialConversion(mod, target, std::move(patterns))))
return signalPassFailure();
auto inti = llvm::APSInt(32, false);
auto i32_ty = IntegerType::get(mod->getContext(), 32);
mod->setAttr(
AttrNumWarpsName,
IntegerAttr::get(i32_ty, llvm::APInt(32, numWarps.getValue())));
// update layouts
// broadcast src => multicast, dst => broadcasted
// if (failed(target.refineLayouts(mod, numWarps)))
@@ -408,3 +418,8 @@ std::unique_ptr<OperationPass<ModuleOp>>
mlir::triton::createConvertTritonToTritonGPUPass(int numWarps) {
return std::make_unique<::ConvertTritonToTritonGPU>(numWarps);
}
std::unique_ptr<OperationPass<ModuleOp>>
mlir::triton::createConvertTritonToTritonGPUPass() {
return std::make_unique<::ConvertTritonToTritonGPU>();
}