From 9cf4107990c0ed98d87733fd7f19478912a76641 Mon Sep 17 00:00:00 2001 From: Yan Da Date: Thu, 7 Apr 2022 15:18:43 +0800 Subject: [PATCH] Add TensorSizeTrait --- include/triton/ir/Dialect.h | 2 ++ include/triton/ir/Traits.h | 49 ++++++++++++++++++++++++++++++++++ include/triton/ir/TritonOps.td | 9 ++++++- python/src/triton.cc | 4 +++ 4 files changed, 63 insertions(+), 1 deletion(-) create mode 100644 include/triton/ir/Traits.h diff --git a/include/triton/ir/Dialect.h b/include/triton/ir/Dialect.h index f26cccac8..7234627e7 100644 --- a/include/triton/ir/Dialect.h +++ b/include/triton/ir/Dialect.h @@ -1,6 +1,8 @@ #ifndef TRITON_IR_DIALECT_H_ #define TRITON_IR_DIALECT_H_ +#include "triton/ir/Traits.h" + #include "mlir/IR/BuiltinOps.h" #include "mlir/IR/Dialect.h" #include "mlir/Interfaces/ControlFlowInterfaces.h" diff --git a/include/triton/ir/Traits.h b/include/triton/ir/Traits.h new file mode 100644 index 000000000..534722d71 --- /dev/null +++ b/include/triton/ir/Traits.h @@ -0,0 +1,49 @@ +#ifndef TRITON_IR_TRAITS_H_ +#define TRITON_IR_TRAITS_H_ + +#include "mlir/IR/OpDefinition.h" + +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/Support/LogicalResult.h" + +#include + +namespace mlir { +namespace OpTrait { +// TODO: should have `namespace triton {}` here + +template +class TensorSizeTrait : public TraitBase { +public: + // TODO: move impl to .cc files + static LogicalResult verifyTrait(Operation *op) { + for (auto opType : op->getOperandTypes()) { + if (auto tensorType = opType.dyn_cast()) { + int64_t numElements = 1; + for (int64_t s : tensorType.getShape()) + numElements *= s; + if (numElements > 1048576) + return op->emitError("Maximum allowed number of elements is 1048576, but ") + << *op << " has more than that"; + } + } + + for (auto opType : op->getResultTypes()) { + if (auto tensorType = opType.dyn_cast()) { + int64_t numElements = 1; + for (int64_t s : tensorType.getShape()) + numElements *= s; + if (numElements > 1048576) + return op->emitError("Maximum allowed number of elements is 1048576, but ") + << *op << " has more than that"; + } + } + + return success(); + } +}; + +} +} + +#endif diff --git a/include/triton/ir/TritonOps.td b/include/triton/ir/TritonOps.td index 5c04398a8..4d6319944 100644 --- a/include/triton/ir/TritonOps.td +++ b/include/triton/ir/TritonOps.td @@ -39,11 +39,12 @@ def TT_Tensor : AnyTypeOf<[TT_FpIntTensor, TT_PtrTensor]>; def TT_Type : AnyTypeOf<[TT_Float, TT_FloatTensor, TT_Int, TT_IntegerTensor, TT_AnyPtr, TT_PtrTensor]>; +def TensorSizeTrait : NativeOpTrait<"TensorSizeTrait">; // // Op Base // class TT_Op traits = []> : - Op; + Op; // // CastOps @@ -136,6 +137,8 @@ def TT_StoreOp : TT_Op<"store", [SameOperandsShape]> { // for args with default values OpBuilder<(ins "Value":$ptr, "Value":$value)>, ]; + + // let assemblyFormat = "$ptr `,` $value `,` $mask `,` attr-dict"; } def TT_GEPOp : TT_Op<"getelementptr", [NoSideEffect, SameOperandsAndResultShape]> { @@ -179,6 +182,8 @@ def TT_GetProgramIdOp : TT_Op<"get_program_id"> { let arguments = (ins I32Attr:$axis); let results = (outs I32:$result); + + let assemblyFormat = "attr-dict `:` type($result)"; } def TT_GetNumProgramsOp : TT_Op<"get_num_programs"> { @@ -277,6 +282,8 @@ def TT_MakeRangeOp : TT_Op<"make_range", [NoSideEffect]> { let arguments = (ins I32Attr:$start, I32Attr:$end); let results = (outs TT_IntegerTensor:$result); + + let assemblyFormat = "attr-dict `:` type($result)"; } #endif // Triton_OPS diff --git a/python/src/triton.cc b/python/src/triton.cc index e88e22641..6cc60c761 100644 --- a/python/src/triton.cc +++ b/python/src/triton.cc @@ -6,6 +6,7 @@ #include "mlir/IR/Builders.h" #include "mlir/IR/BuiltinOps.h" #include "mlir/IR/MLIRContext.h" +#include "mlir/IR/Verifier.h" #include "triton/ir/Dialect.h" #include "triton/ir/Types.h" @@ -742,6 +743,9 @@ void init_triton_ir(py::module &&m) { .def("append_operand", [](mlir::OpState &self, mlir::Value &val) { self->insertOperands(self->getNumOperands(), val); }) + .def("verify", [](mlir::OpState &self) -> bool { + return mlir::succeeded(mlir::verify(self.getOperation())); + }) ; // scf Ops py::class_(m, "ForOp");