157 lines
15 KiB
Plaintext
157 lines
15 KiB
Plaintext
#blocked0 = #triton_gpu.blocked<{sizePerThread = [1, 8], threadsPerWarp = [8, 4], warpsPerCTA = [8, 1], order = [1, 0]}>
|
|
#blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 8], threadsPerWarp = [2, 16], warpsPerCTA = [8, 1], order = [1, 0]}>
|
|
#mma = #triton_gpu.mma<{version = 1, warpsPerCTA = [4, 2]}>
|
|
#shared0 = #triton_gpu.shared<{vec = 4, perPhase = 2, maxPhase = 4, order = [1, 0]}>
|
|
#shared1 = #triton_gpu.shared<{vec = 4, perPhase = 1, maxPhase = 4, order = [1, 0]}>
|
|
module attributes {"triton_gpu.num-warps" = 8 : i32} {
|
|
func public @_kernel_0d1d2d3d4d5d6d7d8d(%arg0: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %arg1: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %arg2: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %arg3: i32 {tt.divisibility = 16 : i32}, %arg4: i32 {tt.divisibility = 16 : i32}, %arg5: i32 {tt.divisibility = 16 : i32}, %arg6: i32 {tt.divisibility = 16 : i32}, %arg7: i32 {tt.divisibility = 16 : i32}, %arg8: i32 {tt.divisibility = 16 : i32}) {
|
|
%c2_i32 = arith.constant 2 : i32
|
|
%c1_i32 = arith.constant 1 : i32
|
|
%c0_i32 = arith.constant 0 : index
|
|
%cst = arith.constant dense<32> : tensor<256x32xi32, #blocked0>
|
|
%cst_0 = arith.constant dense<0.000000e+00> : tensor<256x128xf32, #mma>
|
|
%c8_i32 = arith.constant 8 : i32
|
|
%c255_i32 = arith.constant 255 : i32
|
|
%c127_i32 = arith.constant 127 : i32
|
|
%c32_i32 = arith.constant 32 : i32
|
|
%c0 = arith.constant 0 : index
|
|
%c32 = arith.constant 32 : index
|
|
%c256_i32 = arith.constant 256 : i32
|
|
%c128_i32 = arith.constant 128 : i32
|
|
%0 = tt.get_program_id {axis = 0 : i32} : i32
|
|
%1 = tt.get_program_id {axis = 1 : i32} : i32
|
|
%2 = arith.addi %arg3, %c255_i32 : i32
|
|
%3 = arith.divsi %2, %c256_i32 : i32
|
|
%4 = arith.addi %arg4, %c127_i32 : i32
|
|
%5 = arith.divsi %4, %c128_i32 : i32
|
|
%6 = arith.muli %5, %c8_i32 : i32
|
|
%7 = arith.divsi %0, %6 : i32
|
|
%8 = arith.muli %7, %c8_i32 : i32
|
|
%9 = arith.subi %3, %8 : i32
|
|
%10 = arith.cmpi slt, %9, %c8_i32 : i32
|
|
%11 = select %10, %9, %c8_i32 : i32
|
|
%12 = arith.remsi %0, %11 : i32
|
|
%13 = arith.addi %8, %12 : i32
|
|
%14 = arith.remsi %0, %6 : i32
|
|
%15 = arith.divsi %14, %11 : i32
|
|
%16 = arith.muli %13, %c256_i32 : i32
|
|
%17 = tt.make_range {end = 256 : i32, start = 0 : i32} : tensor<256xi32, #triton_gpu.slice<{dim = 1, parent = #blocked0}>>
|
|
%18 = tt.make_range {end = 256 : i32, start = 0 : i32} : tensor<256xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>>
|
|
%19 = tt.splat %16 : (i32) -> tensor<256xi32, #triton_gpu.slice<{dim = 1, parent = #blocked0}>>
|
|
%20 = tt.splat %16 : (i32) -> tensor<256xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>>
|
|
%21 = arith.muli %15, %c128_i32 : i32
|
|
%22 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>>
|
|
%23 = tt.splat %21 : (i32) -> tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>>
|
|
%24 = tt.splat %arg3 : (i32) -> tensor<256xi32, #triton_gpu.slice<{dim = 1, parent = #blocked0}>>
|
|
%25 = tt.splat %arg3 : (i32) -> tensor<256xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>>
|
|
%26 = tt.splat %arg4 : (i32) -> tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>>
|
|
%27 = arith.muli %1, %c32_i32 : i32
|
|
%28 = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32, #triton_gpu.slice<{dim = 0, parent = #blocked0}>>
|
|
%29 = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>>
|
|
%30 = tt.splat %27 : (i32) -> tensor<32xi32, #triton_gpu.slice<{dim = 0, parent = #blocked0}>>
|
|
%31 = tt.splat %27 : (i32) -> tensor<32xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>>
|
|
%32 = arith.addi %19, %17 : tensor<256xi32, #triton_gpu.slice<{dim = 1, parent = #blocked0}>>
|
|
%33 = arith.remsi %32, %24 : tensor<256xi32, #triton_gpu.slice<{dim = 1, parent = #blocked0}>>
|
|
%34 = tt.splat %arg6 : (i32) -> tensor<256x1xi32, #blocked0>
|
|
%35 = arith.addi %30, %28 : tensor<32xi32, #triton_gpu.slice<{dim = 0, parent = #blocked0}>>
|
|
%36 = tt.expand_dims %35 {axis = 0 : i32} : (tensor<32xi32, #triton_gpu.slice<{dim = 0, parent = #blocked0}>>) -> tensor<1x32xi32, #blocked0>
|
|
%37 = tt.broadcast %36 : (tensor<1x32xi32, #blocked0>) -> tensor<256x32xi32, #blocked0>
|
|
%38 = tt.splat %arg0 : (!tt.ptr<f16>) -> tensor<256x32x!tt.ptr<f16>, #blocked0>
|
|
%39 = arith.addi %31, %29 : tensor<32xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>>
|
|
%40 = tt.expand_dims %39 {axis = 1 : i32} : (tensor<32xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>>) -> tensor<32x1xi32, #blocked1>
|
|
%41 = tt.splat %arg7 : (i32) -> tensor<32x1xi32, #blocked1>
|
|
%42 = arith.addi %23, %22 : tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>>
|
|
%43 = arith.remsi %42, %26 : tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>>
|
|
%44 = tt.expand_dims %43 {axis = 0 : i32} : (tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>>) -> tensor<1x128xi32, #blocked1>
|
|
%45 = tt.broadcast %44 : (tensor<1x128xi32, #blocked1>) -> tensor<32x128xi32, #blocked1>
|
|
%46 = tt.splat %arg1 : (!tt.ptr<f16>) -> tensor<32x128x!tt.ptr<f16>, #blocked1>
|
|
%47 = arith.index_cast %arg5 : i32 to index
|
|
%48 = arith.muli %arg7, %c32_i32 : i32
|
|
%49 = tt.splat %48 : (i32) -> tensor<32x128xi32, #blocked1>
|
|
%50 = tt.expand_dims %33 {axis = 1 : i32} : (tensor<256xi32, #triton_gpu.slice<{dim = 1, parent = #blocked0}>>) -> tensor<256x1xi32, #blocked0>
|
|
%51 = arith.muli %50, %34 : tensor<256x1xi32, #blocked0>
|
|
%52 = tt.broadcast %51 : (tensor<256x1xi32, #blocked0>) -> tensor<256x32xi32, #blocked0>
|
|
%53 = arith.addi %52, %37 : tensor<256x32xi32, #blocked0>
|
|
%54 = tt.addptr %38, %53 : tensor<256x32x!tt.ptr<f16>, #blocked0>, tensor<256x32xi32, #blocked0>
|
|
%55 = arith.muli %40, %41 : tensor<32x1xi32, #blocked1>
|
|
%56 = tt.broadcast %55 : (tensor<32x1xi32, #blocked1>) -> tensor<32x128xi32, #blocked1>
|
|
%57 = arith.addi %56, %45 : tensor<32x128xi32, #blocked1>
|
|
%58 = tt.addptr %46, %57 : tensor<32x128x!tt.ptr<f16>, #blocked1>, tensor<32x128xi32, #blocked1>
|
|
%59 = arith.cmpi slt, %c0, %47 : index
|
|
%60 = triton_gpu.alloc_tensor : tensor<2x256x32xf16, #shared0>
|
|
%64 = triton_gpu.alloc_tensor : tensor<2x32x128xf16, #shared1>
|
|
%61 = tt.splat %59 : (i1) -> tensor<256x32xi1, #blocked0>
|
|
%65 = tt.splat %59 : (i1) -> tensor<32x128xi1, #blocked1>
|
|
%62 = tt.load %54, %61 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<256x32xf16, #blocked0>
|
|
%66 = tt.load %58, %65 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<32x128xf16, #blocked1>
|
|
%63 = tensor.insert_slice %62 into %60[%c0_i32, 0, 0] [1, 256, 32] [1, 1, 1] : tensor<256x32xf16, #blocked0> into tensor<2x256x32xf16, #shared0>
|
|
%67 = tensor.insert_slice %66 into %64[%c0_i32, 0, 0] [1, 32, 128] [1, 1, 1] : tensor<32x128xf16, #blocked1> into tensor<2x32x128xf16, #shared1>
|
|
%68 = tt.addptr %54, %cst : tensor<256x32x!tt.ptr<f16>, #blocked0>, tensor<256x32xi32, #blocked0>
|
|
%69 = tt.addptr %58, %49 : tensor<32x128x!tt.ptr<f16>, #blocked1>, tensor<32x128xi32, #blocked1>
|
|
%70 = tensor.extract_slice %63[0, 0, 0] [1, 256, 32] [1, 1, 1] : tensor<2x256x32xf16, #shared0> to tensor<256x32xf16, #shared0>
|
|
%71 = tensor.extract_slice %67[0, 0, 0] [1, 32, 128] [1, 1, 1] : tensor<2x32x128xf16, #shared1> to tensor<32x128xf16, #shared1>
|
|
%72 = tensor.extract_slice %70[0, 0] [256, 16] [1, 1] : tensor<256x32xf16, #shared0> to tensor<256x16xf16, #shared0>
|
|
gpu.barrier
|
|
%73 = triton_gpu.convert_layout %72 : (tensor<256x16xf16, #shared0>) -> tensor<256x16xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, isMMAv1Row = true}>>
|
|
%74 = tensor.extract_slice %71[0, 0] [16, 128] [1, 1] : tensor<32x128xf16, #shared1> to tensor<16x128xf16, #shared1>
|
|
%75 = triton_gpu.convert_layout %74 : (tensor<16x128xf16, #shared1>) -> tensor<16x128xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, isMMAv1Row = true}>>
|
|
%76:14 = scf.for %arg9 = %c0 to %47 step %c32 iter_args(%arg10 = %cst_0, %arg11 = %54, %arg12 = %58, %arg13 = %63, %arg14 = %67, %arg15 = %70, %arg16 = %71, %arg17 = %68, %arg18 = %69, %arg19 = %c0, %arg20 = %c1_i32, %arg21 = %c1_i32, %arg22 = %73, %arg23 = %75) -> (tensor<256x128xf32, #mma>, tensor<256x32x!tt.ptr<f16>, #blocked0>, tensor<32x128x!tt.ptr<f16>, #blocked1>, tensor<2x256x32xf16, #shared0>, tensor<2x32x128xf16, #shared1>, tensor<256x32xf16, #shared0>, tensor<32x128xf16, #shared1>, tensor<256x32x!tt.ptr<f16>, #blocked0>, tensor<32x128x!tt.ptr<f16>, #blocked1>, index, i32, i32, tensor<256x16xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, isMMAv1Row = true}>>, tensor<16x128xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, isMMAv1Row = true}>>) {
|
|
%104 = arith.addi %arg19, %c32 : index
|
|
%105 = arith.cmpi slt, %104, %47 : index
|
|
%106 = arith.remsi %arg20, %c2_i32 : i32
|
|
%107 = arith.remsi %arg21, %c2_i32 : i32
|
|
%108 = arith.index_cast %107 : i32 to index
|
|
%200 = arith.index_cast %106 : i32 to index
|
|
%109 = tt.splat %105 : (i1) -> tensor<256x32xi1, #blocked0>
|
|
%112 = tt.splat %105 : (i1) -> tensor<32x128xi1, #blocked1>
|
|
%110 = tt.load %arg17, %109 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<256x32xf16, #blocked0>
|
|
%113 = tt.load %arg18, %112 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<32x128xf16, #blocked1>
|
|
%96 = tt.dot %arg22, %arg23, %arg10 {allowTF32 = true} : tensor<256x16xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, isMMAv1Row = true}>> * tensor<16x128xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, isMMAv1Row = true}>> -> tensor<256x128xf32, #mma>
|
|
%97 = tensor.extract_slice %arg15[0, 16] [256, 16] [1, 1] : tensor<256x32xf16, #shared0> to tensor<256x16xf16, #shared0>
|
|
%98 = triton_gpu.convert_layout %97 : (tensor<256x16xf16, #shared0>) -> tensor<256x16xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, isMMAv1Row = true}>>
|
|
%99 = tensor.extract_slice %arg16[16, 0] [16, 128] [1, 1] : tensor<32x128xf16, #shared1> to tensor<16x128xf16, #shared1>
|
|
%100 = triton_gpu.convert_layout %99 : (tensor<16x128xf16, #shared1>) -> tensor<16x128xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, isMMAv1Row = true}>>
|
|
%101 = tt.dot %98, %100, %96 {allowTF32 = true} : tensor<256x16xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, isMMAv1Row = true}>> * tensor<16x128xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, isMMAv1Row = true}>> -> tensor<256x128xf32, #mma>
|
|
%102 = tt.addptr %arg11, %cst : tensor<256x32x!tt.ptr<f16>, #blocked0>, tensor<256x32xi32, #blocked0>
|
|
%103 = tt.addptr %arg12, %49 : tensor<32x128x!tt.ptr<f16>, #blocked1>, tensor<32x128xi32, #blocked1>
|
|
gpu.barrier
|
|
%111 = tensor.insert_slice %110 into %arg13[%200, 0, 0] [1, 256, 32] [1, 1, 1] : tensor<256x32xf16, #blocked0> into tensor<2x256x32xf16, #shared0>
|
|
%114 = tensor.insert_slice %113 into %arg14[%200, 0, 0] [1, 32, 128] [1, 1, 1] : tensor<32x128xf16, #blocked1> into tensor<2x32x128xf16, #shared1>
|
|
gpu.barrier
|
|
%115 = tt.addptr %arg17, %cst : tensor<256x32x!tt.ptr<f16>, #blocked0>, tensor<256x32xi32, #blocked0>
|
|
%116 = tt.addptr %arg18, %49 : tensor<32x128x!tt.ptr<f16>, #blocked1>, tensor<32x128xi32, #blocked1>
|
|
%117 = tensor.extract_slice %111[%108, 0, 0] [1, 256, 32] [1, 1, 1] : tensor<2x256x32xf16, #shared0> to tensor<256x32xf16, #shared0>
|
|
%118 = tensor.extract_slice %114[%108, 0, 0] [1, 32, 128] [1, 1, 1] : tensor<2x32x128xf16, #shared1> to tensor<32x128xf16, #shared1>
|
|
%119 = arith.addi %arg20, %c1_i32 : i32
|
|
%120 = arith.addi %arg21, %c1_i32 : i32
|
|
%121 = tensor.extract_slice %117[0, 0] [256, 16] [1, 1] : tensor<256x32xf16, #shared0> to tensor<256x16xf16, #shared0>
|
|
%122 = triton_gpu.convert_layout %121 : (tensor<256x16xf16, #shared0>) -> tensor<256x16xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, isMMAv1Row = true}>>
|
|
%123 = tensor.extract_slice %118[0, 0] [16, 128] [1, 1] : tensor<32x128xf16, #shared1> to tensor<16x128xf16, #shared1>
|
|
%124 = triton_gpu.convert_layout %123 : (tensor<16x128xf16, #shared1>) -> tensor<16x128xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, isMMAv1Row = true}>>
|
|
scf.yield %101, %102, %103, %111, %114, %117, %118, %115, %116, %104, %119, %120, %122, %124 : tensor<256x128xf32, #mma>, tensor<256x32x!tt.ptr<f16>, #blocked0>, tensor<32x128x!tt.ptr<f16>, #blocked1>, tensor<2x256x32xf16, #shared0>, tensor<2x32x128xf16, #shared1>, tensor<256x32xf16, #shared0>, tensor<32x128xf16, #shared1>, tensor<256x32x!tt.ptr<f16>, #blocked0>, tensor<32x128x!tt.ptr<f16>, #blocked1>, index, i32, i32, tensor<256x16xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, isMMAv1Row = true}>>, tensor<16x128xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, isMMAv1Row = true}>>
|
|
}
|
|
gpu.barrier
|
|
%77 = triton_gpu.convert_layout %76#0 : (tensor<256x128xf32, #mma>) -> tensor<256x128xf32, #blocked1>
|
|
%78 = arith.addi %20, %18 : tensor<256xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>>
|
|
%79 = tt.splat %arg8 : (i32) -> tensor<256x1xi32, #blocked1>
|
|
%80 = tt.expand_dims %42 {axis = 0 : i32} : (tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>>) -> tensor<1x128xi32, #blocked1>
|
|
%81 = tt.broadcast %80 : (tensor<1x128xi32, #blocked1>) -> tensor<256x128xi32, #blocked1>
|
|
%82 = tt.splat %arg2 : (!tt.ptr<f16>) -> tensor<256x128x!tt.ptr<f16>, #blocked1>
|
|
%83 = "triton_gpu.cmpi"(%78, %25) {predicate = 2 : i64} : (tensor<256xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>>, tensor<256xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>>) -> tensor<256xi1, #triton_gpu.slice<{dim = 1, parent = #blocked1}>>
|
|
%84 = "triton_gpu.cmpi"(%42, %26) {predicate = 2 : i64} : (tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>>, tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>>) -> tensor<128xi1, #triton_gpu.slice<{dim = 0, parent = #blocked1}>>
|
|
%85 = tt.expand_dims %84 {axis = 0 : i32} : (tensor<128xi1, #triton_gpu.slice<{dim = 0, parent = #blocked1}>>) -> tensor<1x128xi1, #blocked1>
|
|
%86 = tt.broadcast %85 : (tensor<1x128xi1, #blocked1>) -> tensor<256x128xi1, #blocked1>
|
|
%87 = tt.expand_dims %78 {axis = 1 : i32} : (tensor<256xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>>) -> tensor<256x1xi32, #blocked1>
|
|
%88 = arith.muli %87, %79 : tensor<256x1xi32, #blocked1>
|
|
%89 = tt.broadcast %88 : (tensor<256x1xi32, #blocked1>) -> tensor<256x128xi32, #blocked1>
|
|
%90 = arith.addi %89, %81 : tensor<256x128xi32, #blocked1>
|
|
%91 = tt.addptr %82, %90 : tensor<256x128x!tt.ptr<f16>, #blocked1>, tensor<256x128xi32, #blocked1>
|
|
%92 = arith.truncf %77 : tensor<256x128xf32, #blocked1> to tensor<256x128xf16, #blocked1>
|
|
%93 = tt.expand_dims %83 {axis = 1 : i32} : (tensor<256xi1, #triton_gpu.slice<{dim = 1, parent = #blocked1}>>) -> tensor<256x1xi1, #blocked1>
|
|
%94 = tt.broadcast %93 : (tensor<256x1xi1, #blocked1>) -> tensor<256x128xi1, #blocked1>
|
|
%95 = arith.andi %94, %86 : tensor<256x128xi1, #blocked1>
|
|
tt.store %91, %92, %95 : tensor<256x128xf16, #blocked1>
|
|
return
|
|
}
|
|
}
|