diff --git a/.github/workflows/integration-tests.yml b/.github/workflows/integration-tests.yml index d61848742..370575bbb 100644 --- a/.github/workflows/integration-tests.yml +++ b/.github/workflows/integration-tests.yml @@ -8,20 +8,19 @@ on: - triton-mlir jobs: - Runner-Preparation: runs-on: ubuntu-latest outputs: matrix: ${{ steps.set-matrix.outputs.matrix }} steps: - - name: Prepare runner matrix - id: set-matrix - run: | - if [ x"${{ github.repository }}" == x"openai/triton" ]; then - echo '::set-output name=matrix::[["self-hosted", "A10"], "macos-latest"]' - else - echo '::set-output name=matrix::["ubuntu-latest", "macos-latest"]' - fi + - name: Prepare runner matrix + id: set-matrix + run: | + if [ x"${{ github.repository }}" == x"openai/triton" ]; then + echo '::set-output name=matrix::[["self-hosted", "A10"], "macos-10.15"]' + else + echo '::set-output name=matrix::["ubuntu-latest", "macos-10.15"]' + fi Integration-Tests: needs: Runner-Preparation @@ -33,7 +32,6 @@ jobs: runner: ${{fromJson(needs.Runner-Preparation.outputs.matrix)}} steps: - - name: Checkout uses: actions/checkout@v2 @@ -42,33 +40,32 @@ jobs: rm -rf ~/.triton/cache/ - name: Check imports - if: ${{ matrix.runner != 'macos-latest' }} + if: startsWith(matrix.runner, 'ubuntu') run: | pip install isort isort -c ./python || ( echo '::error title=Imports not sorted::Please run \"isort ./python\"' ; exit 1 ) - name: Check python style - if: ${{ matrix.runner != 'macos-latest' }} + if: startsWith(matrix.runner, 'ubuntu') run: | pip install autopep8 autopep8 -a -r -d --exit-code ./python || ( echo '::error title=Style issues::Please run \"autopep8 -a -r -i ./python\"' ; exit 1 ) - name: Check cpp style - if: ${{ matrix.runner != 'macos-latest' }} + if: startsWith(matrix.runner, 'ubuntu') run: | pip install clang-format find . -regex '.*\.\(cpp\|hpp\|h\|cc\)' -not -path "./python/build/*" -not -path "./include/triton/external/*" -print0 | xargs -0 -n1 clang-format -style=file --dry-run -Werror -i || (echo '::error title=Style issues:: Please run `find . -regex ".*\.\(cpp\|hpp\|h\|cc\)" -not -path "./python/build/*" -not -path "./include/triton/external/*" -print0 | xargs -0 -n1 clang-format -style=file -i`' ; exit 1) - name: Flake8 - if: ${{ matrix.runner != 'macos-latest' }} + if: startsWith(matrix.runner, 'ubuntu') run: | pip install flake8 flake8 --config ./python/setup.cfg ./python || ( echo '::error::Flake8 failed; see logs for errors.' ; exit 1 ) - name: Install Triton run: | - alias python='python3' cd python TRITON_USE_ASSERT_ENABLED_LLVM=TRUE pip3 install -e '.[tests]' @@ -82,7 +79,7 @@ jobs: lit -v "$LIT_TEST_DIR" - name: Run python tests - if: ${{ matrix.runner[0] == 'self-hosted' }} + if: ${{matrix.runner[0] == 'self-hosted'}} run: | cd python/tests pytest diff --git a/include/triton/Dialect/Triton/IR/TritonOps.td b/include/triton/Dialect/Triton/IR/TritonOps.td index 325ea3500..e089f82c4 100644 --- a/include/triton/Dialect/Triton/IR/TritonOps.td +++ b/include/triton/Dialect/Triton/IR/TritonOps.td @@ -30,8 +30,8 @@ class TT_Op traits = []> : // fptoui, fptosi, uitofp, sitofp, // extf, tructf, // extui, extsi, tructi -def TT_IntToPtrOp : TT_Op<"int_to_ptr", [SameOperandsAndResultShape, - SameOperandsAndResultEncoding, +def TT_IntToPtrOp : TT_Op<"int_to_ptr", [SameOperandsAndResultShape, + SameOperandsAndResultEncoding, NoSideEffect, /*DeclareOpInterfaceMethods*/]> { let summary = "Cast int64 to pointer"; @@ -43,7 +43,7 @@ def TT_IntToPtrOp : TT_Op<"int_to_ptr", [SameOperandsAndResultShape, let assemblyFormat = "$from attr-dict `:` type($from) `->` type($result)"; } -def TT_PtrToIntOp : TT_Op<"ptr_to_int", [SameOperandsAndResultShape, +def TT_PtrToIntOp : TT_Op<"ptr_to_int", [SameOperandsAndResultShape, SameOperandsAndResultEncoding, NoSideEffect, /*DeclareOpInterfaceMethods*/]> { @@ -57,7 +57,7 @@ def TT_PtrToIntOp : TT_Op<"ptr_to_int", [SameOperandsAndResultShape, } // arith.bitcast doesn't support pointers -def TT_BitcastOp : TT_Op<"bitcast", [SameOperandsAndResultShape, +def TT_BitcastOp : TT_Op<"bitcast", [SameOperandsAndResultShape, SameOperandsAndResultEncoding, NoSideEffect, /*DeclareOpInterfaceMethods*/]> { @@ -72,7 +72,7 @@ def TT_BitcastOp : TT_Op<"bitcast", [SameOperandsAndResultShape, // TODO: Add verifier } -def TT_FpToFp : TT_Op<"fp_to_fp", [SameOperandsAndResultShape, +def TT_FpToFp : TT_Op<"fp_to_fp", [SameOperandsAndResultShape, SameOperandsAndResultEncoding, NoSideEffect, /*DeclareOpInterfaceMethods*/]> { @@ -99,7 +99,7 @@ def TT_FpToFp : TT_Op<"fp_to_fp", [SameOperandsAndResultShape, // def TT_AddPtrOp : TT_Op<"addptr", - [NoSideEffect, + [NoSideEffect, SameOperandsAndResultShape, SameOperandsAndResultEncoding, TypesMatchWith<"result type matches ptr type", @@ -224,7 +224,7 @@ def TT_AtomicCASOp : TT_Op<"atomic_cas", [SameOperandsAndResultShape, // // Shape Manipulation Ops // -def TT_SplatOp : TT_Op<"splat", [NoSideEffect, +def TT_SplatOp : TT_Op<"splat", [NoSideEffect, SameOperandsAndResultElementType]> { let summary = "splat"; @@ -237,8 +237,8 @@ def TT_SplatOp : TT_Op<"splat", [NoSideEffect, let hasFolder = 1; } -def TT_ExpandDimsOp : TT_Op<"expand_dims", [NoSideEffect, - DeclareOpInterfaceMethods, +def TT_ExpandDimsOp : TT_Op<"expand_dims", [NoSideEffect, + DeclareOpInterfaceMethods, SameOperandsAndResultElementType]> { let summary = "expand_dims"; @@ -249,7 +249,7 @@ def TT_ExpandDimsOp : TT_Op<"expand_dims", [NoSideEffect, let assemblyFormat = "$src attr-dict `:` functional-type(operands, results)"; } -def TT_ViewOp : TT_Op<"view", [NoSideEffect, +def TT_ViewOp : TT_Op<"view", [NoSideEffect, SameOperandsAndResultElementType]> { let summary = "view"; @@ -261,7 +261,7 @@ def TT_ViewOp : TT_Op<"view", [NoSideEffect, } -def TT_BroadcastOp : TT_Op<"broadcast", [NoSideEffect, +def TT_BroadcastOp : TT_Op<"broadcast", [NoSideEffect, SameOperandsAndResultElementType]> { let summary = "broadcast. No left-padding as of now."; @@ -274,7 +274,7 @@ def TT_BroadcastOp : TT_Op<"broadcast", [NoSideEffect, let hasFolder = 1; } -def TT_CatOp : TT_Op<"cat", [NoSideEffect, +def TT_CatOp : TT_Op<"cat", [NoSideEffect, SameOperandsAndResultElementType]> { let summary = "concatenate 2 tensors"; @@ -307,7 +307,7 @@ def TT_GetNumProgramsOp : TT_Op<"get_num_programs", [NoSideEffect]> { // // Dot Op // -def TT_DotOp : TT_Op<"dot", [NoSideEffect, +def TT_DotOp : TT_Op<"dot", [NoSideEffect, DeclareOpInterfaceMethods, TypesMatchWith<"result's type matches accumulator's type", "d", "c", "$_self">]> { @@ -357,12 +357,11 @@ def TT_ExtElemwiseOp : TT_Op<"ext_elemwise", [NoSideEffect, Elementwise, SameOpe return $libpath/$libname:$symbol($args...) }]; - let arguments = (ins Variadic:$args, StrAttr:$libname, StrAttr:$libpath, StrAttr:$symbol); + let arguments = (ins Variadic:$args, StrAttr:$libname, StrAttr:$libpath, StrAttr:$symbol); - let results = (outs TT_Tensor:$result); + let results = (outs TT_Type:$result); let assemblyFormat = "operands attr-dict `:` type(operands) `->` type($result)"; - } // @@ -385,4 +384,20 @@ def TT_MakeRangeOp : TT_Op<"make_range", [NoSideEffect]> { let assemblyFormat = "attr-dict `:` type($result)"; } +// +// Make PrintfOp +// +def TT_PrintfOp : TT_Op<"printf", [MemoryEffects<[MemWrite]>]>, + Arguments<(ins StrAttr:$prefix, + Variadic>:$args)> { + let summary = "Device-side printf, as in CUDA for debugging"; + let description = [{ + `tt.printf` takes a literal string prefix and an arbitrary number of scalar or tensor arguments that should be printed. + format are generated automatically from the arguments. + }]; + let assemblyFormat = [{ + $prefix attr-dict ($args^ `:` type($args))? + }]; +} + #endif // Triton_OPS diff --git a/lib/Analysis/AxisInfo.cpp b/lib/Analysis/AxisInfo.cpp index 3c2dca2e6..5391574dd 100644 --- a/lib/Analysis/AxisInfo.cpp +++ b/lib/Analysis/AxisInfo.cpp @@ -159,6 +159,16 @@ ChangeResult AxisInfoAnalysis::visitOperation( curr = visitBinaryOp(op, operands[0]->getValue(), operands[1]->getValue(), newContiguity, newDivisibility, newConstancy); } + // TODO: All other binary ops + if (llvm::isa(op)) { + auto newContiguity = [](AxisInfo lhs, AxisInfo rhs, int d) { return 1; }; + auto newDivisibility = [](AxisInfo lhs, AxisInfo rhs, int d) { return 1; }; + auto newConstancy = [](AxisInfo lhs, AxisInfo rhs, int d) { + return gcd(lhs.getConstancy(d), rhs.getConstancy(d)); + }; + curr = visitBinaryOp(op, operands[0]->getValue(), operands[1]->getValue(), + newContiguity, newDivisibility, newConstancy); + } // Splat if (llvm::isa(op)) { Type _retTy = *op->result_type_begin(); @@ -200,7 +210,8 @@ ChangeResult AxisInfoAnalysis::visitOperation( for (int d = 0; d < retTy.getRank(); ++d) { contiguity.push_back(opShape[d] == 1 ? 1 : opInfo.getContiguity(d)); divisibility.push_back(opInfo.getDivisibility(d)); - constancy.push_back(opShape[d] == 1 ? retShape[d] : 1); + constancy.push_back(opShape[d] == 1 ? retShape[d] + : opInfo.getConstancy(d)); } curr = AxisInfo(contiguity, divisibility, constancy); } diff --git a/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.cpp b/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.cpp index 61740ae22..4aaddb7a2 100644 --- a/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.cpp +++ b/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.cpp @@ -119,9 +119,11 @@ Value createLLVMIntegerConstant(OpBuilder &builder, Location loc, short width, #define barrier() rewriter.create(loc) #define undef(...) rewriter.create(loc, __VA_ARGS__) #define i32_ty rewriter.getIntegerType(32) +#define ui32_ty rewriter.getIntegerType(32, false) #define f16_ty rewriter.getF16Type() #define i8_ty rewriter.getIntegerType(8) #define f32_ty rewriter.getF32Type() +#define f64_ty rewriter.getF64Type() #define vec_ty(type, num) VectorType::get(num, type) #define f32_val(...) LLVM::createConstantF32(loc, rewriter, __VA_ARGS__) #define void_ty(ctx) LLVM::LLVMVoidType::get(ctx) @@ -691,7 +693,8 @@ Value convertSplatLikeOp(Type elemType, Type resType, Value constVal, TypeConverter *typeConverter, ConversionPatternRewriter &rewriter, Location loc) { auto tensorTy = resType.cast(); - if (tensorTy.getEncoding().isa()) { + if (tensorTy.getEncoding().isa() || + tensorTy.getEncoding().isa()) { auto tensorTy = resType.cast(); auto srcType = typeConverter->convertType(elemType); auto llSrc = bitcast(constVal, srcType); @@ -1807,6 +1810,7 @@ public: } Value view = getStructFromElements(loc, resultVals, rewriter, structTy); rewriter.replaceOp(op, view); + return success(); } @@ -4731,6 +4735,190 @@ struct FDivOpConversion } }; +struct PrintfOpConversion + : public ConvertTritonGPUOpToLLVMPattern { + using ConvertTritonGPUOpToLLVMPattern< + triton::PrintfOp>::ConvertTritonGPUOpToLLVMPattern; + + LogicalResult + matchAndRewrite(triton::PrintfOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op->getLoc(); + SmallVector operands; + for (auto operand : adaptor.getOperands()) { + auto sub_operands = this->getElementsFromStruct(loc, operand, rewriter); + for (auto elem : sub_operands) { + operands.push_back(elem); + } + } + std::string formatStr; + llvm::raw_string_ostream os(formatStr); + os << op.prefix(); + if (operands.size() > 0) { + os << getFormatSubstr(operands[0]); + } + + for (size_t i = 1; i < operands.size(); ++i) { + os << ", " << getFormatSubstr(operands[i]); + } + llPrintf(formatStr, operands, rewriter); + rewriter.eraseOp(op); + return success(); + } + // get format specific for each input value + // currently support pointer, i8, i16, i32, i64, f16, bf16, f32, f64 + std::string getFormatSubstr(Value value) const { + Type type = value.getType(); + unsigned width = type.getIntOrFloatBitWidth(); + + if (type.isa()) { + return "%p"; + } else if (type.isBF16() || type.isF16() || type.isF32() || type.isF64()) { + return "%f"; + } else if (type.isSignedInteger()) { + return "%i"; + } else if (type.isUnsignedInteger() || type.isSignlessInteger()) { + return "%u"; + } + assert(false && "not supported type"); + } + + // declare vprintf(i8*, i8*) as external function + LLVM::LLVMFuncOp + getVprintfDeclaration(ConversionPatternRewriter &rewriter) const { + auto moduleOp = + rewriter.getBlock()->getParent()->getParentOfType(); + StringRef funcName("vprintf"); + Operation *funcOp = moduleOp.lookupSymbol(funcName); + if (funcOp) + return cast(*funcOp); + + auto *context = rewriter.getContext(); + + SmallVector argsType{ptr_ty(IntegerType::get(context, 8)), + ptr_ty(IntegerType::get(context, 8))}; + auto funcType = LLVM::LLVMFunctionType::get(i32_ty, argsType); + + ConversionPatternRewriter::InsertionGuard guard(rewriter); + rewriter.setInsertionPointToStart(moduleOp.getBody()); + + return rewriter.create(UnknownLoc::get(context), funcName, + funcType); + } + + // extend integer to int32, extend float to float64 + // this comes from vprintf alignment requirements. + std::pair promoteValue(ConversionPatternRewriter &rewriter, + Value value) const { + auto *context = rewriter.getContext(); + auto type = value.getType(); + unsigned width = type.getIntOrFloatBitWidth(); + Value newOp = value; + Type newType = type; + + bool bUnsigned = type.isUnsignedInteger(); + if (type.isIntOrIndex() && width < 32) { + if (bUnsigned) { + newType = ui32_ty; + newOp = rewriter.create(UnknownLoc::get(context), newType, + value); + } else { + newType = i32_ty; + newOp = rewriter.create(UnknownLoc::get(context), newType, + value); + } + } else if (type.isBF16() || type.isF16() || type.isF32()) { + newType = f64_ty; + newOp = rewriter.create(UnknownLoc::get(context), newType, + value); + } + + return {newType, newOp}; + } + + void llPrintf(StringRef msg, ValueRange args, + ConversionPatternRewriter &rewriter) const { + static const char formatStringPrefix[] = "printfFormat_"; + assert(!msg.empty() && "printf with empty string not support"); + Type int8Ptr = ptr_ty(i8_ty); + + auto *context = rewriter.getContext(); + auto moduleOp = + rewriter.getBlock()->getParent()->getParentOfType(); + auto funcOp = getVprintfDeclaration(rewriter); + + Value one = rewriter.create( + UnknownLoc::get(context), i32_ty, rewriter.getI32IntegerAttr(1)); + Value zero = rewriter.create( + UnknownLoc::get(context), i32_ty, rewriter.getI32IntegerAttr(0)); + + unsigned stringNumber = 0; + SmallString<16> stringConstName; + do { + stringConstName.clear(); + (formatStringPrefix + Twine(stringNumber++)).toStringRef(stringConstName); + } while (moduleOp.lookupSymbol(stringConstName)); + + llvm::SmallString<64> formatString(msg); + formatString.push_back('\n'); + formatString.push_back('\0'); + size_t formatStringSize = formatString.size_in_bytes(); + auto globalType = LLVM::LLVMArrayType::get(i8_ty, formatStringSize); + + LLVM::GlobalOp global; + { + ConversionPatternRewriter::InsertionGuard guard(rewriter); + rewriter.setInsertionPointToStart(moduleOp.getBody()); + global = rewriter.create( + UnknownLoc::get(context), globalType, + /*isConstant=*/true, LLVM::Linkage::Internal, stringConstName, + rewriter.getStringAttr(formatString)); + } + + Value globalPtr = + rewriter.create(UnknownLoc::get(context), global); + Value stringStart = + rewriter.create(UnknownLoc::get(context), int8Ptr, + globalPtr, mlir::ValueRange({zero, zero})); + + Value bufferPtr = + rewriter.create(UnknownLoc::get(context), int8Ptr); + + SmallVector newArgs; + if (args.size() >= 1) { + SmallVector argTypes; + for (auto arg : args) { + Type newType; + Value newArg; + std::tie(newType, newArg) = promoteValue(rewriter, arg); + argTypes.push_back(newType); + newArgs.push_back(newArg); + } + + Type structTy = LLVM::LLVMStructType::getLiteral(context, argTypes); + auto allocated = rewriter.create(UnknownLoc::get(context), + ptr_ty(structTy), one, + /*alignment=*/0); + + for (const auto &entry : llvm::enumerate(newArgs)) { + auto index = rewriter.create( + UnknownLoc::get(context), i32_ty, + rewriter.getI32IntegerAttr(entry.index())); + auto fieldPtr = rewriter.create( + UnknownLoc::get(context), ptr_ty(argTypes[entry.index()]), + allocated, ArrayRef{zero, index}); + rewriter.create(UnknownLoc::get(context), entry.value(), + fieldPtr); + } + bufferPtr = rewriter.create(UnknownLoc::get(context), + int8Ptr, allocated); + } + + ValueRange operands{stringStart, bufferPtr}; + rewriter.create(UnknownLoc::get(context), funcOp, operands); + } +}; + void populateTritonToLLVMPatterns(mlir::LLVMTypeConverter &typeConverter, RewritePatternSet &patterns, int numWarps, AxisInfoAnalysis &axisInfoAnalysis, @@ -4817,6 +5005,7 @@ void populateTritonToLLVMPatterns(mlir::LLVMTypeConverter &typeConverter, patterns.add>(typeConverter, benefit); patterns.add(typeConverter, allocation, smem, benefit); + patterns.add(typeConverter, benefit); } class ConvertTritonGPUToLLVM diff --git a/lib/Conversion/TritonToTritonGPU/TritonToTritonGPU.cpp b/lib/Conversion/TritonToTritonGPU/TritonToTritonGPU.cpp index 9519ff998..706d10fed 100644 --- a/lib/Conversion/TritonToTritonGPU/TritonToTritonGPU.cpp +++ b/lib/Conversion/TritonToTritonGPU/TritonToTritonGPU.cpp @@ -339,6 +339,18 @@ struct TritonReducePattern : public OpConversionPattern { } }; +struct TritonPrintfPattern : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(PrintfOp op, typename PrintfOp::Adaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + rewriter.replaceOpWithNewOp(op, op.prefixAttr(), + adaptor.getOperands()); + return success(); + } +}; + void populateTritonPatterns(TritonGPUTypeConverter &typeConverter, RewritePatternSet &patterns) { MLIRContext *context = patterns.getContext(); @@ -350,8 +362,8 @@ void populateTritonPatterns(TritonGPUTypeConverter &typeConverter, TritonGenericPattern, TritonBroadcastPattern, TritonGenericPattern, TritonReducePattern, TritonExpandDimsPattern, TritonMakeRangePattern, TritonDotPattern, - TritonLoadPattern, TritonStorePattern, TritonExtElemwisePattern>( - typeConverter, context); + TritonLoadPattern, TritonStorePattern, TritonExtElemwisePattern, + TritonPrintfPattern>(typeConverter, context); } // diff --git a/lib/Dialect/TritonGPU/Transforms/Combine.cpp b/lib/Dialect/TritonGPU/Transforms/Combine.cpp index 36ea7030f..456ce1200 100644 --- a/lib/Dialect/TritonGPU/Transforms/Combine.cpp +++ b/lib/Dialect/TritonGPU/Transforms/Combine.cpp @@ -533,6 +533,35 @@ public: BlockedToMMA(mlir::MLIRContext *context) : mlir::RewritePattern(triton::DotOp::getOperationName(), 2, context) {} + static SmallVector + getWarpsPerTile(const ArrayRef &shape, int version, int numWarps) { + assert(version == 2); + // TODO: Handle one warp per row for fused matmuls + // TODO: unsigned -> int64_t to keep things uniform + SmallVector ret = {1, 1}; + SmallVector shapePerWarp = {16, 8}; + bool changed = false; + // TODO (@daadaada): double-check. + // original logic in + // https://github.com/openai/triton/blob/master/lib/codegen/analysis/layout.cc#L252 + // seems buggy for shape = [32, 16] ? + do { + changed = false; + if (ret[0] * ret[1] >= numWarps) + break; + if (shape[0] / shapePerWarp[0] / ret[0] >= + shape[1] / (shapePerWarp[1] * 2) / ret[1]) { + if (ret[0] < shape[0] / shapePerWarp[0]) { + ret[0] *= 2; + } else + ret[1] *= 2; + } else { + ret[1] *= 2; + } + } while (true); + return ret; + } + mlir::LogicalResult matchAndRewrite(mlir::Operation *op, mlir::PatternRewriter &rewriter) const override { @@ -541,13 +570,20 @@ public: auto oldRetType = dotOp.getResult().getType().cast(); if (oldRetType.getEncoding().isa()) return failure(); - // TODO: compute warpsPerCTA - auto newRetType = RankedTensorType::get( - oldRetType.getShape(), oldRetType.getElementType(), - triton::gpu::MmaEncodingAttr::get(oldRetType.getContext(), 2, {2, 2})); + // get MMA encoding for the given number of warps + auto retShape = oldRetType.getShape(); + auto mod = op->getParentOfType(); + int numWarps = triton::gpu::TritonGPUDialect::getNumWarps(mod); + auto newRetType = + RankedTensorType::get(retShape, oldRetType.getElementType(), + triton::gpu::MmaEncodingAttr::get( + oldRetType.getContext(), 2, + getWarpsPerTile(retShape, 2, numWarps))); + // convert accumulator auto oldAcc = dotOp.getOperand(2); auto newAcc = rewriter.create( oldAcc.getLoc(), newRetType, oldAcc); + // convert output auto newDot = rewriter.create( dotOp.getLoc(), newRetType, dotOp.getOperand(0), dotOp.getOperand(1), newAcc, dotOp.allowTF32(), dotOp.transA(), dotOp.transB()); diff --git a/lib/Target/LLVMIR/LLVMIRTranslation.cpp b/lib/Target/LLVMIR/LLVMIRTranslation.cpp index be2e65b31..eaabb7c24 100644 --- a/lib/Target/LLVMIR/LLVMIRTranslation.cpp +++ b/lib/Target/LLVMIR/LLVMIRTranslation.cpp @@ -197,7 +197,8 @@ translateTritonGPUToLLVMIR(llvm::LLVMContext *llvmContext, ext_mod->setTargetTriple(llvmir->getTargetTriple()); ext_mod->setDataLayout(llvmir->getDataLayout()); - if (llvm::Linker::linkModules(*llvmir, std::move(ext_mod))) { + if (llvm::Linker::linkModules(*llvmir, std::move(ext_mod), + llvm::Linker::Flags::LinkOnlyNeeded)) { llvm::errs() << "Failed to link extern lib " << lib.first; return nullptr; } diff --git a/python/src/triton.cc b/python/src/triton.cc index 56ee90b18..4cbc3f1e7 100644 --- a/python/src/triton.cc +++ b/python/src/triton.cc @@ -1185,6 +1185,16 @@ void init_triton_ir(py::module &&m) { auto loc = self.getUnknownLoc(); return self.create(loc, condition, trueValue, falseValue); + }) + .def("create_printf", + [](mlir::OpBuilder &self, const std::string &prefix, + const std::vector &values) -> void { + auto loc = self.getUnknownLoc(); + self.create( + loc, + mlir::StringAttr::get(self.getContext(), + llvm::StringRef(prefix)), + values); }); py::class_(m, "pass_manager") diff --git a/python/tests/printf_helper.py b/python/tests/printf_helper.py new file mode 100644 index 000000000..22e1350f1 --- /dev/null +++ b/python/tests/printf_helper.py @@ -0,0 +1,56 @@ +import torch +from torch.testing import assert_close + +import triton +import triton.language as tl + +torch_type = { + "bool": torch.bool, + 'int8': torch.int8, + 'uint8': torch.uint8, + 'int16': torch.int16, + "int32": torch.int32, + 'int64': torch.long, + 'float16': torch.float16, + 'bfloat16': torch.bfloat16, + "float32": torch.float32, + "float64": torch.float64 +} + + +def get_tensor(shape, data_type, b_positive=False): + x = None + if data_type.startswith('int'): + x = torch.arange(0, shape[0], dtype=torch_type[data_type], device='cuda') + else: + x = torch.arange(0, shape[0], dtype=torch_type[data_type], device='cuda') + + return x + +# @pytest.mark.parametrize('data_type', +# [("int8"), +# ('int16'), +# ('int32'), +# ("int64"), +# ('float16'), +# ("float32"), +# ("float64")]) + + +def printf(data_type): + @triton.jit + def kernel(X, Y, BLOCK: tl.constexpr): + x = tl.load(X + tl.arange(0, BLOCK)) + tl.printf("", x) + tl.store(Y + tl.arange(0, BLOCK), x) + + shape = (128, ) + # limit the range of integers so that the sum does not overflow + x = get_tensor(shape, data_type) + y = torch.zeros(shape, dtype=x.dtype, device="cuda") + kernel[(1,)](x, y, BLOCK=shape[0]) + assert_close(y, x) + + +printf("float16") +printf("int8") diff --git a/python/tests/test_core.py b/python/tests/test_core.py index 5fd25b118..7609c9419 100644 --- a/python/tests/test_core.py +++ b/python/tests/test_core.py @@ -144,7 +144,7 @@ def _test_unary(dtype_x, expr, numpy_expr=None, device='cuda'): # triton result x_tri = to_triton(x, device=device, dst_type=dtype_x) z_tri = to_triton(np.empty_like(z_ref), device=device, dst_type=dtype_x) - kernel[(1, )](z_tri, x_tri, SIZE=SIZE, num_warps=4) + kernel[(1, )](z_tri, x_tri, SIZE=SIZE, num_warps=4, extern_libs={"libdevice": "/usr/local/cuda/nvvm/libdevice/libdevice.10.bc"}) # compare np.testing.assert_allclose(z_ref, to_numpy(z_tri), rtol=0.01) @@ -463,17 +463,12 @@ def test_unary_op(dtype_x, expr, device='cuda'): # # test math ops # # ---------------- -# TODO: Math module -# # @pytest.mark.parametrize("expr", [ -# # 'exp', 'log', 'cos', 'sin' -# # ]) - -# @pytest.mark.parametrize("expr", [ -# 'exp', 'log', 'cos', 'sin' -# ]) -# def test_math_op(expr, device='cuda'): -# _test_unary('float32', f'tl.{expr}(x)', f'np.{expr}(x) ', device=device) +@pytest.mark.parametrize("expr", [ + 'exp', 'log', 'cos', 'sin' +]) +def test_math_op(expr, device='cuda'): + _test_unary('float32', f'tl.{expr}(x)', f'np.{expr}(x) ', device=device) # # ---------------- @@ -1545,43 +1540,72 @@ def test_num_warps_pow2(): # # ------------- -# @pytest.mark.parametrize("dtype_str, expr, lib_path", -# [('int32', 'libdevice.ffs', ''), -# ('float32', 'libdevice.pow', '/usr/local/cuda/nvvm/libdevice/libdevice.10.bc'), -# ('float64', 'libdevice.norm4d', '')]) -# def test_libdevice(dtype_str, expr, lib_path): +@pytest.mark.parametrize("dtype_str, expr, lib_path", + [('int32', 'libdevice.ffs', ''), + ('float32', 'libdevice.pow', '/usr/local/cuda/nvvm/libdevice/libdevice.10.bc'), + ('float64', 'libdevice.norm4d', '')]) +def test_libdevice_tensor(dtype_str, expr, lib_path): -# @triton.jit -# def kernel(X, Y, BLOCK: tl.constexpr): -# x = tl.load(X + tl.arange(0, BLOCK)) -# y = GENERATE_TEST_HERE -# tl.store(Y + tl.arange(0, BLOCK), y) + @triton.jit + def kernel(X, Y, BLOCK: tl.constexpr): + x = tl.load(X + tl.arange(0, BLOCK)) + y = GENERATE_TEST_HERE + tl.store(Y + tl.arange(0, BLOCK), y) -# shape = (128, ) -# rs = RandomState(17) -# # limit the range of integers so that the sum does not overflow -# x = numpy_random(shape, dtype_str=dtype_str, rs=rs) + shape = (128, ) + rs = RandomState(17) + # limit the range of integers so that the sum does not overflow + x = numpy_random(shape, dtype_str=dtype_str, rs=rs) -# if expr == 'libdevice.ffs': -# kernel = patch_kernel(kernel, {'GENERATE_TEST_HERE': 'tl.libdevice.ffs(x)'}) -# y_ref = np.zeros(shape, dtype=x.dtype) -# for i in range(shape[0]): -# y_ref[i] = (int(x[i]) & int(-x[i])).bit_length() -# elif expr == 'libdevice.pow': -# # numpy does not allow negative factors in power, so we use abs() -# x = np.abs(x) -# kernel = patch_kernel(kernel, {'GENERATE_TEST_HERE': 'tl.libdevice.pow(x, x)'}) -# y_ref = np.power(x, x) -# elif expr == 'libdevice.norm4d': -# kernel = patch_kernel(kernel, {'GENERATE_TEST_HERE': 'tl.libdevice.norm4d(x, x, x, x)'}) -# y_ref = np.sqrt(4 * np.power(x, 2)) + if expr == 'libdevice.ffs': + kernel = patch_kernel(kernel, {'GENERATE_TEST_HERE': 'tl.libdevice.ffs(x)'}) + y_ref = np.zeros(shape, dtype=x.dtype) + for i in range(shape[0]): + y_ref[i] = (int(x[i]) & int(-x[i])).bit_length() + elif expr == 'libdevice.pow': + # numpy does not allow negative factors in power, so we use abs() + x = np.abs(x) + kernel = patch_kernel(kernel, {'GENERATE_TEST_HERE': 'tl.libdevice.pow(x, x)'}) + y_ref = np.power(x, x) + elif expr == 'libdevice.norm4d': + kernel = patch_kernel(kernel, {'GENERATE_TEST_HERE': 'tl.libdevice.norm4d(x, x, x, x)'}) + y_ref = np.sqrt(4 * np.power(x, 2)) -# x_tri = to_triton(x) -# # triton result -# y_tri = to_triton(numpy_random((shape[0],), dtype_str=dtype_str, rs=rs), device='cuda') -# kernel[(1,)](x_tri, y_tri, BLOCK=shape[0], extern_libs={'libdevice': lib_path}) -# # compare -# if expr == 'libdevice.ffs': -# np.testing.assert_equal(y_ref, to_numpy(y_tri)) -# else: -# np.testing.assert_allclose(y_ref, to_numpy(y_tri), rtol=0.01) + x_tri = to_triton(x) + # triton result + y_tri = to_triton(numpy_random((shape[0],), dtype_str=dtype_str, rs=rs), device='cuda') + kernel[(1,)](x_tri, y_tri, BLOCK=shape[0], extern_libs={'libdevice': lib_path}) + # compare + if expr == 'libdevice.ffs': + np.testing.assert_equal(y_ref, to_numpy(y_tri)) + else: + np.testing.assert_allclose(y_ref, to_numpy(y_tri), rtol=0.01) + + +@pytest.mark.parametrize("dtype_str, expr, lib_path", + [('float32', 'libdevice.pow', '')]) +def test_libdevice_scalar(dtype_str, expr, lib_path): + + @triton.jit + def kernel(X, Y, BLOCK: tl.constexpr): + x = X + y = GENERATE_TEST_HERE + tl.store(Y + tl.arange(0, BLOCK), y) + + shape = (128, ) + rs = RandomState(17) + # limit the range of integers so that the sum does not overflow + x = numpy_random((1,), dtype_str=dtype_str, rs=rs) + y_ref = np.zeros(shape, dtype=x.dtype) + + # numpy does not allow negative factors in power, so we use abs() + x = np.abs(x) + kernel = patch_kernel(kernel, {'GENERATE_TEST_HERE': 'tl.libdevice.pow(x, x)'}) + y_ref[:] = np.power(x, x) + + # triton result + x_tri = to_triton(x)[0].item() + y_tri = to_triton(numpy_random((shape[0],), dtype_str=dtype_str, rs=rs), device='cuda') + kernel[(1,)](x_tri, y_tri, BLOCK=shape[0], extern_libs={'libdevice': lib_path}) + # compare + np.testing.assert_allclose(y_ref, to_numpy(y_tri), rtol=0.01) diff --git a/python/tests/test_printf.py b/python/tests/test_printf.py new file mode 100644 index 000000000..cab0573ca --- /dev/null +++ b/python/tests/test_printf.py @@ -0,0 +1,21 @@ +import os +import subprocess + +dir_path = os.path.dirname(os.path.realpath(__file__)) +printf_path = os.path.join(dir_path, "printf_helper.py") + + +def test_printf(): + proc = subprocess.Popen(["python", printf_path], stdout=subprocess.PIPE, shell=False) + (outs, err) = proc.communicate() + outs = outs.split() + new_lines = set() + for line in outs: + try: + value = int(float(line)) + new_lines.add(value) + except Exception as e: + print(e) + for i in range(128): + assert i in new_lines + assert len(new_lines) == 128 diff --git a/python/triton/language/core.py b/python/triton/language/core.py index 8c2708074..94da20f9d 100644 --- a/python/triton/language/core.py +++ b/python/triton/language/core.py @@ -1197,3 +1197,22 @@ def swizzle2d(i, j, size_i, size_j, size_g): @triton.jit def zeros_like(input): return zeros(input.shape, input.dtype) + + +@builtin +def printf(prefix, *args, _builder=None): + import string + new_prefix = prefix + if isinstance(prefix, constexpr): + new_prefix = prefix.value + assert isinstance(new_prefix, str), f"{new_prefix} is not string" + b_ascii = True + for ch in new_prefix: + if ch not in string.printable: + b_ascii = False + break + assert b_ascii, f"{new_prefix} is not an ascii string" + new_args = [] + for arg in args: + new_args.append(_to_tensor(arg, _builder)) + return semantic.printf(new_prefix, new_args, _builder) diff --git a/python/triton/language/extern.py b/python/triton/language/extern.py index 2ef440633..3bb457fb8 100644 --- a/python/triton/language/extern.py +++ b/python/triton/language/extern.py @@ -56,28 +56,34 @@ def elementwise(lib_name: str, lib_path: str, args: list, arg_type_symbol_dict: :return: the return value of the function ''' dispatch_args = args.copy() - if len(args) == 1: - dispatch_args[0] = core._to_tensor(dispatch_args[0], _builder) - ret_shape = dispatch_args[0].shape - elif len(args) == 2: - dispatch_args[0] = core._to_tensor(dispatch_args[0], _builder) - dispatch_args[1] = core._to_tensor(dispatch_args[1], _builder) - dispatch_args[0], dispatch_args[1] = semantic.binary_op_type_checking_impl( - dispatch_args[0], dispatch_args[1], _builder) - ret_shape = dispatch_args[0].shape - else: - for i in range(len(dispatch_args)): - dispatch_args[i] = core._to_tensor(dispatch_args[i], _builder) - broadcast_arg = dispatch_args[0] - # Get the broadcast shape over all the arguments - for i in range(len(dispatch_args)): - _, broadcast_arg = semantic.binary_op_type_checking_impl( - dispatch_args[i], broadcast_arg, _builder) - # Change the shape of each argument based on the broadcast shape - for i in range(len(dispatch_args)): - dispatch_args[i], _ = semantic.binary_op_type_checking_impl( - dispatch_args[i], broadcast_arg, _builder) - ret_shape = broadcast_arg.shape + all_scalar = True + ret_shape = None + for dispatch_arg in dispatch_args: + if dispatch_arg.type.is_block(): + all_scalar = False + if not all_scalar: + if len(args) == 1: + dispatch_args[0] = core._to_tensor(dispatch_args[0], _builder) + ret_shape = dispatch_args[0].shape + elif len(args) == 2: + dispatch_args[0] = core._to_tensor(dispatch_args[0], _builder) + dispatch_args[1] = core._to_tensor(dispatch_args[1], _builder) + dispatch_args[0], dispatch_args[1] = semantic.binary_op_type_checking_impl( + dispatch_args[0], dispatch_args[1], _builder) + ret_shape = dispatch_args[0].shape + else: + for i in range(len(dispatch_args)): + dispatch_args[i] = core._to_tensor(dispatch_args[i], _builder) + broadcast_arg = dispatch_args[0] + # Get the broadcast shape over all the arguments + for i in range(len(dispatch_args)): + _, broadcast_arg = semantic.binary_op_type_checking_impl( + dispatch_args[i], broadcast_arg, _builder) + # Change the shape of each argument based on the broadcast shape + for i in range(len(dispatch_args)): + dispatch_args[i], _ = semantic.binary_op_type_checking_impl( + dispatch_args[i], broadcast_arg, _builder) + ret_shape = broadcast_arg.shape func = getattr(_builder, "create_external_elementwise") return dispatch(func, lib_name, lib_path, dispatch_args, arg_type_symbol_dict, ret_shape, _builder) diff --git a/python/triton/language/semantic.py b/python/triton/language/semantic.py index 748470957..b56e40c35 100644 --- a/python/triton/language/semantic.py +++ b/python/triton/language/semantic.py @@ -1123,3 +1123,10 @@ def max_contiguous(x: tl.tensor, values: List[int]) -> tl.tensor: def debug_barrier(builder: ir.builder) -> tl.tensor: return tl.tensor(builder.create_barrier(''), tl.void) + + +def printf(prefix: str, args: List[tl.tensor], builder: ir.builder) -> tl.tensor: + new_args = [] + for arg in args: + new_args.append(arg.handle) + return tl.tensor(builder.create_printf(prefix, new_args), tl.void) diff --git a/python/tutorials/03-matrix-multiplication.py b/python/tutorials/03-matrix-multiplication.py index f685a6059..82a264dc2 100644 --- a/python/tutorials/03-matrix-multiplication.py +++ b/python/tutorials/03-matrix-multiplication.py @@ -157,15 +157,6 @@ import triton.language as tl @triton.autotune( configs=[ triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=3, num_warps=8), - triton.Config({'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=3, num_warps=8), - triton.Config({'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4), - triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4), - triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4), - triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4), - triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4), - triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4), - triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=5, num_warps=2), - triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=5, num_warps=2), ], key=['M', 'N', 'K'], ) @@ -318,13 +309,13 @@ else: triton.testing.Benchmark( x_names=['M', 'N', 'K'], # argument names to use as an x-axis for the plot x_vals=[ - 128 * i for i in range(2, 33) + 8192 ], # different possible values for `x_name` line_arg='provider', # argument name whose value corresponds to a different line in the plot # possible values for `line_arg`` - line_vals=['cublas', 'cublas + relu', 'triton', 'triton + relu'], + line_vals=['cublas', 'triton'], # label name for the lines - line_names=["cuBLAS", "cuBLAS (+ torch.nn.LeakyReLU)", "Triton", "Triton (+ LeakyReLU)"], + line_names=["cuBLAS", "Triton"], # line styles styles=[('green', '-'), ('green', '--'), ('blue', '-'), ('blue', '--')], ylabel="TFLOPS", # label name for the y-axis @@ -336,18 +327,9 @@ def benchmark(M, N, K, provider): a = torch.randn((M, K), device='cuda', dtype=torch.float16) b = torch.randn((K, N), device='cuda', dtype=torch.float16) if provider == 'cublas': - ms, min_ms, max_ms = triton.testing.do_bench(lambda: torch.matmul(a, b)) + ms, min_ms, max_ms = triton.testing.do_bench(lambda: torch.matmul(a, b), rep=100) if provider == 'triton': - ms, min_ms, max_ms = triton.testing.do_bench(lambda: matmul(a, b)) - if provider == 'cublas + relu': - torch_relu = torch.nn.ReLU(inplace=True) - ms, min_ms, max_ms = triton.testing.do_bench( - lambda: torch_relu(torch.matmul(a, b)) - ) - if provider == 'triton + relu': - ms, min_ms, max_ms = triton.testing.do_bench( - lambda: matmul(a, b, activation=leaky_relu) - ) + ms, min_ms, max_ms = triton.testing.do_bench(lambda: matmul(a, b), rep=100) perf = lambda ms: 2 * M * N * K * 1e-12 / (ms * 1e-3) return perf(ms), perf(max_ms), perf(min_ms)