More progress on TritonGPUTypeConverter & TritonGPUConversionTarget
This commit is contained in:
@@ -9,6 +9,7 @@
|
|||||||
#include "mlir/Dialect/SCF/SCF.h"
|
#include "mlir/Dialect/SCF/SCF.h"
|
||||||
|
|
||||||
#include "triton/Dialect/Triton/IR/Traits.h"
|
#include "triton/Dialect/Triton/IR/Traits.h"
|
||||||
|
#include "triton/Dialect/Triton/IR/Types.h"
|
||||||
#include "triton/Dialect/Triton/IR/Dialect.h.inc"
|
#include "triton/Dialect/Triton/IR/Dialect.h.inc"
|
||||||
#include "triton/Dialect/Triton/IR/OpsEnums.h.inc"
|
#include "triton/Dialect/Triton/IR/OpsEnums.h.inc"
|
||||||
|
|
||||||
|
@@ -9,6 +9,9 @@
|
|||||||
|
|
||||||
#include "triton/Dialect/TritonGPU/IR/Dialect.h.inc"
|
#include "triton/Dialect/TritonGPU/IR/Dialect.h.inc"
|
||||||
|
|
||||||
|
#define GET_ATTRDEF_CLASSES
|
||||||
|
#include "triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.h.inc"
|
||||||
|
|
||||||
#define GET_OP_CLASSES
|
#define GET_OP_CLASSES
|
||||||
#include "triton/Dialect/TritonGPU/IR/Ops.h.inc"
|
#include "triton/Dialect/TritonGPU/IR/Ops.h.inc"
|
||||||
|
|
||||||
|
@@ -2,6 +2,7 @@
|
|||||||
#define TRITONGPU_ATTRDEFS
|
#define TRITONGPU_ATTRDEFS
|
||||||
|
|
||||||
include "TritonGPUDialect.td"
|
include "TritonGPUDialect.td"
|
||||||
|
// include "mlir/IR/TensorEncoding.td"
|
||||||
|
|
||||||
class TritonGPU_Attr<string name, list<Trait> traits = []>
|
class TritonGPU_Attr<string name, list<Trait> traits = []>
|
||||||
: AttrDef<TritonGPU_Dialect, name, traits>;
|
: AttrDef<TritonGPU_Dialect, name, traits>;
|
||||||
@@ -43,7 +44,7 @@ And the associated TritonGPU MLIR
|
|||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
def TritonGPUCoalescedEncodingAttr : TritonGPU_Attr<"TritonGPUCoalescedEncoding"> {
|
def TritonGPUDistributedEncodingAttr : TritonGPU_Attr<"TritonGPUDistributedEncoding"> {
|
||||||
let mnemonic = "coalesced encoding";
|
let mnemonic = "coalesced encoding";
|
||||||
|
|
||||||
let description = [{
|
let description = [{
|
||||||
@@ -81,7 +82,9 @@ And the associated TritonGPU MLIR
|
|||||||
let parameters = (
|
let parameters = (
|
||||||
ins
|
ins
|
||||||
ArrayRefParameter<"unsigned">:$threadTileSize,
|
ArrayRefParameter<"unsigned">:$threadTileSize,
|
||||||
ArrayRefParameter<"unsigned">:$blockTileSize
|
ArrayRefParameter<"unsigned">:$blockTileSize,
|
||||||
|
// fastest-changing axis first
|
||||||
|
ArrayRefParameter<"unsigned">:$order
|
||||||
);
|
);
|
||||||
|
|
||||||
// let genVerifyDecl = 1;
|
// let genVerifyDecl = 1;
|
||||||
@@ -90,7 +93,7 @@ And the associated TritonGPU MLIR
|
|||||||
def TritonGPUMmaEncodingAttr : TritonGPU_Attr<"TritonGPUMmaEncoding"> {
|
def TritonGPUMmaEncodingAttr : TritonGPU_Attr<"TritonGPUMmaEncoding"> {
|
||||||
let mnemonic = "mma encoding";
|
let mnemonic = "mma encoding";
|
||||||
|
|
||||||
let description = [{TODO: I think we may be able to implement it as a special-case of Coalesced encoding with maybe one more warpTileSize attribute!}];
|
let description = [{TODO: I think we may be able to implement it as a special-case of Distributed encoding with maybe one more warpTileSize attribute!}];
|
||||||
|
|
||||||
let parameters = (
|
let parameters = (
|
||||||
ins
|
ins
|
||||||
|
@@ -2,6 +2,7 @@
|
|||||||
#define TRITONGPU_OPS
|
#define TRITONGPU_OPS
|
||||||
|
|
||||||
include "triton/Dialect/TritonGPU/IR/TritonGPUDialect.td"
|
include "triton/Dialect/TritonGPU/IR/TritonGPUDialect.td"
|
||||||
|
include "triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td"
|
||||||
include "triton/Dialect/Triton/IR/TritonTypes.td"
|
include "triton/Dialect/Triton/IR/TritonTypes.td"
|
||||||
include "mlir/IR/OpBase.td"
|
include "mlir/IR/OpBase.td"
|
||||||
include "mlir/Interfaces/SideEffectInterfaces.td" // NoSideEffect
|
include "mlir/Interfaces/SideEffectInterfaces.td" // NoSideEffect
|
||||||
|
@@ -0,0 +1,29 @@
|
|||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
//
|
||||||
|
// Defines utilities to use while converting to the TritonGPU dialect.
|
||||||
|
//
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
#ifndef TRITON_DIALECT_TRITONGPU_TRANSFORMS_TRITONGPUCONVERSION_H_
|
||||||
|
#define TRITON_DIALECT_TRITONGPU_TRANSFORMS_TRITONGPUCONVERSION_H_
|
||||||
|
|
||||||
|
#include "mlir/Transforms/DialectConversion.h"
|
||||||
|
|
||||||
|
namespace mlir {
|
||||||
|
|
||||||
|
class TritonGPUTypeConverter : public TypeConverter {
|
||||||
|
public:
|
||||||
|
TritonGPUTypeConverter(MLIRContext *context, int numThreads);
|
||||||
|
private:
|
||||||
|
MLIRContext *context;
|
||||||
|
int numThreads;
|
||||||
|
};
|
||||||
|
|
||||||
|
class TritonGPUConversionTarget : public ConversionTarget {
|
||||||
|
public:
|
||||||
|
explicit TritonGPUConversionTarget(MLIRContext &ctx);
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace mlir
|
||||||
|
|
||||||
|
#endif // TRITON_DIALECT_TRITONGPU_TRANSFORMS_TRITONGPUCONVERSION_H_
|
@@ -15,4 +15,5 @@ add_mlir_conversion_library(TritonToTritonGPU
|
|||||||
MLIRPass
|
MLIRPass
|
||||||
TritonIR
|
TritonIR
|
||||||
TritonGPUIR
|
TritonGPUIR
|
||||||
|
TritonGPUConversion
|
||||||
)
|
)
|
@@ -1,7 +1,8 @@
|
|||||||
#include "mlir/Transforms/DialectConversion.h"
|
|
||||||
#include "triton/Dialect/Triton/IR/Dialect.h"
|
#include "triton/Dialect/Triton/IR/Dialect.h"
|
||||||
#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
|
|
||||||
#include "triton/Conversion/TritonToTritonGPU/TritonToTritonGPU.h"
|
#include "triton/Conversion/TritonToTritonGPU/TritonToTritonGPU.h"
|
||||||
|
#include "triton/Dialect/TritonGPU/Transforms/TritonGPUConversion.h"
|
||||||
|
#include "mlir/Transforms/DialectConversion.h"
|
||||||
|
#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
|
||||||
#include "../PassDetail.h"
|
#include "../PassDetail.h"
|
||||||
|
|
||||||
using namespace mlir;
|
using namespace mlir;
|
||||||
@@ -39,7 +40,7 @@ void populateArithmeticPatternsAndLegality(
|
|||||||
target.addDynamicallyLegalDialect<arith::ArithmeticDialect>(
|
target.addDynamicallyLegalDialect<arith::ArithmeticDialect>(
|
||||||
// TODO: check above rule here
|
// TODO: check above rule here
|
||||||
[](Operation *op){
|
[](Operation *op){
|
||||||
return false;
|
return true;
|
||||||
}
|
}
|
||||||
);
|
);
|
||||||
// Rewrite rule
|
// Rewrite rule
|
||||||
@@ -53,9 +54,11 @@ class ConvertTritonToTritonGPU:
|
|||||||
public:
|
public:
|
||||||
void runOnOperation() override {
|
void runOnOperation() override {
|
||||||
MLIRContext *context = &getContext();
|
MLIRContext *context = &getContext();
|
||||||
ConversionTarget target(*context);
|
TritonGPUConversionTarget target(*context);
|
||||||
|
ModuleOp mod = getOperation();
|
||||||
|
// int numThreads = mod.getAttr();
|
||||||
// type converter
|
// type converter
|
||||||
TypeConverter typeConverter;
|
TritonGPUTypeConverter typeConverter(context, /*numThreads*/4*32);
|
||||||
// rewrite patterns
|
// rewrite patterns
|
||||||
RewritePatternSet patterns(context);
|
RewritePatternSet patterns(context);
|
||||||
// add rules
|
// add rules
|
||||||
@@ -65,7 +68,6 @@ public:
|
|||||||
if(failed(applyPartialConversion(getOperation(), target,
|
if(failed(applyPartialConversion(getOperation(), target,
|
||||||
std::move(patterns))))
|
std::move(patterns))))
|
||||||
return signalPassFailure();
|
return signalPassFailure();
|
||||||
|
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@@ -1 +1,2 @@
|
|||||||
add_subdirectory(IR)
|
add_subdirectory(IR)
|
||||||
|
add_subdirectory(Transforms)
|
||||||
|
@@ -1,12 +1,52 @@
|
|||||||
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
|
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
|
||||||
|
#include "mlir/IR/DialectImplementation.h"
|
||||||
|
#include "llvm/ADT/TypeSwitch.h"
|
||||||
|
#include <llvm-6.0/llvm/Support/ErrorHandling.h>
|
||||||
|
|
||||||
#include "triton/Dialect/TritonGPU/IR/Dialect.cpp.inc"
|
#include "triton/Dialect/TritonGPU/IR/Dialect.cpp.inc"
|
||||||
|
|
||||||
using namespace mlir::triton::gpu;
|
using namespace mlir::triton::gpu;
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// Attribute methods
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
#define GET_ATTRDEF_CLASSES
|
||||||
|
#include "triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.cpp.inc"
|
||||||
|
|
||||||
|
mlir::Attribute
|
||||||
|
TritonGPUDistributedEncodingAttr::parse(mlir::AsmParser &parser, mlir::Type type) {
|
||||||
|
llvm_unreachable("Not implemented");
|
||||||
|
}
|
||||||
|
|
||||||
|
void TritonGPUDistributedEncodingAttr::print(mlir::AsmPrinter &printer) const {
|
||||||
|
llvm_unreachable("Not implemented");
|
||||||
|
}
|
||||||
|
|
||||||
|
mlir::Attribute
|
||||||
|
TritonGPUMmaEncodingAttr::parse(mlir::AsmParser &parser, ::mlir::Type type) {
|
||||||
|
llvm_unreachable("Not implemented");
|
||||||
|
}
|
||||||
|
|
||||||
|
void TritonGPUMmaEncodingAttr::print(mlir::AsmPrinter &printer) const {
|
||||||
|
llvm_unreachable("Not implemented");
|
||||||
|
}
|
||||||
|
|
||||||
|
mlir::Attribute
|
||||||
|
TritonGPUSharedEncodingAttr::parse(mlir::AsmParser &parser, ::mlir::Type type) {
|
||||||
|
llvm_unreachable("Not implemented");
|
||||||
|
}
|
||||||
|
|
||||||
|
void TritonGPUSharedEncodingAttr::print(mlir::AsmPrinter &printer) const {
|
||||||
|
llvm_unreachable("Not implemented");
|
||||||
|
}
|
||||||
|
|
||||||
void TritonGPUDialect::initialize() {
|
void TritonGPUDialect::initialize() {
|
||||||
addOperations<
|
addOperations<
|
||||||
#define GET_OP_LIST
|
#define GET_OP_LIST
|
||||||
#include "triton/Dialect/TritonGPU/IR/Ops.cpp.inc"
|
#include "triton/Dialect/TritonGPU/IR/Ops.cpp.inc"
|
||||||
>();
|
>();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
#define GET_OP_CLASSES
|
||||||
|
#include "triton/Dialect/TritonGPU/IR/Ops.cpp.inc"
|
||||||
|
10
lib/Dialect/TritonGPU/Transforms/CMakeLists.txt
Normal file
10
lib/Dialect/TritonGPU/Transforms/CMakeLists.txt
Normal file
@@ -0,0 +1,10 @@
|
|||||||
|
add_mlir_dialect_library(TritonGPUConversion
|
||||||
|
TritonGPUConversion.cpp
|
||||||
|
|
||||||
|
# ADDITIONAL_HEADER_DIRS
|
||||||
|
|
||||||
|
LINK_LIBS PUBLIC
|
||||||
|
TritonIR
|
||||||
|
TritonGPUIR
|
||||||
|
# MLIRTransformUtils
|
||||||
|
)
|
68
lib/Dialect/TritonGPU/Transforms/TritonGPUConversion.cpp
Normal file
68
lib/Dialect/TritonGPU/Transforms/TritonGPUConversion.cpp
Normal file
@@ -0,0 +1,68 @@
|
|||||||
|
#include "triton/Dialect/TritonGPU/Transforms/TritonGPUConversion.h"
|
||||||
|
#include "triton/Dialect/Triton/IR/Dialect.h"
|
||||||
|
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
|
||||||
|
#include <algorithm>
|
||||||
|
|
||||||
|
using namespace mlir;
|
||||||
|
|
||||||
|
//
|
||||||
|
// TypeConverter
|
||||||
|
//
|
||||||
|
TritonGPUTypeConverter::TritonGPUTypeConverter(MLIRContext *context,
|
||||||
|
int numThreads)
|
||||||
|
: context(context), numThreads(numThreads) {
|
||||||
|
addConversion([&](RankedTensorType tensorType) -> RankedTensorType {
|
||||||
|
llvm::ArrayRef<int64_t> shape = tensorType.getShape();
|
||||||
|
Type elementType = tensorType.getElementType();
|
||||||
|
int64_t rank = tensorType.getRank();
|
||||||
|
int64_t numElements = tensorType.getNumElements();
|
||||||
|
|
||||||
|
// TODO: we should raise exception here.
|
||||||
|
assert(numElements > numThreads);
|
||||||
|
assert(numElements % numThreads == 0);
|
||||||
|
|
||||||
|
// assert no encoding?
|
||||||
|
|
||||||
|
// Now we assume:
|
||||||
|
// contiguous = 1, order = 0, 1, 2, ...,
|
||||||
|
llvm::SmallVector<unsigned> threadTileSize(rank, 1); // naive layout
|
||||||
|
llvm::SmallVector<unsigned> blockTileSize(rank);
|
||||||
|
llvm::SmallVector<unsigned> order(rank);
|
||||||
|
int remainingThreads = numThreads;
|
||||||
|
for (int64_t dim = 0; dim < rank; ++dim) {
|
||||||
|
blockTileSize[dim] = std::clamp(remainingThreads, 1, int(shape[dim]));
|
||||||
|
order[dim] = dim;
|
||||||
|
|
||||||
|
remainingThreads /= blockTileSize[dim];
|
||||||
|
// TODO: will we need repetition?
|
||||||
|
}
|
||||||
|
Attribute encoding = triton::gpu::TritonGPUDistributedEncodingAttr::get(
|
||||||
|
context, threadTileSize, blockTileSize, order);
|
||||||
|
return RankedTensorType::get(shape, elementType, encoding);
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
//
|
||||||
|
// TritonGPUConversion
|
||||||
|
//
|
||||||
|
TritonGPUConversionTarget::TritonGPUConversionTarget(MLIRContext &context)
|
||||||
|
: ConversionTarget(context) {
|
||||||
|
addLegalDialect<triton::TritonDialect,
|
||||||
|
arith::ArithmeticDialect,
|
||||||
|
scf::SCFDialect>();
|
||||||
|
|
||||||
|
// Some ops from SCF are illegal
|
||||||
|
addIllegalOp<scf::ExecuteRegionOp, scf::ParallelOp,
|
||||||
|
scf::ReduceOp, scf::ReduceReturnOp>();
|
||||||
|
|
||||||
|
// // We have requirements for the data layouts
|
||||||
|
// addDynamicallyLegalOp<triton::DotOp>([](triton::DotOp dotOp) -> bool {
|
||||||
|
// Attribute aEncoding = dotOp.a().getType().cast<RankedTensorType>().getEncoding();
|
||||||
|
// Attribute bEncoding = dotOp.b().getType().cast<RankedTensorType>().getEncoding();
|
||||||
|
// if (aEncoding && aEncoding.isa<triton::gpu::TritonGPUSharedEncodingAttr>() &&
|
||||||
|
// bEncoding && bEncoding.isa<triton::gpu::TritonGPUSharedEncodingAttr>())
|
||||||
|
// return true;
|
||||||
|
// return false;
|
||||||
|
// });
|
||||||
|
|
||||||
|
}
|
@@ -101,6 +101,7 @@ pm = _triton.ir.pass_manager(ctx)
|
|||||||
pm.add_inliner_pass()
|
pm.add_inliner_pass()
|
||||||
pm.add_triton_combine_pass()
|
pm.add_triton_combine_pass()
|
||||||
pm.add_canonicalizer_pass()
|
pm.add_canonicalizer_pass()
|
||||||
|
pm.add_convert_triton_to_tritongpu_pass()
|
||||||
pm.run(mod)
|
pm.run(mod)
|
||||||
|
|
||||||
assert mod.verify()
|
assert mod.verify()
|
||||||
|
Reference in New Issue
Block a user