[FRONTEND] Add scalar type support for some ops (#661)

This PR adds basic support for scalar-type inputs to some ops (cast and pointer arithmetics) for Triton-MLIR. Also renames getelementptr -> addptr
This commit is contained in:
Shintaro Iwasaki
2022-09-15 16:12:52 -07:00
committed by GitHub
parent 2e08450c80
commit 43be75ad42
27 changed files with 203 additions and 129 deletions

View File

@@ -15,7 +15,7 @@ static Type getI1SameShape(Type type) {
if (auto tensorType = type.dyn_cast<RankedTensorType>())
return RankedTensorType::get(tensorType.getShape(), i1Type,
tensorType.getEncoding());
return Type();
return i1Type;
}
static Type getI32SameShape(Type type) {
@@ -23,7 +23,7 @@ static Type getI32SameShape(Type type) {
if (auto tensorType = type.dyn_cast<RankedTensorType>())
return RankedTensorType::get(tensorType.getShape(), i32Type,
tensorType.getEncoding());
return Type();
return i32Type;
}
static Type getPointerTypeFromTensor(Type type) {

View File

@@ -194,7 +194,7 @@ public:
patterns.add<CombineDotAddFRevPattern>(context);
// %}
patterns.add<CombineSelectMaskedLoadPattern>(context);
patterns.add<CombineGEPPattern>(context);
patterns.add<CombineAddPtrPattern>(context);
patterns.add<CombineBroadcastConstantPattern>(context);
if (applyPatternsAndFoldGreedily(m, std::move(patterns)).failed())

View File

@@ -30,12 +30,12 @@ def CombineDotAddFRevPattern : Pat<
[(Constraint<CPred<"isZero($0)">> $c)]>;
// gep(gep(%ptr, %idx0), %idx1) => gep(%ptr, AddI(%idx0, %idx1))
// addptr(addptr(%ptr, %idx0), %idx1) => addptr(%ptr, AddI(%idx0, %idx1))
// Note: leave (sub %c0, %c0) canceling to ArithmeticDialect
// (ref: ArithmeticCanonicalization.td)
def CombineGEPPattern : Pat<
(TT_GEPOp (TT_GEPOp $ptr, $idx0), $idx1),
(TT_GEPOp $ptr, (Arith_AddIOp $idx0, $idx1))>;
def CombineAddPtrPattern : Pat<
(TT_AddPtrOp (TT_AddPtrOp $ptr, $idx0), $idx1),
(TT_AddPtrOp $ptr, (Arith_AddIOp $idx0, $idx1))>;
// broadcast(cst) => cst
def getConstantValue : NativeCodeCall<"getConstantValue($_builder, $0, $1)">;

View File

@@ -203,7 +203,7 @@ bool tryLegalizeOp(Operation *op, DenseSet<Value> toPreserve,
targetType.getEncoding());
};
bool hasSameTypes = op->getDialect()->getNamespace() == "arith" ||
isa<triton::SplatOp, triton::GEPOp>(op);
isa<triton::SplatOp, triton::AddPtrOp>(op);
if (hasSameTypes) {
// replace argument types
for (auto arg : llvm::enumerate(op->getOperands())) {
@@ -440,4 +440,4 @@ public:
std::unique_ptr<Pass> mlir::createTritonGPUCombineOpsPass() {
return std::make_unique<TritonGPUCombineOpsPass>();
}
}

View File

@@ -70,7 +70,7 @@ private:
if (auto storeOp = llvm::dyn_cast<triton::StoreOp>(op)) {
// TODO: fill this
}
if (auto gepOp = llvm::dyn_cast<triton::GEPOp>(op)) {
if (auto addptrOp = llvm::dyn_cast<triton::AddPtrOp>(op)) {
// TODO: fill this
}
// Triton builtin Ops