More progress on TritonGPUTypeConverter & TritonGPUConversionTarget

This commit is contained in:
Yan Da
2022-05-01 22:06:54 +08:00
parent 4ece9fd1f3
commit 1428185c9c
12 changed files with 182 additions and 22 deletions

View File

@@ -9,6 +9,7 @@
#include "mlir/Dialect/SCF/SCF.h"
#include "triton/Dialect/Triton/IR/Traits.h"
#include "triton/Dialect/Triton/IR/Types.h"
#include "triton/Dialect/Triton/IR/Dialect.h.inc"
#include "triton/Dialect/Triton/IR/OpsEnums.h.inc"

View File

@@ -9,6 +9,9 @@
#include "triton/Dialect/TritonGPU/IR/Dialect.h.inc"
#define GET_ATTRDEF_CLASSES
#include "triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.h.inc"
#define GET_OP_CLASSES
#include "triton/Dialect/TritonGPU/IR/Ops.h.inc"

View File

@@ -2,6 +2,7 @@
#define TRITONGPU_ATTRDEFS
include "TritonGPUDialect.td"
// include "mlir/IR/TensorEncoding.td"
class TritonGPU_Attr<string name, list<Trait> traits = []>
: AttrDef<TritonGPU_Dialect, name, traits>;
@@ -43,7 +44,7 @@ And the associated TritonGPU MLIR
);
}
def TritonGPUCoalescedEncodingAttr : TritonGPU_Attr<"TritonGPUCoalescedEncoding"> {
def TritonGPUDistributedEncodingAttr : TritonGPU_Attr<"TritonGPUDistributedEncoding"> {
let mnemonic = "coalesced encoding";
let description = [{
@@ -81,7 +82,9 @@ And the associated TritonGPU MLIR
let parameters = (
ins
ArrayRefParameter<"unsigned">:$threadTileSize,
ArrayRefParameter<"unsigned">:$blockTileSize
ArrayRefParameter<"unsigned">:$blockTileSize,
// fastest-changing axis first
ArrayRefParameter<"unsigned">:$order
);
// let genVerifyDecl = 1;
@@ -90,7 +93,7 @@ And the associated TritonGPU MLIR
def TritonGPUMmaEncodingAttr : TritonGPU_Attr<"TritonGPUMmaEncoding"> {
let mnemonic = "mma encoding";
let description = [{TODO: I think we may be able to implement it as a special-case of Coalesced encoding with maybe one more warpTileSize attribute!}];
let description = [{TODO: I think we may be able to implement it as a special-case of Distributed encoding with maybe one more warpTileSize attribute!}];
let parameters = (
ins

View File

@@ -2,6 +2,7 @@
#define TRITONGPU_OPS
include "triton/Dialect/TritonGPU/IR/TritonGPUDialect.td"
include "triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td"
include "triton/Dialect/Triton/IR/TritonTypes.td"
include "mlir/IR/OpBase.td"
include "mlir/Interfaces/SideEffectInterfaces.td" // NoSideEffect

View File

@@ -0,0 +1,29 @@
//===----------------------------------------------------------------------===//
//
// Defines utilities to use while converting to the TritonGPU dialect.
//
//===----------------------------------------------------------------------===//
#ifndef TRITON_DIALECT_TRITONGPU_TRANSFORMS_TRITONGPUCONVERSION_H_
#define TRITON_DIALECT_TRITONGPU_TRANSFORMS_TRITONGPUCONVERSION_H_
#include "mlir/Transforms/DialectConversion.h"
namespace mlir {
class TritonGPUTypeConverter : public TypeConverter {
public:
TritonGPUTypeConverter(MLIRContext *context, int numThreads);
private:
MLIRContext *context;
int numThreads;
};
class TritonGPUConversionTarget : public ConversionTarget {
public:
explicit TritonGPUConversionTarget(MLIRContext &ctx);
};
} // namespace mlir
#endif // TRITON_DIALECT_TRITONGPU_TRANSFORMS_TRITONGPUCONVERSION_H_