Merge remote-tracking branch 'origin/triton-mlir' into port-fma

This commit is contained in:
Superjomn
2022-11-04 17:43:54 +08:00
15 changed files with 519 additions and 133 deletions

View File

@@ -8,7 +8,6 @@ on:
- triton-mlir
jobs:
Runner-Preparation:
runs-on: ubuntu-latest
outputs:
@@ -18,9 +17,9 @@ jobs:
id: set-matrix
run: |
if [ x"${{ github.repository }}" == x"openai/triton" ]; then
echo '::set-output name=matrix::[["self-hosted", "A10"], "macos-latest"]'
echo '::set-output name=matrix::[["self-hosted", "A10"], "macos-10.15"]'
else
echo '::set-output name=matrix::["ubuntu-latest", "macos-latest"]'
echo '::set-output name=matrix::["ubuntu-latest", "macos-10.15"]'
fi
Integration-Tests:
@@ -33,7 +32,6 @@ jobs:
runner: ${{fromJson(needs.Runner-Preparation.outputs.matrix)}}
steps:
- name: Checkout
uses: actions/checkout@v2
@@ -42,33 +40,32 @@ jobs:
rm -rf ~/.triton/cache/
- name: Check imports
if: ${{ matrix.runner != 'macos-latest' }}
if: startsWith(matrix.runner, 'ubuntu')
run: |
pip install isort
isort -c ./python || ( echo '::error title=Imports not sorted::Please run \"isort ./python\"' ; exit 1 )
- name: Check python style
if: ${{ matrix.runner != 'macos-latest' }}
if: startsWith(matrix.runner, 'ubuntu')
run: |
pip install autopep8
autopep8 -a -r -d --exit-code ./python || ( echo '::error title=Style issues::Please run \"autopep8 -a -r -i ./python\"' ; exit 1 )
- name: Check cpp style
if: ${{ matrix.runner != 'macos-latest' }}
if: startsWith(matrix.runner, 'ubuntu')
run: |
pip install clang-format
find . -regex '.*\.\(cpp\|hpp\|h\|cc\)' -not -path "./python/build/*" -not -path "./include/triton/external/*" -print0 | xargs -0 -n1 clang-format -style=file --dry-run -Werror -i ||
(echo '::error title=Style issues:: Please run `find . -regex ".*\.\(cpp\|hpp\|h\|cc\)" -not -path "./python/build/*" -not -path "./include/triton/external/*" -print0 | xargs -0 -n1 clang-format -style=file -i`' ; exit 1)
- name: Flake8
if: ${{ matrix.runner != 'macos-latest' }}
if: startsWith(matrix.runner, 'ubuntu')
run: |
pip install flake8
flake8 --config ./python/setup.cfg ./python || ( echo '::error::Flake8 failed; see logs for errors.' ; exit 1 )
- name: Install Triton
run: |
alias python='python3'
cd python
TRITON_USE_ASSERT_ENABLED_LLVM=TRUE pip3 install -e '.[tests]'
@@ -82,7 +79,7 @@ jobs:
lit -v "$LIT_TEST_DIR"
- name: Run python tests
if: ${{ matrix.runner[0] == 'self-hosted' }}
if: ${{matrix.runner[0] == 'self-hosted'}}
run: |
cd python/tests
pytest

View File

@@ -357,12 +357,11 @@ def TT_ExtElemwiseOp : TT_Op<"ext_elemwise", [NoSideEffect, Elementwise, SameOpe
return $libpath/$libname:$symbol($args...)
}];
let arguments = (ins Variadic<TT_Tensor>:$args, StrAttr:$libname, StrAttr:$libpath, StrAttr:$symbol);
let arguments = (ins Variadic<TT_Type>:$args, StrAttr:$libname, StrAttr:$libpath, StrAttr:$symbol);
let results = (outs TT_Tensor:$result);
let results = (outs TT_Type:$result);
let assemblyFormat = "operands attr-dict `:` type(operands) `->` type($result)";
}
//
@@ -385,4 +384,20 @@ def TT_MakeRangeOp : TT_Op<"make_range", [NoSideEffect]> {
let assemblyFormat = "attr-dict `:` type($result)";
}
//
// Make PrintfOp
//
def TT_PrintfOp : TT_Op<"printf", [MemoryEffects<[MemWrite]>]>,
Arguments<(ins StrAttr:$prefix,
Variadic<AnyTypeOf<[TT_Type]>>:$args)> {
let summary = "Device-side printf, as in CUDA for debugging";
let description = [{
`tt.printf` takes a literal string prefix and an arbitrary number of scalar or tensor arguments that should be printed.
format are generated automatically from the arguments.
}];
let assemblyFormat = [{
$prefix attr-dict ($args^ `:` type($args))?
}];
}
#endif // Triton_OPS

View File

@@ -159,6 +159,16 @@ ChangeResult AxisInfoAnalysis::visitOperation(
curr = visitBinaryOp(op, operands[0]->getValue(), operands[1]->getValue(),
newContiguity, newDivisibility, newConstancy);
}
// TODO: All other binary ops
if (llvm::isa<arith::AndIOp, arith::OrIOp>(op)) {
auto newContiguity = [](AxisInfo lhs, AxisInfo rhs, int d) { return 1; };
auto newDivisibility = [](AxisInfo lhs, AxisInfo rhs, int d) { return 1; };
auto newConstancy = [](AxisInfo lhs, AxisInfo rhs, int d) {
return gcd(lhs.getConstancy(d), rhs.getConstancy(d));
};
curr = visitBinaryOp(op, operands[0]->getValue(), operands[1]->getValue(),
newContiguity, newDivisibility, newConstancy);
}
// Splat
if (llvm::isa<triton::SplatOp>(op)) {
Type _retTy = *op->result_type_begin();
@@ -200,7 +210,8 @@ ChangeResult AxisInfoAnalysis::visitOperation(
for (int d = 0; d < retTy.getRank(); ++d) {
contiguity.push_back(opShape[d] == 1 ? 1 : opInfo.getContiguity(d));
divisibility.push_back(opInfo.getDivisibility(d));
constancy.push_back(opShape[d] == 1 ? retShape[d] : 1);
constancy.push_back(opShape[d] == 1 ? retShape[d]
: opInfo.getConstancy(d));
}
curr = AxisInfo(contiguity, divisibility, constancy);
}

View File

@@ -119,9 +119,11 @@ Value createLLVMIntegerConstant(OpBuilder &builder, Location loc, short width,
#define barrier() rewriter.create<mlir::gpu::BarrierOp>(loc)
#define undef(...) rewriter.create<LLVM::UndefOp>(loc, __VA_ARGS__)
#define i32_ty rewriter.getIntegerType(32)
#define ui32_ty rewriter.getIntegerType(32, false)
#define f16_ty rewriter.getF16Type()
#define i8_ty rewriter.getIntegerType(8)
#define f32_ty rewriter.getF32Type()
#define f64_ty rewriter.getF64Type()
#define vec_ty(type, num) VectorType::get(num, type)
#define f32_val(...) LLVM::createConstantF32(loc, rewriter, __VA_ARGS__)
#define void_ty(ctx) LLVM::LLVMVoidType::get(ctx)
@@ -691,7 +693,8 @@ Value convertSplatLikeOp(Type elemType, Type resType, Value constVal,
TypeConverter *typeConverter,
ConversionPatternRewriter &rewriter, Location loc) {
auto tensorTy = resType.cast<RankedTensorType>();
if (tensorTy.getEncoding().isa<BlockedEncodingAttr>()) {
if (tensorTy.getEncoding().isa<BlockedEncodingAttr>() ||
tensorTy.getEncoding().isa<SliceEncodingAttr>()) {
auto tensorTy = resType.cast<RankedTensorType>();
auto srcType = typeConverter->convertType(elemType);
auto llSrc = bitcast(constVal, srcType);
@@ -1807,6 +1810,7 @@ public:
}
Value view = getStructFromElements(loc, resultVals, rewriter, structTy);
rewriter.replaceOp(op, view);
return success();
}
@@ -4731,6 +4735,190 @@ struct FDivOpConversion
}
};
struct PrintfOpConversion
: public ConvertTritonGPUOpToLLVMPattern<triton::PrintfOp> {
using ConvertTritonGPUOpToLLVMPattern<
triton::PrintfOp>::ConvertTritonGPUOpToLLVMPattern;
LogicalResult
matchAndRewrite(triton::PrintfOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto loc = op->getLoc();
SmallVector<Value, 16> operands;
for (auto operand : adaptor.getOperands()) {
auto sub_operands = this->getElementsFromStruct(loc, operand, rewriter);
for (auto elem : sub_operands) {
operands.push_back(elem);
}
}
std::string formatStr;
llvm::raw_string_ostream os(formatStr);
os << op.prefix();
if (operands.size() > 0) {
os << getFormatSubstr(operands[0]);
}
for (size_t i = 1; i < operands.size(); ++i) {
os << ", " << getFormatSubstr(operands[i]);
}
llPrintf(formatStr, operands, rewriter);
rewriter.eraseOp(op);
return success();
}
// get format specific for each input value
// currently support pointer, i8, i16, i32, i64, f16, bf16, f32, f64
std::string getFormatSubstr(Value value) const {
Type type = value.getType();
unsigned width = type.getIntOrFloatBitWidth();
if (type.isa<LLVM::LLVMPointerType>()) {
return "%p";
} else if (type.isBF16() || type.isF16() || type.isF32() || type.isF64()) {
return "%f";
} else if (type.isSignedInteger()) {
return "%i";
} else if (type.isUnsignedInteger() || type.isSignlessInteger()) {
return "%u";
}
assert(false && "not supported type");
}
// declare vprintf(i8*, i8*) as external function
LLVM::LLVMFuncOp
getVprintfDeclaration(ConversionPatternRewriter &rewriter) const {
auto moduleOp =
rewriter.getBlock()->getParent()->getParentOfType<ModuleOp>();
StringRef funcName("vprintf");
Operation *funcOp = moduleOp.lookupSymbol(funcName);
if (funcOp)
return cast<LLVM::LLVMFuncOp>(*funcOp);
auto *context = rewriter.getContext();
SmallVector<Type> argsType{ptr_ty(IntegerType::get(context, 8)),
ptr_ty(IntegerType::get(context, 8))};
auto funcType = LLVM::LLVMFunctionType::get(i32_ty, argsType);
ConversionPatternRewriter::InsertionGuard guard(rewriter);
rewriter.setInsertionPointToStart(moduleOp.getBody());
return rewriter.create<LLVM::LLVMFuncOp>(UnknownLoc::get(context), funcName,
funcType);
}
// extend integer to int32, extend float to float64
// this comes from vprintf alignment requirements.
std::pair<Type, Value> promoteValue(ConversionPatternRewriter &rewriter,
Value value) const {
auto *context = rewriter.getContext();
auto type = value.getType();
unsigned width = type.getIntOrFloatBitWidth();
Value newOp = value;
Type newType = type;
bool bUnsigned = type.isUnsignedInteger();
if (type.isIntOrIndex() && width < 32) {
if (bUnsigned) {
newType = ui32_ty;
newOp = rewriter.create<LLVM::ZExtOp>(UnknownLoc::get(context), newType,
value);
} else {
newType = i32_ty;
newOp = rewriter.create<LLVM::SExtOp>(UnknownLoc::get(context), newType,
value);
}
} else if (type.isBF16() || type.isF16() || type.isF32()) {
newType = f64_ty;
newOp = rewriter.create<LLVM::FPExtOp>(UnknownLoc::get(context), newType,
value);
}
return {newType, newOp};
}
void llPrintf(StringRef msg, ValueRange args,
ConversionPatternRewriter &rewriter) const {
static const char formatStringPrefix[] = "printfFormat_";
assert(!msg.empty() && "printf with empty string not support");
Type int8Ptr = ptr_ty(i8_ty);
auto *context = rewriter.getContext();
auto moduleOp =
rewriter.getBlock()->getParent()->getParentOfType<ModuleOp>();
auto funcOp = getVprintfDeclaration(rewriter);
Value one = rewriter.create<LLVM::ConstantOp>(
UnknownLoc::get(context), i32_ty, rewriter.getI32IntegerAttr(1));
Value zero = rewriter.create<LLVM::ConstantOp>(
UnknownLoc::get(context), i32_ty, rewriter.getI32IntegerAttr(0));
unsigned stringNumber = 0;
SmallString<16> stringConstName;
do {
stringConstName.clear();
(formatStringPrefix + Twine(stringNumber++)).toStringRef(stringConstName);
} while (moduleOp.lookupSymbol(stringConstName));
llvm::SmallString<64> formatString(msg);
formatString.push_back('\n');
formatString.push_back('\0');
size_t formatStringSize = formatString.size_in_bytes();
auto globalType = LLVM::LLVMArrayType::get(i8_ty, formatStringSize);
LLVM::GlobalOp global;
{
ConversionPatternRewriter::InsertionGuard guard(rewriter);
rewriter.setInsertionPointToStart(moduleOp.getBody());
global = rewriter.create<LLVM::GlobalOp>(
UnknownLoc::get(context), globalType,
/*isConstant=*/true, LLVM::Linkage::Internal, stringConstName,
rewriter.getStringAttr(formatString));
}
Value globalPtr =
rewriter.create<LLVM::AddressOfOp>(UnknownLoc::get(context), global);
Value stringStart =
rewriter.create<LLVM::GEPOp>(UnknownLoc::get(context), int8Ptr,
globalPtr, mlir::ValueRange({zero, zero}));
Value bufferPtr =
rewriter.create<LLVM::NullOp>(UnknownLoc::get(context), int8Ptr);
SmallVector<Value, 16> newArgs;
if (args.size() >= 1) {
SmallVector<Type> argTypes;
for (auto arg : args) {
Type newType;
Value newArg;
std::tie(newType, newArg) = promoteValue(rewriter, arg);
argTypes.push_back(newType);
newArgs.push_back(newArg);
}
Type structTy = LLVM::LLVMStructType::getLiteral(context, argTypes);
auto allocated = rewriter.create<LLVM::AllocaOp>(UnknownLoc::get(context),
ptr_ty(structTy), one,
/*alignment=*/0);
for (const auto &entry : llvm::enumerate(newArgs)) {
auto index = rewriter.create<LLVM::ConstantOp>(
UnknownLoc::get(context), i32_ty,
rewriter.getI32IntegerAttr(entry.index()));
auto fieldPtr = rewriter.create<LLVM::GEPOp>(
UnknownLoc::get(context), ptr_ty(argTypes[entry.index()]),
allocated, ArrayRef<Value>{zero, index});
rewriter.create<LLVM::StoreOp>(UnknownLoc::get(context), entry.value(),
fieldPtr);
}
bufferPtr = rewriter.create<LLVM::BitcastOp>(UnknownLoc::get(context),
int8Ptr, allocated);
}
ValueRange operands{stringStart, bufferPtr};
rewriter.create<LLVM::CallOp>(UnknownLoc::get(context), funcOp, operands);
}
};
void populateTritonToLLVMPatterns(mlir::LLVMTypeConverter &typeConverter,
RewritePatternSet &patterns, int numWarps,
AxisInfoAnalysis &axisInfoAnalysis,
@@ -4817,6 +5005,7 @@ void populateTritonToLLVMPatterns(mlir::LLVMTypeConverter &typeConverter,
patterns.add<ViewLikeOpConversion<triton::ExpandDimsOp>>(typeConverter,
benefit);
patterns.add<DotOpConversion>(typeConverter, allocation, smem, benefit);
patterns.add<PrintfOpConversion>(typeConverter, benefit);
}
class ConvertTritonGPUToLLVM

View File

@@ -339,6 +339,18 @@ struct TritonReducePattern : public OpConversionPattern<triton::ReduceOp> {
}
};
struct TritonPrintfPattern : public OpConversionPattern<triton::PrintfOp> {
using OpConversionPattern<PrintfOp>::OpConversionPattern;
LogicalResult
matchAndRewrite(PrintfOp op, typename PrintfOp::Adaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
rewriter.replaceOpWithNewOp<triton::PrintfOp>(op, op.prefixAttr(),
adaptor.getOperands());
return success();
}
};
void populateTritonPatterns(TritonGPUTypeConverter &typeConverter,
RewritePatternSet &patterns) {
MLIRContext *context = patterns.getContext();
@@ -350,8 +362,8 @@ void populateTritonPatterns(TritonGPUTypeConverter &typeConverter,
TritonGenericPattern<triton::SplatOp>, TritonBroadcastPattern,
TritonGenericPattern<triton::AddPtrOp>, TritonReducePattern,
TritonExpandDimsPattern, TritonMakeRangePattern, TritonDotPattern,
TritonLoadPattern, TritonStorePattern, TritonExtElemwisePattern>(
typeConverter, context);
TritonLoadPattern, TritonStorePattern, TritonExtElemwisePattern,
TritonPrintfPattern>(typeConverter, context);
}
//

View File

@@ -533,6 +533,35 @@ public:
BlockedToMMA(mlir::MLIRContext *context)
: mlir::RewritePattern(triton::DotOp::getOperationName(), 2, context) {}
static SmallVector<unsigned, 2>
getWarpsPerTile(const ArrayRef<int64_t> &shape, int version, int numWarps) {
assert(version == 2);
// TODO: Handle one warp per row for fused matmuls
// TODO: unsigned -> int64_t to keep things uniform
SmallVector<unsigned, 2> ret = {1, 1};
SmallVector<int64_t, 2> shapePerWarp = {16, 8};
bool changed = false;
// TODO (@daadaada): double-check.
// original logic in
// https://github.com/openai/triton/blob/master/lib/codegen/analysis/layout.cc#L252
// seems buggy for shape = [32, 16] ?
do {
changed = false;
if (ret[0] * ret[1] >= numWarps)
break;
if (shape[0] / shapePerWarp[0] / ret[0] >=
shape[1] / (shapePerWarp[1] * 2) / ret[1]) {
if (ret[0] < shape[0] / shapePerWarp[0]) {
ret[0] *= 2;
} else
ret[1] *= 2;
} else {
ret[1] *= 2;
}
} while (true);
return ret;
}
mlir::LogicalResult
matchAndRewrite(mlir::Operation *op,
mlir::PatternRewriter &rewriter) const override {
@@ -541,13 +570,20 @@ public:
auto oldRetType = dotOp.getResult().getType().cast<RankedTensorType>();
if (oldRetType.getEncoding().isa<triton::gpu::MmaEncodingAttr>())
return failure();
// TODO: compute warpsPerCTA
auto newRetType = RankedTensorType::get(
oldRetType.getShape(), oldRetType.getElementType(),
triton::gpu::MmaEncodingAttr::get(oldRetType.getContext(), 2, {2, 2}));
// get MMA encoding for the given number of warps
auto retShape = oldRetType.getShape();
auto mod = op->getParentOfType<mlir::ModuleOp>();
int numWarps = triton::gpu::TritonGPUDialect::getNumWarps(mod);
auto newRetType =
RankedTensorType::get(retShape, oldRetType.getElementType(),
triton::gpu::MmaEncodingAttr::get(
oldRetType.getContext(), 2,
getWarpsPerTile(retShape, 2, numWarps)));
// convert accumulator
auto oldAcc = dotOp.getOperand(2);
auto newAcc = rewriter.create<triton::gpu::ConvertLayoutOp>(
oldAcc.getLoc(), newRetType, oldAcc);
// convert output
auto newDot = rewriter.create<triton::DotOp>(
dotOp.getLoc(), newRetType, dotOp.getOperand(0), dotOp.getOperand(1),
newAcc, dotOp.allowTF32(), dotOp.transA(), dotOp.transB());

View File

@@ -197,7 +197,8 @@ translateTritonGPUToLLVMIR(llvm::LLVMContext *llvmContext,
ext_mod->setTargetTriple(llvmir->getTargetTriple());
ext_mod->setDataLayout(llvmir->getDataLayout());
if (llvm::Linker::linkModules(*llvmir, std::move(ext_mod))) {
if (llvm::Linker::linkModules(*llvmir, std::move(ext_mod),
llvm::Linker::Flags::LinkOnlyNeeded)) {
llvm::errs() << "Failed to link extern lib " << lib.first;
return nullptr;
}

View File

@@ -1185,6 +1185,16 @@ void init_triton_ir(py::module &&m) {
auto loc = self.getUnknownLoc();
return self.create<mlir::SelectOp>(loc, condition, trueValue,
falseValue);
})
.def("create_printf",
[](mlir::OpBuilder &self, const std::string &prefix,
const std::vector<mlir::Value> &values) -> void {
auto loc = self.getUnknownLoc();
self.create<mlir::triton::PrintfOp>(
loc,
mlir::StringAttr::get(self.getContext(),
llvm::StringRef(prefix)),
values);
});
py::class_<mlir::PassManager>(m, "pass_manager")

View File

@@ -0,0 +1,56 @@
import torch
from torch.testing import assert_close
import triton
import triton.language as tl
torch_type = {
"bool": torch.bool,
'int8': torch.int8,
'uint8': torch.uint8,
'int16': torch.int16,
"int32": torch.int32,
'int64': torch.long,
'float16': torch.float16,
'bfloat16': torch.bfloat16,
"float32": torch.float32,
"float64": torch.float64
}
def get_tensor(shape, data_type, b_positive=False):
x = None
if data_type.startswith('int'):
x = torch.arange(0, shape[0], dtype=torch_type[data_type], device='cuda')
else:
x = torch.arange(0, shape[0], dtype=torch_type[data_type], device='cuda')
return x
# @pytest.mark.parametrize('data_type',
# [("int8"),
# ('int16'),
# ('int32'),
# ("int64"),
# ('float16'),
# ("float32"),
# ("float64")])
def printf(data_type):
@triton.jit
def kernel(X, Y, BLOCK: tl.constexpr):
x = tl.load(X + tl.arange(0, BLOCK))
tl.printf("", x)
tl.store(Y + tl.arange(0, BLOCK), x)
shape = (128, )
# limit the range of integers so that the sum does not overflow
x = get_tensor(shape, data_type)
y = torch.zeros(shape, dtype=x.dtype, device="cuda")
kernel[(1,)](x, y, BLOCK=shape[0])
assert_close(y, x)
printf("float16")
printf("int8")

View File

@@ -144,7 +144,7 @@ def _test_unary(dtype_x, expr, numpy_expr=None, device='cuda'):
# triton result
x_tri = to_triton(x, device=device, dst_type=dtype_x)
z_tri = to_triton(np.empty_like(z_ref), device=device, dst_type=dtype_x)
kernel[(1, )](z_tri, x_tri, SIZE=SIZE, num_warps=4)
kernel[(1, )](z_tri, x_tri, SIZE=SIZE, num_warps=4, extern_libs={"libdevice": "/usr/local/cuda/nvvm/libdevice/libdevice.10.bc"})
# compare
np.testing.assert_allclose(z_ref, to_numpy(z_tri), rtol=0.01)
@@ -463,17 +463,12 @@ def test_unary_op(dtype_x, expr, device='cuda'):
# # test math ops
# # ----------------
# TODO: Math module
# # @pytest.mark.parametrize("expr", [
# # 'exp', 'log', 'cos', 'sin'
# # ])
# @pytest.mark.parametrize("expr", [
# 'exp', 'log', 'cos', 'sin'
# ])
# def test_math_op(expr, device='cuda'):
# _test_unary('float32', f'tl.{expr}(x)', f'np.{expr}(x) ', device=device)
@pytest.mark.parametrize("expr", [
'exp', 'log', 'cos', 'sin'
])
def test_math_op(expr, device='cuda'):
_test_unary('float32', f'tl.{expr}(x)', f'np.{expr}(x) ', device=device)
# # ----------------
@@ -1545,43 +1540,72 @@ def test_num_warps_pow2():
# # -------------
# @pytest.mark.parametrize("dtype_str, expr, lib_path",
# [('int32', 'libdevice.ffs', ''),
# ('float32', 'libdevice.pow', '/usr/local/cuda/nvvm/libdevice/libdevice.10.bc'),
# ('float64', 'libdevice.norm4d', '')])
# def test_libdevice(dtype_str, expr, lib_path):
@pytest.mark.parametrize("dtype_str, expr, lib_path",
[('int32', 'libdevice.ffs', ''),
('float32', 'libdevice.pow', '/usr/local/cuda/nvvm/libdevice/libdevice.10.bc'),
('float64', 'libdevice.norm4d', '')])
def test_libdevice_tensor(dtype_str, expr, lib_path):
# @triton.jit
# def kernel(X, Y, BLOCK: tl.constexpr):
# x = tl.load(X + tl.arange(0, BLOCK))
# y = GENERATE_TEST_HERE
# tl.store(Y + tl.arange(0, BLOCK), y)
@triton.jit
def kernel(X, Y, BLOCK: tl.constexpr):
x = tl.load(X + tl.arange(0, BLOCK))
y = GENERATE_TEST_HERE
tl.store(Y + tl.arange(0, BLOCK), y)
# shape = (128, )
# rs = RandomState(17)
# # limit the range of integers so that the sum does not overflow
# x = numpy_random(shape, dtype_str=dtype_str, rs=rs)
shape = (128, )
rs = RandomState(17)
# limit the range of integers so that the sum does not overflow
x = numpy_random(shape, dtype_str=dtype_str, rs=rs)
# if expr == 'libdevice.ffs':
# kernel = patch_kernel(kernel, {'GENERATE_TEST_HERE': 'tl.libdevice.ffs(x)'})
# y_ref = np.zeros(shape, dtype=x.dtype)
# for i in range(shape[0]):
# y_ref[i] = (int(x[i]) & int(-x[i])).bit_length()
# elif expr == 'libdevice.pow':
# # numpy does not allow negative factors in power, so we use abs()
# x = np.abs(x)
# kernel = patch_kernel(kernel, {'GENERATE_TEST_HERE': 'tl.libdevice.pow(x, x)'})
# y_ref = np.power(x, x)
# elif expr == 'libdevice.norm4d':
# kernel = patch_kernel(kernel, {'GENERATE_TEST_HERE': 'tl.libdevice.norm4d(x, x, x, x)'})
# y_ref = np.sqrt(4 * np.power(x, 2))
if expr == 'libdevice.ffs':
kernel = patch_kernel(kernel, {'GENERATE_TEST_HERE': 'tl.libdevice.ffs(x)'})
y_ref = np.zeros(shape, dtype=x.dtype)
for i in range(shape[0]):
y_ref[i] = (int(x[i]) & int(-x[i])).bit_length()
elif expr == 'libdevice.pow':
# numpy does not allow negative factors in power, so we use abs()
x = np.abs(x)
kernel = patch_kernel(kernel, {'GENERATE_TEST_HERE': 'tl.libdevice.pow(x, x)'})
y_ref = np.power(x, x)
elif expr == 'libdevice.norm4d':
kernel = patch_kernel(kernel, {'GENERATE_TEST_HERE': 'tl.libdevice.norm4d(x, x, x, x)'})
y_ref = np.sqrt(4 * np.power(x, 2))
# x_tri = to_triton(x)
# # triton result
# y_tri = to_triton(numpy_random((shape[0],), dtype_str=dtype_str, rs=rs), device='cuda')
# kernel[(1,)](x_tri, y_tri, BLOCK=shape[0], extern_libs={'libdevice': lib_path})
# # compare
# if expr == 'libdevice.ffs':
# np.testing.assert_equal(y_ref, to_numpy(y_tri))
# else:
# np.testing.assert_allclose(y_ref, to_numpy(y_tri), rtol=0.01)
x_tri = to_triton(x)
# triton result
y_tri = to_triton(numpy_random((shape[0],), dtype_str=dtype_str, rs=rs), device='cuda')
kernel[(1,)](x_tri, y_tri, BLOCK=shape[0], extern_libs={'libdevice': lib_path})
# compare
if expr == 'libdevice.ffs':
np.testing.assert_equal(y_ref, to_numpy(y_tri))
else:
np.testing.assert_allclose(y_ref, to_numpy(y_tri), rtol=0.01)
@pytest.mark.parametrize("dtype_str, expr, lib_path",
[('float32', 'libdevice.pow', '')])
def test_libdevice_scalar(dtype_str, expr, lib_path):
@triton.jit
def kernel(X, Y, BLOCK: tl.constexpr):
x = X
y = GENERATE_TEST_HERE
tl.store(Y + tl.arange(0, BLOCK), y)
shape = (128, )
rs = RandomState(17)
# limit the range of integers so that the sum does not overflow
x = numpy_random((1,), dtype_str=dtype_str, rs=rs)
y_ref = np.zeros(shape, dtype=x.dtype)
# numpy does not allow negative factors in power, so we use abs()
x = np.abs(x)
kernel = patch_kernel(kernel, {'GENERATE_TEST_HERE': 'tl.libdevice.pow(x, x)'})
y_ref[:] = np.power(x, x)
# triton result
x_tri = to_triton(x)[0].item()
y_tri = to_triton(numpy_random((shape[0],), dtype_str=dtype_str, rs=rs), device='cuda')
kernel[(1,)](x_tri, y_tri, BLOCK=shape[0], extern_libs={'libdevice': lib_path})
# compare
np.testing.assert_allclose(y_ref, to_numpy(y_tri), rtol=0.01)

View File

@@ -0,0 +1,21 @@
import os
import subprocess
dir_path = os.path.dirname(os.path.realpath(__file__))
printf_path = os.path.join(dir_path, "printf_helper.py")
def test_printf():
proc = subprocess.Popen(["python", printf_path], stdout=subprocess.PIPE, shell=False)
(outs, err) = proc.communicate()
outs = outs.split()
new_lines = set()
for line in outs:
try:
value = int(float(line))
new_lines.add(value)
except Exception as e:
print(e)
for i in range(128):
assert i in new_lines
assert len(new_lines) == 128

View File

@@ -1197,3 +1197,22 @@ def swizzle2d(i, j, size_i, size_j, size_g):
@triton.jit
def zeros_like(input):
return zeros(input.shape, input.dtype)
@builtin
def printf(prefix, *args, _builder=None):
import string
new_prefix = prefix
if isinstance(prefix, constexpr):
new_prefix = prefix.value
assert isinstance(new_prefix, str), f"{new_prefix} is not string"
b_ascii = True
for ch in new_prefix:
if ch not in string.printable:
b_ascii = False
break
assert b_ascii, f"{new_prefix} is not an ascii string"
new_args = []
for arg in args:
new_args.append(_to_tensor(arg, _builder))
return semantic.printf(new_prefix, new_args, _builder)

View File

@@ -56,6 +56,12 @@ def elementwise(lib_name: str, lib_path: str, args: list, arg_type_symbol_dict:
:return: the return value of the function
'''
dispatch_args = args.copy()
all_scalar = True
ret_shape = None
for dispatch_arg in dispatch_args:
if dispatch_arg.type.is_block():
all_scalar = False
if not all_scalar:
if len(args) == 1:
dispatch_args[0] = core._to_tensor(dispatch_args[0], _builder)
ret_shape = dispatch_args[0].shape

View File

@@ -1123,3 +1123,10 @@ def max_contiguous(x: tl.tensor, values: List[int]) -> tl.tensor:
def debug_barrier(builder: ir.builder) -> tl.tensor:
return tl.tensor(builder.create_barrier(''), tl.void)
def printf(prefix: str, args: List[tl.tensor], builder: ir.builder) -> tl.tensor:
new_args = []
for arg in args:
new_args.append(arg.handle)
return tl.tensor(builder.create_printf(prefix, new_args), tl.void)

View File

@@ -157,15 +157,6 @@ import triton.language as tl
@triton.autotune(
configs=[
triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=3, num_warps=8),
triton.Config({'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=3, num_warps=8),
triton.Config({'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=5, num_warps=2),
triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=5, num_warps=2),
],
key=['M', 'N', 'K'],
)
@@ -318,13 +309,13 @@ else:
triton.testing.Benchmark(
x_names=['M', 'N', 'K'], # argument names to use as an x-axis for the plot
x_vals=[
128 * i for i in range(2, 33)
8192
], # different possible values for `x_name`
line_arg='provider', # argument name whose value corresponds to a different line in the plot
# possible values for `line_arg``
line_vals=['cublas', 'cublas + relu', 'triton', 'triton + relu'],
line_vals=['cublas', 'triton'],
# label name for the lines
line_names=["cuBLAS", "cuBLAS (+ torch.nn.LeakyReLU)", "Triton", "Triton (+ LeakyReLU)"],
line_names=["cuBLAS", "Triton"],
# line styles
styles=[('green', '-'), ('green', '--'), ('blue', '-'), ('blue', '--')],
ylabel="TFLOPS", # label name for the y-axis
@@ -336,18 +327,9 @@ def benchmark(M, N, K, provider):
a = torch.randn((M, K), device='cuda', dtype=torch.float16)
b = torch.randn((K, N), device='cuda', dtype=torch.float16)
if provider == 'cublas':
ms, min_ms, max_ms = triton.testing.do_bench(lambda: torch.matmul(a, b))
ms, min_ms, max_ms = triton.testing.do_bench(lambda: torch.matmul(a, b), rep=100)
if provider == 'triton':
ms, min_ms, max_ms = triton.testing.do_bench(lambda: matmul(a, b))
if provider == 'cublas + relu':
torch_relu = torch.nn.ReLU(inplace=True)
ms, min_ms, max_ms = triton.testing.do_bench(
lambda: torch_relu(torch.matmul(a, b))
)
if provider == 'triton + relu':
ms, min_ms, max_ms = triton.testing.do_bench(
lambda: matmul(a, b, activation=leaky_relu)
)
ms, min_ms, max_ms = triton.testing.do_bench(lambda: matmul(a, b), rep=100)
perf = lambda ms: 2 * M * N * K * 1e-12 / (ms * 1e-3)
return perf(ms), perf(max_ms), perf(min_ms)