Putting Triton dialect in its own folder

This commit is contained in:
Philippe Tillet
2022-04-26 14:38:28 -07:00
parent 62a64ff29b
commit 81001d318c
15 changed files with 23 additions and 22 deletions

View File

@@ -0,0 +1 @@
add_subdirectory(Triton)

View File

@@ -0,0 +1,18 @@
add_mlir_dialect_library(TritonIR
Dialect.cpp
Ops.cpp
Types.cpp
DEPENDS
TritonTableGen
LINK_LIBS PUBLIC
MLIRIR
MLIRArithmetic
MLIRSCF
# Since LLVM 15
# MLIRFunc
# else
MLIRStandard
)

View File

@@ -0,0 +1,25 @@
#include "triton/Dialect/Triton/Dialect.h"
#include "triton/Dialect/Triton/Types.h"
#include "llvm/ADT/StringSwitch.h"
#include "llvm/ADT/TypeSwitch.h"
#include "llvm/Support/raw_ostream.h"
#include "mlir/IR/DialectImplementation.h"
#include "triton/Dialect/Triton/Dialect.cpp.inc"
using namespace mlir;
using namespace mlir::triton;
void TritonDialect::initialize() {
registerTypes();
addOperations<
#define GET_OP_LIST
#include "triton/Dialect/Triton/Ops.cpp.inc"
>();
// We can also add interface here.
}

View File

@@ -0,0 +1,99 @@
#include "triton/Dialect/Triton/Dialect.h"
#include "triton/Dialect/Triton/Types.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/OperationSupport.h"
namespace mlir {
namespace triton {
// Type inference
static Type getI1SameShape(Type type) {
auto i1Type = IntegerType::get(type.getContext(), 1);
if (auto tensorType = type.dyn_cast<RankedTensorType>())
return RankedTensorType::get(tensorType.getShape(), i1Type);
return Type();
}
static Type getI32SameShape(Type type) {
auto i32Type = IntegerType::get(type.getContext(), 32);
if (auto tensorType = type.dyn_cast<RankedTensorType>())
return RankedTensorType::get(tensorType.getShape(), i32Type);
return Type();
}
static Type getPointerTypeFromTensor(Type type) {
if (auto tensorType = type.dyn_cast<RankedTensorType>()) {
Type elementType = tensorType.getElementType();
auto shape = tensorType.getShape();
PointerType ptrType = PointerType::get(elementType, 1);
return RankedTensorType::get(shape, ptrType);
}
return Type();
}
}
}
#define GET_OP_CLASSES
#include "triton/Dialect/Triton/Ops.cpp.inc"
// enum attribute definitions
#include "triton/Dialect/Triton/OpsEnums.cpp.inc"
namespace mlir {
namespace triton {
//-- StoreOp --
// Default mask
void StoreOp::build(::mlir::OpBuilder &builder, ::mlir::OperationState &state, ::mlir::Value ptr, ::mlir::Value value) {
TensorType ptrType = ptr.getType().dyn_cast<TensorType>();
auto shape = ptrType.getShape();
::mlir::Value mask = builder.create<arith::ConstantOp>(
ptr.getLoc(),
RankedTensorType::get(shape, builder.getI1Type()),
DenseIntElementsAttr::get(
RankedTensorType::get(shape, builder.getI1Type()), true
)
);
state.addOperands(ptr);
state.addOperands(value);
state.addOperands(mask);
}
//-- LoadOp --
void LoadOp::build(::mlir::OpBuilder &builder, ::mlir::OperationState &state, ::mlir::Value ptr,
::mlir::triton::CacheModifier cache, ::mlir::triton::EvictionPolicy evict, bool isVolatile) {
TensorType ptrType = ptr.getType().dyn_cast<TensorType>();
Type elementType = ptrType.getElementType().dyn_cast<PointerType>().getPointeeType();
auto shape = ptrType.getShape();
// mask
::mlir::Value mask = builder.create<arith::ConstantOp>(
ptr.getLoc(),
RankedTensorType::get(shape, builder.getI1Type()),
DenseIntElementsAttr::get(
RankedTensorType::get(shape, builder.getI1Type()), true
)
);
// other
Type resultType = RankedTensorType::get(shape, elementType);
::mlir::Value other = builder.create<arith::ConstantOp>(
ptr.getLoc(),
resultType,
DenseElementsAttr::get(
resultType, builder.getZeroAttr(elementType)
)
);
state.addOperands(ptr);
state.addOperands(mask);
state.addOperands(other);
state.addAttribute(cacheAttrName(state.name), ::mlir::triton::CacheModifierAttr::get(builder.getContext(), cache));
state.addAttribute(evictAttrName(state.name), ::mlir::triton::EvictionPolicyAttr::get(builder.getContext(), evict));
state.addAttribute(isVolatileAttrName(state.name), builder.getBoolAttr(isVolatile));
state.addTypes({resultType});
}
} // namespace triton
} // namespace mlir

View File

@@ -0,0 +1,38 @@
#include "triton/Dialect/Triton/Dialect.h"
#include "triton/Dialect/Triton/Types.h"
#include "mlir/IR/DialectImplementation.h" // required by `Types.cpp.inc`
#include "llvm/ADT/TypeSwitch.h" // required by `Types.cpp.inc`
using namespace mlir;
using namespace mlir::triton;
#define GET_TYPEDEF_CLASSES
#include "triton/Dialect/Triton/Types.cpp.inc"
//===----------------------------------------------------------------------===//
// Triton Dialect
//===----------------------------------------------------------------------===//
void TritonDialect::registerTypes() {
addTypes<
#define GET_TYPEDEF_LIST
#include "triton/Dialect/Triton/Types.cpp.inc"
>();
}
Type PointerType::parse(AsmParser &parser) {
if (parser.parseLess())
return Type();
Type pointeeType;
if (parser.parseType(pointeeType))
return Type();
if (parser.parseGreater())
return Type();
return PointerType::get(pointeeType, 0);
}
void PointerType::print(AsmPrinter &printer) const {
printer << "<" << getPointeeType() << ">";
}