[BACKEND] Extracting numWarps from tritonGPU module (#39)

This commit is contained in:
Yan Chunwei
2022-08-09 00:40:20 +08:00
committed by GitHub
parent 920723cf3d
commit 83ef74f248
8 changed files with 66 additions and 17 deletions

View File

@@ -1,9 +1,10 @@
// RUN: triton-opt %s -convert-triton-to-tritongpu
// RUN: triton-opt %s -convert-triton-to-tritongpu=num-warps=2 | FileCheck %s
func @ops() {
// CHECK: module attributes {"triton_gpu.num-warps" = 2 : i32} {{.*}}
%a = arith.constant dense<1.00e+00> : tensor<128x32xf16>
%b = arith.constant dense<2.00e+00> : tensor<32x128xf16>
%c = arith.constant dense<3.00e+00> : tensor<128x128xf32>
%0 = tt.dot %a, %b, %c {allowTF32 = true} : tensor<128x32xf16> * tensor<32x128xf16> -> tensor<128x128xf32>
return
}
}

View File

@@ -1,9 +1,15 @@
// RUN: triton-opt %s -split-input-file --convert-triton-gpu-to-llvm=num-warps=8
// RUN: triton-opt %s -split-input-file --convert-triton-gpu-to-llvm | FileCheck %s
module attributes {"triton_gpu.num-warps" = 4 : i32} {
// CHECK: llvm.func @test_empty_kernel(%arg0: i64, %arg1: !llvm.ptr<f16, 1>)
// CHECK: attributes {nvvm.maxntidx = 96 : i32}
// Here the 128 comes from the 4 in module attribute multiples 32
// CHECK: attributes {nvvm.maxntid = 128 : i32} {{.*}}
func @test_empty_kernel(%lb : index, %A : !tt.ptr<f16>) {
// CHECK: llvm.return
return
}
} // end module