[FRONTEND][BACKEND] Added trans
instruction; made flash attention bwd pass work (#943)
This commit is contained in:
@@ -295,6 +295,18 @@ def TT_CatOp : TT_Op<"cat", [NoSideEffect,
|
||||
let assemblyFormat = "$lhs `,` $rhs attr-dict `:` functional-type(operands, results)";
|
||||
}
|
||||
|
||||
def TT_TransOp : TT_Op<"trans", [NoSideEffect,
|
||||
SameOperandsAndResultElementType]> {
|
||||
|
||||
let summary = "transpose a tensor";
|
||||
|
||||
let arguments = (ins TT_Tensor:$src);
|
||||
|
||||
let results = (outs TT_Tensor:$result);
|
||||
|
||||
let assemblyFormat = "$src attr-dict `:` functional-type(operands, results)";
|
||||
}
|
||||
|
||||
//
|
||||
// SPMD Ops
|
||||
//
|
||||
|
@@ -26,7 +26,7 @@ ChangeResult SharedMemoryAliasAnalysis::visitOperation(
|
||||
// These ops may allocate a new shared memory buffer.
|
||||
auto result = op->getResult(0);
|
||||
// FIXME(Keren): extract and insert are always alias for now
|
||||
if (isa<tensor::ExtractSliceOp>(op)) {
|
||||
if (isa<tensor::ExtractSliceOp, triton::TransOp>(op)) {
|
||||
// extract_slice %src
|
||||
aliasInfo = AliasInfo(operands[0]->getValue());
|
||||
pessimistic = false;
|
||||
|
@@ -105,7 +105,7 @@ bool maybeSharedAllocationOp(Operation *op) {
|
||||
}
|
||||
|
||||
bool maybeAliasOp(Operation *op) {
|
||||
return isa<tensor::ExtractSliceOp>(op) ||
|
||||
return isa<tensor::ExtractSliceOp>(op) || isa<triton::TransOp>(op) ||
|
||||
isa<triton::gpu::InsertSliceAsyncOp>(op) ||
|
||||
isa<tensor::InsertSliceOp>(op);
|
||||
}
|
||||
|
@@ -2716,6 +2716,9 @@ public:
|
||||
auto dstSharedLayout = dstTy.getEncoding().cast<SharedEncodingAttr>();
|
||||
auto inOrd = srcBlockedLayout.getOrder();
|
||||
auto outOrd = dstSharedLayout.getOrder();
|
||||
if (inOrd != outOrd)
|
||||
llvm_unreachable(
|
||||
"blocked -> shared with different order not yet implemented");
|
||||
unsigned inVec =
|
||||
inOrd == outOrd ? srcBlockedLayout.getSizePerThread()[inOrd[0]] : 1;
|
||||
unsigned outVec = dstSharedLayout.getVec();
|
||||
@@ -2775,7 +2778,8 @@ public:
|
||||
getMultiDimIndex<unsigned>(linearRepIdx, reps, inOrd);
|
||||
for (unsigned linearWordIdx = 0; linearWordIdx < numWordsEachRep;
|
||||
++linearWordIdx) {
|
||||
// step 1: recover the multidim_index from the index of input_elements
|
||||
// step 1: recover the multidim_index from the index of
|
||||
// input_elements
|
||||
auto multiDimWordIdx =
|
||||
getMultiDimIndex<unsigned>(linearWordIdx, wordsInEachRep, inOrd);
|
||||
SmallVector<Value> multiDimIdx(2);
|
||||
@@ -3711,6 +3715,33 @@ DotOpConversion::convertFMADot(triton::DotOp op, OpAdaptor adaptor,
|
||||
|
||||
/// ====================== mma codegen end ============================
|
||||
|
||||
/// ====================== trans codegen begin ============================
|
||||
|
||||
struct TransOpConversion
|
||||
: public ConvertTritonGPUOpToLLVMPattern<triton::TransOp> {
|
||||
using ConvertTritonGPUOpToLLVMPattern<
|
||||
triton::TransOp>::ConvertTritonGPUOpToLLVMPattern;
|
||||
|
||||
LogicalResult
|
||||
matchAndRewrite(triton::TransOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
Location loc = op->getLoc();
|
||||
auto srcSmemObj =
|
||||
getSharedMemoryObjectFromStruct(loc, adaptor.src(), rewriter);
|
||||
SmallVector<Value> dstStrides = {srcSmemObj.strides[1],
|
||||
srcSmemObj.strides[0]};
|
||||
SmallVector<Value> dstOffsets = {srcSmemObj.offsets[1],
|
||||
srcSmemObj.offsets[0]};
|
||||
auto dstSmemObj =
|
||||
SharedMemoryObject(srcSmemObj.base, dstStrides, dstOffsets);
|
||||
auto retVal = getStructFromSharedMemoryObject(loc, dstSmemObj, rewriter);
|
||||
rewriter.replaceOp(op, retVal);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
/// ====================== trans codegen end ============================
|
||||
|
||||
Value convertSplatLikeOpWithMmaLayout(const MmaEncodingAttr &layout,
|
||||
Type resType, Type elemType,
|
||||
Value constVal,
|
||||
@@ -4538,6 +4569,7 @@ void populateTritonToLLVMPatterns(mlir::LLVMTypeConverter &typeConverter,
|
||||
patterns.add<ViewLikeOpConversion<triton::ExpandDimsOp>>(typeConverter,
|
||||
benefit);
|
||||
patterns.add<DotOpConversion>(typeConverter, allocation, smem, benefit);
|
||||
patterns.add<TransOpConversion>(typeConverter, benefit);
|
||||
patterns.add<PrintfOpConversion>(typeConverter, benefit);
|
||||
}
|
||||
|
||||
|
@@ -252,6 +252,51 @@ struct TritonDotPattern : public OpConversionPattern<triton::DotOp> {
|
||||
}
|
||||
};
|
||||
|
||||
struct TritonTransPattern : public OpConversionPattern<triton::TransOp> {
|
||||
|
||||
using OpConversionPattern<triton::TransOp>::OpConversionPattern;
|
||||
|
||||
LogicalResult
|
||||
matchAndRewrite(triton::TransOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
Value src = adaptor.src();
|
||||
auto srcType = src.getType().cast<RankedTensorType>();
|
||||
Attribute srcEncoding = srcType.getEncoding();
|
||||
if (!srcEncoding)
|
||||
return failure();
|
||||
if (!srcEncoding.isa<triton::gpu::SharedEncodingAttr>()) {
|
||||
// TODO: end-to-end correctness is broken if
|
||||
// the input is blocked and the output is shared
|
||||
// with different order. Maybe a backend issue in BlockedToShared?
|
||||
SmallVector<unsigned> order = {1, 0};
|
||||
if (auto srcBlockedEncoding =
|
||||
srcEncoding.dyn_cast<triton::gpu::BlockedEncodingAttr>())
|
||||
llvm::copy(srcBlockedEncoding.getOrder(), order.begin());
|
||||
srcEncoding =
|
||||
triton::gpu::SharedEncodingAttr::get(getContext(), 1, 1, 1, order);
|
||||
srcType = RankedTensorType::get(srcType.getShape(),
|
||||
srcType.getElementType(), srcEncoding);
|
||||
src = rewriter.create<triton::gpu::ConvertLayoutOp>(src.getLoc(), srcType,
|
||||
src);
|
||||
}
|
||||
auto srcSharedEncoding =
|
||||
srcEncoding.cast<triton::gpu::SharedEncodingAttr>();
|
||||
SmallVector<unsigned> retOrder(srcSharedEncoding.getOrder().begin(),
|
||||
srcSharedEncoding.getOrder().end());
|
||||
SmallVector<int64_t> retShapes(srcType.getShape().begin(),
|
||||
srcType.getShape().end());
|
||||
std::reverse(retOrder.begin(), retOrder.end());
|
||||
std::reverse(retShapes.begin(), retShapes.end());
|
||||
auto retEncoding =
|
||||
triton::gpu::SharedEncodingAttr::get(getContext(), 1, 1, 1, retOrder);
|
||||
auto retType =
|
||||
RankedTensorType::get(retShapes, srcType.getElementType(), retEncoding);
|
||||
|
||||
rewriter.replaceOpWithNewOp<triton::TransOp>(op, retType, src);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
struct TritonLoadPattern : public OpConversionPattern<triton::LoadOp> {
|
||||
using OpConversionPattern<triton::LoadOp>::OpConversionPattern;
|
||||
|
||||
@@ -390,9 +435,10 @@ void populateTritonPatterns(TritonGPUTypeConverter &typeConverter,
|
||||
TritonGenericPattern<triton::PtrToIntOp>,
|
||||
TritonGenericPattern<triton::SplatOp>, TritonBroadcastPattern,
|
||||
TritonGenericPattern<triton::AddPtrOp>, TritonReducePattern,
|
||||
TritonExpandDimsPattern, TritonMakeRangePattern, TritonDotPattern,
|
||||
TritonLoadPattern, TritonStorePattern, TritonExtElemwisePattern,
|
||||
TritonPrintfPattern, TritonAtomicRMWPattern>(typeConverter, context);
|
||||
TritonTransPattern, TritonExpandDimsPattern, TritonMakeRangePattern,
|
||||
TritonDotPattern, TritonLoadPattern, TritonStorePattern,
|
||||
TritonExtElemwisePattern, TritonPrintfPattern, TritonAtomicRMWPattern>(
|
||||
typeConverter, context);
|
||||
}
|
||||
|
||||
//
|
||||
|
@@ -178,6 +178,10 @@ public:
|
||||
!isSharedEncoding(convert.getResult())) {
|
||||
return mlir::failure();
|
||||
}
|
||||
if (isSharedEncoding(convert.getOperand()) &&
|
||||
isSharedEncoding(convert.getResult())) {
|
||||
return mlir::failure();
|
||||
}
|
||||
auto srcType = convert.getOperand().getType().cast<RankedTensorType>();
|
||||
auto srcShared =
|
||||
srcType.getEncoding().dyn_cast<triton::gpu::SharedEncodingAttr>();
|
||||
@@ -661,6 +665,54 @@ SmallVector<unsigned, 2> warpsPerTileV2(triton::DotOp dotOp,
|
||||
|
||||
} // namespace
|
||||
|
||||
class OptimizeBlockedToShared : public mlir::RewritePattern {
|
||||
public:
|
||||
OptimizeBlockedToShared(mlir::MLIRContext *context)
|
||||
: RewritePattern(triton::gpu::ConvertLayoutOp::getOperationName(), 1,
|
||||
context) {}
|
||||
|
||||
mlir::LogicalResult
|
||||
matchAndRewrite(mlir::Operation *op,
|
||||
mlir::PatternRewriter &rewriter) const override {
|
||||
auto cvt = cast<triton::gpu::ConvertLayoutOp>(op);
|
||||
auto srcType = cvt.getOperand().getType().cast<RankedTensorType>();
|
||||
auto dstType = cvt.getResult().getType().cast<RankedTensorType>();
|
||||
auto srcBlockedLayout =
|
||||
srcType.getEncoding().dyn_cast<triton::gpu::BlockedEncodingAttr>();
|
||||
auto dstSharedLayout =
|
||||
dstType.getEncoding().dyn_cast<triton::gpu::SharedEncodingAttr>();
|
||||
if (!srcBlockedLayout || !dstSharedLayout)
|
||||
return failure();
|
||||
if (srcBlockedLayout.getOrder() == dstSharedLayout.getOrder())
|
||||
return failure();
|
||||
// For now only works if single use is transpose
|
||||
// TODO: rematerialize #shared uses
|
||||
auto users = op->getUsers();
|
||||
if (std::distance(users.begin(), users.end()) != 1 ||
|
||||
!isa<triton::TransOp>(*users.begin()))
|
||||
return failure();
|
||||
|
||||
auto tmpShared = triton::gpu::SharedEncodingAttr::get(
|
||||
op->getContext(), dstSharedLayout.getVec(),
|
||||
dstSharedLayout.getPerPhase(), dstSharedLayout.getMaxPhase(),
|
||||
srcBlockedLayout.getOrder());
|
||||
auto tmpType = RankedTensorType::get(srcType.getShape(),
|
||||
srcType.getElementType(), tmpShared);
|
||||
auto tmpCvt = rewriter.create<triton::gpu::ConvertLayoutOp>(
|
||||
op->getLoc(), tmpType, cvt.getOperand());
|
||||
|
||||
auto newDstType = RankedTensorType::get(
|
||||
users.begin()->getResultTypes()[0].cast<RankedTensorType>().getShape(),
|
||||
srcType.getElementType(), dstSharedLayout);
|
||||
|
||||
auto newTrans = rewriter.create<triton::TransOp>(op->getLoc(), newDstType,
|
||||
tmpCvt.getResult());
|
||||
|
||||
rewriter.replaceOp(*users.begin(), newTrans.getResult());
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
class BlockedToMMA : public mlir::RewritePattern {
|
||||
int computeCapability;
|
||||
|
||||
@@ -755,6 +807,7 @@ public:
|
||||
|
||||
mlir::RewritePatternSet patterns(context);
|
||||
|
||||
patterns.add<OptimizeBlockedToShared>(context);
|
||||
patterns.add<SimplifyConversion>(context);
|
||||
patterns.add<DecomposeDotOperand>(context);
|
||||
patterns.add<RematerializeBackward>(context);
|
||||
|
@@ -1085,6 +1085,16 @@ void init_triton_ir(py::module &&m) {
|
||||
mlir::RankedTensorType::get(shape, lhsType.getElementType()),
|
||||
lhs, rhs);
|
||||
})
|
||||
.def("create_trans",
|
||||
[](mlir::OpBuilder &self, mlir::Value &arg) -> mlir::Value {
|
||||
auto loc = self.getUnknownLoc();
|
||||
auto argType = arg.getType().dyn_cast<mlir::RankedTensorType>();
|
||||
auto argEltType = argType.getElementType();
|
||||
std::vector<int64_t> retShape = argType.getShape();
|
||||
std::reverse(retShape.begin(), retShape.end());
|
||||
return self.create<mlir::triton::TransOp>(
|
||||
loc, mlir::RankedTensorType::get(retShape, argEltType), arg);
|
||||
})
|
||||
.def("create_broadcast",
|
||||
[](mlir::OpBuilder &self, mlir::Value &arg,
|
||||
std::vector<int64_t> &shape) -> mlir::Value {
|
||||
|
@@ -667,7 +667,6 @@ def test_tensor_atomic_rmw(shape, axis, device="cuda"):
|
||||
tl.atomic_add(Z + off1, z)
|
||||
rs = RandomState(17)
|
||||
x = numpy_random((shape0, shape1), dtype_str="float32", rs=rs)
|
||||
print(x)
|
||||
# reference result
|
||||
z_ref = np.sum(x, axis=axis, keepdims=False)
|
||||
# triton result
|
||||
@@ -1067,122 +1066,126 @@ def test_permute(dtype_str, shape, perm, device='cuda'):
|
||||
# # ---------------
|
||||
|
||||
|
||||
# @pytest.mark.parametrize("epilogue, allow_tf32, dtype",
|
||||
# [(epilogue, allow_tf32, dtype)
|
||||
# for epilogue in ['none', 'trans', 'add-matrix', 'add-rows', 'add-cols', 'softmax', 'chain-dot']
|
||||
# for allow_tf32 in [True, False]
|
||||
# for dtype in ['float16']
|
||||
# if not (allow_tf32 and (dtype in ['float16']))])
|
||||
# def test_dot(epilogue, allow_tf32, dtype, device='cuda'):
|
||||
# cc = _triton.runtime.cc(_triton.runtime.backend.CUDA, torch.cuda.current_device())
|
||||
# if cc < 80:
|
||||
# if dtype == 'int8':
|
||||
# pytest.skip("Only test int8 on devices with sm >= 80")
|
||||
# elif dtype == 'float32' and allow_tf32:
|
||||
# pytest.skip("Only test tf32 on devices with sm >= 80")
|
||||
@pytest.mark.parametrize("epilogue, allow_tf32, dtype",
|
||||
[(epilogue, allow_tf32, dtype)
|
||||
for epilogue in ['none', 'trans', 'add-matrix', 'add-rows', 'add-cols', 'softmax', 'chain-dot']
|
||||
for allow_tf32 in [True, False]
|
||||
for dtype in ['float16']
|
||||
if not (allow_tf32 and (dtype in ['float16']))])
|
||||
def test_dot(epilogue, allow_tf32, dtype, device='cuda'):
|
||||
capability = torch.cuda.get_device_capability()
|
||||
if capability[0] < 80:
|
||||
if dtype == 'int8':
|
||||
pytest.skip("Only test int8 on devices with sm >= 80")
|
||||
elif dtype == 'float32' and allow_tf32:
|
||||
pytest.skip("Only test tf32 on devices with sm >= 80")
|
||||
|
||||
# M, N, K = 128, 128, 64
|
||||
# num_warps = 8
|
||||
# trans_a, trans_b = False, False
|
||||
M, N, K = 64, 64, 64
|
||||
num_warps = 4
|
||||
trans_a, trans_b = False, False
|
||||
|
||||
# # triton kernel
|
||||
# @triton.jit
|
||||
# def kernel(X, stride_xm, stride_xk,
|
||||
# Y, stride_yk, stride_yn,
|
||||
# W, stride_wn, stride_wl,
|
||||
# Z, stride_zm, stride_zn,
|
||||
# BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr,
|
||||
# ADD_MATRIX: tl.constexpr, ADD_ROWS: tl.constexpr, ADD_COLS: tl.constexpr,
|
||||
# ALLOW_TF32: tl.constexpr,
|
||||
# DO_SOFTMAX: tl.constexpr, CHAIN_DOT: tl.constexpr,
|
||||
# TRANS_A: tl.constexpr, TRANS_B: tl.constexpr):
|
||||
# off_m = tl.arange(0, BLOCK_M)
|
||||
# off_n = tl.arange(0, BLOCK_N)
|
||||
# off_l = tl.arange(0, BLOCK_N)
|
||||
# off_k = tl.arange(0, BLOCK_K)
|
||||
# Xs = X + off_m[:, None] * stride_xm + off_k[None, :] * stride_xk
|
||||
# Ys = Y + off_k[:, None] * stride_yk + off_n[None, :] * stride_yn
|
||||
# Ws = W + off_n[:, None] * stride_wn + off_l[None, :] * stride_wl
|
||||
# Zs = Z + off_m[:, None] * stride_zm + off_n[None, :] * stride_zn
|
||||
# z = tl.dot(tl.load(Xs), tl.load(Ys), trans_a=TRANS_A, trans_b=TRANS_B, allow_tf32=ALLOW_TF32)
|
||||
# if ADD_MATRIX:
|
||||
# z += tl.load(Zs)
|
||||
# if ADD_ROWS:
|
||||
# ZRs = Z + off_m * stride_zm
|
||||
# z += tl.load(ZRs)[:, None]
|
||||
# if ADD_COLS:
|
||||
# ZCs = Z + off_n * stride_zn
|
||||
# z += tl.load(ZCs)[None, :]
|
||||
# if DO_SOFTMAX:
|
||||
# max = tl.max(z, 1)
|
||||
# z = z - max[:, None]
|
||||
# num = tl.exp(z)
|
||||
# den = tl.sum(num, 1)
|
||||
# z = num / den[:, None]
|
||||
# if CHAIN_DOT:
|
||||
# # tl.store(Zs, z)
|
||||
# # tl.debug_barrier()
|
||||
# z = tl.dot(z.to(tl.float16), tl.load(Ws), trans_a=TRANS_A)
|
||||
# tl.store(Zs, z)
|
||||
# # input
|
||||
# rs = RandomState(17)
|
||||
# x = numpy_random((K, M) if trans_a else (M, K), dtype_str=dtype, rs=rs) * .1
|
||||
# y = numpy_random((N, K) if trans_b else (K, N), dtype_str=dtype, rs=rs) * .1
|
||||
# w = numpy_random((N, N), dtype_str=dtype, rs=rs) * .1
|
||||
# if allow_tf32:
|
||||
# x = (x.view('uint32') & np.uint32(0xffffe000)).view('float32')
|
||||
# y = (y.view('uint32') & np.uint32(0xffffe000)).view('float32')
|
||||
# w = (w.view('uint32') & np.uint32(0xffffe000)).view('float32')
|
||||
# x_tri = to_triton(x, device=device)
|
||||
# y_tri = to_triton(y, device=device)
|
||||
# w_tri = to_triton(w, device=device)
|
||||
# # triton result
|
||||
# z = 1 + numpy_random((M, N), dtype_str=dtype, rs=rs) * .1
|
||||
# z_tri = to_triton(z, device=device)
|
||||
# if epilogue == 'trans':
|
||||
# z_tri = torch.as_strided(z_tri, (M, N), z_tri.stride()[::-1])
|
||||
# pgm = kernel[(1, 1)](x_tri, x_tri.stride(0), x_tri.stride(1),
|
||||
# y_tri, y_tri.stride(0), y_tri.stride(1),
|
||||
# w_tri, w_tri.stride(0), w_tri.stride(1),
|
||||
# z_tri, z_tri.stride(0), z_tri.stride(1),
|
||||
# TRANS_A=trans_a, TRANS_B=trans_b,
|
||||
# BLOCK_M=M, BLOCK_K=K, BLOCK_N=N,
|
||||
# ADD_MATRIX=epilogue == 'add-matrix',
|
||||
# ADD_ROWS=epilogue == 'add-rows',
|
||||
# ADD_COLS=epilogue == 'add-cols',
|
||||
# DO_SOFTMAX=epilogue == 'softmax',
|
||||
# CHAIN_DOT=epilogue == 'chain-dot',
|
||||
# ALLOW_TF32=allow_tf32,
|
||||
# num_warps=num_warps)
|
||||
# # torch result
|
||||
# x_ref = x.T if trans_a else x
|
||||
# y_ref = y.T if trans_b else y
|
||||
# z_ref = np.matmul(x_ref, y_ref)
|
||||
# if epilogue == 'add-matrix':
|
||||
# z_ref += z
|
||||
# if epilogue == 'add-rows':
|
||||
# z_ref += z[:, 0][:, None]
|
||||
# if epilogue == 'add-cols':
|
||||
# z_ref += z[0, :][None, :]
|
||||
# if epilogue == 'softmax':
|
||||
# num = np.exp(z_ref - np.max(z_ref, axis=-1, keepdims=True))
|
||||
# denom = np.sum(num, axis=-1, keepdims=True)
|
||||
# z_ref = num / denom
|
||||
# if epilogue == 'chain-dot':
|
||||
# z_ref = np.matmul(z_ref.T if trans_a else z_ref, w)
|
||||
# # compare
|
||||
# # print(z_ref[:,0], z_tri[:,0])
|
||||
# np.testing.assert_allclose(z_ref, to_numpy(z_tri), rtol=0.01)
|
||||
# # make sure ld/st are vectorized
|
||||
# ptx = pgm.asm['ptx']
|
||||
# assert 'ld.global.v4' in ptx
|
||||
# assert 'st.global.v4' in ptx
|
||||
# if allow_tf32:
|
||||
# assert 'mma.sync.aligned.m16n8k8.row.col.f32.tf32.tf32.f32' in ptx
|
||||
# elif dtype == 'float32':
|
||||
# assert 'mma.sync.aligned.m16n8k8.row.col.f32.tf32.tf32.f32' not in ptx
|
||||
# elif dtype == 'int8':
|
||||
# assert 'mma.sync.aligned.m16n8k32.row.col.satfinite.s32.s8.s8.s32' in ptx
|
||||
# triton kernel
|
||||
@triton.jit
|
||||
def kernel(X, stride_xm, stride_xk,
|
||||
Y, stride_yk, stride_yn,
|
||||
W, stride_wn, stride_wl,
|
||||
Z, stride_zm, stride_zn,
|
||||
BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr,
|
||||
ADD_MATRIX: tl.constexpr, ADD_ROWS: tl.constexpr, ADD_COLS: tl.constexpr,
|
||||
ALLOW_TF32: tl.constexpr,
|
||||
DO_SOFTMAX: tl.constexpr, CHAIN_DOT: tl.constexpr,
|
||||
TRANS_A: tl.constexpr, TRANS_B: tl.constexpr):
|
||||
off_m = tl.arange(0, BLOCK_M)
|
||||
off_n = tl.arange(0, BLOCK_N)
|
||||
off_l = tl.arange(0, BLOCK_N)
|
||||
off_k = tl.arange(0, BLOCK_K)
|
||||
Xs = X + off_m[:, None] * stride_xm + off_k[None, :] * stride_xk
|
||||
Ys = Y + off_k[:, None] * stride_yk + off_n[None, :] * stride_yn
|
||||
Ws = W + off_n[:, None] * stride_wn + off_l[None, :] * stride_wl
|
||||
Zs = Z + off_m[:, None] * stride_zm + off_n[None, :] * stride_zn
|
||||
x = tl.load(Xs)
|
||||
y = tl.load(Ys)
|
||||
x = tl.trans(x) if TRANS_A else x
|
||||
y = tl.trans(y) if TRANS_B else y
|
||||
z = tl.dot(x, y, allow_tf32=ALLOW_TF32)
|
||||
if ADD_MATRIX:
|
||||
z += tl.load(Zs)
|
||||
if ADD_ROWS:
|
||||
ZRs = Z + off_m * stride_zm
|
||||
z += tl.load(ZRs)[:, None]
|
||||
if ADD_COLS:
|
||||
ZCs = Z + off_n * stride_zn
|
||||
z += tl.load(ZCs)[None, :]
|
||||
if DO_SOFTMAX:
|
||||
max = tl.max(z, 1)
|
||||
z = z - max[:, None]
|
||||
num = tl.exp(z)
|
||||
den = tl.sum(num, 1)
|
||||
z = num / den[:, None]
|
||||
if CHAIN_DOT:
|
||||
# tl.store(Zs, z)
|
||||
# tl.debug_barrier()
|
||||
z = tl.dot(tl.trans(z.to(tl.float16)), tl.load(Ws))
|
||||
tl.store(Zs, z)
|
||||
# input
|
||||
rs = RandomState(17)
|
||||
x = numpy_random((K, M) if trans_a else (M, K), dtype_str=dtype, rs=rs) * .1
|
||||
y = numpy_random((N, K) if trans_b else (K, N), dtype_str=dtype, rs=rs) * .1
|
||||
w = numpy_random((N, N), dtype_str=dtype, rs=rs) * .1
|
||||
if allow_tf32:
|
||||
x = (x.view('uint32') & np.uint32(0xffffe000)).view('float32')
|
||||
y = (y.view('uint32') & np.uint32(0xffffe000)).view('float32')
|
||||
w = (w.view('uint32') & np.uint32(0xffffe000)).view('float32')
|
||||
x_tri = to_triton(x, device=device)
|
||||
y_tri = to_triton(y, device=device)
|
||||
w_tri = to_triton(w, device=device)
|
||||
# triton result
|
||||
z = 1 + numpy_random((M, N), dtype_str=dtype, rs=rs) * .1
|
||||
z_tri = to_triton(z, device=device)
|
||||
if epilogue == 'trans':
|
||||
z_tri = torch.as_strided(z_tri, (M, N), z_tri.stride()[::-1])
|
||||
pgm = kernel[(1, 1)](x_tri, x_tri.stride(0), x_tri.stride(1),
|
||||
y_tri, y_tri.stride(0), y_tri.stride(1),
|
||||
w_tri, w_tri.stride(0), w_tri.stride(1),
|
||||
z_tri, z_tri.stride(0), z_tri.stride(1),
|
||||
TRANS_A=trans_a, TRANS_B=trans_b,
|
||||
BLOCK_M=M, BLOCK_K=K, BLOCK_N=N,
|
||||
ADD_MATRIX=epilogue == 'add-matrix',
|
||||
ADD_ROWS=epilogue == 'add-rows',
|
||||
ADD_COLS=epilogue == 'add-cols',
|
||||
DO_SOFTMAX=epilogue == 'softmax',
|
||||
CHAIN_DOT=epilogue == 'chain-dot',
|
||||
ALLOW_TF32=allow_tf32,
|
||||
num_warps=num_warps)
|
||||
# torch result
|
||||
x_ref = x.T if trans_a else x
|
||||
y_ref = y.T if trans_b else y
|
||||
z_ref = np.matmul(x_ref, y_ref)
|
||||
if epilogue == 'add-matrix':
|
||||
z_ref += z
|
||||
if epilogue == 'add-rows':
|
||||
z_ref += z[:, 0][:, None]
|
||||
if epilogue == 'add-cols':
|
||||
z_ref += z[0, :][None, :]
|
||||
if epilogue == 'softmax':
|
||||
num = np.exp(z_ref - np.max(z_ref, axis=-1, keepdims=True))
|
||||
denom = np.sum(num, axis=-1, keepdims=True)
|
||||
z_ref = num / denom
|
||||
if epilogue == 'chain-dot':
|
||||
z_ref = np.matmul(z_ref.T, w)
|
||||
# compare
|
||||
# print(z_ref[:,0], z_tri[:,0])
|
||||
np.testing.assert_allclose(z_ref, to_numpy(z_tri), rtol=0.01)
|
||||
# make sure ld/st are vectorized
|
||||
ptx = pgm.asm['ptx']
|
||||
assert 'ld.global.v4' in ptx
|
||||
assert 'st.global.v4' in ptx
|
||||
if allow_tf32:
|
||||
assert 'mma.sync.aligned.m16n8k8.row.col.f32.tf32.tf32.f32' in ptx
|
||||
elif dtype == 'float32':
|
||||
assert 'mma.sync.aligned.m16n8k8.row.col.f32.tf32.tf32.f32' not in ptx
|
||||
elif dtype == 'int8':
|
||||
assert 'mma.sync.aligned.m16n8k32.row.col.satfinite.s32.s8.s8.s32' in ptx
|
||||
|
||||
|
||||
# def test_dot_without_load():
|
||||
|
@@ -1367,7 +1367,7 @@ prototype_pattern = {
|
||||
"ptx": ptx_prototype_pattern,
|
||||
}
|
||||
|
||||
mlir_arg_type_pattern = r'%\w+: ([^,\s]+)(?: \{\S+ = \S+ : \S+\})?,?'
|
||||
mlir_arg_type_pattern = r'%\w+: ([^,^\)\s]+)(?: \{\S+ = \S+ : \S+\})?,?'
|
||||
ptx_arg_type_pattern = r"\.param\s+\.(\w+)"
|
||||
arg_type_pattern = {
|
||||
"ttir": mlir_arg_type_pattern,
|
||||
@@ -1424,7 +1424,9 @@ def compile(fn, **kwargs):
|
||||
import re
|
||||
match = re.search(prototype_pattern[ir], src, re.MULTILINE)
|
||||
name, signature = match.group(1), match.group(2)
|
||||
print(name, signature)
|
||||
types = re.findall(arg_type_pattern[ir], signature)
|
||||
print(types)
|
||||
param_tys = [convert_type_repr(ty) for ty in types]
|
||||
signature = {k: v for k, v in enumerate(param_tys)}
|
||||
first_stage = list(stages.keys()).index(ir)
|
||||
|
@@ -614,6 +614,7 @@ class tensor:
|
||||
assert False, "unsupported"
|
||||
return ret
|
||||
|
||||
|
||||
# x[:, None, :, None]
|
||||
# x = expand_dims(x, axis=1)
|
||||
# x = expand_dims(x, axis=2)
|
||||
@@ -737,6 +738,9 @@ def broadcast_to(input, shape, _builder=None):
|
||||
"""
|
||||
return semantic.broadcast_impl_shape(input, shape, _builder)
|
||||
|
||||
@builtin
|
||||
def trans(input, _builder=None):
|
||||
return semantic.trans(input, _builder)
|
||||
|
||||
@builtin
|
||||
def cat(input, other, _builder=None):
|
||||
|
@@ -502,6 +502,11 @@ def cat(lhs: tl.tensor, rhs: tl.tensor, builder: ir.builder) -> tl.tensor:
|
||||
# TODO: check types
|
||||
return tl.tensor(builder.create_cat(lhs.handle, rhs.handle), lhs.type)
|
||||
|
||||
def trans(input: tl.tensor, builder: ir.builder) -> tl.tensor:
|
||||
if len(input.shape) != 2:
|
||||
raise ValueError("Only 2D tensors can be transposed")
|
||||
ret_type = tl.block_type(input.type.scalar, [input.shape[1], input.shape[0]])
|
||||
return tl.tensor(builder.create_trans(input.handle), ret_type)
|
||||
|
||||
def broadcast_impl_shape(input: tl.tensor,
|
||||
shape: List[int],
|
||||
|
@@ -32,7 +32,7 @@ def _fwd_kernel(
|
||||
offs_n = tl.arange(0, BLOCK_N)
|
||||
offs_d = tl.arange(0, BLOCK_DMODEL)
|
||||
off_q = off_hz * stride_qh + offs_m[:, None] * stride_qm + offs_d[None, :] * stride_qk
|
||||
off_k = off_hz * stride_qh + offs_n[None, :] * stride_kn + offs_d[:, None] * stride_kk
|
||||
off_k = off_hz * stride_qh + offs_n[:, None] * stride_kn + offs_d[None, :] * stride_kk
|
||||
off_v = off_hz * stride_qh + offs_n[:, None] * stride_qm + offs_d[None, :] * stride_qk
|
||||
# Initialize pointers to Q, K, V
|
||||
q_ptrs = Q + off_q
|
||||
@@ -50,7 +50,7 @@ def _fwd_kernel(
|
||||
# -- compute qk ----
|
||||
k = tl.load(k_ptrs + start_n * stride_kn)
|
||||
qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
|
||||
qk += tl.dot(q, k)
|
||||
qk += tl.dot(q, tl.trans(k))
|
||||
qk *= sm_scale
|
||||
qk += tl.where(offs_m[:, None] >= (start_n + offs_n[None, :]), 0, float("-inf"))
|
||||
# -- compute m_ij, p, l_ij
|
||||
@@ -165,26 +165,26 @@ def _bwd_kernel(
|
||||
q = tl.load(q_ptrs)
|
||||
# recompute p = softmax(qk, dim=-1).T
|
||||
# NOTE: `do` is pre-divided by `l`; no normalization here
|
||||
qk = tl.dot(q, k, trans_b=True)
|
||||
qk = tl.dot(q, tl.trans(k))
|
||||
qk = tl.where(offs_m_curr[:, None] >= (offs_n[None, :]), qk, float("-inf"))
|
||||
m = tl.load(m_ptrs + offs_m_curr)
|
||||
p = tl.exp(qk * sm_scale - m[:, None])
|
||||
# compute dv
|
||||
do = tl.load(do_ptrs)
|
||||
dv += tl.dot(p.to(tl.float16), do, trans_a=True)
|
||||
dv += tl.dot(tl.trans(p.to(tl.float16)), do)
|
||||
# compute dp = dot(v, do)
|
||||
Di = tl.load(D_ptrs + offs_m_curr)
|
||||
dp = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) - Di[:, None]
|
||||
dp += tl.dot(do, v, trans_b=True)
|
||||
dp += tl.dot(do, tl.trans(v))
|
||||
# compute ds = p * (dp - delta[:, None])
|
||||
ds = p * dp * sm_scale
|
||||
# compute dk = dot(ds.T, q)
|
||||
dk += tl.dot(ds.to(tl.float16), q, trans_a=True)
|
||||
# # compute dq
|
||||
dk += tl.dot(tl.trans(ds.to(tl.float16)), q)
|
||||
# compute dq
|
||||
dq = tl.load(dq_ptrs)
|
||||
dq += tl.dot(ds.to(tl.float16), k)
|
||||
tl.store(dq_ptrs, dq)
|
||||
# # increment pointers
|
||||
# increment pointers
|
||||
dq_ptrs += BLOCK_M * stride_qm
|
||||
q_ptrs += BLOCK_M * stride_qm
|
||||
do_ptrs += BLOCK_M * stride_qm
|
||||
@@ -273,7 +273,7 @@ attention = _attention.apply
|
||||
@pytest.mark.parametrize('Z, H, N_CTX, D_HEAD', [(4, 48, 1024, 64)])
|
||||
def test_op(Z, H, N_CTX, D_HEAD, dtype=torch.float16):
|
||||
torch.manual_seed(20)
|
||||
q = torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0.1, std=0.1).requires_grad_()
|
||||
q = torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0.1, std=0.2).requires_grad_()
|
||||
k = torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0.4, std=0.2).requires_grad_()
|
||||
v = torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0.3, std=0.2).requires_grad_()
|
||||
sm_scale = 0.2
|
||||
@@ -287,23 +287,23 @@ def test_op(Z, H, N_CTX, D_HEAD, dtype=torch.float16):
|
||||
p = torch.softmax(p.float(), dim=-1).half()
|
||||
# p = torch.exp(p)
|
||||
ref_out = torch.matmul(p, v)
|
||||
# ref_out.backward(dout)
|
||||
# ref_dv, v.grad = v.grad.clone(), None
|
||||
# ref_dk, k.grad = k.grad.clone(), None
|
||||
# ref_dq, q.grad = q.grad.clone(), None
|
||||
ref_out.backward(dout)
|
||||
ref_dv, v.grad = v.grad.clone(), None
|
||||
ref_dk, k.grad = k.grad.clone(), None
|
||||
ref_dq, q.grad = q.grad.clone(), None
|
||||
# # triton implementation
|
||||
tri_out = attention(q, k, v, sm_scale)
|
||||
# print(ref_out)
|
||||
# print(tri_out)
|
||||
# tri_out.backward(dout)
|
||||
# tri_dv, v.grad = v.grad.clone(), None
|
||||
# tri_dk, k.grad = k.grad.clone(), None
|
||||
# tri_dq, q.grad = q.grad.clone(), None
|
||||
tri_out.backward(dout)
|
||||
tri_dv, v.grad = v.grad.clone(), None
|
||||
tri_dk, k.grad = k.grad.clone(), None
|
||||
tri_dq, q.grad = q.grad.clone(), None
|
||||
# compare
|
||||
triton.testing.assert_almost_equal(ref_out, tri_out)
|
||||
# triton.testing.assert_almost_equal(ref_dv, tri_dv)
|
||||
# triton.testing.assert_almost_equal(ref_dk, tri_dk)
|
||||
# triton.testing.assert_almost_equal(ref_dq, tri_dq)
|
||||
triton.testing.assert_almost_equal(ref_dv, tri_dv)
|
||||
triton.testing.assert_almost_equal(ref_dk, tri_dk)
|
||||
triton.testing.assert_almost_equal(ref_dq, tri_dq)
|
||||
|
||||
BATCH, N_HEADS, N_CTX, D_HEAD = 4, 48, 4096, 64
|
||||
# vary seq length for fixed head and batch=4
|
||||
@@ -350,4 +350,4 @@ def bench_flash_attention(BATCH, H, N_CTX, D_HEAD, mode, provider, dtype=torch.f
|
||||
ms = triton.testing.do_bench(fn, percentiles=None, warmup=warmup, rep=rep)
|
||||
return ms
|
||||
|
||||
bench_flash_attention.run(save_path='.', print_data=True)
|
||||
# bench_flash_attention.run(save_path='.', print_data=True)
|
Reference in New Issue
Block a user