[FRONTEND] Add scalar type support for some ops (#661)
This PR adds basic support for scalar-type inputs to some ops (cast and pointer arithmetics) for Triton-MLIR. Also renames getelementptr -> addptr
This commit is contained in:
@@ -701,7 +701,7 @@ struct StoreOpConversion
|
||||
|
||||
const int numVecs = numElems / vec;
|
||||
for (size_t vecStart = 0; vecStart < numElems; vecStart += vec) {
|
||||
// TODO: optimization when ptr is GEP with constant offset
|
||||
// TODO: optimization when ptr is AddPtr with constant offset
|
||||
size_t in_off = 0;
|
||||
|
||||
const int maxWordWidth = std::max<int>(32, valueElemNbits);
|
||||
@@ -1173,12 +1173,13 @@ struct GetProgramIdOpConversion
|
||||
}
|
||||
};
|
||||
|
||||
struct GEPOpConversion : public ConvertTritonGPUOpToLLVMPattern<triton::GEPOp> {
|
||||
struct AddPtrOpConversion
|
||||
: public ConvertTritonGPUOpToLLVMPattern<triton::AddPtrOp> {
|
||||
using ConvertTritonGPUOpToLLVMPattern<
|
||||
triton::GEPOp>::ConvertTritonGPUOpToLLVMPattern;
|
||||
triton::AddPtrOp>::ConvertTritonGPUOpToLLVMPattern;
|
||||
|
||||
LogicalResult
|
||||
matchAndRewrite(triton::GEPOp op, OpAdaptor adaptor,
|
||||
matchAndRewrite(triton::AddPtrOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
Location loc = op->getLoc();
|
||||
auto resultTy = op.getType().dyn_cast<RankedTensorType>();
|
||||
@@ -1298,7 +1299,7 @@ void populateTritonToLLVMPatterns(mlir::LLVMTypeConverter &typeConverter,
|
||||
|
||||
patterns.add<BroadcastOpConversion>(typeConverter, benefit);
|
||||
patterns.add<FuncOpConversion>(typeConverter, numWarps, benefit);
|
||||
patterns.add<GEPOpConversion>(typeConverter, benefit);
|
||||
patterns.add<AddPtrOpConversion>(typeConverter, benefit);
|
||||
patterns.add<GetProgramIdOpConversion>(typeConverter, benefit);
|
||||
patterns.add<LoadOpConversion>(typeConverter, analysis, benefit);
|
||||
patterns.add<MakeRangeOpConversion>(typeConverter, benefit);
|
||||
|
Reference in New Issue
Block a user