[FRONTEND] Enhanced support for binary operators (#801)

Disabled modulo test (due to change in behavior for `frem` in nvptx
between llvm-11 and llvm-14) and bfloat16 (will require some work to
emulate in software similar to how it's done in `master`)
This commit is contained in:
Philippe Tillet
2022-10-24 19:47:01 -07:00
committed by GitHub
parent fcb228d1d4
commit a2cbe7af91
4 changed files with 1696 additions and 18 deletions

View File

@@ -842,9 +842,8 @@ struct LoadOpConversion
bool otherIsSplatConstInt = false;
DenseElementsAttr constAttr;
int64_t splatVal = 0;
if (valueElemTy.isa<IntegerType>() &&
matchPattern(op.other(), m_Constant(&constAttr)) &&
constAttr.isSplat()) {
if (other && valueElemTy.isa<IntegerType>() &&
matchPattern(other, m_Constant(&constAttr)) && constAttr.isSplat()) {
otherIsSplatConstInt = true;
splatVal = constAttr.getSplatValue<APInt>().getSExtValue();
}
@@ -1454,6 +1453,85 @@ struct BinaryOpConversion
Value getRhs(OpAdaptor adaptor) const { return adaptor.getRhs(); }
};
//
// Unary
//
template <typename SourceOp, typename DestOp, typename ConcreteT>
class UnaryOpConversionBase : public ConvertTritonGPUOpToLLVMPattern<SourceOp> {
public:
using OpAdaptor = typename SourceOp::Adaptor;
explicit UnaryOpConversionBase(LLVMTypeConverter &typeConverter,
PatternBenefit benefit = 1)
: ConvertTritonGPUOpToLLVMPattern<SourceOp>(typeConverter, benefit) {}
LogicalResult
matchAndRewrite(SourceOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto resultTy = op.getType().template dyn_cast<RankedTensorType>();
// ArithmeticToLLVM will handle the lowering of scalar ArithOps
if (!resultTy)
return failure();
Location loc = op->getLoc();
auto resultLayout =
resultTy.getEncoding().template dyn_cast<BlockedEncodingAttr>();
auto resultShape = resultTy.getShape();
assert(resultLayout && "Unexpected resultLayout in UnaryOpConversion");
unsigned elems = resultLayout.getElemsPerThread(resultShape);
Type elemTy =
this->getTypeConverter()->convertType(resultTy.getElementType());
SmallVector<Type> types(elems, elemTy);
Type structTy = LLVM::LLVMStructType::getLiteral(this->getContext(), types);
auto *concreteThis = static_cast<const ConcreteT *>(this);
auto srcs = this->getElementsFromStruct(loc, concreteThis->getSrc(adaptor),
rewriter);
SmallVector<Value> resultVals(elems);
for (unsigned i = 0; i < elems; ++i) {
resultVals[i] =
concreteThis->createDestOp(op, rewriter, elemTy, srcs[i], loc);
}
Value view = getStructFromElements(loc, resultVals, rewriter, structTy);
rewriter.replaceOp(op, view);
return success();
}
};
template <typename SourceOp, typename DestOp>
struct UnaryOpConversion
: public UnaryOpConversionBase<SourceOp, DestOp,
UnaryOpConversion<SourceOp, DestOp>> {
explicit UnaryOpConversion(LLVMTypeConverter &typeConverter,
PatternBenefit benefit = 1)
: UnaryOpConversionBase<SourceOp, DestOp,
UnaryOpConversion<SourceOp, DestOp>>(
typeConverter, benefit) {}
using OpAdaptor = typename SourceOp::Adaptor;
// An interface to support variant DestOp builder.
DestOp createDestOp(SourceOp op, ConversionPatternRewriter &rewriter,
Type elemTy, Value src, Location loc) const {
return rewriter.create<DestOp>(loc, elemTy, src);
}
// Get the source operand of the op.
Value getSrc(OpAdaptor adaptor) const {
auto operands = adaptor.getOperands();
if (operands.size() > 1)
llvm::report_fatal_error("unary operator has more than one operand");
return operands.front();
}
};
//
// comparisons
//
struct CmpIOpConversion
: public BinaryOpConversionBase<triton::gpu::CmpIOp, LLVM::ICmpOp,
CmpIOpConversion> {
@@ -3109,6 +3187,10 @@ public:
addConversion([&](RankedTensorType type) -> llvm::Optional<Type> {
return convertTritonTensorType(type);
});
// internally store bfloat16 as int16
addConversion([&](BFloat16Type type) -> llvm::Optional<Type> {
return IntegerType::get(type.getContext(), 16);
});
}
Type convertTritonPointerType(triton::PointerType type) {
@@ -3367,25 +3449,44 @@ void populateTritonToLLVMPatterns(mlir::LLVMTypeConverter &typeConverter,
benefit);
patterns.add<ArithConstantSplatOpConversion>(typeConverter, benefit);
patterns.add<AsyncWaitOpConversion>(typeConverter, benefit);
patterns.add<BinaryOpConversion<arith::AddIOp, LLVM::AddOp>>(typeConverter,
benefit);
patterns.add<BinaryOpConversion<arith::AddFOp, LLVM::FAddOp>>(typeConverter,
benefit);
patterns.add<BinaryOpConversion<arith::MulIOp, LLVM::MulOp>>(typeConverter,
benefit);
patterns.add<BinaryOpConversion<arith::MulFOp, LLVM::FMulOp>>(typeConverter,
benefit);
#define POPULATE_BINARY_OP(SRC_OP, DST_OP) \
patterns.add<BinaryOpConversion<SRC_OP, DST_OP>>(typeConverter, benefit);
patterns.add<BinaryOpConversion<arith::AndIOp, LLVM::AndOp>>(typeConverter,
benefit);
patterns.add<BinaryOpConversion<arith::OrIOp, LLVM::OrOp>>(typeConverter,
benefit);
POPULATE_BINARY_OP(arith::SubIOp, LLVM::SubOp) // -
POPULATE_BINARY_OP(arith::SubFOp, LLVM::FSubOp)
POPULATE_BINARY_OP(arith::AddIOp, LLVM::AddOp) // +
POPULATE_BINARY_OP(arith::AddFOp, LLVM::FAddOp)
POPULATE_BINARY_OP(arith::MulIOp, LLVM::MulOp) // *
POPULATE_BINARY_OP(arith::MulFOp, LLVM::FMulOp)
POPULATE_BINARY_OP(arith::DivFOp, LLVM::FDivOp) // /
POPULATE_BINARY_OP(arith::DivSIOp, LLVM::SDivOp)
POPULATE_BINARY_OP(arith::DivUIOp, LLVM::UDivOp)
POPULATE_BINARY_OP(arith::RemFOp, LLVM::FRemOp) // %
POPULATE_BINARY_OP(arith::RemSIOp, LLVM::SRemOp)
POPULATE_BINARY_OP(arith::RemUIOp, LLVM::URemOp)
POPULATE_BINARY_OP(arith::AndIOp, LLVM::AndOp) // &
POPULATE_BINARY_OP(arith::OrIOp, LLVM::OrOp) // |
#undef POPULATE_BINARY_OP
patterns.add<CmpIOpConversion>(typeConverter, benefit);
patterns.add<CmpFOpConversion>(typeConverter, benefit);
#define POPULATE_CAST_OP(SRC_OP, DST_OP) \
patterns.add<UnaryOpConversion<SRC_OP, DST_OP>>(typeConverter, benefit);
POPULATE_CAST_OP(arith::TruncIOp, LLVM::TruncOp)
POPULATE_CAST_OP(arith::TruncFOp, LLVM::FPTruncOp)
POPULATE_CAST_OP(arith::ExtSIOp, LLVM::SExtOp)
POPULATE_CAST_OP(arith::ExtUIOp, LLVM::ZExtOp)
POPULATE_CAST_OP(arith::FPToUIOp, LLVM::FPToUIOp)
POPULATE_CAST_OP(arith::FPToSIOp, LLVM::FPToSIOp)
POPULATE_CAST_OP(arith::UIToFPOp, LLVM::UIToFPOp)
POPULATE_CAST_OP(arith::SIToFPOp, LLVM::SIToFPOp)
POPULATE_CAST_OP(arith::ExtFOp, LLVM::FPExtOp)
#undef POPULATE_CAST_OP
patterns.add<BroadcastOpConversion>(typeConverter, benefit);
patterns.add<ConvertLayoutOpConversion>(typeConverter, allocation, smem,
benefit);
patterns.add<ExtractSliceOpConversion>(typeConverter, allocation, smem,
benefit);
patterns.add<GetProgramIdOpConversion>(typeConverter, benefit);

View File

@@ -115,7 +115,9 @@ void populateArithmeticPatternsAndLegality(
ArithCmpPattern<arith::CmpFOp, triton::gpu::CmpFOp>,
// Cast Ops
GenericOpPattern<arith::TruncIOp>, GenericOpPattern<arith::TruncFOp>,
GenericOpPattern<arith::SIToFPOp>>(typeConverter, context);
GenericOpPattern<arith::ExtUIOp>, GenericOpPattern<arith::ExtSIOp>,
GenericOpPattern<arith::ExtFOp>, GenericOpPattern<arith::SIToFPOp>,
GenericOpPattern<arith::UIToFPOp>>(typeConverter, context);
}
// this shouldn't exist if mlir's SelectOp checked encodings properly

View File

@@ -642,8 +642,31 @@ void init_triton_ir(py::module &&m) {
auto loc = self.getUnknownLoc();
return self.create<mlir::arith::TruncFOp>(loc, dstType, src);
})
// .def("create_int_cast", &ir::builder::create_int_cast)
// .def("create_downcast", &ir::builder::create_downcast)
.def("create_int_cast",
[](mlir::OpBuilder &self, mlir::Value &src, mlir::Type &dstType,
bool isSigned) -> mlir::Value {
auto loc = self.getUnknownLoc();
// get element type if necessary
mlir::Type srcType = src.getType();
mlir::Type srcEltType = srcType;
mlir::Type dstEltType = dstType;
if (dstType.isa<mlir::RankedTensorType>()) {
dstEltType =
dstType.cast<mlir::RankedTensorType>().getElementType();
srcEltType =
srcType.cast<mlir::RankedTensorType>().getElementType();
}
unsigned srcWidth = srcEltType.getIntOrFloatBitWidth();
unsigned dstWidth = dstEltType.getIntOrFloatBitWidth();
if (srcWidth == dstWidth)
return self.create<mlir::arith::BitcastOp>(loc, dstType, src);
else if (srcWidth > dstWidth)
return self.create<mlir::arith::TruncIOp>(loc, dstType, src);
else if (isSigned)
return self.create<mlir::arith::ExtSIOp>(loc, dstType, src);
else
return self.create<mlir::arith::ExtUIOp>(loc, dstType, src);
})
.def("create_to_index",
[](mlir::OpBuilder &self, mlir::Value &input) -> mlir::Value {
auto loc = self.getUnknownLoc();

1552
python/tests/test_core.py Normal file

File diff suppressed because it is too large Load Diff