Compare commits
37 Commits
keren/perf
...
keren/inse
Author | SHA1 | Date | |
---|---|---|---|
|
43408fef5a | ||
|
e817fdf1b9 | ||
|
8dd099beef | ||
|
f20f48a255 | ||
|
3eff110fbc | ||
|
5f85b79718 | ||
|
bab7338965 | ||
|
74f3d7a80f | ||
|
115cd3ac47 | ||
|
532e10cf87 | ||
|
16e973edf2 | ||
|
b539e031e8 | ||
|
46fa29496c | ||
|
9490252261 | ||
|
e419781978 | ||
|
189491727a | ||
|
e0072d210a | ||
|
2fa17588f7 | ||
|
e057c65cf0 | ||
|
99c7e0e008 | ||
|
f2fcaeabf3 | ||
|
8edfe813a5 | ||
|
4d64589b22 | ||
|
521ff9ad74 | ||
|
c280ebda1b | ||
|
9def1bcebf | ||
|
7d90a07d0b | ||
|
6461254fb5 | ||
|
4e6a8209ed | ||
|
9bb54402b3 | ||
|
66c36c4378 | ||
|
661be523c0 | ||
|
c87fbf886e | ||
|
0c1d4d764e | ||
|
9d31998a9d | ||
|
04ec5deb41 | ||
|
630dc315ee |
11
.github/workflows/integration-tests.yml
vendored
11
.github/workflows/integration-tests.yml
vendored
@@ -17,7 +17,7 @@ jobs:
|
||||
id: set-matrix
|
||||
run: |
|
||||
if [ x"${{ github.repository }}" == x"openai/triton" ]; then
|
||||
echo '::set-output name=matrix::[["self-hosted", "A10"], "macos-10.15"]'
|
||||
echo '::set-output name=matrix::[["self-hosted", "A10"], ["self-hosted", "V100"], "macos-10.15"]'
|
||||
else
|
||||
echo '::set-output name=matrix::["ubuntu-latest", "macos-10.15"]'
|
||||
fi
|
||||
@@ -79,11 +79,18 @@ jobs:
|
||||
lit -v "$LIT_TEST_DIR"
|
||||
|
||||
- name: Run python tests
|
||||
if: ${{matrix.runner[0] == 'self-hosted'}}
|
||||
if: ${{matrix.runner[0] == 'self-hosted' && matrix.runner[1] == 'A10'}}
|
||||
run: |
|
||||
cd python/tests
|
||||
pytest
|
||||
|
||||
# TODO[Superjomn] Enable all the tests on V100 if available
|
||||
- name: Run python tests on V100
|
||||
if: ${{matrix.runner[0] == 'self-hosted' && matrix.runner[1] == 'V100'}}
|
||||
run: |
|
||||
cd python/tests
|
||||
pytest test_gemm.py::test_gemm_no_scf_for_mmav1
|
||||
|
||||
- name: Run CXX unittests
|
||||
run: |
|
||||
cd python/
|
||||
|
@@ -20,8 +20,6 @@ SmallVector<unsigned>
|
||||
getScratchConfigForCvtLayout(triton::gpu::ConvertLayoutOp op, unsigned &inVec,
|
||||
unsigned &outVec);
|
||||
|
||||
SmallVector<unsigned> getScratchConfigForReduce(triton::ReduceOp op);
|
||||
|
||||
} // namespace triton
|
||||
|
||||
/// Modified from llvm-15.0: llvm/ADT/AddressRanges.h
|
||||
|
@@ -131,6 +131,12 @@ public:
|
||||
ChangeResult
|
||||
visitOperation(Operation *op,
|
||||
ArrayRef<LatticeElement<AxisInfo> *> operands) override;
|
||||
|
||||
unsigned getPtrVectorSize(Value ptr);
|
||||
|
||||
unsigned getPtrAlignment(Value ptr);
|
||||
|
||||
unsigned getMaskAlignment(Value mask);
|
||||
};
|
||||
|
||||
} // namespace mlir
|
||||
|
@@ -29,7 +29,11 @@ public:
|
||||
/// The following circumstances are not considered yet:
|
||||
/// - Double buffers
|
||||
/// - N buffers
|
||||
MembarAnalysis(Allocation *allocation) : allocation(allocation) { run(); }
|
||||
MembarAnalysis(Allocation *allocation) : allocation(allocation) {}
|
||||
|
||||
/// Runs the membar analysis to the given operation, inserts a barrier if
|
||||
/// necessary.
|
||||
void run();
|
||||
|
||||
private:
|
||||
struct RegionInfo {
|
||||
@@ -82,10 +86,6 @@ private:
|
||||
}
|
||||
};
|
||||
|
||||
/// Runs the membar analysis to the given operation, inserts a barrier if
|
||||
/// necessary.
|
||||
void run();
|
||||
|
||||
/// Applies the barrier analysis based on the SCF dialect, in which each
|
||||
/// region has a single basic block only.
|
||||
/// Example:
|
||||
|
@@ -26,6 +26,12 @@ public:
|
||||
|
||||
unsigned getThreadsReductionAxis();
|
||||
|
||||
SmallVector<unsigned> getScratchConfigBasic();
|
||||
|
||||
SmallVector<SmallVector<unsigned>> getScratchConfigsFast();
|
||||
|
||||
unsigned getScratchSizeInBytes();
|
||||
|
||||
private:
|
||||
triton::ReduceOp op;
|
||||
RankedTensorType srcTy{};
|
||||
@@ -39,6 +45,14 @@ bool maybeAliasOp(Operation *op);
|
||||
|
||||
std::string getValueOperandName(Value value, AsmState &state);
|
||||
|
||||
template <typename T_OUT, typename T_IN>
|
||||
inline SmallVector<T_OUT> convertType(ArrayRef<T_IN> in) {
|
||||
SmallVector<T_OUT> out;
|
||||
for (const T_IN &i : in)
|
||||
out.push_back(T_OUT(i));
|
||||
return out;
|
||||
}
|
||||
|
||||
template <typename Int> Int product(llvm::ArrayRef<Int> arr) {
|
||||
return std::accumulate(arr.begin(), arr.end(), 1, std::multiplies{});
|
||||
}
|
||||
|
@@ -295,6 +295,18 @@ def TT_CatOp : TT_Op<"cat", [NoSideEffect,
|
||||
let assemblyFormat = "$lhs `,` $rhs attr-dict `:` functional-type(operands, results)";
|
||||
}
|
||||
|
||||
def TT_TransOp : TT_Op<"trans", [NoSideEffect,
|
||||
SameOperandsAndResultElementType]> {
|
||||
|
||||
let summary = "transpose a tensor";
|
||||
|
||||
let arguments = (ins TT_Tensor:$src);
|
||||
|
||||
let results = (outs TT_Tensor:$result);
|
||||
|
||||
let assemblyFormat = "$src attr-dict `:` functional-type(operands, results)";
|
||||
}
|
||||
|
||||
//
|
||||
// SPMD Ops
|
||||
//
|
||||
@@ -327,7 +339,7 @@ def TT_DotOp : TT_Op<"dot", [NoSideEffect,
|
||||
$d = matrix_multiply($a, $b) + $c
|
||||
}];
|
||||
|
||||
let arguments = (ins TT_FpIntTensor:$a, TT_FpIntTensor:$b, TT_FpIntTensor:$c, BoolAttr:$allowTF32, BoolAttr:$transA, BoolAttr:$transB);
|
||||
let arguments = (ins TT_FpIntTensor:$a, TT_FpIntTensor:$b, TT_FpIntTensor:$c, BoolAttr:$allowTF32);
|
||||
|
||||
let results = (outs TT_FpIntTensor:$d);
|
||||
|
||||
@@ -351,6 +363,11 @@ def TT_ReduceOp : TT_Op<"reduce", [NoSideEffect,
|
||||
|
||||
let assemblyFormat = "$operand attr-dict `:` type($operand) `->` type($result)";
|
||||
|
||||
let extraClassDeclaration = [{
|
||||
// This member function is marked static because we need to call it before the ReduceOp
|
||||
// is constructed, see the implementation of create_reduce in triton.cc.
|
||||
static bool withIndex(mlir::triton::RedOp redOp);
|
||||
}];
|
||||
}
|
||||
|
||||
//
|
||||
|
@@ -25,11 +25,13 @@ namespace gpu {
|
||||
|
||||
unsigned getElemsPerThread(Type type);
|
||||
|
||||
SmallVector<unsigned> getThreadsPerWarp(Attribute layout);
|
||||
SmallVector<unsigned> getThreadsPerWarp(const Attribute &layout);
|
||||
|
||||
SmallVector<unsigned> getWarpsPerCTA(Attribute layout);
|
||||
SmallVector<unsigned> getWarpsPerCTA(const Attribute &layout);
|
||||
|
||||
SmallVector<unsigned> getSizePerThread(Attribute layout);
|
||||
SmallVector<unsigned> getSizePerThread(const Attribute &layout);
|
||||
|
||||
SmallVector<unsigned> getContigPerThread(Attribute layout);
|
||||
|
||||
SmallVector<unsigned> getThreadsPerCTA(const Attribute &layout);
|
||||
|
||||
|
@@ -87,17 +87,25 @@ A_{3, 2} A_{3, 3} A_{3, 0} A_{3, 1} ... [phase 1] /
|
||||
// number of rows per phase
|
||||
int perPhase = 128 / (shape[order[0]] * (eltTy.getIntOrFloatBitWidth() / 8));
|
||||
perPhase = std::max<int>(perPhase, 1);
|
||||
|
||||
|
||||
// index of the inner dimension in `order`
|
||||
unsigned inner = (opIdx == 0) ? 0 : 1;
|
||||
|
||||
// ---- begin version 1 ----
|
||||
// TODO: handle rep (see
|
||||
// https://github.com/openai/triton/blob/master/lib/codegen/analysis/layout.cc#L209)
|
||||
if (version == 1) {
|
||||
bool is_row = order[0] != 0;
|
||||
bool is_vec4 = opIdx == 0 ? !is_row && (shape[order[0]] <= 16) :
|
||||
is_row && (shape[order[0]] <= 16);
|
||||
// TODO[Superjomn]: Support the case when is_vec4=false later
|
||||
// Currently, we only support ld.v2, for the mma layout varies with different ld vector width.
|
||||
is_vec4 = true;
|
||||
int pack_size = opIdx == 0 ? ((is_row || is_vec4) ? 1 : 2) :
|
||||
((is_row && !is_vec4) ? 2 : 1);
|
||||
int rep = 2 * pack_size;
|
||||
int maxPhase = (order[inner] == 1 ? 8 : 4) / perPhase;
|
||||
return $_get(context, 1, perPhase, maxPhase, order);
|
||||
}
|
||||
int vec = 2 * rep;
|
||||
return $_get(context, vec, perPhase, maxPhase, order);
|
||||
}
|
||||
|
||||
// ---- begin version 2 ----
|
||||
if (version == 2) {
|
||||
@@ -106,14 +114,14 @@ A_{3, 2} A_{3, 3} A_{3, 0} A_{3, 1} ... [phase 1] /
|
||||
// for now, disable swizzle when using transposed int8 tensor cores
|
||||
if (eltTy.isInteger(8) && order[0] == inner)
|
||||
return $_get(context, 1, 1, 1, order);
|
||||
|
||||
|
||||
// --- handle A operand ---
|
||||
if (opIdx == 0) { // compute swizzling for A operand
|
||||
int vec = (order[0] == 1) ? matShape[2] : matShape[0]; // k : m
|
||||
int mmaStride = (order[0] == 1) ? matShape[0] : matShape[2];
|
||||
int maxPhase = mmaStride / perPhase;
|
||||
return $_get(context, vec, perPhase, maxPhase, order);
|
||||
}
|
||||
}
|
||||
|
||||
// --- handle B operand ---
|
||||
if (opIdx == 1) {
|
||||
@@ -121,8 +129,8 @@ A_{3, 2} A_{3, 3} A_{3, 0} A_{3, 1} ... [phase 1] /
|
||||
int mmaStride = (order[0] == 1) ? matShape[2] : matShape[1];
|
||||
int maxPhase = mmaStride / perPhase;
|
||||
return $_get(context, vec, perPhase, maxPhase, order);
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
llvm_unreachable("invalid operand index");
|
||||
}
|
||||
|
||||
@@ -293,7 +301,7 @@ partitioned between warps.
|
||||
// -------------------------------- version = 1 --------------------------- //
|
||||
|
||||
For first-gen tensor cores, the implicit warpTileSize is [16, 16].
|
||||
Information about this layout can be found in the official PTX documentation
|
||||
Note: the layout is different from the recommended in PTX ISA
|
||||
https://docs.nvidia.com/cuda/parallel-thread-execution/index.html
|
||||
(mma.884 section, FP32 accumulator).
|
||||
|
||||
@@ -301,29 +309,29 @@ For example, the matrix L corresponding to blockTileSize=[32,16] is:
|
||||
|
||||
warp 0
|
||||
--------------------------------/\-------------------------------
|
||||
[ 0 0 2 2 0 0 2 2 4 4 6 6 4 4 6 6 ]
|
||||
[ 1 1 3 3 1 1 3 3 5 5 7 7 5 5 7 7 ]
|
||||
[ 0 0 2 2 0 0 2 2 4 4 6 6 4 4 6 6 ]
|
||||
[ 1 1 3 3 1 1 3 3 5 5 7 7 5 5 7 7 ]
|
||||
[ 16 16 18 18 16 16 18 18 20 20 22 22 20 20 22 22]
|
||||
[ 17 17 19 19 17 17 19 19 21 21 23 23 21 21 23 23]
|
||||
[ 16 16 18 18 16 16 18 18 20 20 22 22 20 20 22 22]
|
||||
[ 17 17 19 19 17 17 19 19 21 21 23 23 21 21 23 23]
|
||||
[ 8 8 10 10 8 8 10 10 12 12 14 14 12 12 14 14]
|
||||
[ 9 9 11 11 9 9 11 11 13 13 15 15 13 13 15 15]
|
||||
[ ..............................................................
|
||||
[ ..............................................................
|
||||
[ 24 24 26 26 24 24 26 26 28 28 30 30 28 28 30 30]
|
||||
[ 25 25 27 27 25 25 27 27 29 29 31 31 29 29 31 31]
|
||||
[ 0 0 2 2 8 8 10 10 0 0 2 2 8 8 10 10 ]
|
||||
[ 1 1 3 3 9 9 11 11 1 1 3 3 9 9 11 11 ]
|
||||
[ 0 0 2 2 8 8 10 10 0 0 2 2 8 8 10 10 ]
|
||||
[ 1 1 3 3 9 9 11 11 1 1 3 3 9 9 11 11 ]
|
||||
[ 4 4 6 6 12 12 14 14 4 4 6 6 12 12 14 14 ]
|
||||
[ 5 5 7 7 13 13 15 15 5 5 7 7 13 13 15 15 ]
|
||||
[ 4 4 6 6 12 12 14 14 4 4 6 6 12 12 14 14 ]
|
||||
[ 5 5 7 7 13 13 15 15 5 5 7 7 13 13 15 15 ]
|
||||
[ 16 16 18 18 20 20 22 22 16 16 18 18 20 20 22 22 ]
|
||||
[ 17 17 19 19 21 21 23 23 17 17 19 19 21 21 23 23 ]
|
||||
[ 16 16 18 18 20 20 22 22 16 16 18 18 20 20 22 22 ]
|
||||
[ 17 17 19 19 21 21 23 23 17 17 19 19 21 21 23 23 ]
|
||||
[ 24 24 26 26 28 28 30 30 24 24 26 26 28 28 30 30 ]
|
||||
[ 25 25 27 27 29 29 31 31 25 25 27 27 29 29 31 31 ]
|
||||
[ 24 24 26 26 28 28 30 30 24 24 26 26 28 28 30 30 ]
|
||||
[ 25 25 27 27 29 29 31 31 25 25 27 27 29 29 31 31 ]
|
||||
|
||||
warp 1 = warp0 + 32
|
||||
warp 1 = warp0 + 32
|
||||
--------------------------------/\-------------------------------
|
||||
[ 32 32 34 34 32 32 34 34 36 36 38 38 36 36 38 38]
|
||||
[ 33 33 35 35 33 33 35 35 37 37 39 39 37 37 39 39]
|
||||
[ ..............................................................
|
||||
[ ..............................................................
|
||||
[ 56 56 58 58 56 56 58 58 60 60 62 62 60 60 62 62]
|
||||
[ 57 57 59 59 57 57 59 59 61 61 63 63 61 61 63 63]
|
||||
[ 32 32 34 34 40 40 42 42 32 32 34 34 40 40 42 42 ]
|
||||
[ 33 33 35 35 41 41 43 43 33 33 35 35 41 41 43 43 ]
|
||||
[ ............................................................... ]
|
||||
|
||||
|
||||
// -------------------------------- version = 2 --------------------------- //
|
||||
|
||||
|
@@ -151,6 +151,10 @@ def TTG_InsertSliceAsyncOp : TTG_Op<"insert_slice_async",
|
||||
// attr-dict `:` type($src) `->` type($dst)
|
||||
//}];
|
||||
|
||||
let extraClassDeclaration = [{
|
||||
static DenseSet<unsigned> getEligibleLoadByteWidth(int computeCapability);
|
||||
}];
|
||||
|
||||
// The custom parser could be replaced with oilist in LLVM-16
|
||||
let parser = [{ return parseInsertSliceAsyncOp(parser, result); }];
|
||||
|
||||
|
@@ -26,7 +26,7 @@ ChangeResult SharedMemoryAliasAnalysis::visitOperation(
|
||||
// These ops may allocate a new shared memory buffer.
|
||||
auto result = op->getResult(0);
|
||||
// FIXME(Keren): extract and insert are always alias for now
|
||||
if (isa<tensor::ExtractSliceOp>(op)) {
|
||||
if (isa<tensor::ExtractSliceOp, triton::TransOp>(op)) {
|
||||
// extract_slice %src
|
||||
aliasInfo = AliasInfo(operands[0]->getValue());
|
||||
pessimistic = false;
|
||||
|
@@ -13,6 +13,7 @@
|
||||
|
||||
using ::mlir::triton::gpu::BlockedEncodingAttr;
|
||||
using ::mlir::triton::gpu::DotOperandEncodingAttr;
|
||||
using ::mlir::triton::gpu::getContigPerThread;
|
||||
using ::mlir::triton::gpu::getOrder;
|
||||
using ::mlir::triton::gpu::getShapePerCTA;
|
||||
using ::mlir::triton::gpu::getSizePerThread;
|
||||
@@ -60,8 +61,8 @@ getScratchConfigForCvtLayout(triton::gpu::ConvertLayoutOp op, unsigned &inVec,
|
||||
assert(srcLayout && dstLayout &&
|
||||
"Unexpect layout in getScratchConfigForCvtLayout()");
|
||||
auto [inOrd, outOrd] = getCvtOrder(srcLayout, dstLayout);
|
||||
unsigned srcContigPerThread = getSizePerThread(srcLayout)[inOrd[0]];
|
||||
unsigned dstContigPerThread = getSizePerThread(dstLayout)[outOrd[0]];
|
||||
unsigned srcContigPerThread = getContigPerThread(srcLayout)[inOrd[0]];
|
||||
unsigned dstContigPerThread = getContigPerThread(dstLayout)[outOrd[0]];
|
||||
// TODO: Fix the legacy issue that ourOrd[0] == 0 always means
|
||||
// that we cannot do vectorization.
|
||||
inVec = outOrd[0] == 0 ? 1 : inOrd[0] == 0 ? 1 : srcContigPerThread;
|
||||
@@ -88,25 +89,6 @@ getScratchConfigForCvtLayout(triton::gpu::ConvertLayoutOp op, unsigned &inVec,
|
||||
return paddedRepShape;
|
||||
}
|
||||
|
||||
SmallVector<unsigned> getScratchConfigForReduce(triton::ReduceOp op) {
|
||||
ReduceOpHelper helper(op);
|
||||
|
||||
SmallVector<unsigned> smemShape;
|
||||
auto srcShape = helper.getSrcShape();
|
||||
for (auto d : srcShape)
|
||||
smemShape.push_back(d);
|
||||
|
||||
auto axis = op.axis();
|
||||
if (helper.isFastReduction()) {
|
||||
smemShape[axis] = helper.getInterWarpSize();
|
||||
} else {
|
||||
smemShape[axis] =
|
||||
std::min(smemShape[axis], helper.getThreadsReductionAxis());
|
||||
}
|
||||
|
||||
return smemShape;
|
||||
}
|
||||
|
||||
// TODO: extend beyond scalars
|
||||
SmallVector<unsigned> getScratchConfigForAtomicRMW(triton::AtomicRMWOp op) {
|
||||
SmallVector<unsigned> smemShape;
|
||||
@@ -173,21 +155,9 @@ private:
|
||||
/// Initializes temporary shared memory for a given operation.
|
||||
void getScratchValueSize(Operation *op) {
|
||||
if (auto reduceOp = dyn_cast<triton::ReduceOp>(op)) {
|
||||
// TODO(Keren): Reduce with index is not supported yet.
|
||||
auto value = op->getOperand(0);
|
||||
if (auto tensorType = value.getType().dyn_cast<RankedTensorType>()) {
|
||||
bool fastReduce = ReduceOpHelper(reduceOp).isFastReduction();
|
||||
auto smemShape = getScratchConfigForReduce(reduceOp);
|
||||
unsigned elems = std::accumulate(smemShape.begin(), smemShape.end(), 1,
|
||||
std::multiplies{});
|
||||
if (fastReduce) {
|
||||
auto mod = op->getParentOfType<ModuleOp>();
|
||||
unsigned numWarps = triton::gpu::TritonGPUDialect::getNumWarps(mod);
|
||||
elems = std::max<unsigned>(elems, numWarps * 32);
|
||||
}
|
||||
auto bytes = elems * tensorType.getElementTypeBitWidth() / 8;
|
||||
allocation->addBuffer<BufferT::BufferKind::Scratch>(op, bytes);
|
||||
}
|
||||
ReduceOpHelper helper(reduceOp);
|
||||
unsigned bytes = helper.getScratchSizeInBytes();
|
||||
allocation->addBuffer<BufferT::BufferKind::Scratch>(op, bytes);
|
||||
} else if (auto cvtLayout = dyn_cast<triton::gpu::ConvertLayoutOp>(op)) {
|
||||
auto srcTy = cvtLayout.src().getType().cast<RankedTensorType>();
|
||||
auto dstTy = cvtLayout.result().getType().cast<RankedTensorType>();
|
||||
|
@@ -132,6 +132,7 @@ ChangeResult AxisInfoAnalysis::visitOperation(
|
||||
AxisInfo::DimVectorT(ty.getShape().begin(), ty.getShape().end()));
|
||||
}
|
||||
}
|
||||
// TODO: refactor & complete binary ops
|
||||
// Addition
|
||||
if (llvm::isa<arith::AddIOp, triton::AddPtrOp>(op)) {
|
||||
auto newContiguity = [&](AxisInfo lhs, AxisInfo rhs, int d) {
|
||||
@@ -159,6 +160,20 @@ ChangeResult AxisInfoAnalysis::visitOperation(
|
||||
curr = visitBinaryOp(op, operands[0]->getValue(), operands[1]->getValue(),
|
||||
newContiguity, newDivisibility, newConstancy);
|
||||
}
|
||||
// Remainder
|
||||
if (llvm::isa<arith::RemSIOp, arith::RemUIOp>(op)) {
|
||||
auto newContiguity = [](AxisInfo lhs, AxisInfo rhs, int d) {
|
||||
return gcd(lhs.getContiguity(d), rhs.getDivisibility(d));
|
||||
};
|
||||
auto newDivisibility = [](AxisInfo lhs, AxisInfo rhs, int d) {
|
||||
return gcd(lhs.getDivisibility(d), rhs.getDivisibility(d));
|
||||
};
|
||||
auto newConstancy = [](AxisInfo lhs, AxisInfo rhs, int d) {
|
||||
return gcd(lhs.getConstancy(d), rhs.getConstancy(d));
|
||||
};
|
||||
curr = visitBinaryOp(op, operands[0]->getValue(), operands[1]->getValue(),
|
||||
newContiguity, newDivisibility, newConstancy);
|
||||
}
|
||||
// TODO: All other binary ops
|
||||
if (llvm::isa<arith::AndIOp, arith::OrIOp>(op)) {
|
||||
auto newContiguity = [](AxisInfo lhs, AxisInfo rhs, int d) { return 1; };
|
||||
@@ -261,4 +276,46 @@ ChangeResult AxisInfoAnalysis::visitOperation(
|
||||
return result;
|
||||
}
|
||||
|
||||
unsigned AxisInfoAnalysis::getPtrVectorSize(Value ptr) {
|
||||
auto tensorTy = ptr.getType().dyn_cast<RankedTensorType>();
|
||||
if (!tensorTy)
|
||||
return 1;
|
||||
auto layout = tensorTy.getEncoding();
|
||||
auto shape = tensorTy.getShape();
|
||||
|
||||
// Here order should be ordered by contiguous first, so the first element
|
||||
// should have the largest contiguous.
|
||||
auto order = triton::gpu::getOrder(layout);
|
||||
unsigned align = getPtrAlignment(ptr);
|
||||
|
||||
unsigned contigPerThread = triton::gpu::getSizePerThread(layout)[order[0]];
|
||||
unsigned vec = std::min(align, contigPerThread);
|
||||
vec = std::min<unsigned>(shape[order[0]], vec);
|
||||
|
||||
return vec;
|
||||
}
|
||||
|
||||
unsigned AxisInfoAnalysis::getPtrAlignment(Value ptr) {
|
||||
auto tensorTy = ptr.getType().dyn_cast<RankedTensorType>();
|
||||
if (!tensorTy)
|
||||
return 1;
|
||||
auto axisInfo = lookupLatticeElement(ptr)->getValue();
|
||||
auto layout = tensorTy.getEncoding();
|
||||
auto order = triton::gpu::getOrder(layout);
|
||||
unsigned maxMultiple = axisInfo.getDivisibility(order[0]);
|
||||
unsigned maxContig = axisInfo.getContiguity(order[0]);
|
||||
unsigned alignment = std::min(maxMultiple, maxContig);
|
||||
return alignment;
|
||||
}
|
||||
|
||||
unsigned AxisInfoAnalysis::getMaskAlignment(Value mask) {
|
||||
auto tensorTy = mask.getType().dyn_cast<RankedTensorType>();
|
||||
if (!tensorTy)
|
||||
return 1;
|
||||
auto maskOrder = triton::gpu::getOrder(tensorTy.getEncoding());
|
||||
auto maskAxis = lookupLatticeElement(mask)->getValue();
|
||||
auto alignment = std::max<unsigned>(maskAxis.getConstancy(maskOrder[0]), 1);
|
||||
return alignment;
|
||||
}
|
||||
|
||||
} // namespace mlir
|
||||
|
@@ -24,21 +24,43 @@ void MembarAnalysis::dfsOperation(Operation *operation,
|
||||
// scf.if only: two regions
|
||||
// scf.for: one region
|
||||
RegionInfo curRegionInfo;
|
||||
for (auto ®ion : operation->getRegions()) {
|
||||
// Copy the parent info as the current info.
|
||||
RegionInfo regionInfo = *parentRegionInfo;
|
||||
for (auto &block : region.getBlocks()) {
|
||||
assert(region.getBlocks().size() == 1 &&
|
||||
"Multiple blocks in a region is not supported");
|
||||
for (auto &op : block.getOperations()) {
|
||||
// Traverse the nested operation.
|
||||
dfsOperation(&op, ®ionInfo, builder);
|
||||
auto traverseRegions = [&]() -> auto{
|
||||
for (auto ®ion : operation->getRegions()) {
|
||||
// Copy the parent info as the current info.
|
||||
RegionInfo regionInfo = *parentRegionInfo;
|
||||
for (auto &block : region.getBlocks()) {
|
||||
assert(region.getBlocks().size() == 1 &&
|
||||
"Multiple blocks in a region is not supported");
|
||||
for (auto &op : block.getOperations()) {
|
||||
// Traverse the nested operation.
|
||||
dfsOperation(&op, ®ionInfo, builder);
|
||||
}
|
||||
}
|
||||
curRegionInfo.join(regionInfo);
|
||||
}
|
||||
curRegionInfo.join(regionInfo);
|
||||
// Set the parent region info as the union of the nested region info.
|
||||
*parentRegionInfo = curRegionInfo;
|
||||
};
|
||||
|
||||
traverseRegions();
|
||||
if (isa<scf::ForOp>(operation)) {
|
||||
// scf.for can have two possible inputs: the init value and the
|
||||
// previous iteration's result. Although we've applied alias analysis,
|
||||
// there could be unsynced memory accesses on reused memories.
|
||||
// For example, consider the following code:
|
||||
// %1 = convert_layout %0: blocked -> shared
|
||||
// ...
|
||||
// gpu.barrier
|
||||
// ...
|
||||
// %5 = convert_layout %4 : shared -> dot
|
||||
// %6 = tt.dot %2, %5
|
||||
// scf.yield
|
||||
//
|
||||
// Though %5 could be released before scf.yield, it may shared the same
|
||||
// memory with %1. So we actually have to insert a barrier before %1 to
|
||||
// make sure the memory is synced.
|
||||
traverseRegions();
|
||||
}
|
||||
// Set the parent region info as the union of the nested region info.
|
||||
*parentRegionInfo = curRegionInfo;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -49,8 +71,7 @@ void MembarAnalysis::transfer(Operation *op, RegionInfo *regionInfo,
|
||||
// Do not insert barriers before control flow operations and
|
||||
// alloc/extract/insert
|
||||
// alloc is an allocation op without memory write.
|
||||
// In contrast, arith.constant is an allocation op with memory write.
|
||||
// FIXME(Keren): extract is always alias for now
|
||||
// FIXME(Keren): extract_slice is always alias for now
|
||||
return;
|
||||
}
|
||||
|
||||
@@ -60,9 +81,11 @@ void MembarAnalysis::transfer(Operation *op, RegionInfo *regionInfo,
|
||||
return;
|
||||
}
|
||||
|
||||
if (isa<triton::gpu::AsyncWaitOp>(op)) {
|
||||
// If the current op is an async wait, we insert a barrier op and sync
|
||||
// previous reads and writes.
|
||||
if (isa<triton::gpu::AsyncWaitOp>(op) &&
|
||||
!isa<gpu::BarrierOp>(op->getNextNode())) {
|
||||
// If the current op is an async wait and the next op is not a barrier we
|
||||
// insert a barrier op and sync
|
||||
regionInfo->sync();
|
||||
OpBuilder::InsertionGuard g(*builder);
|
||||
builder->setInsertionPointAfter(op);
|
||||
builder->create<gpu::BarrierOp>(op->getLoc());
|
||||
|
@@ -37,6 +37,50 @@ unsigned ReduceOpHelper::getThreadsReductionAxis() {
|
||||
triton::gpu::getWarpsPerCTA(srcLayout)[axis];
|
||||
}
|
||||
|
||||
SmallVector<unsigned> ReduceOpHelper::getScratchConfigBasic() {
|
||||
auto axis = op.axis();
|
||||
auto smemShape = convertType<unsigned>(getSrcShape());
|
||||
smemShape[axis] = std::min(smemShape[axis], getThreadsReductionAxis());
|
||||
return smemShape;
|
||||
}
|
||||
|
||||
SmallVector<SmallVector<unsigned>> ReduceOpHelper::getScratchConfigsFast() {
|
||||
auto axis = op.axis();
|
||||
SmallVector<SmallVector<unsigned>> smemShapes(3);
|
||||
|
||||
/// shared memory block0
|
||||
smemShapes[0] = convertType<unsigned>(getSrcShape());
|
||||
smemShapes[0][axis] = getInterWarpSize();
|
||||
|
||||
/// FIXME(Qingyi): This size is actually larger than required.
|
||||
/// shared memory block1:
|
||||
auto mod = op.getOperation()->getParentOfType<ModuleOp>();
|
||||
unsigned numWarps = triton::gpu::TritonGPUDialect::getNumWarps(mod);
|
||||
smemShapes[1].push_back(numWarps * 32);
|
||||
|
||||
return smemShapes;
|
||||
}
|
||||
|
||||
unsigned ReduceOpHelper::getScratchSizeInBytes() {
|
||||
unsigned elems = 0;
|
||||
if (isFastReduction()) {
|
||||
auto smemShapes = getScratchConfigsFast();
|
||||
for (const auto &smemShape : smemShapes)
|
||||
elems = std::max(elems, product<unsigned>(smemShape));
|
||||
} else {
|
||||
auto smemShape = getScratchConfigBasic();
|
||||
elems = product<unsigned>(smemShape);
|
||||
}
|
||||
|
||||
auto tensorType = op.operand().getType().cast<RankedTensorType>();
|
||||
unsigned bytes = elems * tensorType.getElementTypeBitWidth() / 8;
|
||||
|
||||
if (triton::ReduceOp::withIndex(op.redOp()))
|
||||
bytes += elems * sizeof(int32_t);
|
||||
|
||||
return bytes;
|
||||
}
|
||||
|
||||
bool isSharedEncoding(Value value) {
|
||||
auto type = value.getType();
|
||||
if (auto tensorType = type.dyn_cast<RankedTensorType>()) {
|
||||
@@ -61,7 +105,7 @@ bool maybeSharedAllocationOp(Operation *op) {
|
||||
}
|
||||
|
||||
bool maybeAliasOp(Operation *op) {
|
||||
return isa<tensor::ExtractSliceOp>(op) ||
|
||||
return isa<tensor::ExtractSliceOp>(op) || isa<triton::TransOp>(op) ||
|
||||
isa<triton::gpu::InsertSliceAsyncOp>(op) ||
|
||||
isa<tensor::InsertSliceOp>(op);
|
||||
}
|
||||
|
1816
lib/Conversion/TritonGPUToLLVM/DotHelpers.h
Normal file
1816
lib/Conversion/TritonGPUToLLVM/DotHelpers.h
Normal file
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
319
lib/Conversion/TritonGPUToLLVM/Utility.h
Normal file
319
lib/Conversion/TritonGPUToLLVM/Utility.h
Normal file
@@ -0,0 +1,319 @@
|
||||
#ifndef TRITON_CONVERSION_TRITONGPU_TO_LLVM_UTILITY_H
|
||||
#define TRITON_CONVERSION_TRITONGPU_TO_LLVM_UTILITY_H
|
||||
|
||||
#include "mlir/Analysis/SliceAnalysis.h"
|
||||
#include "mlir/Conversion/ArithmeticToLLVM/ArithmeticToLLVM.h"
|
||||
#include "mlir/Conversion/GPUToNVVM/GPUToNVVMPass.h"
|
||||
#include "mlir/Conversion/LLVMCommon/LoweringOptions.h"
|
||||
#include "mlir/Conversion/LLVMCommon/Pattern.h"
|
||||
#include "mlir/Conversion/MathToLLVM/MathToLLVM.h"
|
||||
#include "mlir/Conversion/SCFToStandard/SCFToStandard.h"
|
||||
#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h"
|
||||
#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
|
||||
#include "mlir/Dialect/GPU/GPUDialect.h"
|
||||
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
|
||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||
#include "mlir/IR/Matchers.h"
|
||||
#include "mlir/IR/TypeUtilities.h"
|
||||
#include "mlir/Transforms/DialectConversion.h"
|
||||
#include "triton/Analysis/Allocation.h"
|
||||
#include "triton/Analysis/AxisInfo.h"
|
||||
#include "triton/Analysis/Membar.h"
|
||||
#include "triton/Analysis/Utility.h"
|
||||
#include "triton/Conversion/MLIRTypes.h"
|
||||
#include "triton/Conversion/TritonGPUToLLVM/PtxAsmFormat.h"
|
||||
#include "triton/Conversion/TritonToTritonGPU/TritonToTritonGPU.h"
|
||||
#include "triton/Dialect/Triton/IR/Dialect.h"
|
||||
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
|
||||
#include "llvm/Support/Format.h"
|
||||
#include "llvm/Support/FormatVariadic.h"
|
||||
#include <memory>
|
||||
#include <numeric>
|
||||
|
||||
// Shortcuts for some commonly used LLVM ops to keep code simple and intuitive
|
||||
#define inttoptr(...) rewriter.create<LLVM::IntToPtrOp>(loc, __VA_ARGS__)
|
||||
#define ptrtoint(...) rewriter.create<LLVM::PtrToIntOp>(loc, __VA_ARGS__)
|
||||
#define zext(...) rewriter.create<LLVM::ZExtOp>(loc, __VA_ARGS__)
|
||||
#define udiv(...) rewriter.create<LLVM::UDivOp>(loc, __VA_ARGS__)
|
||||
#define urem(...) rewriter.create<LLVM::URemOp>(loc, __VA_ARGS__)
|
||||
#define add(...) rewriter.create<LLVM::AddOp>(loc, __VA_ARGS__)
|
||||
#define sub(...) rewriter.create<LLVM::SubOp>(loc, __VA_ARGS__)
|
||||
#define fadd(...) rewriter.create<LLVM::FAddOp>(loc, __VA_ARGS__)
|
||||
#define mul(...) rewriter.create<LLVM::MulOp>(loc, __VA_ARGS__)
|
||||
#define smax(...) rewriter.create<LLVM::SMaxOp>(loc, __VA_ARGS__)
|
||||
#define umax(...) rewriter.create<LLVM::UMaxOp>(loc, __VA_ARGS__)
|
||||
#define fmax(...) rewriter.create<LLVM::MaxNumOp>(loc, __VA_ARGS__)
|
||||
#define smin(...) rewriter.create<LLVM::SMinOp>(loc, __VA_ARGS__)
|
||||
#define umin(...) rewriter.create<LLVM::UMinOp>(loc, __VA_ARGS__)
|
||||
#define fmin(...) rewriter.create<LLVM::MinNumOp>(loc, __VA_ARGS__)
|
||||
#define and_(...) rewriter.create<LLVM::AndOp>(loc, __VA_ARGS__)
|
||||
#define xor_(...) rewriter.create<LLVM::XOrOp>(loc, __VA_ARGS__)
|
||||
#define bitcast(val__, type__) \
|
||||
rewriter.create<LLVM::BitcastOp>(loc, type__, val__)
|
||||
#define gep(...) rewriter.create<LLVM::GEPOp>(loc, __VA_ARGS__)
|
||||
#define ptr_ty(...) LLVM::LLVMPointerType::get(__VA_ARGS__)
|
||||
#define insert_val(...) rewriter.create<LLVM::InsertValueOp>(loc, __VA_ARGS__)
|
||||
#define extract_val(...) rewriter.create<LLVM::ExtractValueOp>(loc, __VA_ARGS__)
|
||||
#define insert_element(...) \
|
||||
rewriter.create<LLVM::InsertElementOp>(loc, __VA_ARGS__)
|
||||
#define extract_element(...) \
|
||||
rewriter.create<LLVM::ExtractElementOp>(loc, __VA_ARGS__)
|
||||
#define load(...) rewriter.create<LLVM::LoadOp>(loc, __VA_ARGS__)
|
||||
#define store(val, ptr) rewriter.create<LLVM::StoreOp>(loc, val, ptr)
|
||||
#define fcmp_ogt(lhs, rhs) \
|
||||
rewriter.create<LLVM::FCmpOp>(loc, rewriter.getI1Type(), \
|
||||
LLVM::FCmpPredicate::ogt, lhs, rhs)
|
||||
#define fcmp_olt(lhs, rhs) \
|
||||
rewriter.create<LLVM::FCmpOp>(loc, rewriter.getI1Type(), \
|
||||
LLVM::FCmpPredicate::olt, lhs, rhs)
|
||||
#define icmp_eq(...) \
|
||||
rewriter.create<LLVM::ICmpOp>(loc, LLVM::ICmpPredicate::eq, __VA_ARGS__)
|
||||
#define icmp_ne(...) \
|
||||
rewriter.create<LLVM::ICmpOp>(loc, LLVM::ICmpPredicate::ne, __VA_ARGS__)
|
||||
#define icmp_slt(...) \
|
||||
rewriter.create<LLVM::ICmpOp>(loc, LLVM::ICmpPredicate::slt, __VA_ARGS__)
|
||||
#define icmp_sle(...) \
|
||||
rewriter.create<LLVM::ICmpOp>(loc, LLVM::ICmpPredicate::sle, __VA_ARGS__)
|
||||
#define icmp_sgt(...) \
|
||||
rewriter.create<LLVM::ICmpOp>(loc, LLVM::ICmpPredicate::sgt, __VA_ARGS__)
|
||||
#define icmp_sge(...) \
|
||||
rewriter.create<LLVM::ICmpOp>(loc, LLVM::ICmpPredicate::sge, __VA_ARGS__)
|
||||
#define icmp_ult(...) \
|
||||
rewriter.create<LLVM::ICmpOp>(loc, LLVM::ICmpPredicate::ult, __VA_ARGS__)
|
||||
#define icmp_ule(...) \
|
||||
rewriter.create<LLVM::ICmpOp>(loc, LLVM::ICmpPredicate::ule, __VA_ARGS__)
|
||||
#define icmp_ugt(...) \
|
||||
rewriter.create<LLVM::ICmpOp>(loc, LLVM::ICmpPredicate::ugt, __VA_ARGS__)
|
||||
#define icmp_uge(...) \
|
||||
rewriter.create<LLVM::ICmpOp>(loc, LLVM::ICmpPredicate::uge, __VA_ARGS__)
|
||||
#define select(...) rewriter.create<LLVM::SelectOp>(loc, __VA_ARGS__)
|
||||
#define address_of(...) rewriter.create<LLVM::AddressOfOp>(loc, __VA_ARGS__)
|
||||
#define barrier() rewriter.create<mlir::gpu::BarrierOp>(loc)
|
||||
#define undef(...) rewriter.create<LLVM::UndefOp>(loc, __VA_ARGS__)
|
||||
#define i32_ty rewriter.getIntegerType(32)
|
||||
#define ui32_ty rewriter.getIntegerType(32, false)
|
||||
#define f16_ty rewriter.getF16Type()
|
||||
#define bf16_ty rewriter.getBF16Type()
|
||||
#define i8_ty rewriter.getIntegerType(8)
|
||||
#define f32_ty rewriter.getF32Type()
|
||||
#define f64_ty rewriter.getF64Type()
|
||||
#define vec_ty(type, num) VectorType::get(num, type)
|
||||
#define f32_val(...) LLVM::createConstantF32(loc, rewriter, __VA_ARGS__)
|
||||
#define f64_val(...) LLVM::createConstantF64(loc, rewriter, __VA_ARGS__)
|
||||
#define void_ty(ctx) LLVM::LLVMVoidType::get(ctx)
|
||||
#define struct_ty(...) LLVM::LLVMStructType::getLiteral(ctx, __VA_ARGS__)
|
||||
|
||||
// Creator for constant
|
||||
#define i32_val(...) LLVM::createConstantI32(loc, rewriter, __VA_ARGS__)
|
||||
#define int_val(width, val) \
|
||||
LLVM::createLLVMIntegerConstant(rewriter, loc, width, val)
|
||||
#define idx_val(...) \
|
||||
LLVM::createIndexConstant(rewriter, loc, this->getTypeConverter(), \
|
||||
__VA_ARGS__)
|
||||
|
||||
#define tid_val() getThreadId(rewriter, loc)
|
||||
|
||||
namespace mlir {
|
||||
namespace LLVM {
|
||||
using namespace mlir::triton;
|
||||
|
||||
Value getStructFromElements(Location loc, ValueRange resultVals,
|
||||
ConversionPatternRewriter &rewriter,
|
||||
Type structType) {
|
||||
if (!structType.isa<LLVM::LLVMStructType>()) {
|
||||
return *resultVals.begin();
|
||||
}
|
||||
|
||||
Value llvmStruct = rewriter.create<LLVM::UndefOp>(loc, structType);
|
||||
for (const auto &v : llvm::enumerate(resultVals)) {
|
||||
assert(v.value() && "can not insert null values");
|
||||
llvmStruct = insert_val(structType, llvmStruct, v.value(),
|
||||
rewriter.getI64ArrayAttr(v.index()));
|
||||
}
|
||||
return llvmStruct;
|
||||
}
|
||||
|
||||
SmallVector<Value> getElementsFromStruct(Location loc, Value llvmStruct,
|
||||
ConversionPatternRewriter &rewriter) {
|
||||
if (llvmStruct.getType().isIntOrIndexOrFloat() ||
|
||||
llvmStruct.getType().isa<triton::PointerType>() ||
|
||||
llvmStruct.getType().isa<LLVM::LLVMPointerType>())
|
||||
return {llvmStruct};
|
||||
ArrayRef<Type> types =
|
||||
llvmStruct.getType().cast<LLVM::LLVMStructType>().getBody();
|
||||
SmallVector<Value> results(types.size());
|
||||
for (unsigned i = 0; i < types.size(); ++i) {
|
||||
Type type = types[i];
|
||||
results[i] = extract_val(type, llvmStruct, rewriter.getI64ArrayAttr(i));
|
||||
}
|
||||
return results;
|
||||
}
|
||||
|
||||
// Create a 32-bit integer constant.
|
||||
Value createConstantI32(Location loc, PatternRewriter &rewriter, int32_t v) {
|
||||
auto i32ty = rewriter.getIntegerType(32);
|
||||
return rewriter.create<LLVM::ConstantOp>(loc, i32ty,
|
||||
IntegerAttr::get(i32ty, v));
|
||||
}
|
||||
|
||||
Value createConstantF32(Location loc, PatternRewriter &rewriter, float v) {
|
||||
auto type = type::f32Ty(rewriter.getContext());
|
||||
return rewriter.create<LLVM::ConstantOp>(loc, type,
|
||||
rewriter.getF32FloatAttr(v));
|
||||
}
|
||||
|
||||
Value createConstantF64(Location loc, PatternRewriter &rewriter, float v) {
|
||||
auto type = type::f64Ty(rewriter.getContext());
|
||||
return rewriter.create<LLVM::ConstantOp>(loc, type,
|
||||
rewriter.getF64FloatAttr(v));
|
||||
}
|
||||
|
||||
// Create an index type constant.
|
||||
Value createIndexConstant(OpBuilder &builder, Location loc,
|
||||
TypeConverter *converter, int64_t value) {
|
||||
Type ty = converter->convertType(builder.getIndexType());
|
||||
return builder.create<LLVM::ConstantOp>(loc, ty,
|
||||
builder.getIntegerAttr(ty, value));
|
||||
}
|
||||
|
||||
// Create an integer constant of \param width bits.
|
||||
Value createLLVMIntegerConstant(OpBuilder &builder, Location loc, short width,
|
||||
int64_t value) {
|
||||
Type ty = builder.getIntegerType(width);
|
||||
return builder.create<LLVM::ConstantOp>(loc, ty,
|
||||
builder.getIntegerAttr(ty, value));
|
||||
}
|
||||
|
||||
/// Helper function to get strides from a given shape and its order
|
||||
SmallVector<Value>
|
||||
getStridesFromShapeAndOrder(ArrayRef<int64_t> shape, ArrayRef<unsigned> order,
|
||||
Location loc, ConversionPatternRewriter &rewriter) {
|
||||
auto rank = shape.size();
|
||||
SmallVector<Value> strides(rank);
|
||||
auto stride = 1;
|
||||
for (auto idx : order) {
|
||||
strides[idx] = i32_val(stride);
|
||||
stride *= shape[idx];
|
||||
}
|
||||
return strides;
|
||||
}
|
||||
|
||||
struct SharedMemoryObject {
|
||||
Value base; // i32 ptr. The start address of the shared memory object.
|
||||
// We need to store strides as Values but not integers because the
|
||||
// extract_slice instruction can take a slice at artibary offsets.
|
||||
// Take $a[16:32, 16:32] as an example, though we know the stride of $a[0] is
|
||||
// 32, we need to let the instruction that uses $a to be aware of that.
|
||||
// Otherwise, when we use $a, we only know that the shape of $a is 16x16. If
|
||||
// we store strides into an attribute array of integers, the information
|
||||
// cannot pass through block argument assignment because attributes are
|
||||
// associated with operations but not Values.
|
||||
// TODO(Keren): We may need to figure out a way to store strides as integers
|
||||
// if we want to support more optimizations.
|
||||
SmallVector<Value>
|
||||
strides; // i32 int. The strides of the shared memory object.
|
||||
SmallVector<Value> offsets; // i32 int. The offsets of the shared memory
|
||||
// objects from the originally allocated object.
|
||||
|
||||
SharedMemoryObject(Value base, ArrayRef<Value> strides,
|
||||
ArrayRef<Value> offsets)
|
||||
: base(base), strides(strides.begin(), strides.end()),
|
||||
offsets(offsets.begin(), offsets.end()) {}
|
||||
|
||||
SharedMemoryObject(Value base, ArrayRef<int64_t> shape,
|
||||
ArrayRef<unsigned> order, Location loc,
|
||||
ConversionPatternRewriter &rewriter)
|
||||
: base(base) {
|
||||
strides = getStridesFromShapeAndOrder(shape, order, loc, rewriter);
|
||||
|
||||
for (auto idx : order) {
|
||||
offsets.emplace_back(i32_val(0));
|
||||
}
|
||||
}
|
||||
|
||||
SmallVector<Value> getElems() const {
|
||||
SmallVector<Value> elems;
|
||||
elems.push_back(base);
|
||||
elems.append(strides.begin(), strides.end());
|
||||
elems.append(offsets.begin(), offsets.end());
|
||||
return elems;
|
||||
}
|
||||
|
||||
SmallVector<Type> getTypes() const {
|
||||
SmallVector<Type> types;
|
||||
types.push_back(base.getType());
|
||||
types.append(strides.size(), IntegerType::get(base.getContext(), 32));
|
||||
types.append(offsets.size(), IntegerType::get(base.getContext(), 32));
|
||||
return types;
|
||||
}
|
||||
|
||||
Value getCSwizzleOffset(int order) const {
|
||||
assert(order >= 0 && order < strides.size());
|
||||
return offsets[order];
|
||||
}
|
||||
|
||||
Value getBaseBeforeSwizzle(int order, Location loc,
|
||||
ConversionPatternRewriter &rewriter) const {
|
||||
Value cSwizzleOffset = getCSwizzleOffset(order);
|
||||
Value offset = sub(i32_val(0), cSwizzleOffset);
|
||||
Type type = base.getType();
|
||||
return gep(type, base, offset);
|
||||
}
|
||||
};
|
||||
|
||||
SharedMemoryObject
|
||||
getSharedMemoryObjectFromStruct(Location loc, Value llvmStruct,
|
||||
ConversionPatternRewriter &rewriter) {
|
||||
auto elems = getElementsFromStruct(loc, llvmStruct, rewriter);
|
||||
auto rank = (elems.size() - 1) / 2;
|
||||
return {/*base=*/elems[0],
|
||||
/*strides=*/{elems.begin() + 1, elems.begin() + 1 + rank},
|
||||
/*offsets=*/{elems.begin() + 1 + rank, elems.end()}};
|
||||
}
|
||||
|
||||
Value storeShared(ConversionPatternRewriter &rewriter, Location loc, Value ptr,
|
||||
Value val, Value pred) {
|
||||
MLIRContext *ctx = rewriter.getContext();
|
||||
unsigned bits = val.getType().getIntOrFloatBitWidth();
|
||||
const char *c = bits == 64 ? "l" : (bits == 16 ? "h" : "r");
|
||||
|
||||
PTXBuilder builder;
|
||||
auto *ptrOpr = builder.newAddrOperand(ptr, "r");
|
||||
auto *valOpr = builder.newOperand(val, c);
|
||||
auto &st = builder.create<>("st")->shared().b(bits);
|
||||
st(ptrOpr, valOpr).predicate(pred, "b");
|
||||
return builder.launch(rewriter, loc, void_ty(ctx));
|
||||
}
|
||||
|
||||
Value shflSync(Location loc, ConversionPatternRewriter &rewriter, Value val,
|
||||
int i) {
|
||||
unsigned bits = val.getType().getIntOrFloatBitWidth();
|
||||
|
||||
if (bits == 64) {
|
||||
Type vecTy = vec_ty(f32_ty, 2);
|
||||
Value vec = bitcast(val, vecTy);
|
||||
Value val0 = extract_element(f32_ty, vec, i32_val(0));
|
||||
Value val1 = extract_element(f32_ty, vec, i32_val(1));
|
||||
val0 = shflSync(loc, rewriter, val0, i);
|
||||
val1 = shflSync(loc, rewriter, val1, i);
|
||||
vec = undef(vecTy);
|
||||
vec = insert_element(vecTy, vec, val0, i32_val(0));
|
||||
vec = insert_element(vecTy, vec, val1, i32_val(1));
|
||||
return bitcast(vec, val.getType());
|
||||
}
|
||||
|
||||
PTXBuilder builder;
|
||||
auto &shfl = builder.create("shfl.sync")->o("bfly").o("b32");
|
||||
auto *dOpr = builder.newOperand("=r");
|
||||
auto *aOpr = builder.newOperand(val, "r");
|
||||
auto *bOpr = builder.newConstantOperand(i);
|
||||
auto *cOpr = builder.newConstantOperand("0x1f");
|
||||
auto *maskOpr = builder.newConstantOperand("0xffffffff");
|
||||
shfl(dOpr, aOpr, bOpr, cOpr, maskOpr);
|
||||
return builder.launch(rewriter, loc, val.getType(), false);
|
||||
}
|
||||
|
||||
} // namespace LLVM
|
||||
} // namespace mlir
|
||||
|
||||
#endif
|
@@ -245,9 +245,53 @@ struct TritonDotPattern : public OpConversionPattern<triton::DotOp> {
|
||||
bType.getElementType(), encoding);
|
||||
b = rewriter.create<triton::gpu::ConvertLayoutOp>(b.getLoc(), dstType, b);
|
||||
}
|
||||
rewriter.replaceOpWithNewOp<triton::DotOp>(
|
||||
op, retType, a, b, adaptor.c(), adaptor.allowTF32(), adaptor.transA(),
|
||||
adaptor.transB());
|
||||
rewriter.replaceOpWithNewOp<triton::DotOp>(op, retType, a, b, adaptor.c(),
|
||||
adaptor.allowTF32());
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
struct TritonTransPattern : public OpConversionPattern<triton::TransOp> {
|
||||
|
||||
using OpConversionPattern<triton::TransOp>::OpConversionPattern;
|
||||
|
||||
LogicalResult
|
||||
matchAndRewrite(triton::TransOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
Value src = adaptor.src();
|
||||
auto srcType = src.getType().cast<RankedTensorType>();
|
||||
Attribute srcEncoding = srcType.getEncoding();
|
||||
if (!srcEncoding)
|
||||
return failure();
|
||||
if (!srcEncoding.isa<triton::gpu::SharedEncodingAttr>()) {
|
||||
// TODO: end-to-end correctness is broken if
|
||||
// the input is blocked and the output is shared
|
||||
// with different order. Maybe a backend issue in BlockedToShared?
|
||||
SmallVector<unsigned> order = {1, 0};
|
||||
if (auto srcBlockedEncoding =
|
||||
srcEncoding.dyn_cast<triton::gpu::BlockedEncodingAttr>())
|
||||
llvm::copy(srcBlockedEncoding.getOrder(), order.begin());
|
||||
srcEncoding =
|
||||
triton::gpu::SharedEncodingAttr::get(getContext(), 1, 1, 1, order);
|
||||
srcType = RankedTensorType::get(srcType.getShape(),
|
||||
srcType.getElementType(), srcEncoding);
|
||||
src = rewriter.create<triton::gpu::ConvertLayoutOp>(src.getLoc(), srcType,
|
||||
src);
|
||||
}
|
||||
auto srcSharedEncoding =
|
||||
srcEncoding.cast<triton::gpu::SharedEncodingAttr>();
|
||||
SmallVector<unsigned> retOrder(srcSharedEncoding.getOrder().begin(),
|
||||
srcSharedEncoding.getOrder().end());
|
||||
SmallVector<int64_t> retShapes(srcType.getShape().begin(),
|
||||
srcType.getShape().end());
|
||||
std::reverse(retOrder.begin(), retOrder.end());
|
||||
std::reverse(retShapes.begin(), retShapes.end());
|
||||
auto retEncoding =
|
||||
triton::gpu::SharedEncodingAttr::get(getContext(), 1, 1, 1, retOrder);
|
||||
auto retType =
|
||||
RankedTensorType::get(retShapes, srcType.getElementType(), retEncoding);
|
||||
|
||||
rewriter.replaceOpWithNewOp<triton::TransOp>(op, retType, src);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
@@ -286,8 +330,8 @@ struct TritonAtomicCASPattern
|
||||
matchAndRewrite(triton::AtomicCASOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
rewriter.replaceOpWithNewOp<triton::AtomicCASOp>(
|
||||
op, typeConverter->convertType(op.getType()),
|
||||
adaptor.ptr(), adaptor.cmp(), adaptor.val());
|
||||
op, typeConverter->convertType(op.getType()), adaptor.ptr(),
|
||||
adaptor.cmp(), adaptor.val());
|
||||
return success();
|
||||
}
|
||||
};
|
||||
@@ -390,9 +434,10 @@ void populateTritonPatterns(TritonGPUTypeConverter &typeConverter,
|
||||
TritonGenericPattern<triton::PtrToIntOp>,
|
||||
TritonGenericPattern<triton::SplatOp>, TritonBroadcastPattern,
|
||||
TritonGenericPattern<triton::AddPtrOp>, TritonReducePattern,
|
||||
TritonExpandDimsPattern, TritonMakeRangePattern, TritonDotPattern,
|
||||
TritonLoadPattern, TritonStorePattern, TritonExtElemwisePattern,
|
||||
TritonPrintfPattern, TritonAtomicRMWPattern>(typeConverter, context);
|
||||
TritonTransPattern, TritonExpandDimsPattern, TritonMakeRangePattern,
|
||||
TritonDotPattern, TritonLoadPattern, TritonStorePattern,
|
||||
TritonExtElemwisePattern, TritonPrintfPattern, TritonAtomicRMWPattern>(
|
||||
typeConverter, context);
|
||||
}
|
||||
|
||||
//
|
||||
@@ -456,10 +501,55 @@ struct SCFYieldPattern : public OpConversionPattern<scf::YieldOp> {
|
||||
}
|
||||
};
|
||||
|
||||
// This is borrowed from ConvertFIfOpTypes in
|
||||
// SCF/Transforms/StructuralTypeConversions.cpp
|
||||
class SCFIfPattern : public OpConversionPattern<scf::IfOp> {
|
||||
public:
|
||||
using OpConversionPattern<scf::IfOp>::OpConversionPattern;
|
||||
LogicalResult
|
||||
matchAndRewrite(scf::IfOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
// TODO: Generalize this to any type conversion, not just 1:1.
|
||||
//
|
||||
// We need to implement something more sophisticated here that tracks which
|
||||
// types convert to which other types and does the appropriate
|
||||
// materialization logic.
|
||||
// For example, it's possible that one result type converts to 0 types and
|
||||
// another to 2 types, so newResultTypes would at least be the right size to
|
||||
// not crash in the llvm::zip call below, but then we would set the the
|
||||
// wrong type on the SSA values! These edge cases are also why we cannot
|
||||
// safely use the TypeConverter::convertTypes helper here.
|
||||
SmallVector<Type> newResultTypes;
|
||||
for (auto type : op.getResultTypes()) {
|
||||
Type newType = typeConverter->convertType(type);
|
||||
if (!newType)
|
||||
return rewriter.notifyMatchFailure(op, "not a 1:1 type conversion");
|
||||
newResultTypes.push_back(newType);
|
||||
}
|
||||
|
||||
// See comments in the ForOp pattern for why we clone without regions and
|
||||
// then inline.
|
||||
scf::IfOp newOp =
|
||||
cast<scf::IfOp>(rewriter.cloneWithoutRegions(*op.getOperation()));
|
||||
rewriter.inlineRegionBefore(op.getThenRegion(), newOp.getThenRegion(),
|
||||
newOp.getThenRegion().end());
|
||||
rewriter.inlineRegionBefore(op.getElseRegion(), newOp.getElseRegion(),
|
||||
newOp.getElseRegion().end());
|
||||
|
||||
// Update the operands and types.
|
||||
newOp->setOperands(adaptor.getOperands());
|
||||
for (auto t : llvm::zip(newOp.getResults(), newResultTypes))
|
||||
std::get<0>(t).setType(std::get<1>(t));
|
||||
rewriter.replaceOp(op, newOp.getResults());
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
void populateSCFPatterns(TritonGPUTypeConverter &typeConverter,
|
||||
RewritePatternSet &patterns) {
|
||||
MLIRContext *context = patterns.getContext();
|
||||
patterns.add<SCFYieldPattern, SCFForPattern>(typeConverter, context);
|
||||
patterns.add<SCFYieldPattern, SCFForPattern, SCFIfPattern>(typeConverter,
|
||||
context);
|
||||
}
|
||||
|
||||
class ConvertTritonToTritonGPU
|
||||
|
@@ -240,12 +240,17 @@ mlir::LogicalResult mlir::triton::ReduceOp::inferReturnTypes(
|
||||
Value arg = operands[0];
|
||||
auto argTy = arg.getType().cast<RankedTensorType>();
|
||||
auto argEltTy = argTy.getElementType();
|
||||
auto i32Ty = IntegerType::get(argEltTy.getContext(), 32);
|
||||
auto redOp =
|
||||
attributes.get("redOp").cast<mlir::triton::RedOpAttr>().getValue();
|
||||
bool withIndex = mlir::triton::ReduceOp::withIndex(redOp);
|
||||
auto retEltTy = withIndex ? i32Ty : argEltTy;
|
||||
auto retShape = argTy.getShape().vec();
|
||||
int axis = attributes.get("axis").cast<IntegerAttr>().getInt();
|
||||
retShape.erase(retShape.begin() + axis);
|
||||
if (retShape.empty()) {
|
||||
// 0d-tensor -> scalar
|
||||
inferredReturnTypes.push_back(argEltTy);
|
||||
inferredReturnTypes.push_back(retEltTy);
|
||||
} else {
|
||||
// nd-tensor where n >= 1
|
||||
// infer encoding
|
||||
@@ -264,11 +269,20 @@ mlir::LogicalResult mlir::triton::ReduceOp::inferReturnTypes(
|
||||
}
|
||||
// create type
|
||||
inferredReturnTypes.push_back(
|
||||
RankedTensorType::get(retShape, argEltTy, retEncoding));
|
||||
RankedTensorType::get(retShape, retEltTy, retEncoding));
|
||||
}
|
||||
return mlir::success();
|
||||
}
|
||||
|
||||
bool mlir::triton::ReduceOp::withIndex(mlir::triton::RedOp redOp) {
|
||||
return redOp == mlir::triton::RedOp::ARGMIN ||
|
||||
redOp == mlir::triton::RedOp::ARGMAX ||
|
||||
redOp == mlir::triton::RedOp::ARGUMIN ||
|
||||
redOp == mlir::triton::RedOp::ARGUMAX ||
|
||||
redOp == mlir::triton::RedOp::ARGFMIN ||
|
||||
redOp == mlir::triton::RedOp::ARGFMAX;
|
||||
}
|
||||
|
||||
//-- SplatOp --
|
||||
OpFoldResult SplatOp::fold(ArrayRef<Attribute> operands) {
|
||||
auto constOperand = src().getDefiningOp<arith::ConstantOp>();
|
||||
|
@@ -12,21 +12,21 @@ include "triton/Dialect/Triton/IR/TritonOps.td"
|
||||
// AddIOp(d, DotOp(a, b, c)) and c==0 => DotOp(a, b, d)
|
||||
// AddFOp(d, DotOp(a, b, c)) and c==0 => DotOp(a, b, d)
|
||||
def CombineDotAddIPattern : Pat<
|
||||
(Arith_AddIOp $d, (TT_DotOp:$res $a, $b, $c, $allowTF32, $transA, $transB)),
|
||||
(TT_DotOp $a, $b, $d, $allowTF32, $transA, $transB),
|
||||
(Arith_AddIOp $d, (TT_DotOp:$res $a, $b, $c, $allowTF32)),
|
||||
(TT_DotOp $a, $b, $d, $allowTF32),
|
||||
[(Constraint<CPred<"isZero($0)">> $c)]>;
|
||||
def CombineDotAddFPattern : Pat<
|
||||
(Arith_AddFOp $d, (TT_DotOp:$res $a, $b, $c, $allowTF32, $transA, $transB)),
|
||||
(TT_DotOp $a, $b, $d, $allowTF32, $transA, $transB),
|
||||
(Arith_AddFOp $d, (TT_DotOp:$res $a, $b, $c, $allowTF32)),
|
||||
(TT_DotOp $a, $b, $d, $allowTF32),
|
||||
[(Constraint<CPred<"isZero($0)">> $c)]>;
|
||||
|
||||
def CombineDotAddIRevPattern : Pat<
|
||||
(Arith_AddIOp (TT_DotOp:$res $a, $b, $c, $allowTF32, $transA, $transB), $d),
|
||||
(TT_DotOp $a, $b, $d, $allowTF32, $transA, $transB),
|
||||
(Arith_AddIOp (TT_DotOp:$res $a, $b, $c, $allowTF32), $d),
|
||||
(TT_DotOp $a, $b, $d, $allowTF32),
|
||||
[(Constraint<CPred<"isZero($0)">> $c)]>;
|
||||
def CombineDotAddFRevPattern : Pat<
|
||||
(Arith_AddFOp (TT_DotOp:$res $a, $b, $c, $allowTF32, $transA, $transB), $d),
|
||||
(TT_DotOp $a, $b, $d, $allowTF32, $transA, $transB),
|
||||
(Arith_AddFOp (TT_DotOp:$res $a, $b, $c, $allowTF32), $d),
|
||||
(TT_DotOp $a, $b, $d, $allowTF32),
|
||||
[(Constraint<CPred<"isZero($0)">> $c)]>;
|
||||
|
||||
|
||||
|
@@ -71,22 +71,22 @@ unsigned getElemsPerThread(Type type) {
|
||||
return getElemsPerThread(tensorType.getEncoding(), tensorType.getShape());
|
||||
}
|
||||
|
||||
SmallVector<unsigned> getThreadsPerWarp(Attribute layout) {
|
||||
SmallVector<unsigned> getThreadsPerWarp(const Attribute &layout) {
|
||||
if (auto blockedLayout = layout.dyn_cast<BlockedEncodingAttr>()) {
|
||||
return SmallVector<unsigned>(blockedLayout.getThreadsPerWarp().begin(),
|
||||
blockedLayout.getThreadsPerWarp().end());
|
||||
}
|
||||
if (auto mmaLayout = layout.dyn_cast<MmaEncodingAttr>()) {
|
||||
if (mmaLayout.getVersion() == 1)
|
||||
return SmallVector<unsigned>{4, 8};
|
||||
return {4, 8};
|
||||
if (mmaLayout.getVersion() == 2)
|
||||
return SmallVector<unsigned>{8, 4};
|
||||
return {8, 4};
|
||||
}
|
||||
assert(0 && "getThreadsPerWarp not implemented");
|
||||
return {};
|
||||
}
|
||||
|
||||
SmallVector<unsigned> getWarpsPerCTA(Attribute layout) {
|
||||
SmallVector<unsigned> getWarpsPerCTA(const Attribute &layout) {
|
||||
if (auto blockedLayout = layout.dyn_cast<BlockedEncodingAttr>()) {
|
||||
return SmallVector<unsigned>(blockedLayout.getWarpsPerCTA().begin(),
|
||||
blockedLayout.getWarpsPerCTA().end());
|
||||
@@ -99,16 +99,22 @@ SmallVector<unsigned> getWarpsPerCTA(Attribute layout) {
|
||||
return {};
|
||||
}
|
||||
|
||||
SmallVector<unsigned> getSizePerThread(Attribute layout) {
|
||||
SmallVector<unsigned> getSizePerThread(const Attribute &layout) {
|
||||
if (auto blockedLayout = layout.dyn_cast<BlockedEncodingAttr>()) {
|
||||
return SmallVector<unsigned>(blockedLayout.getSizePerThread().begin(),
|
||||
blockedLayout.getSizePerThread().end());
|
||||
} else if (auto sliceLayout = layout.dyn_cast<SliceEncodingAttr>()) {
|
||||
return getSizePerThread(sliceLayout.getParent());
|
||||
} else if (auto mmaLayout = layout.dyn_cast<MmaEncodingAttr>()) {
|
||||
assert(mmaLayout.getVersion() == 2 &&
|
||||
"mmaLayout version = 1 is not implemented yet");
|
||||
return SmallVector<unsigned>{2, 2};
|
||||
if (mmaLayout.getVersion() == 2) {
|
||||
return {2, 2};
|
||||
} else if (mmaLayout.getVersion() == 1) {
|
||||
// Note: here the definition of sizePerThread is obscure, which doesn't
|
||||
// mean vecSize=4 can be supported in the last dimension.
|
||||
return {2, 4};
|
||||
} else {
|
||||
llvm_unreachable("Unexpected mma version");
|
||||
}
|
||||
} else if (auto dotLayout = layout.dyn_cast<DotOperandEncodingAttr>()) {
|
||||
auto parentLayout = dotLayout.getParent();
|
||||
assert(parentLayout && "DotOperandEncodingAttr must have a parent");
|
||||
@@ -136,6 +142,15 @@ SmallVector<unsigned> getSizePerThread(Attribute layout) {
|
||||
}
|
||||
}
|
||||
|
||||
SmallVector<unsigned> getContigPerThread(Attribute layout) {
|
||||
if (auto mmaLayout = layout.dyn_cast<MmaEncodingAttr>()) {
|
||||
assert(mmaLayout.getVersion() == 1 || mmaLayout.getVersion() == 2);
|
||||
return {1, 2};
|
||||
} else {
|
||||
return getSizePerThread(layout);
|
||||
}
|
||||
}
|
||||
|
||||
SmallVector<unsigned> getThreadsPerCTA(const Attribute &layout) {
|
||||
SmallVector<unsigned> threads;
|
||||
if (auto blockedLayout = layout.dyn_cast<BlockedEncodingAttr>()) {
|
||||
@@ -194,6 +209,16 @@ SmallVector<unsigned> getShapePerCTA(const Attribute &layout) {
|
||||
assert(0 && "DotOperandEncodingAttr non-MmaEncodingAttr parent not "
|
||||
"supported yet");
|
||||
}
|
||||
} else if (auto mmaLayout = layout.dyn_cast<MmaEncodingAttr>()) {
|
||||
if (mmaLayout.getVersion() == 2) {
|
||||
return {16 * mmaLayout.getWarpsPerCTA()[0],
|
||||
8 * mmaLayout.getWarpsPerCTA()[1]};
|
||||
} else if (mmaLayout.getVersion() == 1) {
|
||||
return {16 * mmaLayout.getWarpsPerCTA()[0],
|
||||
16 * mmaLayout.getWarpsPerCTA()[1]};
|
||||
} else {
|
||||
llvm_unreachable("Unexpected mma version");
|
||||
}
|
||||
} else {
|
||||
assert(0 && "Unimplemented usage of getShapePerCTA");
|
||||
}
|
||||
@@ -205,9 +230,9 @@ SmallVector<unsigned> getOrder(const Attribute &layout) {
|
||||
return SmallVector<unsigned>(blockedLayout.getOrder().begin(),
|
||||
blockedLayout.getOrder().end());
|
||||
} else if (auto mmaLayout = layout.dyn_cast<MmaEncodingAttr>()) {
|
||||
return SmallVector<unsigned>{1, 0};
|
||||
return {1, 0};
|
||||
} else if (auto dotLayout = layout.dyn_cast<DotOperandEncodingAttr>()) {
|
||||
return SmallVector<unsigned>{1, 0};
|
||||
return {1, 0};
|
||||
} else if (auto sliceLayout = layout.dyn_cast<SliceEncodingAttr>()) {
|
||||
SmallVector<unsigned> parentOrder = getOrder(sliceLayout.getParent());
|
||||
unsigned dim = sliceLayout.getDim();
|
||||
@@ -358,6 +383,8 @@ unsigned MmaEncodingAttr::getElemsPerThread(ArrayRef<int64_t> shape) const {
|
||||
unsigned elemsCol = ceil<unsigned>(shape[0], 16 * getWarpsPerCTA()[0]) * 2;
|
||||
unsigned elemsRow = ceil<unsigned>(shape[1], 8 * getWarpsPerCTA()[1]) * 2;
|
||||
res = elemsCol * elemsRow;
|
||||
} else {
|
||||
llvm_unreachable("Unexpected mma version");
|
||||
}
|
||||
|
||||
return res;
|
||||
@@ -632,6 +659,15 @@ void printInsertSliceAsyncOp(OpAsmPrinter &printer,
|
||||
printer.printStrippedAttrOrType(insertSliceAsyncOp.result().getType());
|
||||
}
|
||||
|
||||
DenseSet<unsigned>
|
||||
InsertSliceAsyncOp::getEligibleLoadByteWidth(int computeCapability) {
|
||||
DenseSet<unsigned> validLoadBytes;
|
||||
if (computeCapability >= 80) {
|
||||
validLoadBytes = {4, 8, 16};
|
||||
}
|
||||
return validLoadBytes;
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// ASM Interface (i.e.: alias)
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@@ -50,10 +50,25 @@ public:
|
||||
auto dstType = convert.getType().cast<RankedTensorType>();
|
||||
if (srcType.getEncoding().isa<triton::gpu::BlockedEncodingAttr>() &&
|
||||
dstType.getEncoding().isa<triton::gpu::DotOperandEncodingAttr>()) {
|
||||
auto tmpType =
|
||||
RankedTensorType::get(dstType.getShape(), dstType.getElementType(),
|
||||
triton::gpu::SharedEncodingAttr::get(
|
||||
op->getContext(), 1, 1, 1, {1, 0}));
|
||||
auto dstDotOperand =
|
||||
dstType.getEncoding().cast<triton::gpu::DotOperandEncodingAttr>();
|
||||
auto dstParent = dstDotOperand.getParent();
|
||||
if (dstDotOperand.getOpIdx() == 1 ||
|
||||
!dstParent.isa<triton::gpu::MmaEncodingAttr>())
|
||||
return mlir::failure();
|
||||
auto dstParentMma = dstParent.cast<triton::gpu::MmaEncodingAttr>();
|
||||
if (dstParentMma.getVersion() == 1 ||
|
||||
dstParentMma.getWarpsPerCTA()[1] > 1)
|
||||
return mlir::failure();
|
||||
SetVector<Operation *> bwdSlices;
|
||||
mlir::getBackwardSlice(convert.getResult(), &bwdSlices);
|
||||
if (llvm::find_if(bwdSlices, [](Operation *op) {
|
||||
return isa<triton::DotOp>(op);
|
||||
}) == bwdSlices.end())
|
||||
return mlir::failure();
|
||||
|
||||
auto tmpType = RankedTensorType::get(
|
||||
dstType.getShape(), dstType.getElementType(), dstParentMma);
|
||||
auto tmp = rewriter.create<triton::gpu::ConvertLayoutOp>(
|
||||
convert.getLoc(), tmpType, convert.getOperand());
|
||||
auto newConvert = rewriter.create<triton::gpu::ConvertLayoutOp>(
|
||||
@@ -81,8 +96,11 @@ public:
|
||||
auto convert = llvm::cast<triton::gpu::ConvertLayoutOp>(op);
|
||||
// we don't handle conversions to DotOperandEncodingAttr
|
||||
// this is a heuristics to accommodate fused attention
|
||||
// if (dstType.getEncoding().isa<triton::gpu::DotOperandEncodingAttr>())
|
||||
// return mlir::failure();
|
||||
auto srcType = convert.getOperand().getType().cast<RankedTensorType>();
|
||||
auto dstType = convert.getType().cast<RankedTensorType>();
|
||||
if (dstType.getEncoding().isa<triton::gpu::DotOperandEncodingAttr>() &&
|
||||
srcType.getEncoding().isa<triton::gpu::MmaEncodingAttr>())
|
||||
return mlir::failure();
|
||||
// convert to the same layout -- we can delete
|
||||
if (op->getResultTypes() == op->getOperandTypes()) {
|
||||
rewriter.replaceOp(op, op->getOperands());
|
||||
@@ -160,6 +178,10 @@ public:
|
||||
!isSharedEncoding(convert.getResult())) {
|
||||
return mlir::failure();
|
||||
}
|
||||
if (isSharedEncoding(convert.getOperand()) &&
|
||||
isSharedEncoding(convert.getResult())) {
|
||||
return mlir::failure();
|
||||
}
|
||||
auto srcType = convert.getOperand().getType().cast<RankedTensorType>();
|
||||
auto srcShared =
|
||||
srcType.getEncoding().dyn_cast<triton::gpu::SharedEncodingAttr>();
|
||||
@@ -586,13 +608,9 @@ mmaVersionToShapePerWarp(int version, const ArrayRef<int64_t> &shape,
|
||||
}
|
||||
}
|
||||
|
||||
template <int version>
|
||||
SmallVector<unsigned, 2> warpsPerTile(const ArrayRef<int64_t> shape,
|
||||
int numWarps);
|
||||
|
||||
template <>
|
||||
SmallVector<unsigned, 2> warpsPerTile<1>(const ArrayRef<int64_t> shape,
|
||||
int numWarps) {
|
||||
SmallVector<unsigned, 2> warpsPerTileV1(triton::DotOp dotOp,
|
||||
const ArrayRef<int64_t> shape,
|
||||
int numWarps) {
|
||||
SmallVector<unsigned, 2> ret = {1, 1};
|
||||
SmallVector<int64_t, 2> shapePerWarp =
|
||||
mmaVersionToShapePerWarp(1, shape, numWarps);
|
||||
@@ -611,17 +629,25 @@ SmallVector<unsigned, 2> warpsPerTile<1>(const ArrayRef<int64_t> shape,
|
||||
return ret;
|
||||
}
|
||||
|
||||
template <>
|
||||
SmallVector<unsigned, 2> warpsPerTile<2>(const ArrayRef<int64_t> shape,
|
||||
int numWarps) {
|
||||
SmallVector<unsigned, 2> warpsPerTileV2(triton::DotOp dotOp,
|
||||
const ArrayRef<int64_t> shape,
|
||||
int numWarps) {
|
||||
SetVector<Operation *> slices;
|
||||
mlir::getForwardSlice(dotOp.getResult(), &slices);
|
||||
if (llvm::find_if(slices, [](Operation *op) {
|
||||
return isa<triton::DotOp>(op);
|
||||
}) != slices.end())
|
||||
return {(unsigned)numWarps, 1};
|
||||
|
||||
SmallVector<unsigned, 2> ret = {1, 1};
|
||||
SmallVector<int64_t, 2> shapePerWarp =
|
||||
mmaVersionToShapePerWarp(2, shape, numWarps);
|
||||
SmallVector<int64_t, 2> shapePerWarp = {16, 8};
|
||||
bool changed = false;
|
||||
// TODO (@daadaada): double-check.
|
||||
// original logic in
|
||||
// https://github.com/openai/triton/blob/master/lib/codegen/analysis/layout.cc#L252
|
||||
// seems buggy for shape = [32, 16] ?
|
||||
do {
|
||||
changed = false;
|
||||
if (ret[0] * ret[1] >= numWarps)
|
||||
break;
|
||||
if (shape[0] / shapePerWarp[0] / ret[0] >=
|
||||
@@ -638,6 +664,55 @@ SmallVector<unsigned, 2> warpsPerTile<2>(const ArrayRef<int64_t> shape,
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
class OptimizeBlockedToShared : public mlir::RewritePattern {
|
||||
public:
|
||||
OptimizeBlockedToShared(mlir::MLIRContext *context)
|
||||
: RewritePattern(triton::gpu::ConvertLayoutOp::getOperationName(), 1,
|
||||
context) {}
|
||||
|
||||
mlir::LogicalResult
|
||||
matchAndRewrite(mlir::Operation *op,
|
||||
mlir::PatternRewriter &rewriter) const override {
|
||||
auto cvt = cast<triton::gpu::ConvertLayoutOp>(op);
|
||||
auto srcType = cvt.getOperand().getType().cast<RankedTensorType>();
|
||||
auto dstType = cvt.getResult().getType().cast<RankedTensorType>();
|
||||
auto srcBlockedLayout =
|
||||
srcType.getEncoding().dyn_cast<triton::gpu::BlockedEncodingAttr>();
|
||||
auto dstSharedLayout =
|
||||
dstType.getEncoding().dyn_cast<triton::gpu::SharedEncodingAttr>();
|
||||
if (!srcBlockedLayout || !dstSharedLayout)
|
||||
return failure();
|
||||
if (srcBlockedLayout.getOrder() == dstSharedLayout.getOrder())
|
||||
return failure();
|
||||
// For now only works if single use is transpose
|
||||
// TODO: rematerialize #shared uses
|
||||
auto users = op->getUsers();
|
||||
if (std::distance(users.begin(), users.end()) != 1 ||
|
||||
!isa<triton::TransOp>(*users.begin()))
|
||||
return failure();
|
||||
|
||||
auto tmpShared = triton::gpu::SharedEncodingAttr::get(
|
||||
op->getContext(), dstSharedLayout.getVec(),
|
||||
dstSharedLayout.getPerPhase(), dstSharedLayout.getMaxPhase(),
|
||||
srcBlockedLayout.getOrder());
|
||||
auto tmpType = RankedTensorType::get(srcType.getShape(),
|
||||
srcType.getElementType(), tmpShared);
|
||||
auto tmpCvt = rewriter.create<triton::gpu::ConvertLayoutOp>(
|
||||
op->getLoc(), tmpType, cvt.getOperand());
|
||||
|
||||
auto newDstType = RankedTensorType::get(
|
||||
users.begin()->getResultTypes()[0].cast<RankedTensorType>().getShape(),
|
||||
srcType.getElementType(), dstSharedLayout);
|
||||
|
||||
auto newTrans = rewriter.create<triton::TransOp>(op->getLoc(), newDstType,
|
||||
tmpCvt.getResult());
|
||||
|
||||
rewriter.replaceOp(*users.begin(), newTrans.getResult());
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
class BlockedToMMA : public mlir::RewritePattern {
|
||||
int computeCapability;
|
||||
|
||||
@@ -646,13 +721,14 @@ public:
|
||||
: mlir::RewritePattern(triton::DotOp::getOperationName(), 2, context),
|
||||
computeCapability(computeCapability) {}
|
||||
|
||||
static SmallVector<unsigned, 2> getWarpsPerTile(const ArrayRef<int64_t> shape,
|
||||
static SmallVector<unsigned, 2> getWarpsPerTile(triton::DotOp dotOp,
|
||||
const ArrayRef<int64_t> shape,
|
||||
int version, int numWarps) {
|
||||
switch (version) {
|
||||
case 1:
|
||||
return warpsPerTile<1>(shape, numWarps);
|
||||
return warpsPerTileV1(dotOp, shape, numWarps);
|
||||
case 2:
|
||||
return warpsPerTile<2>(shape, numWarps);
|
||||
return warpsPerTileV2(dotOp, shape, numWarps);
|
||||
default:
|
||||
assert(false && "not supported version");
|
||||
return {0, 0};
|
||||
@@ -680,11 +756,12 @@ public:
|
||||
auto mod = op->getParentOfType<mlir::ModuleOp>();
|
||||
int numWarps = triton::gpu::TritonGPUDialect::getNumWarps(mod);
|
||||
int version = computeCapabilityToMMAVersion(computeCapability);
|
||||
|
||||
auto newRetType = RankedTensorType::get(
|
||||
retShape, oldRetType.getElementType(),
|
||||
triton::gpu::MmaEncodingAttr::get(
|
||||
oldRetType.getContext(), version,
|
||||
getWarpsPerTile(retShape, version, numWarps)));
|
||||
getWarpsPerTile(dotOp, retShape, version, numWarps)));
|
||||
// convert accumulator
|
||||
auto oldAcc = dotOp.getOperand(2);
|
||||
auto newAcc = rewriter.create<triton::gpu::ConvertLayoutOp>(
|
||||
@@ -704,8 +781,7 @@ public:
|
||||
a = rewriter.create<triton::gpu::ConvertLayoutOp>(a.getLoc(), newAType, a);
|
||||
b = rewriter.create<triton::gpu::ConvertLayoutOp>(b.getLoc(), newBType, b);
|
||||
auto newDot = rewriter.create<triton::DotOp>(
|
||||
dotOp.getLoc(), newRetType, a, b, newAcc, dotOp.allowTF32(),
|
||||
dotOp.transA(), dotOp.transB());
|
||||
dotOp.getLoc(), newRetType, a, b, newAcc, dotOp.allowTF32());
|
||||
|
||||
rewriter.replaceOpWithNewOp<triton::gpu::ConvertLayoutOp>(
|
||||
op, oldRetType, newDot.getResult());
|
||||
@@ -731,8 +807,9 @@ public:
|
||||
|
||||
mlir::RewritePatternSet patterns(context);
|
||||
|
||||
patterns.add<OptimizeBlockedToShared>(context);
|
||||
patterns.add<SimplifyConversion>(context);
|
||||
// patterns.add<DecomposeDotOperand>(context);
|
||||
patterns.add<DecomposeDotOperand>(context);
|
||||
patterns.add<RematerializeBackward>(context);
|
||||
patterns.add<RematerializeForward>(context);
|
||||
patterns.add<MoveConvertOutOfLoop>(context);
|
||||
|
@@ -25,18 +25,20 @@ static Type getI1SameShape(Value v) {
|
||||
tensorType.getEncoding());
|
||||
}
|
||||
|
||||
#define int_attr(num) builder.getI64IntegerAttr(num)
|
||||
|
||||
namespace {
|
||||
|
||||
class LoopPipeliner {
|
||||
/// cache forOp we are working on
|
||||
/// Cache forOp we are working on
|
||||
scf::ForOp forOp;
|
||||
|
||||
/// cache YieldOp for this forOp
|
||||
/// Cache YieldOp for this forOp
|
||||
scf::YieldOp yieldOp;
|
||||
|
||||
/// loads to be pipelined
|
||||
/// Loads to be pipelined
|
||||
SetVector<Value> loads;
|
||||
/// the value that each load will be mapped to (after layout conversion)
|
||||
/// The value that each load will be mapped to (after layout conversion)
|
||||
DenseMap<Value, Value> loadsMapping;
|
||||
/// load => buffer
|
||||
DenseMap<Value, Value> loadsBuffer;
|
||||
@@ -51,7 +53,7 @@ class LoopPipeliner {
|
||||
///
|
||||
Value loopIterIdx;
|
||||
|
||||
/// comments on numStages:
|
||||
/// Comments on numStages:
|
||||
/// [0, numStages-1) are in the prologue
|
||||
/// numStages-1 is appended after the loop body
|
||||
int numStages;
|
||||
@@ -61,6 +63,7 @@ class LoopPipeliner {
|
||||
|
||||
/// Block arguments that loads depend on
|
||||
DenseSet<BlockArgument> depArgs;
|
||||
|
||||
/// Operations (inside the loop body) that loads depend on
|
||||
DenseSet<Operation *> depOps;
|
||||
|
||||
@@ -71,7 +74,7 @@ class LoopPipeliner {
|
||||
|
||||
Value lookupOrDefault(Value origin, int stage);
|
||||
|
||||
/// returns a empty buffer of size <numStages, ...>
|
||||
/// Returns a empty buffer of size <numStages, ...>
|
||||
ttg::AllocTensorOp allocateEmptyBuffer(Operation *op, OpBuilder &builder);
|
||||
|
||||
public:
|
||||
@@ -84,7 +87,7 @@ public:
|
||||
/// Collect loads to pipeline. Return success if we can pipeline this loop
|
||||
LogicalResult initialize();
|
||||
|
||||
/// emit pipelined loads (before loop body)
|
||||
/// Emit pipelined loads (before loop body)
|
||||
void emitPrologue();
|
||||
|
||||
/// emit pipelined loads (after loop body)
|
||||
@@ -120,9 +123,13 @@ void LoopPipeliner::collectDeps(Value v, int stages, DenseSet<Value> &deps) {
|
||||
return;
|
||||
|
||||
if (auto arg = v.dyn_cast<BlockArgument>()) {
|
||||
deps.insert(v);
|
||||
// Note: we have iv as the first arg, so the op idx is arg.getArgNumber()-1
|
||||
collectDeps(yieldOp->getOperand(arg.getArgNumber() - 1), stages - 1, deps);
|
||||
if (arg.getArgNumber() > 0) {
|
||||
// Skip the first arg (loop induction variable)
|
||||
// Otherwise the op idx is arg.getArgNumber()-1
|
||||
deps.insert(v);
|
||||
collectDeps(yieldOp->getOperand(arg.getArgNumber() - 1), stages - 1,
|
||||
deps);
|
||||
}
|
||||
} else { // value
|
||||
// v might be in deps, but we still need to visit v.
|
||||
// This is because v might depend on value in previous iterations
|
||||
@@ -134,7 +141,7 @@ void LoopPipeliner::collectDeps(Value v, int stages, DenseSet<Value> &deps) {
|
||||
|
||||
ttg::AllocTensorOp LoopPipeliner::allocateEmptyBuffer(Operation *op,
|
||||
OpBuilder &builder) {
|
||||
// allocate a buffer for each pipelined tensor
|
||||
// Allocate a buffer for each pipelined tensor
|
||||
// shape: e.g. (numStages==4), <32x64xbf16> -> <4x32x64xbf16>
|
||||
Value convertLayout = loadsMapping[op->getResult(0)];
|
||||
if (auto tensorType = convertLayout.getType().dyn_cast<RankedTensorType>()) {
|
||||
@@ -215,9 +222,9 @@ LogicalResult LoopPipeliner::initialize() {
|
||||
loads.insert(loadOp);
|
||||
}
|
||||
|
||||
// we have some loads to pipeline
|
||||
// We have some loads to pipeline
|
||||
if (!loads.empty()) {
|
||||
// update depArgs & depOps
|
||||
// Update depArgs & depOps
|
||||
for (Value loadOp : loads) {
|
||||
for (Value dep : loadDeps[loadOp]) {
|
||||
// TODO: we should record the stage that the value is depended on
|
||||
@@ -244,23 +251,20 @@ void LoopPipeliner::emitPrologue() {
|
||||
setValueMapping(arg, operand.get(), 0);
|
||||
}
|
||||
|
||||
// helper to construct int attribute
|
||||
auto intAttr = [&](int64_t val) { return builder.getI64IntegerAttr(val); };
|
||||
|
||||
// prologue from [0, numStage-1)
|
||||
Value iv = forOp.getLowerBound();
|
||||
pipelineIterIdx = builder.create<arith::ConstantIntOp>(iv.getLoc(), 0, 32);
|
||||
for (int stage = 0; stage < numStages - 1; ++stage) {
|
||||
// special handling for induction variable as the increment is implicit
|
||||
// Special handling for induction variable as the increment is implicit
|
||||
if (stage != 0)
|
||||
iv = builder.create<arith::AddIOp>(iv.getLoc(), iv, forOp.getStep());
|
||||
setValueMapping(forOp.getInductionVar(), iv, stage);
|
||||
|
||||
// special handling for loop condition as there is no condition in ForOp
|
||||
// Special handling for loop condition as there is no condition in ForOp
|
||||
Value loopCond = builder.create<arith::CmpIOp>(
|
||||
iv.getLoc(), arith::CmpIPredicate::slt, iv, forOp.getUpperBound());
|
||||
|
||||
// rematerialize peeled values
|
||||
// Rematerialize peeled values
|
||||
SmallVector<Operation *> orderedDeps;
|
||||
for (Operation &op : forOp.getLoopBody().front()) {
|
||||
if (depOps.contains(&op))
|
||||
@@ -314,7 +318,7 @@ void LoopPipeliner::emitPrologue() {
|
||||
}
|
||||
}
|
||||
|
||||
// update mapping of results
|
||||
// Update mapping of results
|
||||
for (unsigned dstIdx : llvm::seq(unsigned(0), op->getNumResults())) {
|
||||
Value originalResult = op->getResult(dstIdx);
|
||||
// copy_async will update the value of its only use
|
||||
@@ -350,13 +354,14 @@ void LoopPipeliner::emitPrologue() {
|
||||
loadsBufferType[loadOp].getEncoding());
|
||||
Value extractSlice = builder.create<tensor::ExtractSliceOp>(
|
||||
loadOp.getLoc(), sliceType, loadStageBuffer[loadOp][numStages - 1],
|
||||
SmallVector<OpFoldResult>{intAttr(0), intAttr(0), intAttr(0)},
|
||||
SmallVector<OpFoldResult>{intAttr(1), intAttr(sliceType.getShape()[0]),
|
||||
intAttr(sliceType.getShape()[1])},
|
||||
SmallVector<OpFoldResult>{intAttr(1), intAttr(1), intAttr(1)});
|
||||
SmallVector<OpFoldResult>{int_attr(0), int_attr(0), int_attr(0)},
|
||||
SmallVector<OpFoldResult>{int_attr(1),
|
||||
int_attr(sliceType.getShape()[0]),
|
||||
int_attr(sliceType.getShape()[1])},
|
||||
SmallVector<OpFoldResult>{int_attr(1), int_attr(1), int_attr(1)});
|
||||
loadsExtract[loadOp] = extractSlice;
|
||||
}
|
||||
// bump up loopIterIdx, this is used for getting the correct slice for the
|
||||
// Bump up loopIterIdx, this is used for getting the correct slice for the
|
||||
// *next* iteration
|
||||
loopIterIdx = builder.create<arith::AddIOp>(
|
||||
loopIterIdx.getLoc(), loopIterIdx,
|
||||
@@ -365,9 +370,6 @@ void LoopPipeliner::emitPrologue() {
|
||||
|
||||
void LoopPipeliner::emitEpilogue() {
|
||||
// If there's any outstanding async copies, we need to wait for them.
|
||||
// TODO(Keren): We may want to completely avoid the async copies in the last
|
||||
// few iterations by setting is_masked attribute to true. We don't want to use
|
||||
// the mask operand because it's a tensor but not a scalar.
|
||||
OpBuilder builder(forOp);
|
||||
OpBuilder::InsertionGuard g(builder);
|
||||
builder.setInsertionPointAfter(forOp);
|
||||
@@ -376,14 +378,13 @@ void LoopPipeliner::emitEpilogue() {
|
||||
|
||||
scf::ForOp LoopPipeliner::createNewForOp() {
|
||||
OpBuilder builder(forOp);
|
||||
auto intAttr = [&](int64_t v) { return builder.getI64IntegerAttr(v); };
|
||||
|
||||
// order of new args:
|
||||
// (original args),
|
||||
// (insertSliceAsync buffer at stage numStages - 1) for each load
|
||||
// (extracted tensor) for each load
|
||||
// (depArgs at stage numStages-1)
|
||||
// (iv at stage numStages-1)
|
||||
// Order of new args:
|
||||
// (original args)
|
||||
// (insertSliceAsync buffer at stage numStages - 1) for each load
|
||||
// (extracted tensor) for each load
|
||||
// (depArgs at stage numStages - 1)
|
||||
// (iv at stage numStages - 2)
|
||||
// (pipeline iteration index)
|
||||
// (loop iteration index)
|
||||
SmallVector<Value> newLoopArgs;
|
||||
@@ -424,6 +425,7 @@ scf::ForOp LoopPipeliner::createNewForOp() {
|
||||
BlockAndValueMapping mapping;
|
||||
for (const auto &arg : llvm::enumerate(forOp.getRegionIterArgs()))
|
||||
mapping.map(arg.value(), newForOp.getRegionIterArgs()[arg.index()]);
|
||||
mapping.map(forOp.getInductionVar(), newForOp.getInductionVar());
|
||||
|
||||
// 2.1 clone the loop body, replace original args with args of the new ForOp
|
||||
// Insert async wait if necessary.
|
||||
@@ -465,15 +467,16 @@ scf::ForOp LoopPipeliner::createNewForOp() {
|
||||
newForOp.getRegionIterArgs()[argIdx + depArgsBeginIdx]);
|
||||
++argIdx;
|
||||
}
|
||||
// special handling for iv & loop condition
|
||||
// Special handling for iv & loop condition
|
||||
Value nextIV = builder.create<arith::AddIOp>(
|
||||
newForOp.getInductionVar().getLoc(),
|
||||
newForOp.getRegionIterArgs()[nextIVIdx], newForOp.getStep());
|
||||
Value nextLoopCond =
|
||||
builder.create<arith::CmpIOp>(nextIV.getLoc(), arith::CmpIPredicate::slt,
|
||||
nextIV, newForOp.getUpperBound());
|
||||
nextMapping.map(forOp.getInductionVar(), nextIV);
|
||||
|
||||
// slice index
|
||||
// Slice index
|
||||
SmallVector<Value> nextBuffers;
|
||||
SmallVector<Value> extractSlices;
|
||||
|
||||
@@ -490,7 +493,7 @@ scf::ForOp LoopPipeliner::createNewForOp() {
|
||||
|
||||
for (Operation *op : orderedDeps) {
|
||||
Operation *nextOp = nullptr;
|
||||
// update loading mask
|
||||
// Update loading mask
|
||||
if (loads.contains(op->getResult(0))) {
|
||||
auto loadOp = llvm::cast<triton::LoadOp>(op);
|
||||
Value mask = loadOp.mask();
|
||||
@@ -500,7 +503,7 @@ scf::ForOp LoopPipeliner::createNewForOp() {
|
||||
mask.getLoc(), mask.getType(), nextLoopCond);
|
||||
newMask = builder.create<arith::AndIOp>(
|
||||
mask.getLoc(), splatCond, nextMapping.lookupOrDefault(mask));
|
||||
// if mask is defined outside the loop, don't update the map more than
|
||||
// If mask is defined outside the loop, don't update the map more than
|
||||
// once
|
||||
if (!(forOp.isDefinedOutsideOfLoop(mask) && nextMapping.contains(mask)))
|
||||
nextMapping.map(mask, newMask);
|
||||
@@ -522,18 +525,19 @@ scf::ForOp LoopPipeliner::createNewForOp() {
|
||||
loadsBufferType[loadOp].getEncoding());
|
||||
nextOp = builder.create<tensor::ExtractSliceOp>(
|
||||
op->getLoc(), sliceType, insertAsyncOp,
|
||||
SmallVector<OpFoldResult>{extractSliceIndex, intAttr(0), intAttr(0)},
|
||||
SmallVector<OpFoldResult>{intAttr(1),
|
||||
intAttr(sliceType.getShape()[0]),
|
||||
intAttr(sliceType.getShape()[1])},
|
||||
SmallVector<OpFoldResult>{intAttr(1), intAttr(1), intAttr(1)});
|
||||
SmallVector<OpFoldResult>{extractSliceIndex, int_attr(0),
|
||||
int_attr(0)},
|
||||
SmallVector<OpFoldResult>{int_attr(1),
|
||||
int_attr(sliceType.getShape()[0]),
|
||||
int_attr(sliceType.getShape()[1])},
|
||||
SmallVector<OpFoldResult>{int_attr(1), int_attr(1), int_attr(1)});
|
||||
extractSlices.push_back(nextOp->getResult(0));
|
||||
} else
|
||||
nextOp = builder.clone(*op, nextMapping);
|
||||
// update mapping of results
|
||||
// Update mapping of results
|
||||
for (unsigned dstIdx : llvm::seq(unsigned(0), op->getNumResults())) {
|
||||
nextMapping.map(op->getResult(dstIdx), nextOp->getResult(dstIdx));
|
||||
// if this is a loop-carried value, update the mapping for yield
|
||||
// If this is a loop-carried value, update the mapping for yield
|
||||
auto originYield = cast<scf::YieldOp>(forOp.getBody()->getTerminator());
|
||||
for (OpOperand &operand : originYield->getOpOperands()) {
|
||||
if (operand.get() == op->getResult(dstIdx)) {
|
||||
@@ -583,7 +587,7 @@ scf::ForOp LoopPipeliner::createNewForOp() {
|
||||
it->getDefiningOp()->moveAfter(asyncWait);
|
||||
}
|
||||
|
||||
// bump iteration count
|
||||
// Bump iteration count
|
||||
pipelineIterIdx = builder.create<arith::AddIOp>(
|
||||
nextIV.getLoc(), pipelineIterIdx,
|
||||
builder.create<arith::ConstantIntOp>(nextIV.getLoc(), 1, 32));
|
||||
@@ -600,9 +604,11 @@ scf::ForOp LoopPipeliner::createNewForOp() {
|
||||
for (Value nextSlice : extractSlices)
|
||||
yieldValues.push_back(nextSlice);
|
||||
|
||||
for (size_t i = depArgsBeginIdx; i < nextIVIdx; ++i)
|
||||
yieldValues.push_back(
|
||||
depArgsMapping.lookup(newForOp.getRegionIterArgs()[i]));
|
||||
for (size_t i = depArgsBeginIdx; i < nextIVIdx; ++i) {
|
||||
auto arg = newForOp.getRegionIterArgs()[i];
|
||||
assert(depArgsMapping.count(arg) && "Missing loop-carried value");
|
||||
yieldValues.push_back(depArgsMapping[arg]);
|
||||
}
|
||||
yieldValues.push_back(nextIV);
|
||||
yieldValues.push_back(pipelineIterIdx);
|
||||
yieldValues.push_back(loopIterIdx);
|
||||
|
@@ -131,6 +131,11 @@ LogicalResult Prefetcher::initialize() {
|
||||
if (dotsInFor.empty())
|
||||
return failure();
|
||||
|
||||
// TODO: segfault (original for still has uses)
|
||||
// when used in flash attention that has 2 dots in the loop
|
||||
if (dotsInFor.size() > 1)
|
||||
return failure();
|
||||
|
||||
// returns source of cvt
|
||||
auto getPrefetchSrc = [](Value v) -> Value {
|
||||
if (auto cvt = v.getDefiningOp<triton::gpu::ConvertLayoutOp>())
|
||||
|
@@ -25,8 +25,8 @@ def get_build_type():
|
||||
elif check_env_flag("REL_WITH_DEB_INFO"):
|
||||
return "RelWithDebInfo"
|
||||
else:
|
||||
return "Debug"
|
||||
# TODO(Keren): Restore this before we merge into master
|
||||
return "RelWithDebInfo"
|
||||
# TODO: change to release when stable enough
|
||||
#return "Release"
|
||||
|
||||
|
||||
|
@@ -11,6 +11,7 @@
|
||||
#include "mlir/Parser.h"
|
||||
#include "mlir/Support/FileUtilities.h"
|
||||
|
||||
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
|
||||
#include "triton/Analysis/Allocation.h"
|
||||
#include "triton/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.h"
|
||||
#include "triton/Conversion/TritonToTritonGPU/TritonToTritonGPU.h"
|
||||
@@ -115,6 +116,10 @@ void init_triton_ir(py::module &&m) {
|
||||
.def(py::init<>())
|
||||
.def("load_triton", [](mlir::MLIRContext &self) {
|
||||
self.getOrLoadDialect<mlir::triton::TritonDialect>();
|
||||
// we load LLVM because the frontend uses LLVM.undef for
|
||||
// some placeholders
|
||||
self.getOrLoadDialect<mlir::triton::TritonDialect>();
|
||||
self.getOrLoadDialect<mlir::LLVM::LLVMDialect>();
|
||||
});
|
||||
// .def(py::init([](){
|
||||
// mlir::MLIRContext context;
|
||||
@@ -187,6 +192,7 @@ void init_triton_ir(py::module &&m) {
|
||||
/* issue a warning */
|
||||
}
|
||||
})
|
||||
.def("get_context", &mlir::Value::getContext)
|
||||
.def("replace_all_uses_with",
|
||||
[](mlir::Value &self, mlir::Value &newValue) {
|
||||
self.replaceAllUsesWith(newValue);
|
||||
@@ -335,10 +341,21 @@ void init_triton_ir(py::module &&m) {
|
||||
return funcs[0];
|
||||
});
|
||||
|
||||
m.def("make_attr",
|
||||
[](const std::vector<int> &values, mlir::MLIRContext &context) {
|
||||
return mlir::DenseIntElementsAttr::get(
|
||||
mlir::RankedTensorType::get(
|
||||
{static_cast<int64_t>(values.size())},
|
||||
mlir::IntegerType::get(&context, 32)),
|
||||
values)
|
||||
.cast<mlir::Attribute>();
|
||||
});
|
||||
|
||||
m.def(
|
||||
"parse_mlir_module",
|
||||
[](const std::string &inputFilename, mlir::MLIRContext &context) {
|
||||
// initialize registry
|
||||
// note: we initialize llvm for undef
|
||||
mlir::DialectRegistry registry;
|
||||
registry.insert<mlir::triton::TritonDialect,
|
||||
mlir::triton::gpu::TritonGPUDialect,
|
||||
@@ -1068,6 +1085,16 @@ void init_triton_ir(py::module &&m) {
|
||||
mlir::RankedTensorType::get(shape, lhsType.getElementType()),
|
||||
lhs, rhs);
|
||||
})
|
||||
.def("create_trans",
|
||||
[](mlir::OpBuilder &self, mlir::Value &arg) -> mlir::Value {
|
||||
auto loc = self.getUnknownLoc();
|
||||
auto argType = arg.getType().dyn_cast<mlir::RankedTensorType>();
|
||||
auto argEltType = argType.getElementType();
|
||||
std::vector<int64_t> retShape = argType.getShape();
|
||||
std::reverse(retShape.begin(), retShape.end());
|
||||
return self.create<mlir::triton::TransOp>(
|
||||
loc, mlir::RankedTensorType::get(retShape, argEltType), arg);
|
||||
})
|
||||
.def("create_broadcast",
|
||||
[](mlir::OpBuilder &self, mlir::Value &arg,
|
||||
std::vector<int64_t> &shape) -> mlir::Value {
|
||||
@@ -1096,7 +1123,8 @@ void init_triton_ir(py::module &&m) {
|
||||
mlir::Value &val) -> mlir::Value {
|
||||
auto loc = self.getUnknownLoc();
|
||||
mlir::Type dstType;
|
||||
if (auto srcTensorType = ptr.getType().dyn_cast<mlir::RankedTensorType>()) {
|
||||
if (auto srcTensorType =
|
||||
ptr.getType().dyn_cast<mlir::RankedTensorType>()) {
|
||||
mlir::Type dstElemType = srcTensorType.getElementType()
|
||||
.cast<mlir::triton::PointerType>()
|
||||
.getPointeeType();
|
||||
@@ -1156,11 +1184,10 @@ void init_triton_ir(py::module &&m) {
|
||||
})
|
||||
.def("create_dot",
|
||||
[](mlir::OpBuilder &self, mlir::Value &a, mlir::Value &b,
|
||||
mlir::Value &c, bool allowTF32, bool transA,
|
||||
bool transB) -> mlir::Value {
|
||||
mlir::Value &c, bool allowTF32) -> mlir::Value {
|
||||
auto loc = self.getUnknownLoc();
|
||||
return self.create<mlir::triton::DotOp>(loc, c.getType(), a, b, c,
|
||||
allowTF32, transA, transB);
|
||||
allowTF32);
|
||||
})
|
||||
.def("create_exp",
|
||||
[](mlir::OpBuilder &self, mlir::Value &val) -> mlir::Value {
|
||||
@@ -1195,10 +1222,11 @@ void init_triton_ir(py::module &&m) {
|
||||
operand.getType().dyn_cast<mlir::RankedTensorType>();
|
||||
std::vector<int64_t> shape = inputTensorType.getShape();
|
||||
shape.erase(shape.begin() + axis);
|
||||
mlir::Type resType = inputTensorType.getElementType();
|
||||
bool withIndex = mlir::triton::ReduceOp::withIndex(redOp);
|
||||
mlir::Type resType = withIndex ? self.getI32Type()
|
||||
: inputTensorType.getElementType();
|
||||
if (!shape.empty()) {
|
||||
resType = mlir::RankedTensorType::get(
|
||||
shape, inputTensorType.getElementType());
|
||||
resType = mlir::RankedTensorType::get(shape, resType);
|
||||
}
|
||||
return self.create<mlir::triton::ReduceOp>(loc, resType, redOp,
|
||||
operand, axis);
|
||||
@@ -1231,6 +1259,12 @@ void init_triton_ir(py::module &&m) {
|
||||
mlir::StringAttr::get(self.getContext(),
|
||||
llvm::StringRef(prefix)),
|
||||
values);
|
||||
})
|
||||
// Undef
|
||||
.def("create_undef",
|
||||
[](mlir::OpBuilder &self, mlir::Type &type) -> mlir::Value {
|
||||
auto loc = self.getUnknownLoc();
|
||||
return self.create<::mlir::LLVM::UndefOp>(loc, type);
|
||||
});
|
||||
|
||||
py::class_<mlir::PassManager>(m, "pass_manager")
|
||||
@@ -1348,6 +1382,11 @@ void init_triton_translation(py::module &m) {
|
||||
llvm::SMDiagnostic error;
|
||||
std::unique_ptr<llvm::Module> module =
|
||||
llvm::parseIR(buffer->getMemBufferRef(), error, context);
|
||||
if (!module)
|
||||
llvm::report_fatal_error(
|
||||
"failed to parse IR: " + error.getMessage() +
|
||||
"lineno: " + std::to_string(error.getLineNo()));
|
||||
|
||||
// translate module to PTX
|
||||
auto ptxCode =
|
||||
triton::translateLLVMIRToPTX(*module, capability, version);
|
||||
|
18
python/tests/libdevice_testutil.py
Normal file
18
python/tests/libdevice_testutil.py
Normal file
@@ -0,0 +1,18 @@
|
||||
import os
|
||||
from typing import Optional
|
||||
|
||||
_SYSTEM_LIBDEVICE_SEARCH_PATHS = [
|
||||
'/usr/lib/cuda/nvvm/libdevice/libdevice.10.bc',
|
||||
'/usr/local/cuda/nvvm/libdevice/libdevice.10.bc',
|
||||
]
|
||||
|
||||
SYSTEM_LIBDEVICE_PATH: Optional[str] = None
|
||||
for _p in _SYSTEM_LIBDEVICE_SEARCH_PATHS:
|
||||
if os.path.exists(_p):
|
||||
SYSTEM_LIBDEVICE_PATH = _p
|
||||
|
||||
def system_libdevice_path() -> str:
|
||||
assert SYSTEM_LIBDEVICE_PATH is not None, \
|
||||
"Could not find libdevice.10.bc path"
|
||||
return SYSTEM_LIBDEVICE_PATH
|
||||
|
188
python/tests/test_blocksparse.py
Normal file
188
python/tests/test_blocksparse.py
Normal file
@@ -0,0 +1,188 @@
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
import triton
|
||||
|
||||
# TODO: float32 fails
|
||||
|
||||
@pytest.mark.parametrize("MODE", ["sdd", "dds", "dsd"])
|
||||
@pytest.mark.parametrize("TRANS_B", [False, True])
|
||||
@pytest.mark.parametrize("TRANS_A", [False, True])
|
||||
@pytest.mark.parametrize("BLOCK", [16, 32, 64])
|
||||
@pytest.mark.parametrize("DTYPE", [torch.float16])
|
||||
def test_matmul(MODE, TRANS_A, TRANS_B, BLOCK, DTYPE, Z=3, H=2, M=512, N=256, K=384):
|
||||
seed = 0
|
||||
torch.manual_seed(seed)
|
||||
is_sdd = MODE == "sdd"
|
||||
is_dsd = MODE == "dsd"
|
||||
is_dds = MODE == "dds"
|
||||
do_sparsify = lambda x: triton.testing.sparsify_tensor(x, layout, BLOCK)
|
||||
do_mask = lambda x: triton.testing.mask_tensor(x, layout, BLOCK)
|
||||
# create inputs
|
||||
# create op
|
||||
a_shape = (Z, H, K, M) if TRANS_A else (Z, H, M, K)
|
||||
b_shape = (Z, H, N, K) if TRANS_B else (Z, H, K, N)
|
||||
c_shape = (Z, H, M, N)
|
||||
shape = {
|
||||
"sdd": (M, N),
|
||||
"dsd": (a_shape[2], a_shape[3]),
|
||||
"dds": (b_shape[2], b_shape[3]),
|
||||
}[MODE]
|
||||
layout = torch.randint(2, (H, shape[0] // BLOCK, shape[1] // BLOCK))
|
||||
layout[1, 2, :] = 0
|
||||
layout[1, :, 1] = 0
|
||||
# create data
|
||||
a_ref, a_tri = triton.testing.make_pair(a_shape, alpha=.1, dtype=DTYPE)
|
||||
b_ref, b_tri = triton.testing.make_pair(b_shape, alpha=.1, dtype=DTYPE)
|
||||
dc_ref, dc_tri = triton.testing.make_pair(c_shape, dtype=DTYPE)
|
||||
# compute [torch]
|
||||
dc_ref = do_mask(dc_ref) if is_sdd else dc_ref
|
||||
a_ref = do_mask(a_ref) if is_dsd else a_ref
|
||||
b_ref = do_mask(b_ref) if is_dds else b_ref
|
||||
a_ref.requires_grad_().retain_grad()
|
||||
b_ref.requires_grad_().retain_grad()
|
||||
c_ref = torch.matmul(a_ref.transpose(2, 3) if TRANS_A else a_ref,
|
||||
b_ref.transpose(2, 3) if TRANS_B else b_ref)
|
||||
c_ref.backward(dc_ref)
|
||||
c_ref = do_sparsify(c_ref) if is_sdd else c_ref
|
||||
da_ref = do_sparsify(a_ref.grad) if is_dsd else a_ref.grad
|
||||
db_ref = do_sparsify(b_ref.grad) if is_dds else b_ref.grad
|
||||
# triton result
|
||||
dc_tri = do_sparsify(dc_tri) if is_sdd else dc_tri
|
||||
a_tri = do_sparsify(a_tri) if is_dsd else a_tri
|
||||
b_tri = do_sparsify(b_tri) if is_dds else b_tri
|
||||
a_tri.requires_grad_().retain_grad()
|
||||
b_tri.requires_grad_().retain_grad()
|
||||
op = triton.ops.blocksparse.matmul(layout, BLOCK, MODE, trans_a=TRANS_A, trans_b=TRANS_B, device="cuda")
|
||||
c_tri = triton.testing.catch_oor(lambda: op(a_tri, b_tri), pytest)
|
||||
triton.testing.catch_oor(lambda: c_tri.backward(dc_tri), pytest)
|
||||
da_tri = a_tri.grad
|
||||
db_tri = b_tri.grad
|
||||
# compare
|
||||
triton.testing.assert_almost_equal(c_ref, c_tri)
|
||||
triton.testing.assert_almost_equal(da_ref, da_tri)
|
||||
triton.testing.assert_almost_equal(db_ref, db_tri)
|
||||
|
||||
|
||||
configs = [
|
||||
(16, 256),
|
||||
(32, 576),
|
||||
(64, 1871),
|
||||
(128, 2511),
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.parametrize("is_dense", [False, True])
|
||||
@pytest.mark.parametrize("BLOCK, WIDTH", configs)
|
||||
def test_softmax(BLOCK, WIDTH, is_dense, Z=2, H=2, is_causal=True, scale=0.4):
|
||||
# set seed
|
||||
torch.random.manual_seed(0)
|
||||
Z, H, M, N = 2, 3, WIDTH, WIDTH
|
||||
# initialize layout
|
||||
# make sure each row has at least one non-zero element
|
||||
layout = torch.randint(2, (H, M // BLOCK, N // BLOCK))
|
||||
if is_dense:
|
||||
layout[:] = 1
|
||||
else:
|
||||
layout[1, 2, :] = 0
|
||||
layout[1, :, 1] = 0
|
||||
# initialize data
|
||||
a_shape = (Z, H, M, N)
|
||||
a_ref, a_tri = triton.testing.make_pair(a_shape)
|
||||
dout_ref, dout_tri = triton.testing.make_pair(a_shape)
|
||||
# compute [torch]
|
||||
a_ref = triton.testing.mask_tensor(a_ref, layout, BLOCK, value=float("-inf"))
|
||||
a_ref.retain_grad()
|
||||
at_mask = torch.ones((M, N), device="cuda")
|
||||
if is_causal:
|
||||
at_mask = torch.tril(at_mask)
|
||||
M = at_mask[None, None, :, :] + torch.zeros_like(a_ref)
|
||||
a_ref[M == 0] = float("-inf")
|
||||
out_ref = torch.softmax(a_ref * scale, -1)
|
||||
out_ref.backward(dout_ref)
|
||||
out_ref = triton.testing.sparsify_tensor(out_ref, layout, BLOCK)
|
||||
da_ref = triton.testing.sparsify_tensor(a_ref.grad, layout, BLOCK)
|
||||
# compute [triton]
|
||||
a_tri = triton.testing.sparsify_tensor(a_tri, layout, BLOCK)
|
||||
a_tri.retain_grad()
|
||||
dout_tri = triton.testing.sparsify_tensor(dout_tri, layout, BLOCK)
|
||||
op = triton.ops.blocksparse.softmax(layout, BLOCK, device="cuda", is_dense=is_dense)
|
||||
out_tri = op(a_tri, scale=scale, is_causal=is_causal)
|
||||
out_tri.backward(dout_tri)
|
||||
da_tri = a_tri.grad
|
||||
# compare
|
||||
triton.testing.assert_almost_equal(out_tri, out_ref)
|
||||
triton.testing.assert_almost_equal(da_tri, da_ref)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("block", [16, 32, 64])
|
||||
@pytest.mark.parametrize("dtype", [torch.float16])
|
||||
def test_attention_fwd_bwd(
|
||||
block,
|
||||
dtype,
|
||||
input_scale=1.0,
|
||||
scale=1 / 8.0,
|
||||
n_ctx=256,
|
||||
batch_size=2,
|
||||
n_heads=2,
|
||||
):
|
||||
# inputs
|
||||
qkv_shape = (batch_size, n_heads, n_ctx, 64)
|
||||
qkvs = [
|
||||
torch.nn.Parameter(input_scale * torch.randn(qkv_shape), requires_grad=True).to(dtype).cuda() for _ in range(3)
|
||||
]
|
||||
|
||||
# Triton:
|
||||
n_blocks = n_ctx // block
|
||||
layout = torch.tril(torch.ones([n_heads, n_blocks, n_blocks], dtype=torch.long))
|
||||
query, key, value = [x.clone() for x in qkvs]
|
||||
query.retain_grad()
|
||||
key.retain_grad()
|
||||
value.retain_grad()
|
||||
attn_out = triton_attention(layout, block, query=query, key=key, value=value, scale=scale)
|
||||
# ad hoc loss
|
||||
loss = (attn_out ** 2).mean()
|
||||
loss.backward()
|
||||
grads = [query.grad, key.grad, value.grad]
|
||||
|
||||
# Torch version:
|
||||
torch_q, torch_k, torch_v = [x.clone() for x in qkvs]
|
||||
attn_mask = torch.ones([n_ctx, n_ctx], device="cuda", dtype=dtype)
|
||||
attn_mask = torch.tril(attn_mask, diagonal=0)
|
||||
attn_mask = 1e6 * (-1 + (attn_mask.reshape((1, 1, n_ctx, n_ctx)).cuda()))
|
||||
torch_q.retain_grad()
|
||||
torch_k.retain_grad()
|
||||
torch_v.retain_grad()
|
||||
scores = scale * torch.einsum("bhsd,bhtd->bhst", torch_q, torch_k)
|
||||
scores = scores + attn_mask
|
||||
probs = torch.softmax(scores, dim=-1)
|
||||
torch_attn_out = torch.einsum("bhst,bhtd->bhsd", probs, torch_v)
|
||||
# ad hoc loss
|
||||
torch_loss = (torch_attn_out ** 2).mean()
|
||||
torch_loss.backward()
|
||||
torch_grads = [torch_q.grad, torch_k.grad, torch_v.grad]
|
||||
|
||||
# comparison
|
||||
# print(f"Triton loss {loss} and torch loss {torch_loss}. Also checking grads...")
|
||||
triton.testing.assert_almost_equal(loss, torch_loss)
|
||||
for g1, g2 in zip(grads, torch_grads):
|
||||
triton.testing.assert_almost_equal(g1, g2)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("block", [16, 32, 64])
|
||||
def triton_attention(
|
||||
layout,
|
||||
block: int,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
scale: float,
|
||||
):
|
||||
sparse_dot_sdd_nt = triton.ops.blocksparse.matmul(layout, block, "sdd", trans_a=False, trans_b=True, device=value.device)
|
||||
sparse_dot_dsd_nn = triton.ops.blocksparse.matmul(layout, block, "dsd", trans_a=False, trans_b=False, device=value.device)
|
||||
sparse_softmax = triton.ops.blocksparse.softmax(layout, block, device=value.device)
|
||||
|
||||
w = sparse_dot_sdd_nt(query, key)
|
||||
w = sparse_softmax(w, scale=scale, is_causal=True)
|
||||
a = sparse_dot_dsd_nn(w, value)
|
||||
return a
|
@@ -12,6 +12,7 @@ import triton
|
||||
import triton._C.libtriton.triton as _triton
|
||||
import triton.language as tl
|
||||
from triton.runtime.jit import JITFunction, TensorWrapper, reinterpret
|
||||
from tests.libdevice_testutil import system_libdevice_path
|
||||
|
||||
int_dtypes = ['int8', 'int16', 'int32', 'int64']
|
||||
uint_dtypes = ['uint8', 'uint16', 'uint32', 'uint64']
|
||||
@@ -667,7 +668,6 @@ def test_tensor_atomic_rmw(shape, axis, device="cuda"):
|
||||
tl.atomic_add(Z + off1, z)
|
||||
rs = RandomState(17)
|
||||
x = numpy_random((shape0, shape1), dtype_str="float32", rs=rs)
|
||||
print(x)
|
||||
# reference result
|
||||
z_ref = np.sum(x, axis=axis, keepdims=False)
|
||||
# triton result
|
||||
@@ -677,36 +677,7 @@ def test_tensor_atomic_rmw(shape, axis, device="cuda"):
|
||||
kernel[(1,)](z_tri, x_tri, axis, shape0, shape1)
|
||||
np.testing.assert_allclose(z_ref, to_numpy(z_tri), rtol=1e-4)
|
||||
|
||||
# def test_atomic_cas():
|
||||
# # 1. make sure that atomic_cas changes the original value (Lock)
|
||||
# @triton.jit
|
||||
# def change_value(Lock):
|
||||
# tl.atomic_cas(Lock, 0, 1)
|
||||
|
||||
# Lock = torch.zeros((1,), device='cuda', dtype=torch.int32)
|
||||
# change_value[(1,)](Lock)
|
||||
|
||||
# assert (Lock[0] == 1)
|
||||
|
||||
# # 2. only one block enters the critical section
|
||||
# @triton.jit
|
||||
# def serialized_add(data, Lock):
|
||||
# ptrs = data + tl.arange(0, 128)
|
||||
# while tl.atomic_cas(Lock, 0, 1) == 1:
|
||||
# pass
|
||||
|
||||
# tl.store(ptrs, tl.load(ptrs) + 1.0)
|
||||
|
||||
# # release lock
|
||||
# tl.atomic_xchg(Lock, 0)
|
||||
|
||||
# Lock = torch.zeros((1,), device='cuda', dtype=torch.int32)
|
||||
# data = torch.zeros((128,), device='cuda', dtype=torch.float32)
|
||||
# ref = torch.full((128,), 64.0)
|
||||
# serialized_add[(64,)](data, Lock)
|
||||
# triton.testing.assert_almost_equal(data, ref)
|
||||
|
||||
def test_simple_atomic_cas():
|
||||
def test_atomic_cas():
|
||||
# 1. make sure that atomic_cas changes the original value (Lock)
|
||||
@triton.jit
|
||||
def change_value(Lock):
|
||||
@@ -717,6 +688,25 @@ def test_simple_atomic_cas():
|
||||
|
||||
assert (Lock[0] == 1)
|
||||
|
||||
# 2. only one block enters the critical section
|
||||
@triton.jit
|
||||
def serialized_add(data, Lock):
|
||||
ptrs = data + tl.arange(0, 128)
|
||||
while tl.atomic_cas(Lock, 0, 1) == 1:
|
||||
pass
|
||||
|
||||
tl.store(ptrs, tl.load(ptrs) + 1.0)
|
||||
|
||||
# release lock
|
||||
tl.atomic_xchg(Lock, 0)
|
||||
|
||||
Lock = torch.zeros((1,), device='cuda', dtype=torch.int32)
|
||||
data = torch.zeros((128,), device='cuda', dtype=torch.float32)
|
||||
ref = torch.full((128,), 64.0)
|
||||
serialized_add[(64,)](data, Lock)
|
||||
triton.testing.assert_almost_equal(data, ref)
|
||||
|
||||
|
||||
# # ---------------
|
||||
# # test cast
|
||||
# # ---------------
|
||||
@@ -1077,122 +1067,126 @@ def test_permute(dtype_str, shape, perm, device='cuda'):
|
||||
# # ---------------
|
||||
|
||||
|
||||
# @pytest.mark.parametrize("epilogue, allow_tf32, dtype",
|
||||
# [(epilogue, allow_tf32, dtype)
|
||||
# for epilogue in ['none', 'trans', 'add-matrix', 'add-rows', 'add-cols', 'softmax', 'chain-dot']
|
||||
# for allow_tf32 in [True, False]
|
||||
# for dtype in ['float16']
|
||||
# if not (allow_tf32 and (dtype in ['float16']))])
|
||||
# def test_dot(epilogue, allow_tf32, dtype, device='cuda'):
|
||||
# cc = _triton.runtime.cc(_triton.runtime.backend.CUDA, torch.cuda.current_device())
|
||||
# if cc < 80:
|
||||
# if dtype == 'int8':
|
||||
# pytest.skip("Only test int8 on devices with sm >= 80")
|
||||
# elif dtype == 'float32' and allow_tf32:
|
||||
# pytest.skip("Only test tf32 on devices with sm >= 80")
|
||||
@pytest.mark.parametrize("epilogue, allow_tf32, dtype",
|
||||
[(epilogue, allow_tf32, dtype)
|
||||
for epilogue in ['none', 'trans', 'add-matrix', 'add-rows', 'add-cols', 'softmax', 'chain-dot']
|
||||
for allow_tf32 in [True, False]
|
||||
for dtype in ['float16']
|
||||
if not (allow_tf32 and (dtype in ['float16']))])
|
||||
def test_dot(epilogue, allow_tf32, dtype, device='cuda'):
|
||||
capability = torch.cuda.get_device_capability()
|
||||
if capability[0] < 80:
|
||||
if dtype == 'int8':
|
||||
pytest.skip("Only test int8 on devices with sm >= 80")
|
||||
elif dtype == 'float32' and allow_tf32:
|
||||
pytest.skip("Only test tf32 on devices with sm >= 80")
|
||||
|
||||
# M, N, K = 128, 128, 64
|
||||
# num_warps = 8
|
||||
# trans_a, trans_b = False, False
|
||||
M, N, K = 64, 64, 64
|
||||
num_warps = 4
|
||||
trans_a, trans_b = False, False
|
||||
|
||||
# # triton kernel
|
||||
# @triton.jit
|
||||
# def kernel(X, stride_xm, stride_xk,
|
||||
# Y, stride_yk, stride_yn,
|
||||
# W, stride_wn, stride_wl,
|
||||
# Z, stride_zm, stride_zn,
|
||||
# BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr,
|
||||
# ADD_MATRIX: tl.constexpr, ADD_ROWS: tl.constexpr, ADD_COLS: tl.constexpr,
|
||||
# ALLOW_TF32: tl.constexpr,
|
||||
# DO_SOFTMAX: tl.constexpr, CHAIN_DOT: tl.constexpr,
|
||||
# TRANS_A: tl.constexpr, TRANS_B: tl.constexpr):
|
||||
# off_m = tl.arange(0, BLOCK_M)
|
||||
# off_n = tl.arange(0, BLOCK_N)
|
||||
# off_l = tl.arange(0, BLOCK_N)
|
||||
# off_k = tl.arange(0, BLOCK_K)
|
||||
# Xs = X + off_m[:, None] * stride_xm + off_k[None, :] * stride_xk
|
||||
# Ys = Y + off_k[:, None] * stride_yk + off_n[None, :] * stride_yn
|
||||
# Ws = W + off_n[:, None] * stride_wn + off_l[None, :] * stride_wl
|
||||
# Zs = Z + off_m[:, None] * stride_zm + off_n[None, :] * stride_zn
|
||||
# z = tl.dot(tl.load(Xs), tl.load(Ys), trans_a=TRANS_A, trans_b=TRANS_B, allow_tf32=ALLOW_TF32)
|
||||
# if ADD_MATRIX:
|
||||
# z += tl.load(Zs)
|
||||
# if ADD_ROWS:
|
||||
# ZRs = Z + off_m * stride_zm
|
||||
# z += tl.load(ZRs)[:, None]
|
||||
# if ADD_COLS:
|
||||
# ZCs = Z + off_n * stride_zn
|
||||
# z += tl.load(ZCs)[None, :]
|
||||
# if DO_SOFTMAX:
|
||||
# max = tl.max(z, 1)
|
||||
# z = z - max[:, None]
|
||||
# num = tl.exp(z)
|
||||
# den = tl.sum(num, 1)
|
||||
# z = num / den[:, None]
|
||||
# if CHAIN_DOT:
|
||||
# # tl.store(Zs, z)
|
||||
# # tl.debug_barrier()
|
||||
# z = tl.dot(z.to(tl.float16), tl.load(Ws), trans_a=TRANS_A)
|
||||
# tl.store(Zs, z)
|
||||
# # input
|
||||
# rs = RandomState(17)
|
||||
# x = numpy_random((K, M) if trans_a else (M, K), dtype_str=dtype, rs=rs) * .1
|
||||
# y = numpy_random((N, K) if trans_b else (K, N), dtype_str=dtype, rs=rs) * .1
|
||||
# w = numpy_random((N, N), dtype_str=dtype, rs=rs) * .1
|
||||
# if allow_tf32:
|
||||
# x = (x.view('uint32') & np.uint32(0xffffe000)).view('float32')
|
||||
# y = (y.view('uint32') & np.uint32(0xffffe000)).view('float32')
|
||||
# w = (w.view('uint32') & np.uint32(0xffffe000)).view('float32')
|
||||
# x_tri = to_triton(x, device=device)
|
||||
# y_tri = to_triton(y, device=device)
|
||||
# w_tri = to_triton(w, device=device)
|
||||
# # triton result
|
||||
# z = 1 + numpy_random((M, N), dtype_str=dtype, rs=rs) * .1
|
||||
# z_tri = to_triton(z, device=device)
|
||||
# if epilogue == 'trans':
|
||||
# z_tri = torch.as_strided(z_tri, (M, N), z_tri.stride()[::-1])
|
||||
# pgm = kernel[(1, 1)](x_tri, x_tri.stride(0), x_tri.stride(1),
|
||||
# y_tri, y_tri.stride(0), y_tri.stride(1),
|
||||
# w_tri, w_tri.stride(0), w_tri.stride(1),
|
||||
# z_tri, z_tri.stride(0), z_tri.stride(1),
|
||||
# TRANS_A=trans_a, TRANS_B=trans_b,
|
||||
# BLOCK_M=M, BLOCK_K=K, BLOCK_N=N,
|
||||
# ADD_MATRIX=epilogue == 'add-matrix',
|
||||
# ADD_ROWS=epilogue == 'add-rows',
|
||||
# ADD_COLS=epilogue == 'add-cols',
|
||||
# DO_SOFTMAX=epilogue == 'softmax',
|
||||
# CHAIN_DOT=epilogue == 'chain-dot',
|
||||
# ALLOW_TF32=allow_tf32,
|
||||
# num_warps=num_warps)
|
||||
# # torch result
|
||||
# x_ref = x.T if trans_a else x
|
||||
# y_ref = y.T if trans_b else y
|
||||
# z_ref = np.matmul(x_ref, y_ref)
|
||||
# if epilogue == 'add-matrix':
|
||||
# z_ref += z
|
||||
# if epilogue == 'add-rows':
|
||||
# z_ref += z[:, 0][:, None]
|
||||
# if epilogue == 'add-cols':
|
||||
# z_ref += z[0, :][None, :]
|
||||
# if epilogue == 'softmax':
|
||||
# num = np.exp(z_ref - np.max(z_ref, axis=-1, keepdims=True))
|
||||
# denom = np.sum(num, axis=-1, keepdims=True)
|
||||
# z_ref = num / denom
|
||||
# if epilogue == 'chain-dot':
|
||||
# z_ref = np.matmul(z_ref.T if trans_a else z_ref, w)
|
||||
# # compare
|
||||
# # print(z_ref[:,0], z_tri[:,0])
|
||||
# np.testing.assert_allclose(z_ref, to_numpy(z_tri), rtol=0.01)
|
||||
# # make sure ld/st are vectorized
|
||||
# ptx = pgm.asm['ptx']
|
||||
# assert 'ld.global.v4' in ptx
|
||||
# assert 'st.global.v4' in ptx
|
||||
# if allow_tf32:
|
||||
# assert 'mma.sync.aligned.m16n8k8.row.col.f32.tf32.tf32.f32' in ptx
|
||||
# elif dtype == 'float32':
|
||||
# assert 'mma.sync.aligned.m16n8k8.row.col.f32.tf32.tf32.f32' not in ptx
|
||||
# elif dtype == 'int8':
|
||||
# assert 'mma.sync.aligned.m16n8k32.row.col.satfinite.s32.s8.s8.s32' in ptx
|
||||
# triton kernel
|
||||
@triton.jit
|
||||
def kernel(X, stride_xm, stride_xk,
|
||||
Y, stride_yk, stride_yn,
|
||||
W, stride_wn, stride_wl,
|
||||
Z, stride_zm, stride_zn,
|
||||
BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr,
|
||||
ADD_MATRIX: tl.constexpr, ADD_ROWS: tl.constexpr, ADD_COLS: tl.constexpr,
|
||||
ALLOW_TF32: tl.constexpr,
|
||||
DO_SOFTMAX: tl.constexpr, CHAIN_DOT: tl.constexpr,
|
||||
TRANS_A: tl.constexpr, TRANS_B: tl.constexpr):
|
||||
off_m = tl.arange(0, BLOCK_M)
|
||||
off_n = tl.arange(0, BLOCK_N)
|
||||
off_l = tl.arange(0, BLOCK_N)
|
||||
off_k = tl.arange(0, BLOCK_K)
|
||||
Xs = X + off_m[:, None] * stride_xm + off_k[None, :] * stride_xk
|
||||
Ys = Y + off_k[:, None] * stride_yk + off_n[None, :] * stride_yn
|
||||
Ws = W + off_n[:, None] * stride_wn + off_l[None, :] * stride_wl
|
||||
Zs = Z + off_m[:, None] * stride_zm + off_n[None, :] * stride_zn
|
||||
x = tl.load(Xs)
|
||||
y = tl.load(Ys)
|
||||
x = tl.trans(x) if TRANS_A else x
|
||||
y = tl.trans(y) if TRANS_B else y
|
||||
z = tl.dot(x, y, allow_tf32=ALLOW_TF32)
|
||||
if ADD_MATRIX:
|
||||
z += tl.load(Zs)
|
||||
if ADD_ROWS:
|
||||
ZRs = Z + off_m * stride_zm
|
||||
z += tl.load(ZRs)[:, None]
|
||||
if ADD_COLS:
|
||||
ZCs = Z + off_n * stride_zn
|
||||
z += tl.load(ZCs)[None, :]
|
||||
if DO_SOFTMAX:
|
||||
max = tl.max(z, 1)
|
||||
z = z - max[:, None]
|
||||
num = tl.exp(z)
|
||||
den = tl.sum(num, 1)
|
||||
z = num / den[:, None]
|
||||
if CHAIN_DOT:
|
||||
# tl.store(Zs, z)
|
||||
# tl.debug_barrier()
|
||||
z = tl.dot(z.to(tl.float16), tl.load(Ws))
|
||||
tl.store(Zs, z)
|
||||
# input
|
||||
rs = RandomState(17)
|
||||
x = numpy_random((K, M) if trans_a else (M, K), dtype_str=dtype, rs=rs) * .1
|
||||
y = numpy_random((N, K) if trans_b else (K, N), dtype_str=dtype, rs=rs) * .1
|
||||
w = numpy_random((N, N), dtype_str=dtype, rs=rs) * .1
|
||||
if allow_tf32:
|
||||
x = (x.view('uint32') & np.uint32(0xffffe000)).view('float32')
|
||||
y = (y.view('uint32') & np.uint32(0xffffe000)).view('float32')
|
||||
w = (w.view('uint32') & np.uint32(0xffffe000)).view('float32')
|
||||
x_tri = to_triton(x, device=device)
|
||||
y_tri = to_triton(y, device=device)
|
||||
w_tri = to_triton(w, device=device)
|
||||
# triton result
|
||||
z = 1 + numpy_random((M, N), dtype_str=dtype, rs=rs) * .1
|
||||
z_tri = to_triton(z, device=device)
|
||||
if epilogue == 'trans':
|
||||
z_tri = torch.as_strided(z_tri, (M, N), z_tri.stride()[::-1])
|
||||
pgm = kernel[(1, 1)](x_tri, x_tri.stride(0), x_tri.stride(1),
|
||||
y_tri, y_tri.stride(0), y_tri.stride(1),
|
||||
w_tri, w_tri.stride(0), w_tri.stride(1),
|
||||
z_tri, z_tri.stride(0), z_tri.stride(1),
|
||||
TRANS_A=trans_a, TRANS_B=trans_b,
|
||||
BLOCK_M=M, BLOCK_K=K, BLOCK_N=N,
|
||||
ADD_MATRIX=epilogue == 'add-matrix',
|
||||
ADD_ROWS=epilogue == 'add-rows',
|
||||
ADD_COLS=epilogue == 'add-cols',
|
||||
DO_SOFTMAX=epilogue == 'softmax',
|
||||
CHAIN_DOT=epilogue == 'chain-dot',
|
||||
ALLOW_TF32=allow_tf32,
|
||||
num_warps=num_warps)
|
||||
# torch result
|
||||
x_ref = x.T if trans_a else x
|
||||
y_ref = y.T if trans_b else y
|
||||
z_ref = np.matmul(x_ref, y_ref)
|
||||
if epilogue == 'add-matrix':
|
||||
z_ref += z
|
||||
if epilogue == 'add-rows':
|
||||
z_ref += z[:, 0][:, None]
|
||||
if epilogue == 'add-cols':
|
||||
z_ref += z[0, :][None, :]
|
||||
if epilogue == 'softmax':
|
||||
num = np.exp(z_ref - np.max(z_ref, axis=-1, keepdims=True))
|
||||
denom = np.sum(num, axis=-1, keepdims=True)
|
||||
z_ref = num / denom
|
||||
if epilogue == 'chain-dot':
|
||||
z_ref = np.matmul(z_ref, w)
|
||||
# compare
|
||||
# print(z_ref[:,0], z_tri[:,0])
|
||||
np.testing.assert_allclose(z_ref, to_numpy(z_tri), rtol=0.01)
|
||||
# make sure ld/st are vectorized
|
||||
ptx = pgm.asm['ptx']
|
||||
assert 'ld.global.v4' in ptx
|
||||
assert 'st.global.v4' in ptx
|
||||
if allow_tf32:
|
||||
assert 'mma.sync.aligned.m16n8k8.row.col.f32.tf32.tf32.f32' in ptx
|
||||
elif dtype == 'float32':
|
||||
assert 'mma.sync.aligned.m16n8k8.row.col.f32.tf32.tf32.f32' not in ptx
|
||||
elif dtype == 'int8':
|
||||
assert 'mma.sync.aligned.m16n8k32.row.col.satfinite.s32.s8.s8.s32' in ptx
|
||||
|
||||
|
||||
# def test_dot_without_load():
|
||||
@@ -1559,7 +1553,7 @@ def test_num_warps_pow2():
|
||||
|
||||
@pytest.mark.parametrize("dtype_str, expr, lib_path",
|
||||
[('int32', 'libdevice.ffs', ''),
|
||||
('float32', 'libdevice.pow', '/usr/local/cuda/nvvm/libdevice/libdevice.10.bc'),
|
||||
('float32', 'libdevice.pow', system_libdevice_path()),
|
||||
('float64', 'libdevice.norm4d', '')])
|
||||
def test_libdevice_tensor(dtype_str, expr, lib_path):
|
||||
|
||||
|
@@ -5,6 +5,7 @@ import _testcapi
|
||||
import pytest
|
||||
import torch
|
||||
from torch.testing import assert_close
|
||||
from tests.libdevice_testutil import system_libdevice_path
|
||||
|
||||
import triton
|
||||
import triton.language as tl
|
||||
@@ -32,8 +33,6 @@ torch_ops = {
|
||||
"where": "where",
|
||||
}
|
||||
|
||||
libdevice = '/usr/local/cuda/nvvm/libdevice/libdevice.10.bc'
|
||||
|
||||
|
||||
def get_tensor(shape, data_type, b_positive=False):
|
||||
x = None
|
||||
@@ -90,7 +89,11 @@ def kernel(X, Y, BLOCK: tl.constexpr):
|
||||
x = get_tensor(shape, input0_type, expr == 'log' or expr == 'sqrt')
|
||||
# triton result
|
||||
y = torch.zeros(shape, dtype=torch_type[output_type], device="cuda")
|
||||
kernel[(1,)](x, y, BLOCK=shape[0], extern_libs={"libdevice": libdevice})
|
||||
kernel[(1,)](
|
||||
x, y,
|
||||
BLOCK=shape[0],
|
||||
extern_libs={"libdevice": system_libdevice_path()},
|
||||
)
|
||||
# reference result
|
||||
y_ref = getattr(torch, torch_ops[expr])(x)
|
||||
# compare
|
||||
@@ -134,7 +137,11 @@ def kernel(X0, X1, Y, BLOCK: tl.constexpr):
|
||||
|
||||
# triton result
|
||||
y = torch.zeros(shape, dtype=torch_type[output_type], device="cuda")
|
||||
kernel[(1,)](x0, x1, y, BLOCK=shape[0], extern_libs={"libdevice": libdevice})
|
||||
kernel[(1,)](
|
||||
x0, x1, y,
|
||||
BLOCK=shape[0],
|
||||
extern_libs={"libdevice": system_libdevice_path()},
|
||||
)
|
||||
# reference result
|
||||
|
||||
if expr == "cdiv":
|
||||
@@ -182,7 +189,11 @@ def kernel(X0, X1, X2, Y, BLOCK: tl.constexpr):
|
||||
|
||||
# triton result
|
||||
y = torch.zeros(shape, dtype=torch_type[output_type], device="cuda")
|
||||
kernel[(1,)](x0, x1, x2, y, BLOCK=shape[0], extern_libs={"libdevice": libdevice})
|
||||
kernel[(1,)](
|
||||
x0, x1, x2, y,
|
||||
BLOCK=shape[0],
|
||||
extern_libs={"libdevice": system_libdevice_path()},
|
||||
)
|
||||
# reference result
|
||||
|
||||
y_ref = getattr(torch, torch_ops[expr])(x0, x1, x2)
|
||||
|
@@ -5,6 +5,7 @@ from torch.testing import assert_close
|
||||
|
||||
import triton
|
||||
import triton.language as tl
|
||||
from tests.libdevice_testutil import system_libdevice_path
|
||||
|
||||
|
||||
@pytest.mark.parametrize('num_warps, block_size, iter_size', [
|
||||
@@ -125,7 +126,7 @@ def test_fmad_rn_no_mask(num_warps, block_size, iter_size):
|
||||
|
||||
|
||||
@pytest.mark.parametrize("dtype_str, expr, lib_path",
|
||||
[('int32', 'libdevice.ffs', '/usr/local/cuda/nvvm/libdevice/libdevice.10.bc'),
|
||||
[('int32', 'libdevice.ffs', system_libdevice_path()),
|
||||
('int32', 'libdevice.ffs', '')])
|
||||
def test_libdevice(dtype_str, expr, lib_path):
|
||||
src = f"""
|
||||
|
@@ -172,8 +172,9 @@ def get_proper_err(a, b, golden):
|
||||
[128, 64, 128, 4, 128, 64, 128, False, False],
|
||||
[16, 16, 16, 16, 16, 16, 16, False, False], # wpt overflow issue
|
||||
# K-Forloop
|
||||
[32, 32, 64, 4, 32, 32, 32, False, False], # Single shared encoding
|
||||
[16, 16, 128, 4, 16, 16, 16, False, False], # Single shared encoding and small k
|
||||
# [16, 16, 64, 4, 8, 8, 8, False, False], # Wrap threads
|
||||
[32, 32, 64, 4, 32, 32, 32, False, False], # Single shared encoding
|
||||
[16, 16, 128, 4, 16, 16, 16, False, False], # Single shared encoding and small k
|
||||
[64, 32, 128, 4, 64, 32, 64, False, False],
|
||||
[128, 16, 128, 4, 128, 16, 32, False, False],
|
||||
[32, 16, 128, 4, 32, 16, 32, False, False],
|
||||
@@ -187,7 +188,8 @@ def get_proper_err(a, b, golden):
|
||||
[128, 256, 128, 4, 128, 256, 32, False, False],
|
||||
[256, 128, 64, 4, 256, 128, 16, False, False],
|
||||
[128, 64, 128, 4, 128, 64, 32, False, False],
|
||||
# [16, 16, 64, 4, 16, 16, 16, False, False], # TODO failed due to pipeline pass
|
||||
[16, 16, 64, 4, 16, 16, 16, False, False],
|
||||
[32, 32, 64, 4, 32, 32, 32, False, False],
|
||||
# trans
|
||||
[128, 64, 128, 4, 128, 64, 32, True, False],
|
||||
[128, 64, 128, 4, 128, 64, 32, False, True],
|
||||
@@ -218,14 +220,17 @@ def test_gemm(SIZE_M, SIZE_N, SIZE_K, NUM_WARPS, BLOCK_SIZE_M, BLOCK_SIZE_N, BLO
|
||||
assert_close(c, golden, rtol=max(1e-4, 1.5 * golden_rel_err), atol=max(1e-4, 1.5 * golden_abs_err), check_dtype=False)
|
||||
|
||||
|
||||
@pytest.mark.parametrize('M,N,K,num_warps,block_M,block_N,block_K', [
|
||||
[32, 32, 16, 4, 32, 32, 16],
|
||||
[32, 16, 16, 4, 32, 32, 16],
|
||||
[128, 8, 8, 4, 32, 32, 16],
|
||||
# TODO[Superjomn]: fix it later
|
||||
# [127, 41, 43, 4, 32, 32, 16],
|
||||
@pytest.mark.parametrize('M,N,K,num_warps,block_M,block_N,block_K,allow_tf32', [
|
||||
[32, 32, 16, 4, 32, 32, 16, False],
|
||||
[32, 32, 16, 4, 32, 32, 16, True],
|
||||
[32, 16, 16, 4, 32, 32, 16, False],
|
||||
[32, 16, 16, 4, 32, 32, 16, True],
|
||||
[127, 41, 43, 4, 32, 32, 16, False],
|
||||
[127, 41, 43, 4, 32, 32, 16, True],
|
||||
[128, 8, 8, 4, 32, 32, 16, False],
|
||||
[128, 8, 8, 4, 32, 32, 16, True]
|
||||
])
|
||||
def test_gemm_fmadot(M, N, K, num_warps, block_M, block_N, block_K):
|
||||
def test_gemm_fp32(M, N, K, num_warps, block_M, block_N, block_K, allow_tf32):
|
||||
@triton.jit
|
||||
def matmul_kernel(
|
||||
a_ptr, b_ptr, c_ptr,
|
||||
@@ -234,6 +239,7 @@ def test_gemm_fmadot(M, N, K, num_warps, block_M, block_N, block_K):
|
||||
stride_bk, stride_bn,
|
||||
stride_cm, stride_cn,
|
||||
BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr,
|
||||
ALLOW_TF32: tl.constexpr
|
||||
):
|
||||
pid = tl.program_id(axis=0)
|
||||
# num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
|
||||
@@ -251,10 +257,9 @@ def test_gemm_fmadot(M, N, K, num_warps, block_M, block_N, block_K):
|
||||
for k in range(0, K, BLOCK_SIZE_K):
|
||||
a_mask = (offs_am[:, None] < M) & (offs_k[None, :] < K)
|
||||
b_mask = (offs_k[:, None] < K) & (offs_bn[None, :] < N)
|
||||
a = tl.load(a_ptrs, a_mask)
|
||||
b = tl.load(b_ptrs, b_mask)
|
||||
# NOTE the allow_tf32 should be false to force the dot op to do fmadot lowering
|
||||
accumulator += tl.dot(a, b, allow_tf32=False)
|
||||
a = tl.load(a_ptrs, a_mask, other=0.0)
|
||||
b = tl.load(b_ptrs, b_mask, other=0.0)
|
||||
accumulator += tl.dot(a, b, allow_tf32=ALLOW_TF32)
|
||||
a_ptrs += BLOCK_SIZE_K * stride_ak
|
||||
b_ptrs += BLOCK_SIZE_K * stride_bk
|
||||
offs_k += BLOCK_SIZE_K
|
||||
@@ -265,6 +270,9 @@ def test_gemm_fmadot(M, N, K, num_warps, block_M, block_N, block_K):
|
||||
c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N)
|
||||
tl.store(c_ptrs, accumulator, c_mask)
|
||||
|
||||
# Configure the pytorch counterpart
|
||||
torch.backends.cuda.matmul.allow_tf32 = allow_tf32
|
||||
|
||||
a = torch.randn((M, K), device='cuda', dtype=torch.float32)
|
||||
b = torch.randn((K, N), device='cuda', dtype=torch.float32)
|
||||
c = torch.empty((M, N), device=a.device, dtype=torch.float32)
|
||||
@@ -275,8 +283,30 @@ def test_gemm_fmadot(M, N, K, num_warps, block_M, block_N, block_K):
|
||||
stride_am=a.stride(0), stride_ak=a.stride(1),
|
||||
stride_bk=b.stride(0), stride_bn=b.stride(1),
|
||||
stride_cm=c.stride(0), stride_cn=c.stride(1),
|
||||
BLOCK_SIZE_M=block_M, BLOCK_SIZE_N=block_N, BLOCK_SIZE_K=block_K)
|
||||
BLOCK_SIZE_M=block_M, BLOCK_SIZE_N=block_N, BLOCK_SIZE_K=block_K, ALLOW_TF32=allow_tf32)
|
||||
|
||||
golden = torch.matmul(a, b)
|
||||
golden_abs_err, golden_rel_err = get_proper_err(a, b, golden)
|
||||
torch.testing.assert_close(c, golden, rtol=max(1e-4, 1.5 * golden_rel_err), atol=max(1e-4, 1.5 * golden_abs_err))
|
||||
if allow_tf32:
|
||||
# TF32 is not accurate enough
|
||||
torch.testing.assert_close(c, golden, rtol=max(1e-2, 1.5 * golden_rel_err), atol=max(1e-2, 1.5 * golden_abs_err))
|
||||
else:
|
||||
torch.testing.assert_close(c, golden, rtol=max(1e-4, 1.5 * golden_rel_err), atol=max(1e-4, 1.5 * golden_abs_err))
|
||||
|
||||
|
||||
# NOTE this is useful only on Volta GPU.
|
||||
@pytest.mark.parametrize('SHAPE,NUM_WARPS,TRANS_A,TRANS_B', [
|
||||
(shape, num_warps, trans_a, trans_b)
|
||||
for shape in [
|
||||
[16, 16, 16],
|
||||
[16, 16, 32],
|
||||
[32, 16, 16],
|
||||
[32, 32, 32],
|
||||
[128, 16, 16],
|
||||
]
|
||||
for num_warps in [1]
|
||||
for trans_a in [False]
|
||||
for trans_b in [False]
|
||||
])
|
||||
def test_gemm_no_scf_for_mmav1(SHAPE, NUM_WARPS, TRANS_A, TRANS_B):
|
||||
test_gemm_no_scf(SHAPE, NUM_WARPS, TRANS_A, TRANS_B)
|
||||
|
@@ -1,4 +1,5 @@
|
||||
import pytest
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch.testing import assert_close
|
||||
|
||||
@@ -13,7 +14,9 @@ dtypes_with_bfloat16 = int_dtypes + uint_dtypes + float_dtypes
|
||||
dtype_mapping = {dtype_str: torch.__dict__[dtype_str] for dtype_str in dtypes}
|
||||
|
||||
|
||||
def get_reduced_dtype(dtype):
|
||||
def get_reduced_dtype(op, dtype):
|
||||
if op in ['argmin', 'argmax']:
|
||||
return torch.int32
|
||||
if dtype in [torch.int8, torch.int16, torch.uint8]:
|
||||
return torch.int32
|
||||
if dtype in [torch.bfloat16]:
|
||||
@@ -48,7 +51,7 @@ def reduce2d_kernel(x_ptr, z_ptr, axis: tl.constexpr, block_m: tl.constexpr, blo
|
||||
|
||||
reduce1d_configs = [
|
||||
(op, dtype, shape)
|
||||
for op in ['sum', 'min', 'max']
|
||||
for op in ['sum', 'min', 'max', 'argmin', 'argmax', 'xor_sum']
|
||||
for dtype in dtypes
|
||||
for shape in [4, 8, 16, 32, 64, 128, 512, 1024]
|
||||
]
|
||||
@@ -56,8 +59,11 @@ reduce1d_configs = [
|
||||
|
||||
@pytest.mark.parametrize('op, dtype, shape', reduce1d_configs)
|
||||
def test_reduce1d(op, dtype, shape):
|
||||
if op == 'xor_sum' and dtype in float_dtypes:
|
||||
return
|
||||
|
||||
dtype = dtype_mapping[dtype]
|
||||
reduced_dtype = get_reduced_dtype(dtype)
|
||||
reduced_dtype = get_reduced_dtype(op, dtype)
|
||||
|
||||
if dtype.is_floating_point:
|
||||
x = torch.randn((shape,), device='cuda', dtype=dtype)
|
||||
@@ -79,8 +85,17 @@ def test_reduce1d(op, dtype, shape):
|
||||
golden_z = torch.sum(x, dtype=reduced_dtype)
|
||||
elif op == 'min':
|
||||
golden_z = torch.min(x).to(reduced_dtype)
|
||||
else:
|
||||
elif op == 'max':
|
||||
golden_z = torch.max(x).to(reduced_dtype)
|
||||
elif op == 'argmin':
|
||||
golden_z = torch.argmin(x).to(reduced_dtype)
|
||||
elif op == 'argmax':
|
||||
golden_z = torch.argmax(x).to(reduced_dtype)
|
||||
elif op == 'xor_sum':
|
||||
sum_npy = np.bitwise_xor.reduce(x.cpu().numpy())
|
||||
golden_z = torch.tensor(sum_npy, dtype=reduced_dtype).cuda()
|
||||
else:
|
||||
raise RuntimeError(f'Unknwon reduce op {op}')
|
||||
|
||||
if dtype.is_floating_point and op == 'sum':
|
||||
if shape >= 256:
|
||||
@@ -95,7 +110,7 @@ def test_reduce1d(op, dtype, shape):
|
||||
|
||||
reduce2d_configs = [
|
||||
(op, dtype, shape, axis)
|
||||
for op in ['sum', 'min', 'max']
|
||||
for op in ['sum', 'min', 'max', 'argmin', 'argmax', 'xor_sum']
|
||||
for dtype in dtypes
|
||||
for shape in [(1, 4), (1, 8), (1, 16), (1, 32), (2, 32), (4, 32), (4, 128), (32, 64)]
|
||||
for axis in [0, 1]
|
||||
@@ -104,8 +119,11 @@ reduce2d_configs = [
|
||||
|
||||
@pytest.mark.parametrize('op, dtype, shape, axis', reduce2d_configs)
|
||||
def test_reduce2d(op, dtype, shape, axis):
|
||||
if op == 'xor_sum' and dtype in float_dtypes:
|
||||
return
|
||||
|
||||
dtype = dtype_mapping[dtype]
|
||||
reduced_dtype = get_reduced_dtype(dtype)
|
||||
reduced_dtype = get_reduced_dtype(op, dtype)
|
||||
reduced_shape = (shape[1 - axis],)
|
||||
|
||||
if dtype.is_floating_point:
|
||||
@@ -123,8 +141,18 @@ def test_reduce2d(op, dtype, shape, axis):
|
||||
golden_z = torch.sum(x, dim=axis, keepdim=False, dtype=reduced_dtype)
|
||||
elif op == 'min':
|
||||
golden_z = torch.min(x, dim=axis, keepdim=False)[0].to(reduced_dtype)
|
||||
else:
|
||||
elif op == 'max':
|
||||
golden_z = torch.max(x, dim=axis, keepdim=False)[0].to(reduced_dtype)
|
||||
elif op == 'argmin':
|
||||
golden_z = torch.argmin(x, dim=axis, keepdim=False).to(reduced_dtype)
|
||||
elif op == 'argmax':
|
||||
golden_z = torch.argmax(x, dim=axis, keepdim=False).to(reduced_dtype)
|
||||
elif op == 'xor_sum':
|
||||
sum_npy = np.bitwise_xor.reduce(x.cpu().numpy(), axis=axis, keepdims=False)
|
||||
golden_z = torch.tensor(sum_npy, dtype=reduced_dtype).cuda()
|
||||
else:
|
||||
raise RuntimeError(f'Unknwon reduce op {op}')
|
||||
|
||||
if dtype.is_floating_point and op == 'sum':
|
||||
if shape[axis] >= 256:
|
||||
assert_close(z, golden_z, rtol=0.05, atol=0.1)
|
||||
|
@@ -1,15 +1,52 @@
|
||||
"""isort:skip_file"""
|
||||
# flake8: noqa: F401
|
||||
__version__ = '2.0.0'
|
||||
|
||||
# ---------------------------------------
|
||||
# Note: import order is significant here.
|
||||
|
||||
# TODO: torch needs to be imported first
|
||||
# or pybind11 shows `munmap_chunk(): invalid pointer`
|
||||
import torch
|
||||
import torch # noqa: F401
|
||||
|
||||
# submodules
|
||||
from .utils import *
|
||||
from .runtime import Config, autotune, heuristics, JITFunction, KernelInterface
|
||||
from . import impl
|
||||
from .utils import (
|
||||
cdiv,
|
||||
MockTensor,
|
||||
next_power_of_2,
|
||||
reinterpret,
|
||||
TensorWrapper,
|
||||
)
|
||||
from .runtime import (
|
||||
autotune,
|
||||
Config,
|
||||
heuristics,
|
||||
JITFunction,
|
||||
KernelInterface,
|
||||
)
|
||||
from .runtime.jit import jit
|
||||
from .compiler import compile, CompilationError
|
||||
from . import language
|
||||
from . import testing
|
||||
from . import ops
|
||||
|
||||
__all__ = [
|
||||
"autotune",
|
||||
"cdiv",
|
||||
"CompilationError",
|
||||
"compile",
|
||||
"Config",
|
||||
"heuristics",
|
||||
"impl",
|
||||
"jit",
|
||||
"JITFunction",
|
||||
"KernelInterface",
|
||||
"language",
|
||||
"MockTensor",
|
||||
"next_power_of_2",
|
||||
"ops",
|
||||
"reinterpret",
|
||||
"runtime",
|
||||
"TensorWrapper",
|
||||
"testing",
|
||||
]
|
||||
|
@@ -25,6 +25,8 @@ from filelock import FileLock
|
||||
|
||||
import triton
|
||||
import triton._C.libtriton.triton as _triton
|
||||
|
||||
from . import impl
|
||||
from .tools.disasm import extract
|
||||
|
||||
|
||||
@@ -359,7 +361,7 @@ class CodeGenerator(ast.NodeVisitor):
|
||||
cond = cond.to(triton.language.int1, _builder=self.builder)
|
||||
with enter_sub_region(self) as sr:
|
||||
liveins, ip_block = sr
|
||||
|
||||
liveins_copy = liveins.copy()
|
||||
then_block = self.builder.create_block()
|
||||
self.builder.set_insertion_point_to_start(then_block)
|
||||
self.visit_compound_statement(node.body)
|
||||
@@ -394,7 +396,15 @@ class CodeGenerator(ast.NodeVisitor):
|
||||
if then_defs[then_name].type == else_defs[else_name].type:
|
||||
names.append(then_name)
|
||||
ret_types.append(then_defs[then_name].type)
|
||||
|
||||
|
||||
# defined in else block but not in then block
|
||||
# to find in parent scope and yield them
|
||||
for else_name in else_defs:
|
||||
if else_name in liveins and else_name not in then_defs:
|
||||
if else_defs[else_name].type == liveins[else_name].type:
|
||||
names.append(else_name)
|
||||
ret_types.append(else_defs[else_name].type)
|
||||
then_defs[else_name] = liveins_copy[else_name]
|
||||
self.builder.set_insertion_point_to_end(ip_block)
|
||||
|
||||
if then_defs or node.orelse: # with else block
|
||||
@@ -528,8 +538,7 @@ class CodeGenerator(ast.NodeVisitor):
|
||||
[ty.to_ir(self.builder) for ty in ret_types])
|
||||
loop_block.merge_block_before(after_block)
|
||||
self.builder.set_insertion_point_to_end(after_block)
|
||||
if len(yields) > 0:
|
||||
self.builder.create_yield_op([y.handle for y in yields])
|
||||
self.builder.create_yield_op([y.handle for y in yields])
|
||||
|
||||
# update global uses in while_op
|
||||
for i, name in enumerate(names):
|
||||
@@ -594,11 +603,8 @@ class CodeGenerator(ast.NodeVisitor):
|
||||
ub = self.builder.create_to_index(ub)
|
||||
step = self.builder.create_to_index(step)
|
||||
# Create placeholder for the loop induction variable
|
||||
# We can use any value because the variable isn't a constexpr
|
||||
# but use a distinctive value (of the right type) to ease debugging
|
||||
st_target = ast.Name(id=node.target.id, ctx=ast.Store())
|
||||
init_node = ast.Assign(targets=[st_target], value=ast.Num(value=0xBADF00D))
|
||||
self.visit(init_node)
|
||||
iv = self.builder.create_undef(self.builder.get_int32_ty())
|
||||
self.set_value(node.target.id, triton.language.core.tensor(iv, triton.language.core.int32))
|
||||
|
||||
with enter_sub_region(self) as sr:
|
||||
liveins, insert_block = sr
|
||||
@@ -711,9 +717,8 @@ class CodeGenerator(ast.NodeVisitor):
|
||||
for i in range(call_op.get_num_results()):
|
||||
results.append(triton.language.tensor(call_op.get_result(i), callee_ret_type[i]))
|
||||
return tuple(results)
|
||||
if hasattr(fn, '__self__') and self.is_triton_tensor(fn.__self__) or \
|
||||
sys.modules[fn.__module__] is triton.language.core or \
|
||||
isinstance(fn, triton.language.extern.ExternalFunction):
|
||||
if (hasattr(fn, '__self__') and self.is_triton_tensor(fn.__self__)) \
|
||||
or impl.is_builtin(fn):
|
||||
return fn(*args, _builder=self.builder, **kws)
|
||||
if fn in self.builtins.values():
|
||||
args = [arg.value if isinstance(arg, triton.language.constexpr) else arg
|
||||
@@ -757,6 +762,9 @@ class CodeGenerator(ast.NodeVisitor):
|
||||
|
||||
def visit_Attribute(self, node):
|
||||
lhs = self.visit(node.value)
|
||||
if isinstance(lhs, triton.language.tensor):
|
||||
if node.attr == "T":
|
||||
return triton.language.semantic.trans(lhs, builder=self.builder)
|
||||
return getattr(lhs, node.attr)
|
||||
|
||||
def visit_Expr(self, node):
|
||||
@@ -1014,6 +1022,7 @@ def ty_to_cpp(ty):
|
||||
"u32": "uint32_t",
|
||||
"u64": "uint64_t",
|
||||
"fp32": "float",
|
||||
"f32": "float",
|
||||
}[ty]
|
||||
|
||||
|
||||
@@ -1044,6 +1053,7 @@ def generate_launcher(constants, signature):
|
||||
'u32': 'uint32_t',
|
||||
'u64': 'uint64_t',
|
||||
'fp32': 'float',
|
||||
'f32': 'float',
|
||||
'fp64': 'double',
|
||||
}[ty]
|
||||
|
||||
@@ -1343,7 +1353,31 @@ def make_hash(fn, **kwargs):
|
||||
key = f"{fn.cache_key}-{''.join(signature.values())}-{configs_key}-{constants}-{num_warps}-{num_stages}"
|
||||
return hashlib.md5(key.encode("utf-8")).hexdigest()
|
||||
assert isinstance(fn, str)
|
||||
return hashlib.md5(Path(fn).read_text().encode("utf-8")).hexdigest()
|
||||
return hashlib.md5((Path(fn).read_text() + triton.runtime.jit.version_key()).encode("utf-8")).hexdigest()
|
||||
|
||||
|
||||
# - ^\s*func\s+ : match the start of the string, any leading whitespace, the keyword func,
|
||||
# and any following whitespace
|
||||
# - (public\s+)? : optionally match the keyword public and any following whitespace
|
||||
# - (@\w+) : match an @ symbol followed by one or more word characters
|
||||
# (letters, digits, or underscores), and capture it as group 1 (the function name)
|
||||
# - (\((?:%\w+: \S+(?: \{\S+ = \S+ : \S+\})?(?:, )?)*\)) : match a pair of parentheses enclosing
|
||||
# zero or more arguments separated by commas, and capture it as group 2 (the argument list)
|
||||
mlir_prototype_pattern = r'^\s*func\s+(?:public\s+)?(@\w+)(\((?:%\w+: \S+(?: \{\S+ = \S+ : \S+\})?(?:, )?)*\))\s*\{\s*$'
|
||||
ptx_prototype_pattern = r"\.(?:visible|extern)\s+\.(?:entry|func)\s+(\w+)\s*\(([^)]*)\)"
|
||||
prototype_pattern = {
|
||||
"ttir": mlir_prototype_pattern,
|
||||
"ttgir": mlir_prototype_pattern,
|
||||
"ptx": ptx_prototype_pattern,
|
||||
}
|
||||
|
||||
mlir_arg_type_pattern = r'%\w+: ([^,^\)\s]+)(?: \{\S+ = \S+ : \S+\})?,?'
|
||||
ptx_arg_type_pattern = r"\.param\s+\.(\w+)"
|
||||
arg_type_pattern = {
|
||||
"ttir": mlir_arg_type_pattern,
|
||||
"ttgir": mlir_arg_type_pattern,
|
||||
"ptx": ptx_arg_type_pattern,
|
||||
}
|
||||
|
||||
|
||||
# def compile(fn, signature: str, device: int = -1, constants=dict(), num_warps: int = 4, num_stages: int = 3, extern_libs=None, configs=None):
|
||||
@@ -1354,6 +1388,27 @@ def compile(fn, **kwargs):
|
||||
context = _triton.ir.context()
|
||||
asm = dict()
|
||||
constants = kwargs.get("constants", dict())
|
||||
num_warps = kwargs.get("num_warps", 4)
|
||||
num_stages = kwargs.get("num_stages", 3)
|
||||
extern_libs = kwargs.get("extern_libs", dict())
|
||||
device = kwargs.get("device", torch.cuda.current_device())
|
||||
capability = torch.cuda.get_device_capability()
|
||||
capability = capability[0]*10 + capability[1]
|
||||
# build compilation stages
|
||||
stages = {
|
||||
"ast" : (lambda path: fn, None),
|
||||
"ttir": (lambda path: _triton.ir.parse_mlir_module(path, context),
|
||||
lambda src: ast_to_ttir(src, signature, configs[0], constants)),
|
||||
"ttgir": (lambda path: _triton.ir.parse_mlir_module(path, context),
|
||||
lambda src: ttir_to_ttgir(src, num_warps, num_stages, capability)),
|
||||
"llir": (lambda path: Path(path).read_bytes(),
|
||||
lambda src: ttgir_to_llir(src, extern_libs, capability)),
|
||||
"ptx": (lambda path: Path(path).read_text(),
|
||||
lambda src: llir_to_ptx(src, capability)),
|
||||
"cubin": (lambda path: Path(path).read_bytes(),
|
||||
lambda src: ptx_to_cubin(src, capability))
|
||||
}
|
||||
# find out the signature of the function
|
||||
if isinstance(fn, triton.runtime.JITFunction):
|
||||
configs = kwargs.get("configs", None)
|
||||
signature = kwargs["signature"]
|
||||
@@ -1368,13 +1423,17 @@ def compile(fn, **kwargs):
|
||||
kwargs["signature"] = signature
|
||||
else:
|
||||
assert isinstance(fn, str)
|
||||
name, ir = os.path.basename(fn).split(".")
|
||||
assert ir == "ttgir"
|
||||
asm[ir] = _triton.ir.parse_mlir_module(fn, context)
|
||||
function = asm[ir].get_single_function()
|
||||
param_tys = [convert_type_repr(str(ty)) for ty in function.type.param_types()]
|
||||
_, ir = os.path.basename(fn).split(".")
|
||||
src = Path(fn).read_text()
|
||||
import re
|
||||
match = re.search(prototype_pattern[ir], src, re.MULTILINE)
|
||||
name, signature = match.group(1), match.group(2)
|
||||
print(name, signature)
|
||||
types = re.findall(arg_type_pattern[ir], signature)
|
||||
print(types)
|
||||
param_tys = [convert_type_repr(ty) for ty in types]
|
||||
signature = {k: v for k, v in enumerate(param_tys)}
|
||||
first_stage = 2
|
||||
first_stage = list(stages.keys()).index(ir)
|
||||
|
||||
# cache manager
|
||||
so_path = make_stub(name, signature, constants)
|
||||
@@ -1384,58 +1443,42 @@ def compile(fn, **kwargs):
|
||||
if isinstance(fn, triton.runtime.JITFunction):
|
||||
name, ext = fn.__name__, "ast"
|
||||
else:
|
||||
name, ext = os.path.basename(fn).split(".")
|
||||
# initialize compilation params
|
||||
num_warps = kwargs.get("num_warps", 4)
|
||||
num_stages = kwargs.get("num_stages", 3)
|
||||
extern_libs = kwargs.get("extern_libs", dict())
|
||||
device = kwargs.get("device", torch.cuda.current_device())
|
||||
compute_capability = torch.cuda.get_device_capability(device)
|
||||
compute_capability = compute_capability[0] * 10 + compute_capability[1]
|
||||
name, ext = os.path.basename(fn).split(".")
|
||||
|
||||
# load metadata if any
|
||||
metadata = None
|
||||
if fn_cache_manager.has_file(f'{name}.json'):
|
||||
with open(fn_cache_manager._make_path(f"{name}.json")) as f:
|
||||
metadata = json.load(f)
|
||||
else:
|
||||
metadata = {"num_warps": num_warps, "num_stages": num_stages, "ctime": dict()}
|
||||
# build compilation stages
|
||||
stages = {
|
||||
"ast": (lambda path: fn, None),
|
||||
"ttir": (lambda path: _triton.ir.parse_mlir_module(path, context),
|
||||
lambda src: ast_to_ttir(src, signature, configs[0], constants)),
|
||||
"ttgir": (lambda path: _triton.ir.parse_mlir_module(path, context),
|
||||
lambda src: ttir_to_ttgir(src, num_warps, num_stages, compute_capability)),
|
||||
"llir": (lambda path: Path(path).read_bytes(),
|
||||
lambda src: ttgir_to_llir(src, extern_libs, compute_capability)),
|
||||
"ptx": (lambda path: Path(path).read_text(),
|
||||
lambda src: llir_to_ptx(src, compute_capability)),
|
||||
"cubin": (lambda path: Path(path).read_bytes(),
|
||||
lambda src: ptx_to_cubin(src, compute_capability))
|
||||
}
|
||||
metadata = {"num_warps": num_warps, "num_stages": num_stages, "ctime": dict()}
|
||||
if ext == "ptx":
|
||||
assert "shared" in kwargs, "ptx compilation must provide shared memory size"
|
||||
metadata["shared"] = kwargs["shared"]
|
||||
|
||||
first_stage = list(stages.keys()).index(ext)
|
||||
asm = dict()
|
||||
module = fn
|
||||
# run compilation pipeline and populate metadata
|
||||
for ir, (parse, compile) in list(stages.items())[first_stage:]:
|
||||
path = fn_cache_manager._make_path(f"{name}.{ir}")
|
||||
if ir == ext:
|
||||
next_module = parse(fn)
|
||||
elif os.path.exists(path) and \
|
||||
ir in metadata["ctime"] and \
|
||||
os.path.getctime(path) == metadata["ctime"][ir]:
|
||||
next_module = parse(path)
|
||||
else:
|
||||
next_module = compile(module)
|
||||
fn_cache_manager.put(next_module, f"{name}.{ir}")
|
||||
if os.path.exists(path):
|
||||
metadata["ctime"][ir] = os.path.getctime(path)
|
||||
asm[ir] = next_module if ir == "cubin" else str(next_module)
|
||||
if ir == "llir" and "shared" not in metadata:
|
||||
metadata["shared"] = _triton.get_shared_memory_size(module)
|
||||
if ir == "ptx":
|
||||
metadata["name"] = ptx_get_kernel_name(next_module)
|
||||
module = next_module
|
||||
path = fn_cache_manager._make_path(f"{name}.{ir}")
|
||||
if ir == ext:
|
||||
next_module = parse(fn)
|
||||
elif os.path.exists(path) and\
|
||||
ir in metadata["ctime"] and\
|
||||
os.path.getctime(path) == metadata["ctime"][ir]:
|
||||
next_module = parse(path)
|
||||
else:
|
||||
next_module = compile(module)
|
||||
fn_cache_manager.put(next_module, f"{name}.{ir}")
|
||||
if os.path.exists(path):
|
||||
metadata["ctime"][ir] = os.path.getctime(path)
|
||||
asm[ir] = next_module if ir == "cubin" else str(next_module)
|
||||
if ir == "llir" and "shared" not in metadata:
|
||||
metadata["shared"] = _triton.get_shared_memory_size(module)
|
||||
if ir == "ptx":
|
||||
metadata["name"] = ptx_get_kernel_name(next_module)
|
||||
module = next_module
|
||||
# write-back metadata
|
||||
fn_cache_manager.put(json.dumps(metadata), f"{name}.json", binary=False)
|
||||
# return handle to compiled kernel
|
||||
@@ -1515,7 +1558,7 @@ class CudaUtils(object):
|
||||
}
|
||||
}
|
||||
|
||||
#define CUDA_CHECK(ans) { gpuAssert((ans), __FILE__, __LINE__); }
|
||||
#define CUDA_CHECK(ans) { gpuAssert((ans), __FILE__, __LINE__); if(PyErr_Occurred()) return NULL; }
|
||||
|
||||
static PyObject* loadBinary(PyObject* self, PyObject* args) {
|
||||
const char* name;
|
||||
@@ -1530,7 +1573,6 @@ class CudaUtils(object):
|
||||
CUmodule mod;
|
||||
int32_t n_regs = 0;
|
||||
int32_t n_spills = 0;
|
||||
Py_BEGIN_ALLOW_THREADS;
|
||||
// create driver handles
|
||||
CUDA_CHECK(cuModuleLoadData(&mod, data));
|
||||
CUDA_CHECK(cuModuleGetFunction(&fun, mod, name));
|
||||
@@ -1548,7 +1590,6 @@ class CudaUtils(object):
|
||||
CUDA_CHECK(cuFuncGetAttribute(&shared_static, CU_FUNC_ATTRIBUTE_SHARED_SIZE_BYTES, fun));
|
||||
CUDA_CHECK(cuFuncSetAttribute(fun, CU_FUNC_ATTRIBUTE_MAX_DYNAMIC_SHARED_SIZE_BYTES, shared_optin - shared_static));
|
||||
}
|
||||
Py_END_ALLOW_THREADS;
|
||||
|
||||
if(PyErr_Occurred()) {
|
||||
return NULL;
|
||||
|
22
python/triton/impl/__init__.py
Normal file
22
python/triton/impl/__init__.py
Normal file
@@ -0,0 +1,22 @@
|
||||
"""Triton internal implementation details.
|
||||
|
||||
Client libraries should not import interfaces from the `triton.impl` module;
|
||||
as the details are subject to change.
|
||||
|
||||
APIs defined in the `triton.impl` module which are public will be re-exported
|
||||
in other relevant `triton` module namespaces.
|
||||
"""
|
||||
|
||||
from triton._C.libtriton.triton import ir
|
||||
from .base import (
|
||||
builtin,
|
||||
extern,
|
||||
is_builtin,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"builtin",
|
||||
"extern",
|
||||
"ir",
|
||||
"is_builtin",
|
||||
]
|
36
python/triton/impl/base.py
Normal file
36
python/triton/impl/base.py
Normal file
@@ -0,0 +1,36 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from functools import wraps
|
||||
from typing import TypeVar
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
TRITON_BUILTIN = "__triton_builtin__"
|
||||
|
||||
|
||||
def builtin(fn: T) -> T:
|
||||
"""Mark a function as a builtin."""
|
||||
assert callable(fn)
|
||||
|
||||
@wraps(fn)
|
||||
def wrapper(*args, **kwargs):
|
||||
if "_builder" not in kwargs or kwargs["_builder"] is None:
|
||||
raise ValueError(
|
||||
"Did you forget to add @triton.jit ? "
|
||||
"(`_builder` argument must be provided outside of JIT functions.)"
|
||||
)
|
||||
return fn(*args, **kwargs)
|
||||
|
||||
setattr(wrapper, TRITON_BUILTIN, True)
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
def is_builtin(fn) -> bool:
|
||||
"""Is this a registered triton builtin function?"""
|
||||
return getattr(fn, TRITON_BUILTIN, False)
|
||||
|
||||
|
||||
def extern(fn: T) -> T:
|
||||
"""A decorator for external functions."""
|
||||
return builtin(fn)
|
@@ -1,4 +1,173 @@
|
||||
# flake8: noqa: F401
|
||||
"""isort:skip_file"""
|
||||
# Import order is significant here.
|
||||
|
||||
from ..impl import (
|
||||
ir,
|
||||
builtin,
|
||||
)
|
||||
from . import core, extern, libdevice, random
|
||||
from .core import *
|
||||
from .random import *
|
||||
from .core import (
|
||||
abs,
|
||||
arange,
|
||||
argmin,
|
||||
argmax,
|
||||
atomic_add,
|
||||
atomic_and,
|
||||
atomic_cas,
|
||||
atomic_max,
|
||||
atomic_min,
|
||||
atomic_or,
|
||||
atomic_xchg,
|
||||
atomic_xor,
|
||||
bfloat16,
|
||||
block_type,
|
||||
builtin,
|
||||
cat,
|
||||
cdiv,
|
||||
constexpr,
|
||||
cos,
|
||||
debug_barrier,
|
||||
dot,
|
||||
dtype,
|
||||
exp,
|
||||
fdiv,
|
||||
float16,
|
||||
float32,
|
||||
float64,
|
||||
float8,
|
||||
function_type,
|
||||
int1,
|
||||
int16,
|
||||
int32,
|
||||
int64,
|
||||
int8,
|
||||
load,
|
||||
log,
|
||||
max,
|
||||
max_contiguous,
|
||||
maximum,
|
||||
min,
|
||||
minimum,
|
||||
multiple_of,
|
||||
num_programs,
|
||||
pi32_t,
|
||||
pointer_type,
|
||||
printf,
|
||||
program_id,
|
||||
ravel,
|
||||
sigmoid,
|
||||
sin,
|
||||
softmax,
|
||||
sqrt,
|
||||
store,
|
||||
sum,
|
||||
swizzle2d,
|
||||
tensor,
|
||||
trans,
|
||||
triton,
|
||||
uint16,
|
||||
uint32,
|
||||
uint64,
|
||||
uint8,
|
||||
umulhi,
|
||||
void,
|
||||
where,
|
||||
xor_sum,
|
||||
zeros,
|
||||
zeros_like,
|
||||
)
|
||||
from .random import (
|
||||
pair_uniform_to_normal,
|
||||
philox,
|
||||
philox_impl,
|
||||
rand,
|
||||
rand4x,
|
||||
randint,
|
||||
randint4x,
|
||||
randn,
|
||||
randn4x,
|
||||
uint32_to_uniform_float,
|
||||
)
|
||||
|
||||
|
||||
__all__ = [
|
||||
"abs",
|
||||
"arange",
|
||||
"argmin",
|
||||
"argmax",
|
||||
"atomic_add",
|
||||
"atomic_and",
|
||||
"atomic_cas",
|
||||
"atomic_max",
|
||||
"atomic_min",
|
||||
"atomic_or",
|
||||
"atomic_xchg",
|
||||
"atomic_xor",
|
||||
"bfloat16",
|
||||
"block_type",
|
||||
"builtin",
|
||||
"cat",
|
||||
"cdiv",
|
||||
"constexpr",
|
||||
"cos",
|
||||
"debug_barrier",
|
||||
"dot",
|
||||
"dtype",
|
||||
"exp",
|
||||
"fdiv",
|
||||
"float16",
|
||||
"float32",
|
||||
"float64",
|
||||
"float8",
|
||||
"function_type",
|
||||
"int1",
|
||||
"int16",
|
||||
"int32",
|
||||
"int64",
|
||||
"int8",
|
||||
"ir",
|
||||
"load",
|
||||
"log",
|
||||
"max",
|
||||
"max_contiguous",
|
||||
"maximum",
|
||||
"min",
|
||||
"minimum",
|
||||
"multiple_of",
|
||||
"num_programs",
|
||||
"pair_uniform_to_normal",
|
||||
"philox",
|
||||
"philox_impl",
|
||||
"pi32_t",
|
||||
"pointer_type",
|
||||
"printf",
|
||||
"program_id",
|
||||
"rand",
|
||||
"rand4x",
|
||||
"randint",
|
||||
"randint4x",
|
||||
"randn",
|
||||
"randn4x",
|
||||
"ravel",
|
||||
"sigmoid",
|
||||
"sin",
|
||||
"softmax",
|
||||
"sqrt",
|
||||
"store",
|
||||
"sum",
|
||||
"swizzle2d",
|
||||
"tensor",
|
||||
"trans",
|
||||
"triton",
|
||||
"uint16",
|
||||
"uint32",
|
||||
"uint32_to_uniform_float",
|
||||
"uint64",
|
||||
"uint8",
|
||||
"umulhi",
|
||||
"void",
|
||||
"where",
|
||||
"xor_sum",
|
||||
"zeros",
|
||||
"zeros_like",
|
||||
]
|
||||
|
@@ -1,13 +1,13 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from enum import Enum
|
||||
from functools import wraps
|
||||
from typing import List
|
||||
from typing import List, Callable, TypeVar
|
||||
|
||||
import triton
|
||||
from . import semantic
|
||||
from . import builtin, semantic
|
||||
from triton._C.libtriton.triton import ir
|
||||
|
||||
T = TypeVar('T')
|
||||
|
||||
def _to_tensor(x, builder):
|
||||
if isinstance(x, bool):
|
||||
@@ -33,17 +33,6 @@ def _to_tensor(x, builder):
|
||||
assert False, f'cannot convert {x} to tensor'
|
||||
|
||||
|
||||
def builtin(fn):
|
||||
@wraps(fn)
|
||||
def wrapper(*args, **kwargs):
|
||||
if '_builder' not in kwargs or \
|
||||
kwargs['_builder'] is None:
|
||||
raise ValueError("Did you forget to add @triton.jit ? (`_builder` argument must be provided outside of JIT functions.)")
|
||||
return fn(*args, **kwargs)
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
class dtype:
|
||||
SINT_TYPES = ['int1', 'int8', 'int16', 'int32', 'int64']
|
||||
UINT_TYPES = ['uint8', 'uint16', 'uint32', 'uint64']
|
||||
@@ -405,14 +394,14 @@ class constexpr:
|
||||
return constexpr(self.value != other.value)
|
||||
|
||||
def __bool__(self):
|
||||
return constexpr(bool(self.value))
|
||||
return bool(self.value)
|
||||
|
||||
def __neg__(self):
|
||||
return constexpr(-self.value)
|
||||
|
||||
|
||||
def __pos__(self):
|
||||
return constexpr(+self.value)
|
||||
|
||||
|
||||
def __invert__(self):
|
||||
return constexpr(~self.value)
|
||||
|
||||
@@ -614,9 +603,9 @@ class tensor:
|
||||
assert False, "unsupported"
|
||||
return ret
|
||||
|
||||
# x[:, None, :, None]
|
||||
# x = expand_dims(x, axis=1)
|
||||
# x = expand_dims(x, axis=2)
|
||||
@property
|
||||
def T(self):
|
||||
assert False, "Transposition must be created by the AST Visitor"
|
||||
|
||||
@builtin
|
||||
def to(self, dtype, bitcast=False, _builder=None):
|
||||
@@ -737,6 +726,9 @@ def broadcast_to(input, shape, _builder=None):
|
||||
"""
|
||||
return semantic.broadcast_impl_shape(input, shape, _builder)
|
||||
|
||||
@builtin
|
||||
def trans(input, _builder=None):
|
||||
return semantic.trans(input, _builder)
|
||||
|
||||
@builtin
|
||||
def cat(input, other, _builder=None):
|
||||
@@ -766,6 +758,10 @@ def view(input, shape, _builder=None):
|
||||
shape = [x.value for x in shape]
|
||||
return semantic.view(input, shape, _builder)
|
||||
|
||||
@builtin
|
||||
def reshape(input, shape, _builder=None):
|
||||
# TODO: should be more than just a view
|
||||
return view(input, shape, _builder)
|
||||
|
||||
# -----------------------
|
||||
# Linear Algebra
|
||||
@@ -773,7 +769,7 @@ def view(input, shape, _builder=None):
|
||||
|
||||
|
||||
@builtin
|
||||
def dot(input, other, allow_tf32=True, trans_a=False, trans_b=False, _builder=None):
|
||||
def dot(input, other, allow_tf32=True, _builder=None):
|
||||
"""
|
||||
Returns the matrix product of two blocks.
|
||||
|
||||
@@ -785,7 +781,7 @@ def dot(input, other, allow_tf32=True, trans_a=False, trans_b=False, _builder=No
|
||||
:type other: 2D tensor of scalar-type in {:code:`float16`, :code:`bfloat16`, :code:`float32`}
|
||||
"""
|
||||
allow_tf32 = _constexpr_to_value(allow_tf32)
|
||||
return semantic.dot(input, other, allow_tf32, trans_a, trans_b, _builder)
|
||||
return semantic.dot(input, other, allow_tf32, _builder)
|
||||
|
||||
|
||||
# -----------------------
|
||||
@@ -847,9 +843,9 @@ def store(pointer, value, mask=None, _builder=None):
|
||||
# Atomic Memory Operations
|
||||
# -----------------------
|
||||
|
||||
def _add_atomic_docstr(name):
|
||||
def _add_atomic_docstr(name: str) -> Callable[[T], T]:
|
||||
|
||||
def _decorator(func):
|
||||
def _decorator(func: T) -> T:
|
||||
docstr = """
|
||||
Performs an atomic {name} at the memory location specified by :code:`pointer`.
|
||||
|
||||
@@ -970,9 +966,9 @@ def fdiv(x, y, ieee_rounding=False, _builder=None):
|
||||
return semantic.fdiv(x, y, ieee_rounding, _builder)
|
||||
|
||||
|
||||
def _add_math_1arg_docstr(name):
|
||||
def _add_math_1arg_docstr(name: str) -> Callable[[T], T]:
|
||||
|
||||
def _decorator(func):
|
||||
def _decorator(func: T) -> T:
|
||||
docstr = """
|
||||
Computes the element-wise {name} of :code:`x`
|
||||
|
||||
@@ -1019,9 +1015,9 @@ def sqrt(x, _builder=None):
|
||||
# Reductions
|
||||
# -----------------------
|
||||
|
||||
def _add_reduction_docstr(name):
|
||||
def _add_reduction_docstr(name: str) -> Callable[[T], T]:
|
||||
|
||||
def _decorator(func):
|
||||
def _decorator(func: T) -> T:
|
||||
docstr = """
|
||||
Returns the {name} of all elements in the :code:`input` tensor along the provided :code:`axis`
|
||||
|
||||
@@ -1041,6 +1037,13 @@ def max(input, axis, _builder=None):
|
||||
return semantic.max(input, axis, _builder)
|
||||
|
||||
|
||||
@builtin
|
||||
@_add_reduction_docstr("maximum index")
|
||||
def argmax(input, axis, _builder=None):
|
||||
axis = _constexpr_to_value(axis)
|
||||
return semantic.argmax(input, axis, _builder)
|
||||
|
||||
|
||||
@builtin
|
||||
@_add_reduction_docstr("minimum")
|
||||
def min(input, axis, _builder=None):
|
||||
@@ -1048,6 +1051,13 @@ def min(input, axis, _builder=None):
|
||||
return semantic.min(input, axis, _builder)
|
||||
|
||||
|
||||
@builtin
|
||||
@_add_reduction_docstr("minimum index")
|
||||
def argmin(input, axis, _builder=None):
|
||||
axis = _constexpr_to_value(axis)
|
||||
return semantic.argmin(input, axis, _builder)
|
||||
|
||||
|
||||
@builtin
|
||||
@_add_reduction_docstr("sum")
|
||||
def sum(input, axis, _builder=None):
|
||||
|
@@ -86,25 +86,3 @@ def elementwise(lib_name: str, lib_path: str, args: list, arg_type_symbol_dict:
|
||||
ret_shape = broadcast_arg.shape
|
||||
func = getattr(_builder, "create_external_elementwise")
|
||||
return dispatch(func, lib_name, lib_path, dispatch_args, arg_type_symbol_dict, ret_shape, _builder)
|
||||
|
||||
|
||||
class ExternalFunction:
|
||||
'''
|
||||
A wrapper for external functions
|
||||
'''
|
||||
|
||||
def __init__(self, fn):
|
||||
self.fn = fn
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
if '_builder' not in kwargs or \
|
||||
kwargs['_builder'] is None:
|
||||
raise ValueError("Did you forget to add @triton.jit ? (`_builder` argument must be provided outside of JIT functions.)")
|
||||
return self.fn(*args, **kwargs)
|
||||
|
||||
|
||||
def extern(fn):
|
||||
'''
|
||||
A decorator for external functions
|
||||
'''
|
||||
return ExternalFunction(fn)
|
||||
|
File diff suppressed because it is too large
Load Diff
@@ -502,6 +502,11 @@ def cat(lhs: tl.tensor, rhs: tl.tensor, builder: ir.builder) -> tl.tensor:
|
||||
# TODO: check types
|
||||
return tl.tensor(builder.create_cat(lhs.handle, rhs.handle), lhs.type)
|
||||
|
||||
def trans(input: tl.tensor, builder: ir.builder) -> tl.tensor:
|
||||
if len(input.shape) != 2:
|
||||
raise ValueError("Only 2D tensors can be transposed")
|
||||
ret_type = tl.block_type(input.type.scalar, [input.shape[1], input.shape[0]])
|
||||
return tl.tensor(builder.create_trans(input.handle), ret_type)
|
||||
|
||||
def broadcast_impl_shape(input: tl.tensor,
|
||||
shape: List[int],
|
||||
@@ -971,8 +976,6 @@ def atomic_xchg(ptr: tl.tensor,
|
||||
def dot(lhs: tl.tensor,
|
||||
rhs: tl.tensor,
|
||||
allow_tf32: bool,
|
||||
trans_a: bool,
|
||||
trans_b: bool,
|
||||
builder: ir.builder) -> tl.tensor:
|
||||
assert lhs.type.is_block() and rhs.type.is_block()
|
||||
if lhs.type.scalar.is_int():
|
||||
@@ -981,11 +984,11 @@ def dot(lhs: tl.tensor,
|
||||
else:
|
||||
_0 = builder.get_float32(0)
|
||||
ret_scalar_ty = tl.float32
|
||||
M = lhs.type.shape[1 if trans_a else 0]
|
||||
N = rhs.type.shape[0 if trans_b else 1]
|
||||
M = lhs.type.shape[0]
|
||||
N = rhs.type.shape[1]
|
||||
_0 = builder.create_splat(_0, [M, N])
|
||||
ret_ty = tl.block_type(ret_scalar_ty, [M, N])
|
||||
return tl.tensor(builder.create_dot(lhs.handle, rhs.handle, _0, allow_tf32, trans_a, trans_b),
|
||||
return tl.tensor(builder.create_dot(lhs.handle, rhs.handle, _0, allow_tf32),
|
||||
ret_ty)
|
||||
|
||||
|
||||
@@ -1061,10 +1064,18 @@ def min(input: tl.tensor, axis: int, builder: ir.builder) -> tl.tensor:
|
||||
return reduce_impl(input, axis, builder, "min", ir.REDUCE_OP.FMIN, ir.REDUCE_OP.MIN)
|
||||
|
||||
|
||||
def argmin(input: tl.tensor, axis: int, builder: ir.builder) -> tl.tensor:
|
||||
return reduce_impl(input, axis, builder, "argmin", ir.REDUCE_OP.ARGFMIN, ir.REDUCE_OP.ARGMIN)
|
||||
|
||||
|
||||
def max(input: tl.tensor, axis: int, builder: ir.builder) -> tl.tensor:
|
||||
return reduce_impl(input, axis, builder, "max", ir.REDUCE_OP.FMAX, ir.REDUCE_OP.MAX)
|
||||
|
||||
|
||||
def argmax(input: tl.tensor, axis: int, builder: ir.builder) -> tl.tensor:
|
||||
return reduce_impl(input, axis, builder, "argmax", ir.REDUCE_OP.ARGFMAX, ir.REDUCE_OP.ARGMAX)
|
||||
|
||||
|
||||
def sum(input: tl.tensor, axis: int, builder: ir.builder) -> tl.tensor:
|
||||
return reduce_impl(input, axis, builder, "sum", ir.REDUCE_OP.FADD, ir.REDUCE_OP.ADD)
|
||||
|
||||
@@ -1109,16 +1120,16 @@ def sqrt(x: tl.tensor, builder: ir.builder) -> tl.tensor:
|
||||
##
|
||||
|
||||
def multiple_of(x: tl.tensor, values: List[int]) -> tl.tensor:
|
||||
if len(x.shape) != len(values):
|
||||
raise ValueError("Shape of input to multiple_of does not match the length of values")
|
||||
x.handle.multiple_of(values)
|
||||
return x
|
||||
|
||||
|
||||
if len(x.shape) != len(values):
|
||||
raise ValueError("Shape of input to multiple_of does not match the length of values")
|
||||
x.handle.set_attr("tt.divisibility", ir.make_attr(values, x.handle.get_context()))
|
||||
return x
|
||||
|
||||
|
||||
def max_contiguous(x: tl.tensor, values: List[int]) -> tl.tensor:
|
||||
if len(x.shape) != len(values):
|
||||
raise ValueError("Shape of input to max_contiguous does not match the length of values")
|
||||
x.handle.max_contiguous(values)
|
||||
x.handle.set_attr("tt.contiguity", ir.make_attr(values, x.handle.get_context()))
|
||||
return x
|
||||
|
||||
|
||||
|
@@ -1,5 +1,12 @@
|
||||
# flake8: noqa: F401
|
||||
#from .conv import _conv, conv
|
||||
# from .conv import _conv, conv
|
||||
from . import blocksparse
|
||||
from .cross_entropy import _cross_entropy, cross_entropy
|
||||
from .matmul import _matmul, matmul
|
||||
|
||||
__all__ = [
|
||||
"blocksparse",
|
||||
"_cross_entropy",
|
||||
"cross_entropy",
|
||||
"_matmul",
|
||||
"matmul",
|
||||
]
|
||||
|
@@ -1,3 +1,7 @@
|
||||
# flake8: noqa: F401
|
||||
from .matmul import matmul
|
||||
from .softmax import softmax
|
||||
|
||||
__all__ = [
|
||||
"matmul",
|
||||
"softmax",
|
||||
]
|
||||
|
@@ -1,2 +1,12 @@
|
||||
from .autotuner import Config, Heuristics, autotune, heuristics # noqa: F401
|
||||
from .jit import JITFunction, KernelInterface, version_key # noqa: F401
|
||||
from .autotuner import Config, Heuristics, autotune, heuristics
|
||||
from .jit import JITFunction, KernelInterface, version_key
|
||||
|
||||
__all__ = [
|
||||
"Config",
|
||||
"Heuristics",
|
||||
"autotune",
|
||||
"heuristics",
|
||||
"JITFunction",
|
||||
"KernelInterface",
|
||||
"version_key",
|
||||
]
|
||||
|
@@ -8,6 +8,7 @@ import os
|
||||
import subprocess
|
||||
import textwrap
|
||||
from collections import namedtuple
|
||||
from typing import TypeVar, Generic, cast, Callable, overload, Optional, Iterable, Union
|
||||
|
||||
import torch
|
||||
|
||||
@@ -19,6 +20,9 @@ try:
|
||||
except ImportError:
|
||||
get_cuda_stream = lambda dev_idx: torch.cuda.current_stream(dev_idx).cuda_stream
|
||||
|
||||
|
||||
T = TypeVar('T')
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Dependencies Finder
|
||||
# -----------------------------------------------------------------------------
|
||||
@@ -94,20 +98,20 @@ def version_key():
|
||||
return '-'.join(triton.__version__) + '-' + ptxas_version + '-' + '-'.join(contents)
|
||||
|
||||
|
||||
class KernelInterface:
|
||||
class KernelInterface(Generic[T]):
|
||||
run: T
|
||||
|
||||
def __getitem__(self, grid):
|
||||
def __getitem__(self, grid) -> T:
|
||||
"""
|
||||
A JIT function is launched with: fn[grid](*args, **kwargs).
|
||||
Hence JITFunction.__getitem__ returns a callable proxy that
|
||||
memorizes the grid.
|
||||
"""
|
||||
def launcher(*args, **kwargs):
|
||||
return self.run(*args, grid=grid, **kwargs)
|
||||
return launcher
|
||||
return cast(T, functools.partial(cast(Callable, self.run), grid=grid))
|
||||
|
||||
|
||||
class JITFunction(KernelInterface):
|
||||
|
||||
class JITFunction(KernelInterface[T]):
|
||||
|
||||
cache_hook = None
|
||||
divisibility = 16
|
||||
@@ -367,25 +371,55 @@ def {self.fn.__name__}({', '.join(self.arg_names)}, grid, num_warps=4, num_stage
|
||||
# -----------------------------------------------------------------------------
|
||||
|
||||
|
||||
def jit(*args, **kwargs):
|
||||
@overload
|
||||
def jit(fn: T) -> JITFunction[T]:
|
||||
...
|
||||
|
||||
|
||||
@overload
|
||||
def jit(
|
||||
*,
|
||||
version=None,
|
||||
do_not_specialize: Optional[Iterable[int]] = None,
|
||||
) -> Callable[[T], JITFunction[T]]:
|
||||
...
|
||||
|
||||
|
||||
def jit(
|
||||
fn: Optional[T] = None,
|
||||
*,
|
||||
version=None,
|
||||
do_not_specialize: Optional[Iterable[int]] = None,
|
||||
) -> Union[JITFunction[T], Callable[[T], JITFunction[T]]]:
|
||||
"""
|
||||
Decorator for JIT-compiling a function using the Triton compiler.
|
||||
:note: When a jit'd function is called, :code:`torch.tensor` arguments are implicitly converted to pointers using the :code:`.data_ptr()` method.
|
||||
|
||||
:note: When a jit'd function is called, :code:`torch.tensor` arguments are
|
||||
implicitly converted to pointers using the :code:`.data_ptr()` method.
|
||||
|
||||
:note: This function will be compiled and run on the GPU. It will only have access to:
|
||||
|
||||
* python primitives,
|
||||
* objects within the triton.language package,
|
||||
* builtins within the triton package,
|
||||
* arguments to this function,
|
||||
* other jit'd functions
|
||||
|
||||
:param fn: the function to be jit-compiled
|
||||
:type fn: Callable
|
||||
"""
|
||||
if args:
|
||||
assert len(args) == 1
|
||||
assert callable(args[0])
|
||||
return JITFunction(args[0], **kwargs)
|
||||
|
||||
def decorator(fn: T) -> JITFunction[T]:
|
||||
assert callable(fn)
|
||||
return JITFunction(
|
||||
fn,
|
||||
version=version,
|
||||
do_not_specialize=do_not_specialize,
|
||||
)
|
||||
|
||||
if fn is not None:
|
||||
return decorator(fn)
|
||||
|
||||
else:
|
||||
def decorator(fn):
|
||||
return JITFunction(fn, **kwargs)
|
||||
return decorator
|
||||
|
||||
|
||||
|
@@ -34,12 +34,12 @@ def sparsify_tensor(x, mask, block):
|
||||
return ret
|
||||
|
||||
|
||||
def make_pair(shape, device="cuda", alpha=1e-2, beta=0., trans=False, data=None):
|
||||
def make_pair(shape, device="cuda", alpha=1e-2, beta=0., trans=False, data=None, dtype=torch.float32):
|
||||
if data is None:
|
||||
data = torch.randn(shape, dtype=torch.float32, device=device)
|
||||
data = torch.randn(shape, dtype=torch.float32, requires_grad=True, device=device)
|
||||
ref_ret = data
|
||||
ref_ret = ref_ret * alpha + beta
|
||||
ref_ret = ref_ret.half().float()
|
||||
ref_ret = ref_ret.half().to(dtype)
|
||||
if trans:
|
||||
ref_ret = ref_ret.t().requires_grad_()
|
||||
ref_ret = ref_ret.detach().requires_grad_()
|
||||
|
@@ -257,5 +257,5 @@ def bench_layer_norm(M, N, dtype, provider, mode='backward', eps=1e-5, device='c
|
||||
grad_to_none=[x], rep=500)
|
||||
return gbps(ms), gbps(max_ms), gbps(min_ms)
|
||||
|
||||
|
||||
# test_layer_norm(1151, 8192, torch.float16)
|
||||
bench_layer_norm.run(save_path='.', print_data=True)
|
||||
|
@@ -50,7 +50,7 @@ def _fwd_kernel(
|
||||
# -- compute qk ----
|
||||
k = tl.load(k_ptrs + start_n * stride_kn)
|
||||
qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
|
||||
qk += tl.dot(q, k, trans_b=True)
|
||||
qk += tl.dot(q, tl.trans(k))
|
||||
qk *= sm_scale
|
||||
qk += tl.where(offs_m[:, None] >= (start_n + offs_n[None, :]), 0, float("-inf"))
|
||||
# -- compute m_ij, p, l_ij
|
||||
@@ -165,26 +165,26 @@ def _bwd_kernel(
|
||||
q = tl.load(q_ptrs)
|
||||
# recompute p = softmax(qk, dim=-1).T
|
||||
# NOTE: `do` is pre-divided by `l`; no normalization here
|
||||
qk = tl.dot(q, k, trans_b=True)
|
||||
qk = tl.dot(q, tl.trans(k))
|
||||
qk = tl.where(offs_m_curr[:, None] >= (offs_n[None, :]), qk, float("-inf"))
|
||||
m = tl.load(m_ptrs + offs_m_curr)
|
||||
p = tl.exp(qk * sm_scale - m[:, None])
|
||||
# compute dv
|
||||
do = tl.load(do_ptrs)
|
||||
dv += tl.dot(p.to(tl.float16), do, trans_a=True)
|
||||
dv += tl.dot(tl.trans(p.to(tl.float16)), do)
|
||||
# compute dp = dot(v, do)
|
||||
Di = tl.load(D_ptrs + offs_m_curr)
|
||||
dp = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) - Di[:, None]
|
||||
dp += tl.dot(do, v, trans_b=True)
|
||||
dp += tl.dot(do, tl.trans(v))
|
||||
# compute ds = p * (dp - delta[:, None])
|
||||
ds = p * dp * sm_scale
|
||||
# compute dk = dot(ds.T, q)
|
||||
dk += tl.dot(ds.to(tl.float16), q, trans_a=True)
|
||||
# # compute dq
|
||||
dk += tl.dot(tl.trans(ds.to(tl.float16)), q)
|
||||
# compute dq
|
||||
dq = tl.load(dq_ptrs)
|
||||
dq += tl.dot(ds.to(tl.float16), k)
|
||||
tl.store(dq_ptrs, dq)
|
||||
# # increment pointers
|
||||
# increment pointers
|
||||
dq_ptrs += BLOCK_M * stride_qm
|
||||
q_ptrs += BLOCK_M * stride_qm
|
||||
do_ptrs += BLOCK_M * stride_qm
|
||||
@@ -195,6 +195,7 @@ def _bwd_kernel(
|
||||
tl.store(dk_ptrs, dk)
|
||||
|
||||
|
||||
empty = torch.empty(128, device="cuda")
|
||||
class _attention(torch.autograd.Function):
|
||||
|
||||
@staticmethod
|
||||
@@ -205,7 +206,7 @@ class _attention(torch.autograd.Function):
|
||||
assert Lq == Lk and Lk == Lv
|
||||
assert Lk in {16, 32, 64, 128}
|
||||
o = torch.empty_like(q)
|
||||
grid = (triton.cdiv(q.shape[2], BLOCK), q.shape[0] * q.shape[1])
|
||||
grid = (triton.cdiv(q.shape[2], BLOCK), q.shape[0] * q.shape[1], 1)
|
||||
tmp = torch.empty((q.shape[0] * q.shape[1], q.shape[2]), device=q.device, dtype=torch.float32)
|
||||
L = torch.empty((q.shape[0] * q.shape[1], q.shape[2]), device=q.device, dtype=torch.float32)
|
||||
m = torch.empty((q.shape[0] * q.shape[1], q.shape[2]), device=q.device, dtype=torch.float32)
|
||||
@@ -224,6 +225,7 @@ class _attention(torch.autograd.Function):
|
||||
BLOCK_DMODEL=Lk, num_warps=num_warps,
|
||||
num_stages=1,
|
||||
)
|
||||
|
||||
ctx.save_for_backward(q, k, v, o, L, m)
|
||||
ctx.BLOCK = BLOCK
|
||||
ctx.grid = grid
|
||||
@@ -268,13 +270,13 @@ class _attention(torch.autograd.Function):
|
||||
attention = _attention.apply
|
||||
|
||||
|
||||
@pytest.mark.parametrize('Z, H, N_CTX, D_HEAD', [(3, 2, 2048, 64)])
|
||||
@pytest.mark.parametrize('Z, H, N_CTX, D_HEAD', [(4, 48, 1024, 64)])
|
||||
def test_op(Z, H, N_CTX, D_HEAD, dtype=torch.float16):
|
||||
torch.manual_seed(20)
|
||||
q = torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0, std=.5).requires_grad_()
|
||||
k = torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0, std=.5).requires_grad_()
|
||||
v = torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0, std=.5).requires_grad_()
|
||||
sm_scale = 0.3
|
||||
q = torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0.1, std=0.2).requires_grad_()
|
||||
k = torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0.4, std=0.2).requires_grad_()
|
||||
v = torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0.3, std=0.2).requires_grad_()
|
||||
sm_scale = 0.2
|
||||
dout = torch.randn_like(q)
|
||||
# reference implementation
|
||||
M = torch.tril(torch.ones((N_CTX, N_CTX), device="cuda"))
|
||||
@@ -283,13 +285,16 @@ def test_op(Z, H, N_CTX, D_HEAD, dtype=torch.float16):
|
||||
for h in range(H):
|
||||
p[:, :, M == 0] = float("-inf")
|
||||
p = torch.softmax(p.float(), dim=-1).half()
|
||||
# p = torch.exp(p)
|
||||
ref_out = torch.matmul(p, v)
|
||||
ref_out.backward(dout)
|
||||
ref_dv, v.grad = v.grad.clone(), None
|
||||
ref_dk, k.grad = k.grad.clone(), None
|
||||
ref_dq, q.grad = q.grad.clone(), None
|
||||
# triton implementation
|
||||
# # triton implementation
|
||||
tri_out = attention(q, k, v, sm_scale)
|
||||
# print(ref_out)
|
||||
# print(tri_out)
|
||||
tri_out.backward(dout)
|
||||
tri_dv, v.grad = v.grad.clone(), None
|
||||
tri_dk, k.grad = k.grad.clone(), None
|
||||
@@ -299,3 +304,50 @@ def test_op(Z, H, N_CTX, D_HEAD, dtype=torch.float16):
|
||||
triton.testing.assert_almost_equal(ref_dv, tri_dv)
|
||||
triton.testing.assert_almost_equal(ref_dk, tri_dk)
|
||||
triton.testing.assert_almost_equal(ref_dq, tri_dq)
|
||||
|
||||
BATCH, N_HEADS, N_CTX, D_HEAD = 4, 48, 4096, 64
|
||||
# vary seq length for fixed head and batch=4
|
||||
configs = [triton.testing.Benchmark(
|
||||
x_names=['N_CTX'],
|
||||
x_vals=[2**i for i in range(10, 16)],
|
||||
line_arg='provider',
|
||||
line_vals=['triton'],
|
||||
line_names=['Triton'],
|
||||
styles=[('red', '-'), ('blue', '-')],
|
||||
ylabel='ms',
|
||||
plot_name=f'fused-attention-batch{BATCH}-head{N_HEADS}-d{D_HEAD}-{mode}',
|
||||
args={'H': N_HEADS, 'BATCH': BATCH, 'D_HEAD': D_HEAD, 'dtype': torch.float16, 'mode': mode}
|
||||
) for mode in ['fwd']]
|
||||
|
||||
|
||||
@triton.testing.perf_report(configs)
|
||||
def bench_flash_attention(BATCH, H, N_CTX, D_HEAD, mode, provider, dtype=torch.float16, device="cuda"):
|
||||
assert mode in ['fwd', 'bwd']
|
||||
warmup = 25
|
||||
rep = 100
|
||||
if provider == "triton":
|
||||
q = torch.randn((BATCH, H, N_CTX, D_HEAD), dtype=dtype, device="cuda", requires_grad=True)
|
||||
k = torch.randn((BATCH, H, N_CTX, D_HEAD), dtype=dtype, device="cuda", requires_grad=True)
|
||||
v = torch.randn((BATCH, H, N_CTX, D_HEAD), dtype=dtype, device="cuda", requires_grad=True)
|
||||
sm_scale = 1.3
|
||||
fn = lambda: attention(q, k, v, sm_scale)
|
||||
if mode == 'bwd':
|
||||
o = fn()
|
||||
do = torch.randn_like(o)
|
||||
fn = lambda: o.backward(do, retain_graph=True)
|
||||
ms = triton.testing.do_bench(fn, percentiles=None, warmup=warmup, rep=rep)
|
||||
return ms
|
||||
if provider == "flash":
|
||||
lengths = torch.full((BATCH,), fill_value=N_CTX, device=device)
|
||||
cu_seqlens = torch.zeros((BATCH + 1,), device=device, dtype=torch.int32)
|
||||
cu_seqlens[1:] = lengths.cumsum(0)
|
||||
qkv = torch.randn((BATCH * N_CTX, 3, H, D_HEAD), dtype=dtype, device=device, requires_grad=True)
|
||||
fn = lambda: flash_attn_func(qkv, cu_seqlens, 0., N_CTX, causal=True)
|
||||
if mode == 'bwd':
|
||||
o = fn()
|
||||
do = torch.randn_like(o)
|
||||
fn = lambda: o.backward(do, retain_graph=True)
|
||||
ms = triton.testing.do_bench(fn, percentiles=None, warmup=warmup, rep=rep)
|
||||
return ms
|
||||
|
||||
# bench_flash_attention.run(save_path='.', print_data=True)
|
@@ -261,9 +261,9 @@ func @for_alias(%lb : index, %ub : index, %step : index, %A : !tt.ptr<f16>, %B :
|
||||
%cst0 = tt.cat %a_shared_init, %b_shared_init {axis = 0} : (tensor<128x32xf16, #A_SHARED>, tensor<128x32xf16, #A_SHARED>) -> tensor<256x32xf16, #A_SHARED>
|
||||
%c_shared_init = arith.constant dense<0.00e+00> : tensor<128x32xf16, #A_SHARED>
|
||||
%a_shared, %b_shared, %c_shared = scf.for %iv = %lb to %ub step %step iter_args(%a_shared = %a_shared_init, %b_shared = %b_shared_init, %c_shared = %c_shared_init) -> (tensor<128x32xf16, #A_SHARED>, tensor<128x32xf16, #A_SHARED>, tensor<128x32xf16, #A_SHARED>) {
|
||||
%cst1 = tt.cat %a_shared_init, %b_shared_init {axis = 0} : (tensor<128x32xf16, #A_SHARED>, tensor<128x32xf16, #A_SHARED>) -> tensor<256x32xf16, #A_SHARED>
|
||||
%cst1 = tt.cat %a_shared_init, %b_shared_init {axis = 0} : (tensor<128x32xf16, #A_SHARED>, tensor<128x32xf16, #A_SHARED>) -> tensor<256x32xf16, #AL>
|
||||
// CHECK-NEXT: Membar 6
|
||||
%cst2 = tt.cat %a_shared, %b_shared {axis = 0} : (tensor<128x32xf16, #A_SHARED>, tensor<128x32xf16, #A_SHARED>) -> tensor<256x32xf16, #A_SHARED>
|
||||
%cst2 = tt.cat %a_shared, %b_shared {axis = 0} : (tensor<128x32xf16, #A_SHARED>, tensor<128x32xf16, #A_SHARED>) -> tensor<256x32xf16, #AL>
|
||||
scf.yield %c_shared, %a_shared, %b_shared : tensor<128x32xf16, #A_SHARED>, tensor<128x32xf16, #A_SHARED>, tensor<128x32xf16, #A_SHARED>
|
||||
}
|
||||
// CHECK-NEXT: Membar 9
|
||||
@@ -271,4 +271,48 @@ func @for_alias(%lb : index, %ub : index, %step : index, %A : !tt.ptr<f16>, %B :
|
||||
return
|
||||
}
|
||||
|
||||
// Although cst2 is not an argument of scf.yield, its memory is reused by cst1.
|
||||
// So we need a barrier both before and after cst1
|
||||
// CHECK-LABEL: for_reuse
|
||||
func @for_reuse(%lb : index, %ub : index, %step : index, %A : !tt.ptr<f16>, %B : !tt.ptr<f16>) {
|
||||
%a_shared_init = arith.constant dense<0.00e+00> : tensor<128x32xf16, #A_SHARED>
|
||||
%b_shared_init = arith.constant dense<0.00e+00> : tensor<128x32xf16, #A_SHARED>
|
||||
// CHECK-NEXT: Membar 2
|
||||
%cst0 = tt.cat %a_shared_init, %b_shared_init {axis = 0} : (tensor<128x32xf16, #A_SHARED>, tensor<128x32xf16, #A_SHARED>) -> tensor<256x32xf16, #A_SHARED>
|
||||
%c_shared_init = arith.constant dense<0.00e+00> : tensor<128x32xf16, #A_SHARED>
|
||||
%a_shared, %b_shared, %c_shared = scf.for %iv = %lb to %ub step %step iter_args(%a_shared = %a_shared_init, %b_shared = %b_shared_init, %c_shared = %c_shared_init) -> (tensor<128x32xf16, #A_SHARED>, tensor<128x32xf16, #A_SHARED>, tensor<128x32xf16, #A_SHARED>) {
|
||||
// CHECK-NEXT: Membar 5
|
||||
%cst1 = tt.cat %a_shared_init, %b_shared_init {axis = 0} : (tensor<128x32xf16, #A_SHARED>, tensor<128x32xf16, #A_SHARED>) -> tensor<256x32xf16, #A_SHARED>
|
||||
// CHECK-NEXT: Membar 7
|
||||
%cst2 = tt.cat %a_shared, %b_shared {axis = 0} : (tensor<128x32xf16, #A_SHARED>, tensor<128x32xf16, #A_SHARED>) -> tensor<256x32xf16, #A_SHARED>
|
||||
scf.yield %c_shared, %a_shared, %b_shared : tensor<128x32xf16, #A_SHARED>, tensor<128x32xf16, #A_SHARED>, tensor<128x32xf16, #A_SHARED>
|
||||
}
|
||||
// CHECK-NEXT: Membar 10
|
||||
%cst3 = tt.cat %cst0, %cst0 {axis = 0} : (tensor<256x32xf16, #A_SHARED>, tensor<256x32xf16, #A_SHARED>) -> tensor<512x32xf16, #A_SHARED>
|
||||
return
|
||||
}
|
||||
|
||||
|
||||
// CHECK-LABEL: for_reuse_nested
|
||||
func @for_reuse_nested(%lb : index, %ub : index, %step : index, %A : !tt.ptr<f16>, %B : !tt.ptr<f16>) {
|
||||
%a_shared_init = arith.constant dense<0.00e+00> : tensor<128x32xf16, #A_SHARED>
|
||||
%b_shared_init = arith.constant dense<0.00e+00> : tensor<128x32xf16, #A_SHARED>
|
||||
// CHECK-NEXT: Membar 2
|
||||
%cst0 = tt.cat %a_shared_init, %b_shared_init {axis = 0} : (tensor<128x32xf16, #A_SHARED>, tensor<128x32xf16, #A_SHARED>) -> tensor<256x32xf16, #A_SHARED>
|
||||
%c_shared_init = arith.constant dense<0.00e+00> : tensor<128x32xf16, #A_SHARED>
|
||||
%a_shared, %b_shared, %c_shared = scf.for %iv = %lb to %ub step %step iter_args(%a_shared = %a_shared_init, %b_shared = %b_shared_init, %c_shared = %c_shared_init) -> (tensor<128x32xf16, #A_SHARED>, tensor<128x32xf16, #A_SHARED>, tensor<128x32xf16, #A_SHARED>) {
|
||||
// CHECK-NEXT: Membar 5
|
||||
%cst1 = tt.cat %a_shared_init, %b_shared_init {axis = 0} : (tensor<128x32xf16, #A_SHARED>, tensor<128x32xf16, #A_SHARED>) -> tensor<256x32xf16, #A_SHARED>
|
||||
%a_shared_next, %b_shared_next, %c_shared_next = scf.for %ivv = %lb to %ub step %step iter_args(%a_shared_nested = %a_shared_init, %b_shared_nested = %b_shared_init, %c_shared_nested = %c_shared_init) -> (tensor<128x32xf16, #A_SHARED>, tensor<128x32xf16, #A_SHARED>, tensor<128x32xf16, #A_SHARED>) {
|
||||
// CHECK-NEXT: Membar 7
|
||||
%cst2 = tt.cat %a_shared_nested, %b_shared_nested {axis = 0} : (tensor<128x32xf16, #A_SHARED>, tensor<128x32xf16, #A_SHARED>) -> tensor<256x32xf16, #A_SHARED>
|
||||
scf.yield %c_shared_nested, %a_shared_nested, %b_shared_nested : tensor<128x32xf16, #A_SHARED>, tensor<128x32xf16, #A_SHARED>, tensor<128x32xf16, #A_SHARED>
|
||||
}
|
||||
scf.yield %c_shared, %a_shared, %b_shared : tensor<128x32xf16, #A_SHARED>, tensor<128x32xf16, #A_SHARED>, tensor<128x32xf16, #A_SHARED>
|
||||
}
|
||||
// CHECK-NEXT: Membar 11
|
||||
%cst3 = tt.cat %cst0, %cst0 {axis = 0} : (tensor<256x32xf16, #A_SHARED>, tensor<256x32xf16, #A_SHARED>) -> tensor<512x32xf16, #A_SHARED>
|
||||
return
|
||||
}
|
||||
|
||||
}
|
||||
|
@@ -387,6 +387,45 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} {
|
||||
|
||||
// -----
|
||||
|
||||
#block0 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [4], warpsPerCTA = [4], order = [0]}>
|
||||
#block1 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [8], warpsPerCTA = [4], order = [0]}>
|
||||
#block2 = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [4, 1], warpsPerCTA = [4, 1], order = [1, 0]}>
|
||||
#block3 = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 8], warpsPerCTA = [1, 4], order = [1, 0]}>
|
||||
#slice2d1 = #triton_gpu.slice<{dim = 1, parent=#block2}>
|
||||
#slice3d0 = #triton_gpu.slice<{dim = 0, parent=#block3}>
|
||||
#AL = #triton_gpu.blocked<{sizePerThread = [1, 8], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}>
|
||||
#A = #triton_gpu.shared<{vec = 8, perPhase = 1, maxPhase = 4, order = [1, 0]}>
|
||||
module attributes {"triton_gpu.num-warps" = 4 : i32} {
|
||||
// CHECK-LABEL: basic_insert_slice_async_fallback
|
||||
func @basic_insert_slice_async_fallback(%arg0: !tt.ptr<f16> {tt.divisibility = 1 : i32}) {
|
||||
%off0_ = tt.make_range {end = 16 : i32, start = 0 : i32} : tensor<16xi32, #slice2d1>
|
||||
%off1_ = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<64xi32, #slice3d0>
|
||||
%off0 = tt.expand_dims %off0_ {axis = 1 : i32} : (tensor<16xi32, #slice2d1>) -> tensor<16x1xi32, #block2>
|
||||
%off1 = tt.expand_dims %off1_ {axis = 0 : i32} : (tensor<64xi32, #slice3d0>) -> tensor<1x64xi32, #block3>
|
||||
%broadcast_off0_scalar = tt.broadcast %off0 : (tensor<16x1xi32, #block2>) -> tensor<16x64xi32, #block2>
|
||||
%cst_scalar = arith.constant 64 : i32
|
||||
%cst = tt.splat %cst_scalar : (i32) -> tensor<16x64xi32, #block2>
|
||||
%broadcast_off0_ = arith.muli %broadcast_off0_scalar, %cst : tensor<16x64xi32, #block2>
|
||||
%broadcast_off1_ = tt.broadcast %off1 : (tensor<1x64xi32, #block3>) -> tensor<16x64xi32, #block3>
|
||||
%broadcast_off0 = triton_gpu.convert_layout %broadcast_off0_ : (tensor<16x64xi32, #block2>) -> tensor<16x64xi32, #AL>
|
||||
%broadcast_off1 = triton_gpu.convert_layout %broadcast_off1_ : (tensor<16x64xi32, #block3>) -> tensor<16x64xi32, #AL>
|
||||
%off = arith.addi %broadcast_off0, %broadcast_off1 : tensor<16x64xi32, #AL>
|
||||
%a_init = tt.splat %arg0 : (!tt.ptr<f16>) -> tensor<16x64x!tt.ptr<f16>, #AL>
|
||||
%a_ptr = tt.addptr %a_init, %off : tensor<16x64x!tt.ptr<f16>, #AL>
|
||||
%tensor = triton_gpu.alloc_tensor : tensor<2x16x64xf16, #A>
|
||||
%index = arith.constant 1 : i32
|
||||
|
||||
// CHECK: llvm.load
|
||||
// CHECK-SAME: !llvm.ptr<vector<8xi32>, 3>
|
||||
// CHECK: llvm.load
|
||||
// CHECK-SAME: !llvm.ptr<vector<8xi32>, 3>
|
||||
%a = triton_gpu.insert_slice_async %a_ptr, %tensor, %index {axis = 0 : i32, cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<16x64x!tt.ptr<f16>, #AL> -> tensor<2x16x64xf16, #A>
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
#block0 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [4], warpsPerCTA = [4], order = [0]}>
|
||||
#block1 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [8], warpsPerCTA = [4], order = [0]}>
|
||||
#block2 = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [4, 1], warpsPerCTA = [4, 1], order = [1, 0]}>
|
||||
@@ -428,6 +467,100 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} {
|
||||
|
||||
// -----
|
||||
|
||||
#block0 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [4], warpsPerCTA = [4], order = [0]}>
|
||||
#block1 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [8], warpsPerCTA = [4], order = [0]}>
|
||||
#block2 = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [4, 1], warpsPerCTA = [4, 1], order = [1, 0]}>
|
||||
#block3 = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 8], warpsPerCTA = [1, 4], order = [1, 0]}>
|
||||
#slice2d1 = #triton_gpu.slice<{dim = 1, parent=#block2}>
|
||||
#slice3d0 = #triton_gpu.slice<{dim = 0, parent=#block3}>
|
||||
#AL = #triton_gpu.blocked<{sizePerThread = [1, 8], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}>
|
||||
#A = #triton_gpu.shared<{vec = 8, perPhase = 1, maxPhase = 4, order = [1, 0]}>
|
||||
module attributes {"triton_gpu.num-warps" = 4 : i32} {
|
||||
// CHECK-LABEL: basic_insert_slice_async_mask
|
||||
func @basic_insert_slice_async_mask(%arg0: !tt.ptr<f32> {tt.divisibility = 8 : i32}) {
|
||||
%off0_ = tt.make_range {end = 16 : i32, start = 0 : i32} : tensor<16xi32, #slice2d1>
|
||||
%off1_ = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<64xi32, #slice3d0>
|
||||
%off0 = tt.expand_dims %off0_ {axis = 1 : i32} : (tensor<16xi32, #slice2d1>) -> tensor<16x1xi32, #block2>
|
||||
%off1 = tt.expand_dims %off1_ {axis = 0 : i32} : (tensor<64xi32, #slice3d0>) -> tensor<1x64xi32, #block3>
|
||||
%broadcast_off0_scalar = tt.broadcast %off0 : (tensor<16x1xi32, #block2>) -> tensor<16x64xi32, #block2>
|
||||
%cst_scalar = arith.constant 64 : i32
|
||||
%cst = tt.splat %cst_scalar : (i32) -> tensor<16x64xi32, #block2>
|
||||
%broadcast_off0_ = arith.muli %broadcast_off0_scalar, %cst : tensor<16x64xi32, #block2>
|
||||
%broadcast_off1_ = tt.broadcast %off1 : (tensor<1x64xi32, #block3>) -> tensor<16x64xi32, #block3>
|
||||
%broadcast_off0 = triton_gpu.convert_layout %broadcast_off0_ : (tensor<16x64xi32, #block2>) -> tensor<16x64xi32, #AL>
|
||||
%broadcast_off1 = triton_gpu.convert_layout %broadcast_off1_ : (tensor<16x64xi32, #block3>) -> tensor<16x64xi32, #AL>
|
||||
%off = arith.addi %broadcast_off0, %broadcast_off1 : tensor<16x64xi32, #AL>
|
||||
%a_init = tt.splat %arg0 : (!tt.ptr<f32>) -> tensor<16x64x!tt.ptr<f32>, #AL>
|
||||
%a_ptr = tt.addptr %a_init, %off : tensor<16x64x!tt.ptr<f32>, #AL>
|
||||
%tensor = triton_gpu.alloc_tensor : tensor<2x16x64xf32, #A>
|
||||
%index = arith.constant 1 : i32
|
||||
%true = arith.constant 1 : i1
|
||||
%true_tensor = tt.splat %true : (i1) -> tensor<16x64xi1, #AL>
|
||||
|
||||
// CHECK: llvm.select
|
||||
// CHECK: llvm.inline_asm has_side_effects asm_dialect = att
|
||||
// CHECK-SAME: cp.async.cg.shared.global [ ${{.*}} + 0 ], [ ${{.*}} + 0 ], 0x10, ${{.*}}
|
||||
// CHECK: llvm.select
|
||||
// CHECK: llvm.inline_asm has_side_effects asm_dialect = att
|
||||
// CHECK-SAME: cp.async.cg.shared.global [ ${{.*}} + 16 ], [ ${{.*}} + 0 ], 0x10, ${{.*}}
|
||||
// CHECK: llvm.inline_asm has_side_effects asm_dialect = att
|
||||
// CHECK-SAME: cp.async.commit_group
|
||||
%a = triton_gpu.insert_slice_async %a_ptr, %tensor, %index, %true_tensor {axis = 0 : i32, cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<16x64x!tt.ptr<f32>, #AL> -> tensor<2x16x64xf32, #A>
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
#block0 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [4], warpsPerCTA = [4], order = [0]}>
|
||||
#block1 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [8], warpsPerCTA = [4], order = [0]}>
|
||||
#block2 = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [4, 1], warpsPerCTA = [4, 1], order = [1, 0]}>
|
||||
#block3 = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 8], warpsPerCTA = [1, 4], order = [1, 0]}>
|
||||
#slice2d1 = #triton_gpu.slice<{dim = 1, parent=#block2}>
|
||||
#slice3d0 = #triton_gpu.slice<{dim = 0, parent=#block3}>
|
||||
#AL = #triton_gpu.blocked<{sizePerThread = [1, 8], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}>
|
||||
#A = #triton_gpu.shared<{vec = 8, perPhase = 1, maxPhase = 4, order = [1, 0]}>
|
||||
module attributes {"triton_gpu.num-warps" = 4 : i32} {
|
||||
// CHECK-LABEL: basic_insert_slice_async_mask_other
|
||||
func @basic_insert_slice_async_mask_other(%arg0: !tt.ptr<f32> {tt.divisibility = 8 : i32}) {
|
||||
%off0_ = tt.make_range {end = 16 : i32, start = 0 : i32} : tensor<16xi32, #slice2d1>
|
||||
%off1_ = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<64xi32, #slice3d0>
|
||||
%off0 = tt.expand_dims %off0_ {axis = 1 : i32} : (tensor<16xi32, #slice2d1>) -> tensor<16x1xi32, #block2>
|
||||
%off1 = tt.expand_dims %off1_ {axis = 0 : i32} : (tensor<64xi32, #slice3d0>) -> tensor<1x64xi32, #block3>
|
||||
%broadcast_off0_scalar = tt.broadcast %off0 : (tensor<16x1xi32, #block2>) -> tensor<16x64xi32, #block2>
|
||||
%cst_scalar = arith.constant 64 : i32
|
||||
%cst = tt.splat %cst_scalar : (i32) -> tensor<16x64xi32, #block2>
|
||||
%broadcast_off0_ = arith.muli %broadcast_off0_scalar, %cst : tensor<16x64xi32, #block2>
|
||||
%broadcast_off1_ = tt.broadcast %off1 : (tensor<1x64xi32, #block3>) -> tensor<16x64xi32, #block3>
|
||||
%broadcast_off0 = triton_gpu.convert_layout %broadcast_off0_ : (tensor<16x64xi32, #block2>) -> tensor<16x64xi32, #AL>
|
||||
%broadcast_off1 = triton_gpu.convert_layout %broadcast_off1_ : (tensor<16x64xi32, #block3>) -> tensor<16x64xi32, #AL>
|
||||
%off = arith.addi %broadcast_off0, %broadcast_off1 : tensor<16x64xi32, #AL>
|
||||
%a_init = tt.splat %arg0 : (!tt.ptr<f32>) -> tensor<16x64x!tt.ptr<f32>, #AL>
|
||||
%a_ptr = tt.addptr %a_init, %off : tensor<16x64x!tt.ptr<f32>, #AL>
|
||||
%tensor = triton_gpu.alloc_tensor : tensor<2x16x64xf32, #A>
|
||||
%index = arith.constant 1 : i32
|
||||
%true = arith.constant 1 : i1
|
||||
%true_tensor = tt.splat %true : (i1) -> tensor<16x64xi1, #AL>
|
||||
%other = arith.constant 1.0 : f32
|
||||
%other_tensor = tt.splat %other : (f32) -> tensor<16x64xf32, #AL>
|
||||
|
||||
// CHECK: llvm.inline_asm has_side_effects asm_dialect = att
|
||||
// CHECK-SAME: @${{.*}} st.shared.v4.b32 [ ${{.*}} + 0 ]
|
||||
// CHECK: llvm.inline_asm has_side_effects asm_dialect = att
|
||||
// CHECK-SAME: cp.async.cg.shared.global [ ${{.*}} + 0 ], [ ${{.*}} + 0 ], 0x10, ${{.*}}
|
||||
// CHECK: llvm.inline_asm has_side_effects asm_dialect = att
|
||||
// CHECK-SAME: @${{.*}} st.shared.v4.b32 [ ${{.*}} + 16 ]
|
||||
// CHECK: llvm.inline_asm has_side_effects asm_dialect = att
|
||||
// CHECK-SAME: cp.async.cg.shared.global [ ${{.*}} + 16 ], [ ${{.*}} + 0 ], 0x10, ${{.*}}
|
||||
// CHECK: llvm.inline_asm has_side_effects asm_dialect = att
|
||||
// CHECK-SAME: cp.async.commit_group
|
||||
%a = triton_gpu.insert_slice_async %a_ptr, %tensor, %index, %true_tensor, %other_tensor {axis = 0 : i32, cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<16x64x!tt.ptr<f32>, #AL> -> tensor<2x16x64xf32, #A>
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
#block0 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [4], warpsPerCTA = [4], order = [0]}>
|
||||
#block1 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [8], warpsPerCTA = [4], order = [0]}>
|
||||
#block2 = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [4, 1], warpsPerCTA = [4, 1], order = [1, 0]}>
|
||||
@@ -712,8 +845,32 @@ module attributes {"triton_gpu.num-warps" = 1 : i32} {
|
||||
#mma = #triton_gpu.mma<{version = 2, warpsPerCTA = [2, 2]}>
|
||||
module attributes {"triton_gpu.num-warps" = 1 : i32} {
|
||||
// CHECK: llvm.mlir.global external @global_smem() {addr_space = 3 : i32} : !llvm.array<0 x i8>
|
||||
// CHECK-LABEL: convert_layout_mma_block
|
||||
func @convert_layout_mma_blocked(%arg0: tensor<32x16xf32, #mma>) {
|
||||
// CHECK-LABEL: convert_layout_mmav2_block
|
||||
func @convert_layout_mmav2_blocked(%arg0: tensor<32x16xf32, #mma>) {
|
||||
// CHECK: llvm.store
|
||||
// CHECK-SAME: !llvm.ptr<vector<2xf32>, 3>
|
||||
// CHECK: llvm.store
|
||||
// CHECK-SAME: !llvm.ptr<vector<2xf32>, 3>
|
||||
// CHECK: nvvm.barrier0
|
||||
// CHECK: llvm.load
|
||||
// CHECK-SAME: !llvm.ptr<vector<4xf32>, 3>
|
||||
%0 = triton_gpu.convert_layout %arg0 : (tensor<32x16xf32, #mma>) -> tensor<32x16xf32, #blocked0>
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
#blocked0 = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [32, 1], warpsPerCTA = [1, 4], order = [1, 0]}>
|
||||
#mma = #triton_gpu.mma<{version = 1, warpsPerCTA = [2, 1]}>
|
||||
module attributes {"triton_gpu.num-warps" = 1 : i32} {
|
||||
// CHECK: llvm.mlir.global external @global_smem() {addr_space = 3 : i32} : !llvm.array<0 x i8>
|
||||
// CHECK-LABEL: convert_layout_mmav1_block
|
||||
func @convert_layout_mmav1_blocked(%arg0: tensor<32x16xf32, #mma>) {
|
||||
// CHECK: llvm.store
|
||||
// CHECK-SAME: !llvm.ptr<vector<2xf32>, 3>
|
||||
// CHECK: llvm.store
|
||||
// CHECK-SAME: !llvm.ptr<vector<2xf32>, 3>
|
||||
// CHECK: llvm.store
|
||||
// CHECK-SAME: !llvm.ptr<vector<2xf32>, 3>
|
||||
// CHECK: llvm.store
|
||||
@@ -860,6 +1017,45 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} {
|
||||
|
||||
// -----
|
||||
|
||||
#mma = #triton_gpu.mma<{version=2, warpsPerCTA=[2, 2]}>
|
||||
#shared = #triton_gpu.shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [1, 0]}>
|
||||
#blocked = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [2, 16], warpsPerCTA = [1, 4], order = [1, 0]}>
|
||||
#dot_operand_a = #triton_gpu.dot_op<{opIdx=0, parent=#mma}>
|
||||
#dot_operand_b = #triton_gpu.dot_op<{opIdx=1, parent=#mma}>
|
||||
module attributes {"triton_gpu.num-warps" = 4 : i32} {
|
||||
// CHECK-LABEL: matmul_tf32dot
|
||||
func @matmul_tf32dot(%ptr:!tt.ptr<f32> {tt.divisibility = 16 : i32},
|
||||
%a:tensor<32x16xf32, #shared>, %b:tensor<16x32xf32, #shared>) {
|
||||
%cst = arith.constant dense<0.000000e+00> : tensor<32x32xf32, #mma>
|
||||
// CHECK: llvm.inline_asm
|
||||
// CHECK-SAME: ldmatrix.sync.aligned.m8n8.x4.shared.b16
|
||||
// CHECK-SAME: (f32, f32, f32, f32)
|
||||
// CHECK: llvm.inline_asm
|
||||
// CHECK-SAME: ldmatrix.sync.aligned.m8n8.x4.shared.b16
|
||||
// CHECK-SAME: (f32, f32, f32, f32)
|
||||
%a_mat = triton_gpu.convert_layout %a : (tensor<32x16xf32, #shared>) -> tensor<32x16xf32, #dot_operand_a>
|
||||
%b_mat = triton_gpu.convert_layout %b : (tensor<16x32xf32, #shared>) -> tensor<16x32xf32, #dot_operand_b>
|
||||
|
||||
// CHECK: llvm.inline_asm
|
||||
// CHECK-SAME: mma.sync.aligned.m16n8k8.row.col.f32.tf32.tf32.f32
|
||||
// CHECK: llvm.inline_asm
|
||||
// CHECK-SAME: mma.sync.aligned.m16n8k8.row.col.f32.tf32.tf32.f32
|
||||
// CHECK: llvm.inline_asm
|
||||
// CHECK-SAME: mma.sync.aligned.m16n8k8.row.col.f32.tf32.tf32.f32
|
||||
// CHECK: llvm.inline_asm
|
||||
// CHECK-SAME: mma.sync.aligned.m16n8k8.row.col.f32.tf32.tf32.f32
|
||||
%28 = tt.dot %a_mat, %b_mat, %cst {allowTF32 = true, transA = false, transB = false} : tensor<32x16xf32, #dot_operand_a> * tensor<16x32xf32, #dot_operand_b> -> tensor<32x32xf32, #mma>
|
||||
%38 = triton_gpu.convert_layout %28 : (tensor<32x32xf32, #mma>) -> tensor<32x32xf32, #blocked>
|
||||
|
||||
%30 = tt.splat %ptr : (!tt.ptr<f32>) -> tensor<32x1x!tt.ptr<f32>, #blocked>
|
||||
%36 = tt.broadcast %30 : (tensor<32x1x!tt.ptr<f32>, #blocked>) -> tensor<32x32x!tt.ptr<f32>, #blocked>
|
||||
tt.store %36, %38 : tensor<32x32xf32, #blocked>
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
#blocked0 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>
|
||||
module attributes {"triton_gpu.num-warps" = 4 : i32} {
|
||||
// CHECK-LABEL: atomic_add_f32
|
||||
@@ -897,9 +1093,9 @@ func @test_get_program_id(%a: tensor<32x!tt.ptr<i32>, #blocked0>) {
|
||||
module attributes {"triton_gpu.num-warps" = 4 : i32} {
|
||||
|
||||
func @test_get_num_program(%a: tensor<32x!tt.ptr<i32>, #blocked0>) {
|
||||
// CHECK: nvvm.read.ptx.sreg.ntid.x
|
||||
// CHECK: nvvm.read.ptx.sreg.ntid.y
|
||||
// CHECK: nvvm.read.ptx.sreg.ntid.z
|
||||
// CHECK: nvvm.read.ptx.sreg.nctaid.x
|
||||
// CHECK: nvvm.read.ptx.sreg.nctaid.y
|
||||
// CHECK: nvvm.read.ptx.sreg.nctaid.z
|
||||
%blockdimx = tt.get_num_programs {axis=0:i32} : i32
|
||||
%blockdimy = tt.get_num_programs {axis=1:i32} : i32
|
||||
%blockdimz = tt.get_num_programs {axis=2:i32} : i32
|
||||
|
@@ -13,10 +13,10 @@ func @test_combine_dot_add_pattern() -> (tensor<128x128xf32>, tensor<128x128xf32
|
||||
|
||||
%dot_out = tt.dot %a, %b, %zero {allowTF32 = true, transA = false, transB = false} : tensor<128x128xf32> * tensor<128x128xf32> -> tensor<128x128xf32>
|
||||
|
||||
// CHECK-NEXT: %[[res0:.*]] = tt.dot %[[a]], %[[b]], %[[d]] {allowTF32 = true, transA = false, transB = false} : tensor<128x128xf32> * tensor<128x128xf32> -> tensor<128x128xf32>
|
||||
// CHECK-NEXT: %[[res0:.*]] = tt.dot %[[a]], %[[b]], %[[d]] {allowTF32 = true} : tensor<128x128xf32> * tensor<128x128xf32> -> tensor<128x128xf32>
|
||||
%res0 = arith.addf %dot_out, %d : tensor<128x128xf32>
|
||||
|
||||
// CHECK-NEXT: %[[res1:.*]] = tt.dot %[[a]], %[[b]], %[[d]] {allowTF32 = true, transA = false, transB = false} : tensor<128x128xf32> * tensor<128x128xf32> -> tensor<128x128xf32>
|
||||
// CHECK-NEXT: %[[res1:.*]] = tt.dot %[[a]], %[[b]], %[[d]] {allowTF32 = true} : tensor<128x128xf32> * tensor<128x128xf32> -> tensor<128x128xf32>
|
||||
%res1 = arith.addf %d, %dot_out : tensor<128x128xf32>
|
||||
|
||||
return %res0, %res1 : tensor<128x128xf32>, tensor<128x128xf32>
|
||||
|
@@ -26,7 +26,9 @@ struct TestMembarPass
|
||||
auto op_name = SymbolTable::getSymbolName(operation).getValue().str();
|
||||
os << op_name << "\n";
|
||||
Allocation allocation(operation);
|
||||
MembarAnalysis analysis(&allocation);
|
||||
MembarAnalysis membarPass(&allocation);
|
||||
membarPass.run();
|
||||
|
||||
size_t operationId = 0;
|
||||
operation->walk<WalkOrder::PreOrder>([&](Operation *op) {
|
||||
if (isa<gpu::BarrierOp>(op)) {
|
||||
|
Reference in New Issue
Block a user