diff --git a/include/triton/Target/LLVMIR/LLVMIRTranslation.h b/include/triton/Target/LLVMIR/LLVMIRTranslation.h index 01411414b..52395b0b6 100644 --- a/include/triton/Target/LLVMIR/LLVMIRTranslation.h +++ b/include/triton/Target/LLVMIR/LLVMIRTranslation.h @@ -1,6 +1,7 @@ #ifndef TRITON_TARGET_LLVMIRTRANSLATION_H #define TRITON_TARGET_LLVMIRTRANSLATION_H #include +#include namespace llvm { class Module; @@ -14,6 +15,11 @@ class ModuleOp; namespace mlir { namespace triton { +// add external dependent libs +void addExternalLibs(mlir::ModuleOp &module, + const std::vector &names, + const std::vector &paths); + // Translate TritonGPU dialect to LLVMIR, return null if failed. std::unique_ptr translateTritonGPUToLLVMIR(llvm::LLVMContext *llvmContext, diff --git a/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.cpp b/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.cpp index 50470ac88..4949c1a87 100644 --- a/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.cpp +++ b/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.cpp @@ -1792,17 +1792,15 @@ struct ExtractSliceOpConversion } }; -// TODO: rewrite Ternary/Binary/Unary as Elementwise - // A CRTP style of base class. template -class BinaryOpConversionBase +class ElementwiseOpConversionBase : public ConvertTritonGPUOpToLLVMPattern { public: using OpAdaptor = typename SourceOp::Adaptor; - explicit BinaryOpConversionBase(LLVMTypeConverter &typeConverter, - PatternBenefit benefit = 1) + explicit ElementwiseOpConversionBase(LLVMTypeConverter &typeConverter, + PatternBenefit benefit = 1) : ConvertTritonGPUOpToLLVMPattern(typeConverter, benefit) {} LogicalResult @@ -1817,7 +1815,8 @@ public: auto resultLayout = resultTy.getEncoding().template dyn_cast(); auto resultShape = resultTy.getShape(); - assert(resultLayout && "Unexpected resultLayout in BinaryOpConversion"); + assert(resultLayout && + "Unexpected resultLayout in ElementwiseOpConversionBase"); unsigned elems = resultLayout.getElemsPerThread(resultShape); Type elemTy = this->getTypeConverter()->convertType(resultTy.getElementType()); @@ -1825,43 +1824,54 @@ public: Type structTy = LLVM::LLVMStructType::getLiteral(this->getContext(), types); auto *concreteThis = static_cast(this); - auto lhss = this->getElementsFromStruct(loc, concreteThis->getLhs(adaptor), - rewriter); - auto rhss = this->getElementsFromStruct(loc, concreteThis->getRhs(adaptor), - rewriter); + auto operands = getOperands(rewriter, adaptor, elems, loc); SmallVector resultVals(elems); for (unsigned i = 0; i < elems; ++i) { - resultVals[i] = concreteThis->createDestOp(op, rewriter, elemTy, lhss[i], - rhss[i], loc); + resultVals[i] = concreteThis->createDestOp(op, adaptor, rewriter, elemTy, + operands[i], loc); } Value view = getStructFromElements(loc, resultVals, rewriter, structTy); rewriter.replaceOp(op, view); return success(); } + +protected: + SmallVector> + getOperands(ConversionPatternRewriter &rewriter, OpAdaptor adaptor, + const unsigned elems, Location loc) const { + SmallVector> operands(elems); + for (auto operand : adaptor.getOperands()) { + auto sub_operands = this->getElementsFromStruct(loc, operand, rewriter); + for (int i = 0; i < elems; ++i) { + operands[i].push_back(sub_operands[i]); + } + } + return operands; + } }; template -struct BinaryOpConversion - : public BinaryOpConversionBase> { +struct ElementwiseOpConversion + : public ElementwiseOpConversionBase< + SourceOp, DestOp, ElementwiseOpConversion> { + using Base = + ElementwiseOpConversionBase>; + using Base::Base; + using OpAdaptor = typename Base::OpAdaptor; - explicit BinaryOpConversion(LLVMTypeConverter &typeConverter, - PatternBenefit benefit = 1) - : BinaryOpConversionBase>( + explicit ElementwiseOpConversion(LLVMTypeConverter &typeConverter, + PatternBenefit benefit = 1) + : ElementwiseOpConversionBase( typeConverter, benefit) {} - using OpAdaptor = typename SourceOp::Adaptor; // An interface to support variant DestOp builder. - DestOp createDestOp(SourceOp op, ConversionPatternRewriter &rewriter, - Type elemTy, Value lhs, Value rhs, Location loc) const { - return rewriter.create(loc, elemTy, lhs, rhs); + DestOp createDestOp(SourceOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter, Type elemTy, + ValueRange operands, Location loc) const { + return rewriter.create(loc, elemTy, operands, + adaptor.getAttributes().getValue()); } - - // Get the left operand of the op. - Value getLhs(OpAdaptor adaptor) const { return adaptor.getLhs(); } - // Get the right operand of the op. - Value getRhs(OpAdaptor adaptor) const { return adaptor.getRhs(); } }; // @@ -2015,25 +2025,22 @@ struct UnaryOpConversion // struct CmpIOpConversion - : public BinaryOpConversionBase { - explicit CmpIOpConversion(LLVMTypeConverter &typeConverter, - PatternBenefit benefit = 1) - : BinaryOpConversionBase(typeConverter, benefit) {} + : public ElementwiseOpConversionBase { + using Base = ElementwiseOpConversionBase; + using Base::Base; + using Adaptor = typename Base::OpAdaptor; // An interface to support variant DestOp builder. - LLVM::ICmpOp createDestOp(triton::gpu::CmpIOp op, + LLVM::ICmpOp createDestOp(triton::gpu::CmpIOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter, Type elemTy, - Value lhs, Value rhs, Location loc) const { + ValueRange operands, Location loc) const { return rewriter.create( - loc, elemTy, ArithCmpIPredicteToLLVM(op.predicate()), lhs, rhs); + loc, elemTy, ArithCmpIPredicteToLLVM(op.predicate()), operands[0], + operands[1]); } - // Get the left operand of the op. - Value getLhs(OpAdaptor adaptor) const { return adaptor.lhs(); } - // Get the right operand of the op. - Value getRhs(OpAdaptor adaptor) const { return adaptor.rhs(); } - static LLVM::ICmpPredicate ArithCmpIPredicteToLLVM(arith::CmpIPredicate predicate) { switch (predicate) { @@ -2059,25 +2066,22 @@ struct CmpIOpConversion }; struct CmpFOpConversion - : public BinaryOpConversionBase { - explicit CmpFOpConversion(LLVMTypeConverter &typeConverter, - PatternBenefit benefit = 1) - : BinaryOpConversionBase(typeConverter, benefit) {} + : public ElementwiseOpConversionBase { + using Base = ElementwiseOpConversionBase; + using Base::Base; + using Adaptor = typename Base::OpAdaptor; // An interface to support variant DestOp builder. - LLVM::FCmpOp createDestOp(triton::gpu::CmpFOp op, + LLVM::FCmpOp createDestOp(triton::gpu::CmpFOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter, Type elemTy, - Value lhs, Value rhs, Location loc) const { + ValueRange operands, Location loc) const { return rewriter.create( - loc, elemTy, ArithCmpFPredicteToLLVM(op.predicate()), lhs, rhs); + loc, elemTy, ArithCmpFPredicteToLLVM(op.predicate()), operands[0], + operands[1]); } - // Get the left operand of the op. - Value getLhs(OpAdaptor adaptor) const { return adaptor.lhs(); } - // Get the right operand of the op. - Value getRhs(OpAdaptor adaptor) const { return adaptor.rhs(); } - static LLVM::FCmpPredicate ArithCmpFPredicteToLLVM(arith::CmpFPredicate predicate) { switch (predicate) { @@ -4081,6 +4085,90 @@ struct InsertSliceAsyncOpConversion } }; +struct ExtElemwiseOpConversion + : public ElementwiseOpConversionBase< + triton::ExtElemwiseOp, LLVM::LLVMFuncOp, ExtElemwiseOpConversion> { + using Base = + ElementwiseOpConversionBase; + using Base::Base; + using Adaptor = typename Base::OpAdaptor; + + Value createDestOp(triton::ExtElemwiseOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter, Type elemTy, + ValueRange operands, Location loc) const { + StringRef funcName = op.symbol(); + if (funcName.empty()) + llvm::errs() << "ExtElemwiseOpConversion"; + + Type funcType = getFunctionType(elemTy, operands); + LLVM::LLVMFuncOp funcOp = + appendOrGetFuncOp(rewriter, op, funcName, funcType); + return rewriter.create(loc, funcOp, operands).getResult(0); + } + +private: + Type getFunctionType(Type resultType, ValueRange operands) const { + SmallVector operandTypes(operands.getTypes()); + return LLVM::LLVMFunctionType::get(resultType, operandTypes); + } + + LLVM::LLVMFuncOp appendOrGetFuncOp(ConversionPatternRewriter &rewriter, + triton::ExtElemwiseOp op, + StringRef funcName, Type funcType) const { + using LLVM::LLVMFuncOp; + + auto funcAttr = StringAttr::get(op->getContext(), funcName); + Operation *funcOp = SymbolTable::lookupNearestSymbolFrom(op, funcAttr); + if (funcOp) + return cast(*funcOp); + + mlir::OpBuilder b(op->getParentOfType()); + auto ret = b.create(op->getLoc(), funcName, funcType); + ret.getOperation()->setAttr( + "libname", StringAttr::get(op->getContext(), op.libname())); + ret.getOperation()->setAttr( + "libpath", StringAttr::get(op->getContext(), op.libpath())); + return ret; + } +}; + +struct FDivOpConversion + : ElementwiseOpConversionBase { + using Base = ElementwiseOpConversionBase; + using Base::Base; + using Adaptor = typename Base::OpAdaptor; + + Value createDestOp(mlir::arith::DivFOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter, Type elemTy, + ValueRange operands, Location loc) const { + + PTXBuilder ptxBuilder; + auto &fdiv = *ptxBuilder.create("div"); + unsigned bitwidth = elemTy.getIntOrFloatBitWidth(); + if (32 == bitwidth) { + fdiv.o("full").o("f32"); + auto res = ptxBuilder.newOperand("=r"); + auto lhs = ptxBuilder.newOperand(operands[0], "r"); + auto rhs = ptxBuilder.newOperand(operands[1], "r"); + fdiv(res, lhs, rhs); + } else if (64 == bitwidth) { + fdiv.o("rn").o("f64"); + auto res = ptxBuilder.newOperand("=l"); + auto lhs = ptxBuilder.newOperand(operands[0], "l"); + auto rhs = ptxBuilder.newOperand(operands[1], "l"); + fdiv(res, lhs, rhs); + } else { + assert(0 && bitwidth && "not supported"); + } + + Value ret = ptxBuilder.launch(rewriter, loc, elemTy, false); + return ret; + } +}; + void populateTritonToLLVMPatterns(mlir::LLVMTypeConverter &typeConverter, RewritePatternSet &patterns, int numWarps, AxisInfoAnalysis &axisInfoAnalysis, @@ -4093,12 +4181,13 @@ void populateTritonToLLVMPatterns(mlir::LLVMTypeConverter &typeConverter, patterns.add(typeConverter, benefit); #define POPULATE_TERNARY_OP(SRC_OP, DST_OP) \ - patterns.add>(typeConverter, benefit); + patterns.add>(typeConverter, benefit); POPULATE_TERNARY_OP(triton::gpu::SelectOp, LLVM::SelectOp); #undef POPULATE_TERNARY_OP #define POPULATE_BINARY_OP(SRC_OP, DST_OP) \ - patterns.add>(typeConverter, benefit); + patterns.add>(typeConverter, benefit); + POPULATE_BINARY_OP(arith::SubIOp, LLVM::SubOp) // - POPULATE_BINARY_OP(arith::SubFOp, LLVM::FSubOp) POPULATE_BINARY_OP(arith::AddIOp, LLVM::AddOp) // + @@ -4122,7 +4211,7 @@ void populateTritonToLLVMPatterns(mlir::LLVMTypeConverter &typeConverter, patterns.add(typeConverter, benefit); patterns.add(typeConverter, benefit); #define POPULATE_UNARY_OP(SRC_OP, DST_OP) \ - patterns.add>(typeConverter, benefit); + patterns.add>(typeConverter, benefit); POPULATE_UNARY_OP(arith::TruncIOp, LLVM::TruncOp) POPULATE_UNARY_OP(arith::TruncFOp, LLVM::FPTruncOp) POPULATE_UNARY_OP(arith::ExtSIOp, LLVM::SExtOp) @@ -4135,8 +4224,17 @@ void populateTritonToLLVMPatterns(mlir::LLVMTypeConverter &typeConverter, POPULATE_UNARY_OP(triton::BitcastOp, LLVM::BitcastOp) POPULATE_UNARY_OP(triton::IntToPtrOp, LLVM::IntToPtrOp) POPULATE_UNARY_OP(triton::PtrToIntOp, LLVM::PtrToIntOp) + POPULATE_UNARY_OP(math::LogOp, math::LogOp) + POPULATE_UNARY_OP(math::CosOp, math::CosOp) + POPULATE_UNARY_OP(math::SinOp, math::SinOp) + POPULATE_UNARY_OP(math::SqrtOp, math::SqrtOp) + POPULATE_UNARY_OP(math::ExpOp, math::ExpOp) #undef POPULATE_UNARY_OP + patterns.add(typeConverter, benefit); + + patterns.add(typeConverter, benefit); + patterns.add(typeConverter, benefit); patterns.add(typeConverter, allocation, smem, benefit); patterns.add(typeConverter, allocation, smem, diff --git a/lib/Target/LLVMIR/LLVMIRTranslation.cpp b/lib/Target/LLVMIR/LLVMIRTranslation.cpp index 5ed79cd81..be2e65b31 100644 --- a/lib/Target/LLVMIR/LLVMIRTranslation.cpp +++ b/lib/Target/LLVMIR/LLVMIRTranslation.cpp @@ -16,6 +16,9 @@ #include "triton/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.h" #include "triton/tools/sys/getenv.hpp" #include "llvm/IR/Constants.h" +#include "llvm/IRReader/IRReader.h" +#include "llvm/Linker/Linker.h" +#include "llvm/Support/SourceMgr.h" namespace mlir { namespace triton { @@ -148,13 +151,80 @@ translateTritonGPUToLLVMIR(llvm::LLVMContext *llvmContext, return nullptr; } + std::map extern_libs; + SmallVector funcs; + module.walk([&](LLVM::LLVMFuncOp func) { + if (func.isExternal()) + funcs.push_back(func); + }); + + for (auto &func : funcs) { + if (func.getOperation()->hasAttr("libname")) { + auto name = + func.getOperation()->getAttr("libname").dyn_cast(); + auto path = + func.getOperation()->getAttr("libpath").dyn_cast(); + if (name) { + std::string lib_name = name.str(); + extern_libs[lib_name] = path.str(); + } + } + } + + if (module.getOperation()->hasAttr("triton_gpu.externs")) { + auto dict = module.getOperation() + ->getAttr("triton_gpu.externs") + .dyn_cast(); + for (auto &attr : dict) { + extern_libs[attr.getName().strref().trim().str()] = + attr.getValue().dyn_cast().strref().trim().str(); + } + } + auto llvmir = translateLLVMToLLVMIR(llvmContext, module); if (!llvmir) { llvm::errs() << "Translate to LLVM IR failed"; + return nullptr; + } + + llvm::SMDiagnostic err; + for (auto &lib : extern_libs) { + auto ext_mod = llvm::parseIRFile(lib.second, err, *llvmContext); + if (!ext_mod) { + llvm::errs() << "Failed to load extern lib " << lib.first; + return nullptr; + } + ext_mod->setTargetTriple(llvmir->getTargetTriple()); + ext_mod->setDataLayout(llvmir->getDataLayout()); + + if (llvm::Linker::linkModules(*llvmir, std::move(ext_mod))) { + llvm::errs() << "Failed to link extern lib " << lib.first; + return nullptr; + } } return llvmir; } +void addExternalLibs(mlir::ModuleOp &module, + const std::vector &names, + const std::vector &paths) { + if (names.empty() || names.size() != paths.size()) + return; + + llvm::SmallVector attrs; + + for (size_t i = 0; i < names.size(); ++i) { + auto name = StringAttr::get(module->getContext(), names[i]); + auto path = StringAttr::get(module->getContext(), paths[i]); + NamedAttribute attr(name, path); + attrs.push_back(attr); + } + + DictionaryAttr dict = DictionaryAttr::get(module->getContext(), attrs); + module.getOperation()->setAttr("triton_gpu.externs", dict); + return; +} + } // namespace triton } // namespace mlir diff --git a/python/src/triton.cc b/python/src/triton.cc index a15d2dda2..a3d5a357e 100644 --- a/python/src/triton.cc +++ b/python/src/triton.cc @@ -1335,6 +1335,12 @@ void init_triton_translation(py::module &m) { py::bytes bytes(cubin); return bytes; }); + + m.def("add_external_libs", + [](mlir::ModuleOp &op, const std::vector &names, + const std::vector &paths) { + ::mlir::triton::addExternalLibs(op, names, paths); + }); } void init_triton(py::module &m) { diff --git a/python/tests/test_elementwise.py b/python/tests/test_elementwise.py new file mode 100644 index 000000000..f27990e74 --- /dev/null +++ b/python/tests/test_elementwise.py @@ -0,0 +1,189 @@ +import tempfile +from inspect import Parameter, Signature + +import _testcapi +import pytest +import torch +from torch.testing import assert_close + +import triton +import triton.language as tl + +torch_type = { + "bool": torch.bool, + "int32": torch.int32, + "float32": torch.float32, + "float64": torch.float64 +} + +torch_ops = { + "log": "log", + "cos": "cos", + "sin": "sin", + "sqrt": "sqrt", + "abs": "abs", + "exp": "exp", + "sigmoid": "sigmoid", + "umulhi": None, + "cdiv": None, + "fdiv": "div", + "minimum": "minimum", + "maximum": "maximum", + "where": "where", +} + +libdevice = '/usr/local/cuda/nvvm/libdevice/libdevice.10.bc' + + +def get_tensor(shape, data_type, b_positive=False): + x = None + if data_type.startswith('int'): + x = torch.randint(2**31 - 1, shape, dtype=torch_type[data_type], device='cuda') + elif data_type.startswith('bool'): + x = torch.randint(1, shape, dtype=torch_type[data_type], device='cuda') + else: + x = torch.randn(shape, dtype=torch_type[data_type], device='cuda') + + if b_positive: + x = torch.abs(x) + + return x + + +@pytest.mark.parametrize('expr, output_type, input0_type', + [('log', 'float32', 'float32'), + ('log', 'float64', 'float64'), + ('cos', 'float32', 'float32'), + ('cos', 'float64', 'float64'), + ('sin', 'float32', 'float32'), + ('sin', 'float64', 'float64'), + ('sqrt', 'float32', 'float32'), + ('sqrt', 'float64', 'float64'), + ('abs', 'float32', 'float32'), + ('exp', 'float32', 'float32'), + ('sigmoid', 'float32', 'float32'), + ]) +def test_single_input(expr, output_type, input0_type): + src = f""" +def kernel(X, Y, BLOCK: tl.constexpr): + x = tl.load(X + tl.arange(0, BLOCK)) + y = tl.{expr}(x) + tl.store(Y + tl.arange(0, BLOCK), y) +""" + fp = tempfile.NamedTemporaryFile(mode='w', suffix=".py") + fp.write(src) + fp.flush() + + def kernel(X, Y, BLOCK: tl.constexpr): + pass + kernel.__code__ = _testcapi.code_newempty(fp.name, "kernel", 1) + parameters = [] + parameters.append(Parameter("X", 1)) + parameters.append(Parameter("Y", 1)) + parameters.append(Parameter("BLOCK", 1)) + kernel.__signature__ = Signature(parameters=parameters) + kernel = triton.jit(kernel) + + shape = (128, ) + # limit the range of integers so that the sum does not overflow + x = get_tensor(shape, input0_type, expr == 'log' or expr == 'sqrt') + # triton result + y = torch.zeros(shape, dtype=torch_type[output_type], device="cuda") + kernel[(1,)](x, y, BLOCK=shape[0], extern_libs={"libdevice": libdevice}) + # reference result + y_ref = getattr(torch, torch_ops[expr])(x) + # compare + assert_close(y, y_ref) + + +@pytest.mark.parametrize('expr, output_type, input0_type, input1_type', + [('umulhi', 'int32', 'int32', 'int32'), + ('cdiv', 'int32', 'int32', 'int32'), + ('fdiv', 'float32', 'float32', 'float32'), + ('minimum', 'float32', 'float32', 'float32'), + ('maximum', 'float32', 'float32', 'float32'), + ]) +def test_two_input(expr, output_type, input0_type, input1_type): + src = f""" +def kernel(X0, X1, Y, BLOCK: tl.constexpr): + x0 = tl.load(X0 + tl.arange(0, BLOCK)) + x1 = tl.load(X1 + tl.arange(0, BLOCK)) + y = tl.{expr}(x0, x1) + tl.store(Y + tl.arange(0, BLOCK), y) +""" + fp = tempfile.NamedTemporaryFile(mode='w', suffix=".py") + fp.write(src) + fp.flush() + + def kernel(X0, X1, Y, BLOCK: tl.constexpr): + pass + kernel.__code__ = _testcapi.code_newempty(fp.name, "kernel", 1) + parameters = [] + parameters.append(Parameter("X0", 1)) + parameters.append(Parameter("X1", 1)) + parameters.append(Parameter("Y", 1)) + parameters.append(Parameter("BLOCK", 1)) + kernel.__signature__ = Signature(parameters=parameters) + kernel = triton.jit(kernel) + + shape = (128, ) + # limit the range of integers so that the sum does not overflow + x0 = get_tensor(shape, input0_type) + x1 = get_tensor(shape, input1_type) + + # triton result + y = torch.zeros(shape, dtype=torch_type[output_type], device="cuda") + kernel[(1,)](x0, x1, y, BLOCK=shape[0], extern_libs={"libdevice": libdevice}) + # reference result + + if expr == "cdiv": + y_ref = (x0 + x1 - 1) // x1 + elif expr == "umulhi": + y_ref = ((x0.to(torch.int64) * x1) >> 32).to(torch.int32) + else: + y_ref = getattr(torch, torch_ops[expr])(x0, x1) + # compare + assert_close(y, y_ref) + + +@pytest.mark.parametrize('expr, output_type, input0_type, input1_type, input2_type', + [('where', "int32", "bool", "int32", "int32"), ]) +def test_three_input(expr, output_type, input0_type, input1_type, input2_type): + src = f""" +def kernel(X0, X1, X2, Y, BLOCK: tl.constexpr): + x0 = tl.load(X0 + tl.arange(0, BLOCK)) + x1 = tl.load(X1 + tl.arange(0, BLOCK)) + x2 = tl.load(X2 + tl.arange(0, BLOCK)) + y = tl.{expr}(x0, x1, x2) + tl.store(Y + tl.arange(0, BLOCK), y) +""" + fp = tempfile.NamedTemporaryFile(mode='w', suffix=".py") + fp.write(src) + fp.flush() + + def kernel(X0, X1, X2, Y, BLOCK: tl.constexpr): + pass + kernel.__code__ = _testcapi.code_newempty(fp.name, "kernel", 1) + parameters = [] + parameters.append(Parameter("X0", 1)) + parameters.append(Parameter("X1", 1)) + parameters.append(Parameter("X2", 1)) + parameters.append(Parameter("Y", 1)) + parameters.append(Parameter("BLOCK", 1)) + kernel.__signature__ = Signature(parameters=parameters) + kernel = triton.jit(kernel) + + shape = (128, ) + # limit the range of integers so that the sum does not overflow + x0 = get_tensor(shape, input0_type) + x1 = get_tensor(shape, input1_type) + x2 = get_tensor(shape, input1_type) + + # triton result + y = torch.zeros(shape, dtype=torch_type[output_type], device="cuda") + kernel[(1,)](x0, x1, x2, y, BLOCK=shape[0], extern_libs={"libdevice": libdevice}) + # reference result + + y_ref = getattr(torch, torch_ops[expr])(x0, x1, x2) + # compare + assert_close(y, y_ref) diff --git a/python/tests/test_ext_elemwise.py b/python/tests/test_ext_elemwise.py new file mode 100644 index 000000000..9e44db65e --- /dev/null +++ b/python/tests/test_ext_elemwise.py @@ -0,0 +1,178 @@ + +import pytest +import torch +from torch.testing import assert_close + +import triton +import triton.language as tl + + +@pytest.mark.parametrize('num_warps, block_size, iter_size', [ + [4, 256, 1], + [4, 1024, 256], +]) +def test_sin_no_mask(num_warps, block_size, iter_size): + @triton.jit + def kernel(x_ptr, + y_ptr, + block_size, + iter_size: tl.constexpr): + pid = tl.program_id(axis=0) + for i in range(0, block_size, iter_size): + offset = pid * block_size + tl.arange(0, iter_size) + x_ptrs = x_ptr + offset + x = tl.load(x_ptrs) + y = tl.libdevice.sin(x) + y_ptrs = y_ptr + offset + tl.store(y_ptrs, y) + + x_ptr += iter_size + y_ptr += iter_size + + x = torch.randn((block_size,), device='cuda', dtype=torch.float32) + y = torch.empty((block_size,), device=x.device, dtype=x.dtype) + + grid = lambda EA: (x.shape.numel() // (block_size),) + kernel[grid](x_ptr=x, y_ptr=y, + block_size=x.shape[0], iter_size=iter_size, num_warps=num_warps) + + golden_y = torch.sin(x) + assert_close(y, golden_y, rtol=1e-7, atol=1e-7) + + +@pytest.mark.parametrize('num_warps, block_size, iter_size', [ + [4, 256, 1], + [4, 1024, 256], +]) +def test_fmin_no_mask(num_warps, block_size, iter_size): + @triton.jit + def kernel(x_ptr, + y_ptr, + z_ptr, + block_size, + iter_size: tl.constexpr): + pid = tl.program_id(axis=0) + for i in range(0, block_size, iter_size): + offset = pid * block_size + tl.arange(0, iter_size) + x_ptrs = x_ptr + offset + y_ptrs = y_ptr + offset + + x = tl.load(x_ptrs) + y = tl.load(y_ptrs) + z = tl.libdevice.min(x, y) + z_ptrs = z_ptr + offset + tl.store(z_ptrs, z) + + x_ptr += iter_size + y_ptr += iter_size + z_ptr += iter_size + + x = torch.randn((block_size,), device='cuda', dtype=torch.float32) + y = torch.randn((block_size,), device='cuda', dtype=torch.float32) + z = torch.empty((block_size,), device=x.device, dtype=x.dtype) + + grid = lambda EA: (x.shape.numel() // (block_size),) + kernel[grid](x_ptr=x, y_ptr=y, z_ptr=z, + block_size=x.shape[0], iter_size=iter_size, num_warps=num_warps) + + golden_z = torch.minimum(x, y) + assert_close(z, golden_z, rtol=1e-7, atol=1e-7) + + +@pytest.mark.parametrize('num_warps, block_size, iter_size', [ + [4, 256, 1], + [4, 1024, 256], +]) +def test_fmad_rn_no_mask(num_warps, block_size, iter_size): + @triton.jit + def kernel(x_ptr, + y_ptr, + z_ptr, + w_ptr, + block_size, + iter_size: tl.constexpr): + pid = tl.program_id(axis=0) + for i in range(0, block_size, iter_size): + offset = pid * block_size + tl.arange(0, iter_size) + x_ptrs = x_ptr + offset + y_ptrs = y_ptr + offset + z_ptrs = z_ptr + offset + + x = tl.load(x_ptrs) + y = tl.load(y_ptrs) + z = tl.load(z_ptrs) + + w = tl.libdevice.fma_rn(x, y, z) + w_ptrs = w_ptr + offset + tl.store(w_ptrs, w) + + x_ptr += iter_size + y_ptr += iter_size + z_ptr += iter_size + w_ptr += iter_size + + x = torch.randn((block_size,), device='cuda', dtype=torch.float64) + y = torch.randn((block_size,), device='cuda', dtype=torch.float64) + z = torch.randn((block_size,), device='cuda', dtype=torch.float64) + w = torch.empty((block_size,), device=x.device, dtype=x.dtype) + + grid = lambda EA: (x.shape.numel() // (block_size),) + kernel[grid](x_ptr=x, y_ptr=y, z_ptr=z, w_ptr=w, + block_size=x.shape[0], iter_size=iter_size, num_warps=num_warps) + + golden_w = x * y + z + assert_close(w, golden_w, rtol=1e-7, atol=1e-7) + + +@pytest.mark.parametrize("dtype_str, expr, lib_path", + [('int32', 'libdevice.ffs', '/usr/local/cuda/nvvm/libdevice/libdevice.10.bc'), + ('int32', 'libdevice.ffs', '')]) +def test_libdevice(dtype_str, expr, lib_path): + src = f""" +def kernel(X, Y, BLOCK: tl.constexpr): + x = tl.load(X + tl.arange(0, BLOCK)) + y = tl.{expr}(x) + tl.store(Y + tl.arange(0, BLOCK), y) +""" + import tempfile + from inspect import Parameter, Signature + + import _testcapi + + fp = tempfile.NamedTemporaryFile(mode='w', suffix=".py") + fp.write(src) + fp.flush() + + def kernel(X, Y, BLOCK: tl.constexpr): + pass + kernel.__code__ = _testcapi.code_newempty(fp.name, "kernel", 1) + parameters = [] + parameters.append(Parameter("X", 1)) + parameters.append(Parameter("Y", 1)) + parameters.append(Parameter("BLOCK", 1)) + kernel.__signature__ = Signature(parameters=parameters) + kernel = triton.jit(kernel) + + torch_type = { + "int32": torch.int32, + "float32": torch.float32, + "float64": torch.float64 + } + + shape = (128, ) + # limit the range of integers so that the sum does not overflow + x = None + if dtype_str == "int32": + x = torch.randint(2**31 - 1, shape, dtype=torch_type[dtype_str], device="cuda") + else: + x = torch.randn(shape, dtype=torch_type[dtype_str], device="cuda") + if expr == 'libdevice.ffs': + y_ref = torch.zeros(shape, dtype=x.dtype, device="cuda") + for i in range(shape[0]): + y_ref[i] = (int(x[i]) & int(-x[i])).bit_length() + + # triton result + y = torch.zeros(shape, dtype=x.dtype, device="cuda") + kernel[(1,)](x, y, BLOCK=shape[0], extern_libs={"libdevice": lib_path}) + # compare + assert_close(y, y_ref) diff --git a/python/triton/compiler.py b/python/triton/compiler.py index f711fde24..35512fabe 100644 --- a/python/triton/compiler.py +++ b/python/triton/compiler.py @@ -36,6 +36,7 @@ def str_to_ty(name): "bf16": triton.language.bfloat16, "fp32": triton.language.float32, "fp64": triton.language.float64, + "i1": triton.language.int1, "i8": triton.language.int8, "i16": triton.language.int16, "i32": triton.language.int32, @@ -45,7 +46,6 @@ def str_to_ty(name): "u32": triton.language.uint32, "u64": triton.language.uint64, "B": triton.language.int1, - "i1": triton.language.int1, } return tys[name] @@ -888,6 +888,13 @@ def optimize_tritongpu_ir(mod, num_stages): return mod +def add_external_libs(mod, libs): + for name, path in libs.items(): + if len(name) == 0 or len(path) == 0: + return + _triton.add_external_libs(mod, list(libs.keys()), list(libs.values())) + + def make_llvm_ir(mod): return _triton.translate_triton_gpu_to_llvmir(mod) @@ -986,6 +993,8 @@ def _compile(fn, signature: str, device: int = -1, constants=dict(), specializat module = optimize_tritongpu_ir(module, num_stages) if output == "ttgir": return module.str() + if extern_libs: + add_external_libs(module, extern_libs) # llvm-ir llvm_ir = make_llvm_ir(module) diff --git a/python/triton/language/libdevice.10.bc b/python/triton/language/libdevice.10.bc new file mode 100755 index 000000000..b2c75a502 Binary files /dev/null and b/python/triton/language/libdevice.10.bc differ diff --git a/python/triton/language/semantic.py b/python/triton/language/semantic.py index b7fda1736..878e07a8f 100644 --- a/python/triton/language/semantic.py +++ b/python/triton/language/semantic.py @@ -226,7 +226,6 @@ def fdiv(input: tl.tensor, raise ValueError("both operands of fdiv must have floating poscalar type") input, other = binary_op_type_checking_impl(input, other, builder, False, False, False, True) ret = builder.create_fdiv(input.handle, other.handle) - ret.set_fdiv_ieee_rounding(ieee_rounding) return tl.tensor(ret, input.type) @@ -1074,7 +1073,8 @@ def xor_sum(input: tl.tensor, axis: int, builder: ir.builder) -> tl.tensor: def umulhi(x: tl.tensor, y: tl.tensor, builder: ir.builder) -> tl.tensor: x, y = binary_op_type_checking_impl(x, y, builder) - return tl.tensor(builder.create_umulhi(x.handle, y.handle), x.type) + from . import libdevice + return libdevice.mulhi(x, y, _builder=builder) def exp(x: tl.tensor, builder: ir.builder) -> tl.tensor: