More progress on TritonGPUTypeConverter & TritonGPUConversionTarget
This commit is contained in:
@@ -15,4 +15,5 @@ add_mlir_conversion_library(TritonToTritonGPU
|
||||
MLIRPass
|
||||
TritonIR
|
||||
TritonGPUIR
|
||||
)
|
||||
TritonGPUConversion
|
||||
)
|
||||
|
@@ -1,7 +1,8 @@
|
||||
#include "mlir/Transforms/DialectConversion.h"
|
||||
#include "triton/Dialect/Triton/IR/Dialect.h"
|
||||
#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
|
||||
#include "triton/Conversion/TritonToTritonGPU/TritonToTritonGPU.h"
|
||||
#include "triton/Dialect/TritonGPU/Transforms/TritonGPUConversion.h"
|
||||
#include "mlir/Transforms/DialectConversion.h"
|
||||
#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
|
||||
#include "../PassDetail.h"
|
||||
|
||||
using namespace mlir;
|
||||
@@ -39,7 +40,7 @@ void populateArithmeticPatternsAndLegality(
|
||||
target.addDynamicallyLegalDialect<arith::ArithmeticDialect>(
|
||||
// TODO: check above rule here
|
||||
[](Operation *op){
|
||||
return false;
|
||||
return true;
|
||||
}
|
||||
);
|
||||
// Rewrite rule
|
||||
@@ -47,26 +48,27 @@ void populateArithmeticPatternsAndLegality(
|
||||
}
|
||||
|
||||
|
||||
class ConvertTritonToTritonGPU:
|
||||
class ConvertTritonToTritonGPU :
|
||||
public ConvertTritonToTritonGPUBase<ConvertTritonToTritonGPU> {
|
||||
|
||||
public:
|
||||
void runOnOperation() override {
|
||||
MLIRContext *context = &getContext();
|
||||
ConversionTarget target(*context);
|
||||
// type converter
|
||||
TypeConverter typeConverter;
|
||||
// rewrite patterns
|
||||
RewritePatternSet patterns(context);
|
||||
// add rules
|
||||
populateArithmeticPatternsAndLegality(typeConverter, patterns, target);
|
||||
void runOnOperation() override {
|
||||
MLIRContext *context = &getContext();
|
||||
TritonGPUConversionTarget target(*context);
|
||||
ModuleOp mod = getOperation();
|
||||
// int numThreads = mod.getAttr();
|
||||
// type converter
|
||||
TritonGPUTypeConverter typeConverter(context, /*numThreads*/4*32);
|
||||
// rewrite patterns
|
||||
RewritePatternSet patterns(context);
|
||||
// add rules
|
||||
populateArithmeticPatternsAndLegality(typeConverter, patterns, target);
|
||||
|
||||
|
||||
if(failed(applyPartialConversion(getOperation(), target,
|
||||
std::move(patterns))))
|
||||
return signalPassFailure();
|
||||
|
||||
}
|
||||
if(failed(applyPartialConversion(getOperation(), target,
|
||||
std::move(patterns))))
|
||||
return signalPassFailure();
|
||||
}
|
||||
};
|
||||
|
||||
}
|
||||
|
@@ -1 +1,2 @@
|
||||
add_subdirectory(IR)
|
||||
add_subdirectory(Transforms)
|
||||
|
@@ -1,12 +1,52 @@
|
||||
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
|
||||
#include "mlir/IR/DialectImplementation.h"
|
||||
#include "llvm/ADT/TypeSwitch.h"
|
||||
#include <llvm-6.0/llvm/Support/ErrorHandling.h>
|
||||
|
||||
#include "triton/Dialect/TritonGPU/IR/Dialect.cpp.inc"
|
||||
|
||||
using namespace mlir::triton::gpu;
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Attribute methods
|
||||
//===----------------------------------------------------------------------===//
|
||||
#define GET_ATTRDEF_CLASSES
|
||||
#include "triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.cpp.inc"
|
||||
|
||||
mlir::Attribute
|
||||
TritonGPUDistributedEncodingAttr::parse(mlir::AsmParser &parser, mlir::Type type) {
|
||||
llvm_unreachable("Not implemented");
|
||||
}
|
||||
|
||||
void TritonGPUDistributedEncodingAttr::print(mlir::AsmPrinter &printer) const {
|
||||
llvm_unreachable("Not implemented");
|
||||
}
|
||||
|
||||
mlir::Attribute
|
||||
TritonGPUMmaEncodingAttr::parse(mlir::AsmParser &parser, ::mlir::Type type) {
|
||||
llvm_unreachable("Not implemented");
|
||||
}
|
||||
|
||||
void TritonGPUMmaEncodingAttr::print(mlir::AsmPrinter &printer) const {
|
||||
llvm_unreachable("Not implemented");
|
||||
}
|
||||
|
||||
mlir::Attribute
|
||||
TritonGPUSharedEncodingAttr::parse(mlir::AsmParser &parser, ::mlir::Type type) {
|
||||
llvm_unreachable("Not implemented");
|
||||
}
|
||||
|
||||
void TritonGPUSharedEncodingAttr::print(mlir::AsmPrinter &printer) const {
|
||||
llvm_unreachable("Not implemented");
|
||||
}
|
||||
|
||||
void TritonGPUDialect::initialize() {
|
||||
addOperations<
|
||||
#define GET_OP_LIST
|
||||
#include "triton/Dialect/TritonGPU/IR/Ops.cpp.inc"
|
||||
>();
|
||||
}
|
||||
|
||||
|
||||
#define GET_OP_CLASSES
|
||||
#include "triton/Dialect/TritonGPU/IR/Ops.cpp.inc"
|
||||
|
10
lib/Dialect/TritonGPU/Transforms/CMakeLists.txt
Normal file
10
lib/Dialect/TritonGPU/Transforms/CMakeLists.txt
Normal file
@@ -0,0 +1,10 @@
|
||||
add_mlir_dialect_library(TritonGPUConversion
|
||||
TritonGPUConversion.cpp
|
||||
|
||||
# ADDITIONAL_HEADER_DIRS
|
||||
|
||||
LINK_LIBS PUBLIC
|
||||
TritonIR
|
||||
TritonGPUIR
|
||||
# MLIRTransformUtils
|
||||
)
|
68
lib/Dialect/TritonGPU/Transforms/TritonGPUConversion.cpp
Normal file
68
lib/Dialect/TritonGPU/Transforms/TritonGPUConversion.cpp
Normal file
@@ -0,0 +1,68 @@
|
||||
#include "triton/Dialect/TritonGPU/Transforms/TritonGPUConversion.h"
|
||||
#include "triton/Dialect/Triton/IR/Dialect.h"
|
||||
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
|
||||
#include <algorithm>
|
||||
|
||||
using namespace mlir;
|
||||
|
||||
//
|
||||
// TypeConverter
|
||||
//
|
||||
TritonGPUTypeConverter::TritonGPUTypeConverter(MLIRContext *context,
|
||||
int numThreads)
|
||||
: context(context), numThreads(numThreads) {
|
||||
addConversion([&](RankedTensorType tensorType) -> RankedTensorType {
|
||||
llvm::ArrayRef<int64_t> shape = tensorType.getShape();
|
||||
Type elementType = tensorType.getElementType();
|
||||
int64_t rank = tensorType.getRank();
|
||||
int64_t numElements = tensorType.getNumElements();
|
||||
|
||||
// TODO: we should raise exception here.
|
||||
assert(numElements > numThreads);
|
||||
assert(numElements % numThreads == 0);
|
||||
|
||||
// assert no encoding?
|
||||
|
||||
// Now we assume:
|
||||
// contiguous = 1, order = 0, 1, 2, ...,
|
||||
llvm::SmallVector<unsigned> threadTileSize(rank, 1); // naive layout
|
||||
llvm::SmallVector<unsigned> blockTileSize(rank);
|
||||
llvm::SmallVector<unsigned> order(rank);
|
||||
int remainingThreads = numThreads;
|
||||
for (int64_t dim = 0; dim < rank; ++dim) {
|
||||
blockTileSize[dim] = std::clamp(remainingThreads, 1, int(shape[dim]));
|
||||
order[dim] = dim;
|
||||
|
||||
remainingThreads /= blockTileSize[dim];
|
||||
// TODO: will we need repetition?
|
||||
}
|
||||
Attribute encoding = triton::gpu::TritonGPUDistributedEncodingAttr::get(
|
||||
context, threadTileSize, blockTileSize, order);
|
||||
return RankedTensorType::get(shape, elementType, encoding);
|
||||
});
|
||||
}
|
||||
|
||||
//
|
||||
// TritonGPUConversion
|
||||
//
|
||||
TritonGPUConversionTarget::TritonGPUConversionTarget(MLIRContext &context)
|
||||
: ConversionTarget(context) {
|
||||
addLegalDialect<triton::TritonDialect,
|
||||
arith::ArithmeticDialect,
|
||||
scf::SCFDialect>();
|
||||
|
||||
// Some ops from SCF are illegal
|
||||
addIllegalOp<scf::ExecuteRegionOp, scf::ParallelOp,
|
||||
scf::ReduceOp, scf::ReduceReturnOp>();
|
||||
|
||||
// // We have requirements for the data layouts
|
||||
// addDynamicallyLegalOp<triton::DotOp>([](triton::DotOp dotOp) -> bool {
|
||||
// Attribute aEncoding = dotOp.a().getType().cast<RankedTensorType>().getEncoding();
|
||||
// Attribute bEncoding = dotOp.b().getType().cast<RankedTensorType>().getEncoding();
|
||||
// if (aEncoding && aEncoding.isa<triton::gpu::TritonGPUSharedEncodingAttr>() &&
|
||||
// bEncoding && bEncoding.isa<triton::gpu::TritonGPUSharedEncodingAttr>())
|
||||
// return true;
|
||||
// return false;
|
||||
// });
|
||||
|
||||
}
|
Reference in New Issue
Block a user