[FRONTEND] Enhanced support for binary operators (#801)
Disabled modulo test (due to change in behavior for `frem` in nvptx between llvm-11 and llvm-14) and bfloat16 (will require some work to emulate in software similar to how it's done in `master`)
This commit is contained in:
@@ -842,9 +842,8 @@ struct LoadOpConversion
|
||||
bool otherIsSplatConstInt = false;
|
||||
DenseElementsAttr constAttr;
|
||||
int64_t splatVal = 0;
|
||||
if (valueElemTy.isa<IntegerType>() &&
|
||||
matchPattern(op.other(), m_Constant(&constAttr)) &&
|
||||
constAttr.isSplat()) {
|
||||
if (other && valueElemTy.isa<IntegerType>() &&
|
||||
matchPattern(other, m_Constant(&constAttr)) && constAttr.isSplat()) {
|
||||
otherIsSplatConstInt = true;
|
||||
splatVal = constAttr.getSplatValue<APInt>().getSExtValue();
|
||||
}
|
||||
@@ -1454,6 +1453,85 @@ struct BinaryOpConversion
|
||||
Value getRhs(OpAdaptor adaptor) const { return adaptor.getRhs(); }
|
||||
};
|
||||
|
||||
//
|
||||
// Unary
|
||||
//
|
||||
|
||||
template <typename SourceOp, typename DestOp, typename ConcreteT>
|
||||
class UnaryOpConversionBase : public ConvertTritonGPUOpToLLVMPattern<SourceOp> {
|
||||
|
||||
public:
|
||||
using OpAdaptor = typename SourceOp::Adaptor;
|
||||
|
||||
explicit UnaryOpConversionBase(LLVMTypeConverter &typeConverter,
|
||||
PatternBenefit benefit = 1)
|
||||
: ConvertTritonGPUOpToLLVMPattern<SourceOp>(typeConverter, benefit) {}
|
||||
|
||||
LogicalResult
|
||||
matchAndRewrite(SourceOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
auto resultTy = op.getType().template dyn_cast<RankedTensorType>();
|
||||
|
||||
// ArithmeticToLLVM will handle the lowering of scalar ArithOps
|
||||
if (!resultTy)
|
||||
return failure();
|
||||
|
||||
Location loc = op->getLoc();
|
||||
auto resultLayout =
|
||||
resultTy.getEncoding().template dyn_cast<BlockedEncodingAttr>();
|
||||
auto resultShape = resultTy.getShape();
|
||||
assert(resultLayout && "Unexpected resultLayout in UnaryOpConversion");
|
||||
unsigned elems = resultLayout.getElemsPerThread(resultShape);
|
||||
Type elemTy =
|
||||
this->getTypeConverter()->convertType(resultTy.getElementType());
|
||||
SmallVector<Type> types(elems, elemTy);
|
||||
Type structTy = LLVM::LLVMStructType::getLiteral(this->getContext(), types);
|
||||
|
||||
auto *concreteThis = static_cast<const ConcreteT *>(this);
|
||||
auto srcs = this->getElementsFromStruct(loc, concreteThis->getSrc(adaptor),
|
||||
rewriter);
|
||||
SmallVector<Value> resultVals(elems);
|
||||
for (unsigned i = 0; i < elems; ++i) {
|
||||
resultVals[i] =
|
||||
concreteThis->createDestOp(op, rewriter, elemTy, srcs[i], loc);
|
||||
}
|
||||
Value view = getStructFromElements(loc, resultVals, rewriter, structTy);
|
||||
rewriter.replaceOp(op, view);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
template <typename SourceOp, typename DestOp>
|
||||
struct UnaryOpConversion
|
||||
: public UnaryOpConversionBase<SourceOp, DestOp,
|
||||
UnaryOpConversion<SourceOp, DestOp>> {
|
||||
|
||||
explicit UnaryOpConversion(LLVMTypeConverter &typeConverter,
|
||||
PatternBenefit benefit = 1)
|
||||
: UnaryOpConversionBase<SourceOp, DestOp,
|
||||
UnaryOpConversion<SourceOp, DestOp>>(
|
||||
typeConverter, benefit) {}
|
||||
|
||||
using OpAdaptor = typename SourceOp::Adaptor;
|
||||
// An interface to support variant DestOp builder.
|
||||
DestOp createDestOp(SourceOp op, ConversionPatternRewriter &rewriter,
|
||||
Type elemTy, Value src, Location loc) const {
|
||||
return rewriter.create<DestOp>(loc, elemTy, src);
|
||||
}
|
||||
|
||||
// Get the source operand of the op.
|
||||
Value getSrc(OpAdaptor adaptor) const {
|
||||
auto operands = adaptor.getOperands();
|
||||
if (operands.size() > 1)
|
||||
llvm::report_fatal_error("unary operator has more than one operand");
|
||||
return operands.front();
|
||||
}
|
||||
};
|
||||
|
||||
//
|
||||
// comparisons
|
||||
//
|
||||
|
||||
struct CmpIOpConversion
|
||||
: public BinaryOpConversionBase<triton::gpu::CmpIOp, LLVM::ICmpOp,
|
||||
CmpIOpConversion> {
|
||||
@@ -3109,6 +3187,10 @@ public:
|
||||
addConversion([&](RankedTensorType type) -> llvm::Optional<Type> {
|
||||
return convertTritonTensorType(type);
|
||||
});
|
||||
// internally store bfloat16 as int16
|
||||
addConversion([&](BFloat16Type type) -> llvm::Optional<Type> {
|
||||
return IntegerType::get(type.getContext(), 16);
|
||||
});
|
||||
}
|
||||
|
||||
Type convertTritonPointerType(triton::PointerType type) {
|
||||
@@ -3367,25 +3449,44 @@ void populateTritonToLLVMPatterns(mlir::LLVMTypeConverter &typeConverter,
|
||||
benefit);
|
||||
patterns.add<ArithConstantSplatOpConversion>(typeConverter, benefit);
|
||||
patterns.add<AsyncWaitOpConversion>(typeConverter, benefit);
|
||||
patterns.add<BinaryOpConversion<arith::AddIOp, LLVM::AddOp>>(typeConverter,
|
||||
benefit);
|
||||
patterns.add<BinaryOpConversion<arith::AddFOp, LLVM::FAddOp>>(typeConverter,
|
||||
benefit);
|
||||
patterns.add<BinaryOpConversion<arith::MulIOp, LLVM::MulOp>>(typeConverter,
|
||||
benefit);
|
||||
patterns.add<BinaryOpConversion<arith::MulFOp, LLVM::FMulOp>>(typeConverter,
|
||||
benefit);
|
||||
#define POPULATE_BINARY_OP(SRC_OP, DST_OP) \
|
||||
patterns.add<BinaryOpConversion<SRC_OP, DST_OP>>(typeConverter, benefit);
|
||||
|
||||
patterns.add<BinaryOpConversion<arith::AndIOp, LLVM::AndOp>>(typeConverter,
|
||||
benefit);
|
||||
patterns.add<BinaryOpConversion<arith::OrIOp, LLVM::OrOp>>(typeConverter,
|
||||
benefit);
|
||||
POPULATE_BINARY_OP(arith::SubIOp, LLVM::SubOp) // -
|
||||
POPULATE_BINARY_OP(arith::SubFOp, LLVM::FSubOp)
|
||||
POPULATE_BINARY_OP(arith::AddIOp, LLVM::AddOp) // +
|
||||
POPULATE_BINARY_OP(arith::AddFOp, LLVM::FAddOp)
|
||||
POPULATE_BINARY_OP(arith::MulIOp, LLVM::MulOp) // *
|
||||
POPULATE_BINARY_OP(arith::MulFOp, LLVM::FMulOp)
|
||||
POPULATE_BINARY_OP(arith::DivFOp, LLVM::FDivOp) // /
|
||||
POPULATE_BINARY_OP(arith::DivSIOp, LLVM::SDivOp)
|
||||
POPULATE_BINARY_OP(arith::DivUIOp, LLVM::UDivOp)
|
||||
POPULATE_BINARY_OP(arith::RemFOp, LLVM::FRemOp) // %
|
||||
POPULATE_BINARY_OP(arith::RemSIOp, LLVM::SRemOp)
|
||||
POPULATE_BINARY_OP(arith::RemUIOp, LLVM::URemOp)
|
||||
POPULATE_BINARY_OP(arith::AndIOp, LLVM::AndOp) // &
|
||||
POPULATE_BINARY_OP(arith::OrIOp, LLVM::OrOp) // |
|
||||
#undef POPULATE_BINARY_OP
|
||||
|
||||
patterns.add<CmpIOpConversion>(typeConverter, benefit);
|
||||
patterns.add<CmpFOpConversion>(typeConverter, benefit);
|
||||
#define POPULATE_CAST_OP(SRC_OP, DST_OP) \
|
||||
patterns.add<UnaryOpConversion<SRC_OP, DST_OP>>(typeConverter, benefit);
|
||||
POPULATE_CAST_OP(arith::TruncIOp, LLVM::TruncOp)
|
||||
POPULATE_CAST_OP(arith::TruncFOp, LLVM::FPTruncOp)
|
||||
POPULATE_CAST_OP(arith::ExtSIOp, LLVM::SExtOp)
|
||||
POPULATE_CAST_OP(arith::ExtUIOp, LLVM::ZExtOp)
|
||||
POPULATE_CAST_OP(arith::FPToUIOp, LLVM::FPToUIOp)
|
||||
POPULATE_CAST_OP(arith::FPToSIOp, LLVM::FPToSIOp)
|
||||
POPULATE_CAST_OP(arith::UIToFPOp, LLVM::UIToFPOp)
|
||||
POPULATE_CAST_OP(arith::SIToFPOp, LLVM::SIToFPOp)
|
||||
POPULATE_CAST_OP(arith::ExtFOp, LLVM::FPExtOp)
|
||||
#undef POPULATE_CAST_OP
|
||||
|
||||
patterns.add<BroadcastOpConversion>(typeConverter, benefit);
|
||||
patterns.add<ConvertLayoutOpConversion>(typeConverter, allocation, smem,
|
||||
benefit);
|
||||
|
||||
patterns.add<ExtractSliceOpConversion>(typeConverter, allocation, smem,
|
||||
benefit);
|
||||
patterns.add<GetProgramIdOpConversion>(typeConverter, benefit);
|
||||
|
@@ -115,7 +115,9 @@ void populateArithmeticPatternsAndLegality(
|
||||
ArithCmpPattern<arith::CmpFOp, triton::gpu::CmpFOp>,
|
||||
// Cast Ops
|
||||
GenericOpPattern<arith::TruncIOp>, GenericOpPattern<arith::TruncFOp>,
|
||||
GenericOpPattern<arith::SIToFPOp>>(typeConverter, context);
|
||||
GenericOpPattern<arith::ExtUIOp>, GenericOpPattern<arith::ExtSIOp>,
|
||||
GenericOpPattern<arith::ExtFOp>, GenericOpPattern<arith::SIToFPOp>,
|
||||
GenericOpPattern<arith::UIToFPOp>>(typeConverter, context);
|
||||
}
|
||||
|
||||
// this shouldn't exist if mlir's SelectOp checked encodings properly
|
||||
|
@@ -642,8 +642,31 @@ void init_triton_ir(py::module &&m) {
|
||||
auto loc = self.getUnknownLoc();
|
||||
return self.create<mlir::arith::TruncFOp>(loc, dstType, src);
|
||||
})
|
||||
// .def("create_int_cast", &ir::builder::create_int_cast)
|
||||
// .def("create_downcast", &ir::builder::create_downcast)
|
||||
.def("create_int_cast",
|
||||
[](mlir::OpBuilder &self, mlir::Value &src, mlir::Type &dstType,
|
||||
bool isSigned) -> mlir::Value {
|
||||
auto loc = self.getUnknownLoc();
|
||||
// get element type if necessary
|
||||
mlir::Type srcType = src.getType();
|
||||
mlir::Type srcEltType = srcType;
|
||||
mlir::Type dstEltType = dstType;
|
||||
if (dstType.isa<mlir::RankedTensorType>()) {
|
||||
dstEltType =
|
||||
dstType.cast<mlir::RankedTensorType>().getElementType();
|
||||
srcEltType =
|
||||
srcType.cast<mlir::RankedTensorType>().getElementType();
|
||||
}
|
||||
unsigned srcWidth = srcEltType.getIntOrFloatBitWidth();
|
||||
unsigned dstWidth = dstEltType.getIntOrFloatBitWidth();
|
||||
if (srcWidth == dstWidth)
|
||||
return self.create<mlir::arith::BitcastOp>(loc, dstType, src);
|
||||
else if (srcWidth > dstWidth)
|
||||
return self.create<mlir::arith::TruncIOp>(loc, dstType, src);
|
||||
else if (isSigned)
|
||||
return self.create<mlir::arith::ExtSIOp>(loc, dstType, src);
|
||||
else
|
||||
return self.create<mlir::arith::ExtUIOp>(loc, dstType, src);
|
||||
})
|
||||
.def("create_to_index",
|
||||
[](mlir::OpBuilder &self, mlir::Value &input) -> mlir::Value {
|
||||
auto loc = self.getUnknownLoc();
|
||||
|
1552
python/tests/test_core.py
Normal file
1552
python/tests/test_core.py
Normal file
File diff suppressed because it is too large
Load Diff
Reference in New Issue
Block a user