From de2dd04c8a1fac782e3cd474d7a9bbb23a7b7af5 Mon Sep 17 00:00:00 2001 From: goostavz <109190422+goostavz@users.noreply.github.com> Date: Tue, 23 Aug 2022 12:47:09 +0800 Subject: [PATCH] [BACKEND] two minor bugfix on StoreOpLowering and kernel launch & support optional other in LoadOpLowering (#69) * [BACKEND] two minor bugfix on StoreOpLowering and kernel launch & support optional other in LoadOpLowering * Clean code Co-authored-by: goostavz Co-authored-by: Yan Chunwei --- .../TritonGPUToLLVM/TritonGPUToLLVM.cpp | 61 ++++++++++--------- python/src/triton.cc | 4 +- python/tests/test_vecadd_no_scf.py | 34 ++++++++++- test/Conversion/triton_to_llvm.mlir | 2 +- test/Conversion/tritongpu_to_llvm.mlir | 4 +- 5 files changed, 70 insertions(+), 35 deletions(-) diff --git a/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.cpp b/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.cpp index 0d8df10ba..96b24f1ef 100644 --- a/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.cpp +++ b/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.cpp @@ -309,10 +309,10 @@ public: PatternBenefit benefit = 1) : ConvertOpToLLVMPattern(typeConverter, benefit) {} - SmallVector + SmallVector getElementsFromStruct(Location loc, Value llvmStruct, unsigned elems, ConversionPatternRewriter &rewriter) const { - SmallVector results(elems); + SmallVector results(elems); for (unsigned i = 0; i < elems; ++i) { Type type = llvmStruct.getType().cast().getBody()[i]; @@ -710,7 +710,7 @@ struct StoreOpConversion PtxIOInstr asmStoreInstr("st"); asmStoreInstr.predicate(maskElems[vecIdx], "b"); - asmStoreInstr.global().v(width).b(nWords); + asmStoreInstr.global().b(width).v(nWords); llvm::SmallVector asmArgs; @@ -970,7 +970,10 @@ struct LoadOpConversion unsigned numElems = getElemsPerThread(blockedLayout, shape); auto ptrVals = getElementsFromStruct(loc, ptr, numElems, rewriter); auto maskVals = getElementsFromStruct(loc, mask, numElems, rewriter); - auto otherVals = getElementsFromStruct(loc, other, numElems, rewriter); + SmallVector otherVals; + if (other != nullptr) { + otherVals = getElementsFromStruct(loc, other, numElems, rewriter); + } unsigned nbits = elemTy.isa() ? elemTy.cast().getWidth() : elemTy.cast().getWidth(); @@ -1039,31 +1042,33 @@ struct LoadOpConversion asmOss << ", $" << n_words + 2; asmOss << ";"; SmallVector others; - for (size_t ii = 0; ii < n_words; ii++) { - size_t size = width / nbits; - auto vecTy = LLVM::getFixedVectorType(elemTy, size); - Value v = rewriter.create(loc, vecTy); - for (size_t s = 0; s < size; s++) { - Value falseVal = otherVals[i + ii * size + s]; - Value sVal = createIndexAttrConstant( - rewriter, loc, this->getTypeConverter()->getIndexType(), s); - v = rewriter.create(loc, vecTy, v, falseVal, - sVal); + if (other != nullptr) { + for (size_t ii = 0; ii < n_words; ii++) { + size_t size = width / nbits; + auto vecTy = LLVM::getFixedVectorType(elemTy, size); + Value v = rewriter.create(loc, vecTy); + for (size_t s = 0; s < size; s++) { + Value falseVal = otherVals[i + ii * size + s]; + Value sVal = createIndexAttrConstant( + rewriter, loc, this->getTypeConverter()->getIndexType(), s); + v = rewriter.create(loc, vecTy, v, falseVal, + sVal); + } + v = rewriter.create( + loc, IntegerType::get(getContext(), width), v); + asmOss << "\n "; + asmOss << "@!$" << n_words << " mov.u" << width; + asmOss << " $" << ii << ", "; + std::ios_base::fmtflags flags(asmOss.flags()); + if (otherIsSplatConstInt) + asmOss << "0x" << std::hex << splatVal; + else { + asmOss << "$" << n_words + has_l2_evict_policy + 2 + ii; + others.push_back(v); + } + asmOss.flags(flags); + asmOss << ";"; } - v = rewriter.create( - loc, IntegerType::get(getContext(), width), v); - asmOss << "\n "; - asmOss << "@!$" << n_words << " mov.u" << width; - asmOss << " $" << ii << ", "; - std::ios_base::fmtflags flags(asmOss.flags()); - if (otherIsSplatConstInt) - asmOss << "0x" << std::hex << splatVal; - else { - asmOss << "$" << n_words + has_l2_evict_policy + 2 + ii; - others.push_back(v); - } - asmOss.flags(flags); - asmOss << ";"; } // --- // create inline ASM signature diff --git a/python/src/triton.cc b/python/src/triton.cc index c2ef6b4ba..f0b78e15f 100644 --- a/python/src/triton.cc +++ b/python/src/triton.cc @@ -258,9 +258,9 @@ void parse_args(py::list &args, py::list do_not_specialize, void parse_args(py::list &args, py::list &arg_names, std::string ¶ms, size_t ¶ms_size, py::dict constants) { - char *params_ptr = params.data(); - size_t len = PyList_Size(args.ptr()); + params.reserve(8 * len); // 8 max bytes by argument + char *params_ptr = params.data(); for (int i = 0; i < len; i++) { py::object arg = args[i]; auto arg_ptr = arg.ptr(); diff --git a/python/tests/test_vecadd_no_scf.py b/python/tests/test_vecadd_no_scf.py index 995ef5fae..d1604d8f1 100644 --- a/python/tests/test_vecadd_no_scf.py +++ b/python/tests/test_vecadd_no_scf.py @@ -1,7 +1,12 @@ +import torch +from torch.testing import assert_allclose + import triton import triton.language as tl +import triton.runtime as runtime NUM_WARPS = 4 +BLOCK_SIZE = 256 # triton kernel @@ -22,6 +27,31 @@ def test_vecadd_no_scf(): z_ptrs = z_ptr + offset tl.store(z_ptrs, z) - ret = triton.compile(kernel, "*fp32,i32,*fp32,i32,*fp32,i32", constants={"BLOCK_SIZE_N": 256}, num_warps=NUM_WARPS, device=0, output="ptx") + ptx, shem_size, kernel_name = triton.compile(kernel, "*fp32,i32,*fp32,i32,*fp32,i32", constants={"BLOCK_SIZE_N": 256}, num_warps=NUM_WARPS, device=0, output="ptx") - print(ret) + torch.zeros([10], device=torch.device('cuda')) + device = torch.cuda.current_device() + binary = runtime.build_kernel(kernel, "*fp32,i32,*fp32,i32,*fp32,i32", + device=device, + constants={"BLOCK_SIZE_N": BLOCK_SIZE}, + num_warps=NUM_WARPS, + num_stages=3) + grid = lambda META: (1, ) + + x = torch.randn((256,), device='cuda', dtype=torch.float32) + y = torch.randn((256,), device='cuda', dtype=torch.float32) + z = torch.empty((256,), device=x.device, dtype=x.dtype) + runtime.launch_kernel(fn=kernel, + binary=binary, + grid=grid, + num_warps=NUM_WARPS, + num_stages=3, + x_ptr=x, + stride_xn=x.stride(0), + y_ptr=y, + stride_yn=y.stride(0), + z_ptr=z, + stride_zn=z.stride(0), + BLOCK_SIZE_N=tl.constexpr(BLOCK_SIZE)) + golden_z = x + y + assert_allclose(z, golden_z, rtol=1e-7, atol=1e-7) diff --git a/test/Conversion/triton_to_llvm.mlir b/test/Conversion/triton_to_llvm.mlir index 5bb411c35..ec5c470ca 100644 --- a/test/Conversion/triton_to_llvm.mlir +++ b/test/Conversion/triton_to_llvm.mlir @@ -29,7 +29,7 @@ func @test_store_splat(%ptr: !tt.ptr) { %vs = tt.splat %a : (f32) -> tensor<128xf32> %mask = tt.splat %true : (i1) -> tensor<128xi1> - // CHECK: %{{.*}} = llvm.inline_asm has_side_effects asm_dialect = att operand_attrs = [] "@$0 st.global.v32.b1 [ $1 + 0 ], { $2 };", + // CHECK: %{{.*}} = llvm.inline_asm has_side_effects asm_dialect = att operand_attrs = [] "@$0 st.global.b32 [ $1 + 0 ], { $2 };", // CHECK-SAME: "b,l,r" %{{.*}}, %{{.*}}, %{{.*}} : (i1, !llvm.ptr, i32) -> !llvm.struct<()> tt.store %ptrs, %vs, %mask, {} : tensor<128xf32> diff --git a/test/Conversion/tritongpu_to_llvm.mlir b/test/Conversion/tritongpu_to_llvm.mlir index 473607bc3..4cad33ff3 100644 --- a/test/Conversion/tritongpu_to_llvm.mlir +++ b/test/Conversion/tritongpu_to_llvm.mlir @@ -183,9 +183,9 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} { // CHECK-LABEL: basic_store func @basic_store(%ptrs: tensor<256x!tt.ptr, #blocked0>, %vals: tensor<256xf32, #blocked0>, %mask: tensor<256xi1, #blocked0>) { // CHECK: llvm.inline_asm has_side_effects asm_dialect = att - // CHECK-SAME: st.global.v32.b1 [ ${{.*}} + 0 ], { ${{.*}} };", "b,l,r" %{{.*}}, %{{.*}}, %{{.*}} : (i1, !llvm.ptr, i32) -> !llvm.struct<()> + // CHECK-SAME: st.global.b32 [ ${{.*}} + 0 ], { ${{.*}} };", "b,l,r" %{{.*}}, %{{.*}}, %{{.*}} : (i1, !llvm.ptr, i32) -> !llvm.struct<()> // CHECK: llvm.inline_asm has_side_effects asm_dialect = att - // CHECK-SAME: st.global.v32.b1 [ ${{.*}} + 0 ], { ${{.*}} };", "b,l,r" %{{.*}}, %{{.*}}, %{{.*}} : (i1, !llvm.ptr, i32) -> !llvm.struct<()> + // CHECK-SAME: st.global.b32 [ ${{.*}} + 0 ], { ${{.*}} };", "b,l,r" %{{.*}}, %{{.*}}, %{{.*}} : (i1, !llvm.ptr, i32) -> !llvm.struct<()> tt.store %ptrs, %vals, %mask, {} : tensor<256xf32, #blocked0> return }