2022-04-28 18:51:31 +08:00
|
|
|
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
|
2022-05-01 22:06:54 +08:00
|
|
|
#include "mlir/IR/DialectImplementation.h"
|
|
|
|
#include "llvm/ADT/TypeSwitch.h"
|
2022-04-28 18:51:31 +08:00
|
|
|
|
|
|
|
#include "triton/Dialect/TritonGPU/IR/Dialect.cpp.inc"
|
|
|
|
|
|
|
|
using namespace mlir::triton::gpu;
|
|
|
|
|
2022-05-01 22:06:54 +08:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// Attribute methods
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
#define GET_ATTRDEF_CLASSES
|
|
|
|
#include "triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.cpp.inc"
|
|
|
|
|
|
|
|
mlir::Attribute
|
|
|
|
TritonGPUDistributedEncodingAttr::parse(mlir::AsmParser &parser, mlir::Type type) {
|
|
|
|
llvm_unreachable("Not implemented");
|
|
|
|
}
|
|
|
|
|
|
|
|
void TritonGPUDistributedEncodingAttr::print(mlir::AsmPrinter &printer) const {
|
2022-05-04 12:50:02 +08:00
|
|
|
printer << "<"
|
|
|
|
<< "threadTileSize = " << getThreadTileSize()
|
|
|
|
<< ", blockTileSize = " << getBlockTileSize()
|
|
|
|
<< ", order = " << getOrder()
|
|
|
|
<< ">";
|
2022-05-01 22:06:54 +08:00
|
|
|
}
|
|
|
|
|
|
|
|
mlir::Attribute
|
|
|
|
TritonGPUMmaEncodingAttr::parse(mlir::AsmParser &parser, ::mlir::Type type) {
|
|
|
|
llvm_unreachable("Not implemented");
|
|
|
|
}
|
|
|
|
|
|
|
|
void TritonGPUMmaEncodingAttr::print(mlir::AsmPrinter &printer) const {
|
|
|
|
llvm_unreachable("Not implemented");
|
|
|
|
}
|
|
|
|
|
|
|
|
mlir::Attribute
|
|
|
|
TritonGPUSharedEncodingAttr::parse(mlir::AsmParser &parser, ::mlir::Type type) {
|
|
|
|
llvm_unreachable("Not implemented");
|
|
|
|
}
|
|
|
|
|
|
|
|
void TritonGPUSharedEncodingAttr::print(mlir::AsmPrinter &printer) const {
|
2022-05-09 21:19:53 +08:00
|
|
|
printer << "<"
|
|
|
|
// << "threadTileSize = " << getThreadTileSize()
|
|
|
|
// << ", blockTileSize = " << getBlockTileSize()
|
|
|
|
// << ", order = " << getOrder()
|
|
|
|
<< ">";
|
2022-05-01 22:06:54 +08:00
|
|
|
}
|
|
|
|
|
2022-04-28 18:51:31 +08:00
|
|
|
void TritonGPUDialect::initialize() {
|
2022-05-02 21:51:00 +08:00
|
|
|
addAttributes<
|
|
|
|
#define GET_ATTRDEF_LIST
|
|
|
|
#include "triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.cpp.inc"
|
|
|
|
>();
|
2022-04-28 18:51:31 +08:00
|
|
|
addOperations<
|
|
|
|
#define GET_OP_LIST
|
|
|
|
#include "triton/Dialect/TritonGPU/IR/Ops.cpp.inc"
|
|
|
|
>();
|
|
|
|
}
|
2022-05-01 22:06:54 +08:00
|
|
|
|
2022-05-25 17:53:24 +08:00
|
|
|
namespace mlir {
|
|
|
|
namespace triton {
|
|
|
|
|
|
|
|
// Type inference
|
|
|
|
static Type getI1SameShape(Type type) {
|
|
|
|
auto i1Type = IntegerType::get(type.getContext(), 1);
|
|
|
|
if (auto tensorType = type.dyn_cast<RankedTensorType>())
|
|
|
|
return RankedTensorType::get(tensorType.getShape(), i1Type, tensorType.getEncoding());
|
|
|
|
return Type();
|
|
|
|
}
|
|
|
|
|
|
|
|
static Type getPointeeType(Type type) {
|
|
|
|
if (auto tensorType = type.dyn_cast<RankedTensorType>()) {
|
|
|
|
// Tensor of pointers
|
|
|
|
auto shape = tensorType.getShape();
|
|
|
|
auto ptrType = tensorType.getElementType().dyn_cast<PointerType>();
|
|
|
|
Type pointeeType = ptrType.getPointeeType();
|
|
|
|
return RankedTensorType::get(shape, pointeeType, tensorType.getEncoding());
|
|
|
|
} else if (auto ptrType = type.dyn_cast<PointerType>()) {
|
|
|
|
// scalar pointer
|
|
|
|
Type pointeeType = ptrType.getPointeeType();
|
|
|
|
return pointeeType;
|
|
|
|
}
|
|
|
|
return Type();
|
|
|
|
}
|
|
|
|
|
|
|
|
}
|
|
|
|
}
|
2022-05-01 22:06:54 +08:00
|
|
|
|
|
|
|
#define GET_OP_CLASSES
|
|
|
|
#include "triton/Dialect/TritonGPU/IR/Ops.cpp.inc"
|
2022-05-24 19:48:56 +08:00
|
|
|
|
|
|
|
|
|
|
|
// verify TritonGPU ops
|
|
|
|
mlir::LogicalResult
|
|
|
|
TritonGPUDialect::verifyOperationAttribute(mlir::Operation *op,
|
|
|
|
mlir::NamedAttribute attr) {
|
|
|
|
// TODO: fill this.
|
|
|
|
return success();
|
|
|
|
}
|