From 74f3d7a80fbb0d69beb59fa7f25a7590eb0fac1b Mon Sep 17 00:00:00 2001 From: Jokeren Date: Tue, 6 Dec 2022 12:53:25 -0800 Subject: [PATCH] Fix --- lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.cpp | 2 +- test/Conversion/tritongpu_to_llvm.mlir | 12 +++++++++--- 2 files changed, 10 insertions(+), 4 deletions(-) diff --git a/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.cpp b/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.cpp index c5779c021..a5fb4edc0 100644 --- a/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.cpp +++ b/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.cpp @@ -4107,7 +4107,7 @@ struct InsertSliceAsyncOpConversion // Write shared memory if predicate is true auto *valOperand = ptxBuilder.newOperand(v, "r"); auto &st = *ptxBuilder.create("st"); - st.shared().o("b" + std::to_string(byteWidth)); + st.shared().o("b" + std::to_string(bitWidth)); st(dstOperand, valOperand).predicate(pred); } } diff --git a/test/Conversion/tritongpu_to_llvm.mlir b/test/Conversion/tritongpu_to_llvm.mlir index 98a475ad3..c897dde4d 100644 --- a/test/Conversion/tritongpu_to_llvm.mlir +++ b/test/Conversion/tritongpu_to_llvm.mlir @@ -435,7 +435,7 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} { #AL = #triton_gpu.blocked<{sizePerThread = [1, 8], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}> #A = #triton_gpu.shared<{vec = 8, perPhase = 1, maxPhase = 4, order = [1, 0]}> module attributes {"triton_gpu.num-warps" = 4 : i32} { - // CHECK-LABEL: basic_insert_slice_async_v4 + // CHECK-LABEL: basic_insert_slice_async_mask func @basic_insert_slice_async_mask(%arg0: !tt.ptr {tt.divisibility = 8 : i32}) { %off0_ = tt.make_range {end = 16 : i32, start = 0 : i32} : tensor<16xi32, #slice2d1> %off1_ = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<64xi32, #slice3d0> @@ -456,8 +456,10 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} { %true = arith.constant 1 : i1 %true_tensor = tt.splat %true : (i1) -> tensor<16x64xi1, #AL> + // CHECK: llvm.select // CHECK: llvm.inline_asm has_side_effects asm_dialect = att // CHECK-SAME: cp.async.cg.shared.global [ ${{.*}} + 0 ], [ ${{.*}} + 0 ], 0x10, 0x10 + // CHECK: llvm.select // CHECK: llvm.inline_asm has_side_effects asm_dialect = att // CHECK-SAME: cp.async.cg.shared.global [ ${{.*}} + 16 ], [ ${{.*}} + 0 ], 0x10, 0x10 // CHECK: llvm.inline_asm has_side_effects asm_dialect = att @@ -478,7 +480,7 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} { #AL = #triton_gpu.blocked<{sizePerThread = [1, 8], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}> #A = #triton_gpu.shared<{vec = 8, perPhase = 1, maxPhase = 4, order = [1, 0]}> module attributes {"triton_gpu.num-warps" = 4 : i32} { - // CHECK-LABEL: basic_insert_slice_async_v4 + // CHECK-LABEL: basic_insert_slice_async_mask_other func @basic_insert_slice_async_mask_other(%arg0: !tt.ptr {tt.divisibility = 8 : i32}) { %off0_ = tt.make_range {end = 16 : i32, start = 0 : i32} : tensor<16xi32, #slice2d1> %off1_ = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<64xi32, #slice3d0> @@ -498,14 +500,18 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} { %index = arith.constant 1 : i32 %true = arith.constant 1 : i1 %true_tensor = tt.splat %true : (i1) -> tensor<16x64xi1, #AL> + %other = arith.constant 1.0 : f32 + %other_tensor = tt.splat %other : (f32) -> tensor<16x64xf32, #AL> // CHECK: llvm.inline_asm has_side_effects asm_dialect = att + // CHECK-SAME: st.shared.b128 [ ${{.*}} + 0 ] // CHECK-SAME: cp.async.cg.shared.global [ ${{.*}} + 0 ], [ ${{.*}} + 0 ], 0x10, 0x10 // CHECK: llvm.inline_asm has_side_effects asm_dialect = att + // CHECK-SAME: st.shared.b128 [ ${{.*}} + 16 ] // CHECK-SAME: cp.async.cg.shared.global [ ${{.*}} + 16 ], [ ${{.*}} + 0 ], 0x10, 0x10 // CHECK: llvm.inline_asm has_side_effects asm_dialect = att // CHECK-SAME: cp.async.commit_group - %a = triton_gpu.insert_slice_async %a_ptr, %tensor, %index, %true_tensor, %true_tensor {axis = 0 : i32, cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<16x64x!tt.ptr, #AL> -> tensor<2x16x64xf32, #A> + %a = triton_gpu.insert_slice_async %a_ptr, %tensor, %index, %true_tensor, %other_tensor {axis = 0 : i32, cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<16x64x!tt.ptr, #AL> -> tensor<2x16x64xf32, #A> return } }