[FRONTEND] Made more tests pass (#805)
This commit is contained in:
@@ -1400,6 +1400,8 @@ struct ExtractSliceOpConversion
|
||||
}
|
||||
};
|
||||
|
||||
// TODO: rewrite Ternary/Binary/Unary as Elementwise
|
||||
|
||||
// A CRTP style of base class.
|
||||
template <typename SourceOp, typename DestOp, typename ConcreteT>
|
||||
class BinaryOpConversionBase
|
||||
@@ -1470,6 +1472,77 @@ struct BinaryOpConversion
|
||||
Value getRhs(OpAdaptor adaptor) const { return adaptor.getRhs(); }
|
||||
};
|
||||
|
||||
//
|
||||
// Ternary
|
||||
//
|
||||
|
||||
template <typename SourceOp, typename DestOp, typename ConcreteT>
|
||||
class TernaryOpConversionBase
|
||||
: public ConvertTritonGPUOpToLLVMPattern<SourceOp> {
|
||||
public:
|
||||
using OpAdaptor = typename SourceOp::Adaptor;
|
||||
|
||||
explicit TernaryOpConversionBase(LLVMTypeConverter &typeConverter,
|
||||
PatternBenefit benefit = 1)
|
||||
: ConvertTritonGPUOpToLLVMPattern<SourceOp>(typeConverter, benefit) {}
|
||||
|
||||
LogicalResult
|
||||
matchAndRewrite(SourceOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
auto resultTy = op.getType().template dyn_cast<RankedTensorType>();
|
||||
// ArithmeticToLLVM will handle the lowering of scalar ArithOps
|
||||
if (!resultTy)
|
||||
return failure();
|
||||
|
||||
Location loc = op->getLoc();
|
||||
auto resultLayout =
|
||||
resultTy.getEncoding().template dyn_cast<BlockedEncodingAttr>();
|
||||
auto resultShape = resultTy.getShape();
|
||||
assert(resultLayout && "Unexpected resultLayout in TernaryOpConversion");
|
||||
unsigned elems = resultLayout.getElemsPerThread(resultShape);
|
||||
Type elemTy =
|
||||
this->getTypeConverter()->convertType(resultTy.getElementType());
|
||||
SmallVector<Type> types(elems, elemTy);
|
||||
Type structTy = LLVM::LLVMStructType::getLiteral(this->getContext(), types);
|
||||
|
||||
auto *concreteThis = static_cast<const ConcreteT *>(this);
|
||||
auto lhss =
|
||||
this->getElementsFromStruct(loc, adaptor.getOperands()[0], rewriter);
|
||||
auto rhss =
|
||||
this->getElementsFromStruct(loc, adaptor.getOperands()[1], rewriter);
|
||||
auto thss =
|
||||
this->getElementsFromStruct(loc, adaptor.getOperands()[2], rewriter);
|
||||
SmallVector<Value> resultVals(elems);
|
||||
for (unsigned i = 0; i < elems; ++i) {
|
||||
resultVals[i] = concreteThis->createDestOp(op, rewriter, elemTy, lhss[i],
|
||||
rhss[i], thss[i], loc);
|
||||
}
|
||||
Value view = getStructFromElements(loc, resultVals, rewriter, structTy);
|
||||
rewriter.replaceOp(op, view);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
template <typename SourceOp, typename DestOp>
|
||||
struct TernaryOpConversion
|
||||
: public TernaryOpConversionBase<SourceOp, DestOp,
|
||||
TernaryOpConversion<SourceOp, DestOp>> {
|
||||
|
||||
explicit TernaryOpConversion(LLVMTypeConverter &typeConverter,
|
||||
PatternBenefit benefit = 1)
|
||||
: TernaryOpConversionBase<SourceOp, DestOp,
|
||||
TernaryOpConversion<SourceOp, DestOp>>(
|
||||
typeConverter, benefit) {}
|
||||
|
||||
using OpAdaptor = typename SourceOp::Adaptor;
|
||||
// An interface to support variant DestOp builder.
|
||||
DestOp createDestOp(SourceOp op, ConversionPatternRewriter &rewriter,
|
||||
Type elemTy, Value lhs, Value rhs, Value th,
|
||||
Location loc) const {
|
||||
return rewriter.create<DestOp>(loc, elemTy, lhs, rhs, th);
|
||||
}
|
||||
};
|
||||
|
||||
//
|
||||
// Unary
|
||||
//
|
||||
@@ -3590,9 +3663,14 @@ void populateTritonToLLVMPatterns(mlir::LLVMTypeConverter &typeConverter,
|
||||
benefit);
|
||||
patterns.add<ArithConstantSplatOpConversion>(typeConverter, benefit);
|
||||
patterns.add<AsyncWaitOpConversion>(typeConverter, benefit);
|
||||
|
||||
#define POPULATE_TERNARY_OP(SRC_OP, DST_OP) \
|
||||
patterns.add<TernaryOpConversion<SRC_OP, DST_OP>>(typeConverter, benefit);
|
||||
POPULATE_TERNARY_OP(triton::gpu::SelectOp, LLVM::SelectOp);
|
||||
#undef POPULATE_TERNARY_OP
|
||||
|
||||
#define POPULATE_BINARY_OP(SRC_OP, DST_OP) \
|
||||
patterns.add<BinaryOpConversion<SRC_OP, DST_OP>>(typeConverter, benefit);
|
||||
|
||||
POPULATE_BINARY_OP(arith::SubIOp, LLVM::SubOp) // -
|
||||
POPULATE_BINARY_OP(arith::SubFOp, LLVM::FSubOp)
|
||||
POPULATE_BINARY_OP(arith::AddIOp, LLVM::AddOp) // +
|
||||
@@ -3605,24 +3683,31 @@ void populateTritonToLLVMPatterns(mlir::LLVMTypeConverter &typeConverter,
|
||||
POPULATE_BINARY_OP(arith::RemFOp, LLVM::FRemOp) // %
|
||||
POPULATE_BINARY_OP(arith::RemSIOp, LLVM::SRemOp)
|
||||
POPULATE_BINARY_OP(arith::RemUIOp, LLVM::URemOp)
|
||||
POPULATE_BINARY_OP(arith::AndIOp, LLVM::AndOp) // &
|
||||
POPULATE_BINARY_OP(arith::OrIOp, LLVM::OrOp) // |
|
||||
POPULATE_BINARY_OP(arith::AndIOp, LLVM::AndOp) // &
|
||||
POPULATE_BINARY_OP(arith::OrIOp, LLVM::OrOp) // |
|
||||
POPULATE_BINARY_OP(arith::XOrIOp, LLVM::XOrOp) // ^
|
||||
POPULATE_BINARY_OP(arith::ShLIOp, LLVM::ShlOp) // <<
|
||||
POPULATE_BINARY_OP(arith::ShRSIOp, LLVM::AShrOp) // >>
|
||||
POPULATE_BINARY_OP(arith::ShRUIOp, LLVM::LShrOp) // >>
|
||||
#undef POPULATE_BINARY_OP
|
||||
|
||||
patterns.add<CmpIOpConversion>(typeConverter, benefit);
|
||||
patterns.add<CmpFOpConversion>(typeConverter, benefit);
|
||||
#define POPULATE_CAST_OP(SRC_OP, DST_OP) \
|
||||
#define POPULATE_UNARY_OP(SRC_OP, DST_OP) \
|
||||
patterns.add<UnaryOpConversion<SRC_OP, DST_OP>>(typeConverter, benefit);
|
||||
POPULATE_CAST_OP(arith::TruncIOp, LLVM::TruncOp)
|
||||
POPULATE_CAST_OP(arith::TruncFOp, LLVM::FPTruncOp)
|
||||
POPULATE_CAST_OP(arith::ExtSIOp, LLVM::SExtOp)
|
||||
POPULATE_CAST_OP(arith::ExtUIOp, LLVM::ZExtOp)
|
||||
POPULATE_CAST_OP(arith::FPToUIOp, LLVM::FPToUIOp)
|
||||
POPULATE_CAST_OP(arith::FPToSIOp, LLVM::FPToSIOp)
|
||||
POPULATE_CAST_OP(arith::UIToFPOp, LLVM::UIToFPOp)
|
||||
POPULATE_CAST_OP(arith::SIToFPOp, LLVM::SIToFPOp)
|
||||
POPULATE_CAST_OP(arith::ExtFOp, LLVM::FPExtOp)
|
||||
#undef POPULATE_CAST_OP
|
||||
POPULATE_UNARY_OP(arith::TruncIOp, LLVM::TruncOp)
|
||||
POPULATE_UNARY_OP(arith::TruncFOp, LLVM::FPTruncOp)
|
||||
POPULATE_UNARY_OP(arith::ExtSIOp, LLVM::SExtOp)
|
||||
POPULATE_UNARY_OP(arith::ExtUIOp, LLVM::ZExtOp)
|
||||
POPULATE_UNARY_OP(arith::FPToUIOp, LLVM::FPToUIOp)
|
||||
POPULATE_UNARY_OP(arith::FPToSIOp, LLVM::FPToSIOp)
|
||||
POPULATE_UNARY_OP(arith::UIToFPOp, LLVM::UIToFPOp)
|
||||
POPULATE_UNARY_OP(arith::SIToFPOp, LLVM::SIToFPOp)
|
||||
POPULATE_UNARY_OP(arith::ExtFOp, LLVM::FPExtOp)
|
||||
POPULATE_UNARY_OP(triton::BitcastOp, LLVM::BitcastOp)
|
||||
POPULATE_UNARY_OP(triton::IntToPtrOp, LLVM::IntToPtrOp)
|
||||
POPULATE_UNARY_OP(triton::PtrToIntOp, LLVM::PtrToIntOp)
|
||||
#undef POPULATE_UNARY_OP
|
||||
|
||||
patterns.add<BroadcastOpConversion>(typeConverter, benefit);
|
||||
patterns.add<ConvertLayoutOpConversion>(typeConverter, allocation, smem,
|
||||
|
@@ -351,6 +351,9 @@ void populateTritonPatterns(TritonGPUTypeConverter &typeConverter,
|
||||
MLIRContext *context = patterns.getContext();
|
||||
patterns.add< // TODO: view should have custom pattern that views the layout
|
||||
TritonGenericPattern<triton::ViewOp>,
|
||||
TritonGenericPattern<triton::BitcastOp>,
|
||||
TritonGenericPattern<triton::IntToPtrOp>,
|
||||
TritonGenericPattern<triton::PtrToIntOp>,
|
||||
TritonGenericPattern<triton::SplatOp>, TritonBroadcastPattern,
|
||||
TritonGenericPattern<triton::AddPtrOp>, TritonReducePattern,
|
||||
TritonExpandDimsPattern, TritonMakeRangePattern, TritonDotPattern,
|
||||
|
@@ -150,9 +150,9 @@ void LoadOp::build(::mlir::OpBuilder &builder, ::mlir::OperationState &state,
|
||||
::mlir::Value ptr, ::mlir::Value mask, ::mlir::Value other,
|
||||
::mlir::triton::CacheModifier cache,
|
||||
::mlir::triton::EvictionPolicy evict, bool isVolatile) {
|
||||
TensorType ptrType = ptr.getType().dyn_cast<TensorType>();
|
||||
TensorType ptrType = ptr.getType().cast<TensorType>();
|
||||
Type elementType =
|
||||
ptrType.getElementType().dyn_cast<PointerType>().getPointeeType();
|
||||
ptrType.getElementType().cast<PointerType>().getPointeeType();
|
||||
auto shape = ptrType.getShape();
|
||||
Type resultType = RankedTensorType::get(shape, elementType);
|
||||
state.addOperands(ptr);
|
||||
|
Reference in New Issue
Block a user