[Triton-MLIR][BACKEND] Add elementwise ops and tests (#804)
Co-authored-by: Keren Zhou <kerenzhou@openai.com>
This commit is contained in:
@@ -1,6 +1,7 @@
|
|||||||
#ifndef TRITON_TARGET_LLVMIRTRANSLATION_H
|
#ifndef TRITON_TARGET_LLVMIRTRANSLATION_H
|
||||||
#define TRITON_TARGET_LLVMIRTRANSLATION_H
|
#define TRITON_TARGET_LLVMIRTRANSLATION_H
|
||||||
#include <memory>
|
#include <memory>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
namespace llvm {
|
namespace llvm {
|
||||||
class Module;
|
class Module;
|
||||||
@@ -14,6 +15,11 @@ class ModuleOp;
|
|||||||
namespace mlir {
|
namespace mlir {
|
||||||
namespace triton {
|
namespace triton {
|
||||||
|
|
||||||
|
// add external dependent libs
|
||||||
|
void addExternalLibs(mlir::ModuleOp &module,
|
||||||
|
const std::vector<std::string> &names,
|
||||||
|
const std::vector<std::string> &paths);
|
||||||
|
|
||||||
// Translate TritonGPU dialect to LLVMIR, return null if failed.
|
// Translate TritonGPU dialect to LLVMIR, return null if failed.
|
||||||
std::unique_ptr<llvm::Module>
|
std::unique_ptr<llvm::Module>
|
||||||
translateTritonGPUToLLVMIR(llvm::LLVMContext *llvmContext,
|
translateTritonGPUToLLVMIR(llvm::LLVMContext *llvmContext,
|
||||||
|
@@ -1792,17 +1792,15 @@ struct ExtractSliceOpConversion
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
// TODO: rewrite Ternary/Binary/Unary as Elementwise
|
|
||||||
|
|
||||||
// A CRTP style of base class.
|
// A CRTP style of base class.
|
||||||
template <typename SourceOp, typename DestOp, typename ConcreteT>
|
template <typename SourceOp, typename DestOp, typename ConcreteT>
|
||||||
class BinaryOpConversionBase
|
class ElementwiseOpConversionBase
|
||||||
: public ConvertTritonGPUOpToLLVMPattern<SourceOp> {
|
: public ConvertTritonGPUOpToLLVMPattern<SourceOp> {
|
||||||
public:
|
public:
|
||||||
using OpAdaptor = typename SourceOp::Adaptor;
|
using OpAdaptor = typename SourceOp::Adaptor;
|
||||||
|
|
||||||
explicit BinaryOpConversionBase(LLVMTypeConverter &typeConverter,
|
explicit ElementwiseOpConversionBase(LLVMTypeConverter &typeConverter,
|
||||||
PatternBenefit benefit = 1)
|
PatternBenefit benefit = 1)
|
||||||
: ConvertTritonGPUOpToLLVMPattern<SourceOp>(typeConverter, benefit) {}
|
: ConvertTritonGPUOpToLLVMPattern<SourceOp>(typeConverter, benefit) {}
|
||||||
|
|
||||||
LogicalResult
|
LogicalResult
|
||||||
@@ -1817,7 +1815,8 @@ public:
|
|||||||
auto resultLayout =
|
auto resultLayout =
|
||||||
resultTy.getEncoding().template dyn_cast<BlockedEncodingAttr>();
|
resultTy.getEncoding().template dyn_cast<BlockedEncodingAttr>();
|
||||||
auto resultShape = resultTy.getShape();
|
auto resultShape = resultTy.getShape();
|
||||||
assert(resultLayout && "Unexpected resultLayout in BinaryOpConversion");
|
assert(resultLayout &&
|
||||||
|
"Unexpected resultLayout in ElementwiseOpConversionBase");
|
||||||
unsigned elems = resultLayout.getElemsPerThread(resultShape);
|
unsigned elems = resultLayout.getElemsPerThread(resultShape);
|
||||||
Type elemTy =
|
Type elemTy =
|
||||||
this->getTypeConverter()->convertType(resultTy.getElementType());
|
this->getTypeConverter()->convertType(resultTy.getElementType());
|
||||||
@@ -1825,43 +1824,54 @@ public:
|
|||||||
Type structTy = LLVM::LLVMStructType::getLiteral(this->getContext(), types);
|
Type structTy = LLVM::LLVMStructType::getLiteral(this->getContext(), types);
|
||||||
|
|
||||||
auto *concreteThis = static_cast<const ConcreteT *>(this);
|
auto *concreteThis = static_cast<const ConcreteT *>(this);
|
||||||
auto lhss = this->getElementsFromStruct(loc, concreteThis->getLhs(adaptor),
|
auto operands = getOperands(rewriter, adaptor, elems, loc);
|
||||||
rewriter);
|
|
||||||
auto rhss = this->getElementsFromStruct(loc, concreteThis->getRhs(adaptor),
|
|
||||||
rewriter);
|
|
||||||
SmallVector<Value> resultVals(elems);
|
SmallVector<Value> resultVals(elems);
|
||||||
for (unsigned i = 0; i < elems; ++i) {
|
for (unsigned i = 0; i < elems; ++i) {
|
||||||
resultVals[i] = concreteThis->createDestOp(op, rewriter, elemTy, lhss[i],
|
resultVals[i] = concreteThis->createDestOp(op, adaptor, rewriter, elemTy,
|
||||||
rhss[i], loc);
|
operands[i], loc);
|
||||||
}
|
}
|
||||||
Value view = getStructFromElements(loc, resultVals, rewriter, structTy);
|
Value view = getStructFromElements(loc, resultVals, rewriter, structTy);
|
||||||
rewriter.replaceOp(op, view);
|
rewriter.replaceOp(op, view);
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
protected:
|
||||||
|
SmallVector<SmallVector<Value>>
|
||||||
|
getOperands(ConversionPatternRewriter &rewriter, OpAdaptor adaptor,
|
||||||
|
const unsigned elems, Location loc) const {
|
||||||
|
SmallVector<SmallVector<Value>> operands(elems);
|
||||||
|
for (auto operand : adaptor.getOperands()) {
|
||||||
|
auto sub_operands = this->getElementsFromStruct(loc, operand, rewriter);
|
||||||
|
for (int i = 0; i < elems; ++i) {
|
||||||
|
operands[i].push_back(sub_operands[i]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return operands;
|
||||||
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
template <typename SourceOp, typename DestOp>
|
template <typename SourceOp, typename DestOp>
|
||||||
struct BinaryOpConversion
|
struct ElementwiseOpConversion
|
||||||
: public BinaryOpConversionBase<SourceOp, DestOp,
|
: public ElementwiseOpConversionBase<
|
||||||
BinaryOpConversion<SourceOp, DestOp>> {
|
SourceOp, DestOp, ElementwiseOpConversion<SourceOp, DestOp>> {
|
||||||
|
using Base =
|
||||||
|
ElementwiseOpConversionBase<SourceOp, DestOp,
|
||||||
|
ElementwiseOpConversion<SourceOp, DestOp>>;
|
||||||
|
using Base::Base;
|
||||||
|
using OpAdaptor = typename Base::OpAdaptor;
|
||||||
|
|
||||||
explicit BinaryOpConversion(LLVMTypeConverter &typeConverter,
|
explicit ElementwiseOpConversion(LLVMTypeConverter &typeConverter,
|
||||||
PatternBenefit benefit = 1)
|
PatternBenefit benefit = 1)
|
||||||
: BinaryOpConversionBase<SourceOp, DestOp,
|
: ElementwiseOpConversionBase<SourceOp, DestOp, ElementwiseOpConversion>(
|
||||||
BinaryOpConversion<SourceOp, DestOp>>(
|
|
||||||
typeConverter, benefit) {}
|
typeConverter, benefit) {}
|
||||||
|
|
||||||
using OpAdaptor = typename SourceOp::Adaptor;
|
|
||||||
// An interface to support variant DestOp builder.
|
// An interface to support variant DestOp builder.
|
||||||
DestOp createDestOp(SourceOp op, ConversionPatternRewriter &rewriter,
|
DestOp createDestOp(SourceOp op, OpAdaptor adaptor,
|
||||||
Type elemTy, Value lhs, Value rhs, Location loc) const {
|
ConversionPatternRewriter &rewriter, Type elemTy,
|
||||||
return rewriter.create<DestOp>(loc, elemTy, lhs, rhs);
|
ValueRange operands, Location loc) const {
|
||||||
|
return rewriter.create<DestOp>(loc, elemTy, operands,
|
||||||
|
adaptor.getAttributes().getValue());
|
||||||
}
|
}
|
||||||
|
|
||||||
// Get the left operand of the op.
|
|
||||||
Value getLhs(OpAdaptor adaptor) const { return adaptor.getLhs(); }
|
|
||||||
// Get the right operand of the op.
|
|
||||||
Value getRhs(OpAdaptor adaptor) const { return adaptor.getRhs(); }
|
|
||||||
};
|
};
|
||||||
|
|
||||||
//
|
//
|
||||||
@@ -2015,25 +2025,22 @@ struct UnaryOpConversion
|
|||||||
//
|
//
|
||||||
|
|
||||||
struct CmpIOpConversion
|
struct CmpIOpConversion
|
||||||
: public BinaryOpConversionBase<triton::gpu::CmpIOp, LLVM::ICmpOp,
|
: public ElementwiseOpConversionBase<triton::gpu::CmpIOp, LLVM::ICmpOp,
|
||||||
CmpIOpConversion> {
|
CmpIOpConversion> {
|
||||||
explicit CmpIOpConversion(LLVMTypeConverter &typeConverter,
|
using Base = ElementwiseOpConversionBase<triton::gpu::CmpIOp, LLVM::ICmpOp,
|
||||||
PatternBenefit benefit = 1)
|
CmpIOpConversion>;
|
||||||
: BinaryOpConversionBase(typeConverter, benefit) {}
|
using Base::Base;
|
||||||
|
using Adaptor = typename Base::OpAdaptor;
|
||||||
|
|
||||||
// An interface to support variant DestOp builder.
|
// An interface to support variant DestOp builder.
|
||||||
LLVM::ICmpOp createDestOp(triton::gpu::CmpIOp op,
|
LLVM::ICmpOp createDestOp(triton::gpu::CmpIOp op, OpAdaptor adaptor,
|
||||||
ConversionPatternRewriter &rewriter, Type elemTy,
|
ConversionPatternRewriter &rewriter, Type elemTy,
|
||||||
Value lhs, Value rhs, Location loc) const {
|
ValueRange operands, Location loc) const {
|
||||||
return rewriter.create<LLVM::ICmpOp>(
|
return rewriter.create<LLVM::ICmpOp>(
|
||||||
loc, elemTy, ArithCmpIPredicteToLLVM(op.predicate()), lhs, rhs);
|
loc, elemTy, ArithCmpIPredicteToLLVM(op.predicate()), operands[0],
|
||||||
|
operands[1]);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Get the left operand of the op.
|
|
||||||
Value getLhs(OpAdaptor adaptor) const { return adaptor.lhs(); }
|
|
||||||
// Get the right operand of the op.
|
|
||||||
Value getRhs(OpAdaptor adaptor) const { return adaptor.rhs(); }
|
|
||||||
|
|
||||||
static LLVM::ICmpPredicate
|
static LLVM::ICmpPredicate
|
||||||
ArithCmpIPredicteToLLVM(arith::CmpIPredicate predicate) {
|
ArithCmpIPredicteToLLVM(arith::CmpIPredicate predicate) {
|
||||||
switch (predicate) {
|
switch (predicate) {
|
||||||
@@ -2059,25 +2066,22 @@ struct CmpIOpConversion
|
|||||||
};
|
};
|
||||||
|
|
||||||
struct CmpFOpConversion
|
struct CmpFOpConversion
|
||||||
: public BinaryOpConversionBase<triton::gpu::CmpFOp, LLVM::FCmpOp,
|
: public ElementwiseOpConversionBase<triton::gpu::CmpFOp, LLVM::FCmpOp,
|
||||||
CmpFOpConversion> {
|
CmpFOpConversion> {
|
||||||
explicit CmpFOpConversion(LLVMTypeConverter &typeConverter,
|
using Base = ElementwiseOpConversionBase<triton::gpu::CmpFOp, LLVM::FCmpOp,
|
||||||
PatternBenefit benefit = 1)
|
CmpFOpConversion>;
|
||||||
: BinaryOpConversionBase(typeConverter, benefit) {}
|
using Base::Base;
|
||||||
|
using Adaptor = typename Base::OpAdaptor;
|
||||||
|
|
||||||
// An interface to support variant DestOp builder.
|
// An interface to support variant DestOp builder.
|
||||||
LLVM::FCmpOp createDestOp(triton::gpu::CmpFOp op,
|
LLVM::FCmpOp createDestOp(triton::gpu::CmpFOp op, OpAdaptor adaptor,
|
||||||
ConversionPatternRewriter &rewriter, Type elemTy,
|
ConversionPatternRewriter &rewriter, Type elemTy,
|
||||||
Value lhs, Value rhs, Location loc) const {
|
ValueRange operands, Location loc) const {
|
||||||
return rewriter.create<LLVM::FCmpOp>(
|
return rewriter.create<LLVM::FCmpOp>(
|
||||||
loc, elemTy, ArithCmpFPredicteToLLVM(op.predicate()), lhs, rhs);
|
loc, elemTy, ArithCmpFPredicteToLLVM(op.predicate()), operands[0],
|
||||||
|
operands[1]);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Get the left operand of the op.
|
|
||||||
Value getLhs(OpAdaptor adaptor) const { return adaptor.lhs(); }
|
|
||||||
// Get the right operand of the op.
|
|
||||||
Value getRhs(OpAdaptor adaptor) const { return adaptor.rhs(); }
|
|
||||||
|
|
||||||
static LLVM::FCmpPredicate
|
static LLVM::FCmpPredicate
|
||||||
ArithCmpFPredicteToLLVM(arith::CmpFPredicate predicate) {
|
ArithCmpFPredicteToLLVM(arith::CmpFPredicate predicate) {
|
||||||
switch (predicate) {
|
switch (predicate) {
|
||||||
@@ -4081,6 +4085,90 @@ struct InsertSliceAsyncOpConversion
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
struct ExtElemwiseOpConversion
|
||||||
|
: public ElementwiseOpConversionBase<
|
||||||
|
triton::ExtElemwiseOp, LLVM::LLVMFuncOp, ExtElemwiseOpConversion> {
|
||||||
|
using Base =
|
||||||
|
ElementwiseOpConversionBase<triton::ExtElemwiseOp, LLVM::LLVMFuncOp,
|
||||||
|
ExtElemwiseOpConversion>;
|
||||||
|
using Base::Base;
|
||||||
|
using Adaptor = typename Base::OpAdaptor;
|
||||||
|
|
||||||
|
Value createDestOp(triton::ExtElemwiseOp op, OpAdaptor adaptor,
|
||||||
|
ConversionPatternRewriter &rewriter, Type elemTy,
|
||||||
|
ValueRange operands, Location loc) const {
|
||||||
|
StringRef funcName = op.symbol();
|
||||||
|
if (funcName.empty())
|
||||||
|
llvm::errs() << "ExtElemwiseOpConversion";
|
||||||
|
|
||||||
|
Type funcType = getFunctionType(elemTy, operands);
|
||||||
|
LLVM::LLVMFuncOp funcOp =
|
||||||
|
appendOrGetFuncOp(rewriter, op, funcName, funcType);
|
||||||
|
return rewriter.create<LLVM::CallOp>(loc, funcOp, operands).getResult(0);
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
Type getFunctionType(Type resultType, ValueRange operands) const {
|
||||||
|
SmallVector<Type> operandTypes(operands.getTypes());
|
||||||
|
return LLVM::LLVMFunctionType::get(resultType, operandTypes);
|
||||||
|
}
|
||||||
|
|
||||||
|
LLVM::LLVMFuncOp appendOrGetFuncOp(ConversionPatternRewriter &rewriter,
|
||||||
|
triton::ExtElemwiseOp op,
|
||||||
|
StringRef funcName, Type funcType) const {
|
||||||
|
using LLVM::LLVMFuncOp;
|
||||||
|
|
||||||
|
auto funcAttr = StringAttr::get(op->getContext(), funcName);
|
||||||
|
Operation *funcOp = SymbolTable::lookupNearestSymbolFrom(op, funcAttr);
|
||||||
|
if (funcOp)
|
||||||
|
return cast<LLVMFuncOp>(*funcOp);
|
||||||
|
|
||||||
|
mlir::OpBuilder b(op->getParentOfType<LLVMFuncOp>());
|
||||||
|
auto ret = b.create<LLVMFuncOp>(op->getLoc(), funcName, funcType);
|
||||||
|
ret.getOperation()->setAttr(
|
||||||
|
"libname", StringAttr::get(op->getContext(), op.libname()));
|
||||||
|
ret.getOperation()->setAttr(
|
||||||
|
"libpath", StringAttr::get(op->getContext(), op.libpath()));
|
||||||
|
return ret;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
struct FDivOpConversion
|
||||||
|
: ElementwiseOpConversionBase<mlir::arith::DivFOp, LLVM::InlineAsmOp,
|
||||||
|
FDivOpConversion> {
|
||||||
|
using Base = ElementwiseOpConversionBase<mlir::arith::DivFOp,
|
||||||
|
LLVM::InlineAsmOp, FDivOpConversion>;
|
||||||
|
using Base::Base;
|
||||||
|
using Adaptor = typename Base::OpAdaptor;
|
||||||
|
|
||||||
|
Value createDestOp(mlir::arith::DivFOp op, OpAdaptor adaptor,
|
||||||
|
ConversionPatternRewriter &rewriter, Type elemTy,
|
||||||
|
ValueRange operands, Location loc) const {
|
||||||
|
|
||||||
|
PTXBuilder ptxBuilder;
|
||||||
|
auto &fdiv = *ptxBuilder.create<PTXInstr>("div");
|
||||||
|
unsigned bitwidth = elemTy.getIntOrFloatBitWidth();
|
||||||
|
if (32 == bitwidth) {
|
||||||
|
fdiv.o("full").o("f32");
|
||||||
|
auto res = ptxBuilder.newOperand("=r");
|
||||||
|
auto lhs = ptxBuilder.newOperand(operands[0], "r");
|
||||||
|
auto rhs = ptxBuilder.newOperand(operands[1], "r");
|
||||||
|
fdiv(res, lhs, rhs);
|
||||||
|
} else if (64 == bitwidth) {
|
||||||
|
fdiv.o("rn").o("f64");
|
||||||
|
auto res = ptxBuilder.newOperand("=l");
|
||||||
|
auto lhs = ptxBuilder.newOperand(operands[0], "l");
|
||||||
|
auto rhs = ptxBuilder.newOperand(operands[1], "l");
|
||||||
|
fdiv(res, lhs, rhs);
|
||||||
|
} else {
|
||||||
|
assert(0 && bitwidth && "not supported");
|
||||||
|
}
|
||||||
|
|
||||||
|
Value ret = ptxBuilder.launch(rewriter, loc, elemTy, false);
|
||||||
|
return ret;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
void populateTritonToLLVMPatterns(mlir::LLVMTypeConverter &typeConverter,
|
void populateTritonToLLVMPatterns(mlir::LLVMTypeConverter &typeConverter,
|
||||||
RewritePatternSet &patterns, int numWarps,
|
RewritePatternSet &patterns, int numWarps,
|
||||||
AxisInfoAnalysis &axisInfoAnalysis,
|
AxisInfoAnalysis &axisInfoAnalysis,
|
||||||
@@ -4093,12 +4181,13 @@ void populateTritonToLLVMPatterns(mlir::LLVMTypeConverter &typeConverter,
|
|||||||
patterns.add<AsyncWaitOpConversion>(typeConverter, benefit);
|
patterns.add<AsyncWaitOpConversion>(typeConverter, benefit);
|
||||||
|
|
||||||
#define POPULATE_TERNARY_OP(SRC_OP, DST_OP) \
|
#define POPULATE_TERNARY_OP(SRC_OP, DST_OP) \
|
||||||
patterns.add<TernaryOpConversion<SRC_OP, DST_OP>>(typeConverter, benefit);
|
patterns.add<ElementwiseOpConversion<SRC_OP, DST_OP>>(typeConverter, benefit);
|
||||||
POPULATE_TERNARY_OP(triton::gpu::SelectOp, LLVM::SelectOp);
|
POPULATE_TERNARY_OP(triton::gpu::SelectOp, LLVM::SelectOp);
|
||||||
#undef POPULATE_TERNARY_OP
|
#undef POPULATE_TERNARY_OP
|
||||||
|
|
||||||
#define POPULATE_BINARY_OP(SRC_OP, DST_OP) \
|
#define POPULATE_BINARY_OP(SRC_OP, DST_OP) \
|
||||||
patterns.add<BinaryOpConversion<SRC_OP, DST_OP>>(typeConverter, benefit);
|
patterns.add<ElementwiseOpConversion<SRC_OP, DST_OP>>(typeConverter, benefit);
|
||||||
|
|
||||||
POPULATE_BINARY_OP(arith::SubIOp, LLVM::SubOp) // -
|
POPULATE_BINARY_OP(arith::SubIOp, LLVM::SubOp) // -
|
||||||
POPULATE_BINARY_OP(arith::SubFOp, LLVM::FSubOp)
|
POPULATE_BINARY_OP(arith::SubFOp, LLVM::FSubOp)
|
||||||
POPULATE_BINARY_OP(arith::AddIOp, LLVM::AddOp) // +
|
POPULATE_BINARY_OP(arith::AddIOp, LLVM::AddOp) // +
|
||||||
@@ -4122,7 +4211,7 @@ void populateTritonToLLVMPatterns(mlir::LLVMTypeConverter &typeConverter,
|
|||||||
patterns.add<CmpIOpConversion>(typeConverter, benefit);
|
patterns.add<CmpIOpConversion>(typeConverter, benefit);
|
||||||
patterns.add<CmpFOpConversion>(typeConverter, benefit);
|
patterns.add<CmpFOpConversion>(typeConverter, benefit);
|
||||||
#define POPULATE_UNARY_OP(SRC_OP, DST_OP) \
|
#define POPULATE_UNARY_OP(SRC_OP, DST_OP) \
|
||||||
patterns.add<UnaryOpConversion<SRC_OP, DST_OP>>(typeConverter, benefit);
|
patterns.add<ElementwiseOpConversion<SRC_OP, DST_OP>>(typeConverter, benefit);
|
||||||
POPULATE_UNARY_OP(arith::TruncIOp, LLVM::TruncOp)
|
POPULATE_UNARY_OP(arith::TruncIOp, LLVM::TruncOp)
|
||||||
POPULATE_UNARY_OP(arith::TruncFOp, LLVM::FPTruncOp)
|
POPULATE_UNARY_OP(arith::TruncFOp, LLVM::FPTruncOp)
|
||||||
POPULATE_UNARY_OP(arith::ExtSIOp, LLVM::SExtOp)
|
POPULATE_UNARY_OP(arith::ExtSIOp, LLVM::SExtOp)
|
||||||
@@ -4135,8 +4224,17 @@ void populateTritonToLLVMPatterns(mlir::LLVMTypeConverter &typeConverter,
|
|||||||
POPULATE_UNARY_OP(triton::BitcastOp, LLVM::BitcastOp)
|
POPULATE_UNARY_OP(triton::BitcastOp, LLVM::BitcastOp)
|
||||||
POPULATE_UNARY_OP(triton::IntToPtrOp, LLVM::IntToPtrOp)
|
POPULATE_UNARY_OP(triton::IntToPtrOp, LLVM::IntToPtrOp)
|
||||||
POPULATE_UNARY_OP(triton::PtrToIntOp, LLVM::PtrToIntOp)
|
POPULATE_UNARY_OP(triton::PtrToIntOp, LLVM::PtrToIntOp)
|
||||||
|
POPULATE_UNARY_OP(math::LogOp, math::LogOp)
|
||||||
|
POPULATE_UNARY_OP(math::CosOp, math::CosOp)
|
||||||
|
POPULATE_UNARY_OP(math::SinOp, math::SinOp)
|
||||||
|
POPULATE_UNARY_OP(math::SqrtOp, math::SqrtOp)
|
||||||
|
POPULATE_UNARY_OP(math::ExpOp, math::ExpOp)
|
||||||
#undef POPULATE_UNARY_OP
|
#undef POPULATE_UNARY_OP
|
||||||
|
|
||||||
|
patterns.add<FDivOpConversion>(typeConverter, benefit);
|
||||||
|
|
||||||
|
patterns.add<ExtElemwiseOpConversion>(typeConverter, benefit);
|
||||||
|
|
||||||
patterns.add<BroadcastOpConversion>(typeConverter, benefit);
|
patterns.add<BroadcastOpConversion>(typeConverter, benefit);
|
||||||
patterns.add<ReduceOpConversion>(typeConverter, allocation, smem, benefit);
|
patterns.add<ReduceOpConversion>(typeConverter, allocation, smem, benefit);
|
||||||
patterns.add<ConvertLayoutOpConversion>(typeConverter, allocation, smem,
|
patterns.add<ConvertLayoutOpConversion>(typeConverter, allocation, smem,
|
||||||
|
@@ -16,6 +16,9 @@
|
|||||||
#include "triton/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.h"
|
#include "triton/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.h"
|
||||||
#include "triton/tools/sys/getenv.hpp"
|
#include "triton/tools/sys/getenv.hpp"
|
||||||
#include "llvm/IR/Constants.h"
|
#include "llvm/IR/Constants.h"
|
||||||
|
#include "llvm/IRReader/IRReader.h"
|
||||||
|
#include "llvm/Linker/Linker.h"
|
||||||
|
#include "llvm/Support/SourceMgr.h"
|
||||||
|
|
||||||
namespace mlir {
|
namespace mlir {
|
||||||
namespace triton {
|
namespace triton {
|
||||||
@@ -148,13 +151,80 @@ translateTritonGPUToLLVMIR(llvm::LLVMContext *llvmContext,
|
|||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
std::map<std::string, std::string> extern_libs;
|
||||||
|
SmallVector<LLVM::LLVMFuncOp> funcs;
|
||||||
|
module.walk([&](LLVM::LLVMFuncOp func) {
|
||||||
|
if (func.isExternal())
|
||||||
|
funcs.push_back(func);
|
||||||
|
});
|
||||||
|
|
||||||
|
for (auto &func : funcs) {
|
||||||
|
if (func.getOperation()->hasAttr("libname")) {
|
||||||
|
auto name =
|
||||||
|
func.getOperation()->getAttr("libname").dyn_cast<StringAttr>();
|
||||||
|
auto path =
|
||||||
|
func.getOperation()->getAttr("libpath").dyn_cast<StringAttr>();
|
||||||
|
if (name) {
|
||||||
|
std::string lib_name = name.str();
|
||||||
|
extern_libs[lib_name] = path.str();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (module.getOperation()->hasAttr("triton_gpu.externs")) {
|
||||||
|
auto dict = module.getOperation()
|
||||||
|
->getAttr("triton_gpu.externs")
|
||||||
|
.dyn_cast<DictionaryAttr>();
|
||||||
|
for (auto &attr : dict) {
|
||||||
|
extern_libs[attr.getName().strref().trim().str()] =
|
||||||
|
attr.getValue().dyn_cast<StringAttr>().strref().trim().str();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
auto llvmir = translateLLVMToLLVMIR(llvmContext, module);
|
auto llvmir = translateLLVMToLLVMIR(llvmContext, module);
|
||||||
if (!llvmir) {
|
if (!llvmir) {
|
||||||
llvm::errs() << "Translate to LLVM IR failed";
|
llvm::errs() << "Translate to LLVM IR failed";
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
|
||||||
|
llvm::SMDiagnostic err;
|
||||||
|
for (auto &lib : extern_libs) {
|
||||||
|
auto ext_mod = llvm::parseIRFile(lib.second, err, *llvmContext);
|
||||||
|
if (!ext_mod) {
|
||||||
|
llvm::errs() << "Failed to load extern lib " << lib.first;
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
ext_mod->setTargetTriple(llvmir->getTargetTriple());
|
||||||
|
ext_mod->setDataLayout(llvmir->getDataLayout());
|
||||||
|
|
||||||
|
if (llvm::Linker::linkModules(*llvmir, std::move(ext_mod))) {
|
||||||
|
llvm::errs() << "Failed to link extern lib " << lib.first;
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return llvmir;
|
return llvmir;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void addExternalLibs(mlir::ModuleOp &module,
|
||||||
|
const std::vector<std::string> &names,
|
||||||
|
const std::vector<std::string> &paths) {
|
||||||
|
if (names.empty() || names.size() != paths.size())
|
||||||
|
return;
|
||||||
|
|
||||||
|
llvm::SmallVector<NamedAttribute, 2> attrs;
|
||||||
|
|
||||||
|
for (size_t i = 0; i < names.size(); ++i) {
|
||||||
|
auto name = StringAttr::get(module->getContext(), names[i]);
|
||||||
|
auto path = StringAttr::get(module->getContext(), paths[i]);
|
||||||
|
NamedAttribute attr(name, path);
|
||||||
|
attrs.push_back(attr);
|
||||||
|
}
|
||||||
|
|
||||||
|
DictionaryAttr dict = DictionaryAttr::get(module->getContext(), attrs);
|
||||||
|
module.getOperation()->setAttr("triton_gpu.externs", dict);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace triton
|
} // namespace triton
|
||||||
} // namespace mlir
|
} // namespace mlir
|
||||||
|
@@ -1335,6 +1335,12 @@ void init_triton_translation(py::module &m) {
|
|||||||
py::bytes bytes(cubin);
|
py::bytes bytes(cubin);
|
||||||
return bytes;
|
return bytes;
|
||||||
});
|
});
|
||||||
|
|
||||||
|
m.def("add_external_libs",
|
||||||
|
[](mlir::ModuleOp &op, const std::vector<std::string> &names,
|
||||||
|
const std::vector<std::string> &paths) {
|
||||||
|
::mlir::triton::addExternalLibs(op, names, paths);
|
||||||
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
void init_triton(py::module &m) {
|
void init_triton(py::module &m) {
|
||||||
|
189
python/tests/test_elementwise.py
Normal file
189
python/tests/test_elementwise.py
Normal file
@@ -0,0 +1,189 @@
|
|||||||
|
import tempfile
|
||||||
|
from inspect import Parameter, Signature
|
||||||
|
|
||||||
|
import _testcapi
|
||||||
|
import pytest
|
||||||
|
import torch
|
||||||
|
from torch.testing import assert_close
|
||||||
|
|
||||||
|
import triton
|
||||||
|
import triton.language as tl
|
||||||
|
|
||||||
|
torch_type = {
|
||||||
|
"bool": torch.bool,
|
||||||
|
"int32": torch.int32,
|
||||||
|
"float32": torch.float32,
|
||||||
|
"float64": torch.float64
|
||||||
|
}
|
||||||
|
|
||||||
|
torch_ops = {
|
||||||
|
"log": "log",
|
||||||
|
"cos": "cos",
|
||||||
|
"sin": "sin",
|
||||||
|
"sqrt": "sqrt",
|
||||||
|
"abs": "abs",
|
||||||
|
"exp": "exp",
|
||||||
|
"sigmoid": "sigmoid",
|
||||||
|
"umulhi": None,
|
||||||
|
"cdiv": None,
|
||||||
|
"fdiv": "div",
|
||||||
|
"minimum": "minimum",
|
||||||
|
"maximum": "maximum",
|
||||||
|
"where": "where",
|
||||||
|
}
|
||||||
|
|
||||||
|
libdevice = '/usr/local/cuda/nvvm/libdevice/libdevice.10.bc'
|
||||||
|
|
||||||
|
|
||||||
|
def get_tensor(shape, data_type, b_positive=False):
|
||||||
|
x = None
|
||||||
|
if data_type.startswith('int'):
|
||||||
|
x = torch.randint(2**31 - 1, shape, dtype=torch_type[data_type], device='cuda')
|
||||||
|
elif data_type.startswith('bool'):
|
||||||
|
x = torch.randint(1, shape, dtype=torch_type[data_type], device='cuda')
|
||||||
|
else:
|
||||||
|
x = torch.randn(shape, dtype=torch_type[data_type], device='cuda')
|
||||||
|
|
||||||
|
if b_positive:
|
||||||
|
x = torch.abs(x)
|
||||||
|
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize('expr, output_type, input0_type',
|
||||||
|
[('log', 'float32', 'float32'),
|
||||||
|
('log', 'float64', 'float64'),
|
||||||
|
('cos', 'float32', 'float32'),
|
||||||
|
('cos', 'float64', 'float64'),
|
||||||
|
('sin', 'float32', 'float32'),
|
||||||
|
('sin', 'float64', 'float64'),
|
||||||
|
('sqrt', 'float32', 'float32'),
|
||||||
|
('sqrt', 'float64', 'float64'),
|
||||||
|
('abs', 'float32', 'float32'),
|
||||||
|
('exp', 'float32', 'float32'),
|
||||||
|
('sigmoid', 'float32', 'float32'),
|
||||||
|
])
|
||||||
|
def test_single_input(expr, output_type, input0_type):
|
||||||
|
src = f"""
|
||||||
|
def kernel(X, Y, BLOCK: tl.constexpr):
|
||||||
|
x = tl.load(X + tl.arange(0, BLOCK))
|
||||||
|
y = tl.{expr}(x)
|
||||||
|
tl.store(Y + tl.arange(0, BLOCK), y)
|
||||||
|
"""
|
||||||
|
fp = tempfile.NamedTemporaryFile(mode='w', suffix=".py")
|
||||||
|
fp.write(src)
|
||||||
|
fp.flush()
|
||||||
|
|
||||||
|
def kernel(X, Y, BLOCK: tl.constexpr):
|
||||||
|
pass
|
||||||
|
kernel.__code__ = _testcapi.code_newempty(fp.name, "kernel", 1)
|
||||||
|
parameters = []
|
||||||
|
parameters.append(Parameter("X", 1))
|
||||||
|
parameters.append(Parameter("Y", 1))
|
||||||
|
parameters.append(Parameter("BLOCK", 1))
|
||||||
|
kernel.__signature__ = Signature(parameters=parameters)
|
||||||
|
kernel = triton.jit(kernel)
|
||||||
|
|
||||||
|
shape = (128, )
|
||||||
|
# limit the range of integers so that the sum does not overflow
|
||||||
|
x = get_tensor(shape, input0_type, expr == 'log' or expr == 'sqrt')
|
||||||
|
# triton result
|
||||||
|
y = torch.zeros(shape, dtype=torch_type[output_type], device="cuda")
|
||||||
|
kernel[(1,)](x, y, BLOCK=shape[0], extern_libs={"libdevice": libdevice})
|
||||||
|
# reference result
|
||||||
|
y_ref = getattr(torch, torch_ops[expr])(x)
|
||||||
|
# compare
|
||||||
|
assert_close(y, y_ref)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize('expr, output_type, input0_type, input1_type',
|
||||||
|
[('umulhi', 'int32', 'int32', 'int32'),
|
||||||
|
('cdiv', 'int32', 'int32', 'int32'),
|
||||||
|
('fdiv', 'float32', 'float32', 'float32'),
|
||||||
|
('minimum', 'float32', 'float32', 'float32'),
|
||||||
|
('maximum', 'float32', 'float32', 'float32'),
|
||||||
|
])
|
||||||
|
def test_two_input(expr, output_type, input0_type, input1_type):
|
||||||
|
src = f"""
|
||||||
|
def kernel(X0, X1, Y, BLOCK: tl.constexpr):
|
||||||
|
x0 = tl.load(X0 + tl.arange(0, BLOCK))
|
||||||
|
x1 = tl.load(X1 + tl.arange(0, BLOCK))
|
||||||
|
y = tl.{expr}(x0, x1)
|
||||||
|
tl.store(Y + tl.arange(0, BLOCK), y)
|
||||||
|
"""
|
||||||
|
fp = tempfile.NamedTemporaryFile(mode='w', suffix=".py")
|
||||||
|
fp.write(src)
|
||||||
|
fp.flush()
|
||||||
|
|
||||||
|
def kernel(X0, X1, Y, BLOCK: tl.constexpr):
|
||||||
|
pass
|
||||||
|
kernel.__code__ = _testcapi.code_newempty(fp.name, "kernel", 1)
|
||||||
|
parameters = []
|
||||||
|
parameters.append(Parameter("X0", 1))
|
||||||
|
parameters.append(Parameter("X1", 1))
|
||||||
|
parameters.append(Parameter("Y", 1))
|
||||||
|
parameters.append(Parameter("BLOCK", 1))
|
||||||
|
kernel.__signature__ = Signature(parameters=parameters)
|
||||||
|
kernel = triton.jit(kernel)
|
||||||
|
|
||||||
|
shape = (128, )
|
||||||
|
# limit the range of integers so that the sum does not overflow
|
||||||
|
x0 = get_tensor(shape, input0_type)
|
||||||
|
x1 = get_tensor(shape, input1_type)
|
||||||
|
|
||||||
|
# triton result
|
||||||
|
y = torch.zeros(shape, dtype=torch_type[output_type], device="cuda")
|
||||||
|
kernel[(1,)](x0, x1, y, BLOCK=shape[0], extern_libs={"libdevice": libdevice})
|
||||||
|
# reference result
|
||||||
|
|
||||||
|
if expr == "cdiv":
|
||||||
|
y_ref = (x0 + x1 - 1) // x1
|
||||||
|
elif expr == "umulhi":
|
||||||
|
y_ref = ((x0.to(torch.int64) * x1) >> 32).to(torch.int32)
|
||||||
|
else:
|
||||||
|
y_ref = getattr(torch, torch_ops[expr])(x0, x1)
|
||||||
|
# compare
|
||||||
|
assert_close(y, y_ref)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize('expr, output_type, input0_type, input1_type, input2_type',
|
||||||
|
[('where', "int32", "bool", "int32", "int32"), ])
|
||||||
|
def test_three_input(expr, output_type, input0_type, input1_type, input2_type):
|
||||||
|
src = f"""
|
||||||
|
def kernel(X0, X1, X2, Y, BLOCK: tl.constexpr):
|
||||||
|
x0 = tl.load(X0 + tl.arange(0, BLOCK))
|
||||||
|
x1 = tl.load(X1 + tl.arange(0, BLOCK))
|
||||||
|
x2 = tl.load(X2 + tl.arange(0, BLOCK))
|
||||||
|
y = tl.{expr}(x0, x1, x2)
|
||||||
|
tl.store(Y + tl.arange(0, BLOCK), y)
|
||||||
|
"""
|
||||||
|
fp = tempfile.NamedTemporaryFile(mode='w', suffix=".py")
|
||||||
|
fp.write(src)
|
||||||
|
fp.flush()
|
||||||
|
|
||||||
|
def kernel(X0, X1, X2, Y, BLOCK: tl.constexpr):
|
||||||
|
pass
|
||||||
|
kernel.__code__ = _testcapi.code_newempty(fp.name, "kernel", 1)
|
||||||
|
parameters = []
|
||||||
|
parameters.append(Parameter("X0", 1))
|
||||||
|
parameters.append(Parameter("X1", 1))
|
||||||
|
parameters.append(Parameter("X2", 1))
|
||||||
|
parameters.append(Parameter("Y", 1))
|
||||||
|
parameters.append(Parameter("BLOCK", 1))
|
||||||
|
kernel.__signature__ = Signature(parameters=parameters)
|
||||||
|
kernel = triton.jit(kernel)
|
||||||
|
|
||||||
|
shape = (128, )
|
||||||
|
# limit the range of integers so that the sum does not overflow
|
||||||
|
x0 = get_tensor(shape, input0_type)
|
||||||
|
x1 = get_tensor(shape, input1_type)
|
||||||
|
x2 = get_tensor(shape, input1_type)
|
||||||
|
|
||||||
|
# triton result
|
||||||
|
y = torch.zeros(shape, dtype=torch_type[output_type], device="cuda")
|
||||||
|
kernel[(1,)](x0, x1, x2, y, BLOCK=shape[0], extern_libs={"libdevice": libdevice})
|
||||||
|
# reference result
|
||||||
|
|
||||||
|
y_ref = getattr(torch, torch_ops[expr])(x0, x1, x2)
|
||||||
|
# compare
|
||||||
|
assert_close(y, y_ref)
|
178
python/tests/test_ext_elemwise.py
Normal file
178
python/tests/test_ext_elemwise.py
Normal file
@@ -0,0 +1,178 @@
|
|||||||
|
|
||||||
|
import pytest
|
||||||
|
import torch
|
||||||
|
from torch.testing import assert_close
|
||||||
|
|
||||||
|
import triton
|
||||||
|
import triton.language as tl
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize('num_warps, block_size, iter_size', [
|
||||||
|
[4, 256, 1],
|
||||||
|
[4, 1024, 256],
|
||||||
|
])
|
||||||
|
def test_sin_no_mask(num_warps, block_size, iter_size):
|
||||||
|
@triton.jit
|
||||||
|
def kernel(x_ptr,
|
||||||
|
y_ptr,
|
||||||
|
block_size,
|
||||||
|
iter_size: tl.constexpr):
|
||||||
|
pid = tl.program_id(axis=0)
|
||||||
|
for i in range(0, block_size, iter_size):
|
||||||
|
offset = pid * block_size + tl.arange(0, iter_size)
|
||||||
|
x_ptrs = x_ptr + offset
|
||||||
|
x = tl.load(x_ptrs)
|
||||||
|
y = tl.libdevice.sin(x)
|
||||||
|
y_ptrs = y_ptr + offset
|
||||||
|
tl.store(y_ptrs, y)
|
||||||
|
|
||||||
|
x_ptr += iter_size
|
||||||
|
y_ptr += iter_size
|
||||||
|
|
||||||
|
x = torch.randn((block_size,), device='cuda', dtype=torch.float32)
|
||||||
|
y = torch.empty((block_size,), device=x.device, dtype=x.dtype)
|
||||||
|
|
||||||
|
grid = lambda EA: (x.shape.numel() // (block_size),)
|
||||||
|
kernel[grid](x_ptr=x, y_ptr=y,
|
||||||
|
block_size=x.shape[0], iter_size=iter_size, num_warps=num_warps)
|
||||||
|
|
||||||
|
golden_y = torch.sin(x)
|
||||||
|
assert_close(y, golden_y, rtol=1e-7, atol=1e-7)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize('num_warps, block_size, iter_size', [
|
||||||
|
[4, 256, 1],
|
||||||
|
[4, 1024, 256],
|
||||||
|
])
|
||||||
|
def test_fmin_no_mask(num_warps, block_size, iter_size):
|
||||||
|
@triton.jit
|
||||||
|
def kernel(x_ptr,
|
||||||
|
y_ptr,
|
||||||
|
z_ptr,
|
||||||
|
block_size,
|
||||||
|
iter_size: tl.constexpr):
|
||||||
|
pid = tl.program_id(axis=0)
|
||||||
|
for i in range(0, block_size, iter_size):
|
||||||
|
offset = pid * block_size + tl.arange(0, iter_size)
|
||||||
|
x_ptrs = x_ptr + offset
|
||||||
|
y_ptrs = y_ptr + offset
|
||||||
|
|
||||||
|
x = tl.load(x_ptrs)
|
||||||
|
y = tl.load(y_ptrs)
|
||||||
|
z = tl.libdevice.min(x, y)
|
||||||
|
z_ptrs = z_ptr + offset
|
||||||
|
tl.store(z_ptrs, z)
|
||||||
|
|
||||||
|
x_ptr += iter_size
|
||||||
|
y_ptr += iter_size
|
||||||
|
z_ptr += iter_size
|
||||||
|
|
||||||
|
x = torch.randn((block_size,), device='cuda', dtype=torch.float32)
|
||||||
|
y = torch.randn((block_size,), device='cuda', dtype=torch.float32)
|
||||||
|
z = torch.empty((block_size,), device=x.device, dtype=x.dtype)
|
||||||
|
|
||||||
|
grid = lambda EA: (x.shape.numel() // (block_size),)
|
||||||
|
kernel[grid](x_ptr=x, y_ptr=y, z_ptr=z,
|
||||||
|
block_size=x.shape[0], iter_size=iter_size, num_warps=num_warps)
|
||||||
|
|
||||||
|
golden_z = torch.minimum(x, y)
|
||||||
|
assert_close(z, golden_z, rtol=1e-7, atol=1e-7)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize('num_warps, block_size, iter_size', [
|
||||||
|
[4, 256, 1],
|
||||||
|
[4, 1024, 256],
|
||||||
|
])
|
||||||
|
def test_fmad_rn_no_mask(num_warps, block_size, iter_size):
|
||||||
|
@triton.jit
|
||||||
|
def kernel(x_ptr,
|
||||||
|
y_ptr,
|
||||||
|
z_ptr,
|
||||||
|
w_ptr,
|
||||||
|
block_size,
|
||||||
|
iter_size: tl.constexpr):
|
||||||
|
pid = tl.program_id(axis=0)
|
||||||
|
for i in range(0, block_size, iter_size):
|
||||||
|
offset = pid * block_size + tl.arange(0, iter_size)
|
||||||
|
x_ptrs = x_ptr + offset
|
||||||
|
y_ptrs = y_ptr + offset
|
||||||
|
z_ptrs = z_ptr + offset
|
||||||
|
|
||||||
|
x = tl.load(x_ptrs)
|
||||||
|
y = tl.load(y_ptrs)
|
||||||
|
z = tl.load(z_ptrs)
|
||||||
|
|
||||||
|
w = tl.libdevice.fma_rn(x, y, z)
|
||||||
|
w_ptrs = w_ptr + offset
|
||||||
|
tl.store(w_ptrs, w)
|
||||||
|
|
||||||
|
x_ptr += iter_size
|
||||||
|
y_ptr += iter_size
|
||||||
|
z_ptr += iter_size
|
||||||
|
w_ptr += iter_size
|
||||||
|
|
||||||
|
x = torch.randn((block_size,), device='cuda', dtype=torch.float64)
|
||||||
|
y = torch.randn((block_size,), device='cuda', dtype=torch.float64)
|
||||||
|
z = torch.randn((block_size,), device='cuda', dtype=torch.float64)
|
||||||
|
w = torch.empty((block_size,), device=x.device, dtype=x.dtype)
|
||||||
|
|
||||||
|
grid = lambda EA: (x.shape.numel() // (block_size),)
|
||||||
|
kernel[grid](x_ptr=x, y_ptr=y, z_ptr=z, w_ptr=w,
|
||||||
|
block_size=x.shape[0], iter_size=iter_size, num_warps=num_warps)
|
||||||
|
|
||||||
|
golden_w = x * y + z
|
||||||
|
assert_close(w, golden_w, rtol=1e-7, atol=1e-7)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("dtype_str, expr, lib_path",
|
||||||
|
[('int32', 'libdevice.ffs', '/usr/local/cuda/nvvm/libdevice/libdevice.10.bc'),
|
||||||
|
('int32', 'libdevice.ffs', '')])
|
||||||
|
def test_libdevice(dtype_str, expr, lib_path):
|
||||||
|
src = f"""
|
||||||
|
def kernel(X, Y, BLOCK: tl.constexpr):
|
||||||
|
x = tl.load(X + tl.arange(0, BLOCK))
|
||||||
|
y = tl.{expr}(x)
|
||||||
|
tl.store(Y + tl.arange(0, BLOCK), y)
|
||||||
|
"""
|
||||||
|
import tempfile
|
||||||
|
from inspect import Parameter, Signature
|
||||||
|
|
||||||
|
import _testcapi
|
||||||
|
|
||||||
|
fp = tempfile.NamedTemporaryFile(mode='w', suffix=".py")
|
||||||
|
fp.write(src)
|
||||||
|
fp.flush()
|
||||||
|
|
||||||
|
def kernel(X, Y, BLOCK: tl.constexpr):
|
||||||
|
pass
|
||||||
|
kernel.__code__ = _testcapi.code_newempty(fp.name, "kernel", 1)
|
||||||
|
parameters = []
|
||||||
|
parameters.append(Parameter("X", 1))
|
||||||
|
parameters.append(Parameter("Y", 1))
|
||||||
|
parameters.append(Parameter("BLOCK", 1))
|
||||||
|
kernel.__signature__ = Signature(parameters=parameters)
|
||||||
|
kernel = triton.jit(kernel)
|
||||||
|
|
||||||
|
torch_type = {
|
||||||
|
"int32": torch.int32,
|
||||||
|
"float32": torch.float32,
|
||||||
|
"float64": torch.float64
|
||||||
|
}
|
||||||
|
|
||||||
|
shape = (128, )
|
||||||
|
# limit the range of integers so that the sum does not overflow
|
||||||
|
x = None
|
||||||
|
if dtype_str == "int32":
|
||||||
|
x = torch.randint(2**31 - 1, shape, dtype=torch_type[dtype_str], device="cuda")
|
||||||
|
else:
|
||||||
|
x = torch.randn(shape, dtype=torch_type[dtype_str], device="cuda")
|
||||||
|
if expr == 'libdevice.ffs':
|
||||||
|
y_ref = torch.zeros(shape, dtype=x.dtype, device="cuda")
|
||||||
|
for i in range(shape[0]):
|
||||||
|
y_ref[i] = (int(x[i]) & int(-x[i])).bit_length()
|
||||||
|
|
||||||
|
# triton result
|
||||||
|
y = torch.zeros(shape, dtype=x.dtype, device="cuda")
|
||||||
|
kernel[(1,)](x, y, BLOCK=shape[0], extern_libs={"libdevice": lib_path})
|
||||||
|
# compare
|
||||||
|
assert_close(y, y_ref)
|
@@ -36,6 +36,7 @@ def str_to_ty(name):
|
|||||||
"bf16": triton.language.bfloat16,
|
"bf16": triton.language.bfloat16,
|
||||||
"fp32": triton.language.float32,
|
"fp32": triton.language.float32,
|
||||||
"fp64": triton.language.float64,
|
"fp64": triton.language.float64,
|
||||||
|
"i1": triton.language.int1,
|
||||||
"i8": triton.language.int8,
|
"i8": triton.language.int8,
|
||||||
"i16": triton.language.int16,
|
"i16": triton.language.int16,
|
||||||
"i32": triton.language.int32,
|
"i32": triton.language.int32,
|
||||||
@@ -45,7 +46,6 @@ def str_to_ty(name):
|
|||||||
"u32": triton.language.uint32,
|
"u32": triton.language.uint32,
|
||||||
"u64": triton.language.uint64,
|
"u64": triton.language.uint64,
|
||||||
"B": triton.language.int1,
|
"B": triton.language.int1,
|
||||||
"i1": triton.language.int1,
|
|
||||||
}
|
}
|
||||||
return tys[name]
|
return tys[name]
|
||||||
|
|
||||||
@@ -888,6 +888,13 @@ def optimize_tritongpu_ir(mod, num_stages):
|
|||||||
return mod
|
return mod
|
||||||
|
|
||||||
|
|
||||||
|
def add_external_libs(mod, libs):
|
||||||
|
for name, path in libs.items():
|
||||||
|
if len(name) == 0 or len(path) == 0:
|
||||||
|
return
|
||||||
|
_triton.add_external_libs(mod, list(libs.keys()), list(libs.values()))
|
||||||
|
|
||||||
|
|
||||||
def make_llvm_ir(mod):
|
def make_llvm_ir(mod):
|
||||||
return _triton.translate_triton_gpu_to_llvmir(mod)
|
return _triton.translate_triton_gpu_to_llvmir(mod)
|
||||||
|
|
||||||
@@ -986,6 +993,8 @@ def _compile(fn, signature: str, device: int = -1, constants=dict(), specializat
|
|||||||
module = optimize_tritongpu_ir(module, num_stages)
|
module = optimize_tritongpu_ir(module, num_stages)
|
||||||
if output == "ttgir":
|
if output == "ttgir":
|
||||||
return module.str()
|
return module.str()
|
||||||
|
if extern_libs:
|
||||||
|
add_external_libs(module, extern_libs)
|
||||||
|
|
||||||
# llvm-ir
|
# llvm-ir
|
||||||
llvm_ir = make_llvm_ir(module)
|
llvm_ir = make_llvm_ir(module)
|
||||||
|
BIN
python/triton/language/libdevice.10.bc
Executable file
BIN
python/triton/language/libdevice.10.bc
Executable file
Binary file not shown.
@@ -226,7 +226,6 @@ def fdiv(input: tl.tensor,
|
|||||||
raise ValueError("both operands of fdiv must have floating poscalar type")
|
raise ValueError("both operands of fdiv must have floating poscalar type")
|
||||||
input, other = binary_op_type_checking_impl(input, other, builder, False, False, False, True)
|
input, other = binary_op_type_checking_impl(input, other, builder, False, False, False, True)
|
||||||
ret = builder.create_fdiv(input.handle, other.handle)
|
ret = builder.create_fdiv(input.handle, other.handle)
|
||||||
ret.set_fdiv_ieee_rounding(ieee_rounding)
|
|
||||||
return tl.tensor(ret, input.type)
|
return tl.tensor(ret, input.type)
|
||||||
|
|
||||||
|
|
||||||
@@ -1074,7 +1073,8 @@ def xor_sum(input: tl.tensor, axis: int, builder: ir.builder) -> tl.tensor:
|
|||||||
|
|
||||||
def umulhi(x: tl.tensor, y: tl.tensor, builder: ir.builder) -> tl.tensor:
|
def umulhi(x: tl.tensor, y: tl.tensor, builder: ir.builder) -> tl.tensor:
|
||||||
x, y = binary_op_type_checking_impl(x, y, builder)
|
x, y = binary_op_type_checking_impl(x, y, builder)
|
||||||
return tl.tensor(builder.create_umulhi(x.handle, y.handle), x.type)
|
from . import libdevice
|
||||||
|
return libdevice.mulhi(x, y, _builder=builder)
|
||||||
|
|
||||||
|
|
||||||
def exp(x: tl.tensor, builder: ir.builder) -> tl.tensor:
|
def exp(x: tl.tensor, builder: ir.builder) -> tl.tensor:
|
||||||
|
Reference in New Issue
Block a user