diff --git a/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.cpp b/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.cpp index ce59cb214..0960be171 100644 --- a/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.cpp +++ b/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.cpp @@ -4799,6 +4799,8 @@ public: MembarAnalysis membarPass(&allocation); membarPass.run(); + mod.print(llvm::errs()); + RewritePatternSet scf_patterns(context); mlir::populateLoopToStdConversionPatterns(scf_patterns); mlir::ConversionTarget scf_target(*context); diff --git a/python/tests/matmul.ttgir b/python/tests/matmul.ttgir new file mode 100644 index 000000000..23117bb84 --- /dev/null +++ b/python/tests/matmul.ttgir @@ -0,0 +1,156 @@ +#blocked0 = #triton_gpu.blocked<{sizePerThread = [1, 8], threadsPerWarp = [8, 4], warpsPerCTA = [8, 1], order = [1, 0]}> +#blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 8], threadsPerWarp = [2, 16], warpsPerCTA = [8, 1], order = [1, 0]}> +#mma = #triton_gpu.mma<{version = 1, warpsPerCTA = [4, 2]}> +#shared0 = #triton_gpu.shared<{vec = 4, perPhase = 2, maxPhase = 4, order = [1, 0]}> +#shared1 = #triton_gpu.shared<{vec = 4, perPhase = 1, maxPhase = 4, order = [1, 0]}> +module attributes {"triton_gpu.num-warps" = 8 : i32} { + func public @_kernel_0d1d2d3d4d5d6d7d8d(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %arg2: !tt.ptr {tt.divisibility = 16 : i32}, %arg3: i32 {tt.divisibility = 16 : i32}, %arg4: i32 {tt.divisibility = 16 : i32}, %arg5: i32 {tt.divisibility = 16 : i32}, %arg6: i32 {tt.divisibility = 16 : i32}, %arg7: i32 {tt.divisibility = 16 : i32}, %arg8: i32 {tt.divisibility = 16 : i32}) { + %c2_i32 = arith.constant 2 : i32 + %c1_i32 = arith.constant 1 : i32 + %c0_i32 = arith.constant 0 : index + %cst = arith.constant dense<32> : tensor<256x32xi32, #blocked0> + %cst_0 = arith.constant dense<0.000000e+00> : tensor<256x128xf32, #mma> + %c8_i32 = arith.constant 8 : i32 + %c255_i32 = arith.constant 255 : i32 + %c127_i32 = arith.constant 127 : i32 + %c32_i32 = arith.constant 32 : i32 + %c0 = arith.constant 0 : index + %c32 = arith.constant 32 : index + %c256_i32 = arith.constant 256 : i32 + %c128_i32 = arith.constant 128 : i32 + %0 = tt.get_program_id {axis = 0 : i32} : i32 + %1 = tt.get_program_id {axis = 1 : i32} : i32 + %2 = arith.addi %arg3, %c255_i32 : i32 + %3 = arith.divsi %2, %c256_i32 : i32 + %4 = arith.addi %arg4, %c127_i32 : i32 + %5 = arith.divsi %4, %c128_i32 : i32 + %6 = arith.muli %5, %c8_i32 : i32 + %7 = arith.divsi %0, %6 : i32 + %8 = arith.muli %7, %c8_i32 : i32 + %9 = arith.subi %3, %8 : i32 + %10 = arith.cmpi slt, %9, %c8_i32 : i32 + %11 = select %10, %9, %c8_i32 : i32 + %12 = arith.remsi %0, %11 : i32 + %13 = arith.addi %8, %12 : i32 + %14 = arith.remsi %0, %6 : i32 + %15 = arith.divsi %14, %11 : i32 + %16 = arith.muli %13, %c256_i32 : i32 + %17 = tt.make_range {end = 256 : i32, start = 0 : i32} : tensor<256xi32, #triton_gpu.slice<{dim = 1, parent = #blocked0}>> + %18 = tt.make_range {end = 256 : i32, start = 0 : i32} : tensor<256xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>> + %19 = tt.splat %16 : (i32) -> tensor<256xi32, #triton_gpu.slice<{dim = 1, parent = #blocked0}>> + %20 = tt.splat %16 : (i32) -> tensor<256xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>> + %21 = arith.muli %15, %c128_i32 : i32 + %22 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>> + %23 = tt.splat %21 : (i32) -> tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>> + %24 = tt.splat %arg3 : (i32) -> tensor<256xi32, #triton_gpu.slice<{dim = 1, parent = #blocked0}>> + %25 = tt.splat %arg3 : (i32) -> tensor<256xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>> + %26 = tt.splat %arg4 : (i32) -> tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>> + %27 = arith.muli %1, %c32_i32 : i32 + %28 = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32, #triton_gpu.slice<{dim = 0, parent = #blocked0}>> + %29 = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>> + %30 = tt.splat %27 : (i32) -> tensor<32xi32, #triton_gpu.slice<{dim = 0, parent = #blocked0}>> + %31 = tt.splat %27 : (i32) -> tensor<32xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>> + %32 = arith.addi %19, %17 : tensor<256xi32, #triton_gpu.slice<{dim = 1, parent = #blocked0}>> + %33 = arith.remsi %32, %24 : tensor<256xi32, #triton_gpu.slice<{dim = 1, parent = #blocked0}>> + %34 = tt.splat %arg6 : (i32) -> tensor<256x1xi32, #blocked0> + %35 = arith.addi %30, %28 : tensor<32xi32, #triton_gpu.slice<{dim = 0, parent = #blocked0}>> + %36 = tt.expand_dims %35 {axis = 0 : i32} : (tensor<32xi32, #triton_gpu.slice<{dim = 0, parent = #blocked0}>>) -> tensor<1x32xi32, #blocked0> + %37 = tt.broadcast %36 : (tensor<1x32xi32, #blocked0>) -> tensor<256x32xi32, #blocked0> + %38 = tt.splat %arg0 : (!tt.ptr) -> tensor<256x32x!tt.ptr, #blocked0> + %39 = arith.addi %31, %29 : tensor<32xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>> + %40 = tt.expand_dims %39 {axis = 1 : i32} : (tensor<32xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>>) -> tensor<32x1xi32, #blocked1> + %41 = tt.splat %arg7 : (i32) -> tensor<32x1xi32, #blocked1> + %42 = arith.addi %23, %22 : tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>> + %43 = arith.remsi %42, %26 : tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>> + %44 = tt.expand_dims %43 {axis = 0 : i32} : (tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>>) -> tensor<1x128xi32, #blocked1> + %45 = tt.broadcast %44 : (tensor<1x128xi32, #blocked1>) -> tensor<32x128xi32, #blocked1> + %46 = tt.splat %arg1 : (!tt.ptr) -> tensor<32x128x!tt.ptr, #blocked1> + %47 = arith.index_cast %arg5 : i32 to index + %48 = arith.muli %arg7, %c32_i32 : i32 + %49 = tt.splat %48 : (i32) -> tensor<32x128xi32, #blocked1> + %50 = tt.expand_dims %33 {axis = 1 : i32} : (tensor<256xi32, #triton_gpu.slice<{dim = 1, parent = #blocked0}>>) -> tensor<256x1xi32, #blocked0> + %51 = arith.muli %50, %34 : tensor<256x1xi32, #blocked0> + %52 = tt.broadcast %51 : (tensor<256x1xi32, #blocked0>) -> tensor<256x32xi32, #blocked0> + %53 = arith.addi %52, %37 : tensor<256x32xi32, #blocked0> + %54 = tt.addptr %38, %53 : tensor<256x32x!tt.ptr, #blocked0>, tensor<256x32xi32, #blocked0> + %55 = arith.muli %40, %41 : tensor<32x1xi32, #blocked1> + %56 = tt.broadcast %55 : (tensor<32x1xi32, #blocked1>) -> tensor<32x128xi32, #blocked1> + %57 = arith.addi %56, %45 : tensor<32x128xi32, #blocked1> + %58 = tt.addptr %46, %57 : tensor<32x128x!tt.ptr, #blocked1>, tensor<32x128xi32, #blocked1> + %59 = arith.cmpi slt, %c0, %47 : index + %60 = triton_gpu.alloc_tensor : tensor<2x256x32xf16, #shared0> + %64 = triton_gpu.alloc_tensor : tensor<2x32x128xf16, #shared1> + %61 = tt.splat %59 : (i1) -> tensor<256x32xi1, #blocked0> + %65 = tt.splat %59 : (i1) -> tensor<32x128xi1, #blocked1> + %62 = tt.load %54, %61 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<256x32xf16, #blocked0> + %66 = tt.load %58, %65 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<32x128xf16, #blocked1> + %63 = tensor.insert_slice %62 into %60[%c0_i32, 0, 0] [1, 256, 32] [1, 1, 1] : tensor<256x32xf16, #blocked0> into tensor<2x256x32xf16, #shared0> + %67 = tensor.insert_slice %66 into %64[%c0_i32, 0, 0] [1, 32, 128] [1, 1, 1] : tensor<32x128xf16, #blocked1> into tensor<2x32x128xf16, #shared1> + %68 = tt.addptr %54, %cst : tensor<256x32x!tt.ptr, #blocked0>, tensor<256x32xi32, #blocked0> + %69 = tt.addptr %58, %49 : tensor<32x128x!tt.ptr, #blocked1>, tensor<32x128xi32, #blocked1> + %70 = tensor.extract_slice %63[0, 0, 0] [1, 256, 32] [1, 1, 1] : tensor<2x256x32xf16, #shared0> to tensor<256x32xf16, #shared0> + %71 = tensor.extract_slice %67[0, 0, 0] [1, 32, 128] [1, 1, 1] : tensor<2x32x128xf16, #shared1> to tensor<32x128xf16, #shared1> + %72 = tensor.extract_slice %70[0, 0] [256, 16] [1, 1] : tensor<256x32xf16, #shared0> to tensor<256x16xf16, #shared0> + gpu.barrier + %73 = triton_gpu.convert_layout %72 : (tensor<256x16xf16, #shared0>) -> tensor<256x16xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, isMMAv1Row = true}>> + %74 = tensor.extract_slice %71[0, 0] [16, 128] [1, 1] : tensor<32x128xf16, #shared1> to tensor<16x128xf16, #shared1> + %75 = triton_gpu.convert_layout %74 : (tensor<16x128xf16, #shared1>) -> tensor<16x128xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, isMMAv1Row = true}>> + %76:14 = scf.for %arg9 = %c0 to %47 step %c32 iter_args(%arg10 = %cst_0, %arg11 = %54, %arg12 = %58, %arg13 = %63, %arg14 = %67, %arg15 = %70, %arg16 = %71, %arg17 = %68, %arg18 = %69, %arg19 = %c0, %arg20 = %c1_i32, %arg21 = %c1_i32, %arg22 = %73, %arg23 = %75) -> (tensor<256x128xf32, #mma>, tensor<256x32x!tt.ptr, #blocked0>, tensor<32x128x!tt.ptr, #blocked1>, tensor<2x256x32xf16, #shared0>, tensor<2x32x128xf16, #shared1>, tensor<256x32xf16, #shared0>, tensor<32x128xf16, #shared1>, tensor<256x32x!tt.ptr, #blocked0>, tensor<32x128x!tt.ptr, #blocked1>, index, i32, i32, tensor<256x16xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, isMMAv1Row = true}>>, tensor<16x128xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, isMMAv1Row = true}>>) { + %104 = arith.addi %arg19, %c32 : index + %105 = arith.cmpi slt, %104, %47 : index + %106 = arith.remsi %arg20, %c2_i32 : i32 + %107 = arith.remsi %arg21, %c2_i32 : i32 + %108 = arith.index_cast %107 : i32 to index + %200 = arith.index_cast %106 : i32 to index + %109 = tt.splat %105 : (i1) -> tensor<256x32xi1, #blocked0> + %112 = tt.splat %105 : (i1) -> tensor<32x128xi1, #blocked1> + %110 = tt.load %arg17, %109 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<256x32xf16, #blocked0> + %113 = tt.load %arg18, %112 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<32x128xf16, #blocked1> + %96 = tt.dot %arg22, %arg23, %arg10 {allowTF32 = true} : tensor<256x16xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, isMMAv1Row = true}>> * tensor<16x128xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, isMMAv1Row = true}>> -> tensor<256x128xf32, #mma> + %97 = tensor.extract_slice %arg15[0, 16] [256, 16] [1, 1] : tensor<256x32xf16, #shared0> to tensor<256x16xf16, #shared0> + %98 = triton_gpu.convert_layout %97 : (tensor<256x16xf16, #shared0>) -> tensor<256x16xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, isMMAv1Row = true}>> + %99 = tensor.extract_slice %arg16[16, 0] [16, 128] [1, 1] : tensor<32x128xf16, #shared1> to tensor<16x128xf16, #shared1> + %100 = triton_gpu.convert_layout %99 : (tensor<16x128xf16, #shared1>) -> tensor<16x128xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, isMMAv1Row = true}>> + %101 = tt.dot %98, %100, %96 {allowTF32 = true} : tensor<256x16xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, isMMAv1Row = true}>> * tensor<16x128xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, isMMAv1Row = true}>> -> tensor<256x128xf32, #mma> + %102 = tt.addptr %arg11, %cst : tensor<256x32x!tt.ptr, #blocked0>, tensor<256x32xi32, #blocked0> + %103 = tt.addptr %arg12, %49 : tensor<32x128x!tt.ptr, #blocked1>, tensor<32x128xi32, #blocked1> + gpu.barrier + %111 = tensor.insert_slice %110 into %arg13[%200, 0, 0] [1, 256, 32] [1, 1, 1] : tensor<256x32xf16, #blocked0> into tensor<2x256x32xf16, #shared0> + %114 = tensor.insert_slice %113 into %arg14[%200, 0, 0] [1, 32, 128] [1, 1, 1] : tensor<32x128xf16, #blocked1> into tensor<2x32x128xf16, #shared1> + gpu.barrier + %115 = tt.addptr %arg17, %cst : tensor<256x32x!tt.ptr, #blocked0>, tensor<256x32xi32, #blocked0> + %116 = tt.addptr %arg18, %49 : tensor<32x128x!tt.ptr, #blocked1>, tensor<32x128xi32, #blocked1> + %117 = tensor.extract_slice %111[%108, 0, 0] [1, 256, 32] [1, 1, 1] : tensor<2x256x32xf16, #shared0> to tensor<256x32xf16, #shared0> + %118 = tensor.extract_slice %114[%108, 0, 0] [1, 32, 128] [1, 1, 1] : tensor<2x32x128xf16, #shared1> to tensor<32x128xf16, #shared1> + %119 = arith.addi %arg20, %c1_i32 : i32 + %120 = arith.addi %arg21, %c1_i32 : i32 + %121 = tensor.extract_slice %117[0, 0] [256, 16] [1, 1] : tensor<256x32xf16, #shared0> to tensor<256x16xf16, #shared0> + %122 = triton_gpu.convert_layout %121 : (tensor<256x16xf16, #shared0>) -> tensor<256x16xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, isMMAv1Row = true}>> + %123 = tensor.extract_slice %118[0, 0] [16, 128] [1, 1] : tensor<32x128xf16, #shared1> to tensor<16x128xf16, #shared1> + %124 = triton_gpu.convert_layout %123 : (tensor<16x128xf16, #shared1>) -> tensor<16x128xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, isMMAv1Row = true}>> + scf.yield %101, %102, %103, %111, %114, %117, %118, %115, %116, %104, %119, %120, %122, %124 : tensor<256x128xf32, #mma>, tensor<256x32x!tt.ptr, #blocked0>, tensor<32x128x!tt.ptr, #blocked1>, tensor<2x256x32xf16, #shared0>, tensor<2x32x128xf16, #shared1>, tensor<256x32xf16, #shared0>, tensor<32x128xf16, #shared1>, tensor<256x32x!tt.ptr, #blocked0>, tensor<32x128x!tt.ptr, #blocked1>, index, i32, i32, tensor<256x16xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, isMMAv1Row = true}>>, tensor<16x128xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, isMMAv1Row = true}>> + } + gpu.barrier + %77 = triton_gpu.convert_layout %76#0 : (tensor<256x128xf32, #mma>) -> tensor<256x128xf32, #blocked1> + %78 = arith.addi %20, %18 : tensor<256xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>> + %79 = tt.splat %arg8 : (i32) -> tensor<256x1xi32, #blocked1> + %80 = tt.expand_dims %42 {axis = 0 : i32} : (tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>>) -> tensor<1x128xi32, #blocked1> + %81 = tt.broadcast %80 : (tensor<1x128xi32, #blocked1>) -> tensor<256x128xi32, #blocked1> + %82 = tt.splat %arg2 : (!tt.ptr) -> tensor<256x128x!tt.ptr, #blocked1> + %83 = "triton_gpu.cmpi"(%78, %25) {predicate = 2 : i64} : (tensor<256xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>>, tensor<256xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>>) -> tensor<256xi1, #triton_gpu.slice<{dim = 1, parent = #blocked1}>> + %84 = "triton_gpu.cmpi"(%42, %26) {predicate = 2 : i64} : (tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>>, tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>>) -> tensor<128xi1, #triton_gpu.slice<{dim = 0, parent = #blocked1}>> + %85 = tt.expand_dims %84 {axis = 0 : i32} : (tensor<128xi1, #triton_gpu.slice<{dim = 0, parent = #blocked1}>>) -> tensor<1x128xi1, #blocked1> + %86 = tt.broadcast %85 : (tensor<1x128xi1, #blocked1>) -> tensor<256x128xi1, #blocked1> + %87 = tt.expand_dims %78 {axis = 1 : i32} : (tensor<256xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>>) -> tensor<256x1xi32, #blocked1> + %88 = arith.muli %87, %79 : tensor<256x1xi32, #blocked1> + %89 = tt.broadcast %88 : (tensor<256x1xi32, #blocked1>) -> tensor<256x128xi32, #blocked1> + %90 = arith.addi %89, %81 : tensor<256x128xi32, #blocked1> + %91 = tt.addptr %82, %90 : tensor<256x128x!tt.ptr, #blocked1>, tensor<256x128xi32, #blocked1> + %92 = arith.truncf %77 : tensor<256x128xf32, #blocked1> to tensor<256x128xf16, #blocked1> + %93 = tt.expand_dims %83 {axis = 1 : i32} : (tensor<256xi1, #triton_gpu.slice<{dim = 1, parent = #blocked1}>>) -> tensor<256x1xi1, #blocked1> + %94 = tt.broadcast %93 : (tensor<256x1xi1, #blocked1>) -> tensor<256x128xi1, #blocked1> + %95 = arith.andi %94, %86 : tensor<256x128xi1, #blocked1> + tt.store %91, %92, %95 : tensor<256x128xf16, #blocked1> + return + } +} diff --git a/python/triton/compiler.py b/python/triton/compiler.py index 91ff732a7..560c73423 100644 --- a/python/triton/compiler.py +++ b/python/triton/compiler.py @@ -1431,9 +1431,7 @@ def compile(fn, **kwargs): import re match = re.search(prototype_pattern[ir], src, re.MULTILINE) name, signature = match.group(1), match.group(2) - print(name, signature) types = re.findall(arg_type_pattern[ir], signature) - print(types) param_tys = [convert_type_repr(ty) for ty in types] signature = {k: v for k, v in enumerate(param_tys)} first_stage = list(stages.keys()).index(ir) diff --git a/python/triton/ops/matmul.py b/python/triton/ops/matmul.py index 0ffcc1677..406b49401 100644 --- a/python/triton/ops/matmul.py +++ b/python/triton/ops/matmul.py @@ -27,28 +27,29 @@ def get_configs_io_bound(): @triton.autotune( - configs=[ - # basic configs for compute-bound matmuls - triton.Config({'BLOCK_M': 128, 'BLOCK_N': 256, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=3, num_warps=8), - triton.Config({'BLOCK_M': 256, 'BLOCK_N': 128, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=3, num_warps=8), - triton.Config({'BLOCK_M': 256, 'BLOCK_N': 64, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4), - triton.Config({'BLOCK_M': 64, 'BLOCK_N': 256, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4), - triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4), - triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4), - triton.Config({'BLOCK_M': 64, 'BLOCK_N': 128, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4), - triton.Config({'BLOCK_M': 128, 'BLOCK_N': 32, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4), - triton.Config({'BLOCK_M': 64, 'BLOCK_N': 32, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=5, num_warps=2), - # good for int8 - triton.Config({'BLOCK_M': 128, 'BLOCK_N': 256, 'BLOCK_K': 128, 'SPLIT_K': 1}, num_stages=3, num_warps=8), - triton.Config({'BLOCK_M': 256, 'BLOCK_N': 128, 'BLOCK_K': 128, 'SPLIT_K': 1}, num_stages=3, num_warps=8), - triton.Config({'BLOCK_M': 256, 'BLOCK_N': 64, 'BLOCK_K': 128, 'SPLIT_K': 1}, num_stages=4, num_warps=4), - triton.Config({'BLOCK_M': 64, 'BLOCK_N': 256, 'BLOCK_K': 128, 'SPLIT_K': 1}, num_stages=4, num_warps=4), - triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'BLOCK_K': 128, 'SPLIT_K': 1}, num_stages=4, num_warps=4), - triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'BLOCK_K': 64, 'SPLIT_K': 1}, num_stages=4, num_warps=4), - triton.Config({'BLOCK_M': 64, 'BLOCK_N': 128, 'BLOCK_K': 64, 'SPLIT_K': 1}, num_stages=4, num_warps=4), - triton.Config({'BLOCK_M': 128, 'BLOCK_N': 32, 'BLOCK_K': 64, 'SPLIT_K': 1}, num_stages=4, num_warps=4), - triton.Config({'BLOCK_M': 64, 'BLOCK_N': 32, 'BLOCK_K': 64, 'SPLIT_K': 1}, num_stages=5, num_warps=2), - ] + get_configs_io_bound(), + #configs=[ + # # basic configs for compute-bound matmuls + # triton.Config({'BLOCK_M': 128, 'BLOCK_N': 256, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=3, num_warps=8), + # triton.Config({'BLOCK_M': 256, 'BLOCK_N': 128, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=3, num_warps=8), + # triton.Config({'BLOCK_M': 256, 'BLOCK_N': 64, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4), + # triton.Config({'BLOCK_M': 64, 'BLOCK_N': 256, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4), + # triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4), + # triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4), + # triton.Config({'BLOCK_M': 64, 'BLOCK_N': 128, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4), + # triton.Config({'BLOCK_M': 128, 'BLOCK_N': 32, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4), + # triton.Config({'BLOCK_M': 64, 'BLOCK_N': 32, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=5, num_warps=2), + # # good for int8 + # triton.Config({'BLOCK_M': 128, 'BLOCK_N': 256, 'BLOCK_K': 128, 'SPLIT_K': 1}, num_stages=3, num_warps=8), + # triton.Config({'BLOCK_M': 256, 'BLOCK_N': 128, 'BLOCK_K': 128, 'SPLIT_K': 1}, num_stages=3, num_warps=8), + # triton.Config({'BLOCK_M': 256, 'BLOCK_N': 64, 'BLOCK_K': 128, 'SPLIT_K': 1}, num_stages=4, num_warps=4), + # triton.Config({'BLOCK_M': 64, 'BLOCK_N': 256, 'BLOCK_K': 128, 'SPLIT_K': 1}, num_stages=4, num_warps=4), + # triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'BLOCK_K': 128, 'SPLIT_K': 1}, num_stages=4, num_warps=4), + # triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'BLOCK_K': 64, 'SPLIT_K': 1}, num_stages=4, num_warps=4), + # triton.Config({'BLOCK_M': 64, 'BLOCK_N': 128, 'BLOCK_K': 64, 'SPLIT_K': 1}, num_stages=4, num_warps=4), + # triton.Config({'BLOCK_M': 128, 'BLOCK_N': 32, 'BLOCK_K': 64, 'SPLIT_K': 1}, num_stages=4, num_warps=4), + # triton.Config({'BLOCK_M': 64, 'BLOCK_N': 32, 'BLOCK_K': 64, 'SPLIT_K': 1}, num_stages=5, num_warps=2), + #] + get_configs_io_bound(), + configs=[triton.Config({'BLOCK_M': 256, 'BLOCK_N': 128, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=2, num_warps=8)], key=['M', 'N', 'K'], prune_configs_by={ 'early_config_prune': early_config_prune, @@ -113,7 +114,7 @@ def _kernel(A, B, C, M, N, K, class _matmul(torch.autograd.Function): - kernel = _kernel + kernel = None _locks = dict() @@ -134,12 +135,17 @@ class _matmul(torch.autograd.Function): # accumulator types ACC_TYPE = tl.float32 if a.dtype in [torch.float16, torch.bfloat16, torch.float32] else tl.int32 # launch kernel - grid = lambda META: (triton.cdiv(M, META['BLOCK_M']) * triton.cdiv(N, META['BLOCK_N']), META['SPLIT_K']) - _kernel[grid](a, b, c, M, N, K, - a.stride(0), a.stride(1), - b.stride(0), b.stride(1), - c.stride(0), c.stride(1), - GROUP_M=8, ACC_TYPE=ACC_TYPE) + #grid = lambda META: (triton.cdiv(M, META['BLOCK_M']) * triton.cdiv(N, META['BLOCK_N']), META['SPLIT_K']) + if _matmul.kernel is None: + _matmul.kernel = triton.compile("/root/code/triton-mlir/python/tests/matmul.ttgir", num_stages=2, num_warps=8) + #_matmul.kernel = _kernel + _matmul.kernel[(8192//256 * 8192//128, 1, 1,)](a.data_ptr(), b.data_ptr(), c.data_ptr(), + M, N, K, + a.stride(0), b.stride(0), c.stride(0)) + #_matmul.kernel[grid](a, b, c, + # M, N, K, + # a.stride(0), a.stride(1), b.stride(0), b.stride(1), c.stride(0), c.stride(1), + # GROUP_M=8, ACC_TYPE=ACC_TYPE) return c @staticmethod diff --git a/python/triton/runtime/autotuner.py b/python/triton/runtime/autotuner.py index 522964d56..7ec3f458f 100644 --- a/python/triton/runtime/autotuner.py +++ b/python/triton/runtime/autotuner.py @@ -74,6 +74,8 @@ class Autotuner(KernelInterface): for config in pruned_configs} bench_end = time.time() self.bench_time = bench_end - bench_start + for config, ttime in timings.items(): + print(f"config: {config}, time: {ttime}") self.cache[key] = builtins.min(timings, key=timings.get) self.hook(args) self.configs_timings = timings diff --git a/python/tutorials/03-matrix-multiplication.py b/python/tutorials/03-matrix-multiplication.py index f11c3bc09..82a264dc2 100644 --- a/python/tutorials/03-matrix-multiplication.py +++ b/python/tutorials/03-matrix-multiplication.py @@ -156,7 +156,7 @@ import triton.language as tl @triton.autotune( configs=[ - triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8}, num_stages=3, num_warps=8), + triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=3, num_warps=8), ], key=['M', 'N', 'K'], )