[OPTIMIZER] Update the versionMinor in MMA layout for volta (#1014)

Continue the work https://github.com/openai/triton/pull/990

# Background
The `versionMinor` in MmaEncodingAttr holds some states of DotOp's
operands in Volta, while such operands will be modified by some
patterns, making the states out-of-date.

This PR helps to correct the states.

# Implementation
It adds three new patterns:

1. `CollectMmaToUpdateForVolta` helps to collect and build a map holding
the MmaEncodingAttr instances with wrong states and create new correct
ones for them,
2. `UpdateMMAVersionMinorForVolta` helps to replace the Ops generating
the wrong MmaEncodingAttr instances with new correct ones, currently it
supports the following Ops
    a. `convert_layout[X -> mma]`
    b. `arith.constant SplatAttr : !tensor<mma>`
    c. `dot ... : !tensor<mma>`

# Limitation
This PR chooses the mapping way to bypass the IR walk complexity from
the circular dependency between dot_operand[parent] and mma.
We use the MmaEncodingAttr instance as the mapping key, but there might
be multiple DotOp holding different DotOprand(IsMMAv1Row) that have the
same wrong MmaEncodingAttr instance.
To make each DotOp's (wrong) MmaEncodingAttr unique, we might need an ID
field to MmaEncodingAttr.
This commit is contained in:
Yan Chunwei
2022-12-28 12:24:01 +08:00
committed by GitHub
parent fd2da4aff6
commit 2ba74d2729
3 changed files with 281 additions and 26 deletions

View File

@@ -1,4 +1,4 @@
// RUN: triton-opt %s -tritongpu-combine 2>&1 | FileCheck %s
// RUN: triton-opt %s -split-input-file -tritongpu-combine 2>&1 | FileCheck %s
#layout0 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>
#layout1 = #triton_gpu.blocked<{sizePerThread = [4], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>
@@ -7,7 +7,6 @@
// CHECK: [[row_layout:#.*]] = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [2, 16], warpsPerCTA = [1, 4], order = [1, 0]}>
// CHECK: [[col_layout:#.*]] = #triton_gpu.blocked<{sizePerThread = [4, 1], threadsPerWarp = [16, 2], warpsPerCTA = [4, 1], order = [0, 1]}>
// CHECK: [[col_layout_novec:#.*]] = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [32, 1], warpsPerCTA = [4, 1], order = [0, 1]}>
func @cst() -> tensor<1024xi32, #layout1> {
%cst = arith.constant dense<0> : tensor<1024xi32, #layout0>
%1 = triton_gpu.convert_layout %cst : (tensor<1024xi32, #layout0>) -> tensor<1024xi32, #layout1>
@@ -62,9 +61,9 @@ func @remat(%arg0: i32) -> tensor<1024xi32, #layout1> {
// CHECK-LABEL: transpose
func @transpose(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg1: i32 {tt.divisibility = 16 : i32}, %arg2: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg3: i32 {tt.divisibility = 16 : i32}) {
// CHECK-NOT: triton_gpu.convert_layout
// CHECK: [[loaded_val:%.*]] = tt.load {{.*}}, %cst, %cst_0 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<64x64xf32, [[row_layout]]>
// CHECK: [[loaded_val:%.*]] = tt.load {{.*}}, {{%cst.*}}, {{%cst.*}} {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<64x64xf32, [[row_layout]]>
// CHECK: [[cvt_val:%.*]] = triton_gpu.convert_layout [[loaded_val]] : (tensor<64x64xf32, [[row_layout]]>) -> tensor<64x64xf32, [[col_layout]]>
// CHECK: tt.store {{.*}}, [[cvt_val]], %cst_1 : tensor<64x64xf32, [[col_layout]]>
// CHECK: tt.store {{.*}}, [[cvt_val]], {{%cst.*}} : tensor<64x64xf32, [[col_layout]]>
// CHECK: return
%cst = arith.constant dense<0.000000e+00> : tensor<64x64xf32, #blocked1>
%cst_0 = arith.constant dense<true> : tensor<64x64xi1, #blocked1>
@@ -184,3 +183,32 @@ func @vecadd(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg1: !tt.ptr<f3
tt.store %21, %22 : tensor<256xf32, #layout1>
return
}
// -----
// check the UpdateMMAVersionMinorForVolta pattern
#blocked0 = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [8, 4], warpsPerCTA = [1, 1], order = [1, 0]}>
#shared0 = #triton_gpu.shared<{vec = 1, perPhase=2, maxPhase=8 ,order = [1, 0]}>
#mma0 = #triton_gpu.mma<{versionMajor=1, versionMinor=0, warpsPerCTA=[1,1]}>
// Here, the isMMAv1Row of a and b's dot_operands mismatch #mma0's versionMinor,
// and the pattern should update the versionMinor.
#dot_operand_a = #triton_gpu.dot_op<{opIdx=0, parent=#mma0, isMMAv1Row=true}>
#dot_operand_b = #triton_gpu.dot_op<{opIdx=1, parent=#mma0, isMMAv1Row=false}>
// It creates a new MMA layout to fit with $a and $b's dot_operand
// CHECK: [[new_mma:#mma.*]] = #triton_gpu.mma<{versionMajor = 1, versionMinor = 11, warpsPerCTA = [1, 1]}>
module attributes {"triton_gpu.num-warps" = 1 : i32} {
// CHECK-LABEL: dot_mmav1
func @dot_mmav1(%A: tensor<16x16xf16, #blocked0>, %B: tensor<16x16xf16, #blocked0>) -> tensor<16x16xf32, #blocked0> {
%C = arith.constant dense<0.000000e+00> : tensor<16x16xf32, #blocked0>
%AA = triton_gpu.convert_layout %A : (tensor<16x16xf16, #blocked0>) -> tensor<16x16xf16, #dot_operand_a>
%BB = triton_gpu.convert_layout %B : (tensor<16x16xf16, #blocked0>) -> tensor<16x16xf16, #dot_operand_b>
%CC = triton_gpu.convert_layout %C : (tensor<16x16xf32, #blocked0>) -> tensor<16x16xf32, #mma0>
// CHECK: {{.*}} = tt.dot {{.*}}, {{.*}}, %cst {allowTF32 = true} : tensor<16x16xf16, #triton_gpu.dot_op<{opIdx = 0, parent = [[new_mma]], isMMAv1Row = true}>> * tensor<16x16xf16, #triton_gpu.dot_op<{opIdx = 1, parent = [[new_mma]], isMMAv1Row = true}>> -> tensor<16x16xf32, [[new_mma]]>
%D = tt.dot %AA, %BB, %CC {allowTF32 = true} : tensor<16x16xf16, #dot_operand_a> * tensor<16x16xf16, #dot_operand_b> -> tensor<16x16xf32, #mma0>
%res = triton_gpu.convert_layout %D : (tensor<16x16xf32, #mma0>) -> tensor<16x16xf32, #blocked0>
return %res : tensor<16x16xf32, #blocked0>
}
}