## Features - Allow taking a block of tensor slice, as long as each dimension is contiguous (unit stride). - Fix some problems in `insert_slice_async`'s semantic. - More general verification for ops that return shared layout encoding. ## Known Limitations - `insert_slice_async` still uses the old semantic. May submit another PR later to support similar semantic like `tensor.extract_slice`. - No encoding verification for `tensor.extract_slice`. - 3d tensor ops are broken. - Strided accesses are not allowed. - May cause a little performance slowdown since we are passing strides as values but not constants (e.g., int). It would be difficult to pass strides as attributes when we have control flows. A block argument is possible to accept tensors with different strides.
70 lines
2.6 KiB
C++
70 lines
2.6 KiB
C++
#include "triton/Dialect/Triton/IR/Traits.h"
|
|
|
|
static mlir::LogicalResult verifySameEncoding(mlir::Type tyA, mlir::Type tyB) {
|
|
using namespace mlir;
|
|
auto encA = tyA.dyn_cast<RankedTensorType>();
|
|
auto encB = tyA.dyn_cast<RankedTensorType>();
|
|
if (!encA || !encB)
|
|
return success();
|
|
return encA.getEncoding() == encB.getEncoding() ? success() : failure();
|
|
}
|
|
|
|
mlir::LogicalResult
|
|
mlir::OpTrait::impl::verifySameOperandsAndResultEncoding(Operation *op) {
|
|
if (failed(verifyAtLeastNOperands(op, 1)) ||
|
|
failed(verifyAtLeastNResults(op, 1)))
|
|
return failure();
|
|
|
|
auto type = op->getOperand(0).getType();
|
|
for (auto resultType : op->getResultTypes())
|
|
if (failed(verifySameEncoding(resultType, type)))
|
|
return op->emitOpError()
|
|
<< "requires the same shape for all operands and results";
|
|
return verifySameOperandsEncoding(op);
|
|
}
|
|
|
|
mlir::LogicalResult
|
|
mlir::OpTrait::impl::verifySameOperandsEncoding(Operation *op) {
|
|
if (failed(verifyAtLeastNOperands(op, 1)))
|
|
return failure();
|
|
|
|
auto type = op->getOperand(0).getType();
|
|
for (auto opType : llvm::drop_begin(op->getOperandTypes(), 1))
|
|
if (failed(verifySameEncoding(opType, type)))
|
|
return op->emitOpError() << "requires the same encoding for all operands";
|
|
|
|
return success();
|
|
}
|
|
|
|
mlir::LogicalResult mlir::OpTrait::impl::verifyTensorSize(Operation *op) {
|
|
for (auto opType : op->getOperandTypes()) {
|
|
if (auto tensorType = opType.dyn_cast<RankedTensorType>()) {
|
|
int64_t numElements = 1;
|
|
for (int64_t s : tensorType.getShape())
|
|
numElements *= s;
|
|
if (numElements > maxTensorNumElements)
|
|
return op->emitError("Maximum allowed number of elements is ")
|
|
<< maxTensorNumElements << ", but " << *op
|
|
<< " has more than that";
|
|
if ((numElements & (numElements - 1)) != 0)
|
|
return op->emitError("Number of elements must be power-of-two, but ")
|
|
<< *op << " doesn't follow the rule";
|
|
}
|
|
}
|
|
for (auto opType : op->getResultTypes()) {
|
|
if (auto tensorType = opType.dyn_cast<RankedTensorType>()) {
|
|
int64_t numElements = 1;
|
|
for (int64_t s : tensorType.getShape())
|
|
numElements *= s;
|
|
if (numElements > maxTensorNumElements)
|
|
return op->emitError("Maximum allowed number of elements is ")
|
|
<< maxTensorNumElements << ", but " << *op
|
|
<< " has more than that";
|
|
if ((numElements & (numElements - 1)) != 0)
|
|
return op->emitError("Number of elements must be power-of-two, but ")
|
|
<< *op << " doesn't follow the rule";
|
|
}
|
|
}
|
|
return success();
|
|
}
|