[FRONTEND] Add scalar type support for some ops (#661)
This PR adds basic support for scalar-type inputs to some ops (cast and pointer arithmetics) for Triton-MLIR. Also renames getelementptr -> addptr
This commit is contained in:
@@ -126,7 +126,7 @@ ChangeResult AxisInfoAnalysis::visitOperation(
|
||||
}
|
||||
}
|
||||
// Addition
|
||||
if (llvm::isa<arith::AddIOp, triton::GEPOp>(op)) {
|
||||
if (llvm::isa<arith::AddIOp, triton::AddPtrOp>(op)) {
|
||||
auto newContiguity = [&](AxisInfo lhs, AxisInfo rhs, int d) {
|
||||
return std::max(gcd(lhs.getContiguity(d), rhs.getConstancy(d)),
|
||||
gcd(lhs.getConstancy(d), rhs.getContiguity(d)));
|
||||
|
@@ -701,7 +701,7 @@ struct StoreOpConversion
|
||||
|
||||
const int numVecs = numElems / vec;
|
||||
for (size_t vecStart = 0; vecStart < numElems; vecStart += vec) {
|
||||
// TODO: optimization when ptr is GEP with constant offset
|
||||
// TODO: optimization when ptr is AddPtr with constant offset
|
||||
size_t in_off = 0;
|
||||
|
||||
const int maxWordWidth = std::max<int>(32, valueElemNbits);
|
||||
@@ -1173,12 +1173,13 @@ struct GetProgramIdOpConversion
|
||||
}
|
||||
};
|
||||
|
||||
struct GEPOpConversion : public ConvertTritonGPUOpToLLVMPattern<triton::GEPOp> {
|
||||
struct AddPtrOpConversion
|
||||
: public ConvertTritonGPUOpToLLVMPattern<triton::AddPtrOp> {
|
||||
using ConvertTritonGPUOpToLLVMPattern<
|
||||
triton::GEPOp>::ConvertTritonGPUOpToLLVMPattern;
|
||||
triton::AddPtrOp>::ConvertTritonGPUOpToLLVMPattern;
|
||||
|
||||
LogicalResult
|
||||
matchAndRewrite(triton::GEPOp op, OpAdaptor adaptor,
|
||||
matchAndRewrite(triton::AddPtrOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
Location loc = op->getLoc();
|
||||
auto resultTy = op.getType().dyn_cast<RankedTensorType>();
|
||||
@@ -1298,7 +1299,7 @@ void populateTritonToLLVMPatterns(mlir::LLVMTypeConverter &typeConverter,
|
||||
|
||||
patterns.add<BroadcastOpConversion>(typeConverter, benefit);
|
||||
patterns.add<FuncOpConversion>(typeConverter, numWarps, benefit);
|
||||
patterns.add<GEPOpConversion>(typeConverter, benefit);
|
||||
patterns.add<AddPtrOpConversion>(typeConverter, benefit);
|
||||
patterns.add<GetProgramIdOpConversion>(typeConverter, benefit);
|
||||
patterns.add<LoadOpConversion>(typeConverter, analysis, benefit);
|
||||
patterns.add<MakeRangeOpConversion>(typeConverter, benefit);
|
||||
|
@@ -323,7 +323,7 @@ void populateTritonPatterns(TritonGPUTypeConverter &typeConverter,
|
||||
patterns.add< // TODO: view should have custom pattern that views the layout
|
||||
TritonGenericPattern<triton::ViewOp>,
|
||||
TritonGenericPattern<triton::SplatOp>, TritonBroadcastPattern,
|
||||
TritonGenericPattern<triton::GEPOp>, TritonReducePattern,
|
||||
TritonGenericPattern<triton::AddPtrOp>, TritonReducePattern,
|
||||
TritonExpandDimsPattern, TritonMakeRangePattern, TritonDotPattern,
|
||||
TritonLoadPattern, TritonStorePattern, TritonExtElemwisePattern>(
|
||||
typeConverter, context);
|
||||
|
@@ -15,7 +15,7 @@ static Type getI1SameShape(Type type) {
|
||||
if (auto tensorType = type.dyn_cast<RankedTensorType>())
|
||||
return RankedTensorType::get(tensorType.getShape(), i1Type,
|
||||
tensorType.getEncoding());
|
||||
return Type();
|
||||
return i1Type;
|
||||
}
|
||||
|
||||
static Type getI32SameShape(Type type) {
|
||||
@@ -23,7 +23,7 @@ static Type getI32SameShape(Type type) {
|
||||
if (auto tensorType = type.dyn_cast<RankedTensorType>())
|
||||
return RankedTensorType::get(tensorType.getShape(), i32Type,
|
||||
tensorType.getEncoding());
|
||||
return Type();
|
||||
return i32Type;
|
||||
}
|
||||
|
||||
static Type getPointerTypeFromTensor(Type type) {
|
||||
|
@@ -194,7 +194,7 @@ public:
|
||||
patterns.add<CombineDotAddFRevPattern>(context);
|
||||
// %}
|
||||
patterns.add<CombineSelectMaskedLoadPattern>(context);
|
||||
patterns.add<CombineGEPPattern>(context);
|
||||
patterns.add<CombineAddPtrPattern>(context);
|
||||
patterns.add<CombineBroadcastConstantPattern>(context);
|
||||
|
||||
if (applyPatternsAndFoldGreedily(m, std::move(patterns)).failed())
|
||||
|
@@ -30,12 +30,12 @@ def CombineDotAddFRevPattern : Pat<
|
||||
[(Constraint<CPred<"isZero($0)">> $c)]>;
|
||||
|
||||
|
||||
// gep(gep(%ptr, %idx0), %idx1) => gep(%ptr, AddI(%idx0, %idx1))
|
||||
// addptr(addptr(%ptr, %idx0), %idx1) => addptr(%ptr, AddI(%idx0, %idx1))
|
||||
// Note: leave (sub %c0, %c0) canceling to ArithmeticDialect
|
||||
// (ref: ArithmeticCanonicalization.td)
|
||||
def CombineGEPPattern : Pat<
|
||||
(TT_GEPOp (TT_GEPOp $ptr, $idx0), $idx1),
|
||||
(TT_GEPOp $ptr, (Arith_AddIOp $idx0, $idx1))>;
|
||||
def CombineAddPtrPattern : Pat<
|
||||
(TT_AddPtrOp (TT_AddPtrOp $ptr, $idx0), $idx1),
|
||||
(TT_AddPtrOp $ptr, (Arith_AddIOp $idx0, $idx1))>;
|
||||
|
||||
// broadcast(cst) => cst
|
||||
def getConstantValue : NativeCodeCall<"getConstantValue($_builder, $0, $1)">;
|
||||
|
@@ -203,7 +203,7 @@ bool tryLegalizeOp(Operation *op, DenseSet<Value> toPreserve,
|
||||
targetType.getEncoding());
|
||||
};
|
||||
bool hasSameTypes = op->getDialect()->getNamespace() == "arith" ||
|
||||
isa<triton::SplatOp, triton::GEPOp>(op);
|
||||
isa<triton::SplatOp, triton::AddPtrOp>(op);
|
||||
if (hasSameTypes) {
|
||||
// replace argument types
|
||||
for (auto arg : llvm::enumerate(op->getOperands())) {
|
||||
@@ -440,4 +440,4 @@ public:
|
||||
|
||||
std::unique_ptr<Pass> mlir::createTritonGPUCombineOpsPass() {
|
||||
return std::make_unique<TritonGPUCombineOpsPass>();
|
||||
}
|
||||
}
|
||||
|
@@ -70,7 +70,7 @@ private:
|
||||
if (auto storeOp = llvm::dyn_cast<triton::StoreOp>(op)) {
|
||||
// TODO: fill this
|
||||
}
|
||||
if (auto gepOp = llvm::dyn_cast<triton::GEPOp>(op)) {
|
||||
if (auto addptrOp = llvm::dyn_cast<triton::AddPtrOp>(op)) {
|
||||
// TODO: fill this
|
||||
}
|
||||
// Triton builtin Ops
|
||||
|
Reference in New Issue
Block a user