From 9bb54402b3ff33f19ee2cf9decbbc27f4d9ecc08 Mon Sep 17 00:00:00 2001 From: Philippe Tillet Date: Tue, 29 Nov 2022 20:00:34 +0100 Subject: [PATCH] [FRONTEND][BACKEND] Small fixes to multiple_of, num_programs, axisinfo; enable block-sparse tests (#927) --- lib/Analysis/AxisInfo.cpp | 15 ++ .../TritonGPUToLLVM/TritonGPUToLLVM.cpp | 2 +- python/src/triton.cc | 11 + python/tests/test_blocksparse.py | 188 ++++++++++++++++++ python/triton/compiler.py | 4 +- python/triton/language/semantic.py | 14 +- python/triton/testing.py | 6 +- test/Conversion/tritongpu_to_llvm.mlir | 6 +- 8 files changed, 229 insertions(+), 17 deletions(-) create mode 100644 python/tests/test_blocksparse.py diff --git a/lib/Analysis/AxisInfo.cpp b/lib/Analysis/AxisInfo.cpp index 5391574dd..58e3efa3c 100644 --- a/lib/Analysis/AxisInfo.cpp +++ b/lib/Analysis/AxisInfo.cpp @@ -132,6 +132,7 @@ ChangeResult AxisInfoAnalysis::visitOperation( AxisInfo::DimVectorT(ty.getShape().begin(), ty.getShape().end())); } } + // TODO: refactor & complete binary ops // Addition if (llvm::isa(op)) { auto newContiguity = [&](AxisInfo lhs, AxisInfo rhs, int d) { @@ -159,6 +160,20 @@ ChangeResult AxisInfoAnalysis::visitOperation( curr = visitBinaryOp(op, operands[0]->getValue(), operands[1]->getValue(), newContiguity, newDivisibility, newConstancy); } + // Remainder + if (llvm::isa(op)) { + auto newContiguity = [](AxisInfo lhs, AxisInfo rhs, int d) { + return gcd(lhs.getContiguity(d), rhs.getDivisibility(d)); + }; + auto newDivisibility = [](AxisInfo lhs, AxisInfo rhs, int d) { + return gcd(lhs.getDivisibility(d), rhs.getDivisibility(d)); + }; + auto newConstancy = [](AxisInfo lhs, AxisInfo rhs, int d) { + return gcd(lhs.getConstancy(d), rhs.getConstancy(d)); + }; + curr = visitBinaryOp(op, operands[0]->getValue(), operands[1]->getValue(), + newContiguity, newDivisibility, newConstancy); + } // TODO: All other binary ops if (llvm::isa(op)) { auto newContiguity = [](AxisInfo lhs, AxisInfo rhs, int d) { return 1; }; diff --git a/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.cpp b/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.cpp index cdc230d49..138e0bd2d 100644 --- a/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.cpp +++ b/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.cpp @@ -2074,7 +2074,7 @@ struct GetNumProgramsOpConversion Location loc = op->getLoc(); assert(op.axis() < 3); - Value blockId = rewriter.create<::mlir::gpu::BlockDimOp>( + Value blockId = rewriter.create<::mlir::gpu::GridDimOp>( loc, rewriter.getIndexType(), dims[op.axis()]); auto llvmIndexTy = getTypeConverter()->getIndexType(); rewriter.replaceOpWithNewOp( diff --git a/python/src/triton.cc b/python/src/triton.cc index 003910af6..419beed8c 100644 --- a/python/src/triton.cc +++ b/python/src/triton.cc @@ -187,6 +187,7 @@ void init_triton_ir(py::module &&m) { /* issue a warning */ } }) + .def("get_context", &mlir::Value::getContext) .def("replace_all_uses_with", [](mlir::Value &self, mlir::Value &newValue) { self.replaceAllUsesWith(newValue); @@ -335,6 +336,16 @@ void init_triton_ir(py::module &&m) { return funcs[0]; }); + m.def("make_attr", + [](const std::vector &values, mlir::MLIRContext &context) { + return mlir::DenseIntElementsAttr::get( + mlir::RankedTensorType::get( + {static_cast(values.size())}, + mlir::IntegerType::get(&context, 32)), + values) + .cast(); + }); + m.def( "parse_mlir_module", [](const std::string &inputFilename, mlir::MLIRContext &context) { diff --git a/python/tests/test_blocksparse.py b/python/tests/test_blocksparse.py new file mode 100644 index 000000000..8ddf190be --- /dev/null +++ b/python/tests/test_blocksparse.py @@ -0,0 +1,188 @@ +import pytest +import torch + +import triton + +# TODO: float32 fails + +@pytest.mark.parametrize("MODE", ["sdd", "dds", "dsd"]) +@pytest.mark.parametrize("TRANS_B", [False, True]) +@pytest.mark.parametrize("TRANS_A", [False, True]) +@pytest.mark.parametrize("BLOCK", [16, 32, 64]) +@pytest.mark.parametrize("DTYPE", [torch.float16]) +def test_matmul(MODE, TRANS_A, TRANS_B, BLOCK, DTYPE, Z=3, H=2, M=512, N=256, K=384): + seed = 0 + torch.manual_seed(seed) + is_sdd = MODE == "sdd" + is_dsd = MODE == "dsd" + is_dds = MODE == "dds" + do_sparsify = lambda x: triton.testing.sparsify_tensor(x, layout, BLOCK) + do_mask = lambda x: triton.testing.mask_tensor(x, layout, BLOCK) + # create inputs + # create op + a_shape = (Z, H, K, M) if TRANS_A else (Z, H, M, K) + b_shape = (Z, H, N, K) if TRANS_B else (Z, H, K, N) + c_shape = (Z, H, M, N) + shape = { + "sdd": (M, N), + "dsd": (a_shape[2], a_shape[3]), + "dds": (b_shape[2], b_shape[3]), + }[MODE] + layout = torch.randint(2, (H, shape[0] // BLOCK, shape[1] // BLOCK)) + layout[1, 2, :] = 0 + layout[1, :, 1] = 0 + # create data + a_ref, a_tri = triton.testing.make_pair(a_shape, alpha=.1, dtype=DTYPE) + b_ref, b_tri = triton.testing.make_pair(b_shape, alpha=.1, dtype=DTYPE) + dc_ref, dc_tri = triton.testing.make_pair(c_shape, dtype=DTYPE) + # compute [torch] + dc_ref = do_mask(dc_ref) if is_sdd else dc_ref + a_ref = do_mask(a_ref) if is_dsd else a_ref + b_ref = do_mask(b_ref) if is_dds else b_ref + a_ref.requires_grad_().retain_grad() + b_ref.requires_grad_().retain_grad() + c_ref = torch.matmul(a_ref.transpose(2, 3) if TRANS_A else a_ref, + b_ref.transpose(2, 3) if TRANS_B else b_ref) + c_ref.backward(dc_ref) + c_ref = do_sparsify(c_ref) if is_sdd else c_ref + da_ref = do_sparsify(a_ref.grad) if is_dsd else a_ref.grad + db_ref = do_sparsify(b_ref.grad) if is_dds else b_ref.grad + # triton result + dc_tri = do_sparsify(dc_tri) if is_sdd else dc_tri + a_tri = do_sparsify(a_tri) if is_dsd else a_tri + b_tri = do_sparsify(b_tri) if is_dds else b_tri + a_tri.requires_grad_().retain_grad() + b_tri.requires_grad_().retain_grad() + op = triton.ops.blocksparse.matmul(layout, BLOCK, MODE, trans_a=TRANS_A, trans_b=TRANS_B, device="cuda") + c_tri = triton.testing.catch_oor(lambda: op(a_tri, b_tri), pytest) + triton.testing.catch_oor(lambda: c_tri.backward(dc_tri), pytest) + da_tri = a_tri.grad + db_tri = b_tri.grad + # compare + triton.testing.assert_almost_equal(c_ref, c_tri) + triton.testing.assert_almost_equal(da_ref, da_tri) + triton.testing.assert_almost_equal(db_ref, db_tri) + + +configs = [ + (16, 256), + (32, 576), + (64, 1871), + (128, 2511), +] + + +@pytest.mark.parametrize("is_dense", [False, True]) +@pytest.mark.parametrize("BLOCK, WIDTH", configs) +def test_softmax(BLOCK, WIDTH, is_dense, Z=2, H=2, is_causal=True, scale=0.4): + # set seed + torch.random.manual_seed(0) + Z, H, M, N = 2, 3, WIDTH, WIDTH + # initialize layout + # make sure each row has at least one non-zero element + layout = torch.randint(2, (H, M // BLOCK, N // BLOCK)) + if is_dense: + layout[:] = 1 + else: + layout[1, 2, :] = 0 + layout[1, :, 1] = 0 + # initialize data + a_shape = (Z, H, M, N) + a_ref, a_tri = triton.testing.make_pair(a_shape) + dout_ref, dout_tri = triton.testing.make_pair(a_shape) + # compute [torch] + a_ref = triton.testing.mask_tensor(a_ref, layout, BLOCK, value=float("-inf")) + a_ref.retain_grad() + at_mask = torch.ones((M, N), device="cuda") + if is_causal: + at_mask = torch.tril(at_mask) + M = at_mask[None, None, :, :] + torch.zeros_like(a_ref) + a_ref[M == 0] = float("-inf") + out_ref = torch.softmax(a_ref * scale, -1) + out_ref.backward(dout_ref) + out_ref = triton.testing.sparsify_tensor(out_ref, layout, BLOCK) + da_ref = triton.testing.sparsify_tensor(a_ref.grad, layout, BLOCK) + # compute [triton] + a_tri = triton.testing.sparsify_tensor(a_tri, layout, BLOCK) + a_tri.retain_grad() + dout_tri = triton.testing.sparsify_tensor(dout_tri, layout, BLOCK) + op = triton.ops.blocksparse.softmax(layout, BLOCK, device="cuda", is_dense=is_dense) + out_tri = op(a_tri, scale=scale, is_causal=is_causal) + out_tri.backward(dout_tri) + da_tri = a_tri.grad + # compare + triton.testing.assert_almost_equal(out_tri, out_ref) + triton.testing.assert_almost_equal(da_tri, da_ref) + + +@pytest.mark.parametrize("block", [16, 32, 64]) +@pytest.mark.parametrize("dtype", [torch.float16]) +def test_attention_fwd_bwd( + block, + dtype, + input_scale=1.0, + scale=1 / 8.0, + n_ctx=256, + batch_size=2, + n_heads=2, +): + # inputs + qkv_shape = (batch_size, n_heads, n_ctx, 64) + qkvs = [ + torch.nn.Parameter(input_scale * torch.randn(qkv_shape), requires_grad=True).to(dtype).cuda() for _ in range(3) + ] + + # Triton: + n_blocks = n_ctx // block + layout = torch.tril(torch.ones([n_heads, n_blocks, n_blocks], dtype=torch.long)) + query, key, value = [x.clone() for x in qkvs] + query.retain_grad() + key.retain_grad() + value.retain_grad() + attn_out = triton_attention(layout, block, query=query, key=key, value=value, scale=scale) + # ad hoc loss + loss = (attn_out ** 2).mean() + loss.backward() + grads = [query.grad, key.grad, value.grad] + + # Torch version: + torch_q, torch_k, torch_v = [x.clone() for x in qkvs] + attn_mask = torch.ones([n_ctx, n_ctx], device="cuda", dtype=dtype) + attn_mask = torch.tril(attn_mask, diagonal=0) + attn_mask = 1e6 * (-1 + (attn_mask.reshape((1, 1, n_ctx, n_ctx)).cuda())) + torch_q.retain_grad() + torch_k.retain_grad() + torch_v.retain_grad() + scores = scale * torch.einsum("bhsd,bhtd->bhst", torch_q, torch_k) + scores = scores + attn_mask + probs = torch.softmax(scores, dim=-1) + torch_attn_out = torch.einsum("bhst,bhtd->bhsd", probs, torch_v) + # ad hoc loss + torch_loss = (torch_attn_out ** 2).mean() + torch_loss.backward() + torch_grads = [torch_q.grad, torch_k.grad, torch_v.grad] + + # comparison + # print(f"Triton loss {loss} and torch loss {torch_loss}. Also checking grads...") + triton.testing.assert_almost_equal(loss, torch_loss) + for g1, g2 in zip(grads, torch_grads): + triton.testing.assert_almost_equal(g1, g2) + + +@pytest.mark.parametrize("block", [16, 32, 64]) +def triton_attention( + layout, + block: int, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + scale: float, +): + sparse_dot_sdd_nt = triton.ops.blocksparse.matmul(layout, block, "sdd", trans_a=False, trans_b=True, device=value.device) + sparse_dot_dsd_nn = triton.ops.blocksparse.matmul(layout, block, "dsd", trans_a=False, trans_b=False, device=value.device) + sparse_softmax = triton.ops.blocksparse.softmax(layout, block, device=value.device) + + w = sparse_dot_sdd_nt(query, key) + w = sparse_softmax(w, scale=scale, is_causal=True) + a = sparse_dot_dsd_nn(w, value) + return a diff --git a/python/triton/compiler.py b/python/triton/compiler.py index 42b47d3ad..cee2a8663 100644 --- a/python/triton/compiler.py +++ b/python/triton/compiler.py @@ -1515,7 +1515,7 @@ class CudaUtils(object): } } - #define CUDA_CHECK(ans) { gpuAssert((ans), __FILE__, __LINE__); } + #define CUDA_CHECK(ans) { gpuAssert((ans), __FILE__, __LINE__); if(PyErr_Occurred()) return NULL; } static PyObject* loadBinary(PyObject* self, PyObject* args) { const char* name; @@ -1530,7 +1530,6 @@ class CudaUtils(object): CUmodule mod; int32_t n_regs = 0; int32_t n_spills = 0; - Py_BEGIN_ALLOW_THREADS; // create driver handles CUDA_CHECK(cuModuleLoadData(&mod, data)); CUDA_CHECK(cuModuleGetFunction(&fun, mod, name)); @@ -1548,7 +1547,6 @@ class CudaUtils(object): CUDA_CHECK(cuFuncGetAttribute(&shared_static, CU_FUNC_ATTRIBUTE_SHARED_SIZE_BYTES, fun)); CUDA_CHECK(cuFuncSetAttribute(fun, CU_FUNC_ATTRIBUTE_MAX_DYNAMIC_SHARED_SIZE_BYTES, shared_optin - shared_static)); } - Py_END_ALLOW_THREADS; if(PyErr_Occurred()) { return NULL; diff --git a/python/triton/language/semantic.py b/python/triton/language/semantic.py index 741356e1a..36912eb2b 100644 --- a/python/triton/language/semantic.py +++ b/python/triton/language/semantic.py @@ -1117,16 +1117,16 @@ def sqrt(x: tl.tensor, builder: ir.builder) -> tl.tensor: ## def multiple_of(x: tl.tensor, values: List[int]) -> tl.tensor: - if len(x.shape) != len(values): - raise ValueError("Shape of input to multiple_of does not match the length of values") - x.handle.multiple_of(values) - return x - - + if len(x.shape) != len(values): + raise ValueError("Shape of input to multiple_of does not match the length of values") + x.handle.set_attr("tt.divisibility", ir.make_attr(values, x.handle.get_context())) + return x + + def max_contiguous(x: tl.tensor, values: List[int]) -> tl.tensor: if len(x.shape) != len(values): raise ValueError("Shape of input to max_contiguous does not match the length of values") - x.handle.max_contiguous(values) + x.handle.set_attr("tt.contiguity", ir.make_attr(values, x.handle.get_context())) return x diff --git a/python/triton/testing.py b/python/triton/testing.py index 2c9ece2fe..95a05349c 100644 --- a/python/triton/testing.py +++ b/python/triton/testing.py @@ -34,12 +34,12 @@ def sparsify_tensor(x, mask, block): return ret -def make_pair(shape, device="cuda", alpha=1e-2, beta=0., trans=False, data=None): +def make_pair(shape, device="cuda", alpha=1e-2, beta=0., trans=False, data=None, dtype=torch.float32): if data is None: - data = torch.randn(shape, dtype=torch.float32, device=device) + data = torch.randn(shape, dtype=torch.float32, requires_grad=True, device=device) ref_ret = data ref_ret = ref_ret * alpha + beta - ref_ret = ref_ret.half().float() + ref_ret = ref_ret.half().to(dtype) if trans: ref_ret = ref_ret.t().requires_grad_() ref_ret = ref_ret.detach().requires_grad_() diff --git a/test/Conversion/tritongpu_to_llvm.mlir b/test/Conversion/tritongpu_to_llvm.mlir index ac632c13b..dc4acb183 100644 --- a/test/Conversion/tritongpu_to_llvm.mlir +++ b/test/Conversion/tritongpu_to_llvm.mlir @@ -917,9 +917,9 @@ func @test_get_program_id(%a: tensor<32x!tt.ptr, #blocked0>) { module attributes {"triton_gpu.num-warps" = 4 : i32} { func @test_get_num_program(%a: tensor<32x!tt.ptr, #blocked0>) { - // CHECK: nvvm.read.ptx.sreg.ntid.x - // CHECK: nvvm.read.ptx.sreg.ntid.y - // CHECK: nvvm.read.ptx.sreg.ntid.z + // CHECK: nvvm.read.ptx.sreg.nctaid.x + // CHECK: nvvm.read.ptx.sreg.nctaid.y + // CHECK: nvvm.read.ptx.sreg.nctaid.z %blockdimx = tt.get_num_programs {axis=0:i32} : i32 %blockdimy = tt.get_num_programs {axis=1:i32} : i32 %blockdimz = tt.get_num_programs {axis=2:i32} : i32