[CI] run clang-format (#24)
This commit is contained in:
@@ -7,7 +7,6 @@
|
||||
|
||||
#include "mlir/IR/DialectImplementation.h"
|
||||
|
||||
|
||||
#include "triton/Dialect/Triton/IR/Dialect.cpp.inc"
|
||||
|
||||
using namespace mlir;
|
||||
@@ -19,12 +18,13 @@ void TritonDialect::initialize() {
|
||||
addOperations<
|
||||
#define GET_OP_LIST
|
||||
#include "triton/Dialect/Triton/IR/Ops.cpp.inc"
|
||||
>();
|
||||
>();
|
||||
|
||||
// We can also add interface here.
|
||||
}
|
||||
|
||||
Operation *TritonDialect::materializeConstant(OpBuilder &builder, Attribute value,
|
||||
Type type, Location loc) {
|
||||
Operation *TritonDialect::materializeConstant(OpBuilder &builder,
|
||||
Attribute value, Type type,
|
||||
Location loc) {
|
||||
return builder.create<arith::ConstantOp>(loc, type, value);
|
||||
}
|
@@ -13,14 +13,16 @@ namespace triton {
|
||||
static Type getI1SameShape(Type type) {
|
||||
auto i1Type = IntegerType::get(type.getContext(), 1);
|
||||
if (auto tensorType = type.dyn_cast<RankedTensorType>())
|
||||
return RankedTensorType::get(tensorType.getShape(), i1Type, tensorType.getEncoding());
|
||||
return RankedTensorType::get(tensorType.getShape(), i1Type,
|
||||
tensorType.getEncoding());
|
||||
return Type();
|
||||
}
|
||||
|
||||
static Type getI32SameShape(Type type) {
|
||||
auto i32Type = IntegerType::get(type.getContext(), 32);
|
||||
if (auto tensorType = type.dyn_cast<RankedTensorType>())
|
||||
return RankedTensorType::get(tensorType.getShape(), i32Type, tensorType.getEncoding());
|
||||
return RankedTensorType::get(tensorType.getShape(), i32Type,
|
||||
tensorType.getEncoding());
|
||||
return Type();
|
||||
}
|
||||
|
||||
@@ -34,8 +36,8 @@ static Type getPointerTypeFromTensor(Type type) {
|
||||
return Type();
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
} // namespace triton
|
||||
} // namespace mlir
|
||||
|
||||
#define GET_OP_CLASSES
|
||||
#include "triton/Dialect/Triton/IR/Ops.cpp.inc"
|
||||
@@ -48,50 +50,48 @@ namespace triton {
|
||||
|
||||
//-- StoreOp --
|
||||
// Default mask
|
||||
void StoreOp::build(::mlir::OpBuilder &builder, ::mlir::OperationState &state, ::mlir::Value ptr, ::mlir::Value value) {
|
||||
void StoreOp::build(::mlir::OpBuilder &builder, ::mlir::OperationState &state,
|
||||
::mlir::Value ptr, ::mlir::Value value) {
|
||||
TensorType ptrType = ptr.getType().dyn_cast<TensorType>();
|
||||
auto shape = ptrType.getShape();
|
||||
::mlir::Value mask = builder.create<arith::ConstantOp>(
|
||||
ptr.getLoc(),
|
||||
RankedTensorType::get(shape, builder.getI1Type()),
|
||||
DenseIntElementsAttr::get(
|
||||
RankedTensorType::get(shape, builder.getI1Type()), true
|
||||
)
|
||||
);
|
||||
ptr.getLoc(), RankedTensorType::get(shape, builder.getI1Type()),
|
||||
DenseIntElementsAttr::get(
|
||||
RankedTensorType::get(shape, builder.getI1Type()), true));
|
||||
state.addOperands(ptr);
|
||||
state.addOperands(value);
|
||||
state.addOperands(mask);
|
||||
}
|
||||
|
||||
//-- LoadOp --
|
||||
void LoadOp::build(::mlir::OpBuilder &builder, ::mlir::OperationState &state, ::mlir::Value ptr,
|
||||
::mlir::triton::CacheModifier cache, ::mlir::triton::EvictionPolicy evict, bool isVolatile) {
|
||||
void LoadOp::build(::mlir::OpBuilder &builder, ::mlir::OperationState &state,
|
||||
::mlir::Value ptr, ::mlir::triton::CacheModifier cache,
|
||||
::mlir::triton::EvictionPolicy evict, bool isVolatile) {
|
||||
TensorType ptrType = ptr.getType().dyn_cast<TensorType>();
|
||||
Type elementType = ptrType.getElementType().dyn_cast<PointerType>().getPointeeType();
|
||||
Type elementType =
|
||||
ptrType.getElementType().dyn_cast<PointerType>().getPointeeType();
|
||||
auto shape = ptrType.getShape();
|
||||
// mask
|
||||
::mlir::Value mask = builder.create<arith::ConstantOp>(
|
||||
ptr.getLoc(),
|
||||
RankedTensorType::get(shape, builder.getI1Type()),
|
||||
DenseIntElementsAttr::get(
|
||||
RankedTensorType::get(shape, builder.getI1Type()), true
|
||||
)
|
||||
);
|
||||
ptr.getLoc(), RankedTensorType::get(shape, builder.getI1Type()),
|
||||
DenseIntElementsAttr::get(
|
||||
RankedTensorType::get(shape, builder.getI1Type()), true));
|
||||
// other
|
||||
Type resultType = RankedTensorType::get(shape, elementType);
|
||||
::mlir::Value other = builder.create<arith::ConstantOp>(
|
||||
ptr.getLoc(),
|
||||
resultType,
|
||||
DenseElementsAttr::get(
|
||||
resultType, builder.getZeroAttr(elementType)
|
||||
)
|
||||
);
|
||||
ptr.getLoc(), resultType,
|
||||
DenseElementsAttr::get(resultType, builder.getZeroAttr(elementType)));
|
||||
state.addOperands(ptr);
|
||||
state.addOperands(mask);
|
||||
state.addOperands(other);
|
||||
state.addAttribute(cacheAttrName(state.name), ::mlir::triton::CacheModifierAttr::get(builder.getContext(), cache));
|
||||
state.addAttribute(evictAttrName(state.name), ::mlir::triton::EvictionPolicyAttr::get(builder.getContext(), evict));
|
||||
state.addAttribute(isVolatileAttrName(state.name), builder.getBoolAttr(isVolatile));
|
||||
state.addAttribute(
|
||||
cacheAttrName(state.name),
|
||||
::mlir::triton::CacheModifierAttr::get(builder.getContext(), cache));
|
||||
state.addAttribute(
|
||||
evictAttrName(state.name),
|
||||
::mlir::triton::EvictionPolicyAttr::get(builder.getContext(), evict));
|
||||
state.addAttribute(isVolatileAttrName(state.name),
|
||||
builder.getBoolAttr(isVolatile));
|
||||
state.addTypes({resultType});
|
||||
}
|
||||
|
||||
|
@@ -1,6 +1,6 @@
|
||||
#include "triton/Dialect/Triton/IR/Dialect.h"
|
||||
#include "triton/Dialect/Triton/IR/Types.h"
|
||||
#include "mlir/IR/DialectImplementation.h" // required by `Types.cpp.inc`
|
||||
#include "triton/Dialect/Triton/IR/Dialect.h"
|
||||
#include "llvm/ADT/TypeSwitch.h" // required by `Types.cpp.inc`
|
||||
|
||||
using namespace mlir;
|
||||
@@ -16,7 +16,7 @@ void TritonDialect::registerTypes() {
|
||||
addTypes<
|
||||
#define GET_TYPEDEF_LIST
|
||||
#include "triton/Dialect/Triton/IR/Types.cpp.inc"
|
||||
>();
|
||||
>();
|
||||
}
|
||||
|
||||
Type PointerType::parse(AsmParser &parser) {
|
||||
|
Reference in New Issue
Block a user