[FRONTEND] Add scalar type support for some ops (#661)
This PR adds basic support for scalar-type inputs to some ops (cast and pointer arithmetics) for Triton-MLIR. Also renames getelementptr -> addptr
This commit is contained in:
@@ -203,7 +203,7 @@ bool tryLegalizeOp(Operation *op, DenseSet<Value> toPreserve,
|
||||
targetType.getEncoding());
|
||||
};
|
||||
bool hasSameTypes = op->getDialect()->getNamespace() == "arith" ||
|
||||
isa<triton::SplatOp, triton::GEPOp>(op);
|
||||
isa<triton::SplatOp, triton::AddPtrOp>(op);
|
||||
if (hasSameTypes) {
|
||||
// replace argument types
|
||||
for (auto arg : llvm::enumerate(op->getOperands())) {
|
||||
@@ -440,4 +440,4 @@ public:
|
||||
|
||||
std::unique_ptr<Pass> mlir::createTritonGPUCombineOpsPass() {
|
||||
return std::make_unique<TritonGPUCombineOpsPass>();
|
||||
}
|
||||
}
|
||||
|
Reference in New Issue
Block a user