More progress on TritonGPUTypeConverter & TritonGPUConversionTarget

This commit is contained in:
Yan Da
2022-05-01 22:06:54 +08:00
parent 4ece9fd1f3
commit 1428185c9c
12 changed files with 182 additions and 22 deletions

View File

@@ -9,6 +9,7 @@
#include "mlir/Dialect/SCF/SCF.h"
#include "triton/Dialect/Triton/IR/Traits.h"
#include "triton/Dialect/Triton/IR/Types.h"
#include "triton/Dialect/Triton/IR/Dialect.h.inc"
#include "triton/Dialect/Triton/IR/OpsEnums.h.inc"

View File

@@ -9,6 +9,9 @@
#include "triton/Dialect/TritonGPU/IR/Dialect.h.inc"
#define GET_ATTRDEF_CLASSES
#include "triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.h.inc"
#define GET_OP_CLASSES
#include "triton/Dialect/TritonGPU/IR/Ops.h.inc"

View File

@@ -2,6 +2,7 @@
#define TRITONGPU_ATTRDEFS
include "TritonGPUDialect.td"
// include "mlir/IR/TensorEncoding.td"
class TritonGPU_Attr<string name, list<Trait> traits = []>
: AttrDef<TritonGPU_Dialect, name, traits>;
@@ -43,7 +44,7 @@ And the associated TritonGPU MLIR
);
}
def TritonGPUCoalescedEncodingAttr : TritonGPU_Attr<"TritonGPUCoalescedEncoding"> {
def TritonGPUDistributedEncodingAttr : TritonGPU_Attr<"TritonGPUDistributedEncoding"> {
let mnemonic = "coalesced encoding";
let description = [{
@@ -81,7 +82,9 @@ And the associated TritonGPU MLIR
let parameters = (
ins
ArrayRefParameter<"unsigned">:$threadTileSize,
ArrayRefParameter<"unsigned">:$blockTileSize
ArrayRefParameter<"unsigned">:$blockTileSize,
// fastest-changing axis first
ArrayRefParameter<"unsigned">:$order
);
// let genVerifyDecl = 1;
@@ -90,7 +93,7 @@ And the associated TritonGPU MLIR
def TritonGPUMmaEncodingAttr : TritonGPU_Attr<"TritonGPUMmaEncoding"> {
let mnemonic = "mma encoding";
let description = [{TODO: I think we may be able to implement it as a special-case of Coalesced encoding with maybe one more warpTileSize attribute!}];
let description = [{TODO: I think we may be able to implement it as a special-case of Distributed encoding with maybe one more warpTileSize attribute!}];
let parameters = (
ins

View File

@@ -2,6 +2,7 @@
#define TRITONGPU_OPS
include "triton/Dialect/TritonGPU/IR/TritonGPUDialect.td"
include "triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td"
include "triton/Dialect/Triton/IR/TritonTypes.td"
include "mlir/IR/OpBase.td"
include "mlir/Interfaces/SideEffectInterfaces.td" // NoSideEffect

View File

@@ -0,0 +1,29 @@
//===----------------------------------------------------------------------===//
//
// Defines utilities to use while converting to the TritonGPU dialect.
//
//===----------------------------------------------------------------------===//
#ifndef TRITON_DIALECT_TRITONGPU_TRANSFORMS_TRITONGPUCONVERSION_H_
#define TRITON_DIALECT_TRITONGPU_TRANSFORMS_TRITONGPUCONVERSION_H_
#include "mlir/Transforms/DialectConversion.h"
namespace mlir {
class TritonGPUTypeConverter : public TypeConverter {
public:
TritonGPUTypeConverter(MLIRContext *context, int numThreads);
private:
MLIRContext *context;
int numThreads;
};
class TritonGPUConversionTarget : public ConversionTarget {
public:
explicit TritonGPUConversionTarget(MLIRContext &ctx);
};
} // namespace mlir
#endif // TRITON_DIALECT_TRITONGPU_TRANSFORMS_TRITONGPUCONVERSION_H_

View File

@@ -15,4 +15,5 @@ add_mlir_conversion_library(TritonToTritonGPU
MLIRPass
TritonIR
TritonGPUIR
)
TritonGPUConversion
)

View File

@@ -1,7 +1,8 @@
#include "mlir/Transforms/DialectConversion.h"
#include "triton/Dialect/Triton/IR/Dialect.h"
#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
#include "triton/Conversion/TritonToTritonGPU/TritonToTritonGPU.h"
#include "triton/Dialect/TritonGPU/Transforms/TritonGPUConversion.h"
#include "mlir/Transforms/DialectConversion.h"
#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
#include "../PassDetail.h"
using namespace mlir;
@@ -39,7 +40,7 @@ void populateArithmeticPatternsAndLegality(
target.addDynamicallyLegalDialect<arith::ArithmeticDialect>(
// TODO: check above rule here
[](Operation *op){
return false;
return true;
}
);
// Rewrite rule
@@ -47,26 +48,27 @@ void populateArithmeticPatternsAndLegality(
}
class ConvertTritonToTritonGPU:
class ConvertTritonToTritonGPU :
public ConvertTritonToTritonGPUBase<ConvertTritonToTritonGPU> {
public:
void runOnOperation() override {
MLIRContext *context = &getContext();
ConversionTarget target(*context);
// type converter
TypeConverter typeConverter;
// rewrite patterns
RewritePatternSet patterns(context);
// add rules
populateArithmeticPatternsAndLegality(typeConverter, patterns, target);
void runOnOperation() override {
MLIRContext *context = &getContext();
TritonGPUConversionTarget target(*context);
ModuleOp mod = getOperation();
// int numThreads = mod.getAttr();
// type converter
TritonGPUTypeConverter typeConverter(context, /*numThreads*/4*32);
// rewrite patterns
RewritePatternSet patterns(context);
// add rules
populateArithmeticPatternsAndLegality(typeConverter, patterns, target);
if(failed(applyPartialConversion(getOperation(), target,
std::move(patterns))))
return signalPassFailure();
}
if(failed(applyPartialConversion(getOperation(), target,
std::move(patterns))))
return signalPassFailure();
}
};
}

View File

@@ -1 +1,2 @@
add_subdirectory(IR)
add_subdirectory(Transforms)

View File

@@ -1,12 +1,52 @@
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
#include "mlir/IR/DialectImplementation.h"
#include "llvm/ADT/TypeSwitch.h"
#include <llvm-6.0/llvm/Support/ErrorHandling.h>
#include "triton/Dialect/TritonGPU/IR/Dialect.cpp.inc"
using namespace mlir::triton::gpu;
//===----------------------------------------------------------------------===//
// Attribute methods
//===----------------------------------------------------------------------===//
#define GET_ATTRDEF_CLASSES
#include "triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.cpp.inc"
mlir::Attribute
TritonGPUDistributedEncodingAttr::parse(mlir::AsmParser &parser, mlir::Type type) {
llvm_unreachable("Not implemented");
}
void TritonGPUDistributedEncodingAttr::print(mlir::AsmPrinter &printer) const {
llvm_unreachable("Not implemented");
}
mlir::Attribute
TritonGPUMmaEncodingAttr::parse(mlir::AsmParser &parser, ::mlir::Type type) {
llvm_unreachable("Not implemented");
}
void TritonGPUMmaEncodingAttr::print(mlir::AsmPrinter &printer) const {
llvm_unreachable("Not implemented");
}
mlir::Attribute
TritonGPUSharedEncodingAttr::parse(mlir::AsmParser &parser, ::mlir::Type type) {
llvm_unreachable("Not implemented");
}
void TritonGPUSharedEncodingAttr::print(mlir::AsmPrinter &printer) const {
llvm_unreachable("Not implemented");
}
void TritonGPUDialect::initialize() {
addOperations<
#define GET_OP_LIST
#include "triton/Dialect/TritonGPU/IR/Ops.cpp.inc"
>();
}
#define GET_OP_CLASSES
#include "triton/Dialect/TritonGPU/IR/Ops.cpp.inc"

View File

@@ -0,0 +1,10 @@
add_mlir_dialect_library(TritonGPUConversion
TritonGPUConversion.cpp
# ADDITIONAL_HEADER_DIRS
LINK_LIBS PUBLIC
TritonIR
TritonGPUIR
# MLIRTransformUtils
)

View File

@@ -0,0 +1,68 @@
#include "triton/Dialect/TritonGPU/Transforms/TritonGPUConversion.h"
#include "triton/Dialect/Triton/IR/Dialect.h"
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
#include <algorithm>
using namespace mlir;
//
// TypeConverter
//
TritonGPUTypeConverter::TritonGPUTypeConverter(MLIRContext *context,
int numThreads)
: context(context), numThreads(numThreads) {
addConversion([&](RankedTensorType tensorType) -> RankedTensorType {
llvm::ArrayRef<int64_t> shape = tensorType.getShape();
Type elementType = tensorType.getElementType();
int64_t rank = tensorType.getRank();
int64_t numElements = tensorType.getNumElements();
// TODO: we should raise exception here.
assert(numElements > numThreads);
assert(numElements % numThreads == 0);
// assert no encoding?
// Now we assume:
// contiguous = 1, order = 0, 1, 2, ...,
llvm::SmallVector<unsigned> threadTileSize(rank, 1); // naive layout
llvm::SmallVector<unsigned> blockTileSize(rank);
llvm::SmallVector<unsigned> order(rank);
int remainingThreads = numThreads;
for (int64_t dim = 0; dim < rank; ++dim) {
blockTileSize[dim] = std::clamp(remainingThreads, 1, int(shape[dim]));
order[dim] = dim;
remainingThreads /= blockTileSize[dim];
// TODO: will we need repetition?
}
Attribute encoding = triton::gpu::TritonGPUDistributedEncodingAttr::get(
context, threadTileSize, blockTileSize, order);
return RankedTensorType::get(shape, elementType, encoding);
});
}
//
// TritonGPUConversion
//
TritonGPUConversionTarget::TritonGPUConversionTarget(MLIRContext &context)
: ConversionTarget(context) {
addLegalDialect<triton::TritonDialect,
arith::ArithmeticDialect,
scf::SCFDialect>();
// Some ops from SCF are illegal
addIllegalOp<scf::ExecuteRegionOp, scf::ParallelOp,
scf::ReduceOp, scf::ReduceReturnOp>();
// // We have requirements for the data layouts
// addDynamicallyLegalOp<triton::DotOp>([](triton::DotOp dotOp) -> bool {
// Attribute aEncoding = dotOp.a().getType().cast<RankedTensorType>().getEncoding();
// Attribute bEncoding = dotOp.b().getType().cast<RankedTensorType>().getEncoding();
// if (aEncoding && aEncoding.isa<triton::gpu::TritonGPUSharedEncodingAttr>() &&
// bEncoding && bEncoding.isa<triton::gpu::TritonGPUSharedEncodingAttr>())
// return true;
// return false;
// });
}

View File

@@ -101,6 +101,7 @@ pm = _triton.ir.pass_manager(ctx)
pm.add_inliner_pass()
pm.add_triton_combine_pass()
pm.add_canonicalizer_pass()
pm.add_convert_triton_to_tritongpu_pass()
pm.run(mod)
assert mod.verify()