[BACKEND] Add LLVM-translation for store and splat ops (#47)
This commit is contained in:
36
test/Conversion/triton_to_llvm.mlir
Normal file
36
test/Conversion/triton_to_llvm.mlir
Normal file
@@ -0,0 +1,36 @@
|
||||
// RUN: triton-opt %s -split-input-file -convert-triton-to-tritongpu=num-warps=2 -convert-triton-gpu-to-llvm | FileCheck %s
|
||||
|
||||
func @test_splat(%ptr: !tt.ptr<f32>) {
|
||||
// Here, 128 elements, 64(2*32) threads, so each need to process 2 elements
|
||||
//
|
||||
// CHECK: %0 = llvm.bitcast %arg0 : !llvm.ptr<f32, 1> to !llvm.ptr<f32, 1>
|
||||
// CHECK: %1 = llvm.mlir.undef : !llvm.struct<(ptr<f32, 1>, ptr<f32, 1>)>
|
||||
// CHECK: %2 = llvm.insertvalue %0, %1[0] : !llvm.struct<(ptr<f32, 1>, ptr<f32, 1>)>
|
||||
// CHECK: %3 = llvm.insertvalue %0, %2[1] : !llvm.struct<(ptr<f32, 1>, ptr<f32, 1>)>
|
||||
%ptrs = tt.splat %ptr : (!tt.ptr<f32>) -> tensor<128x!tt.ptr<f32>>
|
||||
%a = arith.constant 1.0 : f32
|
||||
%true = arith.constant 1 : i1
|
||||
%b = tt.splat %a : (f32) -> tensor<128xf32>
|
||||
|
||||
// Here, each thread process only 1 element
|
||||
// CHECK: %{{.*}} = llvm.mlir.undef : !llvm.struct<(i1)>
|
||||
%mask = tt.splat %true : (i1) -> tensor<64xi1>
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
func @test_store_splat(%ptr: !tt.ptr<f32>) {
|
||||
%ptrs = tt.splat %ptr : (!tt.ptr<f32>) -> tensor<128x!tt.ptr<f32>>
|
||||
%a = arith.constant 1.0 : f32
|
||||
%true = arith.constant 1 : i1
|
||||
|
||||
%vs = tt.splat %a : (f32) -> tensor<128xf32>
|
||||
%mask = tt.splat %true : (i1) -> tensor<128xi1>
|
||||
|
||||
// CHECK: %{{.*}} = llvm.inline_asm has_side_effects asm_dialect = att operand_attrs = [] "@%0 st.global.v32.b1 [ %1 + 0 ], { %2 };",
|
||||
// CHECK: "b,l,r" %{{.*}}, %{{.*}}, %{{.*}} : (!llvm.struct<(i1, i1)>, !llvm.struct<(ptr<f32, 1>, ptr<f32, 1>)>, i32) -> !llvm.struct<()>
|
||||
|
||||
tt.store %ptrs, %vs, %mask, {} : tensor<128xf32>
|
||||
|
||||
return
|
||||
}
|
@@ -1,11 +1,10 @@
|
||||
// RUN: triton-opt %s -split-input-file --convert-triton-gpu-to-llvm | FileCheck %s
|
||||
|
||||
|
||||
module attributes {"triton_gpu.num-warps" = 4 : i32} {
|
||||
|
||||
// CHECK: llvm.func @test_empty_kernel(%arg0: i32, %arg1: !llvm.ptr<f16, 1>)
|
||||
// Here the 128 comes from the 4 in module attribute multiples 32
|
||||
// CHECK: attributes {nvvm.maxntid = 128 : i32} {{.*}}
|
||||
// CHECK: attributes {nvvm.maxntid = 128 : si32} {{.*}}
|
||||
func @test_empty_kernel(%lb : index, %A : !tt.ptr<f16>) {
|
||||
|
||||
// CHECK: llvm.return
|
||||
|
Reference in New Issue
Block a user