Add TensorSizeTrait

This commit is contained in:
Yan Da
2022-04-07 15:18:43 +08:00
parent 39fad2b18a
commit 9cf4107990
4 changed files with 63 additions and 1 deletions

View File

@@ -1,6 +1,8 @@
#ifndef TRITON_IR_DIALECT_H_
#define TRITON_IR_DIALECT_H_
#include "triton/ir/Traits.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/Dialect.h"
#include "mlir/Interfaces/ControlFlowInterfaces.h"

View File

@@ -0,0 +1,49 @@
#ifndef TRITON_IR_TRAITS_H_
#define TRITON_IR_TRAITS_H_
#include "mlir/IR/OpDefinition.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/Support/LogicalResult.h"
#include <iostream>
namespace mlir {
namespace OpTrait {
// TODO: should have `namespace triton {}` here
template <class ConcreteType>
class TensorSizeTrait : public TraitBase<ConcreteType, TensorSizeTrait> {
public:
// TODO: move impl to .cc files
static LogicalResult verifyTrait(Operation *op) {
for (auto opType : op->getOperandTypes()) {
if (auto tensorType = opType.dyn_cast<RankedTensorType>()) {
int64_t numElements = 1;
for (int64_t s : tensorType.getShape())
numElements *= s;
if (numElements > 1048576)
return op->emitError("Maximum allowed number of elements is 1048576, but ")
<< *op << " has more than that";
}
}
for (auto opType : op->getResultTypes()) {
if (auto tensorType = opType.dyn_cast<RankedTensorType>()) {
int64_t numElements = 1;
for (int64_t s : tensorType.getShape())
numElements *= s;
if (numElements > 1048576)
return op->emitError("Maximum allowed number of elements is 1048576, but ")
<< *op << " has more than that";
}
}
return success();
}
};
}
}
#endif

View File

@@ -39,11 +39,12 @@ def TT_Tensor : AnyTypeOf<[TT_FpIntTensor, TT_PtrTensor]>;
def TT_Type : AnyTypeOf<[TT_Float, TT_FloatTensor, TT_Int, TT_IntegerTensor,
TT_AnyPtr, TT_PtrTensor]>;
def TensorSizeTrait : NativeOpTrait<"TensorSizeTrait">;
//
// Op Base
//
class TT_Op<string mnemonic, list<Trait> traits = []> :
Op<Triton_Dialect, mnemonic, traits>;
Op<Triton_Dialect, mnemonic, !listconcat(traits, [TensorSizeTrait])>;
//
// CastOps
@@ -136,6 +137,8 @@ def TT_StoreOp : TT_Op<"store", [SameOperandsShape]> {
// for args with default values
OpBuilder<(ins "Value":$ptr, "Value":$value)>,
];
// let assemblyFormat = "$ptr `,` $value `,` $mask `,` attr-dict";
}
def TT_GEPOp : TT_Op<"getelementptr", [NoSideEffect, SameOperandsAndResultShape]> {
@@ -179,6 +182,8 @@ def TT_GetProgramIdOp : TT_Op<"get_program_id"> {
let arguments = (ins I32Attr:$axis);
let results = (outs I32:$result);
let assemblyFormat = "attr-dict `:` type($result)";
}
def TT_GetNumProgramsOp : TT_Op<"get_num_programs"> {
@@ -277,6 +282,8 @@ def TT_MakeRangeOp : TT_Op<"make_range", [NoSideEffect]> {
let arguments = (ins I32Attr:$start, I32Attr:$end);
let results = (outs TT_IntegerTensor:$result);
let assemblyFormat = "attr-dict `:` type($result)";
}
#endif // Triton_OPS

View File

@@ -6,6 +6,7 @@
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/MLIRContext.h"
#include "mlir/IR/Verifier.h"
#include "triton/ir/Dialect.h"
#include "triton/ir/Types.h"
@@ -742,6 +743,9 @@ void init_triton_ir(py::module &&m) {
.def("append_operand", [](mlir::OpState &self, mlir::Value &val) {
self->insertOperands(self->getNumOperands(), val);
})
.def("verify", [](mlir::OpState &self) -> bool {
return mlir::succeeded(mlir::verify(self.getOperation()));
})
;
// scf Ops
py::class_<mlir::scf::ForOp, mlir::OpState>(m, "ForOp");