37 Commits

Author SHA1 Message Date
Jokeren
43408fef5a Fix 2022-12-06 17:09:09 -08:00
Jokeren
e817fdf1b9 Fix 2022-12-06 13:46:21 -08:00
Jokeren
8dd099beef Fix 2022-12-06 13:31:42 -08:00
Jokeren
f20f48a255 Move 2022-12-06 13:29:29 -08:00
Jokeren
3eff110fbc Restore 2022-12-06 13:28:43 -08:00
Jokeren
5f85b79718 Merge branch 'triton-mlir' into keren/insert-slice-other-nonzero 2022-12-06 13:25:20 -08:00
Jokeren
bab7338965 Fix 2022-12-06 13:24:50 -08:00
Jokeren
74f3d7a80f Fix 2022-12-06 12:53:25 -08:00
Philippe Tillet
115cd3ac47 [FRONTEND] Added reshape as an alias for view (for now) (#956) 2022-12-06 09:57:05 -08:00
Philippe Tillet
532e10cf87 [FRONTEND][BACKEND] Clean-up transpositions (#953) 2022-12-06 09:32:13 -08:00
Keren Zhou
16e973edf2 [BACKEND] Fix dependency analysis in pipeline (#946) 2022-12-06 09:08:55 -08:00
Jokeren
b539e031e8 Add test 2022-12-05 23:38:54 -08:00
Jokeren
46fa29496c Init 2022-12-05 23:18:13 -08:00
Crutcher Dunnavant
9490252261 [FRONTEND] Support alternative install locations of system libdevice.10.bc (#951) 2022-12-06 03:41:44 +00:00
Yan Chunwei
e419781978 [Triton-MLIR][BACKEND] Make mmav1 works on basic cases (#944)
TODO:

- Add more cases
- Currently, we just set vec to 4 to make the basic cases pass

Issue:

- the vec in shared layout is different compared to master branch
- when vec=1, it encounters CUDA misalignment error, it doesn't work in
master branch as well
- when setting vec to the value identical to master branch, the MMA
works
2022-12-06 10:57:08 +08:00
Crutcher Dunnavant
189491727a [FRONTEND] Extract and unify @builtin/@extern (#913)
This change attaches builtin-ness as an explicit attribute, rather than
a module prefix expectation. This permits us to source those builtins
from multiple sub-modules (useful when some builtins are part of the
true cyclic implementation core, and some are just useful library
additions); but also prevents accidental inclusion of non-builtins that
happen to be in the right library.

Once the flag exists, and the compiler is using `is_builtin()` for
decision making; the existence of the current `@extern` interface
becomes isomorphic to `@builtin`; and the interface can be unified.

Leaving `@extern` a thin-wrapper, and encouraging continued use of it,
establishes future-proofing towards adding additional extern tracing,
metric hooks, or scanning in the future.

* Add `triton.impl` package to hold the core, order dependent impl
details.
 * Extract `@builtin` and unify `@extern`; add `is_builtin()`
   * Add sense bit so that `@builtin` detection is less fragile.
 * Modify the compiler to use `is_builtin()`
2022-12-05 22:59:41 +00:00
Crutcher Dunnavant
e0072d210a [FRONTEND] Propagate mypy types through @jit, @builtin, etc (#915)
Changes to make decorated API methods no longer type-opaque.

```
$ echo 'import triton; reveal_type(triton.language.max)' | mypy /dev/stdin
/dev/stdin:1: note: Revealed type is "def (input: Any, axis: Any, _builder: Any =) -> Any"
Success: no issues found in 1 source file
```
2022-12-05 22:41:02 +00:00
Crutcher Dunnavant
2fa17588f7 [FRONTEND] Expand __init__ * imports, add __all__ (#912)
Expand `from .foo import *` to full listings, and `__all__` sections.

This reifies the module export listings, which is useful for code
importing this module; without this, clients will need special `mypy`
control pragmas for this library.

This removes a number of `# flake8` control pragmas.

Verified with `flake8`
2022-12-05 14:22:55 -08:00
goostavz
e057c65cf0 [BACKEND] Porting the legacy heuristic rule in assigning shared layout for A/B of MMAv1 (#948) 2022-12-05 11:30:23 -08:00
Philippe Tillet
99c7e0e008 [BUILD] Change default build type (#945) 2022-12-03 17:47:33 -08:00
Keren Zhou
f2fcaeabf3 [BACKEND] Support dot op when the output is mma encoding and allowtf32 is true (#937) 2022-12-03 19:14:12 +00:00
Philippe Tillet
8edfe813a5 [FRONTEND][BACKEND] Added trans instruction; made flash attention bwd pass work (#943) 2022-12-03 09:58:24 -08:00
goostavz
4d64589b22 [Triton-MLIR][Backend] Fix the definition of MmaEncodingAttr v1, and the output sequence of DotConversion in MMAv1 (#941) 2022-12-03 21:12:48 +08:00
donproc
521ff9ad74 [TRITON-MLIR][FRONTEND]fix scf.if to run through layernorm tutorial (#938)
Co-authored-by: dongdongl <dongdongl@nvidia.com>
2022-12-02 17:45:29 +08:00
Keren Zhou
c280ebda1b [Triton-MLIR][BACKEND] Fix the membar pass to add missing barriers caused by scf.for (#933)
1. Add missing barriers and revert the previous temporary solution
2. Extract the `run` method from membar analysis because the membar
analysis should have two phases, including construction, which doesn't
modify any IR, and modification, which adds barrier IRs. Hope this could
make the use of membar clear.
2022-12-01 11:54:18 -08:00
donproc
9def1bcebf [TRITON-MLIR][FRONTEND]minor fix to run through atomic_cas test (#925)
Co-authored-by: dongdongl <dongdongl@nvidia.com>
2022-12-01 13:43:26 +00:00
Keren Zhou
7d90a07d0b [Triton-MLIR][BACKEND] Refactor decompose insert_slice_async (#929)
1. Improve pipline's comment
2. Decompose insert_slice_async when load vector size is not supported
3. Add a test that could fail our gemm code

Copy my comments here:

There's a knob that may cause performance regression when decomposition
has been performed. We should remove this knob once we have thorough
analysis on async wait. Currently, we decompose `insert_slice_async`
into `load` and `insert_slice` without knowing which `async_wait` is
responsible for the `insert_slice_async`. To guarantee correctness, we
blindly set the `async_wait` to wait for all async ops if any `insert_slice_async` has been decomposed.

There are two options to improve this:
1. We can perform a dataflow analysis to find the `async_wait` that is
responsible for the `insert_slice_async` in the backend.
4. We can modify the pipeline to perform the decomposition before the
`async_wait` is inserted. However, it is also risky because we don't
know the correct vectorized shape yet in the pipeline pass. Making the
pipeline pass aware of the vectorization could introduce additional
dependencies on the AxisInfoAnalysis and the Coalesce analysis.
2022-11-30 10:07:34 -08:00
Philippe Tillet
6461254fb5 [BACKEND] Make flash attention forward pass work (#928)
This also simplifies BroadcastOp codegen
2022-11-30 10:13:24 +00:00
goostavz
4e6a8209ed [Triton-MLIR] Two fixes on allocation and backend related with MMA v1 (#930) 2022-11-30 09:27:26 +00:00
Philippe Tillet
9bb54402b3 [FRONTEND][BACKEND] Small fixes to multiple_of, num_programs, axisinfo; enable block-sparse tests (#927) 2022-11-29 20:00:34 +01:00
Philippe Tillet
66c36c4378 [BACKEND] Fixed bounds-wrapping issues (#926)
This fixes an issue that led to out-of-bounds shared memory accesses on
small matrices
2022-11-29 17:56:45 +01:00
Qingyi Liu
661be523c0 [Triton-MLIR][BACKEND] Minor fixes of shared memory in ReduceOpConversion (#924) 2022-11-29 11:50:31 +08:00
Yan Chunwei
c87fbf886e [Triton-MLIR][BACKEND] Remove static and unnamed namespace in Utility.h (#923)
Reference
https://wiki.sei.cmu.edu/confluence/display/cplusplus/DCL59-CPP.+Do+not+define+an+unnamed+namespace+in+a+header+file
2022-11-29 01:06:06 +00:00
goostavz
0c1d4d764e [Triton-MLIR][Backend] support MMA v1 in ConvertLayout (#922)
The e2e verification of mma v1 is not done yet. 
Get this merged in advance just to prevent more conflicts.
2022-11-28 08:10:30 +00:00
Qingyi Liu
9d31998a9d [Triton-MLIR][BACKEND] Add argmin / argmax implementation for ReduceOp (#918) 2022-11-27 22:59:27 -08:00
Yan Chunwei
04ec5deb41 [Triton-MLIR][BACKEND] decouple the dot code (#921)
This PR
- apply minimal modification to decouple the Dot helper related code
from TritonGPUToLLVM.cpp to a separate local header file to make it
easier to share some data structure for Dot
- add some patch necessary for transA and transB
- add some patch necessary for MMA v1 execution in backend
2022-11-28 13:30:27 +08:00
goostavz
630dc315ee [Triton-MLIR] uncomment the UT in test_gemm that has already been fixed (#920) 2022-11-28 11:23:20 +08:00
53 changed files with 4825 additions and 3091 deletions

View File

@@ -17,7 +17,7 @@ jobs:
id: set-matrix
run: |
if [ x"${{ github.repository }}" == x"openai/triton" ]; then
echo '::set-output name=matrix::[["self-hosted", "A10"], "macos-10.15"]'
echo '::set-output name=matrix::[["self-hosted", "A10"], ["self-hosted", "V100"], "macos-10.15"]'
else
echo '::set-output name=matrix::["ubuntu-latest", "macos-10.15"]'
fi
@@ -79,11 +79,18 @@ jobs:
lit -v "$LIT_TEST_DIR"
- name: Run python tests
if: ${{matrix.runner[0] == 'self-hosted'}}
if: ${{matrix.runner[0] == 'self-hosted' && matrix.runner[1] == 'A10'}}
run: |
cd python/tests
pytest
# TODO[Superjomn] Enable all the tests on V100 if available
- name: Run python tests on V100
if: ${{matrix.runner[0] == 'self-hosted' && matrix.runner[1] == 'V100'}}
run: |
cd python/tests
pytest test_gemm.py::test_gemm_no_scf_for_mmav1
- name: Run CXX unittests
run: |
cd python/

View File

@@ -20,8 +20,6 @@ SmallVector<unsigned>
getScratchConfigForCvtLayout(triton::gpu::ConvertLayoutOp op, unsigned &inVec,
unsigned &outVec);
SmallVector<unsigned> getScratchConfigForReduce(triton::ReduceOp op);
} // namespace triton
/// Modified from llvm-15.0: llvm/ADT/AddressRanges.h

View File

@@ -131,6 +131,12 @@ public:
ChangeResult
visitOperation(Operation *op,
ArrayRef<LatticeElement<AxisInfo> *> operands) override;
unsigned getPtrVectorSize(Value ptr);
unsigned getPtrAlignment(Value ptr);
unsigned getMaskAlignment(Value mask);
};
} // namespace mlir

View File

@@ -29,7 +29,11 @@ public:
/// The following circumstances are not considered yet:
/// - Double buffers
/// - N buffers
MembarAnalysis(Allocation *allocation) : allocation(allocation) { run(); }
MembarAnalysis(Allocation *allocation) : allocation(allocation) {}
/// Runs the membar analysis to the given operation, inserts a barrier if
/// necessary.
void run();
private:
struct RegionInfo {
@@ -82,10 +86,6 @@ private:
}
};
/// Runs the membar analysis to the given operation, inserts a barrier if
/// necessary.
void run();
/// Applies the barrier analysis based on the SCF dialect, in which each
/// region has a single basic block only.
/// Example:

View File

@@ -26,6 +26,12 @@ public:
unsigned getThreadsReductionAxis();
SmallVector<unsigned> getScratchConfigBasic();
SmallVector<SmallVector<unsigned>> getScratchConfigsFast();
unsigned getScratchSizeInBytes();
private:
triton::ReduceOp op;
RankedTensorType srcTy{};
@@ -39,6 +45,14 @@ bool maybeAliasOp(Operation *op);
std::string getValueOperandName(Value value, AsmState &state);
template <typename T_OUT, typename T_IN>
inline SmallVector<T_OUT> convertType(ArrayRef<T_IN> in) {
SmallVector<T_OUT> out;
for (const T_IN &i : in)
out.push_back(T_OUT(i));
return out;
}
template <typename Int> Int product(llvm::ArrayRef<Int> arr) {
return std::accumulate(arr.begin(), arr.end(), 1, std::multiplies{});
}

View File

@@ -295,6 +295,18 @@ def TT_CatOp : TT_Op<"cat", [NoSideEffect,
let assemblyFormat = "$lhs `,` $rhs attr-dict `:` functional-type(operands, results)";
}
def TT_TransOp : TT_Op<"trans", [NoSideEffect,
SameOperandsAndResultElementType]> {
let summary = "transpose a tensor";
let arguments = (ins TT_Tensor:$src);
let results = (outs TT_Tensor:$result);
let assemblyFormat = "$src attr-dict `:` functional-type(operands, results)";
}
//
// SPMD Ops
//
@@ -327,7 +339,7 @@ def TT_DotOp : TT_Op<"dot", [NoSideEffect,
$d = matrix_multiply($a, $b) + $c
}];
let arguments = (ins TT_FpIntTensor:$a, TT_FpIntTensor:$b, TT_FpIntTensor:$c, BoolAttr:$allowTF32, BoolAttr:$transA, BoolAttr:$transB);
let arguments = (ins TT_FpIntTensor:$a, TT_FpIntTensor:$b, TT_FpIntTensor:$c, BoolAttr:$allowTF32);
let results = (outs TT_FpIntTensor:$d);
@@ -351,6 +363,11 @@ def TT_ReduceOp : TT_Op<"reduce", [NoSideEffect,
let assemblyFormat = "$operand attr-dict `:` type($operand) `->` type($result)";
let extraClassDeclaration = [{
// This member function is marked static because we need to call it before the ReduceOp
// is constructed, see the implementation of create_reduce in triton.cc.
static bool withIndex(mlir::triton::RedOp redOp);
}];
}
//

View File

@@ -25,11 +25,13 @@ namespace gpu {
unsigned getElemsPerThread(Type type);
SmallVector<unsigned> getThreadsPerWarp(Attribute layout);
SmallVector<unsigned> getThreadsPerWarp(const Attribute &layout);
SmallVector<unsigned> getWarpsPerCTA(Attribute layout);
SmallVector<unsigned> getWarpsPerCTA(const Attribute &layout);
SmallVector<unsigned> getSizePerThread(Attribute layout);
SmallVector<unsigned> getSizePerThread(const Attribute &layout);
SmallVector<unsigned> getContigPerThread(Attribute layout);
SmallVector<unsigned> getThreadsPerCTA(const Attribute &layout);

View File

@@ -87,17 +87,25 @@ A_{3, 2} A_{3, 3} A_{3, 0} A_{3, 1} ... [phase 1] /
// number of rows per phase
int perPhase = 128 / (shape[order[0]] * (eltTy.getIntOrFloatBitWidth() / 8));
perPhase = std::max<int>(perPhase, 1);
// index of the inner dimension in `order`
unsigned inner = (opIdx == 0) ? 0 : 1;
// ---- begin version 1 ----
// TODO: handle rep (see
// https://github.com/openai/triton/blob/master/lib/codegen/analysis/layout.cc#L209)
if (version == 1) {
bool is_row = order[0] != 0;
bool is_vec4 = opIdx == 0 ? !is_row && (shape[order[0]] <= 16) :
is_row && (shape[order[0]] <= 16);
// TODO[Superjomn]: Support the case when is_vec4=false later
// Currently, we only support ld.v2, for the mma layout varies with different ld vector width.
is_vec4 = true;
int pack_size = opIdx == 0 ? ((is_row || is_vec4) ? 1 : 2) :
((is_row && !is_vec4) ? 2 : 1);
int rep = 2 * pack_size;
int maxPhase = (order[inner] == 1 ? 8 : 4) / perPhase;
return $_get(context, 1, perPhase, maxPhase, order);
}
int vec = 2 * rep;
return $_get(context, vec, perPhase, maxPhase, order);
}
// ---- begin version 2 ----
if (version == 2) {
@@ -106,14 +114,14 @@ A_{3, 2} A_{3, 3} A_{3, 0} A_{3, 1} ... [phase 1] /
// for now, disable swizzle when using transposed int8 tensor cores
if (eltTy.isInteger(8) && order[0] == inner)
return $_get(context, 1, 1, 1, order);
// --- handle A operand ---
if (opIdx == 0) { // compute swizzling for A operand
int vec = (order[0] == 1) ? matShape[2] : matShape[0]; // k : m
int mmaStride = (order[0] == 1) ? matShape[0] : matShape[2];
int maxPhase = mmaStride / perPhase;
return $_get(context, vec, perPhase, maxPhase, order);
}
}
// --- handle B operand ---
if (opIdx == 1) {
@@ -121,8 +129,8 @@ A_{3, 2} A_{3, 3} A_{3, 0} A_{3, 1} ... [phase 1] /
int mmaStride = (order[0] == 1) ? matShape[2] : matShape[1];
int maxPhase = mmaStride / perPhase;
return $_get(context, vec, perPhase, maxPhase, order);
}
}
llvm_unreachable("invalid operand index");
}
@@ -293,7 +301,7 @@ partitioned between warps.
// -------------------------------- version = 1 --------------------------- //
For first-gen tensor cores, the implicit warpTileSize is [16, 16].
Information about this layout can be found in the official PTX documentation
Note: the layout is different from the recommended in PTX ISA
https://docs.nvidia.com/cuda/parallel-thread-execution/index.html
(mma.884 section, FP32 accumulator).
@@ -301,29 +309,29 @@ For example, the matrix L corresponding to blockTileSize=[32,16] is:
warp 0
--------------------------------/\-------------------------------
[ 0 0 2 2 0 0 2 2 4 4 6 6 4 4 6 6 ]
[ 1 1 3 3 1 1 3 3 5 5 7 7 5 5 7 7 ]
[ 0 0 2 2 0 0 2 2 4 4 6 6 4 4 6 6 ]
[ 1 1 3 3 1 1 3 3 5 5 7 7 5 5 7 7 ]
[ 16 16 18 18 16 16 18 18 20 20 22 22 20 20 22 22]
[ 17 17 19 19 17 17 19 19 21 21 23 23 21 21 23 23]
[ 16 16 18 18 16 16 18 18 20 20 22 22 20 20 22 22]
[ 17 17 19 19 17 17 19 19 21 21 23 23 21 21 23 23]
[ 8 8 10 10 8 8 10 10 12 12 14 14 12 12 14 14]
[ 9 9 11 11 9 9 11 11 13 13 15 15 13 13 15 15]
[ ..............................................................
[ ..............................................................
[ 24 24 26 26 24 24 26 26 28 28 30 30 28 28 30 30]
[ 25 25 27 27 25 25 27 27 29 29 31 31 29 29 31 31]
[ 0 0 2 2 8 8 10 10 0 0 2 2 8 8 10 10 ]
[ 1 1 3 3 9 9 11 11 1 1 3 3 9 9 11 11 ]
[ 0 0 2 2 8 8 10 10 0 0 2 2 8 8 10 10 ]
[ 1 1 3 3 9 9 11 11 1 1 3 3 9 9 11 11 ]
[ 4 4 6 6 12 12 14 14 4 4 6 6 12 12 14 14 ]
[ 5 5 7 7 13 13 15 15 5 5 7 7 13 13 15 15 ]
[ 4 4 6 6 12 12 14 14 4 4 6 6 12 12 14 14 ]
[ 5 5 7 7 13 13 15 15 5 5 7 7 13 13 15 15 ]
[ 16 16 18 18 20 20 22 22 16 16 18 18 20 20 22 22 ]
[ 17 17 19 19 21 21 23 23 17 17 19 19 21 21 23 23 ]
[ 16 16 18 18 20 20 22 22 16 16 18 18 20 20 22 22 ]
[ 17 17 19 19 21 21 23 23 17 17 19 19 21 21 23 23 ]
[ 24 24 26 26 28 28 30 30 24 24 26 26 28 28 30 30 ]
[ 25 25 27 27 29 29 31 31 25 25 27 27 29 29 31 31 ]
[ 24 24 26 26 28 28 30 30 24 24 26 26 28 28 30 30 ]
[ 25 25 27 27 29 29 31 31 25 25 27 27 29 29 31 31 ]
warp 1 = warp0 + 32
warp 1 = warp0 + 32
--------------------------------/\-------------------------------
[ 32 32 34 34 32 32 34 34 36 36 38 38 36 36 38 38]
[ 33 33 35 35 33 33 35 35 37 37 39 39 37 37 39 39]
[ ..............................................................
[ ..............................................................
[ 56 56 58 58 56 56 58 58 60 60 62 62 60 60 62 62]
[ 57 57 59 59 57 57 59 59 61 61 63 63 61 61 63 63]
[ 32 32 34 34 40 40 42 42 32 32 34 34 40 40 42 42 ]
[ 33 33 35 35 41 41 43 43 33 33 35 35 41 41 43 43 ]
[ ............................................................... ]
// -------------------------------- version = 2 --------------------------- //

View File

@@ -151,6 +151,10 @@ def TTG_InsertSliceAsyncOp : TTG_Op<"insert_slice_async",
// attr-dict `:` type($src) `->` type($dst)
//}];
let extraClassDeclaration = [{
static DenseSet<unsigned> getEligibleLoadByteWidth(int computeCapability);
}];
// The custom parser could be replaced with oilist in LLVM-16
let parser = [{ return parseInsertSliceAsyncOp(parser, result); }];

View File

@@ -26,7 +26,7 @@ ChangeResult SharedMemoryAliasAnalysis::visitOperation(
// These ops may allocate a new shared memory buffer.
auto result = op->getResult(0);
// FIXME(Keren): extract and insert are always alias for now
if (isa<tensor::ExtractSliceOp>(op)) {
if (isa<tensor::ExtractSliceOp, triton::TransOp>(op)) {
// extract_slice %src
aliasInfo = AliasInfo(operands[0]->getValue());
pessimistic = false;

View File

@@ -13,6 +13,7 @@
using ::mlir::triton::gpu::BlockedEncodingAttr;
using ::mlir::triton::gpu::DotOperandEncodingAttr;
using ::mlir::triton::gpu::getContigPerThread;
using ::mlir::triton::gpu::getOrder;
using ::mlir::triton::gpu::getShapePerCTA;
using ::mlir::triton::gpu::getSizePerThread;
@@ -60,8 +61,8 @@ getScratchConfigForCvtLayout(triton::gpu::ConvertLayoutOp op, unsigned &inVec,
assert(srcLayout && dstLayout &&
"Unexpect layout in getScratchConfigForCvtLayout()");
auto [inOrd, outOrd] = getCvtOrder(srcLayout, dstLayout);
unsigned srcContigPerThread = getSizePerThread(srcLayout)[inOrd[0]];
unsigned dstContigPerThread = getSizePerThread(dstLayout)[outOrd[0]];
unsigned srcContigPerThread = getContigPerThread(srcLayout)[inOrd[0]];
unsigned dstContigPerThread = getContigPerThread(dstLayout)[outOrd[0]];
// TODO: Fix the legacy issue that ourOrd[0] == 0 always means
// that we cannot do vectorization.
inVec = outOrd[0] == 0 ? 1 : inOrd[0] == 0 ? 1 : srcContigPerThread;
@@ -88,25 +89,6 @@ getScratchConfigForCvtLayout(triton::gpu::ConvertLayoutOp op, unsigned &inVec,
return paddedRepShape;
}
SmallVector<unsigned> getScratchConfigForReduce(triton::ReduceOp op) {
ReduceOpHelper helper(op);
SmallVector<unsigned> smemShape;
auto srcShape = helper.getSrcShape();
for (auto d : srcShape)
smemShape.push_back(d);
auto axis = op.axis();
if (helper.isFastReduction()) {
smemShape[axis] = helper.getInterWarpSize();
} else {
smemShape[axis] =
std::min(smemShape[axis], helper.getThreadsReductionAxis());
}
return smemShape;
}
// TODO: extend beyond scalars
SmallVector<unsigned> getScratchConfigForAtomicRMW(triton::AtomicRMWOp op) {
SmallVector<unsigned> smemShape;
@@ -173,21 +155,9 @@ private:
/// Initializes temporary shared memory for a given operation.
void getScratchValueSize(Operation *op) {
if (auto reduceOp = dyn_cast<triton::ReduceOp>(op)) {
// TODO(Keren): Reduce with index is not supported yet.
auto value = op->getOperand(0);
if (auto tensorType = value.getType().dyn_cast<RankedTensorType>()) {
bool fastReduce = ReduceOpHelper(reduceOp).isFastReduction();
auto smemShape = getScratchConfigForReduce(reduceOp);
unsigned elems = std::accumulate(smemShape.begin(), smemShape.end(), 1,
std::multiplies{});
if (fastReduce) {
auto mod = op->getParentOfType<ModuleOp>();
unsigned numWarps = triton::gpu::TritonGPUDialect::getNumWarps(mod);
elems = std::max<unsigned>(elems, numWarps * 32);
}
auto bytes = elems * tensorType.getElementTypeBitWidth() / 8;
allocation->addBuffer<BufferT::BufferKind::Scratch>(op, bytes);
}
ReduceOpHelper helper(reduceOp);
unsigned bytes = helper.getScratchSizeInBytes();
allocation->addBuffer<BufferT::BufferKind::Scratch>(op, bytes);
} else if (auto cvtLayout = dyn_cast<triton::gpu::ConvertLayoutOp>(op)) {
auto srcTy = cvtLayout.src().getType().cast<RankedTensorType>();
auto dstTy = cvtLayout.result().getType().cast<RankedTensorType>();

View File

@@ -132,6 +132,7 @@ ChangeResult AxisInfoAnalysis::visitOperation(
AxisInfo::DimVectorT(ty.getShape().begin(), ty.getShape().end()));
}
}
// TODO: refactor & complete binary ops
// Addition
if (llvm::isa<arith::AddIOp, triton::AddPtrOp>(op)) {
auto newContiguity = [&](AxisInfo lhs, AxisInfo rhs, int d) {
@@ -159,6 +160,20 @@ ChangeResult AxisInfoAnalysis::visitOperation(
curr = visitBinaryOp(op, operands[0]->getValue(), operands[1]->getValue(),
newContiguity, newDivisibility, newConstancy);
}
// Remainder
if (llvm::isa<arith::RemSIOp, arith::RemUIOp>(op)) {
auto newContiguity = [](AxisInfo lhs, AxisInfo rhs, int d) {
return gcd(lhs.getContiguity(d), rhs.getDivisibility(d));
};
auto newDivisibility = [](AxisInfo lhs, AxisInfo rhs, int d) {
return gcd(lhs.getDivisibility(d), rhs.getDivisibility(d));
};
auto newConstancy = [](AxisInfo lhs, AxisInfo rhs, int d) {
return gcd(lhs.getConstancy(d), rhs.getConstancy(d));
};
curr = visitBinaryOp(op, operands[0]->getValue(), operands[1]->getValue(),
newContiguity, newDivisibility, newConstancy);
}
// TODO: All other binary ops
if (llvm::isa<arith::AndIOp, arith::OrIOp>(op)) {
auto newContiguity = [](AxisInfo lhs, AxisInfo rhs, int d) { return 1; };
@@ -261,4 +276,46 @@ ChangeResult AxisInfoAnalysis::visitOperation(
return result;
}
unsigned AxisInfoAnalysis::getPtrVectorSize(Value ptr) {
auto tensorTy = ptr.getType().dyn_cast<RankedTensorType>();
if (!tensorTy)
return 1;
auto layout = tensorTy.getEncoding();
auto shape = tensorTy.getShape();
// Here order should be ordered by contiguous first, so the first element
// should have the largest contiguous.
auto order = triton::gpu::getOrder(layout);
unsigned align = getPtrAlignment(ptr);
unsigned contigPerThread = triton::gpu::getSizePerThread(layout)[order[0]];
unsigned vec = std::min(align, contigPerThread);
vec = std::min<unsigned>(shape[order[0]], vec);
return vec;
}
unsigned AxisInfoAnalysis::getPtrAlignment(Value ptr) {
auto tensorTy = ptr.getType().dyn_cast<RankedTensorType>();
if (!tensorTy)
return 1;
auto axisInfo = lookupLatticeElement(ptr)->getValue();
auto layout = tensorTy.getEncoding();
auto order = triton::gpu::getOrder(layout);
unsigned maxMultiple = axisInfo.getDivisibility(order[0]);
unsigned maxContig = axisInfo.getContiguity(order[0]);
unsigned alignment = std::min(maxMultiple, maxContig);
return alignment;
}
unsigned AxisInfoAnalysis::getMaskAlignment(Value mask) {
auto tensorTy = mask.getType().dyn_cast<RankedTensorType>();
if (!tensorTy)
return 1;
auto maskOrder = triton::gpu::getOrder(tensorTy.getEncoding());
auto maskAxis = lookupLatticeElement(mask)->getValue();
auto alignment = std::max<unsigned>(maskAxis.getConstancy(maskOrder[0]), 1);
return alignment;
}
} // namespace mlir

View File

@@ -24,21 +24,43 @@ void MembarAnalysis::dfsOperation(Operation *operation,
// scf.if only: two regions
// scf.for: one region
RegionInfo curRegionInfo;
for (auto &region : operation->getRegions()) {
// Copy the parent info as the current info.
RegionInfo regionInfo = *parentRegionInfo;
for (auto &block : region.getBlocks()) {
assert(region.getBlocks().size() == 1 &&
"Multiple blocks in a region is not supported");
for (auto &op : block.getOperations()) {
// Traverse the nested operation.
dfsOperation(&op, &regionInfo, builder);
auto traverseRegions = [&]() -> auto{
for (auto &region : operation->getRegions()) {
// Copy the parent info as the current info.
RegionInfo regionInfo = *parentRegionInfo;
for (auto &block : region.getBlocks()) {
assert(region.getBlocks().size() == 1 &&
"Multiple blocks in a region is not supported");
for (auto &op : block.getOperations()) {
// Traverse the nested operation.
dfsOperation(&op, &regionInfo, builder);
}
}
curRegionInfo.join(regionInfo);
}
curRegionInfo.join(regionInfo);
// Set the parent region info as the union of the nested region info.
*parentRegionInfo = curRegionInfo;
};
traverseRegions();
if (isa<scf::ForOp>(operation)) {
// scf.for can have two possible inputs: the init value and the
// previous iteration's result. Although we've applied alias analysis,
// there could be unsynced memory accesses on reused memories.
// For example, consider the following code:
// %1 = convert_layout %0: blocked -> shared
// ...
// gpu.barrier
// ...
// %5 = convert_layout %4 : shared -> dot
// %6 = tt.dot %2, %5
// scf.yield
//
// Though %5 could be released before scf.yield, it may shared the same
// memory with %1. So we actually have to insert a barrier before %1 to
// make sure the memory is synced.
traverseRegions();
}
// Set the parent region info as the union of the nested region info.
*parentRegionInfo = curRegionInfo;
}
}
@@ -49,8 +71,7 @@ void MembarAnalysis::transfer(Operation *op, RegionInfo *regionInfo,
// Do not insert barriers before control flow operations and
// alloc/extract/insert
// alloc is an allocation op without memory write.
// In contrast, arith.constant is an allocation op with memory write.
// FIXME(Keren): extract is always alias for now
// FIXME(Keren): extract_slice is always alias for now
return;
}
@@ -60,9 +81,11 @@ void MembarAnalysis::transfer(Operation *op, RegionInfo *regionInfo,
return;
}
if (isa<triton::gpu::AsyncWaitOp>(op)) {
// If the current op is an async wait, we insert a barrier op and sync
// previous reads and writes.
if (isa<triton::gpu::AsyncWaitOp>(op) &&
!isa<gpu::BarrierOp>(op->getNextNode())) {
// If the current op is an async wait and the next op is not a barrier we
// insert a barrier op and sync
regionInfo->sync();
OpBuilder::InsertionGuard g(*builder);
builder->setInsertionPointAfter(op);
builder->create<gpu::BarrierOp>(op->getLoc());

View File

@@ -37,6 +37,50 @@ unsigned ReduceOpHelper::getThreadsReductionAxis() {
triton::gpu::getWarpsPerCTA(srcLayout)[axis];
}
SmallVector<unsigned> ReduceOpHelper::getScratchConfigBasic() {
auto axis = op.axis();
auto smemShape = convertType<unsigned>(getSrcShape());
smemShape[axis] = std::min(smemShape[axis], getThreadsReductionAxis());
return smemShape;
}
SmallVector<SmallVector<unsigned>> ReduceOpHelper::getScratchConfigsFast() {
auto axis = op.axis();
SmallVector<SmallVector<unsigned>> smemShapes(3);
/// shared memory block0
smemShapes[0] = convertType<unsigned>(getSrcShape());
smemShapes[0][axis] = getInterWarpSize();
/// FIXME(Qingyi): This size is actually larger than required.
/// shared memory block1:
auto mod = op.getOperation()->getParentOfType<ModuleOp>();
unsigned numWarps = triton::gpu::TritonGPUDialect::getNumWarps(mod);
smemShapes[1].push_back(numWarps * 32);
return smemShapes;
}
unsigned ReduceOpHelper::getScratchSizeInBytes() {
unsigned elems = 0;
if (isFastReduction()) {
auto smemShapes = getScratchConfigsFast();
for (const auto &smemShape : smemShapes)
elems = std::max(elems, product<unsigned>(smemShape));
} else {
auto smemShape = getScratchConfigBasic();
elems = product<unsigned>(smemShape);
}
auto tensorType = op.operand().getType().cast<RankedTensorType>();
unsigned bytes = elems * tensorType.getElementTypeBitWidth() / 8;
if (triton::ReduceOp::withIndex(op.redOp()))
bytes += elems * sizeof(int32_t);
return bytes;
}
bool isSharedEncoding(Value value) {
auto type = value.getType();
if (auto tensorType = type.dyn_cast<RankedTensorType>()) {
@@ -61,7 +105,7 @@ bool maybeSharedAllocationOp(Operation *op) {
}
bool maybeAliasOp(Operation *op) {
return isa<tensor::ExtractSliceOp>(op) ||
return isa<tensor::ExtractSliceOp>(op) || isa<triton::TransOp>(op) ||
isa<triton::gpu::InsertSliceAsyncOp>(op) ||
isa<tensor::InsertSliceOp>(op);
}

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,319 @@
#ifndef TRITON_CONVERSION_TRITONGPU_TO_LLVM_UTILITY_H
#define TRITON_CONVERSION_TRITONGPU_TO_LLVM_UTILITY_H
#include "mlir/Analysis/SliceAnalysis.h"
#include "mlir/Conversion/ArithmeticToLLVM/ArithmeticToLLVM.h"
#include "mlir/Conversion/GPUToNVVM/GPUToNVVMPass.h"
#include "mlir/Conversion/LLVMCommon/LoweringOptions.h"
#include "mlir/Conversion/LLVMCommon/Pattern.h"
#include "mlir/Conversion/MathToLLVM/MathToLLVM.h"
#include "mlir/Conversion/SCFToStandard/SCFToStandard.h"
#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h"
#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
#include "mlir/Dialect/GPU/GPUDialect.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/IR/Matchers.h"
#include "mlir/IR/TypeUtilities.h"
#include "mlir/Transforms/DialectConversion.h"
#include "triton/Analysis/Allocation.h"
#include "triton/Analysis/AxisInfo.h"
#include "triton/Analysis/Membar.h"
#include "triton/Analysis/Utility.h"
#include "triton/Conversion/MLIRTypes.h"
#include "triton/Conversion/TritonGPUToLLVM/PtxAsmFormat.h"
#include "triton/Conversion/TritonToTritonGPU/TritonToTritonGPU.h"
#include "triton/Dialect/Triton/IR/Dialect.h"
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
#include "llvm/Support/Format.h"
#include "llvm/Support/FormatVariadic.h"
#include <memory>
#include <numeric>
// Shortcuts for some commonly used LLVM ops to keep code simple and intuitive
#define inttoptr(...) rewriter.create<LLVM::IntToPtrOp>(loc, __VA_ARGS__)
#define ptrtoint(...) rewriter.create<LLVM::PtrToIntOp>(loc, __VA_ARGS__)
#define zext(...) rewriter.create<LLVM::ZExtOp>(loc, __VA_ARGS__)
#define udiv(...) rewriter.create<LLVM::UDivOp>(loc, __VA_ARGS__)
#define urem(...) rewriter.create<LLVM::URemOp>(loc, __VA_ARGS__)
#define add(...) rewriter.create<LLVM::AddOp>(loc, __VA_ARGS__)
#define sub(...) rewriter.create<LLVM::SubOp>(loc, __VA_ARGS__)
#define fadd(...) rewriter.create<LLVM::FAddOp>(loc, __VA_ARGS__)
#define mul(...) rewriter.create<LLVM::MulOp>(loc, __VA_ARGS__)
#define smax(...) rewriter.create<LLVM::SMaxOp>(loc, __VA_ARGS__)
#define umax(...) rewriter.create<LLVM::UMaxOp>(loc, __VA_ARGS__)
#define fmax(...) rewriter.create<LLVM::MaxNumOp>(loc, __VA_ARGS__)
#define smin(...) rewriter.create<LLVM::SMinOp>(loc, __VA_ARGS__)
#define umin(...) rewriter.create<LLVM::UMinOp>(loc, __VA_ARGS__)
#define fmin(...) rewriter.create<LLVM::MinNumOp>(loc, __VA_ARGS__)
#define and_(...) rewriter.create<LLVM::AndOp>(loc, __VA_ARGS__)
#define xor_(...) rewriter.create<LLVM::XOrOp>(loc, __VA_ARGS__)
#define bitcast(val__, type__) \
rewriter.create<LLVM::BitcastOp>(loc, type__, val__)
#define gep(...) rewriter.create<LLVM::GEPOp>(loc, __VA_ARGS__)
#define ptr_ty(...) LLVM::LLVMPointerType::get(__VA_ARGS__)
#define insert_val(...) rewriter.create<LLVM::InsertValueOp>(loc, __VA_ARGS__)
#define extract_val(...) rewriter.create<LLVM::ExtractValueOp>(loc, __VA_ARGS__)
#define insert_element(...) \
rewriter.create<LLVM::InsertElementOp>(loc, __VA_ARGS__)
#define extract_element(...) \
rewriter.create<LLVM::ExtractElementOp>(loc, __VA_ARGS__)
#define load(...) rewriter.create<LLVM::LoadOp>(loc, __VA_ARGS__)
#define store(val, ptr) rewriter.create<LLVM::StoreOp>(loc, val, ptr)
#define fcmp_ogt(lhs, rhs) \
rewriter.create<LLVM::FCmpOp>(loc, rewriter.getI1Type(), \
LLVM::FCmpPredicate::ogt, lhs, rhs)
#define fcmp_olt(lhs, rhs) \
rewriter.create<LLVM::FCmpOp>(loc, rewriter.getI1Type(), \
LLVM::FCmpPredicate::olt, lhs, rhs)
#define icmp_eq(...) \
rewriter.create<LLVM::ICmpOp>(loc, LLVM::ICmpPredicate::eq, __VA_ARGS__)
#define icmp_ne(...) \
rewriter.create<LLVM::ICmpOp>(loc, LLVM::ICmpPredicate::ne, __VA_ARGS__)
#define icmp_slt(...) \
rewriter.create<LLVM::ICmpOp>(loc, LLVM::ICmpPredicate::slt, __VA_ARGS__)
#define icmp_sle(...) \
rewriter.create<LLVM::ICmpOp>(loc, LLVM::ICmpPredicate::sle, __VA_ARGS__)
#define icmp_sgt(...) \
rewriter.create<LLVM::ICmpOp>(loc, LLVM::ICmpPredicate::sgt, __VA_ARGS__)
#define icmp_sge(...) \
rewriter.create<LLVM::ICmpOp>(loc, LLVM::ICmpPredicate::sge, __VA_ARGS__)
#define icmp_ult(...) \
rewriter.create<LLVM::ICmpOp>(loc, LLVM::ICmpPredicate::ult, __VA_ARGS__)
#define icmp_ule(...) \
rewriter.create<LLVM::ICmpOp>(loc, LLVM::ICmpPredicate::ule, __VA_ARGS__)
#define icmp_ugt(...) \
rewriter.create<LLVM::ICmpOp>(loc, LLVM::ICmpPredicate::ugt, __VA_ARGS__)
#define icmp_uge(...) \
rewriter.create<LLVM::ICmpOp>(loc, LLVM::ICmpPredicate::uge, __VA_ARGS__)
#define select(...) rewriter.create<LLVM::SelectOp>(loc, __VA_ARGS__)
#define address_of(...) rewriter.create<LLVM::AddressOfOp>(loc, __VA_ARGS__)
#define barrier() rewriter.create<mlir::gpu::BarrierOp>(loc)
#define undef(...) rewriter.create<LLVM::UndefOp>(loc, __VA_ARGS__)
#define i32_ty rewriter.getIntegerType(32)
#define ui32_ty rewriter.getIntegerType(32, false)
#define f16_ty rewriter.getF16Type()
#define bf16_ty rewriter.getBF16Type()
#define i8_ty rewriter.getIntegerType(8)
#define f32_ty rewriter.getF32Type()
#define f64_ty rewriter.getF64Type()
#define vec_ty(type, num) VectorType::get(num, type)
#define f32_val(...) LLVM::createConstantF32(loc, rewriter, __VA_ARGS__)
#define f64_val(...) LLVM::createConstantF64(loc, rewriter, __VA_ARGS__)
#define void_ty(ctx) LLVM::LLVMVoidType::get(ctx)
#define struct_ty(...) LLVM::LLVMStructType::getLiteral(ctx, __VA_ARGS__)
// Creator for constant
#define i32_val(...) LLVM::createConstantI32(loc, rewriter, __VA_ARGS__)
#define int_val(width, val) \
LLVM::createLLVMIntegerConstant(rewriter, loc, width, val)
#define idx_val(...) \
LLVM::createIndexConstant(rewriter, loc, this->getTypeConverter(), \
__VA_ARGS__)
#define tid_val() getThreadId(rewriter, loc)
namespace mlir {
namespace LLVM {
using namespace mlir::triton;
Value getStructFromElements(Location loc, ValueRange resultVals,
ConversionPatternRewriter &rewriter,
Type structType) {
if (!structType.isa<LLVM::LLVMStructType>()) {
return *resultVals.begin();
}
Value llvmStruct = rewriter.create<LLVM::UndefOp>(loc, structType);
for (const auto &v : llvm::enumerate(resultVals)) {
assert(v.value() && "can not insert null values");
llvmStruct = insert_val(structType, llvmStruct, v.value(),
rewriter.getI64ArrayAttr(v.index()));
}
return llvmStruct;
}
SmallVector<Value> getElementsFromStruct(Location loc, Value llvmStruct,
ConversionPatternRewriter &rewriter) {
if (llvmStruct.getType().isIntOrIndexOrFloat() ||
llvmStruct.getType().isa<triton::PointerType>() ||
llvmStruct.getType().isa<LLVM::LLVMPointerType>())
return {llvmStruct};
ArrayRef<Type> types =
llvmStruct.getType().cast<LLVM::LLVMStructType>().getBody();
SmallVector<Value> results(types.size());
for (unsigned i = 0; i < types.size(); ++i) {
Type type = types[i];
results[i] = extract_val(type, llvmStruct, rewriter.getI64ArrayAttr(i));
}
return results;
}
// Create a 32-bit integer constant.
Value createConstantI32(Location loc, PatternRewriter &rewriter, int32_t v) {
auto i32ty = rewriter.getIntegerType(32);
return rewriter.create<LLVM::ConstantOp>(loc, i32ty,
IntegerAttr::get(i32ty, v));
}
Value createConstantF32(Location loc, PatternRewriter &rewriter, float v) {
auto type = type::f32Ty(rewriter.getContext());
return rewriter.create<LLVM::ConstantOp>(loc, type,
rewriter.getF32FloatAttr(v));
}
Value createConstantF64(Location loc, PatternRewriter &rewriter, float v) {
auto type = type::f64Ty(rewriter.getContext());
return rewriter.create<LLVM::ConstantOp>(loc, type,
rewriter.getF64FloatAttr(v));
}
// Create an index type constant.
Value createIndexConstant(OpBuilder &builder, Location loc,
TypeConverter *converter, int64_t value) {
Type ty = converter->convertType(builder.getIndexType());
return builder.create<LLVM::ConstantOp>(loc, ty,
builder.getIntegerAttr(ty, value));
}
// Create an integer constant of \param width bits.
Value createLLVMIntegerConstant(OpBuilder &builder, Location loc, short width,
int64_t value) {
Type ty = builder.getIntegerType(width);
return builder.create<LLVM::ConstantOp>(loc, ty,
builder.getIntegerAttr(ty, value));
}
/// Helper function to get strides from a given shape and its order
SmallVector<Value>
getStridesFromShapeAndOrder(ArrayRef<int64_t> shape, ArrayRef<unsigned> order,
Location loc, ConversionPatternRewriter &rewriter) {
auto rank = shape.size();
SmallVector<Value> strides(rank);
auto stride = 1;
for (auto idx : order) {
strides[idx] = i32_val(stride);
stride *= shape[idx];
}
return strides;
}
struct SharedMemoryObject {
Value base; // i32 ptr. The start address of the shared memory object.
// We need to store strides as Values but not integers because the
// extract_slice instruction can take a slice at artibary offsets.
// Take $a[16:32, 16:32] as an example, though we know the stride of $a[0] is
// 32, we need to let the instruction that uses $a to be aware of that.
// Otherwise, when we use $a, we only know that the shape of $a is 16x16. If
// we store strides into an attribute array of integers, the information
// cannot pass through block argument assignment because attributes are
// associated with operations but not Values.
// TODO(Keren): We may need to figure out a way to store strides as integers
// if we want to support more optimizations.
SmallVector<Value>
strides; // i32 int. The strides of the shared memory object.
SmallVector<Value> offsets; // i32 int. The offsets of the shared memory
// objects from the originally allocated object.
SharedMemoryObject(Value base, ArrayRef<Value> strides,
ArrayRef<Value> offsets)
: base(base), strides(strides.begin(), strides.end()),
offsets(offsets.begin(), offsets.end()) {}
SharedMemoryObject(Value base, ArrayRef<int64_t> shape,
ArrayRef<unsigned> order, Location loc,
ConversionPatternRewriter &rewriter)
: base(base) {
strides = getStridesFromShapeAndOrder(shape, order, loc, rewriter);
for (auto idx : order) {
offsets.emplace_back(i32_val(0));
}
}
SmallVector<Value> getElems() const {
SmallVector<Value> elems;
elems.push_back(base);
elems.append(strides.begin(), strides.end());
elems.append(offsets.begin(), offsets.end());
return elems;
}
SmallVector<Type> getTypes() const {
SmallVector<Type> types;
types.push_back(base.getType());
types.append(strides.size(), IntegerType::get(base.getContext(), 32));
types.append(offsets.size(), IntegerType::get(base.getContext(), 32));
return types;
}
Value getCSwizzleOffset(int order) const {
assert(order >= 0 && order < strides.size());
return offsets[order];
}
Value getBaseBeforeSwizzle(int order, Location loc,
ConversionPatternRewriter &rewriter) const {
Value cSwizzleOffset = getCSwizzleOffset(order);
Value offset = sub(i32_val(0), cSwizzleOffset);
Type type = base.getType();
return gep(type, base, offset);
}
};
SharedMemoryObject
getSharedMemoryObjectFromStruct(Location loc, Value llvmStruct,
ConversionPatternRewriter &rewriter) {
auto elems = getElementsFromStruct(loc, llvmStruct, rewriter);
auto rank = (elems.size() - 1) / 2;
return {/*base=*/elems[0],
/*strides=*/{elems.begin() + 1, elems.begin() + 1 + rank},
/*offsets=*/{elems.begin() + 1 + rank, elems.end()}};
}
Value storeShared(ConversionPatternRewriter &rewriter, Location loc, Value ptr,
Value val, Value pred) {
MLIRContext *ctx = rewriter.getContext();
unsigned bits = val.getType().getIntOrFloatBitWidth();
const char *c = bits == 64 ? "l" : (bits == 16 ? "h" : "r");
PTXBuilder builder;
auto *ptrOpr = builder.newAddrOperand(ptr, "r");
auto *valOpr = builder.newOperand(val, c);
auto &st = builder.create<>("st")->shared().b(bits);
st(ptrOpr, valOpr).predicate(pred, "b");
return builder.launch(rewriter, loc, void_ty(ctx));
}
Value shflSync(Location loc, ConversionPatternRewriter &rewriter, Value val,
int i) {
unsigned bits = val.getType().getIntOrFloatBitWidth();
if (bits == 64) {
Type vecTy = vec_ty(f32_ty, 2);
Value vec = bitcast(val, vecTy);
Value val0 = extract_element(f32_ty, vec, i32_val(0));
Value val1 = extract_element(f32_ty, vec, i32_val(1));
val0 = shflSync(loc, rewriter, val0, i);
val1 = shflSync(loc, rewriter, val1, i);
vec = undef(vecTy);
vec = insert_element(vecTy, vec, val0, i32_val(0));
vec = insert_element(vecTy, vec, val1, i32_val(1));
return bitcast(vec, val.getType());
}
PTXBuilder builder;
auto &shfl = builder.create("shfl.sync")->o("bfly").o("b32");
auto *dOpr = builder.newOperand("=r");
auto *aOpr = builder.newOperand(val, "r");
auto *bOpr = builder.newConstantOperand(i);
auto *cOpr = builder.newConstantOperand("0x1f");
auto *maskOpr = builder.newConstantOperand("0xffffffff");
shfl(dOpr, aOpr, bOpr, cOpr, maskOpr);
return builder.launch(rewriter, loc, val.getType(), false);
}
} // namespace LLVM
} // namespace mlir
#endif

View File

@@ -245,9 +245,53 @@ struct TritonDotPattern : public OpConversionPattern<triton::DotOp> {
bType.getElementType(), encoding);
b = rewriter.create<triton::gpu::ConvertLayoutOp>(b.getLoc(), dstType, b);
}
rewriter.replaceOpWithNewOp<triton::DotOp>(
op, retType, a, b, adaptor.c(), adaptor.allowTF32(), adaptor.transA(),
adaptor.transB());
rewriter.replaceOpWithNewOp<triton::DotOp>(op, retType, a, b, adaptor.c(),
adaptor.allowTF32());
return success();
}
};
struct TritonTransPattern : public OpConversionPattern<triton::TransOp> {
using OpConversionPattern<triton::TransOp>::OpConversionPattern;
LogicalResult
matchAndRewrite(triton::TransOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
Value src = adaptor.src();
auto srcType = src.getType().cast<RankedTensorType>();
Attribute srcEncoding = srcType.getEncoding();
if (!srcEncoding)
return failure();
if (!srcEncoding.isa<triton::gpu::SharedEncodingAttr>()) {
// TODO: end-to-end correctness is broken if
// the input is blocked and the output is shared
// with different order. Maybe a backend issue in BlockedToShared?
SmallVector<unsigned> order = {1, 0};
if (auto srcBlockedEncoding =
srcEncoding.dyn_cast<triton::gpu::BlockedEncodingAttr>())
llvm::copy(srcBlockedEncoding.getOrder(), order.begin());
srcEncoding =
triton::gpu::SharedEncodingAttr::get(getContext(), 1, 1, 1, order);
srcType = RankedTensorType::get(srcType.getShape(),
srcType.getElementType(), srcEncoding);
src = rewriter.create<triton::gpu::ConvertLayoutOp>(src.getLoc(), srcType,
src);
}
auto srcSharedEncoding =
srcEncoding.cast<triton::gpu::SharedEncodingAttr>();
SmallVector<unsigned> retOrder(srcSharedEncoding.getOrder().begin(),
srcSharedEncoding.getOrder().end());
SmallVector<int64_t> retShapes(srcType.getShape().begin(),
srcType.getShape().end());
std::reverse(retOrder.begin(), retOrder.end());
std::reverse(retShapes.begin(), retShapes.end());
auto retEncoding =
triton::gpu::SharedEncodingAttr::get(getContext(), 1, 1, 1, retOrder);
auto retType =
RankedTensorType::get(retShapes, srcType.getElementType(), retEncoding);
rewriter.replaceOpWithNewOp<triton::TransOp>(op, retType, src);
return success();
}
};
@@ -286,8 +330,8 @@ struct TritonAtomicCASPattern
matchAndRewrite(triton::AtomicCASOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
rewriter.replaceOpWithNewOp<triton::AtomicCASOp>(
op, typeConverter->convertType(op.getType()),
adaptor.ptr(), adaptor.cmp(), adaptor.val());
op, typeConverter->convertType(op.getType()), adaptor.ptr(),
adaptor.cmp(), adaptor.val());
return success();
}
};
@@ -390,9 +434,10 @@ void populateTritonPatterns(TritonGPUTypeConverter &typeConverter,
TritonGenericPattern<triton::PtrToIntOp>,
TritonGenericPattern<triton::SplatOp>, TritonBroadcastPattern,
TritonGenericPattern<triton::AddPtrOp>, TritonReducePattern,
TritonExpandDimsPattern, TritonMakeRangePattern, TritonDotPattern,
TritonLoadPattern, TritonStorePattern, TritonExtElemwisePattern,
TritonPrintfPattern, TritonAtomicRMWPattern>(typeConverter, context);
TritonTransPattern, TritonExpandDimsPattern, TritonMakeRangePattern,
TritonDotPattern, TritonLoadPattern, TritonStorePattern,
TritonExtElemwisePattern, TritonPrintfPattern, TritonAtomicRMWPattern>(
typeConverter, context);
}
//
@@ -456,10 +501,55 @@ struct SCFYieldPattern : public OpConversionPattern<scf::YieldOp> {
}
};
// This is borrowed from ConvertFIfOpTypes in
// SCF/Transforms/StructuralTypeConversions.cpp
class SCFIfPattern : public OpConversionPattern<scf::IfOp> {
public:
using OpConversionPattern<scf::IfOp>::OpConversionPattern;
LogicalResult
matchAndRewrite(scf::IfOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
// TODO: Generalize this to any type conversion, not just 1:1.
//
// We need to implement something more sophisticated here that tracks which
// types convert to which other types and does the appropriate
// materialization logic.
// For example, it's possible that one result type converts to 0 types and
// another to 2 types, so newResultTypes would at least be the right size to
// not crash in the llvm::zip call below, but then we would set the the
// wrong type on the SSA values! These edge cases are also why we cannot
// safely use the TypeConverter::convertTypes helper here.
SmallVector<Type> newResultTypes;
for (auto type : op.getResultTypes()) {
Type newType = typeConverter->convertType(type);
if (!newType)
return rewriter.notifyMatchFailure(op, "not a 1:1 type conversion");
newResultTypes.push_back(newType);
}
// See comments in the ForOp pattern for why we clone without regions and
// then inline.
scf::IfOp newOp =
cast<scf::IfOp>(rewriter.cloneWithoutRegions(*op.getOperation()));
rewriter.inlineRegionBefore(op.getThenRegion(), newOp.getThenRegion(),
newOp.getThenRegion().end());
rewriter.inlineRegionBefore(op.getElseRegion(), newOp.getElseRegion(),
newOp.getElseRegion().end());
// Update the operands and types.
newOp->setOperands(adaptor.getOperands());
for (auto t : llvm::zip(newOp.getResults(), newResultTypes))
std::get<0>(t).setType(std::get<1>(t));
rewriter.replaceOp(op, newOp.getResults());
return success();
}
};
void populateSCFPatterns(TritonGPUTypeConverter &typeConverter,
RewritePatternSet &patterns) {
MLIRContext *context = patterns.getContext();
patterns.add<SCFYieldPattern, SCFForPattern>(typeConverter, context);
patterns.add<SCFYieldPattern, SCFForPattern, SCFIfPattern>(typeConverter,
context);
}
class ConvertTritonToTritonGPU

View File

@@ -240,12 +240,17 @@ mlir::LogicalResult mlir::triton::ReduceOp::inferReturnTypes(
Value arg = operands[0];
auto argTy = arg.getType().cast<RankedTensorType>();
auto argEltTy = argTy.getElementType();
auto i32Ty = IntegerType::get(argEltTy.getContext(), 32);
auto redOp =
attributes.get("redOp").cast<mlir::triton::RedOpAttr>().getValue();
bool withIndex = mlir::triton::ReduceOp::withIndex(redOp);
auto retEltTy = withIndex ? i32Ty : argEltTy;
auto retShape = argTy.getShape().vec();
int axis = attributes.get("axis").cast<IntegerAttr>().getInt();
retShape.erase(retShape.begin() + axis);
if (retShape.empty()) {
// 0d-tensor -> scalar
inferredReturnTypes.push_back(argEltTy);
inferredReturnTypes.push_back(retEltTy);
} else {
// nd-tensor where n >= 1
// infer encoding
@@ -264,11 +269,20 @@ mlir::LogicalResult mlir::triton::ReduceOp::inferReturnTypes(
}
// create type
inferredReturnTypes.push_back(
RankedTensorType::get(retShape, argEltTy, retEncoding));
RankedTensorType::get(retShape, retEltTy, retEncoding));
}
return mlir::success();
}
bool mlir::triton::ReduceOp::withIndex(mlir::triton::RedOp redOp) {
return redOp == mlir::triton::RedOp::ARGMIN ||
redOp == mlir::triton::RedOp::ARGMAX ||
redOp == mlir::triton::RedOp::ARGUMIN ||
redOp == mlir::triton::RedOp::ARGUMAX ||
redOp == mlir::triton::RedOp::ARGFMIN ||
redOp == mlir::triton::RedOp::ARGFMAX;
}
//-- SplatOp --
OpFoldResult SplatOp::fold(ArrayRef<Attribute> operands) {
auto constOperand = src().getDefiningOp<arith::ConstantOp>();

View File

@@ -12,21 +12,21 @@ include "triton/Dialect/Triton/IR/TritonOps.td"
// AddIOp(d, DotOp(a, b, c)) and c==0 => DotOp(a, b, d)
// AddFOp(d, DotOp(a, b, c)) and c==0 => DotOp(a, b, d)
def CombineDotAddIPattern : Pat<
(Arith_AddIOp $d, (TT_DotOp:$res $a, $b, $c, $allowTF32, $transA, $transB)),
(TT_DotOp $a, $b, $d, $allowTF32, $transA, $transB),
(Arith_AddIOp $d, (TT_DotOp:$res $a, $b, $c, $allowTF32)),
(TT_DotOp $a, $b, $d, $allowTF32),
[(Constraint<CPred<"isZero($0)">> $c)]>;
def CombineDotAddFPattern : Pat<
(Arith_AddFOp $d, (TT_DotOp:$res $a, $b, $c, $allowTF32, $transA, $transB)),
(TT_DotOp $a, $b, $d, $allowTF32, $transA, $transB),
(Arith_AddFOp $d, (TT_DotOp:$res $a, $b, $c, $allowTF32)),
(TT_DotOp $a, $b, $d, $allowTF32),
[(Constraint<CPred<"isZero($0)">> $c)]>;
def CombineDotAddIRevPattern : Pat<
(Arith_AddIOp (TT_DotOp:$res $a, $b, $c, $allowTF32, $transA, $transB), $d),
(TT_DotOp $a, $b, $d, $allowTF32, $transA, $transB),
(Arith_AddIOp (TT_DotOp:$res $a, $b, $c, $allowTF32), $d),
(TT_DotOp $a, $b, $d, $allowTF32),
[(Constraint<CPred<"isZero($0)">> $c)]>;
def CombineDotAddFRevPattern : Pat<
(Arith_AddFOp (TT_DotOp:$res $a, $b, $c, $allowTF32, $transA, $transB), $d),
(TT_DotOp $a, $b, $d, $allowTF32, $transA, $transB),
(Arith_AddFOp (TT_DotOp:$res $a, $b, $c, $allowTF32), $d),
(TT_DotOp $a, $b, $d, $allowTF32),
[(Constraint<CPred<"isZero($0)">> $c)]>;

View File

@@ -71,22 +71,22 @@ unsigned getElemsPerThread(Type type) {
return getElemsPerThread(tensorType.getEncoding(), tensorType.getShape());
}
SmallVector<unsigned> getThreadsPerWarp(Attribute layout) {
SmallVector<unsigned> getThreadsPerWarp(const Attribute &layout) {
if (auto blockedLayout = layout.dyn_cast<BlockedEncodingAttr>()) {
return SmallVector<unsigned>(blockedLayout.getThreadsPerWarp().begin(),
blockedLayout.getThreadsPerWarp().end());
}
if (auto mmaLayout = layout.dyn_cast<MmaEncodingAttr>()) {
if (mmaLayout.getVersion() == 1)
return SmallVector<unsigned>{4, 8};
return {4, 8};
if (mmaLayout.getVersion() == 2)
return SmallVector<unsigned>{8, 4};
return {8, 4};
}
assert(0 && "getThreadsPerWarp not implemented");
return {};
}
SmallVector<unsigned> getWarpsPerCTA(Attribute layout) {
SmallVector<unsigned> getWarpsPerCTA(const Attribute &layout) {
if (auto blockedLayout = layout.dyn_cast<BlockedEncodingAttr>()) {
return SmallVector<unsigned>(blockedLayout.getWarpsPerCTA().begin(),
blockedLayout.getWarpsPerCTA().end());
@@ -99,16 +99,22 @@ SmallVector<unsigned> getWarpsPerCTA(Attribute layout) {
return {};
}
SmallVector<unsigned> getSizePerThread(Attribute layout) {
SmallVector<unsigned> getSizePerThread(const Attribute &layout) {
if (auto blockedLayout = layout.dyn_cast<BlockedEncodingAttr>()) {
return SmallVector<unsigned>(blockedLayout.getSizePerThread().begin(),
blockedLayout.getSizePerThread().end());
} else if (auto sliceLayout = layout.dyn_cast<SliceEncodingAttr>()) {
return getSizePerThread(sliceLayout.getParent());
} else if (auto mmaLayout = layout.dyn_cast<MmaEncodingAttr>()) {
assert(mmaLayout.getVersion() == 2 &&
"mmaLayout version = 1 is not implemented yet");
return SmallVector<unsigned>{2, 2};
if (mmaLayout.getVersion() == 2) {
return {2, 2};
} else if (mmaLayout.getVersion() == 1) {
// Note: here the definition of sizePerThread is obscure, which doesn't
// mean vecSize=4 can be supported in the last dimension.
return {2, 4};
} else {
llvm_unreachable("Unexpected mma version");
}
} else if (auto dotLayout = layout.dyn_cast<DotOperandEncodingAttr>()) {
auto parentLayout = dotLayout.getParent();
assert(parentLayout && "DotOperandEncodingAttr must have a parent");
@@ -136,6 +142,15 @@ SmallVector<unsigned> getSizePerThread(Attribute layout) {
}
}
SmallVector<unsigned> getContigPerThread(Attribute layout) {
if (auto mmaLayout = layout.dyn_cast<MmaEncodingAttr>()) {
assert(mmaLayout.getVersion() == 1 || mmaLayout.getVersion() == 2);
return {1, 2};
} else {
return getSizePerThread(layout);
}
}
SmallVector<unsigned> getThreadsPerCTA(const Attribute &layout) {
SmallVector<unsigned> threads;
if (auto blockedLayout = layout.dyn_cast<BlockedEncodingAttr>()) {
@@ -194,6 +209,16 @@ SmallVector<unsigned> getShapePerCTA(const Attribute &layout) {
assert(0 && "DotOperandEncodingAttr non-MmaEncodingAttr parent not "
"supported yet");
}
} else if (auto mmaLayout = layout.dyn_cast<MmaEncodingAttr>()) {
if (mmaLayout.getVersion() == 2) {
return {16 * mmaLayout.getWarpsPerCTA()[0],
8 * mmaLayout.getWarpsPerCTA()[1]};
} else if (mmaLayout.getVersion() == 1) {
return {16 * mmaLayout.getWarpsPerCTA()[0],
16 * mmaLayout.getWarpsPerCTA()[1]};
} else {
llvm_unreachable("Unexpected mma version");
}
} else {
assert(0 && "Unimplemented usage of getShapePerCTA");
}
@@ -205,9 +230,9 @@ SmallVector<unsigned> getOrder(const Attribute &layout) {
return SmallVector<unsigned>(blockedLayout.getOrder().begin(),
blockedLayout.getOrder().end());
} else if (auto mmaLayout = layout.dyn_cast<MmaEncodingAttr>()) {
return SmallVector<unsigned>{1, 0};
return {1, 0};
} else if (auto dotLayout = layout.dyn_cast<DotOperandEncodingAttr>()) {
return SmallVector<unsigned>{1, 0};
return {1, 0};
} else if (auto sliceLayout = layout.dyn_cast<SliceEncodingAttr>()) {
SmallVector<unsigned> parentOrder = getOrder(sliceLayout.getParent());
unsigned dim = sliceLayout.getDim();
@@ -358,6 +383,8 @@ unsigned MmaEncodingAttr::getElemsPerThread(ArrayRef<int64_t> shape) const {
unsigned elemsCol = ceil<unsigned>(shape[0], 16 * getWarpsPerCTA()[0]) * 2;
unsigned elemsRow = ceil<unsigned>(shape[1], 8 * getWarpsPerCTA()[1]) * 2;
res = elemsCol * elemsRow;
} else {
llvm_unreachable("Unexpected mma version");
}
return res;
@@ -632,6 +659,15 @@ void printInsertSliceAsyncOp(OpAsmPrinter &printer,
printer.printStrippedAttrOrType(insertSliceAsyncOp.result().getType());
}
DenseSet<unsigned>
InsertSliceAsyncOp::getEligibleLoadByteWidth(int computeCapability) {
DenseSet<unsigned> validLoadBytes;
if (computeCapability >= 80) {
validLoadBytes = {4, 8, 16};
}
return validLoadBytes;
}
//===----------------------------------------------------------------------===//
// ASM Interface (i.e.: alias)
//===----------------------------------------------------------------------===//

View File

@@ -50,10 +50,25 @@ public:
auto dstType = convert.getType().cast<RankedTensorType>();
if (srcType.getEncoding().isa<triton::gpu::BlockedEncodingAttr>() &&
dstType.getEncoding().isa<triton::gpu::DotOperandEncodingAttr>()) {
auto tmpType =
RankedTensorType::get(dstType.getShape(), dstType.getElementType(),
triton::gpu::SharedEncodingAttr::get(
op->getContext(), 1, 1, 1, {1, 0}));
auto dstDotOperand =
dstType.getEncoding().cast<triton::gpu::DotOperandEncodingAttr>();
auto dstParent = dstDotOperand.getParent();
if (dstDotOperand.getOpIdx() == 1 ||
!dstParent.isa<triton::gpu::MmaEncodingAttr>())
return mlir::failure();
auto dstParentMma = dstParent.cast<triton::gpu::MmaEncodingAttr>();
if (dstParentMma.getVersion() == 1 ||
dstParentMma.getWarpsPerCTA()[1] > 1)
return mlir::failure();
SetVector<Operation *> bwdSlices;
mlir::getBackwardSlice(convert.getResult(), &bwdSlices);
if (llvm::find_if(bwdSlices, [](Operation *op) {
return isa<triton::DotOp>(op);
}) == bwdSlices.end())
return mlir::failure();
auto tmpType = RankedTensorType::get(
dstType.getShape(), dstType.getElementType(), dstParentMma);
auto tmp = rewriter.create<triton::gpu::ConvertLayoutOp>(
convert.getLoc(), tmpType, convert.getOperand());
auto newConvert = rewriter.create<triton::gpu::ConvertLayoutOp>(
@@ -81,8 +96,11 @@ public:
auto convert = llvm::cast<triton::gpu::ConvertLayoutOp>(op);
// we don't handle conversions to DotOperandEncodingAttr
// this is a heuristics to accommodate fused attention
// if (dstType.getEncoding().isa<triton::gpu::DotOperandEncodingAttr>())
// return mlir::failure();
auto srcType = convert.getOperand().getType().cast<RankedTensorType>();
auto dstType = convert.getType().cast<RankedTensorType>();
if (dstType.getEncoding().isa<triton::gpu::DotOperandEncodingAttr>() &&
srcType.getEncoding().isa<triton::gpu::MmaEncodingAttr>())
return mlir::failure();
// convert to the same layout -- we can delete
if (op->getResultTypes() == op->getOperandTypes()) {
rewriter.replaceOp(op, op->getOperands());
@@ -160,6 +178,10 @@ public:
!isSharedEncoding(convert.getResult())) {
return mlir::failure();
}
if (isSharedEncoding(convert.getOperand()) &&
isSharedEncoding(convert.getResult())) {
return mlir::failure();
}
auto srcType = convert.getOperand().getType().cast<RankedTensorType>();
auto srcShared =
srcType.getEncoding().dyn_cast<triton::gpu::SharedEncodingAttr>();
@@ -586,13 +608,9 @@ mmaVersionToShapePerWarp(int version, const ArrayRef<int64_t> &shape,
}
}
template <int version>
SmallVector<unsigned, 2> warpsPerTile(const ArrayRef<int64_t> shape,
int numWarps);
template <>
SmallVector<unsigned, 2> warpsPerTile<1>(const ArrayRef<int64_t> shape,
int numWarps) {
SmallVector<unsigned, 2> warpsPerTileV1(triton::DotOp dotOp,
const ArrayRef<int64_t> shape,
int numWarps) {
SmallVector<unsigned, 2> ret = {1, 1};
SmallVector<int64_t, 2> shapePerWarp =
mmaVersionToShapePerWarp(1, shape, numWarps);
@@ -611,17 +629,25 @@ SmallVector<unsigned, 2> warpsPerTile<1>(const ArrayRef<int64_t> shape,
return ret;
}
template <>
SmallVector<unsigned, 2> warpsPerTile<2>(const ArrayRef<int64_t> shape,
int numWarps) {
SmallVector<unsigned, 2> warpsPerTileV2(triton::DotOp dotOp,
const ArrayRef<int64_t> shape,
int numWarps) {
SetVector<Operation *> slices;
mlir::getForwardSlice(dotOp.getResult(), &slices);
if (llvm::find_if(slices, [](Operation *op) {
return isa<triton::DotOp>(op);
}) != slices.end())
return {(unsigned)numWarps, 1};
SmallVector<unsigned, 2> ret = {1, 1};
SmallVector<int64_t, 2> shapePerWarp =
mmaVersionToShapePerWarp(2, shape, numWarps);
SmallVector<int64_t, 2> shapePerWarp = {16, 8};
bool changed = false;
// TODO (@daadaada): double-check.
// original logic in
// https://github.com/openai/triton/blob/master/lib/codegen/analysis/layout.cc#L252
// seems buggy for shape = [32, 16] ?
do {
changed = false;
if (ret[0] * ret[1] >= numWarps)
break;
if (shape[0] / shapePerWarp[0] / ret[0] >=
@@ -638,6 +664,55 @@ SmallVector<unsigned, 2> warpsPerTile<2>(const ArrayRef<int64_t> shape,
}
} // namespace
class OptimizeBlockedToShared : public mlir::RewritePattern {
public:
OptimizeBlockedToShared(mlir::MLIRContext *context)
: RewritePattern(triton::gpu::ConvertLayoutOp::getOperationName(), 1,
context) {}
mlir::LogicalResult
matchAndRewrite(mlir::Operation *op,
mlir::PatternRewriter &rewriter) const override {
auto cvt = cast<triton::gpu::ConvertLayoutOp>(op);
auto srcType = cvt.getOperand().getType().cast<RankedTensorType>();
auto dstType = cvt.getResult().getType().cast<RankedTensorType>();
auto srcBlockedLayout =
srcType.getEncoding().dyn_cast<triton::gpu::BlockedEncodingAttr>();
auto dstSharedLayout =
dstType.getEncoding().dyn_cast<triton::gpu::SharedEncodingAttr>();
if (!srcBlockedLayout || !dstSharedLayout)
return failure();
if (srcBlockedLayout.getOrder() == dstSharedLayout.getOrder())
return failure();
// For now only works if single use is transpose
// TODO: rematerialize #shared uses
auto users = op->getUsers();
if (std::distance(users.begin(), users.end()) != 1 ||
!isa<triton::TransOp>(*users.begin()))
return failure();
auto tmpShared = triton::gpu::SharedEncodingAttr::get(
op->getContext(), dstSharedLayout.getVec(),
dstSharedLayout.getPerPhase(), dstSharedLayout.getMaxPhase(),
srcBlockedLayout.getOrder());
auto tmpType = RankedTensorType::get(srcType.getShape(),
srcType.getElementType(), tmpShared);
auto tmpCvt = rewriter.create<triton::gpu::ConvertLayoutOp>(
op->getLoc(), tmpType, cvt.getOperand());
auto newDstType = RankedTensorType::get(
users.begin()->getResultTypes()[0].cast<RankedTensorType>().getShape(),
srcType.getElementType(), dstSharedLayout);
auto newTrans = rewriter.create<triton::TransOp>(op->getLoc(), newDstType,
tmpCvt.getResult());
rewriter.replaceOp(*users.begin(), newTrans.getResult());
return success();
}
};
class BlockedToMMA : public mlir::RewritePattern {
int computeCapability;
@@ -646,13 +721,14 @@ public:
: mlir::RewritePattern(triton::DotOp::getOperationName(), 2, context),
computeCapability(computeCapability) {}
static SmallVector<unsigned, 2> getWarpsPerTile(const ArrayRef<int64_t> shape,
static SmallVector<unsigned, 2> getWarpsPerTile(triton::DotOp dotOp,
const ArrayRef<int64_t> shape,
int version, int numWarps) {
switch (version) {
case 1:
return warpsPerTile<1>(shape, numWarps);
return warpsPerTileV1(dotOp, shape, numWarps);
case 2:
return warpsPerTile<2>(shape, numWarps);
return warpsPerTileV2(dotOp, shape, numWarps);
default:
assert(false && "not supported version");
return {0, 0};
@@ -680,11 +756,12 @@ public:
auto mod = op->getParentOfType<mlir::ModuleOp>();
int numWarps = triton::gpu::TritonGPUDialect::getNumWarps(mod);
int version = computeCapabilityToMMAVersion(computeCapability);
auto newRetType = RankedTensorType::get(
retShape, oldRetType.getElementType(),
triton::gpu::MmaEncodingAttr::get(
oldRetType.getContext(), version,
getWarpsPerTile(retShape, version, numWarps)));
getWarpsPerTile(dotOp, retShape, version, numWarps)));
// convert accumulator
auto oldAcc = dotOp.getOperand(2);
auto newAcc = rewriter.create<triton::gpu::ConvertLayoutOp>(
@@ -704,8 +781,7 @@ public:
a = rewriter.create<triton::gpu::ConvertLayoutOp>(a.getLoc(), newAType, a);
b = rewriter.create<triton::gpu::ConvertLayoutOp>(b.getLoc(), newBType, b);
auto newDot = rewriter.create<triton::DotOp>(
dotOp.getLoc(), newRetType, a, b, newAcc, dotOp.allowTF32(),
dotOp.transA(), dotOp.transB());
dotOp.getLoc(), newRetType, a, b, newAcc, dotOp.allowTF32());
rewriter.replaceOpWithNewOp<triton::gpu::ConvertLayoutOp>(
op, oldRetType, newDot.getResult());
@@ -731,8 +807,9 @@ public:
mlir::RewritePatternSet patterns(context);
patterns.add<OptimizeBlockedToShared>(context);
patterns.add<SimplifyConversion>(context);
// patterns.add<DecomposeDotOperand>(context);
patterns.add<DecomposeDotOperand>(context);
patterns.add<RematerializeBackward>(context);
patterns.add<RematerializeForward>(context);
patterns.add<MoveConvertOutOfLoop>(context);

View File

@@ -25,18 +25,20 @@ static Type getI1SameShape(Value v) {
tensorType.getEncoding());
}
#define int_attr(num) builder.getI64IntegerAttr(num)
namespace {
class LoopPipeliner {
/// cache forOp we are working on
/// Cache forOp we are working on
scf::ForOp forOp;
/// cache YieldOp for this forOp
/// Cache YieldOp for this forOp
scf::YieldOp yieldOp;
/// loads to be pipelined
/// Loads to be pipelined
SetVector<Value> loads;
/// the value that each load will be mapped to (after layout conversion)
/// The value that each load will be mapped to (after layout conversion)
DenseMap<Value, Value> loadsMapping;
/// load => buffer
DenseMap<Value, Value> loadsBuffer;
@@ -51,7 +53,7 @@ class LoopPipeliner {
///
Value loopIterIdx;
/// comments on numStages:
/// Comments on numStages:
/// [0, numStages-1) are in the prologue
/// numStages-1 is appended after the loop body
int numStages;
@@ -61,6 +63,7 @@ class LoopPipeliner {
/// Block arguments that loads depend on
DenseSet<BlockArgument> depArgs;
/// Operations (inside the loop body) that loads depend on
DenseSet<Operation *> depOps;
@@ -71,7 +74,7 @@ class LoopPipeliner {
Value lookupOrDefault(Value origin, int stage);
/// returns a empty buffer of size <numStages, ...>
/// Returns a empty buffer of size <numStages, ...>
ttg::AllocTensorOp allocateEmptyBuffer(Operation *op, OpBuilder &builder);
public:
@@ -84,7 +87,7 @@ public:
/// Collect loads to pipeline. Return success if we can pipeline this loop
LogicalResult initialize();
/// emit pipelined loads (before loop body)
/// Emit pipelined loads (before loop body)
void emitPrologue();
/// emit pipelined loads (after loop body)
@@ -120,9 +123,13 @@ void LoopPipeliner::collectDeps(Value v, int stages, DenseSet<Value> &deps) {
return;
if (auto arg = v.dyn_cast<BlockArgument>()) {
deps.insert(v);
// Note: we have iv as the first arg, so the op idx is arg.getArgNumber()-1
collectDeps(yieldOp->getOperand(arg.getArgNumber() - 1), stages - 1, deps);
if (arg.getArgNumber() > 0) {
// Skip the first arg (loop induction variable)
// Otherwise the op idx is arg.getArgNumber()-1
deps.insert(v);
collectDeps(yieldOp->getOperand(arg.getArgNumber() - 1), stages - 1,
deps);
}
} else { // value
// v might be in deps, but we still need to visit v.
// This is because v might depend on value in previous iterations
@@ -134,7 +141,7 @@ void LoopPipeliner::collectDeps(Value v, int stages, DenseSet<Value> &deps) {
ttg::AllocTensorOp LoopPipeliner::allocateEmptyBuffer(Operation *op,
OpBuilder &builder) {
// allocate a buffer for each pipelined tensor
// Allocate a buffer for each pipelined tensor
// shape: e.g. (numStages==4), <32x64xbf16> -> <4x32x64xbf16>
Value convertLayout = loadsMapping[op->getResult(0)];
if (auto tensorType = convertLayout.getType().dyn_cast<RankedTensorType>()) {
@@ -215,9 +222,9 @@ LogicalResult LoopPipeliner::initialize() {
loads.insert(loadOp);
}
// we have some loads to pipeline
// We have some loads to pipeline
if (!loads.empty()) {
// update depArgs & depOps
// Update depArgs & depOps
for (Value loadOp : loads) {
for (Value dep : loadDeps[loadOp]) {
// TODO: we should record the stage that the value is depended on
@@ -244,23 +251,20 @@ void LoopPipeliner::emitPrologue() {
setValueMapping(arg, operand.get(), 0);
}
// helper to construct int attribute
auto intAttr = [&](int64_t val) { return builder.getI64IntegerAttr(val); };
// prologue from [0, numStage-1)
Value iv = forOp.getLowerBound();
pipelineIterIdx = builder.create<arith::ConstantIntOp>(iv.getLoc(), 0, 32);
for (int stage = 0; stage < numStages - 1; ++stage) {
// special handling for induction variable as the increment is implicit
// Special handling for induction variable as the increment is implicit
if (stage != 0)
iv = builder.create<arith::AddIOp>(iv.getLoc(), iv, forOp.getStep());
setValueMapping(forOp.getInductionVar(), iv, stage);
// special handling for loop condition as there is no condition in ForOp
// Special handling for loop condition as there is no condition in ForOp
Value loopCond = builder.create<arith::CmpIOp>(
iv.getLoc(), arith::CmpIPredicate::slt, iv, forOp.getUpperBound());
// rematerialize peeled values
// Rematerialize peeled values
SmallVector<Operation *> orderedDeps;
for (Operation &op : forOp.getLoopBody().front()) {
if (depOps.contains(&op))
@@ -314,7 +318,7 @@ void LoopPipeliner::emitPrologue() {
}
}
// update mapping of results
// Update mapping of results
for (unsigned dstIdx : llvm::seq(unsigned(0), op->getNumResults())) {
Value originalResult = op->getResult(dstIdx);
// copy_async will update the value of its only use
@@ -350,13 +354,14 @@ void LoopPipeliner::emitPrologue() {
loadsBufferType[loadOp].getEncoding());
Value extractSlice = builder.create<tensor::ExtractSliceOp>(
loadOp.getLoc(), sliceType, loadStageBuffer[loadOp][numStages - 1],
SmallVector<OpFoldResult>{intAttr(0), intAttr(0), intAttr(0)},
SmallVector<OpFoldResult>{intAttr(1), intAttr(sliceType.getShape()[0]),
intAttr(sliceType.getShape()[1])},
SmallVector<OpFoldResult>{intAttr(1), intAttr(1), intAttr(1)});
SmallVector<OpFoldResult>{int_attr(0), int_attr(0), int_attr(0)},
SmallVector<OpFoldResult>{int_attr(1),
int_attr(sliceType.getShape()[0]),
int_attr(sliceType.getShape()[1])},
SmallVector<OpFoldResult>{int_attr(1), int_attr(1), int_attr(1)});
loadsExtract[loadOp] = extractSlice;
}
// bump up loopIterIdx, this is used for getting the correct slice for the
// Bump up loopIterIdx, this is used for getting the correct slice for the
// *next* iteration
loopIterIdx = builder.create<arith::AddIOp>(
loopIterIdx.getLoc(), loopIterIdx,
@@ -365,9 +370,6 @@ void LoopPipeliner::emitPrologue() {
void LoopPipeliner::emitEpilogue() {
// If there's any outstanding async copies, we need to wait for them.
// TODO(Keren): We may want to completely avoid the async copies in the last
// few iterations by setting is_masked attribute to true. We don't want to use
// the mask operand because it's a tensor but not a scalar.
OpBuilder builder(forOp);
OpBuilder::InsertionGuard g(builder);
builder.setInsertionPointAfter(forOp);
@@ -376,14 +378,13 @@ void LoopPipeliner::emitEpilogue() {
scf::ForOp LoopPipeliner::createNewForOp() {
OpBuilder builder(forOp);
auto intAttr = [&](int64_t v) { return builder.getI64IntegerAttr(v); };
// order of new args:
// (original args),
// (insertSliceAsync buffer at stage numStages - 1) for each load
// (extracted tensor) for each load
// (depArgs at stage numStages-1)
// (iv at stage numStages-1)
// Order of new args:
// (original args)
// (insertSliceAsync buffer at stage numStages - 1) for each load
// (extracted tensor) for each load
// (depArgs at stage numStages - 1)
// (iv at stage numStages - 2)
// (pipeline iteration index)
// (loop iteration index)
SmallVector<Value> newLoopArgs;
@@ -424,6 +425,7 @@ scf::ForOp LoopPipeliner::createNewForOp() {
BlockAndValueMapping mapping;
for (const auto &arg : llvm::enumerate(forOp.getRegionIterArgs()))
mapping.map(arg.value(), newForOp.getRegionIterArgs()[arg.index()]);
mapping.map(forOp.getInductionVar(), newForOp.getInductionVar());
// 2.1 clone the loop body, replace original args with args of the new ForOp
// Insert async wait if necessary.
@@ -465,15 +467,16 @@ scf::ForOp LoopPipeliner::createNewForOp() {
newForOp.getRegionIterArgs()[argIdx + depArgsBeginIdx]);
++argIdx;
}
// special handling for iv & loop condition
// Special handling for iv & loop condition
Value nextIV = builder.create<arith::AddIOp>(
newForOp.getInductionVar().getLoc(),
newForOp.getRegionIterArgs()[nextIVIdx], newForOp.getStep());
Value nextLoopCond =
builder.create<arith::CmpIOp>(nextIV.getLoc(), arith::CmpIPredicate::slt,
nextIV, newForOp.getUpperBound());
nextMapping.map(forOp.getInductionVar(), nextIV);
// slice index
// Slice index
SmallVector<Value> nextBuffers;
SmallVector<Value> extractSlices;
@@ -490,7 +493,7 @@ scf::ForOp LoopPipeliner::createNewForOp() {
for (Operation *op : orderedDeps) {
Operation *nextOp = nullptr;
// update loading mask
// Update loading mask
if (loads.contains(op->getResult(0))) {
auto loadOp = llvm::cast<triton::LoadOp>(op);
Value mask = loadOp.mask();
@@ -500,7 +503,7 @@ scf::ForOp LoopPipeliner::createNewForOp() {
mask.getLoc(), mask.getType(), nextLoopCond);
newMask = builder.create<arith::AndIOp>(
mask.getLoc(), splatCond, nextMapping.lookupOrDefault(mask));
// if mask is defined outside the loop, don't update the map more than
// If mask is defined outside the loop, don't update the map more than
// once
if (!(forOp.isDefinedOutsideOfLoop(mask) && nextMapping.contains(mask)))
nextMapping.map(mask, newMask);
@@ -522,18 +525,19 @@ scf::ForOp LoopPipeliner::createNewForOp() {
loadsBufferType[loadOp].getEncoding());
nextOp = builder.create<tensor::ExtractSliceOp>(
op->getLoc(), sliceType, insertAsyncOp,
SmallVector<OpFoldResult>{extractSliceIndex, intAttr(0), intAttr(0)},
SmallVector<OpFoldResult>{intAttr(1),
intAttr(sliceType.getShape()[0]),
intAttr(sliceType.getShape()[1])},
SmallVector<OpFoldResult>{intAttr(1), intAttr(1), intAttr(1)});
SmallVector<OpFoldResult>{extractSliceIndex, int_attr(0),
int_attr(0)},
SmallVector<OpFoldResult>{int_attr(1),
int_attr(sliceType.getShape()[0]),
int_attr(sliceType.getShape()[1])},
SmallVector<OpFoldResult>{int_attr(1), int_attr(1), int_attr(1)});
extractSlices.push_back(nextOp->getResult(0));
} else
nextOp = builder.clone(*op, nextMapping);
// update mapping of results
// Update mapping of results
for (unsigned dstIdx : llvm::seq(unsigned(0), op->getNumResults())) {
nextMapping.map(op->getResult(dstIdx), nextOp->getResult(dstIdx));
// if this is a loop-carried value, update the mapping for yield
// If this is a loop-carried value, update the mapping for yield
auto originYield = cast<scf::YieldOp>(forOp.getBody()->getTerminator());
for (OpOperand &operand : originYield->getOpOperands()) {
if (operand.get() == op->getResult(dstIdx)) {
@@ -583,7 +587,7 @@ scf::ForOp LoopPipeliner::createNewForOp() {
it->getDefiningOp()->moveAfter(asyncWait);
}
// bump iteration count
// Bump iteration count
pipelineIterIdx = builder.create<arith::AddIOp>(
nextIV.getLoc(), pipelineIterIdx,
builder.create<arith::ConstantIntOp>(nextIV.getLoc(), 1, 32));
@@ -600,9 +604,11 @@ scf::ForOp LoopPipeliner::createNewForOp() {
for (Value nextSlice : extractSlices)
yieldValues.push_back(nextSlice);
for (size_t i = depArgsBeginIdx; i < nextIVIdx; ++i)
yieldValues.push_back(
depArgsMapping.lookup(newForOp.getRegionIterArgs()[i]));
for (size_t i = depArgsBeginIdx; i < nextIVIdx; ++i) {
auto arg = newForOp.getRegionIterArgs()[i];
assert(depArgsMapping.count(arg) && "Missing loop-carried value");
yieldValues.push_back(depArgsMapping[arg]);
}
yieldValues.push_back(nextIV);
yieldValues.push_back(pipelineIterIdx);
yieldValues.push_back(loopIterIdx);

View File

@@ -131,6 +131,11 @@ LogicalResult Prefetcher::initialize() {
if (dotsInFor.empty())
return failure();
// TODO: segfault (original for still has uses)
// when used in flash attention that has 2 dots in the loop
if (dotsInFor.size() > 1)
return failure();
// returns source of cvt
auto getPrefetchSrc = [](Value v) -> Value {
if (auto cvt = v.getDefiningOp<triton::gpu::ConvertLayoutOp>())

View File

@@ -25,8 +25,8 @@ def get_build_type():
elif check_env_flag("REL_WITH_DEB_INFO"):
return "RelWithDebInfo"
else:
return "Debug"
# TODO(Keren): Restore this before we merge into master
return "RelWithDebInfo"
# TODO: change to release when stable enough
#return "Release"

View File

@@ -11,6 +11,7 @@
#include "mlir/Parser.h"
#include "mlir/Support/FileUtilities.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "triton/Analysis/Allocation.h"
#include "triton/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.h"
#include "triton/Conversion/TritonToTritonGPU/TritonToTritonGPU.h"
@@ -115,6 +116,10 @@ void init_triton_ir(py::module &&m) {
.def(py::init<>())
.def("load_triton", [](mlir::MLIRContext &self) {
self.getOrLoadDialect<mlir::triton::TritonDialect>();
// we load LLVM because the frontend uses LLVM.undef for
// some placeholders
self.getOrLoadDialect<mlir::triton::TritonDialect>();
self.getOrLoadDialect<mlir::LLVM::LLVMDialect>();
});
// .def(py::init([](){
// mlir::MLIRContext context;
@@ -187,6 +192,7 @@ void init_triton_ir(py::module &&m) {
/* issue a warning */
}
})
.def("get_context", &mlir::Value::getContext)
.def("replace_all_uses_with",
[](mlir::Value &self, mlir::Value &newValue) {
self.replaceAllUsesWith(newValue);
@@ -335,10 +341,21 @@ void init_triton_ir(py::module &&m) {
return funcs[0];
});
m.def("make_attr",
[](const std::vector<int> &values, mlir::MLIRContext &context) {
return mlir::DenseIntElementsAttr::get(
mlir::RankedTensorType::get(
{static_cast<int64_t>(values.size())},
mlir::IntegerType::get(&context, 32)),
values)
.cast<mlir::Attribute>();
});
m.def(
"parse_mlir_module",
[](const std::string &inputFilename, mlir::MLIRContext &context) {
// initialize registry
// note: we initialize llvm for undef
mlir::DialectRegistry registry;
registry.insert<mlir::triton::TritonDialect,
mlir::triton::gpu::TritonGPUDialect,
@@ -1068,6 +1085,16 @@ void init_triton_ir(py::module &&m) {
mlir::RankedTensorType::get(shape, lhsType.getElementType()),
lhs, rhs);
})
.def("create_trans",
[](mlir::OpBuilder &self, mlir::Value &arg) -> mlir::Value {
auto loc = self.getUnknownLoc();
auto argType = arg.getType().dyn_cast<mlir::RankedTensorType>();
auto argEltType = argType.getElementType();
std::vector<int64_t> retShape = argType.getShape();
std::reverse(retShape.begin(), retShape.end());
return self.create<mlir::triton::TransOp>(
loc, mlir::RankedTensorType::get(retShape, argEltType), arg);
})
.def("create_broadcast",
[](mlir::OpBuilder &self, mlir::Value &arg,
std::vector<int64_t> &shape) -> mlir::Value {
@@ -1096,7 +1123,8 @@ void init_triton_ir(py::module &&m) {
mlir::Value &val) -> mlir::Value {
auto loc = self.getUnknownLoc();
mlir::Type dstType;
if (auto srcTensorType = ptr.getType().dyn_cast<mlir::RankedTensorType>()) {
if (auto srcTensorType =
ptr.getType().dyn_cast<mlir::RankedTensorType>()) {
mlir::Type dstElemType = srcTensorType.getElementType()
.cast<mlir::triton::PointerType>()
.getPointeeType();
@@ -1156,11 +1184,10 @@ void init_triton_ir(py::module &&m) {
})
.def("create_dot",
[](mlir::OpBuilder &self, mlir::Value &a, mlir::Value &b,
mlir::Value &c, bool allowTF32, bool transA,
bool transB) -> mlir::Value {
mlir::Value &c, bool allowTF32) -> mlir::Value {
auto loc = self.getUnknownLoc();
return self.create<mlir::triton::DotOp>(loc, c.getType(), a, b, c,
allowTF32, transA, transB);
allowTF32);
})
.def("create_exp",
[](mlir::OpBuilder &self, mlir::Value &val) -> mlir::Value {
@@ -1195,10 +1222,11 @@ void init_triton_ir(py::module &&m) {
operand.getType().dyn_cast<mlir::RankedTensorType>();
std::vector<int64_t> shape = inputTensorType.getShape();
shape.erase(shape.begin() + axis);
mlir::Type resType = inputTensorType.getElementType();
bool withIndex = mlir::triton::ReduceOp::withIndex(redOp);
mlir::Type resType = withIndex ? self.getI32Type()
: inputTensorType.getElementType();
if (!shape.empty()) {
resType = mlir::RankedTensorType::get(
shape, inputTensorType.getElementType());
resType = mlir::RankedTensorType::get(shape, resType);
}
return self.create<mlir::triton::ReduceOp>(loc, resType, redOp,
operand, axis);
@@ -1231,6 +1259,12 @@ void init_triton_ir(py::module &&m) {
mlir::StringAttr::get(self.getContext(),
llvm::StringRef(prefix)),
values);
})
// Undef
.def("create_undef",
[](mlir::OpBuilder &self, mlir::Type &type) -> mlir::Value {
auto loc = self.getUnknownLoc();
return self.create<::mlir::LLVM::UndefOp>(loc, type);
});
py::class_<mlir::PassManager>(m, "pass_manager")
@@ -1348,6 +1382,11 @@ void init_triton_translation(py::module &m) {
llvm::SMDiagnostic error;
std::unique_ptr<llvm::Module> module =
llvm::parseIR(buffer->getMemBufferRef(), error, context);
if (!module)
llvm::report_fatal_error(
"failed to parse IR: " + error.getMessage() +
"lineno: " + std::to_string(error.getLineNo()));
// translate module to PTX
auto ptxCode =
triton::translateLLVMIRToPTX(*module, capability, version);

View File

@@ -0,0 +1,18 @@
import os
from typing import Optional
_SYSTEM_LIBDEVICE_SEARCH_PATHS = [
'/usr/lib/cuda/nvvm/libdevice/libdevice.10.bc',
'/usr/local/cuda/nvvm/libdevice/libdevice.10.bc',
]
SYSTEM_LIBDEVICE_PATH: Optional[str] = None
for _p in _SYSTEM_LIBDEVICE_SEARCH_PATHS:
if os.path.exists(_p):
SYSTEM_LIBDEVICE_PATH = _p
def system_libdevice_path() -> str:
assert SYSTEM_LIBDEVICE_PATH is not None, \
"Could not find libdevice.10.bc path"
return SYSTEM_LIBDEVICE_PATH

View File

@@ -0,0 +1,188 @@
import pytest
import torch
import triton
# TODO: float32 fails
@pytest.mark.parametrize("MODE", ["sdd", "dds", "dsd"])
@pytest.mark.parametrize("TRANS_B", [False, True])
@pytest.mark.parametrize("TRANS_A", [False, True])
@pytest.mark.parametrize("BLOCK", [16, 32, 64])
@pytest.mark.parametrize("DTYPE", [torch.float16])
def test_matmul(MODE, TRANS_A, TRANS_B, BLOCK, DTYPE, Z=3, H=2, M=512, N=256, K=384):
seed = 0
torch.manual_seed(seed)
is_sdd = MODE == "sdd"
is_dsd = MODE == "dsd"
is_dds = MODE == "dds"
do_sparsify = lambda x: triton.testing.sparsify_tensor(x, layout, BLOCK)
do_mask = lambda x: triton.testing.mask_tensor(x, layout, BLOCK)
# create inputs
# create op
a_shape = (Z, H, K, M) if TRANS_A else (Z, H, M, K)
b_shape = (Z, H, N, K) if TRANS_B else (Z, H, K, N)
c_shape = (Z, H, M, N)
shape = {
"sdd": (M, N),
"dsd": (a_shape[2], a_shape[3]),
"dds": (b_shape[2], b_shape[3]),
}[MODE]
layout = torch.randint(2, (H, shape[0] // BLOCK, shape[1] // BLOCK))
layout[1, 2, :] = 0
layout[1, :, 1] = 0
# create data
a_ref, a_tri = triton.testing.make_pair(a_shape, alpha=.1, dtype=DTYPE)
b_ref, b_tri = triton.testing.make_pair(b_shape, alpha=.1, dtype=DTYPE)
dc_ref, dc_tri = triton.testing.make_pair(c_shape, dtype=DTYPE)
# compute [torch]
dc_ref = do_mask(dc_ref) if is_sdd else dc_ref
a_ref = do_mask(a_ref) if is_dsd else a_ref
b_ref = do_mask(b_ref) if is_dds else b_ref
a_ref.requires_grad_().retain_grad()
b_ref.requires_grad_().retain_grad()
c_ref = torch.matmul(a_ref.transpose(2, 3) if TRANS_A else a_ref,
b_ref.transpose(2, 3) if TRANS_B else b_ref)
c_ref.backward(dc_ref)
c_ref = do_sparsify(c_ref) if is_sdd else c_ref
da_ref = do_sparsify(a_ref.grad) if is_dsd else a_ref.grad
db_ref = do_sparsify(b_ref.grad) if is_dds else b_ref.grad
# triton result
dc_tri = do_sparsify(dc_tri) if is_sdd else dc_tri
a_tri = do_sparsify(a_tri) if is_dsd else a_tri
b_tri = do_sparsify(b_tri) if is_dds else b_tri
a_tri.requires_grad_().retain_grad()
b_tri.requires_grad_().retain_grad()
op = triton.ops.blocksparse.matmul(layout, BLOCK, MODE, trans_a=TRANS_A, trans_b=TRANS_B, device="cuda")
c_tri = triton.testing.catch_oor(lambda: op(a_tri, b_tri), pytest)
triton.testing.catch_oor(lambda: c_tri.backward(dc_tri), pytest)
da_tri = a_tri.grad
db_tri = b_tri.grad
# compare
triton.testing.assert_almost_equal(c_ref, c_tri)
triton.testing.assert_almost_equal(da_ref, da_tri)
triton.testing.assert_almost_equal(db_ref, db_tri)
configs = [
(16, 256),
(32, 576),
(64, 1871),
(128, 2511),
]
@pytest.mark.parametrize("is_dense", [False, True])
@pytest.mark.parametrize("BLOCK, WIDTH", configs)
def test_softmax(BLOCK, WIDTH, is_dense, Z=2, H=2, is_causal=True, scale=0.4):
# set seed
torch.random.manual_seed(0)
Z, H, M, N = 2, 3, WIDTH, WIDTH
# initialize layout
# make sure each row has at least one non-zero element
layout = torch.randint(2, (H, M // BLOCK, N // BLOCK))
if is_dense:
layout[:] = 1
else:
layout[1, 2, :] = 0
layout[1, :, 1] = 0
# initialize data
a_shape = (Z, H, M, N)
a_ref, a_tri = triton.testing.make_pair(a_shape)
dout_ref, dout_tri = triton.testing.make_pair(a_shape)
# compute [torch]
a_ref = triton.testing.mask_tensor(a_ref, layout, BLOCK, value=float("-inf"))
a_ref.retain_grad()
at_mask = torch.ones((M, N), device="cuda")
if is_causal:
at_mask = torch.tril(at_mask)
M = at_mask[None, None, :, :] + torch.zeros_like(a_ref)
a_ref[M == 0] = float("-inf")
out_ref = torch.softmax(a_ref * scale, -1)
out_ref.backward(dout_ref)
out_ref = triton.testing.sparsify_tensor(out_ref, layout, BLOCK)
da_ref = triton.testing.sparsify_tensor(a_ref.grad, layout, BLOCK)
# compute [triton]
a_tri = triton.testing.sparsify_tensor(a_tri, layout, BLOCK)
a_tri.retain_grad()
dout_tri = triton.testing.sparsify_tensor(dout_tri, layout, BLOCK)
op = triton.ops.blocksparse.softmax(layout, BLOCK, device="cuda", is_dense=is_dense)
out_tri = op(a_tri, scale=scale, is_causal=is_causal)
out_tri.backward(dout_tri)
da_tri = a_tri.grad
# compare
triton.testing.assert_almost_equal(out_tri, out_ref)
triton.testing.assert_almost_equal(da_tri, da_ref)
@pytest.mark.parametrize("block", [16, 32, 64])
@pytest.mark.parametrize("dtype", [torch.float16])
def test_attention_fwd_bwd(
block,
dtype,
input_scale=1.0,
scale=1 / 8.0,
n_ctx=256,
batch_size=2,
n_heads=2,
):
# inputs
qkv_shape = (batch_size, n_heads, n_ctx, 64)
qkvs = [
torch.nn.Parameter(input_scale * torch.randn(qkv_shape), requires_grad=True).to(dtype).cuda() for _ in range(3)
]
# Triton:
n_blocks = n_ctx // block
layout = torch.tril(torch.ones([n_heads, n_blocks, n_blocks], dtype=torch.long))
query, key, value = [x.clone() for x in qkvs]
query.retain_grad()
key.retain_grad()
value.retain_grad()
attn_out = triton_attention(layout, block, query=query, key=key, value=value, scale=scale)
# ad hoc loss
loss = (attn_out ** 2).mean()
loss.backward()
grads = [query.grad, key.grad, value.grad]
# Torch version:
torch_q, torch_k, torch_v = [x.clone() for x in qkvs]
attn_mask = torch.ones([n_ctx, n_ctx], device="cuda", dtype=dtype)
attn_mask = torch.tril(attn_mask, diagonal=0)
attn_mask = 1e6 * (-1 + (attn_mask.reshape((1, 1, n_ctx, n_ctx)).cuda()))
torch_q.retain_grad()
torch_k.retain_grad()
torch_v.retain_grad()
scores = scale * torch.einsum("bhsd,bhtd->bhst", torch_q, torch_k)
scores = scores + attn_mask
probs = torch.softmax(scores, dim=-1)
torch_attn_out = torch.einsum("bhst,bhtd->bhsd", probs, torch_v)
# ad hoc loss
torch_loss = (torch_attn_out ** 2).mean()
torch_loss.backward()
torch_grads = [torch_q.grad, torch_k.grad, torch_v.grad]
# comparison
# print(f"Triton loss {loss} and torch loss {torch_loss}. Also checking grads...")
triton.testing.assert_almost_equal(loss, torch_loss)
for g1, g2 in zip(grads, torch_grads):
triton.testing.assert_almost_equal(g1, g2)
@pytest.mark.parametrize("block", [16, 32, 64])
def triton_attention(
layout,
block: int,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
scale: float,
):
sparse_dot_sdd_nt = triton.ops.blocksparse.matmul(layout, block, "sdd", trans_a=False, trans_b=True, device=value.device)
sparse_dot_dsd_nn = triton.ops.blocksparse.matmul(layout, block, "dsd", trans_a=False, trans_b=False, device=value.device)
sparse_softmax = triton.ops.blocksparse.softmax(layout, block, device=value.device)
w = sparse_dot_sdd_nt(query, key)
w = sparse_softmax(w, scale=scale, is_causal=True)
a = sparse_dot_dsd_nn(w, value)
return a

View File

@@ -12,6 +12,7 @@ import triton
import triton._C.libtriton.triton as _triton
import triton.language as tl
from triton.runtime.jit import JITFunction, TensorWrapper, reinterpret
from tests.libdevice_testutil import system_libdevice_path
int_dtypes = ['int8', 'int16', 'int32', 'int64']
uint_dtypes = ['uint8', 'uint16', 'uint32', 'uint64']
@@ -667,7 +668,6 @@ def test_tensor_atomic_rmw(shape, axis, device="cuda"):
tl.atomic_add(Z + off1, z)
rs = RandomState(17)
x = numpy_random((shape0, shape1), dtype_str="float32", rs=rs)
print(x)
# reference result
z_ref = np.sum(x, axis=axis, keepdims=False)
# triton result
@@ -677,36 +677,7 @@ def test_tensor_atomic_rmw(shape, axis, device="cuda"):
kernel[(1,)](z_tri, x_tri, axis, shape0, shape1)
np.testing.assert_allclose(z_ref, to_numpy(z_tri), rtol=1e-4)
# def test_atomic_cas():
# # 1. make sure that atomic_cas changes the original value (Lock)
# @triton.jit
# def change_value(Lock):
# tl.atomic_cas(Lock, 0, 1)
# Lock = torch.zeros((1,), device='cuda', dtype=torch.int32)
# change_value[(1,)](Lock)
# assert (Lock[0] == 1)
# # 2. only one block enters the critical section
# @triton.jit
# def serialized_add(data, Lock):
# ptrs = data + tl.arange(0, 128)
# while tl.atomic_cas(Lock, 0, 1) == 1:
# pass
# tl.store(ptrs, tl.load(ptrs) + 1.0)
# # release lock
# tl.atomic_xchg(Lock, 0)
# Lock = torch.zeros((1,), device='cuda', dtype=torch.int32)
# data = torch.zeros((128,), device='cuda', dtype=torch.float32)
# ref = torch.full((128,), 64.0)
# serialized_add[(64,)](data, Lock)
# triton.testing.assert_almost_equal(data, ref)
def test_simple_atomic_cas():
def test_atomic_cas():
# 1. make sure that atomic_cas changes the original value (Lock)
@triton.jit
def change_value(Lock):
@@ -717,6 +688,25 @@ def test_simple_atomic_cas():
assert (Lock[0] == 1)
# 2. only one block enters the critical section
@triton.jit
def serialized_add(data, Lock):
ptrs = data + tl.arange(0, 128)
while tl.atomic_cas(Lock, 0, 1) == 1:
pass
tl.store(ptrs, tl.load(ptrs) + 1.0)
# release lock
tl.atomic_xchg(Lock, 0)
Lock = torch.zeros((1,), device='cuda', dtype=torch.int32)
data = torch.zeros((128,), device='cuda', dtype=torch.float32)
ref = torch.full((128,), 64.0)
serialized_add[(64,)](data, Lock)
triton.testing.assert_almost_equal(data, ref)
# # ---------------
# # test cast
# # ---------------
@@ -1077,122 +1067,126 @@ def test_permute(dtype_str, shape, perm, device='cuda'):
# # ---------------
# @pytest.mark.parametrize("epilogue, allow_tf32, dtype",
# [(epilogue, allow_tf32, dtype)
# for epilogue in ['none', 'trans', 'add-matrix', 'add-rows', 'add-cols', 'softmax', 'chain-dot']
# for allow_tf32 in [True, False]
# for dtype in ['float16']
# if not (allow_tf32 and (dtype in ['float16']))])
# def test_dot(epilogue, allow_tf32, dtype, device='cuda'):
# cc = _triton.runtime.cc(_triton.runtime.backend.CUDA, torch.cuda.current_device())
# if cc < 80:
# if dtype == 'int8':
# pytest.skip("Only test int8 on devices with sm >= 80")
# elif dtype == 'float32' and allow_tf32:
# pytest.skip("Only test tf32 on devices with sm >= 80")
@pytest.mark.parametrize("epilogue, allow_tf32, dtype",
[(epilogue, allow_tf32, dtype)
for epilogue in ['none', 'trans', 'add-matrix', 'add-rows', 'add-cols', 'softmax', 'chain-dot']
for allow_tf32 in [True, False]
for dtype in ['float16']
if not (allow_tf32 and (dtype in ['float16']))])
def test_dot(epilogue, allow_tf32, dtype, device='cuda'):
capability = torch.cuda.get_device_capability()
if capability[0] < 80:
if dtype == 'int8':
pytest.skip("Only test int8 on devices with sm >= 80")
elif dtype == 'float32' and allow_tf32:
pytest.skip("Only test tf32 on devices with sm >= 80")
# M, N, K = 128, 128, 64
# num_warps = 8
# trans_a, trans_b = False, False
M, N, K = 64, 64, 64
num_warps = 4
trans_a, trans_b = False, False
# # triton kernel
# @triton.jit
# def kernel(X, stride_xm, stride_xk,
# Y, stride_yk, stride_yn,
# W, stride_wn, stride_wl,
# Z, stride_zm, stride_zn,
# BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr,
# ADD_MATRIX: tl.constexpr, ADD_ROWS: tl.constexpr, ADD_COLS: tl.constexpr,
# ALLOW_TF32: tl.constexpr,
# DO_SOFTMAX: tl.constexpr, CHAIN_DOT: tl.constexpr,
# TRANS_A: tl.constexpr, TRANS_B: tl.constexpr):
# off_m = tl.arange(0, BLOCK_M)
# off_n = tl.arange(0, BLOCK_N)
# off_l = tl.arange(0, BLOCK_N)
# off_k = tl.arange(0, BLOCK_K)
# Xs = X + off_m[:, None] * stride_xm + off_k[None, :] * stride_xk
# Ys = Y + off_k[:, None] * stride_yk + off_n[None, :] * stride_yn
# Ws = W + off_n[:, None] * stride_wn + off_l[None, :] * stride_wl
# Zs = Z + off_m[:, None] * stride_zm + off_n[None, :] * stride_zn
# z = tl.dot(tl.load(Xs), tl.load(Ys), trans_a=TRANS_A, trans_b=TRANS_B, allow_tf32=ALLOW_TF32)
# if ADD_MATRIX:
# z += tl.load(Zs)
# if ADD_ROWS:
# ZRs = Z + off_m * stride_zm
# z += tl.load(ZRs)[:, None]
# if ADD_COLS:
# ZCs = Z + off_n * stride_zn
# z += tl.load(ZCs)[None, :]
# if DO_SOFTMAX:
# max = tl.max(z, 1)
# z = z - max[:, None]
# num = tl.exp(z)
# den = tl.sum(num, 1)
# z = num / den[:, None]
# if CHAIN_DOT:
# # tl.store(Zs, z)
# # tl.debug_barrier()
# z = tl.dot(z.to(tl.float16), tl.load(Ws), trans_a=TRANS_A)
# tl.store(Zs, z)
# # input
# rs = RandomState(17)
# x = numpy_random((K, M) if trans_a else (M, K), dtype_str=dtype, rs=rs) * .1
# y = numpy_random((N, K) if trans_b else (K, N), dtype_str=dtype, rs=rs) * .1
# w = numpy_random((N, N), dtype_str=dtype, rs=rs) * .1
# if allow_tf32:
# x = (x.view('uint32') & np.uint32(0xffffe000)).view('float32')
# y = (y.view('uint32') & np.uint32(0xffffe000)).view('float32')
# w = (w.view('uint32') & np.uint32(0xffffe000)).view('float32')
# x_tri = to_triton(x, device=device)
# y_tri = to_triton(y, device=device)
# w_tri = to_triton(w, device=device)
# # triton result
# z = 1 + numpy_random((M, N), dtype_str=dtype, rs=rs) * .1
# z_tri = to_triton(z, device=device)
# if epilogue == 'trans':
# z_tri = torch.as_strided(z_tri, (M, N), z_tri.stride()[::-1])
# pgm = kernel[(1, 1)](x_tri, x_tri.stride(0), x_tri.stride(1),
# y_tri, y_tri.stride(0), y_tri.stride(1),
# w_tri, w_tri.stride(0), w_tri.stride(1),
# z_tri, z_tri.stride(0), z_tri.stride(1),
# TRANS_A=trans_a, TRANS_B=trans_b,
# BLOCK_M=M, BLOCK_K=K, BLOCK_N=N,
# ADD_MATRIX=epilogue == 'add-matrix',
# ADD_ROWS=epilogue == 'add-rows',
# ADD_COLS=epilogue == 'add-cols',
# DO_SOFTMAX=epilogue == 'softmax',
# CHAIN_DOT=epilogue == 'chain-dot',
# ALLOW_TF32=allow_tf32,
# num_warps=num_warps)
# # torch result
# x_ref = x.T if trans_a else x
# y_ref = y.T if trans_b else y
# z_ref = np.matmul(x_ref, y_ref)
# if epilogue == 'add-matrix':
# z_ref += z
# if epilogue == 'add-rows':
# z_ref += z[:, 0][:, None]
# if epilogue == 'add-cols':
# z_ref += z[0, :][None, :]
# if epilogue == 'softmax':
# num = np.exp(z_ref - np.max(z_ref, axis=-1, keepdims=True))
# denom = np.sum(num, axis=-1, keepdims=True)
# z_ref = num / denom
# if epilogue == 'chain-dot':
# z_ref = np.matmul(z_ref.T if trans_a else z_ref, w)
# # compare
# # print(z_ref[:,0], z_tri[:,0])
# np.testing.assert_allclose(z_ref, to_numpy(z_tri), rtol=0.01)
# # make sure ld/st are vectorized
# ptx = pgm.asm['ptx']
# assert 'ld.global.v4' in ptx
# assert 'st.global.v4' in ptx
# if allow_tf32:
# assert 'mma.sync.aligned.m16n8k8.row.col.f32.tf32.tf32.f32' in ptx
# elif dtype == 'float32':
# assert 'mma.sync.aligned.m16n8k8.row.col.f32.tf32.tf32.f32' not in ptx
# elif dtype == 'int8':
# assert 'mma.sync.aligned.m16n8k32.row.col.satfinite.s32.s8.s8.s32' in ptx
# triton kernel
@triton.jit
def kernel(X, stride_xm, stride_xk,
Y, stride_yk, stride_yn,
W, stride_wn, stride_wl,
Z, stride_zm, stride_zn,
BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr,
ADD_MATRIX: tl.constexpr, ADD_ROWS: tl.constexpr, ADD_COLS: tl.constexpr,
ALLOW_TF32: tl.constexpr,
DO_SOFTMAX: tl.constexpr, CHAIN_DOT: tl.constexpr,
TRANS_A: tl.constexpr, TRANS_B: tl.constexpr):
off_m = tl.arange(0, BLOCK_M)
off_n = tl.arange(0, BLOCK_N)
off_l = tl.arange(0, BLOCK_N)
off_k = tl.arange(0, BLOCK_K)
Xs = X + off_m[:, None] * stride_xm + off_k[None, :] * stride_xk
Ys = Y + off_k[:, None] * stride_yk + off_n[None, :] * stride_yn
Ws = W + off_n[:, None] * stride_wn + off_l[None, :] * stride_wl
Zs = Z + off_m[:, None] * stride_zm + off_n[None, :] * stride_zn
x = tl.load(Xs)
y = tl.load(Ys)
x = tl.trans(x) if TRANS_A else x
y = tl.trans(y) if TRANS_B else y
z = tl.dot(x, y, allow_tf32=ALLOW_TF32)
if ADD_MATRIX:
z += tl.load(Zs)
if ADD_ROWS:
ZRs = Z + off_m * stride_zm
z += tl.load(ZRs)[:, None]
if ADD_COLS:
ZCs = Z + off_n * stride_zn
z += tl.load(ZCs)[None, :]
if DO_SOFTMAX:
max = tl.max(z, 1)
z = z - max[:, None]
num = tl.exp(z)
den = tl.sum(num, 1)
z = num / den[:, None]
if CHAIN_DOT:
# tl.store(Zs, z)
# tl.debug_barrier()
z = tl.dot(z.to(tl.float16), tl.load(Ws))
tl.store(Zs, z)
# input
rs = RandomState(17)
x = numpy_random((K, M) if trans_a else (M, K), dtype_str=dtype, rs=rs) * .1
y = numpy_random((N, K) if trans_b else (K, N), dtype_str=dtype, rs=rs) * .1
w = numpy_random((N, N), dtype_str=dtype, rs=rs) * .1
if allow_tf32:
x = (x.view('uint32') & np.uint32(0xffffe000)).view('float32')
y = (y.view('uint32') & np.uint32(0xffffe000)).view('float32')
w = (w.view('uint32') & np.uint32(0xffffe000)).view('float32')
x_tri = to_triton(x, device=device)
y_tri = to_triton(y, device=device)
w_tri = to_triton(w, device=device)
# triton result
z = 1 + numpy_random((M, N), dtype_str=dtype, rs=rs) * .1
z_tri = to_triton(z, device=device)
if epilogue == 'trans':
z_tri = torch.as_strided(z_tri, (M, N), z_tri.stride()[::-1])
pgm = kernel[(1, 1)](x_tri, x_tri.stride(0), x_tri.stride(1),
y_tri, y_tri.stride(0), y_tri.stride(1),
w_tri, w_tri.stride(0), w_tri.stride(1),
z_tri, z_tri.stride(0), z_tri.stride(1),
TRANS_A=trans_a, TRANS_B=trans_b,
BLOCK_M=M, BLOCK_K=K, BLOCK_N=N,
ADD_MATRIX=epilogue == 'add-matrix',
ADD_ROWS=epilogue == 'add-rows',
ADD_COLS=epilogue == 'add-cols',
DO_SOFTMAX=epilogue == 'softmax',
CHAIN_DOT=epilogue == 'chain-dot',
ALLOW_TF32=allow_tf32,
num_warps=num_warps)
# torch result
x_ref = x.T if trans_a else x
y_ref = y.T if trans_b else y
z_ref = np.matmul(x_ref, y_ref)
if epilogue == 'add-matrix':
z_ref += z
if epilogue == 'add-rows':
z_ref += z[:, 0][:, None]
if epilogue == 'add-cols':
z_ref += z[0, :][None, :]
if epilogue == 'softmax':
num = np.exp(z_ref - np.max(z_ref, axis=-1, keepdims=True))
denom = np.sum(num, axis=-1, keepdims=True)
z_ref = num / denom
if epilogue == 'chain-dot':
z_ref = np.matmul(z_ref, w)
# compare
# print(z_ref[:,0], z_tri[:,0])
np.testing.assert_allclose(z_ref, to_numpy(z_tri), rtol=0.01)
# make sure ld/st are vectorized
ptx = pgm.asm['ptx']
assert 'ld.global.v4' in ptx
assert 'st.global.v4' in ptx
if allow_tf32:
assert 'mma.sync.aligned.m16n8k8.row.col.f32.tf32.tf32.f32' in ptx
elif dtype == 'float32':
assert 'mma.sync.aligned.m16n8k8.row.col.f32.tf32.tf32.f32' not in ptx
elif dtype == 'int8':
assert 'mma.sync.aligned.m16n8k32.row.col.satfinite.s32.s8.s8.s32' in ptx
# def test_dot_without_load():
@@ -1559,7 +1553,7 @@ def test_num_warps_pow2():
@pytest.mark.parametrize("dtype_str, expr, lib_path",
[('int32', 'libdevice.ffs', ''),
('float32', 'libdevice.pow', '/usr/local/cuda/nvvm/libdevice/libdevice.10.bc'),
('float32', 'libdevice.pow', system_libdevice_path()),
('float64', 'libdevice.norm4d', '')])
def test_libdevice_tensor(dtype_str, expr, lib_path):

View File

@@ -5,6 +5,7 @@ import _testcapi
import pytest
import torch
from torch.testing import assert_close
from tests.libdevice_testutil import system_libdevice_path
import triton
import triton.language as tl
@@ -32,8 +33,6 @@ torch_ops = {
"where": "where",
}
libdevice = '/usr/local/cuda/nvvm/libdevice/libdevice.10.bc'
def get_tensor(shape, data_type, b_positive=False):
x = None
@@ -90,7 +89,11 @@ def kernel(X, Y, BLOCK: tl.constexpr):
x = get_tensor(shape, input0_type, expr == 'log' or expr == 'sqrt')
# triton result
y = torch.zeros(shape, dtype=torch_type[output_type], device="cuda")
kernel[(1,)](x, y, BLOCK=shape[0], extern_libs={"libdevice": libdevice})
kernel[(1,)](
x, y,
BLOCK=shape[0],
extern_libs={"libdevice": system_libdevice_path()},
)
# reference result
y_ref = getattr(torch, torch_ops[expr])(x)
# compare
@@ -134,7 +137,11 @@ def kernel(X0, X1, Y, BLOCK: tl.constexpr):
# triton result
y = torch.zeros(shape, dtype=torch_type[output_type], device="cuda")
kernel[(1,)](x0, x1, y, BLOCK=shape[0], extern_libs={"libdevice": libdevice})
kernel[(1,)](
x0, x1, y,
BLOCK=shape[0],
extern_libs={"libdevice": system_libdevice_path()},
)
# reference result
if expr == "cdiv":
@@ -182,7 +189,11 @@ def kernel(X0, X1, X2, Y, BLOCK: tl.constexpr):
# triton result
y = torch.zeros(shape, dtype=torch_type[output_type], device="cuda")
kernel[(1,)](x0, x1, x2, y, BLOCK=shape[0], extern_libs={"libdevice": libdevice})
kernel[(1,)](
x0, x1, x2, y,
BLOCK=shape[0],
extern_libs={"libdevice": system_libdevice_path()},
)
# reference result
y_ref = getattr(torch, torch_ops[expr])(x0, x1, x2)

View File

@@ -5,6 +5,7 @@ from torch.testing import assert_close
import triton
import triton.language as tl
from tests.libdevice_testutil import system_libdevice_path
@pytest.mark.parametrize('num_warps, block_size, iter_size', [
@@ -125,7 +126,7 @@ def test_fmad_rn_no_mask(num_warps, block_size, iter_size):
@pytest.mark.parametrize("dtype_str, expr, lib_path",
[('int32', 'libdevice.ffs', '/usr/local/cuda/nvvm/libdevice/libdevice.10.bc'),
[('int32', 'libdevice.ffs', system_libdevice_path()),
('int32', 'libdevice.ffs', '')])
def test_libdevice(dtype_str, expr, lib_path):
src = f"""

View File

@@ -172,8 +172,9 @@ def get_proper_err(a, b, golden):
[128, 64, 128, 4, 128, 64, 128, False, False],
[16, 16, 16, 16, 16, 16, 16, False, False], # wpt overflow issue
# K-Forloop
[32, 32, 64, 4, 32, 32, 32, False, False], # Single shared encoding
[16, 16, 128, 4, 16, 16, 16, False, False], # Single shared encoding and small k
# [16, 16, 64, 4, 8, 8, 8, False, False], # Wrap threads
[32, 32, 64, 4, 32, 32, 32, False, False], # Single shared encoding
[16, 16, 128, 4, 16, 16, 16, False, False], # Single shared encoding and small k
[64, 32, 128, 4, 64, 32, 64, False, False],
[128, 16, 128, 4, 128, 16, 32, False, False],
[32, 16, 128, 4, 32, 16, 32, False, False],
@@ -187,7 +188,8 @@ def get_proper_err(a, b, golden):
[128, 256, 128, 4, 128, 256, 32, False, False],
[256, 128, 64, 4, 256, 128, 16, False, False],
[128, 64, 128, 4, 128, 64, 32, False, False],
# [16, 16, 64, 4, 16, 16, 16, False, False], # TODO failed due to pipeline pass
[16, 16, 64, 4, 16, 16, 16, False, False],
[32, 32, 64, 4, 32, 32, 32, False, False],
# trans
[128, 64, 128, 4, 128, 64, 32, True, False],
[128, 64, 128, 4, 128, 64, 32, False, True],
@@ -218,14 +220,17 @@ def test_gemm(SIZE_M, SIZE_N, SIZE_K, NUM_WARPS, BLOCK_SIZE_M, BLOCK_SIZE_N, BLO
assert_close(c, golden, rtol=max(1e-4, 1.5 * golden_rel_err), atol=max(1e-4, 1.5 * golden_abs_err), check_dtype=False)
@pytest.mark.parametrize('M,N,K,num_warps,block_M,block_N,block_K', [
[32, 32, 16, 4, 32, 32, 16],
[32, 16, 16, 4, 32, 32, 16],
[128, 8, 8, 4, 32, 32, 16],
# TODO[Superjomn]: fix it later
# [127, 41, 43, 4, 32, 32, 16],
@pytest.mark.parametrize('M,N,K,num_warps,block_M,block_N,block_K,allow_tf32', [
[32, 32, 16, 4, 32, 32, 16, False],
[32, 32, 16, 4, 32, 32, 16, True],
[32, 16, 16, 4, 32, 32, 16, False],
[32, 16, 16, 4, 32, 32, 16, True],
[127, 41, 43, 4, 32, 32, 16, False],
[127, 41, 43, 4, 32, 32, 16, True],
[128, 8, 8, 4, 32, 32, 16, False],
[128, 8, 8, 4, 32, 32, 16, True]
])
def test_gemm_fmadot(M, N, K, num_warps, block_M, block_N, block_K):
def test_gemm_fp32(M, N, K, num_warps, block_M, block_N, block_K, allow_tf32):
@triton.jit
def matmul_kernel(
a_ptr, b_ptr, c_ptr,
@@ -234,6 +239,7 @@ def test_gemm_fmadot(M, N, K, num_warps, block_M, block_N, block_K):
stride_bk, stride_bn,
stride_cm, stride_cn,
BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr,
ALLOW_TF32: tl.constexpr
):
pid = tl.program_id(axis=0)
# num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
@@ -251,10 +257,9 @@ def test_gemm_fmadot(M, N, K, num_warps, block_M, block_N, block_K):
for k in range(0, K, BLOCK_SIZE_K):
a_mask = (offs_am[:, None] < M) & (offs_k[None, :] < K)
b_mask = (offs_k[:, None] < K) & (offs_bn[None, :] < N)
a = tl.load(a_ptrs, a_mask)
b = tl.load(b_ptrs, b_mask)
# NOTE the allow_tf32 should be false to force the dot op to do fmadot lowering
accumulator += tl.dot(a, b, allow_tf32=False)
a = tl.load(a_ptrs, a_mask, other=0.0)
b = tl.load(b_ptrs, b_mask, other=0.0)
accumulator += tl.dot(a, b, allow_tf32=ALLOW_TF32)
a_ptrs += BLOCK_SIZE_K * stride_ak
b_ptrs += BLOCK_SIZE_K * stride_bk
offs_k += BLOCK_SIZE_K
@@ -265,6 +270,9 @@ def test_gemm_fmadot(M, N, K, num_warps, block_M, block_N, block_K):
c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N)
tl.store(c_ptrs, accumulator, c_mask)
# Configure the pytorch counterpart
torch.backends.cuda.matmul.allow_tf32 = allow_tf32
a = torch.randn((M, K), device='cuda', dtype=torch.float32)
b = torch.randn((K, N), device='cuda', dtype=torch.float32)
c = torch.empty((M, N), device=a.device, dtype=torch.float32)
@@ -275,8 +283,30 @@ def test_gemm_fmadot(M, N, K, num_warps, block_M, block_N, block_K):
stride_am=a.stride(0), stride_ak=a.stride(1),
stride_bk=b.stride(0), stride_bn=b.stride(1),
stride_cm=c.stride(0), stride_cn=c.stride(1),
BLOCK_SIZE_M=block_M, BLOCK_SIZE_N=block_N, BLOCK_SIZE_K=block_K)
BLOCK_SIZE_M=block_M, BLOCK_SIZE_N=block_N, BLOCK_SIZE_K=block_K, ALLOW_TF32=allow_tf32)
golden = torch.matmul(a, b)
golden_abs_err, golden_rel_err = get_proper_err(a, b, golden)
torch.testing.assert_close(c, golden, rtol=max(1e-4, 1.5 * golden_rel_err), atol=max(1e-4, 1.5 * golden_abs_err))
if allow_tf32:
# TF32 is not accurate enough
torch.testing.assert_close(c, golden, rtol=max(1e-2, 1.5 * golden_rel_err), atol=max(1e-2, 1.5 * golden_abs_err))
else:
torch.testing.assert_close(c, golden, rtol=max(1e-4, 1.5 * golden_rel_err), atol=max(1e-4, 1.5 * golden_abs_err))
# NOTE this is useful only on Volta GPU.
@pytest.mark.parametrize('SHAPE,NUM_WARPS,TRANS_A,TRANS_B', [
(shape, num_warps, trans_a, trans_b)
for shape in [
[16, 16, 16],
[16, 16, 32],
[32, 16, 16],
[32, 32, 32],
[128, 16, 16],
]
for num_warps in [1]
for trans_a in [False]
for trans_b in [False]
])
def test_gemm_no_scf_for_mmav1(SHAPE, NUM_WARPS, TRANS_A, TRANS_B):
test_gemm_no_scf(SHAPE, NUM_WARPS, TRANS_A, TRANS_B)

View File

@@ -1,4 +1,5 @@
import pytest
import numpy as np
import torch
from torch.testing import assert_close
@@ -13,7 +14,9 @@ dtypes_with_bfloat16 = int_dtypes + uint_dtypes + float_dtypes
dtype_mapping = {dtype_str: torch.__dict__[dtype_str] for dtype_str in dtypes}
def get_reduced_dtype(dtype):
def get_reduced_dtype(op, dtype):
if op in ['argmin', 'argmax']:
return torch.int32
if dtype in [torch.int8, torch.int16, torch.uint8]:
return torch.int32
if dtype in [torch.bfloat16]:
@@ -48,7 +51,7 @@ def reduce2d_kernel(x_ptr, z_ptr, axis: tl.constexpr, block_m: tl.constexpr, blo
reduce1d_configs = [
(op, dtype, shape)
for op in ['sum', 'min', 'max']
for op in ['sum', 'min', 'max', 'argmin', 'argmax', 'xor_sum']
for dtype in dtypes
for shape in [4, 8, 16, 32, 64, 128, 512, 1024]
]
@@ -56,8 +59,11 @@ reduce1d_configs = [
@pytest.mark.parametrize('op, dtype, shape', reduce1d_configs)
def test_reduce1d(op, dtype, shape):
if op == 'xor_sum' and dtype in float_dtypes:
return
dtype = dtype_mapping[dtype]
reduced_dtype = get_reduced_dtype(dtype)
reduced_dtype = get_reduced_dtype(op, dtype)
if dtype.is_floating_point:
x = torch.randn((shape,), device='cuda', dtype=dtype)
@@ -79,8 +85,17 @@ def test_reduce1d(op, dtype, shape):
golden_z = torch.sum(x, dtype=reduced_dtype)
elif op == 'min':
golden_z = torch.min(x).to(reduced_dtype)
else:
elif op == 'max':
golden_z = torch.max(x).to(reduced_dtype)
elif op == 'argmin':
golden_z = torch.argmin(x).to(reduced_dtype)
elif op == 'argmax':
golden_z = torch.argmax(x).to(reduced_dtype)
elif op == 'xor_sum':
sum_npy = np.bitwise_xor.reduce(x.cpu().numpy())
golden_z = torch.tensor(sum_npy, dtype=reduced_dtype).cuda()
else:
raise RuntimeError(f'Unknwon reduce op {op}')
if dtype.is_floating_point and op == 'sum':
if shape >= 256:
@@ -95,7 +110,7 @@ def test_reduce1d(op, dtype, shape):
reduce2d_configs = [
(op, dtype, shape, axis)
for op in ['sum', 'min', 'max']
for op in ['sum', 'min', 'max', 'argmin', 'argmax', 'xor_sum']
for dtype in dtypes
for shape in [(1, 4), (1, 8), (1, 16), (1, 32), (2, 32), (4, 32), (4, 128), (32, 64)]
for axis in [0, 1]
@@ -104,8 +119,11 @@ reduce2d_configs = [
@pytest.mark.parametrize('op, dtype, shape, axis', reduce2d_configs)
def test_reduce2d(op, dtype, shape, axis):
if op == 'xor_sum' and dtype in float_dtypes:
return
dtype = dtype_mapping[dtype]
reduced_dtype = get_reduced_dtype(dtype)
reduced_dtype = get_reduced_dtype(op, dtype)
reduced_shape = (shape[1 - axis],)
if dtype.is_floating_point:
@@ -123,8 +141,18 @@ def test_reduce2d(op, dtype, shape, axis):
golden_z = torch.sum(x, dim=axis, keepdim=False, dtype=reduced_dtype)
elif op == 'min':
golden_z = torch.min(x, dim=axis, keepdim=False)[0].to(reduced_dtype)
else:
elif op == 'max':
golden_z = torch.max(x, dim=axis, keepdim=False)[0].to(reduced_dtype)
elif op == 'argmin':
golden_z = torch.argmin(x, dim=axis, keepdim=False).to(reduced_dtype)
elif op == 'argmax':
golden_z = torch.argmax(x, dim=axis, keepdim=False).to(reduced_dtype)
elif op == 'xor_sum':
sum_npy = np.bitwise_xor.reduce(x.cpu().numpy(), axis=axis, keepdims=False)
golden_z = torch.tensor(sum_npy, dtype=reduced_dtype).cuda()
else:
raise RuntimeError(f'Unknwon reduce op {op}')
if dtype.is_floating_point and op == 'sum':
if shape[axis] >= 256:
assert_close(z, golden_z, rtol=0.05, atol=0.1)

View File

@@ -1,15 +1,52 @@
"""isort:skip_file"""
# flake8: noqa: F401
__version__ = '2.0.0'
# ---------------------------------------
# Note: import order is significant here.
# TODO: torch needs to be imported first
# or pybind11 shows `munmap_chunk(): invalid pointer`
import torch
import torch # noqa: F401
# submodules
from .utils import *
from .runtime import Config, autotune, heuristics, JITFunction, KernelInterface
from . import impl
from .utils import (
cdiv,
MockTensor,
next_power_of_2,
reinterpret,
TensorWrapper,
)
from .runtime import (
autotune,
Config,
heuristics,
JITFunction,
KernelInterface,
)
from .runtime.jit import jit
from .compiler import compile, CompilationError
from . import language
from . import testing
from . import ops
__all__ = [
"autotune",
"cdiv",
"CompilationError",
"compile",
"Config",
"heuristics",
"impl",
"jit",
"JITFunction",
"KernelInterface",
"language",
"MockTensor",
"next_power_of_2",
"ops",
"reinterpret",
"runtime",
"TensorWrapper",
"testing",
]

View File

@@ -25,6 +25,8 @@ from filelock import FileLock
import triton
import triton._C.libtriton.triton as _triton
from . import impl
from .tools.disasm import extract
@@ -359,7 +361,7 @@ class CodeGenerator(ast.NodeVisitor):
cond = cond.to(triton.language.int1, _builder=self.builder)
with enter_sub_region(self) as sr:
liveins, ip_block = sr
liveins_copy = liveins.copy()
then_block = self.builder.create_block()
self.builder.set_insertion_point_to_start(then_block)
self.visit_compound_statement(node.body)
@@ -394,7 +396,15 @@ class CodeGenerator(ast.NodeVisitor):
if then_defs[then_name].type == else_defs[else_name].type:
names.append(then_name)
ret_types.append(then_defs[then_name].type)
# defined in else block but not in then block
# to find in parent scope and yield them
for else_name in else_defs:
if else_name in liveins and else_name not in then_defs:
if else_defs[else_name].type == liveins[else_name].type:
names.append(else_name)
ret_types.append(else_defs[else_name].type)
then_defs[else_name] = liveins_copy[else_name]
self.builder.set_insertion_point_to_end(ip_block)
if then_defs or node.orelse: # with else block
@@ -528,8 +538,7 @@ class CodeGenerator(ast.NodeVisitor):
[ty.to_ir(self.builder) for ty in ret_types])
loop_block.merge_block_before(after_block)
self.builder.set_insertion_point_to_end(after_block)
if len(yields) > 0:
self.builder.create_yield_op([y.handle for y in yields])
self.builder.create_yield_op([y.handle for y in yields])
# update global uses in while_op
for i, name in enumerate(names):
@@ -594,11 +603,8 @@ class CodeGenerator(ast.NodeVisitor):
ub = self.builder.create_to_index(ub)
step = self.builder.create_to_index(step)
# Create placeholder for the loop induction variable
# We can use any value because the variable isn't a constexpr
# but use a distinctive value (of the right type) to ease debugging
st_target = ast.Name(id=node.target.id, ctx=ast.Store())
init_node = ast.Assign(targets=[st_target], value=ast.Num(value=0xBADF00D))
self.visit(init_node)
iv = self.builder.create_undef(self.builder.get_int32_ty())
self.set_value(node.target.id, triton.language.core.tensor(iv, triton.language.core.int32))
with enter_sub_region(self) as sr:
liveins, insert_block = sr
@@ -711,9 +717,8 @@ class CodeGenerator(ast.NodeVisitor):
for i in range(call_op.get_num_results()):
results.append(triton.language.tensor(call_op.get_result(i), callee_ret_type[i]))
return tuple(results)
if hasattr(fn, '__self__') and self.is_triton_tensor(fn.__self__) or \
sys.modules[fn.__module__] is triton.language.core or \
isinstance(fn, triton.language.extern.ExternalFunction):
if (hasattr(fn, '__self__') and self.is_triton_tensor(fn.__self__)) \
or impl.is_builtin(fn):
return fn(*args, _builder=self.builder, **kws)
if fn in self.builtins.values():
args = [arg.value if isinstance(arg, triton.language.constexpr) else arg
@@ -757,6 +762,9 @@ class CodeGenerator(ast.NodeVisitor):
def visit_Attribute(self, node):
lhs = self.visit(node.value)
if isinstance(lhs, triton.language.tensor):
if node.attr == "T":
return triton.language.semantic.trans(lhs, builder=self.builder)
return getattr(lhs, node.attr)
def visit_Expr(self, node):
@@ -1014,6 +1022,7 @@ def ty_to_cpp(ty):
"u32": "uint32_t",
"u64": "uint64_t",
"fp32": "float",
"f32": "float",
}[ty]
@@ -1044,6 +1053,7 @@ def generate_launcher(constants, signature):
'u32': 'uint32_t',
'u64': 'uint64_t',
'fp32': 'float',
'f32': 'float',
'fp64': 'double',
}[ty]
@@ -1343,7 +1353,31 @@ def make_hash(fn, **kwargs):
key = f"{fn.cache_key}-{''.join(signature.values())}-{configs_key}-{constants}-{num_warps}-{num_stages}"
return hashlib.md5(key.encode("utf-8")).hexdigest()
assert isinstance(fn, str)
return hashlib.md5(Path(fn).read_text().encode("utf-8")).hexdigest()
return hashlib.md5((Path(fn).read_text() + triton.runtime.jit.version_key()).encode("utf-8")).hexdigest()
# - ^\s*func\s+ : match the start of the string, any leading whitespace, the keyword func,
# and any following whitespace
# - (public\s+)? : optionally match the keyword public and any following whitespace
# - (@\w+) : match an @ symbol followed by one or more word characters
# (letters, digits, or underscores), and capture it as group 1 (the function name)
# - (\((?:%\w+: \S+(?: \{\S+ = \S+ : \S+\})?(?:, )?)*\)) : match a pair of parentheses enclosing
# zero or more arguments separated by commas, and capture it as group 2 (the argument list)
mlir_prototype_pattern = r'^\s*func\s+(?:public\s+)?(@\w+)(\((?:%\w+: \S+(?: \{\S+ = \S+ : \S+\})?(?:, )?)*\))\s*\{\s*$'
ptx_prototype_pattern = r"\.(?:visible|extern)\s+\.(?:entry|func)\s+(\w+)\s*\(([^)]*)\)"
prototype_pattern = {
"ttir": mlir_prototype_pattern,
"ttgir": mlir_prototype_pattern,
"ptx": ptx_prototype_pattern,
}
mlir_arg_type_pattern = r'%\w+: ([^,^\)\s]+)(?: \{\S+ = \S+ : \S+\})?,?'
ptx_arg_type_pattern = r"\.param\s+\.(\w+)"
arg_type_pattern = {
"ttir": mlir_arg_type_pattern,
"ttgir": mlir_arg_type_pattern,
"ptx": ptx_arg_type_pattern,
}
# def compile(fn, signature: str, device: int = -1, constants=dict(), num_warps: int = 4, num_stages: int = 3, extern_libs=None, configs=None):
@@ -1354,6 +1388,27 @@ def compile(fn, **kwargs):
context = _triton.ir.context()
asm = dict()
constants = kwargs.get("constants", dict())
num_warps = kwargs.get("num_warps", 4)
num_stages = kwargs.get("num_stages", 3)
extern_libs = kwargs.get("extern_libs", dict())
device = kwargs.get("device", torch.cuda.current_device())
capability = torch.cuda.get_device_capability()
capability = capability[0]*10 + capability[1]
# build compilation stages
stages = {
"ast" : (lambda path: fn, None),
"ttir": (lambda path: _triton.ir.parse_mlir_module(path, context),
lambda src: ast_to_ttir(src, signature, configs[0], constants)),
"ttgir": (lambda path: _triton.ir.parse_mlir_module(path, context),
lambda src: ttir_to_ttgir(src, num_warps, num_stages, capability)),
"llir": (lambda path: Path(path).read_bytes(),
lambda src: ttgir_to_llir(src, extern_libs, capability)),
"ptx": (lambda path: Path(path).read_text(),
lambda src: llir_to_ptx(src, capability)),
"cubin": (lambda path: Path(path).read_bytes(),
lambda src: ptx_to_cubin(src, capability))
}
# find out the signature of the function
if isinstance(fn, triton.runtime.JITFunction):
configs = kwargs.get("configs", None)
signature = kwargs["signature"]
@@ -1368,13 +1423,17 @@ def compile(fn, **kwargs):
kwargs["signature"] = signature
else:
assert isinstance(fn, str)
name, ir = os.path.basename(fn).split(".")
assert ir == "ttgir"
asm[ir] = _triton.ir.parse_mlir_module(fn, context)
function = asm[ir].get_single_function()
param_tys = [convert_type_repr(str(ty)) for ty in function.type.param_types()]
_, ir = os.path.basename(fn).split(".")
src = Path(fn).read_text()
import re
match = re.search(prototype_pattern[ir], src, re.MULTILINE)
name, signature = match.group(1), match.group(2)
print(name, signature)
types = re.findall(arg_type_pattern[ir], signature)
print(types)
param_tys = [convert_type_repr(ty) for ty in types]
signature = {k: v for k, v in enumerate(param_tys)}
first_stage = 2
first_stage = list(stages.keys()).index(ir)
# cache manager
so_path = make_stub(name, signature, constants)
@@ -1384,58 +1443,42 @@ def compile(fn, **kwargs):
if isinstance(fn, triton.runtime.JITFunction):
name, ext = fn.__name__, "ast"
else:
name, ext = os.path.basename(fn).split(".")
# initialize compilation params
num_warps = kwargs.get("num_warps", 4)
num_stages = kwargs.get("num_stages", 3)
extern_libs = kwargs.get("extern_libs", dict())
device = kwargs.get("device", torch.cuda.current_device())
compute_capability = torch.cuda.get_device_capability(device)
compute_capability = compute_capability[0] * 10 + compute_capability[1]
name, ext = os.path.basename(fn).split(".")
# load metadata if any
metadata = None
if fn_cache_manager.has_file(f'{name}.json'):
with open(fn_cache_manager._make_path(f"{name}.json")) as f:
metadata = json.load(f)
else:
metadata = {"num_warps": num_warps, "num_stages": num_stages, "ctime": dict()}
# build compilation stages
stages = {
"ast": (lambda path: fn, None),
"ttir": (lambda path: _triton.ir.parse_mlir_module(path, context),
lambda src: ast_to_ttir(src, signature, configs[0], constants)),
"ttgir": (lambda path: _triton.ir.parse_mlir_module(path, context),
lambda src: ttir_to_ttgir(src, num_warps, num_stages, compute_capability)),
"llir": (lambda path: Path(path).read_bytes(),
lambda src: ttgir_to_llir(src, extern_libs, compute_capability)),
"ptx": (lambda path: Path(path).read_text(),
lambda src: llir_to_ptx(src, compute_capability)),
"cubin": (lambda path: Path(path).read_bytes(),
lambda src: ptx_to_cubin(src, compute_capability))
}
metadata = {"num_warps": num_warps, "num_stages": num_stages, "ctime": dict()}
if ext == "ptx":
assert "shared" in kwargs, "ptx compilation must provide shared memory size"
metadata["shared"] = kwargs["shared"]
first_stage = list(stages.keys()).index(ext)
asm = dict()
module = fn
# run compilation pipeline and populate metadata
for ir, (parse, compile) in list(stages.items())[first_stage:]:
path = fn_cache_manager._make_path(f"{name}.{ir}")
if ir == ext:
next_module = parse(fn)
elif os.path.exists(path) and \
ir in metadata["ctime"] and \
os.path.getctime(path) == metadata["ctime"][ir]:
next_module = parse(path)
else:
next_module = compile(module)
fn_cache_manager.put(next_module, f"{name}.{ir}")
if os.path.exists(path):
metadata["ctime"][ir] = os.path.getctime(path)
asm[ir] = next_module if ir == "cubin" else str(next_module)
if ir == "llir" and "shared" not in metadata:
metadata["shared"] = _triton.get_shared_memory_size(module)
if ir == "ptx":
metadata["name"] = ptx_get_kernel_name(next_module)
module = next_module
path = fn_cache_manager._make_path(f"{name}.{ir}")
if ir == ext:
next_module = parse(fn)
elif os.path.exists(path) and\
ir in metadata["ctime"] and\
os.path.getctime(path) == metadata["ctime"][ir]:
next_module = parse(path)
else:
next_module = compile(module)
fn_cache_manager.put(next_module, f"{name}.{ir}")
if os.path.exists(path):
metadata["ctime"][ir] = os.path.getctime(path)
asm[ir] = next_module if ir == "cubin" else str(next_module)
if ir == "llir" and "shared" not in metadata:
metadata["shared"] = _triton.get_shared_memory_size(module)
if ir == "ptx":
metadata["name"] = ptx_get_kernel_name(next_module)
module = next_module
# write-back metadata
fn_cache_manager.put(json.dumps(metadata), f"{name}.json", binary=False)
# return handle to compiled kernel
@@ -1515,7 +1558,7 @@ class CudaUtils(object):
}
}
#define CUDA_CHECK(ans) { gpuAssert((ans), __FILE__, __LINE__); }
#define CUDA_CHECK(ans) { gpuAssert((ans), __FILE__, __LINE__); if(PyErr_Occurred()) return NULL; }
static PyObject* loadBinary(PyObject* self, PyObject* args) {
const char* name;
@@ -1530,7 +1573,6 @@ class CudaUtils(object):
CUmodule mod;
int32_t n_regs = 0;
int32_t n_spills = 0;
Py_BEGIN_ALLOW_THREADS;
// create driver handles
CUDA_CHECK(cuModuleLoadData(&mod, data));
CUDA_CHECK(cuModuleGetFunction(&fun, mod, name));
@@ -1548,7 +1590,6 @@ class CudaUtils(object):
CUDA_CHECK(cuFuncGetAttribute(&shared_static, CU_FUNC_ATTRIBUTE_SHARED_SIZE_BYTES, fun));
CUDA_CHECK(cuFuncSetAttribute(fun, CU_FUNC_ATTRIBUTE_MAX_DYNAMIC_SHARED_SIZE_BYTES, shared_optin - shared_static));
}
Py_END_ALLOW_THREADS;
if(PyErr_Occurred()) {
return NULL;

View File

@@ -0,0 +1,22 @@
"""Triton internal implementation details.
Client libraries should not import interfaces from the `triton.impl` module;
as the details are subject to change.
APIs defined in the `triton.impl` module which are public will be re-exported
in other relevant `triton` module namespaces.
"""
from triton._C.libtriton.triton import ir
from .base import (
builtin,
extern,
is_builtin,
)
__all__ = [
"builtin",
"extern",
"ir",
"is_builtin",
]

View File

@@ -0,0 +1,36 @@
from __future__ import annotations
from functools import wraps
from typing import TypeVar
T = TypeVar("T")
TRITON_BUILTIN = "__triton_builtin__"
def builtin(fn: T) -> T:
"""Mark a function as a builtin."""
assert callable(fn)
@wraps(fn)
def wrapper(*args, **kwargs):
if "_builder" not in kwargs or kwargs["_builder"] is None:
raise ValueError(
"Did you forget to add @triton.jit ? "
"(`_builder` argument must be provided outside of JIT functions.)"
)
return fn(*args, **kwargs)
setattr(wrapper, TRITON_BUILTIN, True)
return wrapper
def is_builtin(fn) -> bool:
"""Is this a registered triton builtin function?"""
return getattr(fn, TRITON_BUILTIN, False)
def extern(fn: T) -> T:
"""A decorator for external functions."""
return builtin(fn)

View File

@@ -1,4 +1,173 @@
# flake8: noqa: F401
"""isort:skip_file"""
# Import order is significant here.
from ..impl import (
ir,
builtin,
)
from . import core, extern, libdevice, random
from .core import *
from .random import *
from .core import (
abs,
arange,
argmin,
argmax,
atomic_add,
atomic_and,
atomic_cas,
atomic_max,
atomic_min,
atomic_or,
atomic_xchg,
atomic_xor,
bfloat16,
block_type,
builtin,
cat,
cdiv,
constexpr,
cos,
debug_barrier,
dot,
dtype,
exp,
fdiv,
float16,
float32,
float64,
float8,
function_type,
int1,
int16,
int32,
int64,
int8,
load,
log,
max,
max_contiguous,
maximum,
min,
minimum,
multiple_of,
num_programs,
pi32_t,
pointer_type,
printf,
program_id,
ravel,
sigmoid,
sin,
softmax,
sqrt,
store,
sum,
swizzle2d,
tensor,
trans,
triton,
uint16,
uint32,
uint64,
uint8,
umulhi,
void,
where,
xor_sum,
zeros,
zeros_like,
)
from .random import (
pair_uniform_to_normal,
philox,
philox_impl,
rand,
rand4x,
randint,
randint4x,
randn,
randn4x,
uint32_to_uniform_float,
)
__all__ = [
"abs",
"arange",
"argmin",
"argmax",
"atomic_add",
"atomic_and",
"atomic_cas",
"atomic_max",
"atomic_min",
"atomic_or",
"atomic_xchg",
"atomic_xor",
"bfloat16",
"block_type",
"builtin",
"cat",
"cdiv",
"constexpr",
"cos",
"debug_barrier",
"dot",
"dtype",
"exp",
"fdiv",
"float16",
"float32",
"float64",
"float8",
"function_type",
"int1",
"int16",
"int32",
"int64",
"int8",
"ir",
"load",
"log",
"max",
"max_contiguous",
"maximum",
"min",
"minimum",
"multiple_of",
"num_programs",
"pair_uniform_to_normal",
"philox",
"philox_impl",
"pi32_t",
"pointer_type",
"printf",
"program_id",
"rand",
"rand4x",
"randint",
"randint4x",
"randn",
"randn4x",
"ravel",
"sigmoid",
"sin",
"softmax",
"sqrt",
"store",
"sum",
"swizzle2d",
"tensor",
"trans",
"triton",
"uint16",
"uint32",
"uint32_to_uniform_float",
"uint64",
"uint8",
"umulhi",
"void",
"where",
"xor_sum",
"zeros",
"zeros_like",
]

View File

@@ -1,13 +1,13 @@
from __future__ import annotations
from enum import Enum
from functools import wraps
from typing import List
from typing import List, Callable, TypeVar
import triton
from . import semantic
from . import builtin, semantic
from triton._C.libtriton.triton import ir
T = TypeVar('T')
def _to_tensor(x, builder):
if isinstance(x, bool):
@@ -33,17 +33,6 @@ def _to_tensor(x, builder):
assert False, f'cannot convert {x} to tensor'
def builtin(fn):
@wraps(fn)
def wrapper(*args, **kwargs):
if '_builder' not in kwargs or \
kwargs['_builder'] is None:
raise ValueError("Did you forget to add @triton.jit ? (`_builder` argument must be provided outside of JIT functions.)")
return fn(*args, **kwargs)
return wrapper
class dtype:
SINT_TYPES = ['int1', 'int8', 'int16', 'int32', 'int64']
UINT_TYPES = ['uint8', 'uint16', 'uint32', 'uint64']
@@ -405,14 +394,14 @@ class constexpr:
return constexpr(self.value != other.value)
def __bool__(self):
return constexpr(bool(self.value))
return bool(self.value)
def __neg__(self):
return constexpr(-self.value)
def __pos__(self):
return constexpr(+self.value)
def __invert__(self):
return constexpr(~self.value)
@@ -614,9 +603,9 @@ class tensor:
assert False, "unsupported"
return ret
# x[:, None, :, None]
# x = expand_dims(x, axis=1)
# x = expand_dims(x, axis=2)
@property
def T(self):
assert False, "Transposition must be created by the AST Visitor"
@builtin
def to(self, dtype, bitcast=False, _builder=None):
@@ -737,6 +726,9 @@ def broadcast_to(input, shape, _builder=None):
"""
return semantic.broadcast_impl_shape(input, shape, _builder)
@builtin
def trans(input, _builder=None):
return semantic.trans(input, _builder)
@builtin
def cat(input, other, _builder=None):
@@ -766,6 +758,10 @@ def view(input, shape, _builder=None):
shape = [x.value for x in shape]
return semantic.view(input, shape, _builder)
@builtin
def reshape(input, shape, _builder=None):
# TODO: should be more than just a view
return view(input, shape, _builder)
# -----------------------
# Linear Algebra
@@ -773,7 +769,7 @@ def view(input, shape, _builder=None):
@builtin
def dot(input, other, allow_tf32=True, trans_a=False, trans_b=False, _builder=None):
def dot(input, other, allow_tf32=True, _builder=None):
"""
Returns the matrix product of two blocks.
@@ -785,7 +781,7 @@ def dot(input, other, allow_tf32=True, trans_a=False, trans_b=False, _builder=No
:type other: 2D tensor of scalar-type in {:code:`float16`, :code:`bfloat16`, :code:`float32`}
"""
allow_tf32 = _constexpr_to_value(allow_tf32)
return semantic.dot(input, other, allow_tf32, trans_a, trans_b, _builder)
return semantic.dot(input, other, allow_tf32, _builder)
# -----------------------
@@ -847,9 +843,9 @@ def store(pointer, value, mask=None, _builder=None):
# Atomic Memory Operations
# -----------------------
def _add_atomic_docstr(name):
def _add_atomic_docstr(name: str) -> Callable[[T], T]:
def _decorator(func):
def _decorator(func: T) -> T:
docstr = """
Performs an atomic {name} at the memory location specified by :code:`pointer`.
@@ -970,9 +966,9 @@ def fdiv(x, y, ieee_rounding=False, _builder=None):
return semantic.fdiv(x, y, ieee_rounding, _builder)
def _add_math_1arg_docstr(name):
def _add_math_1arg_docstr(name: str) -> Callable[[T], T]:
def _decorator(func):
def _decorator(func: T) -> T:
docstr = """
Computes the element-wise {name} of :code:`x`
@@ -1019,9 +1015,9 @@ def sqrt(x, _builder=None):
# Reductions
# -----------------------
def _add_reduction_docstr(name):
def _add_reduction_docstr(name: str) -> Callable[[T], T]:
def _decorator(func):
def _decorator(func: T) -> T:
docstr = """
Returns the {name} of all elements in the :code:`input` tensor along the provided :code:`axis`
@@ -1041,6 +1037,13 @@ def max(input, axis, _builder=None):
return semantic.max(input, axis, _builder)
@builtin
@_add_reduction_docstr("maximum index")
def argmax(input, axis, _builder=None):
axis = _constexpr_to_value(axis)
return semantic.argmax(input, axis, _builder)
@builtin
@_add_reduction_docstr("minimum")
def min(input, axis, _builder=None):
@@ -1048,6 +1051,13 @@ def min(input, axis, _builder=None):
return semantic.min(input, axis, _builder)
@builtin
@_add_reduction_docstr("minimum index")
def argmin(input, axis, _builder=None):
axis = _constexpr_to_value(axis)
return semantic.argmin(input, axis, _builder)
@builtin
@_add_reduction_docstr("sum")
def sum(input, axis, _builder=None):

View File

@@ -86,25 +86,3 @@ def elementwise(lib_name: str, lib_path: str, args: list, arg_type_symbol_dict:
ret_shape = broadcast_arg.shape
func = getattr(_builder, "create_external_elementwise")
return dispatch(func, lib_name, lib_path, dispatch_args, arg_type_symbol_dict, ret_shape, _builder)
class ExternalFunction:
'''
A wrapper for external functions
'''
def __init__(self, fn):
self.fn = fn
def __call__(self, *args, **kwargs):
if '_builder' not in kwargs or \
kwargs['_builder'] is None:
raise ValueError("Did you forget to add @triton.jit ? (`_builder` argument must be provided outside of JIT functions.)")
return self.fn(*args, **kwargs)
def extern(fn):
'''
A decorator for external functions
'''
return ExternalFunction(fn)

File diff suppressed because it is too large Load Diff

View File

@@ -502,6 +502,11 @@ def cat(lhs: tl.tensor, rhs: tl.tensor, builder: ir.builder) -> tl.tensor:
# TODO: check types
return tl.tensor(builder.create_cat(lhs.handle, rhs.handle), lhs.type)
def trans(input: tl.tensor, builder: ir.builder) -> tl.tensor:
if len(input.shape) != 2:
raise ValueError("Only 2D tensors can be transposed")
ret_type = tl.block_type(input.type.scalar, [input.shape[1], input.shape[0]])
return tl.tensor(builder.create_trans(input.handle), ret_type)
def broadcast_impl_shape(input: tl.tensor,
shape: List[int],
@@ -971,8 +976,6 @@ def atomic_xchg(ptr: tl.tensor,
def dot(lhs: tl.tensor,
rhs: tl.tensor,
allow_tf32: bool,
trans_a: bool,
trans_b: bool,
builder: ir.builder) -> tl.tensor:
assert lhs.type.is_block() and rhs.type.is_block()
if lhs.type.scalar.is_int():
@@ -981,11 +984,11 @@ def dot(lhs: tl.tensor,
else:
_0 = builder.get_float32(0)
ret_scalar_ty = tl.float32
M = lhs.type.shape[1 if trans_a else 0]
N = rhs.type.shape[0 if trans_b else 1]
M = lhs.type.shape[0]
N = rhs.type.shape[1]
_0 = builder.create_splat(_0, [M, N])
ret_ty = tl.block_type(ret_scalar_ty, [M, N])
return tl.tensor(builder.create_dot(lhs.handle, rhs.handle, _0, allow_tf32, trans_a, trans_b),
return tl.tensor(builder.create_dot(lhs.handle, rhs.handle, _0, allow_tf32),
ret_ty)
@@ -1061,10 +1064,18 @@ def min(input: tl.tensor, axis: int, builder: ir.builder) -> tl.tensor:
return reduce_impl(input, axis, builder, "min", ir.REDUCE_OP.FMIN, ir.REDUCE_OP.MIN)
def argmin(input: tl.tensor, axis: int, builder: ir.builder) -> tl.tensor:
return reduce_impl(input, axis, builder, "argmin", ir.REDUCE_OP.ARGFMIN, ir.REDUCE_OP.ARGMIN)
def max(input: tl.tensor, axis: int, builder: ir.builder) -> tl.tensor:
return reduce_impl(input, axis, builder, "max", ir.REDUCE_OP.FMAX, ir.REDUCE_OP.MAX)
def argmax(input: tl.tensor, axis: int, builder: ir.builder) -> tl.tensor:
return reduce_impl(input, axis, builder, "argmax", ir.REDUCE_OP.ARGFMAX, ir.REDUCE_OP.ARGMAX)
def sum(input: tl.tensor, axis: int, builder: ir.builder) -> tl.tensor:
return reduce_impl(input, axis, builder, "sum", ir.REDUCE_OP.FADD, ir.REDUCE_OP.ADD)
@@ -1109,16 +1120,16 @@ def sqrt(x: tl.tensor, builder: ir.builder) -> tl.tensor:
##
def multiple_of(x: tl.tensor, values: List[int]) -> tl.tensor:
if len(x.shape) != len(values):
raise ValueError("Shape of input to multiple_of does not match the length of values")
x.handle.multiple_of(values)
return x
if len(x.shape) != len(values):
raise ValueError("Shape of input to multiple_of does not match the length of values")
x.handle.set_attr("tt.divisibility", ir.make_attr(values, x.handle.get_context()))
return x
def max_contiguous(x: tl.tensor, values: List[int]) -> tl.tensor:
if len(x.shape) != len(values):
raise ValueError("Shape of input to max_contiguous does not match the length of values")
x.handle.max_contiguous(values)
x.handle.set_attr("tt.contiguity", ir.make_attr(values, x.handle.get_context()))
return x

View File

@@ -1,5 +1,12 @@
# flake8: noqa: F401
#from .conv import _conv, conv
# from .conv import _conv, conv
from . import blocksparse
from .cross_entropy import _cross_entropy, cross_entropy
from .matmul import _matmul, matmul
__all__ = [
"blocksparse",
"_cross_entropy",
"cross_entropy",
"_matmul",
"matmul",
]

View File

@@ -1,3 +1,7 @@
# flake8: noqa: F401
from .matmul import matmul
from .softmax import softmax
__all__ = [
"matmul",
"softmax",
]

View File

@@ -1,2 +1,12 @@
from .autotuner import Config, Heuristics, autotune, heuristics # noqa: F401
from .jit import JITFunction, KernelInterface, version_key # noqa: F401
from .autotuner import Config, Heuristics, autotune, heuristics
from .jit import JITFunction, KernelInterface, version_key
__all__ = [
"Config",
"Heuristics",
"autotune",
"heuristics",
"JITFunction",
"KernelInterface",
"version_key",
]

View File

@@ -8,6 +8,7 @@ import os
import subprocess
import textwrap
from collections import namedtuple
from typing import TypeVar, Generic, cast, Callable, overload, Optional, Iterable, Union
import torch
@@ -19,6 +20,9 @@ try:
except ImportError:
get_cuda_stream = lambda dev_idx: torch.cuda.current_stream(dev_idx).cuda_stream
T = TypeVar('T')
# -----------------------------------------------------------------------------
# Dependencies Finder
# -----------------------------------------------------------------------------
@@ -94,20 +98,20 @@ def version_key():
return '-'.join(triton.__version__) + '-' + ptxas_version + '-' + '-'.join(contents)
class KernelInterface:
class KernelInterface(Generic[T]):
run: T
def __getitem__(self, grid):
def __getitem__(self, grid) -> T:
"""
A JIT function is launched with: fn[grid](*args, **kwargs).
Hence JITFunction.__getitem__ returns a callable proxy that
memorizes the grid.
"""
def launcher(*args, **kwargs):
return self.run(*args, grid=grid, **kwargs)
return launcher
return cast(T, functools.partial(cast(Callable, self.run), grid=grid))
class JITFunction(KernelInterface):
class JITFunction(KernelInterface[T]):
cache_hook = None
divisibility = 16
@@ -367,25 +371,55 @@ def {self.fn.__name__}({', '.join(self.arg_names)}, grid, num_warps=4, num_stage
# -----------------------------------------------------------------------------
def jit(*args, **kwargs):
@overload
def jit(fn: T) -> JITFunction[T]:
...
@overload
def jit(
*,
version=None,
do_not_specialize: Optional[Iterable[int]] = None,
) -> Callable[[T], JITFunction[T]]:
...
def jit(
fn: Optional[T] = None,
*,
version=None,
do_not_specialize: Optional[Iterable[int]] = None,
) -> Union[JITFunction[T], Callable[[T], JITFunction[T]]]:
"""
Decorator for JIT-compiling a function using the Triton compiler.
:note: When a jit'd function is called, :code:`torch.tensor` arguments are implicitly converted to pointers using the :code:`.data_ptr()` method.
:note: When a jit'd function is called, :code:`torch.tensor` arguments are
implicitly converted to pointers using the :code:`.data_ptr()` method.
:note: This function will be compiled and run on the GPU. It will only have access to:
* python primitives,
* objects within the triton.language package,
* builtins within the triton package,
* arguments to this function,
* other jit'd functions
:param fn: the function to be jit-compiled
:type fn: Callable
"""
if args:
assert len(args) == 1
assert callable(args[0])
return JITFunction(args[0], **kwargs)
def decorator(fn: T) -> JITFunction[T]:
assert callable(fn)
return JITFunction(
fn,
version=version,
do_not_specialize=do_not_specialize,
)
if fn is not None:
return decorator(fn)
else:
def decorator(fn):
return JITFunction(fn, **kwargs)
return decorator

View File

@@ -34,12 +34,12 @@ def sparsify_tensor(x, mask, block):
return ret
def make_pair(shape, device="cuda", alpha=1e-2, beta=0., trans=False, data=None):
def make_pair(shape, device="cuda", alpha=1e-2, beta=0., trans=False, data=None, dtype=torch.float32):
if data is None:
data = torch.randn(shape, dtype=torch.float32, device=device)
data = torch.randn(shape, dtype=torch.float32, requires_grad=True, device=device)
ref_ret = data
ref_ret = ref_ret * alpha + beta
ref_ret = ref_ret.half().float()
ref_ret = ref_ret.half().to(dtype)
if trans:
ref_ret = ref_ret.t().requires_grad_()
ref_ret = ref_ret.detach().requires_grad_()

View File

@@ -257,5 +257,5 @@ def bench_layer_norm(M, N, dtype, provider, mode='backward', eps=1e-5, device='c
grad_to_none=[x], rep=500)
return gbps(ms), gbps(max_ms), gbps(min_ms)
# test_layer_norm(1151, 8192, torch.float16)
bench_layer_norm.run(save_path='.', print_data=True)

View File

@@ -50,7 +50,7 @@ def _fwd_kernel(
# -- compute qk ----
k = tl.load(k_ptrs + start_n * stride_kn)
qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
qk += tl.dot(q, k, trans_b=True)
qk += tl.dot(q, tl.trans(k))
qk *= sm_scale
qk += tl.where(offs_m[:, None] >= (start_n + offs_n[None, :]), 0, float("-inf"))
# -- compute m_ij, p, l_ij
@@ -165,26 +165,26 @@ def _bwd_kernel(
q = tl.load(q_ptrs)
# recompute p = softmax(qk, dim=-1).T
# NOTE: `do` is pre-divided by `l`; no normalization here
qk = tl.dot(q, k, trans_b=True)
qk = tl.dot(q, tl.trans(k))
qk = tl.where(offs_m_curr[:, None] >= (offs_n[None, :]), qk, float("-inf"))
m = tl.load(m_ptrs + offs_m_curr)
p = tl.exp(qk * sm_scale - m[:, None])
# compute dv
do = tl.load(do_ptrs)
dv += tl.dot(p.to(tl.float16), do, trans_a=True)
dv += tl.dot(tl.trans(p.to(tl.float16)), do)
# compute dp = dot(v, do)
Di = tl.load(D_ptrs + offs_m_curr)
dp = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) - Di[:, None]
dp += tl.dot(do, v, trans_b=True)
dp += tl.dot(do, tl.trans(v))
# compute ds = p * (dp - delta[:, None])
ds = p * dp * sm_scale
# compute dk = dot(ds.T, q)
dk += tl.dot(ds.to(tl.float16), q, trans_a=True)
# # compute dq
dk += tl.dot(tl.trans(ds.to(tl.float16)), q)
# compute dq
dq = tl.load(dq_ptrs)
dq += tl.dot(ds.to(tl.float16), k)
tl.store(dq_ptrs, dq)
# # increment pointers
# increment pointers
dq_ptrs += BLOCK_M * stride_qm
q_ptrs += BLOCK_M * stride_qm
do_ptrs += BLOCK_M * stride_qm
@@ -195,6 +195,7 @@ def _bwd_kernel(
tl.store(dk_ptrs, dk)
empty = torch.empty(128, device="cuda")
class _attention(torch.autograd.Function):
@staticmethod
@@ -205,7 +206,7 @@ class _attention(torch.autograd.Function):
assert Lq == Lk and Lk == Lv
assert Lk in {16, 32, 64, 128}
o = torch.empty_like(q)
grid = (triton.cdiv(q.shape[2], BLOCK), q.shape[0] * q.shape[1])
grid = (triton.cdiv(q.shape[2], BLOCK), q.shape[0] * q.shape[1], 1)
tmp = torch.empty((q.shape[0] * q.shape[1], q.shape[2]), device=q.device, dtype=torch.float32)
L = torch.empty((q.shape[0] * q.shape[1], q.shape[2]), device=q.device, dtype=torch.float32)
m = torch.empty((q.shape[0] * q.shape[1], q.shape[2]), device=q.device, dtype=torch.float32)
@@ -224,6 +225,7 @@ class _attention(torch.autograd.Function):
BLOCK_DMODEL=Lk, num_warps=num_warps,
num_stages=1,
)
ctx.save_for_backward(q, k, v, o, L, m)
ctx.BLOCK = BLOCK
ctx.grid = grid
@@ -268,13 +270,13 @@ class _attention(torch.autograd.Function):
attention = _attention.apply
@pytest.mark.parametrize('Z, H, N_CTX, D_HEAD', [(3, 2, 2048, 64)])
@pytest.mark.parametrize('Z, H, N_CTX, D_HEAD', [(4, 48, 1024, 64)])
def test_op(Z, H, N_CTX, D_HEAD, dtype=torch.float16):
torch.manual_seed(20)
q = torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0, std=.5).requires_grad_()
k = torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0, std=.5).requires_grad_()
v = torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0, std=.5).requires_grad_()
sm_scale = 0.3
q = torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0.1, std=0.2).requires_grad_()
k = torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0.4, std=0.2).requires_grad_()
v = torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0.3, std=0.2).requires_grad_()
sm_scale = 0.2
dout = torch.randn_like(q)
# reference implementation
M = torch.tril(torch.ones((N_CTX, N_CTX), device="cuda"))
@@ -283,13 +285,16 @@ def test_op(Z, H, N_CTX, D_HEAD, dtype=torch.float16):
for h in range(H):
p[:, :, M == 0] = float("-inf")
p = torch.softmax(p.float(), dim=-1).half()
# p = torch.exp(p)
ref_out = torch.matmul(p, v)
ref_out.backward(dout)
ref_dv, v.grad = v.grad.clone(), None
ref_dk, k.grad = k.grad.clone(), None
ref_dq, q.grad = q.grad.clone(), None
# triton implementation
# # triton implementation
tri_out = attention(q, k, v, sm_scale)
# print(ref_out)
# print(tri_out)
tri_out.backward(dout)
tri_dv, v.grad = v.grad.clone(), None
tri_dk, k.grad = k.grad.clone(), None
@@ -299,3 +304,50 @@ def test_op(Z, H, N_CTX, D_HEAD, dtype=torch.float16):
triton.testing.assert_almost_equal(ref_dv, tri_dv)
triton.testing.assert_almost_equal(ref_dk, tri_dk)
triton.testing.assert_almost_equal(ref_dq, tri_dq)
BATCH, N_HEADS, N_CTX, D_HEAD = 4, 48, 4096, 64
# vary seq length for fixed head and batch=4
configs = [triton.testing.Benchmark(
x_names=['N_CTX'],
x_vals=[2**i for i in range(10, 16)],
line_arg='provider',
line_vals=['triton'],
line_names=['Triton'],
styles=[('red', '-'), ('blue', '-')],
ylabel='ms',
plot_name=f'fused-attention-batch{BATCH}-head{N_HEADS}-d{D_HEAD}-{mode}',
args={'H': N_HEADS, 'BATCH': BATCH, 'D_HEAD': D_HEAD, 'dtype': torch.float16, 'mode': mode}
) for mode in ['fwd']]
@triton.testing.perf_report(configs)
def bench_flash_attention(BATCH, H, N_CTX, D_HEAD, mode, provider, dtype=torch.float16, device="cuda"):
assert mode in ['fwd', 'bwd']
warmup = 25
rep = 100
if provider == "triton":
q = torch.randn((BATCH, H, N_CTX, D_HEAD), dtype=dtype, device="cuda", requires_grad=True)
k = torch.randn((BATCH, H, N_CTX, D_HEAD), dtype=dtype, device="cuda", requires_grad=True)
v = torch.randn((BATCH, H, N_CTX, D_HEAD), dtype=dtype, device="cuda", requires_grad=True)
sm_scale = 1.3
fn = lambda: attention(q, k, v, sm_scale)
if mode == 'bwd':
o = fn()
do = torch.randn_like(o)
fn = lambda: o.backward(do, retain_graph=True)
ms = triton.testing.do_bench(fn, percentiles=None, warmup=warmup, rep=rep)
return ms
if provider == "flash":
lengths = torch.full((BATCH,), fill_value=N_CTX, device=device)
cu_seqlens = torch.zeros((BATCH + 1,), device=device, dtype=torch.int32)
cu_seqlens[1:] = lengths.cumsum(0)
qkv = torch.randn((BATCH * N_CTX, 3, H, D_HEAD), dtype=dtype, device=device, requires_grad=True)
fn = lambda: flash_attn_func(qkv, cu_seqlens, 0., N_CTX, causal=True)
if mode == 'bwd':
o = fn()
do = torch.randn_like(o)
fn = lambda: o.backward(do, retain_graph=True)
ms = triton.testing.do_bench(fn, percentiles=None, warmup=warmup, rep=rep)
return ms
# bench_flash_attention.run(save_path='.', print_data=True)

View File

@@ -261,9 +261,9 @@ func @for_alias(%lb : index, %ub : index, %step : index, %A : !tt.ptr<f16>, %B :
%cst0 = tt.cat %a_shared_init, %b_shared_init {axis = 0} : (tensor<128x32xf16, #A_SHARED>, tensor<128x32xf16, #A_SHARED>) -> tensor<256x32xf16, #A_SHARED>
%c_shared_init = arith.constant dense<0.00e+00> : tensor<128x32xf16, #A_SHARED>
%a_shared, %b_shared, %c_shared = scf.for %iv = %lb to %ub step %step iter_args(%a_shared = %a_shared_init, %b_shared = %b_shared_init, %c_shared = %c_shared_init) -> (tensor<128x32xf16, #A_SHARED>, tensor<128x32xf16, #A_SHARED>, tensor<128x32xf16, #A_SHARED>) {
%cst1 = tt.cat %a_shared_init, %b_shared_init {axis = 0} : (tensor<128x32xf16, #A_SHARED>, tensor<128x32xf16, #A_SHARED>) -> tensor<256x32xf16, #A_SHARED>
%cst1 = tt.cat %a_shared_init, %b_shared_init {axis = 0} : (tensor<128x32xf16, #A_SHARED>, tensor<128x32xf16, #A_SHARED>) -> tensor<256x32xf16, #AL>
// CHECK-NEXT: Membar 6
%cst2 = tt.cat %a_shared, %b_shared {axis = 0} : (tensor<128x32xf16, #A_SHARED>, tensor<128x32xf16, #A_SHARED>) -> tensor<256x32xf16, #A_SHARED>
%cst2 = tt.cat %a_shared, %b_shared {axis = 0} : (tensor<128x32xf16, #A_SHARED>, tensor<128x32xf16, #A_SHARED>) -> tensor<256x32xf16, #AL>
scf.yield %c_shared, %a_shared, %b_shared : tensor<128x32xf16, #A_SHARED>, tensor<128x32xf16, #A_SHARED>, tensor<128x32xf16, #A_SHARED>
}
// CHECK-NEXT: Membar 9
@@ -271,4 +271,48 @@ func @for_alias(%lb : index, %ub : index, %step : index, %A : !tt.ptr<f16>, %B :
return
}
// Although cst2 is not an argument of scf.yield, its memory is reused by cst1.
// So we need a barrier both before and after cst1
// CHECK-LABEL: for_reuse
func @for_reuse(%lb : index, %ub : index, %step : index, %A : !tt.ptr<f16>, %B : !tt.ptr<f16>) {
%a_shared_init = arith.constant dense<0.00e+00> : tensor<128x32xf16, #A_SHARED>
%b_shared_init = arith.constant dense<0.00e+00> : tensor<128x32xf16, #A_SHARED>
// CHECK-NEXT: Membar 2
%cst0 = tt.cat %a_shared_init, %b_shared_init {axis = 0} : (tensor<128x32xf16, #A_SHARED>, tensor<128x32xf16, #A_SHARED>) -> tensor<256x32xf16, #A_SHARED>
%c_shared_init = arith.constant dense<0.00e+00> : tensor<128x32xf16, #A_SHARED>
%a_shared, %b_shared, %c_shared = scf.for %iv = %lb to %ub step %step iter_args(%a_shared = %a_shared_init, %b_shared = %b_shared_init, %c_shared = %c_shared_init) -> (tensor<128x32xf16, #A_SHARED>, tensor<128x32xf16, #A_SHARED>, tensor<128x32xf16, #A_SHARED>) {
// CHECK-NEXT: Membar 5
%cst1 = tt.cat %a_shared_init, %b_shared_init {axis = 0} : (tensor<128x32xf16, #A_SHARED>, tensor<128x32xf16, #A_SHARED>) -> tensor<256x32xf16, #A_SHARED>
// CHECK-NEXT: Membar 7
%cst2 = tt.cat %a_shared, %b_shared {axis = 0} : (tensor<128x32xf16, #A_SHARED>, tensor<128x32xf16, #A_SHARED>) -> tensor<256x32xf16, #A_SHARED>
scf.yield %c_shared, %a_shared, %b_shared : tensor<128x32xf16, #A_SHARED>, tensor<128x32xf16, #A_SHARED>, tensor<128x32xf16, #A_SHARED>
}
// CHECK-NEXT: Membar 10
%cst3 = tt.cat %cst0, %cst0 {axis = 0} : (tensor<256x32xf16, #A_SHARED>, tensor<256x32xf16, #A_SHARED>) -> tensor<512x32xf16, #A_SHARED>
return
}
// CHECK-LABEL: for_reuse_nested
func @for_reuse_nested(%lb : index, %ub : index, %step : index, %A : !tt.ptr<f16>, %B : !tt.ptr<f16>) {
%a_shared_init = arith.constant dense<0.00e+00> : tensor<128x32xf16, #A_SHARED>
%b_shared_init = arith.constant dense<0.00e+00> : tensor<128x32xf16, #A_SHARED>
// CHECK-NEXT: Membar 2
%cst0 = tt.cat %a_shared_init, %b_shared_init {axis = 0} : (tensor<128x32xf16, #A_SHARED>, tensor<128x32xf16, #A_SHARED>) -> tensor<256x32xf16, #A_SHARED>
%c_shared_init = arith.constant dense<0.00e+00> : tensor<128x32xf16, #A_SHARED>
%a_shared, %b_shared, %c_shared = scf.for %iv = %lb to %ub step %step iter_args(%a_shared = %a_shared_init, %b_shared = %b_shared_init, %c_shared = %c_shared_init) -> (tensor<128x32xf16, #A_SHARED>, tensor<128x32xf16, #A_SHARED>, tensor<128x32xf16, #A_SHARED>) {
// CHECK-NEXT: Membar 5
%cst1 = tt.cat %a_shared_init, %b_shared_init {axis = 0} : (tensor<128x32xf16, #A_SHARED>, tensor<128x32xf16, #A_SHARED>) -> tensor<256x32xf16, #A_SHARED>
%a_shared_next, %b_shared_next, %c_shared_next = scf.for %ivv = %lb to %ub step %step iter_args(%a_shared_nested = %a_shared_init, %b_shared_nested = %b_shared_init, %c_shared_nested = %c_shared_init) -> (tensor<128x32xf16, #A_SHARED>, tensor<128x32xf16, #A_SHARED>, tensor<128x32xf16, #A_SHARED>) {
// CHECK-NEXT: Membar 7
%cst2 = tt.cat %a_shared_nested, %b_shared_nested {axis = 0} : (tensor<128x32xf16, #A_SHARED>, tensor<128x32xf16, #A_SHARED>) -> tensor<256x32xf16, #A_SHARED>
scf.yield %c_shared_nested, %a_shared_nested, %b_shared_nested : tensor<128x32xf16, #A_SHARED>, tensor<128x32xf16, #A_SHARED>, tensor<128x32xf16, #A_SHARED>
}
scf.yield %c_shared, %a_shared, %b_shared : tensor<128x32xf16, #A_SHARED>, tensor<128x32xf16, #A_SHARED>, tensor<128x32xf16, #A_SHARED>
}
// CHECK-NEXT: Membar 11
%cst3 = tt.cat %cst0, %cst0 {axis = 0} : (tensor<256x32xf16, #A_SHARED>, tensor<256x32xf16, #A_SHARED>) -> tensor<512x32xf16, #A_SHARED>
return
}
}

View File

@@ -387,6 +387,45 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} {
// -----
#block0 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [4], warpsPerCTA = [4], order = [0]}>
#block1 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [8], warpsPerCTA = [4], order = [0]}>
#block2 = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [4, 1], warpsPerCTA = [4, 1], order = [1, 0]}>
#block3 = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 8], warpsPerCTA = [1, 4], order = [1, 0]}>
#slice2d1 = #triton_gpu.slice<{dim = 1, parent=#block2}>
#slice3d0 = #triton_gpu.slice<{dim = 0, parent=#block3}>
#AL = #triton_gpu.blocked<{sizePerThread = [1, 8], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}>
#A = #triton_gpu.shared<{vec = 8, perPhase = 1, maxPhase = 4, order = [1, 0]}>
module attributes {"triton_gpu.num-warps" = 4 : i32} {
// CHECK-LABEL: basic_insert_slice_async_fallback
func @basic_insert_slice_async_fallback(%arg0: !tt.ptr<f16> {tt.divisibility = 1 : i32}) {
%off0_ = tt.make_range {end = 16 : i32, start = 0 : i32} : tensor<16xi32, #slice2d1>
%off1_ = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<64xi32, #slice3d0>
%off0 = tt.expand_dims %off0_ {axis = 1 : i32} : (tensor<16xi32, #slice2d1>) -> tensor<16x1xi32, #block2>
%off1 = tt.expand_dims %off1_ {axis = 0 : i32} : (tensor<64xi32, #slice3d0>) -> tensor<1x64xi32, #block3>
%broadcast_off0_scalar = tt.broadcast %off0 : (tensor<16x1xi32, #block2>) -> tensor<16x64xi32, #block2>
%cst_scalar = arith.constant 64 : i32
%cst = tt.splat %cst_scalar : (i32) -> tensor<16x64xi32, #block2>
%broadcast_off0_ = arith.muli %broadcast_off0_scalar, %cst : tensor<16x64xi32, #block2>
%broadcast_off1_ = tt.broadcast %off1 : (tensor<1x64xi32, #block3>) -> tensor<16x64xi32, #block3>
%broadcast_off0 = triton_gpu.convert_layout %broadcast_off0_ : (tensor<16x64xi32, #block2>) -> tensor<16x64xi32, #AL>
%broadcast_off1 = triton_gpu.convert_layout %broadcast_off1_ : (tensor<16x64xi32, #block3>) -> tensor<16x64xi32, #AL>
%off = arith.addi %broadcast_off0, %broadcast_off1 : tensor<16x64xi32, #AL>
%a_init = tt.splat %arg0 : (!tt.ptr<f16>) -> tensor<16x64x!tt.ptr<f16>, #AL>
%a_ptr = tt.addptr %a_init, %off : tensor<16x64x!tt.ptr<f16>, #AL>
%tensor = triton_gpu.alloc_tensor : tensor<2x16x64xf16, #A>
%index = arith.constant 1 : i32
// CHECK: llvm.load
// CHECK-SAME: !llvm.ptr<vector<8xi32>, 3>
// CHECK: llvm.load
// CHECK-SAME: !llvm.ptr<vector<8xi32>, 3>
%a = triton_gpu.insert_slice_async %a_ptr, %tensor, %index {axis = 0 : i32, cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<16x64x!tt.ptr<f16>, #AL> -> tensor<2x16x64xf16, #A>
return
}
}
// -----
#block0 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [4], warpsPerCTA = [4], order = [0]}>
#block1 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [8], warpsPerCTA = [4], order = [0]}>
#block2 = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [4, 1], warpsPerCTA = [4, 1], order = [1, 0]}>
@@ -428,6 +467,100 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} {
// -----
#block0 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [4], warpsPerCTA = [4], order = [0]}>
#block1 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [8], warpsPerCTA = [4], order = [0]}>
#block2 = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [4, 1], warpsPerCTA = [4, 1], order = [1, 0]}>
#block3 = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 8], warpsPerCTA = [1, 4], order = [1, 0]}>
#slice2d1 = #triton_gpu.slice<{dim = 1, parent=#block2}>
#slice3d0 = #triton_gpu.slice<{dim = 0, parent=#block3}>
#AL = #triton_gpu.blocked<{sizePerThread = [1, 8], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}>
#A = #triton_gpu.shared<{vec = 8, perPhase = 1, maxPhase = 4, order = [1, 0]}>
module attributes {"triton_gpu.num-warps" = 4 : i32} {
// CHECK-LABEL: basic_insert_slice_async_mask
func @basic_insert_slice_async_mask(%arg0: !tt.ptr<f32> {tt.divisibility = 8 : i32}) {
%off0_ = tt.make_range {end = 16 : i32, start = 0 : i32} : tensor<16xi32, #slice2d1>
%off1_ = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<64xi32, #slice3d0>
%off0 = tt.expand_dims %off0_ {axis = 1 : i32} : (tensor<16xi32, #slice2d1>) -> tensor<16x1xi32, #block2>
%off1 = tt.expand_dims %off1_ {axis = 0 : i32} : (tensor<64xi32, #slice3d0>) -> tensor<1x64xi32, #block3>
%broadcast_off0_scalar = tt.broadcast %off0 : (tensor<16x1xi32, #block2>) -> tensor<16x64xi32, #block2>
%cst_scalar = arith.constant 64 : i32
%cst = tt.splat %cst_scalar : (i32) -> tensor<16x64xi32, #block2>
%broadcast_off0_ = arith.muli %broadcast_off0_scalar, %cst : tensor<16x64xi32, #block2>
%broadcast_off1_ = tt.broadcast %off1 : (tensor<1x64xi32, #block3>) -> tensor<16x64xi32, #block3>
%broadcast_off0 = triton_gpu.convert_layout %broadcast_off0_ : (tensor<16x64xi32, #block2>) -> tensor<16x64xi32, #AL>
%broadcast_off1 = triton_gpu.convert_layout %broadcast_off1_ : (tensor<16x64xi32, #block3>) -> tensor<16x64xi32, #AL>
%off = arith.addi %broadcast_off0, %broadcast_off1 : tensor<16x64xi32, #AL>
%a_init = tt.splat %arg0 : (!tt.ptr<f32>) -> tensor<16x64x!tt.ptr<f32>, #AL>
%a_ptr = tt.addptr %a_init, %off : tensor<16x64x!tt.ptr<f32>, #AL>
%tensor = triton_gpu.alloc_tensor : tensor<2x16x64xf32, #A>
%index = arith.constant 1 : i32
%true = arith.constant 1 : i1
%true_tensor = tt.splat %true : (i1) -> tensor<16x64xi1, #AL>
// CHECK: llvm.select
// CHECK: llvm.inline_asm has_side_effects asm_dialect = att
// CHECK-SAME: cp.async.cg.shared.global [ ${{.*}} + 0 ], [ ${{.*}} + 0 ], 0x10, ${{.*}}
// CHECK: llvm.select
// CHECK: llvm.inline_asm has_side_effects asm_dialect = att
// CHECK-SAME: cp.async.cg.shared.global [ ${{.*}} + 16 ], [ ${{.*}} + 0 ], 0x10, ${{.*}}
// CHECK: llvm.inline_asm has_side_effects asm_dialect = att
// CHECK-SAME: cp.async.commit_group
%a = triton_gpu.insert_slice_async %a_ptr, %tensor, %index, %true_tensor {axis = 0 : i32, cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<16x64x!tt.ptr<f32>, #AL> -> tensor<2x16x64xf32, #A>
return
}
}
// -----
#block0 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [4], warpsPerCTA = [4], order = [0]}>
#block1 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [8], warpsPerCTA = [4], order = [0]}>
#block2 = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [4, 1], warpsPerCTA = [4, 1], order = [1, 0]}>
#block3 = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 8], warpsPerCTA = [1, 4], order = [1, 0]}>
#slice2d1 = #triton_gpu.slice<{dim = 1, parent=#block2}>
#slice3d0 = #triton_gpu.slice<{dim = 0, parent=#block3}>
#AL = #triton_gpu.blocked<{sizePerThread = [1, 8], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}>
#A = #triton_gpu.shared<{vec = 8, perPhase = 1, maxPhase = 4, order = [1, 0]}>
module attributes {"triton_gpu.num-warps" = 4 : i32} {
// CHECK-LABEL: basic_insert_slice_async_mask_other
func @basic_insert_slice_async_mask_other(%arg0: !tt.ptr<f32> {tt.divisibility = 8 : i32}) {
%off0_ = tt.make_range {end = 16 : i32, start = 0 : i32} : tensor<16xi32, #slice2d1>
%off1_ = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<64xi32, #slice3d0>
%off0 = tt.expand_dims %off0_ {axis = 1 : i32} : (tensor<16xi32, #slice2d1>) -> tensor<16x1xi32, #block2>
%off1 = tt.expand_dims %off1_ {axis = 0 : i32} : (tensor<64xi32, #slice3d0>) -> tensor<1x64xi32, #block3>
%broadcast_off0_scalar = tt.broadcast %off0 : (tensor<16x1xi32, #block2>) -> tensor<16x64xi32, #block2>
%cst_scalar = arith.constant 64 : i32
%cst = tt.splat %cst_scalar : (i32) -> tensor<16x64xi32, #block2>
%broadcast_off0_ = arith.muli %broadcast_off0_scalar, %cst : tensor<16x64xi32, #block2>
%broadcast_off1_ = tt.broadcast %off1 : (tensor<1x64xi32, #block3>) -> tensor<16x64xi32, #block3>
%broadcast_off0 = triton_gpu.convert_layout %broadcast_off0_ : (tensor<16x64xi32, #block2>) -> tensor<16x64xi32, #AL>
%broadcast_off1 = triton_gpu.convert_layout %broadcast_off1_ : (tensor<16x64xi32, #block3>) -> tensor<16x64xi32, #AL>
%off = arith.addi %broadcast_off0, %broadcast_off1 : tensor<16x64xi32, #AL>
%a_init = tt.splat %arg0 : (!tt.ptr<f32>) -> tensor<16x64x!tt.ptr<f32>, #AL>
%a_ptr = tt.addptr %a_init, %off : tensor<16x64x!tt.ptr<f32>, #AL>
%tensor = triton_gpu.alloc_tensor : tensor<2x16x64xf32, #A>
%index = arith.constant 1 : i32
%true = arith.constant 1 : i1
%true_tensor = tt.splat %true : (i1) -> tensor<16x64xi1, #AL>
%other = arith.constant 1.0 : f32
%other_tensor = tt.splat %other : (f32) -> tensor<16x64xf32, #AL>
// CHECK: llvm.inline_asm has_side_effects asm_dialect = att
// CHECK-SAME: @${{.*}} st.shared.v4.b32 [ ${{.*}} + 0 ]
// CHECK: llvm.inline_asm has_side_effects asm_dialect = att
// CHECK-SAME: cp.async.cg.shared.global [ ${{.*}} + 0 ], [ ${{.*}} + 0 ], 0x10, ${{.*}}
// CHECK: llvm.inline_asm has_side_effects asm_dialect = att
// CHECK-SAME: @${{.*}} st.shared.v4.b32 [ ${{.*}} + 16 ]
// CHECK: llvm.inline_asm has_side_effects asm_dialect = att
// CHECK-SAME: cp.async.cg.shared.global [ ${{.*}} + 16 ], [ ${{.*}} + 0 ], 0x10, ${{.*}}
// CHECK: llvm.inline_asm has_side_effects asm_dialect = att
// CHECK-SAME: cp.async.commit_group
%a = triton_gpu.insert_slice_async %a_ptr, %tensor, %index, %true_tensor, %other_tensor {axis = 0 : i32, cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<16x64x!tt.ptr<f32>, #AL> -> tensor<2x16x64xf32, #A>
return
}
}
// -----
#block0 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [4], warpsPerCTA = [4], order = [0]}>
#block1 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [8], warpsPerCTA = [4], order = [0]}>
#block2 = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [4, 1], warpsPerCTA = [4, 1], order = [1, 0]}>
@@ -712,8 +845,32 @@ module attributes {"triton_gpu.num-warps" = 1 : i32} {
#mma = #triton_gpu.mma<{version = 2, warpsPerCTA = [2, 2]}>
module attributes {"triton_gpu.num-warps" = 1 : i32} {
// CHECK: llvm.mlir.global external @global_smem() {addr_space = 3 : i32} : !llvm.array<0 x i8>
// CHECK-LABEL: convert_layout_mma_block
func @convert_layout_mma_blocked(%arg0: tensor<32x16xf32, #mma>) {
// CHECK-LABEL: convert_layout_mmav2_block
func @convert_layout_mmav2_blocked(%arg0: tensor<32x16xf32, #mma>) {
// CHECK: llvm.store
// CHECK-SAME: !llvm.ptr<vector<2xf32>, 3>
// CHECK: llvm.store
// CHECK-SAME: !llvm.ptr<vector<2xf32>, 3>
// CHECK: nvvm.barrier0
// CHECK: llvm.load
// CHECK-SAME: !llvm.ptr<vector<4xf32>, 3>
%0 = triton_gpu.convert_layout %arg0 : (tensor<32x16xf32, #mma>) -> tensor<32x16xf32, #blocked0>
return
}
}
// -----
#blocked0 = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [32, 1], warpsPerCTA = [1, 4], order = [1, 0]}>
#mma = #triton_gpu.mma<{version = 1, warpsPerCTA = [2, 1]}>
module attributes {"triton_gpu.num-warps" = 1 : i32} {
// CHECK: llvm.mlir.global external @global_smem() {addr_space = 3 : i32} : !llvm.array<0 x i8>
// CHECK-LABEL: convert_layout_mmav1_block
func @convert_layout_mmav1_blocked(%arg0: tensor<32x16xf32, #mma>) {
// CHECK: llvm.store
// CHECK-SAME: !llvm.ptr<vector<2xf32>, 3>
// CHECK: llvm.store
// CHECK-SAME: !llvm.ptr<vector<2xf32>, 3>
// CHECK: llvm.store
// CHECK-SAME: !llvm.ptr<vector<2xf32>, 3>
// CHECK: llvm.store
@@ -860,6 +1017,45 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} {
// -----
#mma = #triton_gpu.mma<{version=2, warpsPerCTA=[2, 2]}>
#shared = #triton_gpu.shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [1, 0]}>
#blocked = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [2, 16], warpsPerCTA = [1, 4], order = [1, 0]}>
#dot_operand_a = #triton_gpu.dot_op<{opIdx=0, parent=#mma}>
#dot_operand_b = #triton_gpu.dot_op<{opIdx=1, parent=#mma}>
module attributes {"triton_gpu.num-warps" = 4 : i32} {
// CHECK-LABEL: matmul_tf32dot
func @matmul_tf32dot(%ptr:!tt.ptr<f32> {tt.divisibility = 16 : i32},
%a:tensor<32x16xf32, #shared>, %b:tensor<16x32xf32, #shared>) {
%cst = arith.constant dense<0.000000e+00> : tensor<32x32xf32, #mma>
// CHECK: llvm.inline_asm
// CHECK-SAME: ldmatrix.sync.aligned.m8n8.x4.shared.b16
// CHECK-SAME: (f32, f32, f32, f32)
// CHECK: llvm.inline_asm
// CHECK-SAME: ldmatrix.sync.aligned.m8n8.x4.shared.b16
// CHECK-SAME: (f32, f32, f32, f32)
%a_mat = triton_gpu.convert_layout %a : (tensor<32x16xf32, #shared>) -> tensor<32x16xf32, #dot_operand_a>
%b_mat = triton_gpu.convert_layout %b : (tensor<16x32xf32, #shared>) -> tensor<16x32xf32, #dot_operand_b>
// CHECK: llvm.inline_asm
// CHECK-SAME: mma.sync.aligned.m16n8k8.row.col.f32.tf32.tf32.f32
// CHECK: llvm.inline_asm
// CHECK-SAME: mma.sync.aligned.m16n8k8.row.col.f32.tf32.tf32.f32
// CHECK: llvm.inline_asm
// CHECK-SAME: mma.sync.aligned.m16n8k8.row.col.f32.tf32.tf32.f32
// CHECK: llvm.inline_asm
// CHECK-SAME: mma.sync.aligned.m16n8k8.row.col.f32.tf32.tf32.f32
%28 = tt.dot %a_mat, %b_mat, %cst {allowTF32 = true, transA = false, transB = false} : tensor<32x16xf32, #dot_operand_a> * tensor<16x32xf32, #dot_operand_b> -> tensor<32x32xf32, #mma>
%38 = triton_gpu.convert_layout %28 : (tensor<32x32xf32, #mma>) -> tensor<32x32xf32, #blocked>
%30 = tt.splat %ptr : (!tt.ptr<f32>) -> tensor<32x1x!tt.ptr<f32>, #blocked>
%36 = tt.broadcast %30 : (tensor<32x1x!tt.ptr<f32>, #blocked>) -> tensor<32x32x!tt.ptr<f32>, #blocked>
tt.store %36, %38 : tensor<32x32xf32, #blocked>
return
}
}
// -----
#blocked0 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>
module attributes {"triton_gpu.num-warps" = 4 : i32} {
// CHECK-LABEL: atomic_add_f32
@@ -897,9 +1093,9 @@ func @test_get_program_id(%a: tensor<32x!tt.ptr<i32>, #blocked0>) {
module attributes {"triton_gpu.num-warps" = 4 : i32} {
func @test_get_num_program(%a: tensor<32x!tt.ptr<i32>, #blocked0>) {
// CHECK: nvvm.read.ptx.sreg.ntid.x
// CHECK: nvvm.read.ptx.sreg.ntid.y
// CHECK: nvvm.read.ptx.sreg.ntid.z
// CHECK: nvvm.read.ptx.sreg.nctaid.x
// CHECK: nvvm.read.ptx.sreg.nctaid.y
// CHECK: nvvm.read.ptx.sreg.nctaid.z
%blockdimx = tt.get_num_programs {axis=0:i32} : i32
%blockdimy = tt.get_num_programs {axis=1:i32} : i32
%blockdimz = tt.get_num_programs {axis=2:i32} : i32

View File

@@ -13,10 +13,10 @@ func @test_combine_dot_add_pattern() -> (tensor<128x128xf32>, tensor<128x128xf32
%dot_out = tt.dot %a, %b, %zero {allowTF32 = true, transA = false, transB = false} : tensor<128x128xf32> * tensor<128x128xf32> -> tensor<128x128xf32>
// CHECK-NEXT: %[[res0:.*]] = tt.dot %[[a]], %[[b]], %[[d]] {allowTF32 = true, transA = false, transB = false} : tensor<128x128xf32> * tensor<128x128xf32> -> tensor<128x128xf32>
// CHECK-NEXT: %[[res0:.*]] = tt.dot %[[a]], %[[b]], %[[d]] {allowTF32 = true} : tensor<128x128xf32> * tensor<128x128xf32> -> tensor<128x128xf32>
%res0 = arith.addf %dot_out, %d : tensor<128x128xf32>
// CHECK-NEXT: %[[res1:.*]] = tt.dot %[[a]], %[[b]], %[[d]] {allowTF32 = true, transA = false, transB = false} : tensor<128x128xf32> * tensor<128x128xf32> -> tensor<128x128xf32>
// CHECK-NEXT: %[[res1:.*]] = tt.dot %[[a]], %[[b]], %[[d]] {allowTF32 = true} : tensor<128x128xf32> * tensor<128x128xf32> -> tensor<128x128xf32>
%res1 = arith.addf %d, %dot_out : tensor<128x128xf32>
return %res0, %res1 : tensor<128x128xf32>, tensor<128x128xf32>

View File

@@ -26,7 +26,9 @@ struct TestMembarPass
auto op_name = SymbolTable::getSymbolName(operation).getValue().str();
os << op_name << "\n";
Allocation allocation(operation);
MembarAnalysis analysis(&allocation);
MembarAnalysis membarPass(&allocation);
membarPass.run();
size_t operationId = 0;
operation->walk<WalkOrder::PreOrder>([&](Operation *op) {
if (isa<gpu::BarrierOp>(op)) {