Files
triton/lib/Dialect/TritonGPU/IR/Dialect.cpp

101 lines
2.9 KiB
C++
Raw Normal View History

2022-04-28 18:51:31 +08:00
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
#include "mlir/IR/DialectImplementation.h"
#include "llvm/ADT/TypeSwitch.h"
2022-04-28 18:51:31 +08:00
#include "triton/Dialect/TritonGPU/IR/Dialect.cpp.inc"
using namespace mlir::triton::gpu;
//===----------------------------------------------------------------------===//
// Attribute methods
//===----------------------------------------------------------------------===//
#define GET_ATTRDEF_CLASSES
#include "triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.cpp.inc"
mlir::Attribute
TritonGPUDistributedEncodingAttr::parse(mlir::AsmParser &parser, mlir::Type type) {
llvm_unreachable("Not implemented");
}
void TritonGPUDistributedEncodingAttr::print(mlir::AsmPrinter &printer) const {
2022-05-04 12:50:02 +08:00
printer << "<"
<< "threadTileSize = " << getThreadTileSize()
<< ", blockTileSize = " << getBlockTileSize()
<< ", order = " << getOrder()
<< ">";
}
mlir::Attribute
TritonGPUMmaEncodingAttr::parse(mlir::AsmParser &parser, ::mlir::Type type) {
llvm_unreachable("Not implemented");
}
void TritonGPUMmaEncodingAttr::print(mlir::AsmPrinter &printer) const {
llvm_unreachable("Not implemented");
}
mlir::Attribute
TritonGPUSharedEncodingAttr::parse(mlir::AsmParser &parser, ::mlir::Type type) {
llvm_unreachable("Not implemented");
}
void TritonGPUSharedEncodingAttr::print(mlir::AsmPrinter &printer) const {
printer << "<"
// << "threadTileSize = " << getThreadTileSize()
// << ", blockTileSize = " << getBlockTileSize()
// << ", order = " << getOrder()
<< ">";
}
2022-04-28 18:51:31 +08:00
void TritonGPUDialect::initialize() {
2022-05-02 21:51:00 +08:00
addAttributes<
#define GET_ATTRDEF_LIST
#include "triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.cpp.inc"
>();
2022-04-28 18:51:31 +08:00
addOperations<
#define GET_OP_LIST
#include "triton/Dialect/TritonGPU/IR/Ops.cpp.inc"
>();
}
2022-05-25 17:53:24 +08:00
namespace mlir {
namespace triton {
// Type inference
static Type getI1SameShape(Type type) {
auto i1Type = IntegerType::get(type.getContext(), 1);
if (auto tensorType = type.dyn_cast<RankedTensorType>())
return RankedTensorType::get(tensorType.getShape(), i1Type, tensorType.getEncoding());
return Type();
}
static Type getPointeeType(Type type) {
if (auto tensorType = type.dyn_cast<RankedTensorType>()) {
// Tensor of pointers
auto shape = tensorType.getShape();
auto ptrType = tensorType.getElementType().dyn_cast<PointerType>();
Type pointeeType = ptrType.getPointeeType();
return RankedTensorType::get(shape, pointeeType, tensorType.getEncoding());
} else if (auto ptrType = type.dyn_cast<PointerType>()) {
// scalar pointer
Type pointeeType = ptrType.getPointeeType();
return pointeeType;
}
return Type();
}
}
}
#define GET_OP_CLASSES
#include "triton/Dialect/TritonGPU/IR/Ops.cpp.inc"
2022-05-24 19:48:56 +08:00
// verify TritonGPU ops
mlir::LogicalResult
TritonGPUDialect::verifyOperationAttribute(mlir::Operation *op,
mlir::NamedAttribute attr) {
// TODO: fill this.
return success();
}