diff --git a/include/triton/Dialect/Triton/IR/TritonOps.td b/include/triton/Dialect/Triton/IR/TritonOps.td index 325ea3500..364f942e0 100644 --- a/include/triton/Dialect/Triton/IR/TritonOps.td +++ b/include/triton/Dialect/Triton/IR/TritonOps.td @@ -30,8 +30,8 @@ class TT_Op traits = []> : // fptoui, fptosi, uitofp, sitofp, // extf, tructf, // extui, extsi, tructi -def TT_IntToPtrOp : TT_Op<"int_to_ptr", [SameOperandsAndResultShape, - SameOperandsAndResultEncoding, +def TT_IntToPtrOp : TT_Op<"int_to_ptr", [SameOperandsAndResultShape, + SameOperandsAndResultEncoding, NoSideEffect, /*DeclareOpInterfaceMethods*/]> { let summary = "Cast int64 to pointer"; @@ -43,7 +43,7 @@ def TT_IntToPtrOp : TT_Op<"int_to_ptr", [SameOperandsAndResultShape, let assemblyFormat = "$from attr-dict `:` type($from) `->` type($result)"; } -def TT_PtrToIntOp : TT_Op<"ptr_to_int", [SameOperandsAndResultShape, +def TT_PtrToIntOp : TT_Op<"ptr_to_int", [SameOperandsAndResultShape, SameOperandsAndResultEncoding, NoSideEffect, /*DeclareOpInterfaceMethods*/]> { @@ -57,7 +57,7 @@ def TT_PtrToIntOp : TT_Op<"ptr_to_int", [SameOperandsAndResultShape, } // arith.bitcast doesn't support pointers -def TT_BitcastOp : TT_Op<"bitcast", [SameOperandsAndResultShape, +def TT_BitcastOp : TT_Op<"bitcast", [SameOperandsAndResultShape, SameOperandsAndResultEncoding, NoSideEffect, /*DeclareOpInterfaceMethods*/]> { @@ -72,7 +72,7 @@ def TT_BitcastOp : TT_Op<"bitcast", [SameOperandsAndResultShape, // TODO: Add verifier } -def TT_FpToFp : TT_Op<"fp_to_fp", [SameOperandsAndResultShape, +def TT_FpToFp : TT_Op<"fp_to_fp", [SameOperandsAndResultShape, SameOperandsAndResultEncoding, NoSideEffect, /*DeclareOpInterfaceMethods*/]> { @@ -99,7 +99,7 @@ def TT_FpToFp : TT_Op<"fp_to_fp", [SameOperandsAndResultShape, // def TT_AddPtrOp : TT_Op<"addptr", - [NoSideEffect, + [NoSideEffect, SameOperandsAndResultShape, SameOperandsAndResultEncoding, TypesMatchWith<"result type matches ptr type", @@ -224,7 +224,7 @@ def TT_AtomicCASOp : TT_Op<"atomic_cas", [SameOperandsAndResultShape, // // Shape Manipulation Ops // -def TT_SplatOp : TT_Op<"splat", [NoSideEffect, +def TT_SplatOp : TT_Op<"splat", [NoSideEffect, SameOperandsAndResultElementType]> { let summary = "splat"; @@ -237,8 +237,8 @@ def TT_SplatOp : TT_Op<"splat", [NoSideEffect, let hasFolder = 1; } -def TT_ExpandDimsOp : TT_Op<"expand_dims", [NoSideEffect, - DeclareOpInterfaceMethods, +def TT_ExpandDimsOp : TT_Op<"expand_dims", [NoSideEffect, + DeclareOpInterfaceMethods, SameOperandsAndResultElementType]> { let summary = "expand_dims"; @@ -249,7 +249,7 @@ def TT_ExpandDimsOp : TT_Op<"expand_dims", [NoSideEffect, let assemblyFormat = "$src attr-dict `:` functional-type(operands, results)"; } -def TT_ViewOp : TT_Op<"view", [NoSideEffect, +def TT_ViewOp : TT_Op<"view", [NoSideEffect, SameOperandsAndResultElementType]> { let summary = "view"; @@ -261,7 +261,7 @@ def TT_ViewOp : TT_Op<"view", [NoSideEffect, } -def TT_BroadcastOp : TT_Op<"broadcast", [NoSideEffect, +def TT_BroadcastOp : TT_Op<"broadcast", [NoSideEffect, SameOperandsAndResultElementType]> { let summary = "broadcast. No left-padding as of now."; @@ -274,7 +274,7 @@ def TT_BroadcastOp : TT_Op<"broadcast", [NoSideEffect, let hasFolder = 1; } -def TT_CatOp : TT_Op<"cat", [NoSideEffect, +def TT_CatOp : TT_Op<"cat", [NoSideEffect, SameOperandsAndResultElementType]> { let summary = "concatenate 2 tensors"; @@ -307,7 +307,7 @@ def TT_GetNumProgramsOp : TT_Op<"get_num_programs", [NoSideEffect]> { // // Dot Op // -def TT_DotOp : TT_Op<"dot", [NoSideEffect, +def TT_DotOp : TT_Op<"dot", [NoSideEffect, DeclareOpInterfaceMethods, TypesMatchWith<"result's type matches accumulator's type", "d", "c", "$_self">]> { @@ -385,4 +385,20 @@ def TT_MakeRangeOp : TT_Op<"make_range", [NoSideEffect]> { let assemblyFormat = "attr-dict `:` type($result)"; } +// +// Make PrintfOp +// +def TT_PrintfOp : TT_Op<"printf", [MemoryEffects<[MemWrite]>]>, + Arguments<(ins StrAttr:$prefix, + Variadic>:$args)> { + let summary = "Device-side printf, as in CUDA for debugging"; + let description = [{ + `tt.printf` takes a literal string prefix and an arbitrary number of scalar or tensor arguments that should be printed. + format are generated automatically from the arguments. + }]; + let assemblyFormat = [{ + $prefix attr-dict ($args^ `:` type($args))? + }]; +} + #endif // Triton_OPS diff --git a/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.cpp b/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.cpp index a8da162a3..44caad4e5 100644 --- a/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.cpp +++ b/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.cpp @@ -119,9 +119,11 @@ Value createLLVMIntegerConstant(OpBuilder &builder, Location loc, short width, #define barrier() rewriter.create(loc) #define undef(...) rewriter.create(loc, __VA_ARGS__) #define i32_ty rewriter.getIntegerType(32) +#define ui32_ty rewriter.getIntegerType(32, false) #define f16_ty rewriter.getF16Type() #define i8_ty rewriter.getIntegerType(8) #define f32_ty rewriter.getF32Type() +#define f64_ty rewriter.getF64Type() #define vec_ty(type, num) VectorType::get(num, type) #define f32_val(...) LLVM::createConstantF32(loc, rewriter, __VA_ARGS__) #define void_ty(ctx) LLVM::LLVMVoidType::get(ctx) @@ -1807,6 +1809,7 @@ public: } Value view = getStructFromElements(loc, resultVals, rewriter, structTy); rewriter.replaceOp(op, view); + return success(); } @@ -4541,6 +4544,190 @@ struct FDivOpConversion } }; +struct PrintfOpConversion + : public ConvertTritonGPUOpToLLVMPattern { + using ConvertTritonGPUOpToLLVMPattern< + triton::PrintfOp>::ConvertTritonGPUOpToLLVMPattern; + + LogicalResult + matchAndRewrite(triton::PrintfOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op->getLoc(); + SmallVector operands; + for (auto operand : adaptor.getOperands()) { + auto sub_operands = this->getElementsFromStruct(loc, operand, rewriter); + for (auto elem : sub_operands) { + operands.push_back(elem); + } + } + std::string formatStr; + llvm::raw_string_ostream os(formatStr); + os << op.prefix(); + if (operands.size() > 0) { + os << getFormatSubstr(operands[0]); + } + + for (size_t i = 1; i < operands.size(); ++i) { + os << ", " << getFormatSubstr(operands[i]); + } + llPrintf(formatStr, operands, rewriter); + rewriter.eraseOp(op); + return success(); + } + // get format specific for each input value + // currently support pointer, i8, i16, i32, i64, f16, bf16, f32, f64 + std::string getFormatSubstr(Value value) const { + Type type = value.getType(); + unsigned width = type.getIntOrFloatBitWidth(); + + if (type.isa()) { + return "%p"; + } else if (type.isBF16() || type.isF16() || type.isF32() || type.isF64()) { + return "%f"; + } else if (type.isSignedInteger()) { + return "%i"; + } else if (type.isUnsignedInteger() || type.isSignlessInteger()) { + return "%u"; + } + assert(false && "not supported type"); + } + + // declare vprintf(i8*, i8*) as external function + LLVM::LLVMFuncOp + getVprintfDeclaration(ConversionPatternRewriter &rewriter) const { + auto moduleOp = + rewriter.getBlock()->getParent()->getParentOfType(); + StringRef funcName("vprintf"); + Operation *funcOp = moduleOp.lookupSymbol(funcName); + if (funcOp) + return cast(*funcOp); + + auto *context = rewriter.getContext(); + + SmallVector argsType{ptr_ty(IntegerType::get(context, 8)), + ptr_ty(IntegerType::get(context, 8))}; + auto funcType = LLVM::LLVMFunctionType::get(i32_ty, argsType); + + ConversionPatternRewriter::InsertionGuard guard(rewriter); + rewriter.setInsertionPointToStart(moduleOp.getBody()); + + return rewriter.create(UnknownLoc::get(context), funcName, + funcType); + } + + // extend integer to int32, extend float to float64 + // this comes from vprintf alignment requirements. + std::pair promoteValue(ConversionPatternRewriter &rewriter, + Value value) const { + auto *context = rewriter.getContext(); + auto type = value.getType(); + unsigned width = type.getIntOrFloatBitWidth(); + Value newOp = value; + Type newType = type; + + bool bUnsigned = type.isUnsignedInteger(); + if (type.isIntOrIndex() && width < 32) { + if (bUnsigned) { + newType = ui32_ty; + newOp = rewriter.create(UnknownLoc::get(context), newType, + value); + } else { + newType = i32_ty; + newOp = rewriter.create(UnknownLoc::get(context), newType, + value); + } + } else if (type.isBF16() || type.isF16() || type.isF32()) { + newType = f64_ty; + newOp = rewriter.create(UnknownLoc::get(context), newType, + value); + } + + return {newType, newOp}; + } + + void llPrintf(StringRef msg, ValueRange args, + ConversionPatternRewriter &rewriter) const { + static const char formatStringPrefix[] = "printfFormat_"; + assert(!msg.empty() && "printf with empty string not support"); + Type int8Ptr = ptr_ty(i8_ty); + + auto *context = rewriter.getContext(); + auto moduleOp = + rewriter.getBlock()->getParent()->getParentOfType(); + auto funcOp = getVprintfDeclaration(rewriter); + + Value one = rewriter.create( + UnknownLoc::get(context), i32_ty, rewriter.getI32IntegerAttr(1)); + Value zero = rewriter.create( + UnknownLoc::get(context), i32_ty, rewriter.getI32IntegerAttr(0)); + + unsigned stringNumber = 0; + SmallString<16> stringConstName; + do { + stringConstName.clear(); + (formatStringPrefix + Twine(stringNumber++)).toStringRef(stringConstName); + } while (moduleOp.lookupSymbol(stringConstName)); + + llvm::SmallString<64> formatString(msg); + formatString.push_back('\n'); + formatString.push_back('\0'); + size_t formatStringSize = formatString.size_in_bytes(); + auto globalType = LLVM::LLVMArrayType::get(i8_ty, formatStringSize); + + LLVM::GlobalOp global; + { + ConversionPatternRewriter::InsertionGuard guard(rewriter); + rewriter.setInsertionPointToStart(moduleOp.getBody()); + global = rewriter.create( + UnknownLoc::get(context), globalType, + /*isConstant=*/true, LLVM::Linkage::Internal, stringConstName, + rewriter.getStringAttr(formatString)); + } + + Value globalPtr = + rewriter.create(UnknownLoc::get(context), global); + Value stringStart = + rewriter.create(UnknownLoc::get(context), int8Ptr, + globalPtr, mlir::ValueRange({zero, zero})); + + Value bufferPtr = + rewriter.create(UnknownLoc::get(context), int8Ptr); + + SmallVector newArgs; + if (args.size() >= 1) { + SmallVector argTypes; + for (auto arg : args) { + Type newType; + Value newArg; + std::tie(newType, newArg) = promoteValue(rewriter, arg); + argTypes.push_back(newType); + newArgs.push_back(newArg); + } + + Type structTy = LLVM::LLVMStructType::getLiteral(context, argTypes); + auto allocated = rewriter.create(UnknownLoc::get(context), + ptr_ty(structTy), one, + /*alignment=*/0); + + for (const auto &entry : llvm::enumerate(newArgs)) { + auto index = rewriter.create( + UnknownLoc::get(context), i32_ty, + rewriter.getI32IntegerAttr(entry.index())); + auto fieldPtr = rewriter.create( + UnknownLoc::get(context), ptr_ty(argTypes[entry.index()]), + allocated, ArrayRef{zero, index}); + rewriter.create(UnknownLoc::get(context), entry.value(), + fieldPtr); + } + bufferPtr = rewriter.create(UnknownLoc::get(context), + int8Ptr, allocated); + } + + ValueRange operands{stringStart, bufferPtr}; + rewriter.create(UnknownLoc::get(context), funcOp, operands); + } +}; + void populateTritonToLLVMPatterns(mlir::LLVMTypeConverter &typeConverter, RewritePatternSet &patterns, int numWarps, AxisInfoAnalysis &axisInfoAnalysis, @@ -4627,6 +4814,7 @@ void populateTritonToLLVMPatterns(mlir::LLVMTypeConverter &typeConverter, patterns.add>(typeConverter, benefit); patterns.add(typeConverter, allocation, smem, benefit); + patterns.add(typeConverter, benefit); } class ConvertTritonGPUToLLVM diff --git a/lib/Conversion/TritonToTritonGPU/TritonToTritonGPU.cpp b/lib/Conversion/TritonToTritonGPU/TritonToTritonGPU.cpp index 9519ff998..706d10fed 100644 --- a/lib/Conversion/TritonToTritonGPU/TritonToTritonGPU.cpp +++ b/lib/Conversion/TritonToTritonGPU/TritonToTritonGPU.cpp @@ -339,6 +339,18 @@ struct TritonReducePattern : public OpConversionPattern { } }; +struct TritonPrintfPattern : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(PrintfOp op, typename PrintfOp::Adaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + rewriter.replaceOpWithNewOp(op, op.prefixAttr(), + adaptor.getOperands()); + return success(); + } +}; + void populateTritonPatterns(TritonGPUTypeConverter &typeConverter, RewritePatternSet &patterns) { MLIRContext *context = patterns.getContext(); @@ -350,8 +362,8 @@ void populateTritonPatterns(TritonGPUTypeConverter &typeConverter, TritonGenericPattern, TritonBroadcastPattern, TritonGenericPattern, TritonReducePattern, TritonExpandDimsPattern, TritonMakeRangePattern, TritonDotPattern, - TritonLoadPattern, TritonStorePattern, TritonExtElemwisePattern>( - typeConverter, context); + TritonLoadPattern, TritonStorePattern, TritonExtElemwisePattern, + TritonPrintfPattern>(typeConverter, context); } // diff --git a/python/src/triton.cc b/python/src/triton.cc index 56ee90b18..4cbc3f1e7 100644 --- a/python/src/triton.cc +++ b/python/src/triton.cc @@ -1185,6 +1185,16 @@ void init_triton_ir(py::module &&m) { auto loc = self.getUnknownLoc(); return self.create(loc, condition, trueValue, falseValue); + }) + .def("create_printf", + [](mlir::OpBuilder &self, const std::string &prefix, + const std::vector &values) -> void { + auto loc = self.getUnknownLoc(); + self.create( + loc, + mlir::StringAttr::get(self.getContext(), + llvm::StringRef(prefix)), + values); }); py::class_(m, "pass_manager") diff --git a/python/tests/printf_helper.py b/python/tests/printf_helper.py new file mode 100644 index 000000000..22e1350f1 --- /dev/null +++ b/python/tests/printf_helper.py @@ -0,0 +1,56 @@ +import torch +from torch.testing import assert_close + +import triton +import triton.language as tl + +torch_type = { + "bool": torch.bool, + 'int8': torch.int8, + 'uint8': torch.uint8, + 'int16': torch.int16, + "int32": torch.int32, + 'int64': torch.long, + 'float16': torch.float16, + 'bfloat16': torch.bfloat16, + "float32": torch.float32, + "float64": torch.float64 +} + + +def get_tensor(shape, data_type, b_positive=False): + x = None + if data_type.startswith('int'): + x = torch.arange(0, shape[0], dtype=torch_type[data_type], device='cuda') + else: + x = torch.arange(0, shape[0], dtype=torch_type[data_type], device='cuda') + + return x + +# @pytest.mark.parametrize('data_type', +# [("int8"), +# ('int16'), +# ('int32'), +# ("int64"), +# ('float16'), +# ("float32"), +# ("float64")]) + + +def printf(data_type): + @triton.jit + def kernel(X, Y, BLOCK: tl.constexpr): + x = tl.load(X + tl.arange(0, BLOCK)) + tl.printf("", x) + tl.store(Y + tl.arange(0, BLOCK), x) + + shape = (128, ) + # limit the range of integers so that the sum does not overflow + x = get_tensor(shape, data_type) + y = torch.zeros(shape, dtype=x.dtype, device="cuda") + kernel[(1,)](x, y, BLOCK=shape[0]) + assert_close(y, x) + + +printf("float16") +printf("int8") diff --git a/python/tests/test_core.py b/python/tests/test_core.py index 5fd25b118..b8cb40bd5 100644 --- a/python/tests/test_core.py +++ b/python/tests/test_core.py @@ -144,7 +144,7 @@ def _test_unary(dtype_x, expr, numpy_expr=None, device='cuda'): # triton result x_tri = to_triton(x, device=device, dst_type=dtype_x) z_tri = to_triton(np.empty_like(z_ref), device=device, dst_type=dtype_x) - kernel[(1, )](z_tri, x_tri, SIZE=SIZE, num_warps=4) + kernel[(1, )](z_tri, x_tri, SIZE=SIZE, num_warps=4, extern_libs={"libdevice": "/usr/local/cuda/nvvm/libdevice/libdevice.10.bc"}) # compare np.testing.assert_allclose(z_ref, to_numpy(z_tri), rtol=0.01) @@ -463,17 +463,12 @@ def test_unary_op(dtype_x, expr, device='cuda'): # # test math ops # # ---------------- -# TODO: Math module -# # @pytest.mark.parametrize("expr", [ -# # 'exp', 'log', 'cos', 'sin' -# # ]) - -# @pytest.mark.parametrize("expr", [ -# 'exp', 'log', 'cos', 'sin' -# ]) -# def test_math_op(expr, device='cuda'): -# _test_unary('float32', f'tl.{expr}(x)', f'np.{expr}(x) ', device=device) +@pytest.mark.parametrize("expr", [ + 'exp', 'log', 'cos', 'sin' +]) +def test_math_op(expr, device='cuda'): + _test_unary('float32', f'tl.{expr}(x)', f'np.{expr}(x) ', device=device) # # ---------------- @@ -1545,43 +1540,43 @@ def test_num_warps_pow2(): # # ------------- -# @pytest.mark.parametrize("dtype_str, expr, lib_path", -# [('int32', 'libdevice.ffs', ''), -# ('float32', 'libdevice.pow', '/usr/local/cuda/nvvm/libdevice/libdevice.10.bc'), -# ('float64', 'libdevice.norm4d', '')]) -# def test_libdevice(dtype_str, expr, lib_path): +@pytest.mark.parametrize("dtype_str, expr, lib_path", + [('int32', 'libdevice.ffs', ''), + ('float32', 'libdevice.pow', '/usr/local/cuda/nvvm/libdevice/libdevice.10.bc'), + ('float64', 'libdevice.norm4d', '')]) +def test_libdevice(dtype_str, expr, lib_path): -# @triton.jit -# def kernel(X, Y, BLOCK: tl.constexpr): -# x = tl.load(X + tl.arange(0, BLOCK)) -# y = GENERATE_TEST_HERE -# tl.store(Y + tl.arange(0, BLOCK), y) + @triton.jit + def kernel(X, Y, BLOCK: tl.constexpr): + x = tl.load(X + tl.arange(0, BLOCK)) + y = GENERATE_TEST_HERE + tl.store(Y + tl.arange(0, BLOCK), y) -# shape = (128, ) -# rs = RandomState(17) -# # limit the range of integers so that the sum does not overflow -# x = numpy_random(shape, dtype_str=dtype_str, rs=rs) + shape = (128, ) + rs = RandomState(17) + # limit the range of integers so that the sum does not overflow + x = numpy_random(shape, dtype_str=dtype_str, rs=rs) -# if expr == 'libdevice.ffs': -# kernel = patch_kernel(kernel, {'GENERATE_TEST_HERE': 'tl.libdevice.ffs(x)'}) -# y_ref = np.zeros(shape, dtype=x.dtype) -# for i in range(shape[0]): -# y_ref[i] = (int(x[i]) & int(-x[i])).bit_length() -# elif expr == 'libdevice.pow': -# # numpy does not allow negative factors in power, so we use abs() -# x = np.abs(x) -# kernel = patch_kernel(kernel, {'GENERATE_TEST_HERE': 'tl.libdevice.pow(x, x)'}) -# y_ref = np.power(x, x) -# elif expr == 'libdevice.norm4d': -# kernel = patch_kernel(kernel, {'GENERATE_TEST_HERE': 'tl.libdevice.norm4d(x, x, x, x)'}) -# y_ref = np.sqrt(4 * np.power(x, 2)) + if expr == 'libdevice.ffs': + kernel = patch_kernel(kernel, {'GENERATE_TEST_HERE': 'tl.libdevice.ffs(x)'}) + y_ref = np.zeros(shape, dtype=x.dtype) + for i in range(shape[0]): + y_ref[i] = (int(x[i]) & int(-x[i])).bit_length() + elif expr == 'libdevice.pow': + # numpy does not allow negative factors in power, so we use abs() + x = np.abs(x) + kernel = patch_kernel(kernel, {'GENERATE_TEST_HERE': 'tl.libdevice.pow(x, x)'}) + y_ref = np.power(x, x) + elif expr == 'libdevice.norm4d': + kernel = patch_kernel(kernel, {'GENERATE_TEST_HERE': 'tl.libdevice.norm4d(x, x, x, x)'}) + y_ref = np.sqrt(4 * np.power(x, 2)) -# x_tri = to_triton(x) -# # triton result -# y_tri = to_triton(numpy_random((shape[0],), dtype_str=dtype_str, rs=rs), device='cuda') -# kernel[(1,)](x_tri, y_tri, BLOCK=shape[0], extern_libs={'libdevice': lib_path}) -# # compare -# if expr == 'libdevice.ffs': -# np.testing.assert_equal(y_ref, to_numpy(y_tri)) -# else: -# np.testing.assert_allclose(y_ref, to_numpy(y_tri), rtol=0.01) + x_tri = to_triton(x) + # triton result + y_tri = to_triton(numpy_random((shape[0],), dtype_str=dtype_str, rs=rs), device='cuda') + kernel[(1,)](x_tri, y_tri, BLOCK=shape[0], extern_libs={'libdevice': lib_path}) + # compare + if expr == 'libdevice.ffs': + np.testing.assert_equal(y_ref, to_numpy(y_tri)) + else: + np.testing.assert_allclose(y_ref, to_numpy(y_tri), rtol=0.01) diff --git a/python/tests/test_printf.py b/python/tests/test_printf.py new file mode 100644 index 000000000..cab0573ca --- /dev/null +++ b/python/tests/test_printf.py @@ -0,0 +1,21 @@ +import os +import subprocess + +dir_path = os.path.dirname(os.path.realpath(__file__)) +printf_path = os.path.join(dir_path, "printf_helper.py") + + +def test_printf(): + proc = subprocess.Popen(["python", printf_path], stdout=subprocess.PIPE, shell=False) + (outs, err) = proc.communicate() + outs = outs.split() + new_lines = set() + for line in outs: + try: + value = int(float(line)) + new_lines.add(value) + except Exception as e: + print(e) + for i in range(128): + assert i in new_lines + assert len(new_lines) == 128 diff --git a/python/triton/language/core.py b/python/triton/language/core.py index 8c2708074..94da20f9d 100644 --- a/python/triton/language/core.py +++ b/python/triton/language/core.py @@ -1197,3 +1197,22 @@ def swizzle2d(i, j, size_i, size_j, size_g): @triton.jit def zeros_like(input): return zeros(input.shape, input.dtype) + + +@builtin +def printf(prefix, *args, _builder=None): + import string + new_prefix = prefix + if isinstance(prefix, constexpr): + new_prefix = prefix.value + assert isinstance(new_prefix, str), f"{new_prefix} is not string" + b_ascii = True + for ch in new_prefix: + if ch not in string.printable: + b_ascii = False + break + assert b_ascii, f"{new_prefix} is not an ascii string" + new_args = [] + for arg in args: + new_args.append(_to_tensor(arg, _builder)) + return semantic.printf(new_prefix, new_args, _builder) diff --git a/python/triton/language/semantic.py b/python/triton/language/semantic.py index 748470957..b56e40c35 100644 --- a/python/triton/language/semantic.py +++ b/python/triton/language/semantic.py @@ -1123,3 +1123,10 @@ def max_contiguous(x: tl.tensor, values: List[int]) -> tl.tensor: def debug_barrier(builder: ir.builder) -> tl.tensor: return tl.tensor(builder.create_barrier(''), tl.void) + + +def printf(prefix: str, args: List[tl.tensor], builder: ir.builder) -> tl.tensor: + new_args = [] + for arg in args: + new_args.append(arg.handle) + return tl.tensor(builder.create_printf(prefix, new_args), tl.void)