[FRONTEND][BACKEND] Fixes for cat / reshape / addptr (#959)
Most notably, this PR: - changes the traits (and assembly format) of addptr so it can handle offsets that have arbitrary integer width. - adds support for `cat`
This commit is contained in:
@@ -1707,6 +1707,43 @@ LogicalResult ReduceOpConversion::matchAndRewriteFast(
|
||||
|
||||
/// ====================== reduce codegen end ==========================
|
||||
|
||||
/// ====================== cat codegen begin ==========================
|
||||
|
||||
struct CatOpConversion : public ConvertTritonGPUOpToLLVMPattern<CatOp> {
|
||||
using OpAdaptor = typename CatOp::Adaptor;
|
||||
|
||||
explicit CatOpConversion(LLVMTypeConverter &typeConverter,
|
||||
PatternBenefit benefit = 1)
|
||||
: ConvertTritonGPUOpToLLVMPattern<CatOp>(typeConverter, benefit) {}
|
||||
|
||||
LogicalResult
|
||||
matchAndRewrite(CatOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
Location loc = op->getLoc();
|
||||
auto resultTy = op.getType().template cast<RankedTensorType>();
|
||||
unsigned elems = getElemsPerThread(resultTy);
|
||||
Type elemTy =
|
||||
this->getTypeConverter()->convertType(resultTy.getElementType());
|
||||
SmallVector<Type> types(elems, elemTy);
|
||||
// unpack input values
|
||||
auto lhsVals = getElementsFromStruct(loc, adaptor.lhs(), rewriter);
|
||||
auto rhsVals = getElementsFromStruct(loc, adaptor.rhs(), rewriter);
|
||||
// concatenate (and potentially reorder) values
|
||||
SmallVector<Value> retVals;
|
||||
for(Value v: lhsVals)
|
||||
retVals.push_back(v);
|
||||
for(Value v: rhsVals)
|
||||
retVals.push_back(v);
|
||||
// pack and replace
|
||||
Type structTy = LLVM::LLVMStructType::getLiteral(this->getContext(), types);
|
||||
Value ret = getStructFromElements(loc, retVals, rewriter, structTy);
|
||||
rewriter.replaceOp(op, ret);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
/// ====================== cat codegen end ==========================
|
||||
|
||||
template <typename SourceOp>
|
||||
struct ViewLikeOpConversion : public ConvertTritonGPUOpToLLVMPattern<SourceOp> {
|
||||
using OpAdaptor = typename SourceOp::Adaptor;
|
||||
@@ -4537,6 +4574,7 @@ void populateTritonToLLVMPatterns(mlir::LLVMTypeConverter &typeConverter,
|
||||
benefit);
|
||||
patterns.add<DotOpConversion>(typeConverter, allocation, smem, benefit);
|
||||
patterns.add<TransOpConversion>(typeConverter, benefit);
|
||||
patterns.add<CatOpConversion>(typeConverter, benefit);
|
||||
patterns.add<PrintfOpConversion>(typeConverter, benefit);
|
||||
}
|
||||
|
||||
|
Reference in New Issue
Block a user