[Triton-MLIR]Add ptx vprintf support (#825)
Not know how to write unit test for this feature. Co-authored-by: Yan Chunwei <yanchunwei@outlook.com>
This commit is contained in:
@@ -30,8 +30,8 @@ class TT_Op<string mnemonic, list<Trait> traits = []> :
|
||||
// fptoui, fptosi, uitofp, sitofp,
|
||||
// extf, tructf,
|
||||
// extui, extsi, tructi
|
||||
def TT_IntToPtrOp : TT_Op<"int_to_ptr", [SameOperandsAndResultShape,
|
||||
SameOperandsAndResultEncoding,
|
||||
def TT_IntToPtrOp : TT_Op<"int_to_ptr", [SameOperandsAndResultShape,
|
||||
SameOperandsAndResultEncoding,
|
||||
NoSideEffect,
|
||||
/*DeclareOpInterfaceMethods<CastOpInterface>*/]> {
|
||||
let summary = "Cast int64 to pointer";
|
||||
@@ -43,7 +43,7 @@ def TT_IntToPtrOp : TT_Op<"int_to_ptr", [SameOperandsAndResultShape,
|
||||
let assemblyFormat = "$from attr-dict `:` type($from) `->` type($result)";
|
||||
}
|
||||
|
||||
def TT_PtrToIntOp : TT_Op<"ptr_to_int", [SameOperandsAndResultShape,
|
||||
def TT_PtrToIntOp : TT_Op<"ptr_to_int", [SameOperandsAndResultShape,
|
||||
SameOperandsAndResultEncoding,
|
||||
NoSideEffect,
|
||||
/*DeclareOpInterfaceMethods<CastOpInterface>*/]> {
|
||||
@@ -57,7 +57,7 @@ def TT_PtrToIntOp : TT_Op<"ptr_to_int", [SameOperandsAndResultShape,
|
||||
}
|
||||
|
||||
// arith.bitcast doesn't support pointers
|
||||
def TT_BitcastOp : TT_Op<"bitcast", [SameOperandsAndResultShape,
|
||||
def TT_BitcastOp : TT_Op<"bitcast", [SameOperandsAndResultShape,
|
||||
SameOperandsAndResultEncoding,
|
||||
NoSideEffect,
|
||||
/*DeclareOpInterfaceMethods<CastOpInterface>*/]> {
|
||||
@@ -72,7 +72,7 @@ def TT_BitcastOp : TT_Op<"bitcast", [SameOperandsAndResultShape,
|
||||
// TODO: Add verifier
|
||||
}
|
||||
|
||||
def TT_FpToFp : TT_Op<"fp_to_fp", [SameOperandsAndResultShape,
|
||||
def TT_FpToFp : TT_Op<"fp_to_fp", [SameOperandsAndResultShape,
|
||||
SameOperandsAndResultEncoding,
|
||||
NoSideEffect,
|
||||
/*DeclareOpInterfaceMethods<CastOpInterface>*/]> {
|
||||
@@ -99,7 +99,7 @@ def TT_FpToFp : TT_Op<"fp_to_fp", [SameOperandsAndResultShape,
|
||||
//
|
||||
|
||||
def TT_AddPtrOp : TT_Op<"addptr",
|
||||
[NoSideEffect,
|
||||
[NoSideEffect,
|
||||
SameOperandsAndResultShape,
|
||||
SameOperandsAndResultEncoding,
|
||||
TypesMatchWith<"result type matches ptr type",
|
||||
@@ -224,7 +224,7 @@ def TT_AtomicCASOp : TT_Op<"atomic_cas", [SameOperandsAndResultShape,
|
||||
//
|
||||
// Shape Manipulation Ops
|
||||
//
|
||||
def TT_SplatOp : TT_Op<"splat", [NoSideEffect,
|
||||
def TT_SplatOp : TT_Op<"splat", [NoSideEffect,
|
||||
SameOperandsAndResultElementType]> {
|
||||
let summary = "splat";
|
||||
|
||||
@@ -237,8 +237,8 @@ def TT_SplatOp : TT_Op<"splat", [NoSideEffect,
|
||||
let hasFolder = 1;
|
||||
}
|
||||
|
||||
def TT_ExpandDimsOp : TT_Op<"expand_dims", [NoSideEffect,
|
||||
DeclareOpInterfaceMethods<InferTypeOpInterface>,
|
||||
def TT_ExpandDimsOp : TT_Op<"expand_dims", [NoSideEffect,
|
||||
DeclareOpInterfaceMethods<InferTypeOpInterface>,
|
||||
SameOperandsAndResultElementType]> {
|
||||
let summary = "expand_dims";
|
||||
|
||||
@@ -249,7 +249,7 @@ def TT_ExpandDimsOp : TT_Op<"expand_dims", [NoSideEffect,
|
||||
let assemblyFormat = "$src attr-dict `:` functional-type(operands, results)";
|
||||
}
|
||||
|
||||
def TT_ViewOp : TT_Op<"view", [NoSideEffect,
|
||||
def TT_ViewOp : TT_Op<"view", [NoSideEffect,
|
||||
SameOperandsAndResultElementType]> {
|
||||
let summary = "view";
|
||||
|
||||
@@ -261,7 +261,7 @@ def TT_ViewOp : TT_Op<"view", [NoSideEffect,
|
||||
|
||||
}
|
||||
|
||||
def TT_BroadcastOp : TT_Op<"broadcast", [NoSideEffect,
|
||||
def TT_BroadcastOp : TT_Op<"broadcast", [NoSideEffect,
|
||||
SameOperandsAndResultElementType]> {
|
||||
let summary = "broadcast. No left-padding as of now.";
|
||||
|
||||
@@ -274,7 +274,7 @@ def TT_BroadcastOp : TT_Op<"broadcast", [NoSideEffect,
|
||||
let hasFolder = 1;
|
||||
}
|
||||
|
||||
def TT_CatOp : TT_Op<"cat", [NoSideEffect,
|
||||
def TT_CatOp : TT_Op<"cat", [NoSideEffect,
|
||||
SameOperandsAndResultElementType]> {
|
||||
let summary = "concatenate 2 tensors";
|
||||
|
||||
@@ -307,7 +307,7 @@ def TT_GetNumProgramsOp : TT_Op<"get_num_programs", [NoSideEffect]> {
|
||||
//
|
||||
// Dot Op
|
||||
//
|
||||
def TT_DotOp : TT_Op<"dot", [NoSideEffect,
|
||||
def TT_DotOp : TT_Op<"dot", [NoSideEffect,
|
||||
DeclareOpInterfaceMethods<InferTypeOpInterface>,
|
||||
TypesMatchWith<"result's type matches accumulator's type",
|
||||
"d", "c", "$_self">]> {
|
||||
@@ -385,4 +385,20 @@ def TT_MakeRangeOp : TT_Op<"make_range", [NoSideEffect]> {
|
||||
let assemblyFormat = "attr-dict `:` type($result)";
|
||||
}
|
||||
|
||||
//
|
||||
// Make PrintfOp
|
||||
//
|
||||
def TT_PrintfOp : TT_Op<"printf", [MemoryEffects<[MemWrite]>]>,
|
||||
Arguments<(ins StrAttr:$prefix,
|
||||
Variadic<AnyTypeOf<[TT_Type]>>:$args)> {
|
||||
let summary = "Device-side printf, as in CUDA for debugging";
|
||||
let description = [{
|
||||
`tt.printf` takes a literal string prefix and an arbitrary number of scalar or tensor arguments that should be printed.
|
||||
format are generated automatically from the arguments.
|
||||
}];
|
||||
let assemblyFormat = [{
|
||||
$prefix attr-dict ($args^ `:` type($args))?
|
||||
}];
|
||||
}
|
||||
|
||||
#endif // Triton_OPS
|
||||
|
@@ -119,9 +119,11 @@ Value createLLVMIntegerConstant(OpBuilder &builder, Location loc, short width,
|
||||
#define barrier() rewriter.create<mlir::gpu::BarrierOp>(loc)
|
||||
#define undef(...) rewriter.create<LLVM::UndefOp>(loc, __VA_ARGS__)
|
||||
#define i32_ty rewriter.getIntegerType(32)
|
||||
#define ui32_ty rewriter.getIntegerType(32, false)
|
||||
#define f16_ty rewriter.getF16Type()
|
||||
#define i8_ty rewriter.getIntegerType(8)
|
||||
#define f32_ty rewriter.getF32Type()
|
||||
#define f64_ty rewriter.getF64Type()
|
||||
#define vec_ty(type, num) VectorType::get(num, type)
|
||||
#define f32_val(...) LLVM::createConstantF32(loc, rewriter, __VA_ARGS__)
|
||||
#define void_ty(ctx) LLVM::LLVMVoidType::get(ctx)
|
||||
@@ -1807,6 +1809,7 @@ public:
|
||||
}
|
||||
Value view = getStructFromElements(loc, resultVals, rewriter, structTy);
|
||||
rewriter.replaceOp(op, view);
|
||||
|
||||
return success();
|
||||
}
|
||||
|
||||
@@ -4541,6 +4544,190 @@ struct FDivOpConversion
|
||||
}
|
||||
};
|
||||
|
||||
struct PrintfOpConversion
|
||||
: public ConvertTritonGPUOpToLLVMPattern<triton::PrintfOp> {
|
||||
using ConvertTritonGPUOpToLLVMPattern<
|
||||
triton::PrintfOp>::ConvertTritonGPUOpToLLVMPattern;
|
||||
|
||||
LogicalResult
|
||||
matchAndRewrite(triton::PrintfOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
auto loc = op->getLoc();
|
||||
SmallVector<Value, 16> operands;
|
||||
for (auto operand : adaptor.getOperands()) {
|
||||
auto sub_operands = this->getElementsFromStruct(loc, operand, rewriter);
|
||||
for (auto elem : sub_operands) {
|
||||
operands.push_back(elem);
|
||||
}
|
||||
}
|
||||
std::string formatStr;
|
||||
llvm::raw_string_ostream os(formatStr);
|
||||
os << op.prefix();
|
||||
if (operands.size() > 0) {
|
||||
os << getFormatSubstr(operands[0]);
|
||||
}
|
||||
|
||||
for (size_t i = 1; i < operands.size(); ++i) {
|
||||
os << ", " << getFormatSubstr(operands[i]);
|
||||
}
|
||||
llPrintf(formatStr, operands, rewriter);
|
||||
rewriter.eraseOp(op);
|
||||
return success();
|
||||
}
|
||||
// get format specific for each input value
|
||||
// currently support pointer, i8, i16, i32, i64, f16, bf16, f32, f64
|
||||
std::string getFormatSubstr(Value value) const {
|
||||
Type type = value.getType();
|
||||
unsigned width = type.getIntOrFloatBitWidth();
|
||||
|
||||
if (type.isa<LLVM::LLVMPointerType>()) {
|
||||
return "%p";
|
||||
} else if (type.isBF16() || type.isF16() || type.isF32() || type.isF64()) {
|
||||
return "%f";
|
||||
} else if (type.isSignedInteger()) {
|
||||
return "%i";
|
||||
} else if (type.isUnsignedInteger() || type.isSignlessInteger()) {
|
||||
return "%u";
|
||||
}
|
||||
assert(false && "not supported type");
|
||||
}
|
||||
|
||||
// declare vprintf(i8*, i8*) as external function
|
||||
LLVM::LLVMFuncOp
|
||||
getVprintfDeclaration(ConversionPatternRewriter &rewriter) const {
|
||||
auto moduleOp =
|
||||
rewriter.getBlock()->getParent()->getParentOfType<ModuleOp>();
|
||||
StringRef funcName("vprintf");
|
||||
Operation *funcOp = moduleOp.lookupSymbol(funcName);
|
||||
if (funcOp)
|
||||
return cast<LLVM::LLVMFuncOp>(*funcOp);
|
||||
|
||||
auto *context = rewriter.getContext();
|
||||
|
||||
SmallVector<Type> argsType{ptr_ty(IntegerType::get(context, 8)),
|
||||
ptr_ty(IntegerType::get(context, 8))};
|
||||
auto funcType = LLVM::LLVMFunctionType::get(i32_ty, argsType);
|
||||
|
||||
ConversionPatternRewriter::InsertionGuard guard(rewriter);
|
||||
rewriter.setInsertionPointToStart(moduleOp.getBody());
|
||||
|
||||
return rewriter.create<LLVM::LLVMFuncOp>(UnknownLoc::get(context), funcName,
|
||||
funcType);
|
||||
}
|
||||
|
||||
// extend integer to int32, extend float to float64
|
||||
// this comes from vprintf alignment requirements.
|
||||
std::pair<Type, Value> promoteValue(ConversionPatternRewriter &rewriter,
|
||||
Value value) const {
|
||||
auto *context = rewriter.getContext();
|
||||
auto type = value.getType();
|
||||
unsigned width = type.getIntOrFloatBitWidth();
|
||||
Value newOp = value;
|
||||
Type newType = type;
|
||||
|
||||
bool bUnsigned = type.isUnsignedInteger();
|
||||
if (type.isIntOrIndex() && width < 32) {
|
||||
if (bUnsigned) {
|
||||
newType = ui32_ty;
|
||||
newOp = rewriter.create<LLVM::ZExtOp>(UnknownLoc::get(context), newType,
|
||||
value);
|
||||
} else {
|
||||
newType = i32_ty;
|
||||
newOp = rewriter.create<LLVM::SExtOp>(UnknownLoc::get(context), newType,
|
||||
value);
|
||||
}
|
||||
} else if (type.isBF16() || type.isF16() || type.isF32()) {
|
||||
newType = f64_ty;
|
||||
newOp = rewriter.create<LLVM::FPExtOp>(UnknownLoc::get(context), newType,
|
||||
value);
|
||||
}
|
||||
|
||||
return {newType, newOp};
|
||||
}
|
||||
|
||||
void llPrintf(StringRef msg, ValueRange args,
|
||||
ConversionPatternRewriter &rewriter) const {
|
||||
static const char formatStringPrefix[] = "printfFormat_";
|
||||
assert(!msg.empty() && "printf with empty string not support");
|
||||
Type int8Ptr = ptr_ty(i8_ty);
|
||||
|
||||
auto *context = rewriter.getContext();
|
||||
auto moduleOp =
|
||||
rewriter.getBlock()->getParent()->getParentOfType<ModuleOp>();
|
||||
auto funcOp = getVprintfDeclaration(rewriter);
|
||||
|
||||
Value one = rewriter.create<LLVM::ConstantOp>(
|
||||
UnknownLoc::get(context), i32_ty, rewriter.getI32IntegerAttr(1));
|
||||
Value zero = rewriter.create<LLVM::ConstantOp>(
|
||||
UnknownLoc::get(context), i32_ty, rewriter.getI32IntegerAttr(0));
|
||||
|
||||
unsigned stringNumber = 0;
|
||||
SmallString<16> stringConstName;
|
||||
do {
|
||||
stringConstName.clear();
|
||||
(formatStringPrefix + Twine(stringNumber++)).toStringRef(stringConstName);
|
||||
} while (moduleOp.lookupSymbol(stringConstName));
|
||||
|
||||
llvm::SmallString<64> formatString(msg);
|
||||
formatString.push_back('\n');
|
||||
formatString.push_back('\0');
|
||||
size_t formatStringSize = formatString.size_in_bytes();
|
||||
auto globalType = LLVM::LLVMArrayType::get(i8_ty, formatStringSize);
|
||||
|
||||
LLVM::GlobalOp global;
|
||||
{
|
||||
ConversionPatternRewriter::InsertionGuard guard(rewriter);
|
||||
rewriter.setInsertionPointToStart(moduleOp.getBody());
|
||||
global = rewriter.create<LLVM::GlobalOp>(
|
||||
UnknownLoc::get(context), globalType,
|
||||
/*isConstant=*/true, LLVM::Linkage::Internal, stringConstName,
|
||||
rewriter.getStringAttr(formatString));
|
||||
}
|
||||
|
||||
Value globalPtr =
|
||||
rewriter.create<LLVM::AddressOfOp>(UnknownLoc::get(context), global);
|
||||
Value stringStart =
|
||||
rewriter.create<LLVM::GEPOp>(UnknownLoc::get(context), int8Ptr,
|
||||
globalPtr, mlir::ValueRange({zero, zero}));
|
||||
|
||||
Value bufferPtr =
|
||||
rewriter.create<LLVM::NullOp>(UnknownLoc::get(context), int8Ptr);
|
||||
|
||||
SmallVector<Value, 16> newArgs;
|
||||
if (args.size() >= 1) {
|
||||
SmallVector<Type> argTypes;
|
||||
for (auto arg : args) {
|
||||
Type newType;
|
||||
Value newArg;
|
||||
std::tie(newType, newArg) = promoteValue(rewriter, arg);
|
||||
argTypes.push_back(newType);
|
||||
newArgs.push_back(newArg);
|
||||
}
|
||||
|
||||
Type structTy = LLVM::LLVMStructType::getLiteral(context, argTypes);
|
||||
auto allocated = rewriter.create<LLVM::AllocaOp>(UnknownLoc::get(context),
|
||||
ptr_ty(structTy), one,
|
||||
/*alignment=*/0);
|
||||
|
||||
for (const auto &entry : llvm::enumerate(newArgs)) {
|
||||
auto index = rewriter.create<LLVM::ConstantOp>(
|
||||
UnknownLoc::get(context), i32_ty,
|
||||
rewriter.getI32IntegerAttr(entry.index()));
|
||||
auto fieldPtr = rewriter.create<LLVM::GEPOp>(
|
||||
UnknownLoc::get(context), ptr_ty(argTypes[entry.index()]),
|
||||
allocated, ArrayRef<Value>{zero, index});
|
||||
rewriter.create<LLVM::StoreOp>(UnknownLoc::get(context), entry.value(),
|
||||
fieldPtr);
|
||||
}
|
||||
bufferPtr = rewriter.create<LLVM::BitcastOp>(UnknownLoc::get(context),
|
||||
int8Ptr, allocated);
|
||||
}
|
||||
|
||||
ValueRange operands{stringStart, bufferPtr};
|
||||
rewriter.create<LLVM::CallOp>(UnknownLoc::get(context), funcOp, operands);
|
||||
}
|
||||
};
|
||||
|
||||
void populateTritonToLLVMPatterns(mlir::LLVMTypeConverter &typeConverter,
|
||||
RewritePatternSet &patterns, int numWarps,
|
||||
AxisInfoAnalysis &axisInfoAnalysis,
|
||||
@@ -4627,6 +4814,7 @@ void populateTritonToLLVMPatterns(mlir::LLVMTypeConverter &typeConverter,
|
||||
patterns.add<ViewLikeOpConversion<triton::ExpandDimsOp>>(typeConverter,
|
||||
benefit);
|
||||
patterns.add<DotOpConversion>(typeConverter, allocation, smem, benefit);
|
||||
patterns.add<PrintfOpConversion>(typeConverter, benefit);
|
||||
}
|
||||
|
||||
class ConvertTritonGPUToLLVM
|
||||
|
@@ -339,6 +339,18 @@ struct TritonReducePattern : public OpConversionPattern<triton::ReduceOp> {
|
||||
}
|
||||
};
|
||||
|
||||
struct TritonPrintfPattern : public OpConversionPattern<triton::PrintfOp> {
|
||||
using OpConversionPattern<PrintfOp>::OpConversionPattern;
|
||||
|
||||
LogicalResult
|
||||
matchAndRewrite(PrintfOp op, typename PrintfOp::Adaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
rewriter.replaceOpWithNewOp<triton::PrintfOp>(op, op.prefixAttr(),
|
||||
adaptor.getOperands());
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
void populateTritonPatterns(TritonGPUTypeConverter &typeConverter,
|
||||
RewritePatternSet &patterns) {
|
||||
MLIRContext *context = patterns.getContext();
|
||||
@@ -350,8 +362,8 @@ void populateTritonPatterns(TritonGPUTypeConverter &typeConverter,
|
||||
TritonGenericPattern<triton::SplatOp>, TritonBroadcastPattern,
|
||||
TritonGenericPattern<triton::AddPtrOp>, TritonReducePattern,
|
||||
TritonExpandDimsPattern, TritonMakeRangePattern, TritonDotPattern,
|
||||
TritonLoadPattern, TritonStorePattern, TritonExtElemwisePattern>(
|
||||
typeConverter, context);
|
||||
TritonLoadPattern, TritonStorePattern, TritonExtElemwisePattern,
|
||||
TritonPrintfPattern>(typeConverter, context);
|
||||
}
|
||||
|
||||
//
|
||||
|
@@ -1185,6 +1185,16 @@ void init_triton_ir(py::module &&m) {
|
||||
auto loc = self.getUnknownLoc();
|
||||
return self.create<mlir::SelectOp>(loc, condition, trueValue,
|
||||
falseValue);
|
||||
})
|
||||
.def("create_printf",
|
||||
[](mlir::OpBuilder &self, const std::string &prefix,
|
||||
const std::vector<mlir::Value> &values) -> void {
|
||||
auto loc = self.getUnknownLoc();
|
||||
self.create<mlir::triton::PrintfOp>(
|
||||
loc,
|
||||
mlir::StringAttr::get(self.getContext(),
|
||||
llvm::StringRef(prefix)),
|
||||
values);
|
||||
});
|
||||
|
||||
py::class_<mlir::PassManager>(m, "pass_manager")
|
||||
|
56
python/tests/printf_helper.py
Normal file
56
python/tests/printf_helper.py
Normal file
@@ -0,0 +1,56 @@
|
||||
import torch
|
||||
from torch.testing import assert_close
|
||||
|
||||
import triton
|
||||
import triton.language as tl
|
||||
|
||||
torch_type = {
|
||||
"bool": torch.bool,
|
||||
'int8': torch.int8,
|
||||
'uint8': torch.uint8,
|
||||
'int16': torch.int16,
|
||||
"int32": torch.int32,
|
||||
'int64': torch.long,
|
||||
'float16': torch.float16,
|
||||
'bfloat16': torch.bfloat16,
|
||||
"float32": torch.float32,
|
||||
"float64": torch.float64
|
||||
}
|
||||
|
||||
|
||||
def get_tensor(shape, data_type, b_positive=False):
|
||||
x = None
|
||||
if data_type.startswith('int'):
|
||||
x = torch.arange(0, shape[0], dtype=torch_type[data_type], device='cuda')
|
||||
else:
|
||||
x = torch.arange(0, shape[0], dtype=torch_type[data_type], device='cuda')
|
||||
|
||||
return x
|
||||
|
||||
# @pytest.mark.parametrize('data_type',
|
||||
# [("int8"),
|
||||
# ('int16'),
|
||||
# ('int32'),
|
||||
# ("int64"),
|
||||
# ('float16'),
|
||||
# ("float32"),
|
||||
# ("float64")])
|
||||
|
||||
|
||||
def printf(data_type):
|
||||
@triton.jit
|
||||
def kernel(X, Y, BLOCK: tl.constexpr):
|
||||
x = tl.load(X + tl.arange(0, BLOCK))
|
||||
tl.printf("", x)
|
||||
tl.store(Y + tl.arange(0, BLOCK), x)
|
||||
|
||||
shape = (128, )
|
||||
# limit the range of integers so that the sum does not overflow
|
||||
x = get_tensor(shape, data_type)
|
||||
y = torch.zeros(shape, dtype=x.dtype, device="cuda")
|
||||
kernel[(1,)](x, y, BLOCK=shape[0])
|
||||
assert_close(y, x)
|
||||
|
||||
|
||||
printf("float16")
|
||||
printf("int8")
|
@@ -144,7 +144,7 @@ def _test_unary(dtype_x, expr, numpy_expr=None, device='cuda'):
|
||||
# triton result
|
||||
x_tri = to_triton(x, device=device, dst_type=dtype_x)
|
||||
z_tri = to_triton(np.empty_like(z_ref), device=device, dst_type=dtype_x)
|
||||
kernel[(1, )](z_tri, x_tri, SIZE=SIZE, num_warps=4)
|
||||
kernel[(1, )](z_tri, x_tri, SIZE=SIZE, num_warps=4, extern_libs={"libdevice": "/usr/local/cuda/nvvm/libdevice/libdevice.10.bc"})
|
||||
# compare
|
||||
np.testing.assert_allclose(z_ref, to_numpy(z_tri), rtol=0.01)
|
||||
|
||||
@@ -463,17 +463,12 @@ def test_unary_op(dtype_x, expr, device='cuda'):
|
||||
# # test math ops
|
||||
# # ----------------
|
||||
|
||||
# TODO: Math module
|
||||
# # @pytest.mark.parametrize("expr", [
|
||||
# # 'exp', 'log', 'cos', 'sin'
|
||||
# # ])
|
||||
|
||||
|
||||
# @pytest.mark.parametrize("expr", [
|
||||
# 'exp', 'log', 'cos', 'sin'
|
||||
# ])
|
||||
# def test_math_op(expr, device='cuda'):
|
||||
# _test_unary('float32', f'tl.{expr}(x)', f'np.{expr}(x) ', device=device)
|
||||
@pytest.mark.parametrize("expr", [
|
||||
'exp', 'log', 'cos', 'sin'
|
||||
])
|
||||
def test_math_op(expr, device='cuda'):
|
||||
_test_unary('float32', f'tl.{expr}(x)', f'np.{expr}(x) ', device=device)
|
||||
|
||||
|
||||
# # ----------------
|
||||
@@ -1545,43 +1540,43 @@ def test_num_warps_pow2():
|
||||
# # -------------
|
||||
|
||||
|
||||
# @pytest.mark.parametrize("dtype_str, expr, lib_path",
|
||||
# [('int32', 'libdevice.ffs', ''),
|
||||
# ('float32', 'libdevice.pow', '/usr/local/cuda/nvvm/libdevice/libdevice.10.bc'),
|
||||
# ('float64', 'libdevice.norm4d', '')])
|
||||
# def test_libdevice(dtype_str, expr, lib_path):
|
||||
@pytest.mark.parametrize("dtype_str, expr, lib_path",
|
||||
[('int32', 'libdevice.ffs', ''),
|
||||
('float32', 'libdevice.pow', '/usr/local/cuda/nvvm/libdevice/libdevice.10.bc'),
|
||||
('float64', 'libdevice.norm4d', '')])
|
||||
def test_libdevice(dtype_str, expr, lib_path):
|
||||
|
||||
# @triton.jit
|
||||
# def kernel(X, Y, BLOCK: tl.constexpr):
|
||||
# x = tl.load(X + tl.arange(0, BLOCK))
|
||||
# y = GENERATE_TEST_HERE
|
||||
# tl.store(Y + tl.arange(0, BLOCK), y)
|
||||
@triton.jit
|
||||
def kernel(X, Y, BLOCK: tl.constexpr):
|
||||
x = tl.load(X + tl.arange(0, BLOCK))
|
||||
y = GENERATE_TEST_HERE
|
||||
tl.store(Y + tl.arange(0, BLOCK), y)
|
||||
|
||||
# shape = (128, )
|
||||
# rs = RandomState(17)
|
||||
# # limit the range of integers so that the sum does not overflow
|
||||
# x = numpy_random(shape, dtype_str=dtype_str, rs=rs)
|
||||
shape = (128, )
|
||||
rs = RandomState(17)
|
||||
# limit the range of integers so that the sum does not overflow
|
||||
x = numpy_random(shape, dtype_str=dtype_str, rs=rs)
|
||||
|
||||
# if expr == 'libdevice.ffs':
|
||||
# kernel = patch_kernel(kernel, {'GENERATE_TEST_HERE': 'tl.libdevice.ffs(x)'})
|
||||
# y_ref = np.zeros(shape, dtype=x.dtype)
|
||||
# for i in range(shape[0]):
|
||||
# y_ref[i] = (int(x[i]) & int(-x[i])).bit_length()
|
||||
# elif expr == 'libdevice.pow':
|
||||
# # numpy does not allow negative factors in power, so we use abs()
|
||||
# x = np.abs(x)
|
||||
# kernel = patch_kernel(kernel, {'GENERATE_TEST_HERE': 'tl.libdevice.pow(x, x)'})
|
||||
# y_ref = np.power(x, x)
|
||||
# elif expr == 'libdevice.norm4d':
|
||||
# kernel = patch_kernel(kernel, {'GENERATE_TEST_HERE': 'tl.libdevice.norm4d(x, x, x, x)'})
|
||||
# y_ref = np.sqrt(4 * np.power(x, 2))
|
||||
if expr == 'libdevice.ffs':
|
||||
kernel = patch_kernel(kernel, {'GENERATE_TEST_HERE': 'tl.libdevice.ffs(x)'})
|
||||
y_ref = np.zeros(shape, dtype=x.dtype)
|
||||
for i in range(shape[0]):
|
||||
y_ref[i] = (int(x[i]) & int(-x[i])).bit_length()
|
||||
elif expr == 'libdevice.pow':
|
||||
# numpy does not allow negative factors in power, so we use abs()
|
||||
x = np.abs(x)
|
||||
kernel = patch_kernel(kernel, {'GENERATE_TEST_HERE': 'tl.libdevice.pow(x, x)'})
|
||||
y_ref = np.power(x, x)
|
||||
elif expr == 'libdevice.norm4d':
|
||||
kernel = patch_kernel(kernel, {'GENERATE_TEST_HERE': 'tl.libdevice.norm4d(x, x, x, x)'})
|
||||
y_ref = np.sqrt(4 * np.power(x, 2))
|
||||
|
||||
# x_tri = to_triton(x)
|
||||
# # triton result
|
||||
# y_tri = to_triton(numpy_random((shape[0],), dtype_str=dtype_str, rs=rs), device='cuda')
|
||||
# kernel[(1,)](x_tri, y_tri, BLOCK=shape[0], extern_libs={'libdevice': lib_path})
|
||||
# # compare
|
||||
# if expr == 'libdevice.ffs':
|
||||
# np.testing.assert_equal(y_ref, to_numpy(y_tri))
|
||||
# else:
|
||||
# np.testing.assert_allclose(y_ref, to_numpy(y_tri), rtol=0.01)
|
||||
x_tri = to_triton(x)
|
||||
# triton result
|
||||
y_tri = to_triton(numpy_random((shape[0],), dtype_str=dtype_str, rs=rs), device='cuda')
|
||||
kernel[(1,)](x_tri, y_tri, BLOCK=shape[0], extern_libs={'libdevice': lib_path})
|
||||
# compare
|
||||
if expr == 'libdevice.ffs':
|
||||
np.testing.assert_equal(y_ref, to_numpy(y_tri))
|
||||
else:
|
||||
np.testing.assert_allclose(y_ref, to_numpy(y_tri), rtol=0.01)
|
||||
|
21
python/tests/test_printf.py
Normal file
21
python/tests/test_printf.py
Normal file
@@ -0,0 +1,21 @@
|
||||
import os
|
||||
import subprocess
|
||||
|
||||
dir_path = os.path.dirname(os.path.realpath(__file__))
|
||||
printf_path = os.path.join(dir_path, "printf_helper.py")
|
||||
|
||||
|
||||
def test_printf():
|
||||
proc = subprocess.Popen(["python", printf_path], stdout=subprocess.PIPE, shell=False)
|
||||
(outs, err) = proc.communicate()
|
||||
outs = outs.split()
|
||||
new_lines = set()
|
||||
for line in outs:
|
||||
try:
|
||||
value = int(float(line))
|
||||
new_lines.add(value)
|
||||
except Exception as e:
|
||||
print(e)
|
||||
for i in range(128):
|
||||
assert i in new_lines
|
||||
assert len(new_lines) == 128
|
@@ -1197,3 +1197,22 @@ def swizzle2d(i, j, size_i, size_j, size_g):
|
||||
@triton.jit
|
||||
def zeros_like(input):
|
||||
return zeros(input.shape, input.dtype)
|
||||
|
||||
|
||||
@builtin
|
||||
def printf(prefix, *args, _builder=None):
|
||||
import string
|
||||
new_prefix = prefix
|
||||
if isinstance(prefix, constexpr):
|
||||
new_prefix = prefix.value
|
||||
assert isinstance(new_prefix, str), f"{new_prefix} is not string"
|
||||
b_ascii = True
|
||||
for ch in new_prefix:
|
||||
if ch not in string.printable:
|
||||
b_ascii = False
|
||||
break
|
||||
assert b_ascii, f"{new_prefix} is not an ascii string"
|
||||
new_args = []
|
||||
for arg in args:
|
||||
new_args.append(_to_tensor(arg, _builder))
|
||||
return semantic.printf(new_prefix, new_args, _builder)
|
||||
|
@@ -1123,3 +1123,10 @@ def max_contiguous(x: tl.tensor, values: List[int]) -> tl.tensor:
|
||||
|
||||
def debug_barrier(builder: ir.builder) -> tl.tensor:
|
||||
return tl.tensor(builder.create_barrier(''), tl.void)
|
||||
|
||||
|
||||
def printf(prefix: str, args: List[tl.tensor], builder: ir.builder) -> tl.tensor:
|
||||
new_args = []
|
||||
for arg in args:
|
||||
new_args.append(arg.handle)
|
||||
return tl.tensor(builder.create_printf(prefix, new_args), tl.void)
|
||||
|
Reference in New Issue
Block a user