Add TensorSizeTrait
This commit is contained in:
@@ -1,6 +1,8 @@
|
|||||||
#ifndef TRITON_IR_DIALECT_H_
|
#ifndef TRITON_IR_DIALECT_H_
|
||||||
#define TRITON_IR_DIALECT_H_
|
#define TRITON_IR_DIALECT_H_
|
||||||
|
|
||||||
|
#include "triton/ir/Traits.h"
|
||||||
|
|
||||||
#include "mlir/IR/BuiltinOps.h"
|
#include "mlir/IR/BuiltinOps.h"
|
||||||
#include "mlir/IR/Dialect.h"
|
#include "mlir/IR/Dialect.h"
|
||||||
#include "mlir/Interfaces/ControlFlowInterfaces.h"
|
#include "mlir/Interfaces/ControlFlowInterfaces.h"
|
||||||
|
49
include/triton/ir/Traits.h
Normal file
49
include/triton/ir/Traits.h
Normal file
@@ -0,0 +1,49 @@
|
|||||||
|
#ifndef TRITON_IR_TRAITS_H_
|
||||||
|
#define TRITON_IR_TRAITS_H_
|
||||||
|
|
||||||
|
#include "mlir/IR/OpDefinition.h"
|
||||||
|
|
||||||
|
#include "mlir/IR/BuiltinTypes.h"
|
||||||
|
#include "mlir/Support/LogicalResult.h"
|
||||||
|
|
||||||
|
#include <iostream>
|
||||||
|
|
||||||
|
namespace mlir {
|
||||||
|
namespace OpTrait {
|
||||||
|
// TODO: should have `namespace triton {}` here
|
||||||
|
|
||||||
|
template <class ConcreteType>
|
||||||
|
class TensorSizeTrait : public TraitBase<ConcreteType, TensorSizeTrait> {
|
||||||
|
public:
|
||||||
|
// TODO: move impl to .cc files
|
||||||
|
static LogicalResult verifyTrait(Operation *op) {
|
||||||
|
for (auto opType : op->getOperandTypes()) {
|
||||||
|
if (auto tensorType = opType.dyn_cast<RankedTensorType>()) {
|
||||||
|
int64_t numElements = 1;
|
||||||
|
for (int64_t s : tensorType.getShape())
|
||||||
|
numElements *= s;
|
||||||
|
if (numElements > 1048576)
|
||||||
|
return op->emitError("Maximum allowed number of elements is 1048576, but ")
|
||||||
|
<< *op << " has more than that";
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for (auto opType : op->getResultTypes()) {
|
||||||
|
if (auto tensorType = opType.dyn_cast<RankedTensorType>()) {
|
||||||
|
int64_t numElements = 1;
|
||||||
|
for (int64_t s : tensorType.getShape())
|
||||||
|
numElements *= s;
|
||||||
|
if (numElements > 1048576)
|
||||||
|
return op->emitError("Maximum allowed number of elements is 1048576, but ")
|
||||||
|
<< *op << " has more than that";
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#endif
|
@@ -39,11 +39,12 @@ def TT_Tensor : AnyTypeOf<[TT_FpIntTensor, TT_PtrTensor]>;
|
|||||||
def TT_Type : AnyTypeOf<[TT_Float, TT_FloatTensor, TT_Int, TT_IntegerTensor,
|
def TT_Type : AnyTypeOf<[TT_Float, TT_FloatTensor, TT_Int, TT_IntegerTensor,
|
||||||
TT_AnyPtr, TT_PtrTensor]>;
|
TT_AnyPtr, TT_PtrTensor]>;
|
||||||
|
|
||||||
|
def TensorSizeTrait : NativeOpTrait<"TensorSizeTrait">;
|
||||||
//
|
//
|
||||||
// Op Base
|
// Op Base
|
||||||
//
|
//
|
||||||
class TT_Op<string mnemonic, list<Trait> traits = []> :
|
class TT_Op<string mnemonic, list<Trait> traits = []> :
|
||||||
Op<Triton_Dialect, mnemonic, traits>;
|
Op<Triton_Dialect, mnemonic, !listconcat(traits, [TensorSizeTrait])>;
|
||||||
|
|
||||||
//
|
//
|
||||||
// CastOps
|
// CastOps
|
||||||
@@ -136,6 +137,8 @@ def TT_StoreOp : TT_Op<"store", [SameOperandsShape]> {
|
|||||||
// for args with default values
|
// for args with default values
|
||||||
OpBuilder<(ins "Value":$ptr, "Value":$value)>,
|
OpBuilder<(ins "Value":$ptr, "Value":$value)>,
|
||||||
];
|
];
|
||||||
|
|
||||||
|
// let assemblyFormat = "$ptr `,` $value `,` $mask `,` attr-dict";
|
||||||
}
|
}
|
||||||
|
|
||||||
def TT_GEPOp : TT_Op<"getelementptr", [NoSideEffect, SameOperandsAndResultShape]> {
|
def TT_GEPOp : TT_Op<"getelementptr", [NoSideEffect, SameOperandsAndResultShape]> {
|
||||||
@@ -179,6 +182,8 @@ def TT_GetProgramIdOp : TT_Op<"get_program_id"> {
|
|||||||
let arguments = (ins I32Attr:$axis);
|
let arguments = (ins I32Attr:$axis);
|
||||||
|
|
||||||
let results = (outs I32:$result);
|
let results = (outs I32:$result);
|
||||||
|
|
||||||
|
let assemblyFormat = "attr-dict `:` type($result)";
|
||||||
}
|
}
|
||||||
|
|
||||||
def TT_GetNumProgramsOp : TT_Op<"get_num_programs"> {
|
def TT_GetNumProgramsOp : TT_Op<"get_num_programs"> {
|
||||||
@@ -277,6 +282,8 @@ def TT_MakeRangeOp : TT_Op<"make_range", [NoSideEffect]> {
|
|||||||
let arguments = (ins I32Attr:$start, I32Attr:$end);
|
let arguments = (ins I32Attr:$start, I32Attr:$end);
|
||||||
|
|
||||||
let results = (outs TT_IntegerTensor:$result);
|
let results = (outs TT_IntegerTensor:$result);
|
||||||
|
|
||||||
|
let assemblyFormat = "attr-dict `:` type($result)";
|
||||||
}
|
}
|
||||||
|
|
||||||
#endif // Triton_OPS
|
#endif // Triton_OPS
|
||||||
|
@@ -6,6 +6,7 @@
|
|||||||
#include "mlir/IR/Builders.h"
|
#include "mlir/IR/Builders.h"
|
||||||
#include "mlir/IR/BuiltinOps.h"
|
#include "mlir/IR/BuiltinOps.h"
|
||||||
#include "mlir/IR/MLIRContext.h"
|
#include "mlir/IR/MLIRContext.h"
|
||||||
|
#include "mlir/IR/Verifier.h"
|
||||||
|
|
||||||
#include "triton/ir/Dialect.h"
|
#include "triton/ir/Dialect.h"
|
||||||
#include "triton/ir/Types.h"
|
#include "triton/ir/Types.h"
|
||||||
@@ -742,6 +743,9 @@ void init_triton_ir(py::module &&m) {
|
|||||||
.def("append_operand", [](mlir::OpState &self, mlir::Value &val) {
|
.def("append_operand", [](mlir::OpState &self, mlir::Value &val) {
|
||||||
self->insertOperands(self->getNumOperands(), val);
|
self->insertOperands(self->getNumOperands(), val);
|
||||||
})
|
})
|
||||||
|
.def("verify", [](mlir::OpState &self) -> bool {
|
||||||
|
return mlir::succeeded(mlir::verify(self.getOperation()));
|
||||||
|
})
|
||||||
;
|
;
|
||||||
// scf Ops
|
// scf Ops
|
||||||
py::class_<mlir::scf::ForOp, mlir::OpState>(m, "ForOp");
|
py::class_<mlir::scf::ForOp, mlir::OpState>(m, "ForOp");
|
||||||
|
Reference in New Issue
Block a user