[BACKEND] two minor bugfix on StoreOpLowering and kernel launch & support optional other in LoadOpLowering (#69)

* [BACKEND] two minor bugfix on StoreOpLowering and kernel launch & support optional other in LoadOpLowering

* Clean code

Co-authored-by: goostavz <gzhu@nvidia.com>
Co-authored-by: Yan Chunwei <yanchunwei@outlook.com>
This commit is contained in:
goostavz
2022-08-23 12:47:09 +08:00
committed by GitHub
parent 92ef552a54
commit de2dd04c8a
5 changed files with 70 additions and 35 deletions

View File

@@ -29,7 +29,7 @@ func @test_store_splat(%ptr: !tt.ptr<f32>) {
%vs = tt.splat %a : (f32) -> tensor<128xf32>
%mask = tt.splat %true : (i1) -> tensor<128xi1>
// CHECK: %{{.*}} = llvm.inline_asm has_side_effects asm_dialect = att operand_attrs = [] "@$0 st.global.v32.b1 [ $1 + 0 ], { $2 };",
// CHECK: %{{.*}} = llvm.inline_asm has_side_effects asm_dialect = att operand_attrs = [] "@$0 st.global.b32 [ $1 + 0 ], { $2 };",
// CHECK-SAME: "b,l,r" %{{.*}}, %{{.*}}, %{{.*}} : (i1, !llvm.ptr<f32, 1>, i32) -> !llvm.struct<()>
tt.store %ptrs, %vs, %mask, {} : tensor<128xf32>

View File

@@ -183,9 +183,9 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} {
// CHECK-LABEL: basic_store
func @basic_store(%ptrs: tensor<256x!tt.ptr<f32>, #blocked0>, %vals: tensor<256xf32, #blocked0>, %mask: tensor<256xi1, #blocked0>) {
// CHECK: llvm.inline_asm has_side_effects asm_dialect = att
// CHECK-SAME: st.global.v32.b1 [ ${{.*}} + 0 ], { ${{.*}} };", "b,l,r" %{{.*}}, %{{.*}}, %{{.*}} : (i1, !llvm.ptr<f32, 1>, i32) -> !llvm.struct<()>
// CHECK-SAME: st.global.b32 [ ${{.*}} + 0 ], { ${{.*}} };", "b,l,r" %{{.*}}, %{{.*}}, %{{.*}} : (i1, !llvm.ptr<f32, 1>, i32) -> !llvm.struct<()>
// CHECK: llvm.inline_asm has_side_effects asm_dialect = att
// CHECK-SAME: st.global.v32.b1 [ ${{.*}} + 0 ], { ${{.*}} };", "b,l,r" %{{.*}}, %{{.*}}, %{{.*}} : (i1, !llvm.ptr<f32, 1>, i32) -> !llvm.struct<()>
// CHECK-SAME: st.global.b32 [ ${{.*}} + 0 ], { ${{.*}} };", "b,l,r" %{{.*}}, %{{.*}}, %{{.*}} : (i1, !llvm.ptr<f32, 1>, i32) -> !llvm.struct<()>
tt.store %ptrs, %vals, %mask, {} : tensor<256xf32, #blocked0>
return
}