[TritonIR] simplify Load/StoreOps when mask is true/false (#79)

* [TritonIR] fix Load/Store/CopyAsyncOp's parsers

* [TritonIR] simplify Load/StoreOps when mask is true/false

* [TEST] adds tests to check load/store simplification
This commit is contained in:
Shintaro Iwasaki
2022-08-24 12:55:49 -07:00
committed by GitHub
parent 1b513c9866
commit 84aa7d025a
6 changed files with 269 additions and 36 deletions

View File

@@ -10,6 +10,37 @@
using namespace mlir;
using namespace mlir::triton::gpu;
// Utility
namespace mlir {
namespace triton {
// Type inference
static Type getI1SameShape(Type type) {
auto i1Type = IntegerType::get(type.getContext(), 1);
if (auto tensorType = type.dyn_cast<RankedTensorType>())
return RankedTensorType::get(tensorType.getShape(), i1Type,
tensorType.getEncoding());
return Type();
}
static Type getPointeeType(Type type) {
if (auto tensorType = type.dyn_cast<RankedTensorType>()) {
// Tensor of pointers
auto shape = tensorType.getShape();
auto ptrType = tensorType.getElementType().dyn_cast<PointerType>();
Type pointeeType = ptrType.getPointeeType();
return RankedTensorType::get(shape, pointeeType, tensorType.getEncoding());
} else if (auto ptrType = type.dyn_cast<PointerType>()) {
// scalar pointer
Type pointeeType = ptrType.getPointeeType();
return pointeeType;
}
return Type();
}
} // namespace triton
} // namespace mlir
static LogicalResult parseIntAttrValue(AsmParser &parser, const Attribute &attr,
unsigned &value, StringRef desc) {
auto intAttr = attr.dyn_cast<IntegerAttr>();
@@ -260,6 +291,44 @@ void SharedEncodingAttr::print(AsmPrinter &printer) const {
<< "}>";
}
//===----------------------------------------------------------------------===//
// CopyAsyncOp
//===----------------------------------------------------------------------===//
ParseResult parseCopyAsyncOp(OpAsmParser &parser, OperationState &result) {
SmallVector<OpAsmParser::OperandType, 4> allOperands;
Type resultTypes[1], ptrType;
SMLoc allOperandLoc = parser.getCurrentLocation();
if (parser.parseOperandList(allOperands) ||
parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() ||
parser.parseCustomTypeWithFallback(ptrType) || parser.parseArrow() ||
parser.parseCustomTypeWithFallback(resultTypes[0]))
return failure();
result.addTypes(resultTypes);
SmallVector<Type> operandTypes;
operandTypes.push_back(ptrType); // ptr
if (allOperands.size() >= 2)
operandTypes.push_back(triton::getI1SameShape(ptrType)); // mask
if (allOperands.size() >= 3)
operandTypes.push_back(triton::getPointeeType(ptrType)); // other
if (parser.resolveOperands(allOperands, operandTypes, allOperandLoc,
result.operands))
return failure();
return success();
}
void printCopyAsyncOp(OpAsmPrinter &printer, CopyAsyncOp copyAsyncOp) {
printer << " ";
printer << copyAsyncOp.getOperation()->getOperands();
printer.printOptionalAttrDict(copyAsyncOp->getAttrs(), /*elidedAttrs=*/{});
printer << " : ";
printer.printStrippedAttrOrType(copyAsyncOp.ptr().getType());
printer << " -> ";
printer.printStrippedAttrOrType(copyAsyncOp.result().getType());
}
//===----------------------------------------------------------------------===//
// ASM Interface (i.e.: alias)
//===----------------------------------------------------------------------===//
@@ -282,8 +351,7 @@ public:
os << "slice";
return AliasResult::FinalAlias;
} */
OpAsmDialectInterface::getAlias(attr, os);
return AliasResult::FinalAlias;
return OpAsmDialectInterface::getAlias(attr, os);
}
};
@@ -299,36 +367,6 @@ void TritonGPUDialect::initialize() {
addInterfaces<TritonGPUOpAsmInterface>();
}
namespace mlir {
namespace triton {
// Type inference
static Type getI1SameShape(Type type) {
auto i1Type = IntegerType::get(type.getContext(), 1);
if (auto tensorType = type.dyn_cast<RankedTensorType>())
return RankedTensorType::get(tensorType.getShape(), i1Type,
tensorType.getEncoding());
return Type();
}
static Type getPointeeType(Type type) {
if (auto tensorType = type.dyn_cast<RankedTensorType>()) {
// Tensor of pointers
auto shape = tensorType.getShape();
auto ptrType = tensorType.getElementType().dyn_cast<PointerType>();
Type pointeeType = ptrType.getPointeeType();
return RankedTensorType::get(shape, pointeeType, tensorType.getEncoding());
} else if (auto ptrType = type.dyn_cast<PointerType>()) {
// scalar pointer
Type pointeeType = ptrType.getPointeeType();
return pointeeType;
}
return Type();
}
} // namespace triton
} // namespace mlir
//===----------------------------------------------------------------------===//
// Verification
//===----------------------------------------------------------------------===//
@@ -352,4 +390,4 @@ LogicalResult TritonGPUDialect::verifyOperationAttribute(Operation *op,
NamedAttribute attr) {
// TODO: fill this.
return success();
}
}