[Triton-MLIR]Add ptx vprintf support (#825)

Not know how to write unit test for this feature.

Co-authored-by: Yan Chunwei <yanchunwei@outlook.com>
This commit is contained in:
ben-zhang-609
2022-11-02 16:39:09 +08:00
committed by GitHub
parent 12d60cb4a3
commit 5feb6e24f9
9 changed files with 386 additions and 62 deletions

View File

@@ -30,8 +30,8 @@ class TT_Op<string mnemonic, list<Trait> traits = []> :
// fptoui, fptosi, uitofp, sitofp,
// extf, tructf,
// extui, extsi, tructi
def TT_IntToPtrOp : TT_Op<"int_to_ptr", [SameOperandsAndResultShape,
SameOperandsAndResultEncoding,
def TT_IntToPtrOp : TT_Op<"int_to_ptr", [SameOperandsAndResultShape,
SameOperandsAndResultEncoding,
NoSideEffect,
/*DeclareOpInterfaceMethods<CastOpInterface>*/]> {
let summary = "Cast int64 to pointer";
@@ -43,7 +43,7 @@ def TT_IntToPtrOp : TT_Op<"int_to_ptr", [SameOperandsAndResultShape,
let assemblyFormat = "$from attr-dict `:` type($from) `->` type($result)";
}
def TT_PtrToIntOp : TT_Op<"ptr_to_int", [SameOperandsAndResultShape,
def TT_PtrToIntOp : TT_Op<"ptr_to_int", [SameOperandsAndResultShape,
SameOperandsAndResultEncoding,
NoSideEffect,
/*DeclareOpInterfaceMethods<CastOpInterface>*/]> {
@@ -57,7 +57,7 @@ def TT_PtrToIntOp : TT_Op<"ptr_to_int", [SameOperandsAndResultShape,
}
// arith.bitcast doesn't support pointers
def TT_BitcastOp : TT_Op<"bitcast", [SameOperandsAndResultShape,
def TT_BitcastOp : TT_Op<"bitcast", [SameOperandsAndResultShape,
SameOperandsAndResultEncoding,
NoSideEffect,
/*DeclareOpInterfaceMethods<CastOpInterface>*/]> {
@@ -72,7 +72,7 @@ def TT_BitcastOp : TT_Op<"bitcast", [SameOperandsAndResultShape,
// TODO: Add verifier
}
def TT_FpToFp : TT_Op<"fp_to_fp", [SameOperandsAndResultShape,
def TT_FpToFp : TT_Op<"fp_to_fp", [SameOperandsAndResultShape,
SameOperandsAndResultEncoding,
NoSideEffect,
/*DeclareOpInterfaceMethods<CastOpInterface>*/]> {
@@ -99,7 +99,7 @@ def TT_FpToFp : TT_Op<"fp_to_fp", [SameOperandsAndResultShape,
//
def TT_AddPtrOp : TT_Op<"addptr",
[NoSideEffect,
[NoSideEffect,
SameOperandsAndResultShape,
SameOperandsAndResultEncoding,
TypesMatchWith<"result type matches ptr type",
@@ -224,7 +224,7 @@ def TT_AtomicCASOp : TT_Op<"atomic_cas", [SameOperandsAndResultShape,
//
// Shape Manipulation Ops
//
def TT_SplatOp : TT_Op<"splat", [NoSideEffect,
def TT_SplatOp : TT_Op<"splat", [NoSideEffect,
SameOperandsAndResultElementType]> {
let summary = "splat";
@@ -237,8 +237,8 @@ def TT_SplatOp : TT_Op<"splat", [NoSideEffect,
let hasFolder = 1;
}
def TT_ExpandDimsOp : TT_Op<"expand_dims", [NoSideEffect,
DeclareOpInterfaceMethods<InferTypeOpInterface>,
def TT_ExpandDimsOp : TT_Op<"expand_dims", [NoSideEffect,
DeclareOpInterfaceMethods<InferTypeOpInterface>,
SameOperandsAndResultElementType]> {
let summary = "expand_dims";
@@ -249,7 +249,7 @@ def TT_ExpandDimsOp : TT_Op<"expand_dims", [NoSideEffect,
let assemblyFormat = "$src attr-dict `:` functional-type(operands, results)";
}
def TT_ViewOp : TT_Op<"view", [NoSideEffect,
def TT_ViewOp : TT_Op<"view", [NoSideEffect,
SameOperandsAndResultElementType]> {
let summary = "view";
@@ -261,7 +261,7 @@ def TT_ViewOp : TT_Op<"view", [NoSideEffect,
}
def TT_BroadcastOp : TT_Op<"broadcast", [NoSideEffect,
def TT_BroadcastOp : TT_Op<"broadcast", [NoSideEffect,
SameOperandsAndResultElementType]> {
let summary = "broadcast. No left-padding as of now.";
@@ -274,7 +274,7 @@ def TT_BroadcastOp : TT_Op<"broadcast", [NoSideEffect,
let hasFolder = 1;
}
def TT_CatOp : TT_Op<"cat", [NoSideEffect,
def TT_CatOp : TT_Op<"cat", [NoSideEffect,
SameOperandsAndResultElementType]> {
let summary = "concatenate 2 tensors";
@@ -307,7 +307,7 @@ def TT_GetNumProgramsOp : TT_Op<"get_num_programs", [NoSideEffect]> {
//
// Dot Op
//
def TT_DotOp : TT_Op<"dot", [NoSideEffect,
def TT_DotOp : TT_Op<"dot", [NoSideEffect,
DeclareOpInterfaceMethods<InferTypeOpInterface>,
TypesMatchWith<"result's type matches accumulator's type",
"d", "c", "$_self">]> {
@@ -385,4 +385,20 @@ def TT_MakeRangeOp : TT_Op<"make_range", [NoSideEffect]> {
let assemblyFormat = "attr-dict `:` type($result)";
}
//
// Make PrintfOp
//
def TT_PrintfOp : TT_Op<"printf", [MemoryEffects<[MemWrite]>]>,
Arguments<(ins StrAttr:$prefix,
Variadic<AnyTypeOf<[TT_Type]>>:$args)> {
let summary = "Device-side printf, as in CUDA for debugging";
let description = [{
`tt.printf` takes a literal string prefix and an arbitrary number of scalar or tensor arguments that should be printed.
format are generated automatically from the arguments.
}];
let assemblyFormat = [{
$prefix attr-dict ($args^ `:` type($args))?
}];
}
#endif // Triton_OPS

View File

@@ -119,9 +119,11 @@ Value createLLVMIntegerConstant(OpBuilder &builder, Location loc, short width,
#define barrier() rewriter.create<mlir::gpu::BarrierOp>(loc)
#define undef(...) rewriter.create<LLVM::UndefOp>(loc, __VA_ARGS__)
#define i32_ty rewriter.getIntegerType(32)
#define ui32_ty rewriter.getIntegerType(32, false)
#define f16_ty rewriter.getF16Type()
#define i8_ty rewriter.getIntegerType(8)
#define f32_ty rewriter.getF32Type()
#define f64_ty rewriter.getF64Type()
#define vec_ty(type, num) VectorType::get(num, type)
#define f32_val(...) LLVM::createConstantF32(loc, rewriter, __VA_ARGS__)
#define void_ty(ctx) LLVM::LLVMVoidType::get(ctx)
@@ -1807,6 +1809,7 @@ public:
}
Value view = getStructFromElements(loc, resultVals, rewriter, structTy);
rewriter.replaceOp(op, view);
return success();
}
@@ -4541,6 +4544,190 @@ struct FDivOpConversion
}
};
struct PrintfOpConversion
: public ConvertTritonGPUOpToLLVMPattern<triton::PrintfOp> {
using ConvertTritonGPUOpToLLVMPattern<
triton::PrintfOp>::ConvertTritonGPUOpToLLVMPattern;
LogicalResult
matchAndRewrite(triton::PrintfOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto loc = op->getLoc();
SmallVector<Value, 16> operands;
for (auto operand : adaptor.getOperands()) {
auto sub_operands = this->getElementsFromStruct(loc, operand, rewriter);
for (auto elem : sub_operands) {
operands.push_back(elem);
}
}
std::string formatStr;
llvm::raw_string_ostream os(formatStr);
os << op.prefix();
if (operands.size() > 0) {
os << getFormatSubstr(operands[0]);
}
for (size_t i = 1; i < operands.size(); ++i) {
os << ", " << getFormatSubstr(operands[i]);
}
llPrintf(formatStr, operands, rewriter);
rewriter.eraseOp(op);
return success();
}
// get format specific for each input value
// currently support pointer, i8, i16, i32, i64, f16, bf16, f32, f64
std::string getFormatSubstr(Value value) const {
Type type = value.getType();
unsigned width = type.getIntOrFloatBitWidth();
if (type.isa<LLVM::LLVMPointerType>()) {
return "%p";
} else if (type.isBF16() || type.isF16() || type.isF32() || type.isF64()) {
return "%f";
} else if (type.isSignedInteger()) {
return "%i";
} else if (type.isUnsignedInteger() || type.isSignlessInteger()) {
return "%u";
}
assert(false && "not supported type");
}
// declare vprintf(i8*, i8*) as external function
LLVM::LLVMFuncOp
getVprintfDeclaration(ConversionPatternRewriter &rewriter) const {
auto moduleOp =
rewriter.getBlock()->getParent()->getParentOfType<ModuleOp>();
StringRef funcName("vprintf");
Operation *funcOp = moduleOp.lookupSymbol(funcName);
if (funcOp)
return cast<LLVM::LLVMFuncOp>(*funcOp);
auto *context = rewriter.getContext();
SmallVector<Type> argsType{ptr_ty(IntegerType::get(context, 8)),
ptr_ty(IntegerType::get(context, 8))};
auto funcType = LLVM::LLVMFunctionType::get(i32_ty, argsType);
ConversionPatternRewriter::InsertionGuard guard(rewriter);
rewriter.setInsertionPointToStart(moduleOp.getBody());
return rewriter.create<LLVM::LLVMFuncOp>(UnknownLoc::get(context), funcName,
funcType);
}
// extend integer to int32, extend float to float64
// this comes from vprintf alignment requirements.
std::pair<Type, Value> promoteValue(ConversionPatternRewriter &rewriter,
Value value) const {
auto *context = rewriter.getContext();
auto type = value.getType();
unsigned width = type.getIntOrFloatBitWidth();
Value newOp = value;
Type newType = type;
bool bUnsigned = type.isUnsignedInteger();
if (type.isIntOrIndex() && width < 32) {
if (bUnsigned) {
newType = ui32_ty;
newOp = rewriter.create<LLVM::ZExtOp>(UnknownLoc::get(context), newType,
value);
} else {
newType = i32_ty;
newOp = rewriter.create<LLVM::SExtOp>(UnknownLoc::get(context), newType,
value);
}
} else if (type.isBF16() || type.isF16() || type.isF32()) {
newType = f64_ty;
newOp = rewriter.create<LLVM::FPExtOp>(UnknownLoc::get(context), newType,
value);
}
return {newType, newOp};
}
void llPrintf(StringRef msg, ValueRange args,
ConversionPatternRewriter &rewriter) const {
static const char formatStringPrefix[] = "printfFormat_";
assert(!msg.empty() && "printf with empty string not support");
Type int8Ptr = ptr_ty(i8_ty);
auto *context = rewriter.getContext();
auto moduleOp =
rewriter.getBlock()->getParent()->getParentOfType<ModuleOp>();
auto funcOp = getVprintfDeclaration(rewriter);
Value one = rewriter.create<LLVM::ConstantOp>(
UnknownLoc::get(context), i32_ty, rewriter.getI32IntegerAttr(1));
Value zero = rewriter.create<LLVM::ConstantOp>(
UnknownLoc::get(context), i32_ty, rewriter.getI32IntegerAttr(0));
unsigned stringNumber = 0;
SmallString<16> stringConstName;
do {
stringConstName.clear();
(formatStringPrefix + Twine(stringNumber++)).toStringRef(stringConstName);
} while (moduleOp.lookupSymbol(stringConstName));
llvm::SmallString<64> formatString(msg);
formatString.push_back('\n');
formatString.push_back('\0');
size_t formatStringSize = formatString.size_in_bytes();
auto globalType = LLVM::LLVMArrayType::get(i8_ty, formatStringSize);
LLVM::GlobalOp global;
{
ConversionPatternRewriter::InsertionGuard guard(rewriter);
rewriter.setInsertionPointToStart(moduleOp.getBody());
global = rewriter.create<LLVM::GlobalOp>(
UnknownLoc::get(context), globalType,
/*isConstant=*/true, LLVM::Linkage::Internal, stringConstName,
rewriter.getStringAttr(formatString));
}
Value globalPtr =
rewriter.create<LLVM::AddressOfOp>(UnknownLoc::get(context), global);
Value stringStart =
rewriter.create<LLVM::GEPOp>(UnknownLoc::get(context), int8Ptr,
globalPtr, mlir::ValueRange({zero, zero}));
Value bufferPtr =
rewriter.create<LLVM::NullOp>(UnknownLoc::get(context), int8Ptr);
SmallVector<Value, 16> newArgs;
if (args.size() >= 1) {
SmallVector<Type> argTypes;
for (auto arg : args) {
Type newType;
Value newArg;
std::tie(newType, newArg) = promoteValue(rewriter, arg);
argTypes.push_back(newType);
newArgs.push_back(newArg);
}
Type structTy = LLVM::LLVMStructType::getLiteral(context, argTypes);
auto allocated = rewriter.create<LLVM::AllocaOp>(UnknownLoc::get(context),
ptr_ty(structTy), one,
/*alignment=*/0);
for (const auto &entry : llvm::enumerate(newArgs)) {
auto index = rewriter.create<LLVM::ConstantOp>(
UnknownLoc::get(context), i32_ty,
rewriter.getI32IntegerAttr(entry.index()));
auto fieldPtr = rewriter.create<LLVM::GEPOp>(
UnknownLoc::get(context), ptr_ty(argTypes[entry.index()]),
allocated, ArrayRef<Value>{zero, index});
rewriter.create<LLVM::StoreOp>(UnknownLoc::get(context), entry.value(),
fieldPtr);
}
bufferPtr = rewriter.create<LLVM::BitcastOp>(UnknownLoc::get(context),
int8Ptr, allocated);
}
ValueRange operands{stringStart, bufferPtr};
rewriter.create<LLVM::CallOp>(UnknownLoc::get(context), funcOp, operands);
}
};
void populateTritonToLLVMPatterns(mlir::LLVMTypeConverter &typeConverter,
RewritePatternSet &patterns, int numWarps,
AxisInfoAnalysis &axisInfoAnalysis,
@@ -4627,6 +4814,7 @@ void populateTritonToLLVMPatterns(mlir::LLVMTypeConverter &typeConverter,
patterns.add<ViewLikeOpConversion<triton::ExpandDimsOp>>(typeConverter,
benefit);
patterns.add<DotOpConversion>(typeConverter, allocation, smem, benefit);
patterns.add<PrintfOpConversion>(typeConverter, benefit);
}
class ConvertTritonGPUToLLVM

View File

@@ -339,6 +339,18 @@ struct TritonReducePattern : public OpConversionPattern<triton::ReduceOp> {
}
};
struct TritonPrintfPattern : public OpConversionPattern<triton::PrintfOp> {
using OpConversionPattern<PrintfOp>::OpConversionPattern;
LogicalResult
matchAndRewrite(PrintfOp op, typename PrintfOp::Adaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
rewriter.replaceOpWithNewOp<triton::PrintfOp>(op, op.prefixAttr(),
adaptor.getOperands());
return success();
}
};
void populateTritonPatterns(TritonGPUTypeConverter &typeConverter,
RewritePatternSet &patterns) {
MLIRContext *context = patterns.getContext();
@@ -350,8 +362,8 @@ void populateTritonPatterns(TritonGPUTypeConverter &typeConverter,
TritonGenericPattern<triton::SplatOp>, TritonBroadcastPattern,
TritonGenericPattern<triton::AddPtrOp>, TritonReducePattern,
TritonExpandDimsPattern, TritonMakeRangePattern, TritonDotPattern,
TritonLoadPattern, TritonStorePattern, TritonExtElemwisePattern>(
typeConverter, context);
TritonLoadPattern, TritonStorePattern, TritonExtElemwisePattern,
TritonPrintfPattern>(typeConverter, context);
}
//

View File

@@ -1185,6 +1185,16 @@ void init_triton_ir(py::module &&m) {
auto loc = self.getUnknownLoc();
return self.create<mlir::SelectOp>(loc, condition, trueValue,
falseValue);
})
.def("create_printf",
[](mlir::OpBuilder &self, const std::string &prefix,
const std::vector<mlir::Value> &values) -> void {
auto loc = self.getUnknownLoc();
self.create<mlir::triton::PrintfOp>(
loc,
mlir::StringAttr::get(self.getContext(),
llvm::StringRef(prefix)),
values);
});
py::class_<mlir::PassManager>(m, "pass_manager")

View File

@@ -0,0 +1,56 @@
import torch
from torch.testing import assert_close
import triton
import triton.language as tl
torch_type = {
"bool": torch.bool,
'int8': torch.int8,
'uint8': torch.uint8,
'int16': torch.int16,
"int32": torch.int32,
'int64': torch.long,
'float16': torch.float16,
'bfloat16': torch.bfloat16,
"float32": torch.float32,
"float64": torch.float64
}
def get_tensor(shape, data_type, b_positive=False):
x = None
if data_type.startswith('int'):
x = torch.arange(0, shape[0], dtype=torch_type[data_type], device='cuda')
else:
x = torch.arange(0, shape[0], dtype=torch_type[data_type], device='cuda')
return x
# @pytest.mark.parametrize('data_type',
# [("int8"),
# ('int16'),
# ('int32'),
# ("int64"),
# ('float16'),
# ("float32"),
# ("float64")])
def printf(data_type):
@triton.jit
def kernel(X, Y, BLOCK: tl.constexpr):
x = tl.load(X + tl.arange(0, BLOCK))
tl.printf("", x)
tl.store(Y + tl.arange(0, BLOCK), x)
shape = (128, )
# limit the range of integers so that the sum does not overflow
x = get_tensor(shape, data_type)
y = torch.zeros(shape, dtype=x.dtype, device="cuda")
kernel[(1,)](x, y, BLOCK=shape[0])
assert_close(y, x)
printf("float16")
printf("int8")

View File

@@ -144,7 +144,7 @@ def _test_unary(dtype_x, expr, numpy_expr=None, device='cuda'):
# triton result
x_tri = to_triton(x, device=device, dst_type=dtype_x)
z_tri = to_triton(np.empty_like(z_ref), device=device, dst_type=dtype_x)
kernel[(1, )](z_tri, x_tri, SIZE=SIZE, num_warps=4)
kernel[(1, )](z_tri, x_tri, SIZE=SIZE, num_warps=4, extern_libs={"libdevice": "/usr/local/cuda/nvvm/libdevice/libdevice.10.bc"})
# compare
np.testing.assert_allclose(z_ref, to_numpy(z_tri), rtol=0.01)
@@ -463,17 +463,12 @@ def test_unary_op(dtype_x, expr, device='cuda'):
# # test math ops
# # ----------------
# TODO: Math module
# # @pytest.mark.parametrize("expr", [
# # 'exp', 'log', 'cos', 'sin'
# # ])
# @pytest.mark.parametrize("expr", [
# 'exp', 'log', 'cos', 'sin'
# ])
# def test_math_op(expr, device='cuda'):
# _test_unary('float32', f'tl.{expr}(x)', f'np.{expr}(x) ', device=device)
@pytest.mark.parametrize("expr", [
'exp', 'log', 'cos', 'sin'
])
def test_math_op(expr, device='cuda'):
_test_unary('float32', f'tl.{expr}(x)', f'np.{expr}(x) ', device=device)
# # ----------------
@@ -1545,43 +1540,43 @@ def test_num_warps_pow2():
# # -------------
# @pytest.mark.parametrize("dtype_str, expr, lib_path",
# [('int32', 'libdevice.ffs', ''),
# ('float32', 'libdevice.pow', '/usr/local/cuda/nvvm/libdevice/libdevice.10.bc'),
# ('float64', 'libdevice.norm4d', '')])
# def test_libdevice(dtype_str, expr, lib_path):
@pytest.mark.parametrize("dtype_str, expr, lib_path",
[('int32', 'libdevice.ffs', ''),
('float32', 'libdevice.pow', '/usr/local/cuda/nvvm/libdevice/libdevice.10.bc'),
('float64', 'libdevice.norm4d', '')])
def test_libdevice(dtype_str, expr, lib_path):
# @triton.jit
# def kernel(X, Y, BLOCK: tl.constexpr):
# x = tl.load(X + tl.arange(0, BLOCK))
# y = GENERATE_TEST_HERE
# tl.store(Y + tl.arange(0, BLOCK), y)
@triton.jit
def kernel(X, Y, BLOCK: tl.constexpr):
x = tl.load(X + tl.arange(0, BLOCK))
y = GENERATE_TEST_HERE
tl.store(Y + tl.arange(0, BLOCK), y)
# shape = (128, )
# rs = RandomState(17)
# # limit the range of integers so that the sum does not overflow
# x = numpy_random(shape, dtype_str=dtype_str, rs=rs)
shape = (128, )
rs = RandomState(17)
# limit the range of integers so that the sum does not overflow
x = numpy_random(shape, dtype_str=dtype_str, rs=rs)
# if expr == 'libdevice.ffs':
# kernel = patch_kernel(kernel, {'GENERATE_TEST_HERE': 'tl.libdevice.ffs(x)'})
# y_ref = np.zeros(shape, dtype=x.dtype)
# for i in range(shape[0]):
# y_ref[i] = (int(x[i]) & int(-x[i])).bit_length()
# elif expr == 'libdevice.pow':
# # numpy does not allow negative factors in power, so we use abs()
# x = np.abs(x)
# kernel = patch_kernel(kernel, {'GENERATE_TEST_HERE': 'tl.libdevice.pow(x, x)'})
# y_ref = np.power(x, x)
# elif expr == 'libdevice.norm4d':
# kernel = patch_kernel(kernel, {'GENERATE_TEST_HERE': 'tl.libdevice.norm4d(x, x, x, x)'})
# y_ref = np.sqrt(4 * np.power(x, 2))
if expr == 'libdevice.ffs':
kernel = patch_kernel(kernel, {'GENERATE_TEST_HERE': 'tl.libdevice.ffs(x)'})
y_ref = np.zeros(shape, dtype=x.dtype)
for i in range(shape[0]):
y_ref[i] = (int(x[i]) & int(-x[i])).bit_length()
elif expr == 'libdevice.pow':
# numpy does not allow negative factors in power, so we use abs()
x = np.abs(x)
kernel = patch_kernel(kernel, {'GENERATE_TEST_HERE': 'tl.libdevice.pow(x, x)'})
y_ref = np.power(x, x)
elif expr == 'libdevice.norm4d':
kernel = patch_kernel(kernel, {'GENERATE_TEST_HERE': 'tl.libdevice.norm4d(x, x, x, x)'})
y_ref = np.sqrt(4 * np.power(x, 2))
# x_tri = to_triton(x)
# # triton result
# y_tri = to_triton(numpy_random((shape[0],), dtype_str=dtype_str, rs=rs), device='cuda')
# kernel[(1,)](x_tri, y_tri, BLOCK=shape[0], extern_libs={'libdevice': lib_path})
# # compare
# if expr == 'libdevice.ffs':
# np.testing.assert_equal(y_ref, to_numpy(y_tri))
# else:
# np.testing.assert_allclose(y_ref, to_numpy(y_tri), rtol=0.01)
x_tri = to_triton(x)
# triton result
y_tri = to_triton(numpy_random((shape[0],), dtype_str=dtype_str, rs=rs), device='cuda')
kernel[(1,)](x_tri, y_tri, BLOCK=shape[0], extern_libs={'libdevice': lib_path})
# compare
if expr == 'libdevice.ffs':
np.testing.assert_equal(y_ref, to_numpy(y_tri))
else:
np.testing.assert_allclose(y_ref, to_numpy(y_tri), rtol=0.01)

View File

@@ -0,0 +1,21 @@
import os
import subprocess
dir_path = os.path.dirname(os.path.realpath(__file__))
printf_path = os.path.join(dir_path, "printf_helper.py")
def test_printf():
proc = subprocess.Popen(["python", printf_path], stdout=subprocess.PIPE, shell=False)
(outs, err) = proc.communicate()
outs = outs.split()
new_lines = set()
for line in outs:
try:
value = int(float(line))
new_lines.add(value)
except Exception as e:
print(e)
for i in range(128):
assert i in new_lines
assert len(new_lines) == 128

View File

@@ -1197,3 +1197,22 @@ def swizzle2d(i, j, size_i, size_j, size_g):
@triton.jit
def zeros_like(input):
return zeros(input.shape, input.dtype)
@builtin
def printf(prefix, *args, _builder=None):
import string
new_prefix = prefix
if isinstance(prefix, constexpr):
new_prefix = prefix.value
assert isinstance(new_prefix, str), f"{new_prefix} is not string"
b_ascii = True
for ch in new_prefix:
if ch not in string.printable:
b_ascii = False
break
assert b_ascii, f"{new_prefix} is not an ascii string"
new_args = []
for arg in args:
new_args.append(_to_tensor(arg, _builder))
return semantic.printf(new_prefix, new_args, _builder)

View File

@@ -1123,3 +1123,10 @@ def max_contiguous(x: tl.tensor, values: List[int]) -> tl.tensor:
def debug_barrier(builder: ir.builder) -> tl.tensor:
return tl.tensor(builder.create_barrier(''), tl.void)
def printf(prefix: str, args: List[tl.tensor], builder: ir.builder) -> tl.tensor:
new_args = []
for arg in args:
new_args.append(arg.handle)
return tl.tensor(builder.create_printf(prefix, new_args), tl.void)