[FRONTEND] Made more tests pass (#805)

This commit is contained in:
Philippe Tillet
2022-10-26 17:47:33 -07:00
committed by GitHub
parent bb7008651a
commit 3e6cc6d66c
9 changed files with 303 additions and 166 deletions

View File

@@ -1400,6 +1400,8 @@ struct ExtractSliceOpConversion
}
};
// TODO: rewrite Ternary/Binary/Unary as Elementwise
// A CRTP style of base class.
template <typename SourceOp, typename DestOp, typename ConcreteT>
class BinaryOpConversionBase
@@ -1470,6 +1472,77 @@ struct BinaryOpConversion
Value getRhs(OpAdaptor adaptor) const { return adaptor.getRhs(); }
};
//
// Ternary
//
template <typename SourceOp, typename DestOp, typename ConcreteT>
class TernaryOpConversionBase
: public ConvertTritonGPUOpToLLVMPattern<SourceOp> {
public:
using OpAdaptor = typename SourceOp::Adaptor;
explicit TernaryOpConversionBase(LLVMTypeConverter &typeConverter,
PatternBenefit benefit = 1)
: ConvertTritonGPUOpToLLVMPattern<SourceOp>(typeConverter, benefit) {}
LogicalResult
matchAndRewrite(SourceOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto resultTy = op.getType().template dyn_cast<RankedTensorType>();
// ArithmeticToLLVM will handle the lowering of scalar ArithOps
if (!resultTy)
return failure();
Location loc = op->getLoc();
auto resultLayout =
resultTy.getEncoding().template dyn_cast<BlockedEncodingAttr>();
auto resultShape = resultTy.getShape();
assert(resultLayout && "Unexpected resultLayout in TernaryOpConversion");
unsigned elems = resultLayout.getElemsPerThread(resultShape);
Type elemTy =
this->getTypeConverter()->convertType(resultTy.getElementType());
SmallVector<Type> types(elems, elemTy);
Type structTy = LLVM::LLVMStructType::getLiteral(this->getContext(), types);
auto *concreteThis = static_cast<const ConcreteT *>(this);
auto lhss =
this->getElementsFromStruct(loc, adaptor.getOperands()[0], rewriter);
auto rhss =
this->getElementsFromStruct(loc, adaptor.getOperands()[1], rewriter);
auto thss =
this->getElementsFromStruct(loc, adaptor.getOperands()[2], rewriter);
SmallVector<Value> resultVals(elems);
for (unsigned i = 0; i < elems; ++i) {
resultVals[i] = concreteThis->createDestOp(op, rewriter, elemTy, lhss[i],
rhss[i], thss[i], loc);
}
Value view = getStructFromElements(loc, resultVals, rewriter, structTy);
rewriter.replaceOp(op, view);
return success();
}
};
template <typename SourceOp, typename DestOp>
struct TernaryOpConversion
: public TernaryOpConversionBase<SourceOp, DestOp,
TernaryOpConversion<SourceOp, DestOp>> {
explicit TernaryOpConversion(LLVMTypeConverter &typeConverter,
PatternBenefit benefit = 1)
: TernaryOpConversionBase<SourceOp, DestOp,
TernaryOpConversion<SourceOp, DestOp>>(
typeConverter, benefit) {}
using OpAdaptor = typename SourceOp::Adaptor;
// An interface to support variant DestOp builder.
DestOp createDestOp(SourceOp op, ConversionPatternRewriter &rewriter,
Type elemTy, Value lhs, Value rhs, Value th,
Location loc) const {
return rewriter.create<DestOp>(loc, elemTy, lhs, rhs, th);
}
};
//
// Unary
//
@@ -3590,9 +3663,14 @@ void populateTritonToLLVMPatterns(mlir::LLVMTypeConverter &typeConverter,
benefit);
patterns.add<ArithConstantSplatOpConversion>(typeConverter, benefit);
patterns.add<AsyncWaitOpConversion>(typeConverter, benefit);
#define POPULATE_TERNARY_OP(SRC_OP, DST_OP) \
patterns.add<TernaryOpConversion<SRC_OP, DST_OP>>(typeConverter, benefit);
POPULATE_TERNARY_OP(triton::gpu::SelectOp, LLVM::SelectOp);
#undef POPULATE_TERNARY_OP
#define POPULATE_BINARY_OP(SRC_OP, DST_OP) \
patterns.add<BinaryOpConversion<SRC_OP, DST_OP>>(typeConverter, benefit);
POPULATE_BINARY_OP(arith::SubIOp, LLVM::SubOp) // -
POPULATE_BINARY_OP(arith::SubFOp, LLVM::FSubOp)
POPULATE_BINARY_OP(arith::AddIOp, LLVM::AddOp) // +
@@ -3605,24 +3683,31 @@ void populateTritonToLLVMPatterns(mlir::LLVMTypeConverter &typeConverter,
POPULATE_BINARY_OP(arith::RemFOp, LLVM::FRemOp) // %
POPULATE_BINARY_OP(arith::RemSIOp, LLVM::SRemOp)
POPULATE_BINARY_OP(arith::RemUIOp, LLVM::URemOp)
POPULATE_BINARY_OP(arith::AndIOp, LLVM::AndOp) // &
POPULATE_BINARY_OP(arith::OrIOp, LLVM::OrOp) // |
POPULATE_BINARY_OP(arith::AndIOp, LLVM::AndOp) // &
POPULATE_BINARY_OP(arith::OrIOp, LLVM::OrOp) // |
POPULATE_BINARY_OP(arith::XOrIOp, LLVM::XOrOp) // ^
POPULATE_BINARY_OP(arith::ShLIOp, LLVM::ShlOp) // <<
POPULATE_BINARY_OP(arith::ShRSIOp, LLVM::AShrOp) // >>
POPULATE_BINARY_OP(arith::ShRUIOp, LLVM::LShrOp) // >>
#undef POPULATE_BINARY_OP
patterns.add<CmpIOpConversion>(typeConverter, benefit);
patterns.add<CmpFOpConversion>(typeConverter, benefit);
#define POPULATE_CAST_OP(SRC_OP, DST_OP) \
#define POPULATE_UNARY_OP(SRC_OP, DST_OP) \
patterns.add<UnaryOpConversion<SRC_OP, DST_OP>>(typeConverter, benefit);
POPULATE_CAST_OP(arith::TruncIOp, LLVM::TruncOp)
POPULATE_CAST_OP(arith::TruncFOp, LLVM::FPTruncOp)
POPULATE_CAST_OP(arith::ExtSIOp, LLVM::SExtOp)
POPULATE_CAST_OP(arith::ExtUIOp, LLVM::ZExtOp)
POPULATE_CAST_OP(arith::FPToUIOp, LLVM::FPToUIOp)
POPULATE_CAST_OP(arith::FPToSIOp, LLVM::FPToSIOp)
POPULATE_CAST_OP(arith::UIToFPOp, LLVM::UIToFPOp)
POPULATE_CAST_OP(arith::SIToFPOp, LLVM::SIToFPOp)
POPULATE_CAST_OP(arith::ExtFOp, LLVM::FPExtOp)
#undef POPULATE_CAST_OP
POPULATE_UNARY_OP(arith::TruncIOp, LLVM::TruncOp)
POPULATE_UNARY_OP(arith::TruncFOp, LLVM::FPTruncOp)
POPULATE_UNARY_OP(arith::ExtSIOp, LLVM::SExtOp)
POPULATE_UNARY_OP(arith::ExtUIOp, LLVM::ZExtOp)
POPULATE_UNARY_OP(arith::FPToUIOp, LLVM::FPToUIOp)
POPULATE_UNARY_OP(arith::FPToSIOp, LLVM::FPToSIOp)
POPULATE_UNARY_OP(arith::UIToFPOp, LLVM::UIToFPOp)
POPULATE_UNARY_OP(arith::SIToFPOp, LLVM::SIToFPOp)
POPULATE_UNARY_OP(arith::ExtFOp, LLVM::FPExtOp)
POPULATE_UNARY_OP(triton::BitcastOp, LLVM::BitcastOp)
POPULATE_UNARY_OP(triton::IntToPtrOp, LLVM::IntToPtrOp)
POPULATE_UNARY_OP(triton::PtrToIntOp, LLVM::PtrToIntOp)
#undef POPULATE_UNARY_OP
patterns.add<BroadcastOpConversion>(typeConverter, benefit);
patterns.add<ConvertLayoutOpConversion>(typeConverter, allocation, smem,

View File

@@ -351,6 +351,9 @@ void populateTritonPatterns(TritonGPUTypeConverter &typeConverter,
MLIRContext *context = patterns.getContext();
patterns.add< // TODO: view should have custom pattern that views the layout
TritonGenericPattern<triton::ViewOp>,
TritonGenericPattern<triton::BitcastOp>,
TritonGenericPattern<triton::IntToPtrOp>,
TritonGenericPattern<triton::PtrToIntOp>,
TritonGenericPattern<triton::SplatOp>, TritonBroadcastPattern,
TritonGenericPattern<triton::AddPtrOp>, TritonReducePattern,
TritonExpandDimsPattern, TritonMakeRangePattern, TritonDotPattern,

View File

@@ -150,9 +150,9 @@ void LoadOp::build(::mlir::OpBuilder &builder, ::mlir::OperationState &state,
::mlir::Value ptr, ::mlir::Value mask, ::mlir::Value other,
::mlir::triton::CacheModifier cache,
::mlir::triton::EvictionPolicy evict, bool isVolatile) {
TensorType ptrType = ptr.getType().dyn_cast<TensorType>();
TensorType ptrType = ptr.getType().cast<TensorType>();
Type elementType =
ptrType.getElementType().dyn_cast<PointerType>().getPointeeType();
ptrType.getElementType().cast<PointerType>().getPointeeType();
auto shape = ptrType.getShape();
Type resultType = RankedTensorType::get(shape, elementType);
state.addOperands(ptr);