Merge remote-tracking branch 'origin/triton-mlir' into port-fma
This commit is contained in:
29
.github/workflows/integration-tests.yml
vendored
29
.github/workflows/integration-tests.yml
vendored
@@ -8,20 +8,19 @@ on:
|
|||||||
- triton-mlir
|
- triton-mlir
|
||||||
|
|
||||||
jobs:
|
jobs:
|
||||||
|
|
||||||
Runner-Preparation:
|
Runner-Preparation:
|
||||||
runs-on: ubuntu-latest
|
runs-on: ubuntu-latest
|
||||||
outputs:
|
outputs:
|
||||||
matrix: ${{ steps.set-matrix.outputs.matrix }}
|
matrix: ${{ steps.set-matrix.outputs.matrix }}
|
||||||
steps:
|
steps:
|
||||||
- name: Prepare runner matrix
|
- name: Prepare runner matrix
|
||||||
id: set-matrix
|
id: set-matrix
|
||||||
run: |
|
run: |
|
||||||
if [ x"${{ github.repository }}" == x"openai/triton" ]; then
|
if [ x"${{ github.repository }}" == x"openai/triton" ]; then
|
||||||
echo '::set-output name=matrix::[["self-hosted", "A10"], "macos-latest"]'
|
echo '::set-output name=matrix::[["self-hosted", "A10"], "macos-10.15"]'
|
||||||
else
|
else
|
||||||
echo '::set-output name=matrix::["ubuntu-latest", "macos-latest"]'
|
echo '::set-output name=matrix::["ubuntu-latest", "macos-10.15"]'
|
||||||
fi
|
fi
|
||||||
|
|
||||||
Integration-Tests:
|
Integration-Tests:
|
||||||
needs: Runner-Preparation
|
needs: Runner-Preparation
|
||||||
@@ -33,7 +32,6 @@ jobs:
|
|||||||
runner: ${{fromJson(needs.Runner-Preparation.outputs.matrix)}}
|
runner: ${{fromJson(needs.Runner-Preparation.outputs.matrix)}}
|
||||||
|
|
||||||
steps:
|
steps:
|
||||||
|
|
||||||
- name: Checkout
|
- name: Checkout
|
||||||
uses: actions/checkout@v2
|
uses: actions/checkout@v2
|
||||||
|
|
||||||
@@ -42,33 +40,32 @@ jobs:
|
|||||||
rm -rf ~/.triton/cache/
|
rm -rf ~/.triton/cache/
|
||||||
|
|
||||||
- name: Check imports
|
- name: Check imports
|
||||||
if: ${{ matrix.runner != 'macos-latest' }}
|
if: startsWith(matrix.runner, 'ubuntu')
|
||||||
run: |
|
run: |
|
||||||
pip install isort
|
pip install isort
|
||||||
isort -c ./python || ( echo '::error title=Imports not sorted::Please run \"isort ./python\"' ; exit 1 )
|
isort -c ./python || ( echo '::error title=Imports not sorted::Please run \"isort ./python\"' ; exit 1 )
|
||||||
|
|
||||||
- name: Check python style
|
- name: Check python style
|
||||||
if: ${{ matrix.runner != 'macos-latest' }}
|
if: startsWith(matrix.runner, 'ubuntu')
|
||||||
run: |
|
run: |
|
||||||
pip install autopep8
|
pip install autopep8
|
||||||
autopep8 -a -r -d --exit-code ./python || ( echo '::error title=Style issues::Please run \"autopep8 -a -r -i ./python\"' ; exit 1 )
|
autopep8 -a -r -d --exit-code ./python || ( echo '::error title=Style issues::Please run \"autopep8 -a -r -i ./python\"' ; exit 1 )
|
||||||
|
|
||||||
- name: Check cpp style
|
- name: Check cpp style
|
||||||
if: ${{ matrix.runner != 'macos-latest' }}
|
if: startsWith(matrix.runner, 'ubuntu')
|
||||||
run: |
|
run: |
|
||||||
pip install clang-format
|
pip install clang-format
|
||||||
find . -regex '.*\.\(cpp\|hpp\|h\|cc\)' -not -path "./python/build/*" -not -path "./include/triton/external/*" -print0 | xargs -0 -n1 clang-format -style=file --dry-run -Werror -i ||
|
find . -regex '.*\.\(cpp\|hpp\|h\|cc\)' -not -path "./python/build/*" -not -path "./include/triton/external/*" -print0 | xargs -0 -n1 clang-format -style=file --dry-run -Werror -i ||
|
||||||
(echo '::error title=Style issues:: Please run `find . -regex ".*\.\(cpp\|hpp\|h\|cc\)" -not -path "./python/build/*" -not -path "./include/triton/external/*" -print0 | xargs -0 -n1 clang-format -style=file -i`' ; exit 1)
|
(echo '::error title=Style issues:: Please run `find . -regex ".*\.\(cpp\|hpp\|h\|cc\)" -not -path "./python/build/*" -not -path "./include/triton/external/*" -print0 | xargs -0 -n1 clang-format -style=file -i`' ; exit 1)
|
||||||
|
|
||||||
- name: Flake8
|
- name: Flake8
|
||||||
if: ${{ matrix.runner != 'macos-latest' }}
|
if: startsWith(matrix.runner, 'ubuntu')
|
||||||
run: |
|
run: |
|
||||||
pip install flake8
|
pip install flake8
|
||||||
flake8 --config ./python/setup.cfg ./python || ( echo '::error::Flake8 failed; see logs for errors.' ; exit 1 )
|
flake8 --config ./python/setup.cfg ./python || ( echo '::error::Flake8 failed; see logs for errors.' ; exit 1 )
|
||||||
|
|
||||||
- name: Install Triton
|
- name: Install Triton
|
||||||
run: |
|
run: |
|
||||||
alias python='python3'
|
|
||||||
cd python
|
cd python
|
||||||
TRITON_USE_ASSERT_ENABLED_LLVM=TRUE pip3 install -e '.[tests]'
|
TRITON_USE_ASSERT_ENABLED_LLVM=TRUE pip3 install -e '.[tests]'
|
||||||
|
|
||||||
@@ -82,7 +79,7 @@ jobs:
|
|||||||
lit -v "$LIT_TEST_DIR"
|
lit -v "$LIT_TEST_DIR"
|
||||||
|
|
||||||
- name: Run python tests
|
- name: Run python tests
|
||||||
if: ${{ matrix.runner[0] == 'self-hosted' }}
|
if: ${{matrix.runner[0] == 'self-hosted'}}
|
||||||
run: |
|
run: |
|
||||||
cd python/tests
|
cd python/tests
|
||||||
pytest
|
pytest
|
||||||
|
@@ -30,8 +30,8 @@ class TT_Op<string mnemonic, list<Trait> traits = []> :
|
|||||||
// fptoui, fptosi, uitofp, sitofp,
|
// fptoui, fptosi, uitofp, sitofp,
|
||||||
// extf, tructf,
|
// extf, tructf,
|
||||||
// extui, extsi, tructi
|
// extui, extsi, tructi
|
||||||
def TT_IntToPtrOp : TT_Op<"int_to_ptr", [SameOperandsAndResultShape,
|
def TT_IntToPtrOp : TT_Op<"int_to_ptr", [SameOperandsAndResultShape,
|
||||||
SameOperandsAndResultEncoding,
|
SameOperandsAndResultEncoding,
|
||||||
NoSideEffect,
|
NoSideEffect,
|
||||||
/*DeclareOpInterfaceMethods<CastOpInterface>*/]> {
|
/*DeclareOpInterfaceMethods<CastOpInterface>*/]> {
|
||||||
let summary = "Cast int64 to pointer";
|
let summary = "Cast int64 to pointer";
|
||||||
@@ -43,7 +43,7 @@ def TT_IntToPtrOp : TT_Op<"int_to_ptr", [SameOperandsAndResultShape,
|
|||||||
let assemblyFormat = "$from attr-dict `:` type($from) `->` type($result)";
|
let assemblyFormat = "$from attr-dict `:` type($from) `->` type($result)";
|
||||||
}
|
}
|
||||||
|
|
||||||
def TT_PtrToIntOp : TT_Op<"ptr_to_int", [SameOperandsAndResultShape,
|
def TT_PtrToIntOp : TT_Op<"ptr_to_int", [SameOperandsAndResultShape,
|
||||||
SameOperandsAndResultEncoding,
|
SameOperandsAndResultEncoding,
|
||||||
NoSideEffect,
|
NoSideEffect,
|
||||||
/*DeclareOpInterfaceMethods<CastOpInterface>*/]> {
|
/*DeclareOpInterfaceMethods<CastOpInterface>*/]> {
|
||||||
@@ -57,7 +57,7 @@ def TT_PtrToIntOp : TT_Op<"ptr_to_int", [SameOperandsAndResultShape,
|
|||||||
}
|
}
|
||||||
|
|
||||||
// arith.bitcast doesn't support pointers
|
// arith.bitcast doesn't support pointers
|
||||||
def TT_BitcastOp : TT_Op<"bitcast", [SameOperandsAndResultShape,
|
def TT_BitcastOp : TT_Op<"bitcast", [SameOperandsAndResultShape,
|
||||||
SameOperandsAndResultEncoding,
|
SameOperandsAndResultEncoding,
|
||||||
NoSideEffect,
|
NoSideEffect,
|
||||||
/*DeclareOpInterfaceMethods<CastOpInterface>*/]> {
|
/*DeclareOpInterfaceMethods<CastOpInterface>*/]> {
|
||||||
@@ -72,7 +72,7 @@ def TT_BitcastOp : TT_Op<"bitcast", [SameOperandsAndResultShape,
|
|||||||
// TODO: Add verifier
|
// TODO: Add verifier
|
||||||
}
|
}
|
||||||
|
|
||||||
def TT_FpToFp : TT_Op<"fp_to_fp", [SameOperandsAndResultShape,
|
def TT_FpToFp : TT_Op<"fp_to_fp", [SameOperandsAndResultShape,
|
||||||
SameOperandsAndResultEncoding,
|
SameOperandsAndResultEncoding,
|
||||||
NoSideEffect,
|
NoSideEffect,
|
||||||
/*DeclareOpInterfaceMethods<CastOpInterface>*/]> {
|
/*DeclareOpInterfaceMethods<CastOpInterface>*/]> {
|
||||||
@@ -99,7 +99,7 @@ def TT_FpToFp : TT_Op<"fp_to_fp", [SameOperandsAndResultShape,
|
|||||||
//
|
//
|
||||||
|
|
||||||
def TT_AddPtrOp : TT_Op<"addptr",
|
def TT_AddPtrOp : TT_Op<"addptr",
|
||||||
[NoSideEffect,
|
[NoSideEffect,
|
||||||
SameOperandsAndResultShape,
|
SameOperandsAndResultShape,
|
||||||
SameOperandsAndResultEncoding,
|
SameOperandsAndResultEncoding,
|
||||||
TypesMatchWith<"result type matches ptr type",
|
TypesMatchWith<"result type matches ptr type",
|
||||||
@@ -224,7 +224,7 @@ def TT_AtomicCASOp : TT_Op<"atomic_cas", [SameOperandsAndResultShape,
|
|||||||
//
|
//
|
||||||
// Shape Manipulation Ops
|
// Shape Manipulation Ops
|
||||||
//
|
//
|
||||||
def TT_SplatOp : TT_Op<"splat", [NoSideEffect,
|
def TT_SplatOp : TT_Op<"splat", [NoSideEffect,
|
||||||
SameOperandsAndResultElementType]> {
|
SameOperandsAndResultElementType]> {
|
||||||
let summary = "splat";
|
let summary = "splat";
|
||||||
|
|
||||||
@@ -237,8 +237,8 @@ def TT_SplatOp : TT_Op<"splat", [NoSideEffect,
|
|||||||
let hasFolder = 1;
|
let hasFolder = 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
def TT_ExpandDimsOp : TT_Op<"expand_dims", [NoSideEffect,
|
def TT_ExpandDimsOp : TT_Op<"expand_dims", [NoSideEffect,
|
||||||
DeclareOpInterfaceMethods<InferTypeOpInterface>,
|
DeclareOpInterfaceMethods<InferTypeOpInterface>,
|
||||||
SameOperandsAndResultElementType]> {
|
SameOperandsAndResultElementType]> {
|
||||||
let summary = "expand_dims";
|
let summary = "expand_dims";
|
||||||
|
|
||||||
@@ -249,7 +249,7 @@ def TT_ExpandDimsOp : TT_Op<"expand_dims", [NoSideEffect,
|
|||||||
let assemblyFormat = "$src attr-dict `:` functional-type(operands, results)";
|
let assemblyFormat = "$src attr-dict `:` functional-type(operands, results)";
|
||||||
}
|
}
|
||||||
|
|
||||||
def TT_ViewOp : TT_Op<"view", [NoSideEffect,
|
def TT_ViewOp : TT_Op<"view", [NoSideEffect,
|
||||||
SameOperandsAndResultElementType]> {
|
SameOperandsAndResultElementType]> {
|
||||||
let summary = "view";
|
let summary = "view";
|
||||||
|
|
||||||
@@ -261,7 +261,7 @@ def TT_ViewOp : TT_Op<"view", [NoSideEffect,
|
|||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
def TT_BroadcastOp : TT_Op<"broadcast", [NoSideEffect,
|
def TT_BroadcastOp : TT_Op<"broadcast", [NoSideEffect,
|
||||||
SameOperandsAndResultElementType]> {
|
SameOperandsAndResultElementType]> {
|
||||||
let summary = "broadcast. No left-padding as of now.";
|
let summary = "broadcast. No left-padding as of now.";
|
||||||
|
|
||||||
@@ -274,7 +274,7 @@ def TT_BroadcastOp : TT_Op<"broadcast", [NoSideEffect,
|
|||||||
let hasFolder = 1;
|
let hasFolder = 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
def TT_CatOp : TT_Op<"cat", [NoSideEffect,
|
def TT_CatOp : TT_Op<"cat", [NoSideEffect,
|
||||||
SameOperandsAndResultElementType]> {
|
SameOperandsAndResultElementType]> {
|
||||||
let summary = "concatenate 2 tensors";
|
let summary = "concatenate 2 tensors";
|
||||||
|
|
||||||
@@ -307,7 +307,7 @@ def TT_GetNumProgramsOp : TT_Op<"get_num_programs", [NoSideEffect]> {
|
|||||||
//
|
//
|
||||||
// Dot Op
|
// Dot Op
|
||||||
//
|
//
|
||||||
def TT_DotOp : TT_Op<"dot", [NoSideEffect,
|
def TT_DotOp : TT_Op<"dot", [NoSideEffect,
|
||||||
DeclareOpInterfaceMethods<InferTypeOpInterface>,
|
DeclareOpInterfaceMethods<InferTypeOpInterface>,
|
||||||
TypesMatchWith<"result's type matches accumulator's type",
|
TypesMatchWith<"result's type matches accumulator's type",
|
||||||
"d", "c", "$_self">]> {
|
"d", "c", "$_self">]> {
|
||||||
@@ -357,12 +357,11 @@ def TT_ExtElemwiseOp : TT_Op<"ext_elemwise", [NoSideEffect, Elementwise, SameOpe
|
|||||||
return $libpath/$libname:$symbol($args...)
|
return $libpath/$libname:$symbol($args...)
|
||||||
}];
|
}];
|
||||||
|
|
||||||
let arguments = (ins Variadic<TT_Tensor>:$args, StrAttr:$libname, StrAttr:$libpath, StrAttr:$symbol);
|
let arguments = (ins Variadic<TT_Type>:$args, StrAttr:$libname, StrAttr:$libpath, StrAttr:$symbol);
|
||||||
|
|
||||||
let results = (outs TT_Tensor:$result);
|
let results = (outs TT_Type:$result);
|
||||||
|
|
||||||
let assemblyFormat = "operands attr-dict `:` type(operands) `->` type($result)";
|
let assemblyFormat = "operands attr-dict `:` type(operands) `->` type($result)";
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
//
|
//
|
||||||
@@ -385,4 +384,20 @@ def TT_MakeRangeOp : TT_Op<"make_range", [NoSideEffect]> {
|
|||||||
let assemblyFormat = "attr-dict `:` type($result)";
|
let assemblyFormat = "attr-dict `:` type($result)";
|
||||||
}
|
}
|
||||||
|
|
||||||
|
//
|
||||||
|
// Make PrintfOp
|
||||||
|
//
|
||||||
|
def TT_PrintfOp : TT_Op<"printf", [MemoryEffects<[MemWrite]>]>,
|
||||||
|
Arguments<(ins StrAttr:$prefix,
|
||||||
|
Variadic<AnyTypeOf<[TT_Type]>>:$args)> {
|
||||||
|
let summary = "Device-side printf, as in CUDA for debugging";
|
||||||
|
let description = [{
|
||||||
|
`tt.printf` takes a literal string prefix and an arbitrary number of scalar or tensor arguments that should be printed.
|
||||||
|
format are generated automatically from the arguments.
|
||||||
|
}];
|
||||||
|
let assemblyFormat = [{
|
||||||
|
$prefix attr-dict ($args^ `:` type($args))?
|
||||||
|
}];
|
||||||
|
}
|
||||||
|
|
||||||
#endif // Triton_OPS
|
#endif // Triton_OPS
|
||||||
|
@@ -159,6 +159,16 @@ ChangeResult AxisInfoAnalysis::visitOperation(
|
|||||||
curr = visitBinaryOp(op, operands[0]->getValue(), operands[1]->getValue(),
|
curr = visitBinaryOp(op, operands[0]->getValue(), operands[1]->getValue(),
|
||||||
newContiguity, newDivisibility, newConstancy);
|
newContiguity, newDivisibility, newConstancy);
|
||||||
}
|
}
|
||||||
|
// TODO: All other binary ops
|
||||||
|
if (llvm::isa<arith::AndIOp, arith::OrIOp>(op)) {
|
||||||
|
auto newContiguity = [](AxisInfo lhs, AxisInfo rhs, int d) { return 1; };
|
||||||
|
auto newDivisibility = [](AxisInfo lhs, AxisInfo rhs, int d) { return 1; };
|
||||||
|
auto newConstancy = [](AxisInfo lhs, AxisInfo rhs, int d) {
|
||||||
|
return gcd(lhs.getConstancy(d), rhs.getConstancy(d));
|
||||||
|
};
|
||||||
|
curr = visitBinaryOp(op, operands[0]->getValue(), operands[1]->getValue(),
|
||||||
|
newContiguity, newDivisibility, newConstancy);
|
||||||
|
}
|
||||||
// Splat
|
// Splat
|
||||||
if (llvm::isa<triton::SplatOp>(op)) {
|
if (llvm::isa<triton::SplatOp>(op)) {
|
||||||
Type _retTy = *op->result_type_begin();
|
Type _retTy = *op->result_type_begin();
|
||||||
@@ -200,7 +210,8 @@ ChangeResult AxisInfoAnalysis::visitOperation(
|
|||||||
for (int d = 0; d < retTy.getRank(); ++d) {
|
for (int d = 0; d < retTy.getRank(); ++d) {
|
||||||
contiguity.push_back(opShape[d] == 1 ? 1 : opInfo.getContiguity(d));
|
contiguity.push_back(opShape[d] == 1 ? 1 : opInfo.getContiguity(d));
|
||||||
divisibility.push_back(opInfo.getDivisibility(d));
|
divisibility.push_back(opInfo.getDivisibility(d));
|
||||||
constancy.push_back(opShape[d] == 1 ? retShape[d] : 1);
|
constancy.push_back(opShape[d] == 1 ? retShape[d]
|
||||||
|
: opInfo.getConstancy(d));
|
||||||
}
|
}
|
||||||
curr = AxisInfo(contiguity, divisibility, constancy);
|
curr = AxisInfo(contiguity, divisibility, constancy);
|
||||||
}
|
}
|
||||||
|
@@ -119,9 +119,11 @@ Value createLLVMIntegerConstant(OpBuilder &builder, Location loc, short width,
|
|||||||
#define barrier() rewriter.create<mlir::gpu::BarrierOp>(loc)
|
#define barrier() rewriter.create<mlir::gpu::BarrierOp>(loc)
|
||||||
#define undef(...) rewriter.create<LLVM::UndefOp>(loc, __VA_ARGS__)
|
#define undef(...) rewriter.create<LLVM::UndefOp>(loc, __VA_ARGS__)
|
||||||
#define i32_ty rewriter.getIntegerType(32)
|
#define i32_ty rewriter.getIntegerType(32)
|
||||||
|
#define ui32_ty rewriter.getIntegerType(32, false)
|
||||||
#define f16_ty rewriter.getF16Type()
|
#define f16_ty rewriter.getF16Type()
|
||||||
#define i8_ty rewriter.getIntegerType(8)
|
#define i8_ty rewriter.getIntegerType(8)
|
||||||
#define f32_ty rewriter.getF32Type()
|
#define f32_ty rewriter.getF32Type()
|
||||||
|
#define f64_ty rewriter.getF64Type()
|
||||||
#define vec_ty(type, num) VectorType::get(num, type)
|
#define vec_ty(type, num) VectorType::get(num, type)
|
||||||
#define f32_val(...) LLVM::createConstantF32(loc, rewriter, __VA_ARGS__)
|
#define f32_val(...) LLVM::createConstantF32(loc, rewriter, __VA_ARGS__)
|
||||||
#define void_ty(ctx) LLVM::LLVMVoidType::get(ctx)
|
#define void_ty(ctx) LLVM::LLVMVoidType::get(ctx)
|
||||||
@@ -691,7 +693,8 @@ Value convertSplatLikeOp(Type elemType, Type resType, Value constVal,
|
|||||||
TypeConverter *typeConverter,
|
TypeConverter *typeConverter,
|
||||||
ConversionPatternRewriter &rewriter, Location loc) {
|
ConversionPatternRewriter &rewriter, Location loc) {
|
||||||
auto tensorTy = resType.cast<RankedTensorType>();
|
auto tensorTy = resType.cast<RankedTensorType>();
|
||||||
if (tensorTy.getEncoding().isa<BlockedEncodingAttr>()) {
|
if (tensorTy.getEncoding().isa<BlockedEncodingAttr>() ||
|
||||||
|
tensorTy.getEncoding().isa<SliceEncodingAttr>()) {
|
||||||
auto tensorTy = resType.cast<RankedTensorType>();
|
auto tensorTy = resType.cast<RankedTensorType>();
|
||||||
auto srcType = typeConverter->convertType(elemType);
|
auto srcType = typeConverter->convertType(elemType);
|
||||||
auto llSrc = bitcast(constVal, srcType);
|
auto llSrc = bitcast(constVal, srcType);
|
||||||
@@ -1807,6 +1810,7 @@ public:
|
|||||||
}
|
}
|
||||||
Value view = getStructFromElements(loc, resultVals, rewriter, structTy);
|
Value view = getStructFromElements(loc, resultVals, rewriter, structTy);
|
||||||
rewriter.replaceOp(op, view);
|
rewriter.replaceOp(op, view);
|
||||||
|
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -4731,6 +4735,190 @@ struct FDivOpConversion
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
struct PrintfOpConversion
|
||||||
|
: public ConvertTritonGPUOpToLLVMPattern<triton::PrintfOp> {
|
||||||
|
using ConvertTritonGPUOpToLLVMPattern<
|
||||||
|
triton::PrintfOp>::ConvertTritonGPUOpToLLVMPattern;
|
||||||
|
|
||||||
|
LogicalResult
|
||||||
|
matchAndRewrite(triton::PrintfOp op, OpAdaptor adaptor,
|
||||||
|
ConversionPatternRewriter &rewriter) const override {
|
||||||
|
auto loc = op->getLoc();
|
||||||
|
SmallVector<Value, 16> operands;
|
||||||
|
for (auto operand : adaptor.getOperands()) {
|
||||||
|
auto sub_operands = this->getElementsFromStruct(loc, operand, rewriter);
|
||||||
|
for (auto elem : sub_operands) {
|
||||||
|
operands.push_back(elem);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
std::string formatStr;
|
||||||
|
llvm::raw_string_ostream os(formatStr);
|
||||||
|
os << op.prefix();
|
||||||
|
if (operands.size() > 0) {
|
||||||
|
os << getFormatSubstr(operands[0]);
|
||||||
|
}
|
||||||
|
|
||||||
|
for (size_t i = 1; i < operands.size(); ++i) {
|
||||||
|
os << ", " << getFormatSubstr(operands[i]);
|
||||||
|
}
|
||||||
|
llPrintf(formatStr, operands, rewriter);
|
||||||
|
rewriter.eraseOp(op);
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
// get format specific for each input value
|
||||||
|
// currently support pointer, i8, i16, i32, i64, f16, bf16, f32, f64
|
||||||
|
std::string getFormatSubstr(Value value) const {
|
||||||
|
Type type = value.getType();
|
||||||
|
unsigned width = type.getIntOrFloatBitWidth();
|
||||||
|
|
||||||
|
if (type.isa<LLVM::LLVMPointerType>()) {
|
||||||
|
return "%p";
|
||||||
|
} else if (type.isBF16() || type.isF16() || type.isF32() || type.isF64()) {
|
||||||
|
return "%f";
|
||||||
|
} else if (type.isSignedInteger()) {
|
||||||
|
return "%i";
|
||||||
|
} else if (type.isUnsignedInteger() || type.isSignlessInteger()) {
|
||||||
|
return "%u";
|
||||||
|
}
|
||||||
|
assert(false && "not supported type");
|
||||||
|
}
|
||||||
|
|
||||||
|
// declare vprintf(i8*, i8*) as external function
|
||||||
|
LLVM::LLVMFuncOp
|
||||||
|
getVprintfDeclaration(ConversionPatternRewriter &rewriter) const {
|
||||||
|
auto moduleOp =
|
||||||
|
rewriter.getBlock()->getParent()->getParentOfType<ModuleOp>();
|
||||||
|
StringRef funcName("vprintf");
|
||||||
|
Operation *funcOp = moduleOp.lookupSymbol(funcName);
|
||||||
|
if (funcOp)
|
||||||
|
return cast<LLVM::LLVMFuncOp>(*funcOp);
|
||||||
|
|
||||||
|
auto *context = rewriter.getContext();
|
||||||
|
|
||||||
|
SmallVector<Type> argsType{ptr_ty(IntegerType::get(context, 8)),
|
||||||
|
ptr_ty(IntegerType::get(context, 8))};
|
||||||
|
auto funcType = LLVM::LLVMFunctionType::get(i32_ty, argsType);
|
||||||
|
|
||||||
|
ConversionPatternRewriter::InsertionGuard guard(rewriter);
|
||||||
|
rewriter.setInsertionPointToStart(moduleOp.getBody());
|
||||||
|
|
||||||
|
return rewriter.create<LLVM::LLVMFuncOp>(UnknownLoc::get(context), funcName,
|
||||||
|
funcType);
|
||||||
|
}
|
||||||
|
|
||||||
|
// extend integer to int32, extend float to float64
|
||||||
|
// this comes from vprintf alignment requirements.
|
||||||
|
std::pair<Type, Value> promoteValue(ConversionPatternRewriter &rewriter,
|
||||||
|
Value value) const {
|
||||||
|
auto *context = rewriter.getContext();
|
||||||
|
auto type = value.getType();
|
||||||
|
unsigned width = type.getIntOrFloatBitWidth();
|
||||||
|
Value newOp = value;
|
||||||
|
Type newType = type;
|
||||||
|
|
||||||
|
bool bUnsigned = type.isUnsignedInteger();
|
||||||
|
if (type.isIntOrIndex() && width < 32) {
|
||||||
|
if (bUnsigned) {
|
||||||
|
newType = ui32_ty;
|
||||||
|
newOp = rewriter.create<LLVM::ZExtOp>(UnknownLoc::get(context), newType,
|
||||||
|
value);
|
||||||
|
} else {
|
||||||
|
newType = i32_ty;
|
||||||
|
newOp = rewriter.create<LLVM::SExtOp>(UnknownLoc::get(context), newType,
|
||||||
|
value);
|
||||||
|
}
|
||||||
|
} else if (type.isBF16() || type.isF16() || type.isF32()) {
|
||||||
|
newType = f64_ty;
|
||||||
|
newOp = rewriter.create<LLVM::FPExtOp>(UnknownLoc::get(context), newType,
|
||||||
|
value);
|
||||||
|
}
|
||||||
|
|
||||||
|
return {newType, newOp};
|
||||||
|
}
|
||||||
|
|
||||||
|
void llPrintf(StringRef msg, ValueRange args,
|
||||||
|
ConversionPatternRewriter &rewriter) const {
|
||||||
|
static const char formatStringPrefix[] = "printfFormat_";
|
||||||
|
assert(!msg.empty() && "printf with empty string not support");
|
||||||
|
Type int8Ptr = ptr_ty(i8_ty);
|
||||||
|
|
||||||
|
auto *context = rewriter.getContext();
|
||||||
|
auto moduleOp =
|
||||||
|
rewriter.getBlock()->getParent()->getParentOfType<ModuleOp>();
|
||||||
|
auto funcOp = getVprintfDeclaration(rewriter);
|
||||||
|
|
||||||
|
Value one = rewriter.create<LLVM::ConstantOp>(
|
||||||
|
UnknownLoc::get(context), i32_ty, rewriter.getI32IntegerAttr(1));
|
||||||
|
Value zero = rewriter.create<LLVM::ConstantOp>(
|
||||||
|
UnknownLoc::get(context), i32_ty, rewriter.getI32IntegerAttr(0));
|
||||||
|
|
||||||
|
unsigned stringNumber = 0;
|
||||||
|
SmallString<16> stringConstName;
|
||||||
|
do {
|
||||||
|
stringConstName.clear();
|
||||||
|
(formatStringPrefix + Twine(stringNumber++)).toStringRef(stringConstName);
|
||||||
|
} while (moduleOp.lookupSymbol(stringConstName));
|
||||||
|
|
||||||
|
llvm::SmallString<64> formatString(msg);
|
||||||
|
formatString.push_back('\n');
|
||||||
|
formatString.push_back('\0');
|
||||||
|
size_t formatStringSize = formatString.size_in_bytes();
|
||||||
|
auto globalType = LLVM::LLVMArrayType::get(i8_ty, formatStringSize);
|
||||||
|
|
||||||
|
LLVM::GlobalOp global;
|
||||||
|
{
|
||||||
|
ConversionPatternRewriter::InsertionGuard guard(rewriter);
|
||||||
|
rewriter.setInsertionPointToStart(moduleOp.getBody());
|
||||||
|
global = rewriter.create<LLVM::GlobalOp>(
|
||||||
|
UnknownLoc::get(context), globalType,
|
||||||
|
/*isConstant=*/true, LLVM::Linkage::Internal, stringConstName,
|
||||||
|
rewriter.getStringAttr(formatString));
|
||||||
|
}
|
||||||
|
|
||||||
|
Value globalPtr =
|
||||||
|
rewriter.create<LLVM::AddressOfOp>(UnknownLoc::get(context), global);
|
||||||
|
Value stringStart =
|
||||||
|
rewriter.create<LLVM::GEPOp>(UnknownLoc::get(context), int8Ptr,
|
||||||
|
globalPtr, mlir::ValueRange({zero, zero}));
|
||||||
|
|
||||||
|
Value bufferPtr =
|
||||||
|
rewriter.create<LLVM::NullOp>(UnknownLoc::get(context), int8Ptr);
|
||||||
|
|
||||||
|
SmallVector<Value, 16> newArgs;
|
||||||
|
if (args.size() >= 1) {
|
||||||
|
SmallVector<Type> argTypes;
|
||||||
|
for (auto arg : args) {
|
||||||
|
Type newType;
|
||||||
|
Value newArg;
|
||||||
|
std::tie(newType, newArg) = promoteValue(rewriter, arg);
|
||||||
|
argTypes.push_back(newType);
|
||||||
|
newArgs.push_back(newArg);
|
||||||
|
}
|
||||||
|
|
||||||
|
Type structTy = LLVM::LLVMStructType::getLiteral(context, argTypes);
|
||||||
|
auto allocated = rewriter.create<LLVM::AllocaOp>(UnknownLoc::get(context),
|
||||||
|
ptr_ty(structTy), one,
|
||||||
|
/*alignment=*/0);
|
||||||
|
|
||||||
|
for (const auto &entry : llvm::enumerate(newArgs)) {
|
||||||
|
auto index = rewriter.create<LLVM::ConstantOp>(
|
||||||
|
UnknownLoc::get(context), i32_ty,
|
||||||
|
rewriter.getI32IntegerAttr(entry.index()));
|
||||||
|
auto fieldPtr = rewriter.create<LLVM::GEPOp>(
|
||||||
|
UnknownLoc::get(context), ptr_ty(argTypes[entry.index()]),
|
||||||
|
allocated, ArrayRef<Value>{zero, index});
|
||||||
|
rewriter.create<LLVM::StoreOp>(UnknownLoc::get(context), entry.value(),
|
||||||
|
fieldPtr);
|
||||||
|
}
|
||||||
|
bufferPtr = rewriter.create<LLVM::BitcastOp>(UnknownLoc::get(context),
|
||||||
|
int8Ptr, allocated);
|
||||||
|
}
|
||||||
|
|
||||||
|
ValueRange operands{stringStart, bufferPtr};
|
||||||
|
rewriter.create<LLVM::CallOp>(UnknownLoc::get(context), funcOp, operands);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
void populateTritonToLLVMPatterns(mlir::LLVMTypeConverter &typeConverter,
|
void populateTritonToLLVMPatterns(mlir::LLVMTypeConverter &typeConverter,
|
||||||
RewritePatternSet &patterns, int numWarps,
|
RewritePatternSet &patterns, int numWarps,
|
||||||
AxisInfoAnalysis &axisInfoAnalysis,
|
AxisInfoAnalysis &axisInfoAnalysis,
|
||||||
@@ -4817,6 +5005,7 @@ void populateTritonToLLVMPatterns(mlir::LLVMTypeConverter &typeConverter,
|
|||||||
patterns.add<ViewLikeOpConversion<triton::ExpandDimsOp>>(typeConverter,
|
patterns.add<ViewLikeOpConversion<triton::ExpandDimsOp>>(typeConverter,
|
||||||
benefit);
|
benefit);
|
||||||
patterns.add<DotOpConversion>(typeConverter, allocation, smem, benefit);
|
patterns.add<DotOpConversion>(typeConverter, allocation, smem, benefit);
|
||||||
|
patterns.add<PrintfOpConversion>(typeConverter, benefit);
|
||||||
}
|
}
|
||||||
|
|
||||||
class ConvertTritonGPUToLLVM
|
class ConvertTritonGPUToLLVM
|
||||||
|
@@ -339,6 +339,18 @@ struct TritonReducePattern : public OpConversionPattern<triton::ReduceOp> {
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
struct TritonPrintfPattern : public OpConversionPattern<triton::PrintfOp> {
|
||||||
|
using OpConversionPattern<PrintfOp>::OpConversionPattern;
|
||||||
|
|
||||||
|
LogicalResult
|
||||||
|
matchAndRewrite(PrintfOp op, typename PrintfOp::Adaptor adaptor,
|
||||||
|
ConversionPatternRewriter &rewriter) const override {
|
||||||
|
rewriter.replaceOpWithNewOp<triton::PrintfOp>(op, op.prefixAttr(),
|
||||||
|
adaptor.getOperands());
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
void populateTritonPatterns(TritonGPUTypeConverter &typeConverter,
|
void populateTritonPatterns(TritonGPUTypeConverter &typeConverter,
|
||||||
RewritePatternSet &patterns) {
|
RewritePatternSet &patterns) {
|
||||||
MLIRContext *context = patterns.getContext();
|
MLIRContext *context = patterns.getContext();
|
||||||
@@ -350,8 +362,8 @@ void populateTritonPatterns(TritonGPUTypeConverter &typeConverter,
|
|||||||
TritonGenericPattern<triton::SplatOp>, TritonBroadcastPattern,
|
TritonGenericPattern<triton::SplatOp>, TritonBroadcastPattern,
|
||||||
TritonGenericPattern<triton::AddPtrOp>, TritonReducePattern,
|
TritonGenericPattern<triton::AddPtrOp>, TritonReducePattern,
|
||||||
TritonExpandDimsPattern, TritonMakeRangePattern, TritonDotPattern,
|
TritonExpandDimsPattern, TritonMakeRangePattern, TritonDotPattern,
|
||||||
TritonLoadPattern, TritonStorePattern, TritonExtElemwisePattern>(
|
TritonLoadPattern, TritonStorePattern, TritonExtElemwisePattern,
|
||||||
typeConverter, context);
|
TritonPrintfPattern>(typeConverter, context);
|
||||||
}
|
}
|
||||||
|
|
||||||
//
|
//
|
||||||
|
@@ -533,6 +533,35 @@ public:
|
|||||||
BlockedToMMA(mlir::MLIRContext *context)
|
BlockedToMMA(mlir::MLIRContext *context)
|
||||||
: mlir::RewritePattern(triton::DotOp::getOperationName(), 2, context) {}
|
: mlir::RewritePattern(triton::DotOp::getOperationName(), 2, context) {}
|
||||||
|
|
||||||
|
static SmallVector<unsigned, 2>
|
||||||
|
getWarpsPerTile(const ArrayRef<int64_t> &shape, int version, int numWarps) {
|
||||||
|
assert(version == 2);
|
||||||
|
// TODO: Handle one warp per row for fused matmuls
|
||||||
|
// TODO: unsigned -> int64_t to keep things uniform
|
||||||
|
SmallVector<unsigned, 2> ret = {1, 1};
|
||||||
|
SmallVector<int64_t, 2> shapePerWarp = {16, 8};
|
||||||
|
bool changed = false;
|
||||||
|
// TODO (@daadaada): double-check.
|
||||||
|
// original logic in
|
||||||
|
// https://github.com/openai/triton/blob/master/lib/codegen/analysis/layout.cc#L252
|
||||||
|
// seems buggy for shape = [32, 16] ?
|
||||||
|
do {
|
||||||
|
changed = false;
|
||||||
|
if (ret[0] * ret[1] >= numWarps)
|
||||||
|
break;
|
||||||
|
if (shape[0] / shapePerWarp[0] / ret[0] >=
|
||||||
|
shape[1] / (shapePerWarp[1] * 2) / ret[1]) {
|
||||||
|
if (ret[0] < shape[0] / shapePerWarp[0]) {
|
||||||
|
ret[0] *= 2;
|
||||||
|
} else
|
||||||
|
ret[1] *= 2;
|
||||||
|
} else {
|
||||||
|
ret[1] *= 2;
|
||||||
|
}
|
||||||
|
} while (true);
|
||||||
|
return ret;
|
||||||
|
}
|
||||||
|
|
||||||
mlir::LogicalResult
|
mlir::LogicalResult
|
||||||
matchAndRewrite(mlir::Operation *op,
|
matchAndRewrite(mlir::Operation *op,
|
||||||
mlir::PatternRewriter &rewriter) const override {
|
mlir::PatternRewriter &rewriter) const override {
|
||||||
@@ -541,13 +570,20 @@ public:
|
|||||||
auto oldRetType = dotOp.getResult().getType().cast<RankedTensorType>();
|
auto oldRetType = dotOp.getResult().getType().cast<RankedTensorType>();
|
||||||
if (oldRetType.getEncoding().isa<triton::gpu::MmaEncodingAttr>())
|
if (oldRetType.getEncoding().isa<triton::gpu::MmaEncodingAttr>())
|
||||||
return failure();
|
return failure();
|
||||||
// TODO: compute warpsPerCTA
|
// get MMA encoding for the given number of warps
|
||||||
auto newRetType = RankedTensorType::get(
|
auto retShape = oldRetType.getShape();
|
||||||
oldRetType.getShape(), oldRetType.getElementType(),
|
auto mod = op->getParentOfType<mlir::ModuleOp>();
|
||||||
triton::gpu::MmaEncodingAttr::get(oldRetType.getContext(), 2, {2, 2}));
|
int numWarps = triton::gpu::TritonGPUDialect::getNumWarps(mod);
|
||||||
|
auto newRetType =
|
||||||
|
RankedTensorType::get(retShape, oldRetType.getElementType(),
|
||||||
|
triton::gpu::MmaEncodingAttr::get(
|
||||||
|
oldRetType.getContext(), 2,
|
||||||
|
getWarpsPerTile(retShape, 2, numWarps)));
|
||||||
|
// convert accumulator
|
||||||
auto oldAcc = dotOp.getOperand(2);
|
auto oldAcc = dotOp.getOperand(2);
|
||||||
auto newAcc = rewriter.create<triton::gpu::ConvertLayoutOp>(
|
auto newAcc = rewriter.create<triton::gpu::ConvertLayoutOp>(
|
||||||
oldAcc.getLoc(), newRetType, oldAcc);
|
oldAcc.getLoc(), newRetType, oldAcc);
|
||||||
|
// convert output
|
||||||
auto newDot = rewriter.create<triton::DotOp>(
|
auto newDot = rewriter.create<triton::DotOp>(
|
||||||
dotOp.getLoc(), newRetType, dotOp.getOperand(0), dotOp.getOperand(1),
|
dotOp.getLoc(), newRetType, dotOp.getOperand(0), dotOp.getOperand(1),
|
||||||
newAcc, dotOp.allowTF32(), dotOp.transA(), dotOp.transB());
|
newAcc, dotOp.allowTF32(), dotOp.transA(), dotOp.transB());
|
||||||
|
@@ -197,7 +197,8 @@ translateTritonGPUToLLVMIR(llvm::LLVMContext *llvmContext,
|
|||||||
ext_mod->setTargetTriple(llvmir->getTargetTriple());
|
ext_mod->setTargetTriple(llvmir->getTargetTriple());
|
||||||
ext_mod->setDataLayout(llvmir->getDataLayout());
|
ext_mod->setDataLayout(llvmir->getDataLayout());
|
||||||
|
|
||||||
if (llvm::Linker::linkModules(*llvmir, std::move(ext_mod))) {
|
if (llvm::Linker::linkModules(*llvmir, std::move(ext_mod),
|
||||||
|
llvm::Linker::Flags::LinkOnlyNeeded)) {
|
||||||
llvm::errs() << "Failed to link extern lib " << lib.first;
|
llvm::errs() << "Failed to link extern lib " << lib.first;
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
|
@@ -1185,6 +1185,16 @@ void init_triton_ir(py::module &&m) {
|
|||||||
auto loc = self.getUnknownLoc();
|
auto loc = self.getUnknownLoc();
|
||||||
return self.create<mlir::SelectOp>(loc, condition, trueValue,
|
return self.create<mlir::SelectOp>(loc, condition, trueValue,
|
||||||
falseValue);
|
falseValue);
|
||||||
|
})
|
||||||
|
.def("create_printf",
|
||||||
|
[](mlir::OpBuilder &self, const std::string &prefix,
|
||||||
|
const std::vector<mlir::Value> &values) -> void {
|
||||||
|
auto loc = self.getUnknownLoc();
|
||||||
|
self.create<mlir::triton::PrintfOp>(
|
||||||
|
loc,
|
||||||
|
mlir::StringAttr::get(self.getContext(),
|
||||||
|
llvm::StringRef(prefix)),
|
||||||
|
values);
|
||||||
});
|
});
|
||||||
|
|
||||||
py::class_<mlir::PassManager>(m, "pass_manager")
|
py::class_<mlir::PassManager>(m, "pass_manager")
|
||||||
|
56
python/tests/printf_helper.py
Normal file
56
python/tests/printf_helper.py
Normal file
@@ -0,0 +1,56 @@
|
|||||||
|
import torch
|
||||||
|
from torch.testing import assert_close
|
||||||
|
|
||||||
|
import triton
|
||||||
|
import triton.language as tl
|
||||||
|
|
||||||
|
torch_type = {
|
||||||
|
"bool": torch.bool,
|
||||||
|
'int8': torch.int8,
|
||||||
|
'uint8': torch.uint8,
|
||||||
|
'int16': torch.int16,
|
||||||
|
"int32": torch.int32,
|
||||||
|
'int64': torch.long,
|
||||||
|
'float16': torch.float16,
|
||||||
|
'bfloat16': torch.bfloat16,
|
||||||
|
"float32": torch.float32,
|
||||||
|
"float64": torch.float64
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def get_tensor(shape, data_type, b_positive=False):
|
||||||
|
x = None
|
||||||
|
if data_type.startswith('int'):
|
||||||
|
x = torch.arange(0, shape[0], dtype=torch_type[data_type], device='cuda')
|
||||||
|
else:
|
||||||
|
x = torch.arange(0, shape[0], dtype=torch_type[data_type], device='cuda')
|
||||||
|
|
||||||
|
return x
|
||||||
|
|
||||||
|
# @pytest.mark.parametrize('data_type',
|
||||||
|
# [("int8"),
|
||||||
|
# ('int16'),
|
||||||
|
# ('int32'),
|
||||||
|
# ("int64"),
|
||||||
|
# ('float16'),
|
||||||
|
# ("float32"),
|
||||||
|
# ("float64")])
|
||||||
|
|
||||||
|
|
||||||
|
def printf(data_type):
|
||||||
|
@triton.jit
|
||||||
|
def kernel(X, Y, BLOCK: tl.constexpr):
|
||||||
|
x = tl.load(X + tl.arange(0, BLOCK))
|
||||||
|
tl.printf("", x)
|
||||||
|
tl.store(Y + tl.arange(0, BLOCK), x)
|
||||||
|
|
||||||
|
shape = (128, )
|
||||||
|
# limit the range of integers so that the sum does not overflow
|
||||||
|
x = get_tensor(shape, data_type)
|
||||||
|
y = torch.zeros(shape, dtype=x.dtype, device="cuda")
|
||||||
|
kernel[(1,)](x, y, BLOCK=shape[0])
|
||||||
|
assert_close(y, x)
|
||||||
|
|
||||||
|
|
||||||
|
printf("float16")
|
||||||
|
printf("int8")
|
@@ -144,7 +144,7 @@ def _test_unary(dtype_x, expr, numpy_expr=None, device='cuda'):
|
|||||||
# triton result
|
# triton result
|
||||||
x_tri = to_triton(x, device=device, dst_type=dtype_x)
|
x_tri = to_triton(x, device=device, dst_type=dtype_x)
|
||||||
z_tri = to_triton(np.empty_like(z_ref), device=device, dst_type=dtype_x)
|
z_tri = to_triton(np.empty_like(z_ref), device=device, dst_type=dtype_x)
|
||||||
kernel[(1, )](z_tri, x_tri, SIZE=SIZE, num_warps=4)
|
kernel[(1, )](z_tri, x_tri, SIZE=SIZE, num_warps=4, extern_libs={"libdevice": "/usr/local/cuda/nvvm/libdevice/libdevice.10.bc"})
|
||||||
# compare
|
# compare
|
||||||
np.testing.assert_allclose(z_ref, to_numpy(z_tri), rtol=0.01)
|
np.testing.assert_allclose(z_ref, to_numpy(z_tri), rtol=0.01)
|
||||||
|
|
||||||
@@ -463,17 +463,12 @@ def test_unary_op(dtype_x, expr, device='cuda'):
|
|||||||
# # test math ops
|
# # test math ops
|
||||||
# # ----------------
|
# # ----------------
|
||||||
|
|
||||||
# TODO: Math module
|
|
||||||
# # @pytest.mark.parametrize("expr", [
|
|
||||||
# # 'exp', 'log', 'cos', 'sin'
|
|
||||||
# # ])
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("expr", [
|
||||||
# @pytest.mark.parametrize("expr", [
|
'exp', 'log', 'cos', 'sin'
|
||||||
# 'exp', 'log', 'cos', 'sin'
|
])
|
||||||
# ])
|
def test_math_op(expr, device='cuda'):
|
||||||
# def test_math_op(expr, device='cuda'):
|
_test_unary('float32', f'tl.{expr}(x)', f'np.{expr}(x) ', device=device)
|
||||||
# _test_unary('float32', f'tl.{expr}(x)', f'np.{expr}(x) ', device=device)
|
|
||||||
|
|
||||||
|
|
||||||
# # ----------------
|
# # ----------------
|
||||||
@@ -1545,43 +1540,72 @@ def test_num_warps_pow2():
|
|||||||
# # -------------
|
# # -------------
|
||||||
|
|
||||||
|
|
||||||
# @pytest.mark.parametrize("dtype_str, expr, lib_path",
|
@pytest.mark.parametrize("dtype_str, expr, lib_path",
|
||||||
# [('int32', 'libdevice.ffs', ''),
|
[('int32', 'libdevice.ffs', ''),
|
||||||
# ('float32', 'libdevice.pow', '/usr/local/cuda/nvvm/libdevice/libdevice.10.bc'),
|
('float32', 'libdevice.pow', '/usr/local/cuda/nvvm/libdevice/libdevice.10.bc'),
|
||||||
# ('float64', 'libdevice.norm4d', '')])
|
('float64', 'libdevice.norm4d', '')])
|
||||||
# def test_libdevice(dtype_str, expr, lib_path):
|
def test_libdevice_tensor(dtype_str, expr, lib_path):
|
||||||
|
|
||||||
# @triton.jit
|
@triton.jit
|
||||||
# def kernel(X, Y, BLOCK: tl.constexpr):
|
def kernel(X, Y, BLOCK: tl.constexpr):
|
||||||
# x = tl.load(X + tl.arange(0, BLOCK))
|
x = tl.load(X + tl.arange(0, BLOCK))
|
||||||
# y = GENERATE_TEST_HERE
|
y = GENERATE_TEST_HERE
|
||||||
# tl.store(Y + tl.arange(0, BLOCK), y)
|
tl.store(Y + tl.arange(0, BLOCK), y)
|
||||||
|
|
||||||
# shape = (128, )
|
shape = (128, )
|
||||||
# rs = RandomState(17)
|
rs = RandomState(17)
|
||||||
# # limit the range of integers so that the sum does not overflow
|
# limit the range of integers so that the sum does not overflow
|
||||||
# x = numpy_random(shape, dtype_str=dtype_str, rs=rs)
|
x = numpy_random(shape, dtype_str=dtype_str, rs=rs)
|
||||||
|
|
||||||
# if expr == 'libdevice.ffs':
|
if expr == 'libdevice.ffs':
|
||||||
# kernel = patch_kernel(kernel, {'GENERATE_TEST_HERE': 'tl.libdevice.ffs(x)'})
|
kernel = patch_kernel(kernel, {'GENERATE_TEST_HERE': 'tl.libdevice.ffs(x)'})
|
||||||
# y_ref = np.zeros(shape, dtype=x.dtype)
|
y_ref = np.zeros(shape, dtype=x.dtype)
|
||||||
# for i in range(shape[0]):
|
for i in range(shape[0]):
|
||||||
# y_ref[i] = (int(x[i]) & int(-x[i])).bit_length()
|
y_ref[i] = (int(x[i]) & int(-x[i])).bit_length()
|
||||||
# elif expr == 'libdevice.pow':
|
elif expr == 'libdevice.pow':
|
||||||
# # numpy does not allow negative factors in power, so we use abs()
|
# numpy does not allow negative factors in power, so we use abs()
|
||||||
# x = np.abs(x)
|
x = np.abs(x)
|
||||||
# kernel = patch_kernel(kernel, {'GENERATE_TEST_HERE': 'tl.libdevice.pow(x, x)'})
|
kernel = patch_kernel(kernel, {'GENERATE_TEST_HERE': 'tl.libdevice.pow(x, x)'})
|
||||||
# y_ref = np.power(x, x)
|
y_ref = np.power(x, x)
|
||||||
# elif expr == 'libdevice.norm4d':
|
elif expr == 'libdevice.norm4d':
|
||||||
# kernel = patch_kernel(kernel, {'GENERATE_TEST_HERE': 'tl.libdevice.norm4d(x, x, x, x)'})
|
kernel = patch_kernel(kernel, {'GENERATE_TEST_HERE': 'tl.libdevice.norm4d(x, x, x, x)'})
|
||||||
# y_ref = np.sqrt(4 * np.power(x, 2))
|
y_ref = np.sqrt(4 * np.power(x, 2))
|
||||||
|
|
||||||
# x_tri = to_triton(x)
|
x_tri = to_triton(x)
|
||||||
# # triton result
|
# triton result
|
||||||
# y_tri = to_triton(numpy_random((shape[0],), dtype_str=dtype_str, rs=rs), device='cuda')
|
y_tri = to_triton(numpy_random((shape[0],), dtype_str=dtype_str, rs=rs), device='cuda')
|
||||||
# kernel[(1,)](x_tri, y_tri, BLOCK=shape[0], extern_libs={'libdevice': lib_path})
|
kernel[(1,)](x_tri, y_tri, BLOCK=shape[0], extern_libs={'libdevice': lib_path})
|
||||||
# # compare
|
# compare
|
||||||
# if expr == 'libdevice.ffs':
|
if expr == 'libdevice.ffs':
|
||||||
# np.testing.assert_equal(y_ref, to_numpy(y_tri))
|
np.testing.assert_equal(y_ref, to_numpy(y_tri))
|
||||||
# else:
|
else:
|
||||||
# np.testing.assert_allclose(y_ref, to_numpy(y_tri), rtol=0.01)
|
np.testing.assert_allclose(y_ref, to_numpy(y_tri), rtol=0.01)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("dtype_str, expr, lib_path",
|
||||||
|
[('float32', 'libdevice.pow', '')])
|
||||||
|
def test_libdevice_scalar(dtype_str, expr, lib_path):
|
||||||
|
|
||||||
|
@triton.jit
|
||||||
|
def kernel(X, Y, BLOCK: tl.constexpr):
|
||||||
|
x = X
|
||||||
|
y = GENERATE_TEST_HERE
|
||||||
|
tl.store(Y + tl.arange(0, BLOCK), y)
|
||||||
|
|
||||||
|
shape = (128, )
|
||||||
|
rs = RandomState(17)
|
||||||
|
# limit the range of integers so that the sum does not overflow
|
||||||
|
x = numpy_random((1,), dtype_str=dtype_str, rs=rs)
|
||||||
|
y_ref = np.zeros(shape, dtype=x.dtype)
|
||||||
|
|
||||||
|
# numpy does not allow negative factors in power, so we use abs()
|
||||||
|
x = np.abs(x)
|
||||||
|
kernel = patch_kernel(kernel, {'GENERATE_TEST_HERE': 'tl.libdevice.pow(x, x)'})
|
||||||
|
y_ref[:] = np.power(x, x)
|
||||||
|
|
||||||
|
# triton result
|
||||||
|
x_tri = to_triton(x)[0].item()
|
||||||
|
y_tri = to_triton(numpy_random((shape[0],), dtype_str=dtype_str, rs=rs), device='cuda')
|
||||||
|
kernel[(1,)](x_tri, y_tri, BLOCK=shape[0], extern_libs={'libdevice': lib_path})
|
||||||
|
# compare
|
||||||
|
np.testing.assert_allclose(y_ref, to_numpy(y_tri), rtol=0.01)
|
||||||
|
21
python/tests/test_printf.py
Normal file
21
python/tests/test_printf.py
Normal file
@@ -0,0 +1,21 @@
|
|||||||
|
import os
|
||||||
|
import subprocess
|
||||||
|
|
||||||
|
dir_path = os.path.dirname(os.path.realpath(__file__))
|
||||||
|
printf_path = os.path.join(dir_path, "printf_helper.py")
|
||||||
|
|
||||||
|
|
||||||
|
def test_printf():
|
||||||
|
proc = subprocess.Popen(["python", printf_path], stdout=subprocess.PIPE, shell=False)
|
||||||
|
(outs, err) = proc.communicate()
|
||||||
|
outs = outs.split()
|
||||||
|
new_lines = set()
|
||||||
|
for line in outs:
|
||||||
|
try:
|
||||||
|
value = int(float(line))
|
||||||
|
new_lines.add(value)
|
||||||
|
except Exception as e:
|
||||||
|
print(e)
|
||||||
|
for i in range(128):
|
||||||
|
assert i in new_lines
|
||||||
|
assert len(new_lines) == 128
|
@@ -1197,3 +1197,22 @@ def swizzle2d(i, j, size_i, size_j, size_g):
|
|||||||
@triton.jit
|
@triton.jit
|
||||||
def zeros_like(input):
|
def zeros_like(input):
|
||||||
return zeros(input.shape, input.dtype)
|
return zeros(input.shape, input.dtype)
|
||||||
|
|
||||||
|
|
||||||
|
@builtin
|
||||||
|
def printf(prefix, *args, _builder=None):
|
||||||
|
import string
|
||||||
|
new_prefix = prefix
|
||||||
|
if isinstance(prefix, constexpr):
|
||||||
|
new_prefix = prefix.value
|
||||||
|
assert isinstance(new_prefix, str), f"{new_prefix} is not string"
|
||||||
|
b_ascii = True
|
||||||
|
for ch in new_prefix:
|
||||||
|
if ch not in string.printable:
|
||||||
|
b_ascii = False
|
||||||
|
break
|
||||||
|
assert b_ascii, f"{new_prefix} is not an ascii string"
|
||||||
|
new_args = []
|
||||||
|
for arg in args:
|
||||||
|
new_args.append(_to_tensor(arg, _builder))
|
||||||
|
return semantic.printf(new_prefix, new_args, _builder)
|
||||||
|
@@ -56,28 +56,34 @@ def elementwise(lib_name: str, lib_path: str, args: list, arg_type_symbol_dict:
|
|||||||
:return: the return value of the function
|
:return: the return value of the function
|
||||||
'''
|
'''
|
||||||
dispatch_args = args.copy()
|
dispatch_args = args.copy()
|
||||||
if len(args) == 1:
|
all_scalar = True
|
||||||
dispatch_args[0] = core._to_tensor(dispatch_args[0], _builder)
|
ret_shape = None
|
||||||
ret_shape = dispatch_args[0].shape
|
for dispatch_arg in dispatch_args:
|
||||||
elif len(args) == 2:
|
if dispatch_arg.type.is_block():
|
||||||
dispatch_args[0] = core._to_tensor(dispatch_args[0], _builder)
|
all_scalar = False
|
||||||
dispatch_args[1] = core._to_tensor(dispatch_args[1], _builder)
|
if not all_scalar:
|
||||||
dispatch_args[0], dispatch_args[1] = semantic.binary_op_type_checking_impl(
|
if len(args) == 1:
|
||||||
dispatch_args[0], dispatch_args[1], _builder)
|
dispatch_args[0] = core._to_tensor(dispatch_args[0], _builder)
|
||||||
ret_shape = dispatch_args[0].shape
|
ret_shape = dispatch_args[0].shape
|
||||||
else:
|
elif len(args) == 2:
|
||||||
for i in range(len(dispatch_args)):
|
dispatch_args[0] = core._to_tensor(dispatch_args[0], _builder)
|
||||||
dispatch_args[i] = core._to_tensor(dispatch_args[i], _builder)
|
dispatch_args[1] = core._to_tensor(dispatch_args[1], _builder)
|
||||||
broadcast_arg = dispatch_args[0]
|
dispatch_args[0], dispatch_args[1] = semantic.binary_op_type_checking_impl(
|
||||||
# Get the broadcast shape over all the arguments
|
dispatch_args[0], dispatch_args[1], _builder)
|
||||||
for i in range(len(dispatch_args)):
|
ret_shape = dispatch_args[0].shape
|
||||||
_, broadcast_arg = semantic.binary_op_type_checking_impl(
|
else:
|
||||||
dispatch_args[i], broadcast_arg, _builder)
|
for i in range(len(dispatch_args)):
|
||||||
# Change the shape of each argument based on the broadcast shape
|
dispatch_args[i] = core._to_tensor(dispatch_args[i], _builder)
|
||||||
for i in range(len(dispatch_args)):
|
broadcast_arg = dispatch_args[0]
|
||||||
dispatch_args[i], _ = semantic.binary_op_type_checking_impl(
|
# Get the broadcast shape over all the arguments
|
||||||
dispatch_args[i], broadcast_arg, _builder)
|
for i in range(len(dispatch_args)):
|
||||||
ret_shape = broadcast_arg.shape
|
_, broadcast_arg = semantic.binary_op_type_checking_impl(
|
||||||
|
dispatch_args[i], broadcast_arg, _builder)
|
||||||
|
# Change the shape of each argument based on the broadcast shape
|
||||||
|
for i in range(len(dispatch_args)):
|
||||||
|
dispatch_args[i], _ = semantic.binary_op_type_checking_impl(
|
||||||
|
dispatch_args[i], broadcast_arg, _builder)
|
||||||
|
ret_shape = broadcast_arg.shape
|
||||||
func = getattr(_builder, "create_external_elementwise")
|
func = getattr(_builder, "create_external_elementwise")
|
||||||
return dispatch(func, lib_name, lib_path, dispatch_args, arg_type_symbol_dict, ret_shape, _builder)
|
return dispatch(func, lib_name, lib_path, dispatch_args, arg_type_symbol_dict, ret_shape, _builder)
|
||||||
|
|
||||||
|
@@ -1123,3 +1123,10 @@ def max_contiguous(x: tl.tensor, values: List[int]) -> tl.tensor:
|
|||||||
|
|
||||||
def debug_barrier(builder: ir.builder) -> tl.tensor:
|
def debug_barrier(builder: ir.builder) -> tl.tensor:
|
||||||
return tl.tensor(builder.create_barrier(''), tl.void)
|
return tl.tensor(builder.create_barrier(''), tl.void)
|
||||||
|
|
||||||
|
|
||||||
|
def printf(prefix: str, args: List[tl.tensor], builder: ir.builder) -> tl.tensor:
|
||||||
|
new_args = []
|
||||||
|
for arg in args:
|
||||||
|
new_args.append(arg.handle)
|
||||||
|
return tl.tensor(builder.create_printf(prefix, new_args), tl.void)
|
||||||
|
@@ -157,15 +157,6 @@ import triton.language as tl
|
|||||||
@triton.autotune(
|
@triton.autotune(
|
||||||
configs=[
|
configs=[
|
||||||
triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=3, num_warps=8),
|
triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=3, num_warps=8),
|
||||||
triton.Config({'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=3, num_warps=8),
|
|
||||||
triton.Config({'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4),
|
|
||||||
triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4),
|
|
||||||
triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4),
|
|
||||||
triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4),
|
|
||||||
triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4),
|
|
||||||
triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4),
|
|
||||||
triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=5, num_warps=2),
|
|
||||||
triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=5, num_warps=2),
|
|
||||||
],
|
],
|
||||||
key=['M', 'N', 'K'],
|
key=['M', 'N', 'K'],
|
||||||
)
|
)
|
||||||
@@ -318,13 +309,13 @@ else:
|
|||||||
triton.testing.Benchmark(
|
triton.testing.Benchmark(
|
||||||
x_names=['M', 'N', 'K'], # argument names to use as an x-axis for the plot
|
x_names=['M', 'N', 'K'], # argument names to use as an x-axis for the plot
|
||||||
x_vals=[
|
x_vals=[
|
||||||
128 * i for i in range(2, 33)
|
8192
|
||||||
], # different possible values for `x_name`
|
], # different possible values for `x_name`
|
||||||
line_arg='provider', # argument name whose value corresponds to a different line in the plot
|
line_arg='provider', # argument name whose value corresponds to a different line in the plot
|
||||||
# possible values for `line_arg``
|
# possible values for `line_arg``
|
||||||
line_vals=['cublas', 'cublas + relu', 'triton', 'triton + relu'],
|
line_vals=['cublas', 'triton'],
|
||||||
# label name for the lines
|
# label name for the lines
|
||||||
line_names=["cuBLAS", "cuBLAS (+ torch.nn.LeakyReLU)", "Triton", "Triton (+ LeakyReLU)"],
|
line_names=["cuBLAS", "Triton"],
|
||||||
# line styles
|
# line styles
|
||||||
styles=[('green', '-'), ('green', '--'), ('blue', '-'), ('blue', '--')],
|
styles=[('green', '-'), ('green', '--'), ('blue', '-'), ('blue', '--')],
|
||||||
ylabel="TFLOPS", # label name for the y-axis
|
ylabel="TFLOPS", # label name for the y-axis
|
||||||
@@ -336,18 +327,9 @@ def benchmark(M, N, K, provider):
|
|||||||
a = torch.randn((M, K), device='cuda', dtype=torch.float16)
|
a = torch.randn((M, K), device='cuda', dtype=torch.float16)
|
||||||
b = torch.randn((K, N), device='cuda', dtype=torch.float16)
|
b = torch.randn((K, N), device='cuda', dtype=torch.float16)
|
||||||
if provider == 'cublas':
|
if provider == 'cublas':
|
||||||
ms, min_ms, max_ms = triton.testing.do_bench(lambda: torch.matmul(a, b))
|
ms, min_ms, max_ms = triton.testing.do_bench(lambda: torch.matmul(a, b), rep=100)
|
||||||
if provider == 'triton':
|
if provider == 'triton':
|
||||||
ms, min_ms, max_ms = triton.testing.do_bench(lambda: matmul(a, b))
|
ms, min_ms, max_ms = triton.testing.do_bench(lambda: matmul(a, b), rep=100)
|
||||||
if provider == 'cublas + relu':
|
|
||||||
torch_relu = torch.nn.ReLU(inplace=True)
|
|
||||||
ms, min_ms, max_ms = triton.testing.do_bench(
|
|
||||||
lambda: torch_relu(torch.matmul(a, b))
|
|
||||||
)
|
|
||||||
if provider == 'triton + relu':
|
|
||||||
ms, min_ms, max_ms = triton.testing.do_bench(
|
|
||||||
lambda: matmul(a, b, activation=leaky_relu)
|
|
||||||
)
|
|
||||||
perf = lambda ms: 2 * M * N * K * 1e-12 / (ms * 1e-3)
|
perf = lambda ms: 2 * M * N * K * 1e-12 / (ms * 1e-3)
|
||||||
return perf(ms), perf(max_ms), perf(min_ms)
|
return perf(ms), perf(max_ms), perf(min_ms)
|
||||||
|
|
||||||
|
Reference in New Issue
Block a user