Add TensorSizeTrait
This commit is contained in:
@@ -1,6 +1,8 @@
|
||||
#ifndef TRITON_IR_DIALECT_H_
|
||||
#define TRITON_IR_DIALECT_H_
|
||||
|
||||
#include "triton/ir/Traits.h"
|
||||
|
||||
#include "mlir/IR/BuiltinOps.h"
|
||||
#include "mlir/IR/Dialect.h"
|
||||
#include "mlir/Interfaces/ControlFlowInterfaces.h"
|
||||
|
49
include/triton/ir/Traits.h
Normal file
49
include/triton/ir/Traits.h
Normal file
@@ -0,0 +1,49 @@
|
||||
#ifndef TRITON_IR_TRAITS_H_
|
||||
#define TRITON_IR_TRAITS_H_
|
||||
|
||||
#include "mlir/IR/OpDefinition.h"
|
||||
|
||||
#include "mlir/IR/BuiltinTypes.h"
|
||||
#include "mlir/Support/LogicalResult.h"
|
||||
|
||||
#include <iostream>
|
||||
|
||||
namespace mlir {
|
||||
namespace OpTrait {
|
||||
// TODO: should have `namespace triton {}` here
|
||||
|
||||
template <class ConcreteType>
|
||||
class TensorSizeTrait : public TraitBase<ConcreteType, TensorSizeTrait> {
|
||||
public:
|
||||
// TODO: move impl to .cc files
|
||||
static LogicalResult verifyTrait(Operation *op) {
|
||||
for (auto opType : op->getOperandTypes()) {
|
||||
if (auto tensorType = opType.dyn_cast<RankedTensorType>()) {
|
||||
int64_t numElements = 1;
|
||||
for (int64_t s : tensorType.getShape())
|
||||
numElements *= s;
|
||||
if (numElements > 1048576)
|
||||
return op->emitError("Maximum allowed number of elements is 1048576, but ")
|
||||
<< *op << " has more than that";
|
||||
}
|
||||
}
|
||||
|
||||
for (auto opType : op->getResultTypes()) {
|
||||
if (auto tensorType = opType.dyn_cast<RankedTensorType>()) {
|
||||
int64_t numElements = 1;
|
||||
for (int64_t s : tensorType.getShape())
|
||||
numElements *= s;
|
||||
if (numElements > 1048576)
|
||||
return op->emitError("Maximum allowed number of elements is 1048576, but ")
|
||||
<< *op << " has more than that";
|
||||
}
|
||||
}
|
||||
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
#endif
|
@@ -39,11 +39,12 @@ def TT_Tensor : AnyTypeOf<[TT_FpIntTensor, TT_PtrTensor]>;
|
||||
def TT_Type : AnyTypeOf<[TT_Float, TT_FloatTensor, TT_Int, TT_IntegerTensor,
|
||||
TT_AnyPtr, TT_PtrTensor]>;
|
||||
|
||||
def TensorSizeTrait : NativeOpTrait<"TensorSizeTrait">;
|
||||
//
|
||||
// Op Base
|
||||
//
|
||||
class TT_Op<string mnemonic, list<Trait> traits = []> :
|
||||
Op<Triton_Dialect, mnemonic, traits>;
|
||||
Op<Triton_Dialect, mnemonic, !listconcat(traits, [TensorSizeTrait])>;
|
||||
|
||||
//
|
||||
// CastOps
|
||||
@@ -136,6 +137,8 @@ def TT_StoreOp : TT_Op<"store", [SameOperandsShape]> {
|
||||
// for args with default values
|
||||
OpBuilder<(ins "Value":$ptr, "Value":$value)>,
|
||||
];
|
||||
|
||||
// let assemblyFormat = "$ptr `,` $value `,` $mask `,` attr-dict";
|
||||
}
|
||||
|
||||
def TT_GEPOp : TT_Op<"getelementptr", [NoSideEffect, SameOperandsAndResultShape]> {
|
||||
@@ -179,6 +182,8 @@ def TT_GetProgramIdOp : TT_Op<"get_program_id"> {
|
||||
let arguments = (ins I32Attr:$axis);
|
||||
|
||||
let results = (outs I32:$result);
|
||||
|
||||
let assemblyFormat = "attr-dict `:` type($result)";
|
||||
}
|
||||
|
||||
def TT_GetNumProgramsOp : TT_Op<"get_num_programs"> {
|
||||
@@ -277,6 +282,8 @@ def TT_MakeRangeOp : TT_Op<"make_range", [NoSideEffect]> {
|
||||
let arguments = (ins I32Attr:$start, I32Attr:$end);
|
||||
|
||||
let results = (outs TT_IntegerTensor:$result);
|
||||
|
||||
let assemblyFormat = "attr-dict `:` type($result)";
|
||||
}
|
||||
|
||||
#endif // Triton_OPS
|
||||
|
@@ -6,6 +6,7 @@
|
||||
#include "mlir/IR/Builders.h"
|
||||
#include "mlir/IR/BuiltinOps.h"
|
||||
#include "mlir/IR/MLIRContext.h"
|
||||
#include "mlir/IR/Verifier.h"
|
||||
|
||||
#include "triton/ir/Dialect.h"
|
||||
#include "triton/ir/Types.h"
|
||||
@@ -742,6 +743,9 @@ void init_triton_ir(py::module &&m) {
|
||||
.def("append_operand", [](mlir::OpState &self, mlir::Value &val) {
|
||||
self->insertOperands(self->getNumOperands(), val);
|
||||
})
|
||||
.def("verify", [](mlir::OpState &self) -> bool {
|
||||
return mlir::succeeded(mlir::verify(self.getOperation()));
|
||||
})
|
||||
;
|
||||
// scf Ops
|
||||
py::class_<mlir::scf::ForOp, mlir::OpState>(m, "ForOp");
|
||||
|
Reference in New Issue
Block a user