diff --git a/python/test/unit/operators/test_blocksparse.py b/python/test/unit/operators/test_blocksparse.py index ed569c04d..9e0c72de9 100644 --- a/python/test/unit/operators/test_blocksparse.py +++ b/python/test/unit/operators/test_blocksparse.py @@ -10,77 +10,108 @@ import triton @pytest.mark.parametrize("BLOCK", [16, 32, 64]) @pytest.mark.parametrize("DTYPE", [torch.float16]) def test_matmul(MODE, TRANS_A, TRANS_B, BLOCK, DTYPE, Z=3, H=2, M=512, N=384, K=256): - # set seed - torch.random.manual_seed(0) + seed = 0 + torch.manual_seed(seed) + is_sdd = MODE == "sdd" + is_dsd = MODE == "dsd" + is_dds = MODE == "dds" + do_sparsify = lambda x: triton.testing.sparsify_tensor(x, layout, BLOCK) + do_mask = lambda x: triton.testing.mask_tensor(x, layout, BLOCK) # create inputs - a = torch.randn((Z, H, K, M) if TRANS_A else (Z, H, M, K), dtype=DTYPE, device="cuda") - b = torch.randn((Z, H, N, K) if TRANS_B else (Z, H, K, N), dtype=DTYPE, device="cuda") + # create op + a_shape = (Z, H, K, M) if TRANS_A else (Z, H, M, K) + b_shape = (Z, H, N, K) if TRANS_B else (Z, H, K, N) + c_shape = (Z, H, M, N) shape = { "sdd": (M, N), - "dsd": (a.shape[2], a.shape[3]), - "dds": (b.shape[2], b.shape[3]), + "dsd": (a_shape[2], a_shape[3]), + "dds": (b_shape[2], b_shape[3]), }[MODE] layout = torch.randint(2, (H, shape[0] // BLOCK, shape[1] // BLOCK)) + layout[1, 2, :] = 0 + layout[1, :, 1] = 0 + # create data + a_ref, a_tri = triton.testing.make_pair(a_shape, alpha=.1) + b_ref, b_tri = triton.testing.make_pair(b_shape, alpha=.1) + dc_ref, dc_tri = triton.testing.make_pair(c_shape) + # compute [torch] + dc_ref = do_mask(dc_ref) if is_sdd else dc_ref + a_ref = do_mask(a_ref) if is_dsd else a_ref + b_ref = do_mask(b_ref) if is_dds else b_ref + a_ref.retain_grad() + b_ref.retain_grad() + c_ref = torch.matmul(a_ref.transpose(2, 3) if TRANS_A else a_ref, + b_ref.transpose(2, 3) if TRANS_B else b_ref) + c_ref.backward(dc_ref) + c_ref = do_sparsify(c_ref) if is_sdd else c_ref + da_ref = do_sparsify(a_ref.grad) if is_dsd else a_ref.grad + db_ref = do_sparsify(b_ref.grad) if is_dds else b_ref.grad # triton result + dc_tri = do_sparsify(dc_tri) if is_sdd else dc_tri + a_tri = do_sparsify(a_tri) if is_dsd else a_tri + b_tri = do_sparsify(b_tri) if is_dds else b_tri + a_tri.retain_grad() + b_tri.retain_grad() op = triton.ops.blocksparse.matmul(layout, BLOCK, MODE, trans_a=TRANS_A, trans_b=TRANS_B, device="cuda") - ra = triton.testing.sparsify_tensor(a, layout, BLOCK) if MODE == "dsd" else a - rb = triton.testing.sparsify_tensor(b, layout, BLOCK) if MODE == "dds" else b - rc = triton.testing.catch_oor(lambda: op(ra, rb), pytest) - # torch result - ta = triton.testing.mask_tensor(a, layout, BLOCK) if MODE == "dsd" else a - tb = triton.testing.mask_tensor(b, layout, BLOCK) if MODE == "dds" else b - ta = ta.transpose(2, 3) if TRANS_A else ta - tb = tb.transpose(2, 3) if TRANS_B else tb - tc = torch.matmul(ta, tb) - tc = triton.testing.mask_tensor(tc, layout, BLOCK) if MODE == "sdd" else tc - tc = triton.testing.sparsify_tensor(tc, layout, BLOCK) if MODE == "sdd" else tc + c_tri = triton.testing.catch_oor(lambda: op(a_tri, b_tri), pytest) + triton.testing.catch_oor(lambda: c_tri.backward(dc_tri), pytest) + da_tri = a_tri.grad + db_tri = b_tri.grad # compare - triton.testing.assert_almost_equal(rc, tc) + triton.testing.assert_almost_equal(c_ref, c_tri) + triton.testing.assert_almost_equal(da_ref, da_tri) + triton.testing.assert_almost_equal(db_ref, db_tri) -@pytest.mark.parametrize("BLOCK", [16, 32, 64]) -@pytest.mark.parametrize("WIDTH", [256, 576, 1024, 1792]) -@pytest.mark.parametrize("DTYPE", [torch.float16, torch.float32]) -def test_softmax(BLOCK, WIDTH, DTYPE): - is_causal = True +configs = [ + (16, 256), + (32, 576), + (64, 1871), + (128, 2511), +] + + +@pytest.mark.parametrize("is_dense", [False, True]) +@pytest.mark.parametrize("BLOCK, WIDTH", configs) +def test_softmax(BLOCK, WIDTH, is_dense, Z=2, H=2, is_causal=True, scale=0.4): # set seed torch.random.manual_seed(0) - Z, H, M, N = 1, 1, WIDTH, WIDTH - scale = 0.4 - # create inputs - layout = torch.randint(2, (H, M // BLOCK, N // BLOCK)) - x = torch.randn((Z, H, M, N), dtype=DTYPE, requires_grad=True, device="cuda") - at_mask = torch.randint(low=0, high=2, size=(N, N), dtype=torch.bool, requires_grad=False, device="cuda") + Z, H, M, N = 2, 3, WIDTH, WIDTH + # initialize layout # make sure each row has at least one non-zero element - torch.diagonal(layout)[:] = 1 - torch.diagonal(at_mask)[:] = 1 - kp_mask = torch.randint(low=0, high=2, size=(Z, N), dtype=DTYPE, requires_grad=False, device="cuda") - kp_mask[:] = 0 - kp_mask[kp_mask == 1.0] = float("-inf") - # triton result - op = triton.ops.blocksparse.softmax(layout, BLOCK) - tx = triton.testing.sparsify_tensor(x, layout, BLOCK) - ty = op( - tx, - scale=scale, - key_padding_mask=kp_mask, - key_padding_mask_mode="add", - attn_mask=at_mask.to(DTYPE), - attn_mask_mode="mul", - is_causal=is_causal, - ) - # torch result - rx = triton.testing.mask_tensor(x, layout, BLOCK, value=float("-inf")) - # broadcast at_mask to the same shape as rx + layout = torch.randint(2, (H, M // BLOCK, N // BLOCK)) + if is_dense: + layout[:] = 1 + else: + layout[1, 2, :] = 0 + layout[1, :, 1] = 0 + # initialize data + a_shape = (Z, H, M, N) + a_ref, a_tri = triton.testing.make_pair(a_shape) + dout_ref, dout_tri = triton.testing.make_pair(a_shape) + # compute [torch] + a_ref = triton.testing.mask_tensor(a_ref, layout, BLOCK, value=float("-inf")) + a_ref.retain_grad() + at_mask = torch.ones((M, N), device="cuda") if is_causal: at_mask = torch.tril(at_mask) - M = at_mask[None, None, :, :] + torch.zeros_like(rx) - rx[M == 0] = float("-inf") - # rx += kp_mask[:, None, None, :] - ry = torch.softmax(rx * scale, -1) - ry = triton.testing.sparsify_tensor(ry, layout, BLOCK) + M = at_mask[None, None, :, :] + torch.zeros_like(a_ref) + a_ref[M == 0] = float("-inf") + out_ref = torch.softmax(a_ref * scale, -1) + out_ref.backward(dout_ref) + out_ref = triton.testing.sparsify_tensor(out_ref, layout, BLOCK) + da_ref = triton.testing.sparsify_tensor(a_ref.grad, layout, BLOCK) + # compute [triton] + a_tri = triton.testing.sparsify_tensor(a_tri, layout, BLOCK) + a_tri.retain_grad() + dout_tri = triton.testing.sparsify_tensor(dout_tri, layout, BLOCK) + op = triton.ops.blocksparse.softmax(layout, BLOCK, device="cuda", is_dense=is_dense) + out_tri = op(a_tri, scale=scale, is_causal=is_causal) + out_tri.backward(dout_tri) + da_tri = a_tri.grad # compare - triton.testing.assert_almost_equal(ry, ty) + triton.testing.assert_almost_equal(out_tri, out_ref) + triton.testing.assert_almost_equal(da_tri, da_ref) @pytest.mark.parametrize("block", [16, 32, 64]) @@ -99,14 +130,6 @@ def test_attention_fwd_bwd( qkvs = [ torch.nn.Parameter(input_scale * torch.randn(qkv_shape), requires_grad=True).to(dtype).cuda() for _ in range(3) ] - attn_mask = torch.tril( - torch.ones( - [n_ctx, n_ctx], - device="cuda", - dtype=dtype, - ), - diagonal=0, - ) # Triton: n_blocks = n_ctx // block @@ -115,7 +138,7 @@ def test_attention_fwd_bwd( query.retain_grad() key.retain_grad() value.retain_grad() - attn_out = triton_attention(layout, block, attn_mask, query=query, key=key, value=value, scale=scale) + attn_out = triton_attention(layout, block, query=query, key=key, value=value, scale=scale) # ad hoc loss loss = (attn_out ** 2).mean() loss.backward() @@ -123,6 +146,8 @@ def test_attention_fwd_bwd( # Torch version: torch_q, torch_k, torch_v = [x.clone() for x in qkvs] + attn_mask = torch.ones([n_ctx, n_ctx], device="cuda", dtype=dtype) + attn_mask = torch.tril(attn_mask, diagonal=0) attn_mask = 1e6 * (-1 + (attn_mask.reshape((1, 1, n_ctx, n_ctx)).cuda())) torch_q.retain_grad() torch_k.retain_grad() @@ -147,7 +172,6 @@ def test_attention_fwd_bwd( def triton_attention( layout, block: int, - attn_mask: torch.Tensor, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, @@ -155,12 +179,9 @@ def triton_attention( ): sparse_dot_sdd_nt = triton.ops.blocksparse.matmul(layout, block, "sdd", trans_a=False, trans_b=True, device=value.device) sparse_dot_dsd_nn = triton.ops.blocksparse.matmul(layout, block, "dsd", trans_a=False, trans_b=False, device=value.device) - sparse_softmax = triton.ops.blocksparse.softmax( - layout, - block, - ) + sparse_softmax = triton.ops.blocksparse.softmax(layout, block, device=value.device) w = sparse_dot_sdd_nt(query, key) - w = sparse_softmax(w, scale=scale, attn_mask=attn_mask, attn_mask_mode="mul") + w = sparse_softmax(w, scale=scale, is_causal=True) a = sparse_dot_dsd_nn(w, value) return a diff --git a/python/triton/ops/blocksparse/matmul.py b/python/triton/ops/blocksparse/matmul.py index ebcea8e1e..0fa1a5878 100644 --- a/python/triton/ops/blocksparse/matmul.py +++ b/python/triton/ops/blocksparse/matmul.py @@ -242,8 +242,31 @@ def dsd_matmul(a, b, trans_a, trans_b, trans_c, spdims, block, lut, width, out=N def dsd_lut(layout, block, step, trans, device): + """ + Generates the look-up table for incrementing pointers in the DSD/DDS matmul. + Example (BLOCK=32, STEP=16) + [[1, 0, 0, 1, 0], + [0, 1, 1, 0, 1], + [1, 0, 1, 0, 0]] + + Then the offsets for A are + [0 , 16, 32, 48] <- row 0 + \\----/ \\----/ + col=0 col=3 + [64, 80, 96, 112, 128, 144] <- row 1 + \\----/ \\----/ \\------/ + col=1 col=2 col=3 + [160, 176, 192, 208] + which leads to increments table + [0, 16, 16, 16, || 64, 16, 16, 16, 16, 16, || 160, 16, 16, 16] + + Because B is dense, the offsets are + [0, 16, 96, 112] <- row 0 + [32, 48, 64, 80] <- row 1 + [0, 16, 64, 80] <- row 2 + """ sizes = torch.sum(layout, 2 if trans else 1) - head_id, col_id = sizes.nonzero(as_tuple=True) + head_id, col_id = torch.ones_like(sizes).nonzero(as_tuple=True) sizes = sizes.flatten() segments = sizes * step # pointer increments @@ -258,13 +281,6 @@ def dsd_lut(layout, block, step, trans, device): # ------------------------------- # dense input pointer increments # ------------------------------- - # given a list of the indices for the first element of each non-zero block. - # For example, for the indices - # [32, 80, 128, 256, 288] - # we would generate the increments - # [32, 48, 48, 128, 32] - # ^ - # index of first element # Note that the inner loop matmul kernel may have a fixed step size (e.g., TILE_K) # that is smaller than the block size, so we need to do a bit of extra work # to handle this case @@ -324,116 +340,11 @@ def dsd_lut(layout, block, step, trans, device): # ----------------------------- # Dense = Dense x Sparse (DDS) # ----------------------------- - - -@triton.jit -def _dds_kernel( - A, B, C, - stride_za, stride_ha, stride_ma, stride_ka, - stride_zb, stride_hb, stride_bk, stride_bn, - stride_zc, stride_hc, stride_mc, stride_nc, - DS0, DS1, lut, - TILE_M: tl.constexpr, TILE_N: tl.constexpr, TILE_K: tl.constexpr, - GROUP_SIZE_M: tl.constexpr, BLOCK: tl.constexpr, -): - # ------------ # - # - Prologue - # - # ------------ # - pid_m = tl.program_id(0) - pid_n = tl.program_id(1) - num_pid_m = tl.num_programs(0) - num_pid_n = tl.num_programs(1) - pid_n, pid_m = tl.swizzle2d(pid_n, pid_m, num_pid_n, num_pid_m, GROUP_SIZE_M) - pid_z = tl.program_id(2) - header = lut + pid_n * 4 - offset = tl.load(header + 0) - AS1 = tl.load(header + 1) - column = tl.load(header + 2) - off_h = tl.load(header + 3) - pinc = lut + offset - # initialize pointers to A (dense) - offs_am = pid_m * TILE_M + tl.arange(0, TILE_M) - offs_am = tl.max_contiguous(tl.multiple_of(offs_am % DS0, TILE_M), TILE_M) - start_ak = tl.load(pinc) - start_ak = tl.multiple_of(start_ak, 8) - offs_ak = start_ak + tl.arange(0, TILE_K) - ptrs_a = A + pid_z * stride_za \ - + off_h * stride_ha \ - + offs_am[:, None] * stride_ma \ - + offs_ak[None, :] * stride_ka - # initialize pointers to B (sparse) - block_id = tl.load(pinc + 1) - block_id = tl.multiple_of(block_id, 8) - offs_bn = tl.arange(0, TILE_N) - offs_bk = tl.arange(0, TILE_K) - ptrs_b = B + pid_z * stride_zb \ - + block_id * stride_hb \ - + offs_bn[None, :] * stride_bn \ - + offs_bk[:, None] * stride_bk - # ---------------- # - # Inner Loop # - # ---------------- # - acc = tl.zeros((TILE_M, TILE_N), dtype=tl.float32) - for k in range(AS1, 0, -TILE_K): - a = tl.load(ptrs_a, mask=offs_am[:, None] < DS0) - b = tl.load(ptrs_b, mask=True) - acc += tl.dot(a, b) - pinc += 2 - inc_a = tl.load(pinc) - inc_b = tl.load(pinc + 1) - inc_a = tl.multiple_of(inc_a, 8) - inc_b = tl.multiple_of(inc_b, 8) - inc_a = inc_a * stride_ka - ptrs_a += inc_a - ptrs_b += inc_b - # ---------------- # - # Epilogue # - # ---------------- # - c = acc.to(C.dtype.element_ty) - # initialize pointers to C (dense) - offs_cm = pid_m * TILE_M + tl.arange(0, TILE_M) - offs_cn = column * TILE_N + tl.arange(0, TILE_N) - ptrs_c = C + off_h * stride_hc \ - + pid_z * stride_zc \ - + offs_cm[:, None] * stride_mc \ - + offs_cn[None, :] * stride_nc - # write back - tl.store(ptrs_c, c, mask=offs_cm[:, None] < DS0) +# AB = (B^T A^T)^T def dds_matmul(a, b, trans_a, trans_b, trans_c, spdims, block, lut, width, out=None): - if a.stride(2) != 1 and a.stride(3) != 1: - a = a.contiguous() - if b.stride(2) != 1 and b.stride(3) != 1: - b = b.contiguous() - # shapes / dtypes - AS0 = a.size(0) - AS1 = a.size(1) - AS2 = a.size(3 if trans_a else 2) - BS2 = block * spdims[1 if trans_b else 2] - dtype = a.dtype - # output - CS0 = AS0 - CS1 = AS1 - CS2 = BS2 if trans_c else AS2 - CS3 = AS2 if trans_c else BS2 - if out is None: - c = torch.empty((CS0, CS1, CS2, CS3), dtype=dtype, device=a.device) - else: - assert out.shape == (CS0, CS1, CS2, CS3) - c = out - TILE_M = {16: 256, 32: 256, 64: 128, 128: 128}[block] - grid = lambda meta: [triton.cdiv(AS2, meta['TILE_M']), width, AS0] - _dds_kernel[grid]( - a, b, c, - a.stride(0), a.stride(1), a.stride(3 if trans_a else 2), a.stride(2 if trans_a else 3), - b.stride(0), b.stride(1), b.stride(3 if trans_b else 2), b.stride(2 if trans_b else 3), - c.stride(0), c.stride(1), c.stride(3 if trans_c else 2), c.stride(2 if trans_c else 3), - AS2, BS2, lut, - TILE_M=TILE_M, TILE_N=block, TILE_K=min(block, 32), BLOCK=block, num_stages=4, - num_warps=4, GROUP_SIZE_M=4, - ) - return c + return dsd_matmul(b, a, not trans_b, not trans_a, not trans_c, spdims, block, lut, width, out=out) ############## # MAIN API # diff --git a/python/triton/ops/blocksparse/softmax.py b/python/triton/ops/blocksparse/softmax.py index 6ac76dcc4..bb915be13 100644 --- a/python/triton/ops/blocksparse/softmax.py +++ b/python/triton/ops/blocksparse/softmax.py @@ -5,230 +5,235 @@ import triton.language as tl def num_warps(n): - if n < 512: + if n <= 128: + return 1 + if n <= 256: + return 2 + if n <= 512: return 4 - if n < 2048: + if n <= 4096: return 8 return 16 -@triton.heuristics({'num_warps': lambda nargs: num_warps(nargs['sizemax'] * nargs['BLOCK'])}) -@triton.heuristics({'TN': lambda nargs: triton.next_power_of_2(nargs['sizemax'] * nargs['BLOCK'])}) @triton.jit -def _forward( - X, scale, LUT, RPE, KP_M, ATTN_M, is_causal, sizemax, stride_zx, stride_zrpe, stride_hrpe, stride_srpe, stride_zkpm, stride_zattnm, - TN: tl.constexpr, BLOCK: tl.constexpr, APPLY_SCALE: tl.constexpr, APPLY_RPE: tl.constexpr, APPLY_KP_MASK: tl.constexpr, - KP_MASK_MUL: tl.constexpr, APPLY_ATTN_MASK: tl.constexpr, ATTN_MASK_MUL: tl.constexpr, +def _blocksparse_softmax_fwd( + Out, A, stride_xz, LUT, + R, extent, stride_zr, stride_hr, # relative attention + scale, is_causal, + ROW_SIZE: tl.constexpr, + BLOCK_SIZE: tl.constexpr, + IS_DENSE: tl.constexpr, ): - pidhm = tl.program_id(0) - pidz = tl.program_id(1) + h = tl.program_id(0) + m = tl.program_id(1) + z = tl.program_id(2) # create index ranges - rxm = pidhm % BLOCK - rbm = pidhm // BLOCK - rxn = tl.arange(0, TN) % BLOCK - rbn = tl.arange(0, TN) // BLOCK + hm = h * tl.num_programs(1) + m + lane_n = tl.arange(0, ROW_SIZE) % BLOCK_SIZE + block_n = tl.arange(0, ROW_SIZE) // BLOCK_SIZE # extract information from LUT - header = LUT + rbm * 2 + header = LUT + (hm // BLOCK_SIZE) * 2 size = tl.load(header + 0) offset = tl.load(header + 1) - check = rbn < size - rbmn = tl.where(check, rbn, size - 1) - # block id and column id - blockid = tl.load(LUT + offset + rbmn * 4 + 0) - columnid = tl.load(LUT + offset + rbmn * 4 + 1) - rowid = tl.load(LUT + offset + rbmn * 4 + 2) - headid = tl.load(LUT + offset + rbmn * 4 + 3) - # pointers to X - px = X + pidz * stride_zx + blockid * BLOCK * BLOCK + rxm * BLOCK + rxn - x = tl.load(px, mask=check, other=-float('inf')) - x = x.to(tl.float32) - # apply scale - if APPLY_SCALE: - x = x * scale - # apply RPE - if APPLY_RPE: - prpe = RPE + pidz * stride_zrpe + headid * stride_hrpe + columnid * BLOCK + rowid * BLOCK * stride_srpe + rxm * stride_srpe + rxn - rpe = tl.load(prpe, mask=check, other=0) - x = x + rpe - # apply key-padding mask - if APPLY_KP_MASK: - pkp_m = KP_M + pidz * stride_zkpm + columnid * BLOCK + rxn - kp_m = tl.load(pkp_m, mask=check, other=-float('inf')) - if KP_MASK_MUL: - kp_m = tl.where(kp_m == 0, -float('inf'), 0.) - x = x + kp_m - # apply attention mask - if APPLY_ATTN_MASK: - pattn_m = ATTN_M + columnid * BLOCK + rowid * BLOCK * stride_zattnm + rxm * stride_zattnm + rxn - attn_m = tl.load(pattn_m, mask=check, other=-float('inf')) - if ATTN_MASK_MUL: - attn_m = tl.where(attn_m == 0, -float('inf'), 0.) - x = x + attn_m + # pointer offset + off_a = z * stride_xz + off_a += (offset + block_n) * BLOCK_SIZE * BLOCK_SIZE # block indx + off_a += (m % BLOCK_SIZE) * BLOCK_SIZE # row indx + # do not need to read column indices in the dense case + if IS_DENSE: + ns = tl.arange(0, ROW_SIZE) + else: + off_lut = offset + 2 * tl.num_programs(0) * tl.num_programs(1) // BLOCK_SIZE + start_n = tl.load(LUT + off_lut + block_n, mask=block_n < size, other=0) + ns = start_n * BLOCK_SIZE + lane_n + # load X + mask = block_n < size + a = tl.load(A + off_a + lane_n, mask=mask, other=-float("inf")) + a = a.to(tl.float32) + # compute + out = a + out *= scale + # apply relative attention + if R is not None: + R += z * stride_zr + R += h * stride_hr + off_lo = (extent - m - 1) + ns + mask_lo = (off_lo >= 0) & (off_lo < extent) + rel_logits = tl.load(R + m * extent + off_lo, mask=mask_lo, other=0.0) + out += rel_logits + out = out.to(tl.float32) # apply causal mask - is_in_upper_triangle = columnid * BLOCK + rxn > rowid * BLOCK + rxm - x = x + tl.where(is_in_upper_triangle & is_causal, -float('inf'), 0.) + out = tl.where((ns > m) & is_causal, -float("inf"), out) # computation - x = tl.softmax(x) - tl.store(px, x, mask=check) + out = tl.softmax(out) + # write-back + tl.store(Out + off_a + lane_n, out, mask=mask) -@triton.heuristics({'num_warps': lambda nargs: num_warps(nargs['sizemax'] * nargs['BLOCK'])}) -@triton.heuristics({'TN': lambda nargs: triton.next_power_of_2(nargs['sizemax']) * nargs['BLOCK']}) @triton.jit -def _backward(X, scale, DX, LUT, sizemax, stride_zx, stride_zdx, TN: tl.constexpr, BLOCK: tl.constexpr): - pidhm = tl.program_id(0) - pidz = tl.program_id(1) +def _blocksparse_softmax_bwd( + DA, stride_zdx, + DOut, stride_zdout, + Out, stride_zout, + scale, + LUT, + DR, extent, stride_zr, stride_hr, stride_er, + is_causal, + ROW_SIZE: tl.constexpr, + BLOCK_SIZE: tl.constexpr, + IS_DENSE: tl.constexpr, +): + h = tl.program_id(0) + m = tl.program_id(1) + z = tl.program_id(2) # create index ranges - rxm = pidhm % BLOCK - rbm = pidhm // BLOCK - rxn = tl.arange(0, TN) % BLOCK - rbn = tl.arange(0, TN) // BLOCK - # extract information from look-up table - header = LUT + rbm * 2 + hm = h * tl.num_programs(1) + m + lane_n = tl.arange(0, ROW_SIZE) % BLOCK_SIZE + block_n = tl.arange(0, ROW_SIZE) // BLOCK_SIZE + # extract information from LUT + header = LUT + (hm // BLOCK_SIZE) * 2 size = tl.load(header + 0) offset = tl.load(header + 1) - # bounds checking on lut - check = rbn < size - rbmn = tl.where(check, rbn, size - 1) - # initialize pointers to block-sparse input - blockid = tl.load(LUT + offset + rbmn * 4) - X = X + pidz * stride_zx + blockid * BLOCK * BLOCK + rxm * BLOCK + rxn - DX = DX + pidz * stride_zdx + blockid * BLOCK * BLOCK + rxm * BLOCK + rxn - # compute fused softmax backward - x = tl.load(X, mask=check, other=0) - dx = tl.load(DX, mask=check, other=0) - x = x.to(tl.float32) - dx = dx.to(tl.float32) - y = x * (dx - tl.sum(x * dx, 0)) * scale - tl.store(DX, y, mask=check) + # row-col offset + off_mn = (offset + block_n) * BLOCK_SIZE * BLOCK_SIZE + off_mn += (m % BLOCK_SIZE) * BLOCK_SIZE + mask = block_n < size + # pointers + As = Out + z * stride_zout + off_mn + DOuts = DOut + z * stride_zdout + off_mn + # do not need to read column indices in the dense case + if IS_DENSE: + ns = tl.arange(0, ROW_SIZE) + else: + off_lut = offset + 2 * tl.num_programs(0) * tl.num_programs(1) // BLOCK_SIZE + start_n = tl.load(LUT + off_lut + block_n, mask=mask, other=0) + ns = start_n * BLOCK_SIZE + lane_n + # load data + a = tl.load(As + lane_n, mask=mask, other=0.0) + a = a.to(tl.float32) + dout = tl.load(DOuts + lane_n, mask=mask, other=0.0) + dout = dout.to(tl.float32) + # compute + da = a * (dout - tl.sum(a * dout, 0)) + da = tl.where((ns > m) & is_causal, 0., da) + # apply relative attention + if DR is not None: + DR += z * stride_zr + DR += h * stride_hr + off_lo = (extent - m - 1) + ns + mask_lo = (off_lo >= 0) & (off_lo < extent) & mask + tl.store(DR + m * extent + off_lo, da, mask=mask_lo) + da = da * scale + # convert da + # write-back + DAs = DA + z * stride_zdx + off_mn + tl.store(DAs + lane_n, da, mask=mask) class _softmax(torch.autograd.Function): @staticmethod def make_lut(layout, block, device): + _empty = torch.tensor([], dtype=torch.int64, device=layout.device) + sizes = _empty.clone() # sizes along rows - sizes = layout.sum(-1).view(-1) + for h in range(layout.shape[0]): + sizes = torch.cat((sizes, layout[h, :, :].sum(-1))) + total_sizes = sizes * block # offsets in block format offsets = torch.zeros_like(sizes) offsets[1:] = torch.cumsum(sizes[:-1], dim=0) # block indices - layout_sum = sizes.sum() - idx = torch.arange(layout_sum, device=layout.device) - layout_nonzero = layout.nonzero(as_tuple=False) - head = layout_nonzero[:, 0] - rows = layout_nonzero[:, 1] - columns = layout_nonzero[:, 2] - core = torch.stack((idx, columns, rows, head), dim=1).view(-1) - # construct look-up table - offsets = offsets * 4 + 2 * sizes.numel() + columns = layout.nonzero(as_tuple=False)[:, 2] header = torch.stack((sizes, offsets), dim=1).view(-1) - lut = torch.cat((header, core)).type(torch.int32).to(device) - return lut, int(sizes.max()) + lut = torch.cat((header, columns)).type(torch.int32).to(device) + return lut, int(total_sizes.max()) @staticmethod def forward( - ctx, x, scale, rpe, - key_padding_mask, attn_mask, - kp_mask_mode, attn_mask_mode, - is_causal, - spdims, block, lut, maxlut + ctx, a, scale, rel_logits, is_causal, + spdims, block, lut, maxlut, is_dense ): - apply_scale = False if scale == 1.0 else True - # handle None rpe - if rpe is None: - apply_rpe = False - stride_zrpe, stride_hrpe, stride_srpe = 0, 0, 0 - rpe = torch.empty(0, dtype=x.dtype, device=x.device) - else: - apply_rpe = True - stride_zrpe, stride_hrpe, stride_srpe = rpe.stride(0), rpe.stride(1), rpe.stride(2) - # handle None key_padding_mask - if key_padding_mask is None: - apply_kp_mask = False - stride_zkpm = 0 - key_padding_mask = torch.empty(0, dtype=x.dtype, device=x.device) - else: - apply_kp_mask = True - stride_zkpm = key_padding_mask.stride(0) - # handle None attention_mask - if attn_mask is None: - apply_attn_mask = False - stride_zattnm = 0 - attn_mask = torch.empty(0, dtype=x.dtype, device=x.device) - else: - apply_attn_mask = True - stride_zattnm = attn_mask.stride(0) - # run kernel - M = x.shape[0] - grid = [spdims[0] * spdims[1] * block, M] - _forward[grid](x, scale, lut, rpe, key_padding_mask, attn_mask, is_causal, maxlut, x.stride(0), - stride_zrpe, stride_hrpe, stride_srpe, stride_zkpm, stride_zattnm, - BLOCK=block, - APPLY_SCALE=apply_scale, - APPLY_RPE=apply_rpe, - APPLY_KP_MASK=apply_kp_mask, - APPLY_ATTN_MASK=apply_attn_mask, - KP_MASK_MUL=(kp_mask_mode == 'mul'), - ATTN_MASK_MUL=(attn_mask_mode == 'mul')) + if scale is not None and isinstance(scale, torch.Tensor): + assert scale.device.type == "cpu" + scale = scale.item() + M = a.shape[0] + grid = [spdims[0], spdims[1] * block, M] + rel_shape = (1, 1, 1, 1) if rel_logits is None else rel_logits.shape + rel_strides = (1, 1, 1, 1) if rel_logits is None else rel_logits.stride() + # enqueue kernel + out = torch.empty_like(a) + _blocksparse_softmax_fwd[grid]( + out, a, a.stride(0), lut, + rel_logits, rel_shape[-1], rel_strides[0], rel_strides[1], # relative attn + scale, + is_causal, + BLOCK_SIZE=block, + ROW_SIZE=triton.next_power_of_2(maxlut), + IS_DENSE=is_dense, + num_warps=num_warps(maxlut) + ) # save to context - ctx.mark_dirty(x) - ctx.save_for_backward(x, lut) + # ctx.mark_dirty(x) + ctx.save_for_backward(out, lut) ctx.spdims = spdims ctx.block = block ctx.maxlut = maxlut ctx.scale = scale - ctx.apply_scale = apply_scale - ctx.apply_rpe = apply_rpe - ctx.apply_kp_mask = apply_kp_mask - ctx.apply_attn_mask = apply_attn_mask - ctx.kp_mask_mode = kp_mask_mode - ctx.attn_mask_mode = attn_mask_mode - return x + ctx.rel_shape = rel_shape + ctx.rel_strides = rel_strides + ctx.rel_dtype = a.dtype + ctx.is_dense = is_dense + ctx.is_causal = is_causal + return out @staticmethod - def backward(ctx, dx): + def backward(ctx, dout): # retrieve from context - x, lut = ctx.saved_tensors + out, lut = ctx.saved_tensors + # relative logits gradients + dr = None + if ctx.needs_input_grad[3]: + dr = torch.zeros(ctx.rel_shape, dtype=ctx.rel_dtype, device=out.device) # run kernel - M = x.shape[0] - grid = lambda opt: [ctx.spdims[0] * ctx.spdims[1] * ctx.block, M] - _backward[grid](x, ctx.scale, dx, lut, ctx.maxlut, x.stride(0), dx.stride(0), BLOCK=ctx.block) - return dx, None, None, None, None, None, None, None, None, None, None, None, None, None, None + M = out.shape[0] + grid = (ctx.spdims[0], ctx.spdims[1] * ctx.block, M) + da = torch.empty_like(dout) + _blocksparse_softmax_bwd[grid]( + da, da.stride(0), + dout, dout.stride(0), + out, out.stride(0), + ctx.scale, + lut, + dr, ctx.rel_shape[-1], ctx.rel_strides[0], ctx.rel_strides[1], ctx.rel_strides[2], + ctx.is_causal, + BLOCK_SIZE=ctx.block, + ROW_SIZE=triton.next_power_of_2(ctx.maxlut), + IS_DENSE=ctx.is_dense, + num_warps=num_warps(ctx.maxlut) + ) + return (da, None, None, dr, None, + None, None, None, None, None, + None, + None, None, None, + None, + None, None, None + ) class softmax: - - def make_lut(self, device): - key = (device, ) - if key not in self.lut_cache: - self.lut_cache[key] = _softmax.make_lut(self.layout, self.block, device) - return self.lut_cache[key] - - def __init__(self, layout, block): + def __init__(self, layout, block, device, is_dense=False): self.spdims = layout.shape self.layout = layout self.block = block - self.lut_cache = dict() + self.lut, self.maxlut = _softmax.make_lut(self.layout, self.block, device) + self.is_dense = is_dense - def __call__( - self, x, scale=1., rpe=None, - key_padding_mask=None, attn_mask=None, - key_padding_mask_mode='add', attn_mask_mode='add', - is_causal=False - ): - if rpe is not None and rpe.dtype != x.dtype: - raise ValueError('relative position embedding must be %s' % x.dtype) - if attn_mask is not None and attn_mask.dtype != x.dtype: - raise ValueError('Attention mask must be %s' % x.dtype) - if key_padding_mask is not None and key_padding_mask.dtype != x.dtype: - raise ValueError('Key padding mask must be %s' % x.dtype) - lut, maxlut = self.make_lut(x.device) - x = _softmax.apply( - x, scale, rpe, - key_padding_mask, attn_mask, - key_padding_mask_mode, attn_mask_mode, - is_causal, - self.spdims, self.block, - lut, maxlut + def __call__(self, a, *, scale=1.0, rel_logits=None, is_causal=False): + if rel_logits is not None and rel_logits.dtype != a.dtype: + raise ValueError("relative position embedding must be %s" % a.dtype) + a = _softmax.apply( + a, scale, rel_logits, is_causal, + self.spdims, self.block, self.lut, self.maxlut, self.is_dense, ) - return x + return a diff --git a/python/triton/testing.py b/python/triton/testing.py index 199226ea1..c720f64cf 100644 --- a/python/triton/testing.py +++ b/python/triton/testing.py @@ -32,6 +32,19 @@ def sparsify_tensor(x, mask, block): return ret +def make_pair(shape, device="cuda", alpha=1e-2, beta=0., trans=False, data=None): + if data is None: + data = torch.randn(shape, dtype=torch.float32, device=device) + ref_ret = data + ref_ret = ref_ret * alpha + beta + ref_ret = ref_ret.half().float() + if trans: + ref_ret = ref_ret.t().requires_grad_() + ref_ret = ref_ret.detach().requires_grad_() + tri_ret = ref_ret.clone().detach().requires_grad_() + return ref_ret, tri_ret + + def cutlass_matmul(a, b): if _cutlass is None: raise RuntimeError("Cannot find cutlass library")