From a3d0812d27d90fe59d02b553b94dd027b97c0fd2 Mon Sep 17 00:00:00 2001 From: Yan Da Date: Fri, 8 Apr 2022 19:38:25 +0800 Subject: [PATCH] Update example --- rewrite-test/jit/matmul/matmul.mlir | 90 ++++++++++++++--------------- 1 file changed, 45 insertions(+), 45 deletions(-) diff --git a/rewrite-test/jit/matmul/matmul.mlir b/rewrite-test/jit/matmul/matmul.mlir index 81b8a1bcf..87a81abb7 100644 --- a/rewrite-test/jit/matmul/matmul.mlir +++ b/rewrite-test/jit/matmul/matmul.mlir @@ -30,40 +30,40 @@ module { %c64_i32_7 = arith.constant 64 : i32 %17 = arith.muli %14, %c64_i32_7 : i32 %18 = triton.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32> - %19 = "triton.broadcast"(%17) : (i32) -> tensor<64xi32> + %19 = triton.broadcast %17 : (i32) -> tensor<64xi32> %20 = arith.addi %19, %18 : tensor<64xi32> %c64_i32_8 = arith.constant 64 : i32 %21 = arith.muli %16, %c64_i32_8 : i32 %22 = triton.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32> - %23 = "triton.broadcast"(%21) : (i32) -> tensor<64xi32> + %23 = triton.broadcast %21 : (i32) -> tensor<64xi32> %24 = arith.addi %23, %22 : tensor<64xi32> %25 = triton.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32> - %26 = "triton.reshape"(%20) : (tensor<64xi32>) -> tensor<64x1xi32> - %27 = "triton.broadcast"(%arg6) : (i32) -> tensor<64x1xi32> + %26 = triton.reshape %20 : (tensor<64xi32>) -> tensor<64x1xi32> + %27 = triton.broadcast %arg6 : (i32) -> tensor<64x1xi32> %28 = arith.muli %26, %27 : tensor<64x1xi32> - %29 = "triton.reshape"(%25) : (tensor<32xi32>) -> tensor<1x32xi32> + %29 = triton.reshape %25 : (tensor<32xi32>) -> tensor<1x32xi32> %c1_i32_9 = arith.constant 1 : i32 - %30 = "triton.broadcast"(%c1_i32_9) : (i32) -> tensor<1x32xi32> + %30 = triton.broadcast %c1_i32_9 : (i32) -> tensor<1x32xi32> %31 = arith.muli %29, %30 : tensor<1x32xi32> - %32 = "triton.broadcast"(%28) : (tensor<64x1xi32>) -> tensor<64x32xi32> - %33 = "triton.broadcast"(%31) : (tensor<1x32xi32>) -> tensor<64x32xi32> + %32 = triton.broadcast %28 : (tensor<64x1xi32>) -> tensor<64x32xi32> + %33 = triton.broadcast %31 : (tensor<1x32xi32>) -> tensor<64x32xi32> %34 = arith.addi %32, %33 : tensor<64x32xi32> - %35 = "triton.broadcast"(%arg0) : (!triton.ptr) -> tensor<64x32x!triton.ptr> - %36 = "triton.getelementptr"(%35, %34) : (tensor<64x32x!triton.ptr>, tensor<64x32xi32>) -> tensor<64x32x!triton.ptr> - %37 = "triton.reshape"(%25) : (tensor<32xi32>) -> tensor<32x1xi32> - %38 = "triton.broadcast"(%arg7) : (i32) -> tensor<32x1xi32> + %35 = triton.broadcast %arg0 : (!triton.ptr) -> tensor<64x32x!triton.ptr> + %36 = triton.getelementptr %35, %34, : tensor<64x32x!triton.ptr> + %37 = triton.reshape %25 : (tensor<32xi32>) -> tensor<32x1xi32> + %38 = triton.broadcast %arg7 : (i32) -> tensor<32x1xi32> %39 = arith.muli %37, %38 : tensor<32x1xi32> - %40 = "triton.reshape"(%24) : (tensor<64xi32>) -> tensor<1x64xi32> + %40 = triton.reshape %24 : (tensor<64xi32>) -> tensor<1x64xi32> %c1_i32_10 = arith.constant 1 : i32 - %41 = "triton.broadcast"(%c1_i32_10) : (i32) -> tensor<1x64xi32> + %41 = triton.broadcast %c1_i32_10 : (i32) -> tensor<1x64xi32> %42 = arith.muli %40, %41 : tensor<1x64xi32> - %43 = "triton.broadcast"(%39) : (tensor<32x1xi32>) -> tensor<32x64xi32> - %44 = "triton.broadcast"(%42) : (tensor<1x64xi32>) -> tensor<32x64xi32> + %43 = triton.broadcast %39 : (tensor<32x1xi32>) -> tensor<32x64xi32> + %44 = triton.broadcast %42 : (tensor<1x64xi32>) -> tensor<32x64xi32> %45 = arith.addi %43, %44 : tensor<32x64xi32> - %46 = "triton.broadcast"(%arg1) : (!triton.ptr) -> tensor<32x64x!triton.ptr> - %47 = "triton.getelementptr"(%46, %45) : (tensor<32x64x!triton.ptr>, tensor<32x64xi32>) -> tensor<32x64x!triton.ptr> + %46 = triton.broadcast %arg1 : (!triton.ptr) -> tensor<32x64x!triton.ptr> + %47 = triton.getelementptr %46, %45, : tensor<32x64x!triton.ptr> %cst = arith.constant 0.000000e+00 : f32 - %48 = "triton.broadcast"(%cst) : (f32) -> tensor<64x64xf32> + %48 = triton.broadcast %cst : (f32) -> tensor<64x64xf32> %c0_i32 = arith.constant 0 : i32 %c32_i32 = arith.constant 32 : i32 %49 = arith.index_cast %c0_i32 : i32 to index @@ -72,56 +72,56 @@ module { %52:3 = scf.for %arg9 = %49 to %50 step %51 iter_args(%arg10 = %48, %arg11 = %36, %arg12 = %47) -> (tensor<64x64xf32>, tensor<64x32x!triton.ptr>, tensor<32x64x!triton.ptr>) { %cst_14 = arith.constant dense : tensor<64x32xi1> %cst_15 = arith.constant dense<0.000000e+00> : tensor<64x32xf16> - %82 = "triton.load"(%arg11, %cst_14, %cst_15) {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : (tensor<64x32x!triton.ptr>, tensor<64x32xi1>, tensor<64x32xf16>) -> tensor<64x32xf16> + %82 = triton.load %arg11, %cst_14, %cst_15 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<64x32xf16> %cst_16 = arith.constant dense : tensor<32x64xi1> %cst_17 = arith.constant dense<0.000000e+00> : tensor<32x64xf16> - %83 = "triton.load"(%arg12, %cst_16, %cst_17) {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : (tensor<32x64x!triton.ptr>, tensor<32x64xi1>, tensor<32x64xf16>) -> tensor<32x64xf16> + %83 = triton.load %arg12, %cst_16, %cst_17 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<32x64xf16> %cst_18 = arith.constant 0.000000e+00 : f32 - %84 = "triton.broadcast"(%cst_18) : (f32) -> tensor<64x64xf32> - %85 = "triton.dot"(%82, %83, %84) {allowTF32 = true} : (tensor<64x32xf16>, tensor<32x64xf16>, tensor<64x64xf32>) -> tensor<64x64xf32> + %84 = triton.broadcast %cst_18 : (f32) -> tensor<64x64xf32> + %85 = triton.dot %82, %83, %84 {allowTF32 = true} : tensor<64x32xf16> * tensor<32x64xf16> -> tensor<64x64xf32> %86 = arith.addf %arg10, %85 : tensor<64x64xf32> %c32_i32_19 = arith.constant 32 : i32 - %87 = "triton.broadcast"(%c32_i32_19) : (i32) -> tensor<64x32xi32> - %88 = "triton.getelementptr"(%arg11, %87) : (tensor<64x32x!triton.ptr>, tensor<64x32xi32>) -> tensor<64x32x!triton.ptr> + %87 = triton.broadcast %c32_i32_19 : (i32) -> tensor<64x32xi32> + %88 = triton.getelementptr %arg11, %87, : tensor<64x32x!triton.ptr> %c32_i32_20 = arith.constant 32 : i32 %89 = arith.muli %arg7, %c32_i32_20 : i32 - %90 = "triton.broadcast"(%89) : (i32) -> tensor<32x64xi32> - %91 = "triton.getelementptr"(%arg12, %90) : (tensor<32x64x!triton.ptr>, tensor<32x64xi32>) -> tensor<32x64x!triton.ptr> + %90 = triton.broadcast %89 : (i32) -> tensor<32x64xi32> + %91 = triton.getelementptr %arg12, %90, : tensor<32x64x!triton.ptr> scf.yield %86, %88, %91 : tensor<64x64xf32>, tensor<64x32x!triton.ptr>, tensor<32x64x!triton.ptr> } %53 = arith.truncf %52#0 : tensor<64x64xf32> to tensor<64x64xf16> %c64_i32_11 = arith.constant 64 : i32 %54 = arith.muli %14, %c64_i32_11 : i32 %55 = triton.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32> - %56 = "triton.broadcast"(%54) : (i32) -> tensor<64xi32> + %56 = triton.broadcast %54 : (i32) -> tensor<64xi32> %57 = arith.addi %56, %55 : tensor<64xi32> %c64_i32_12 = arith.constant 64 : i32 %58 = arith.muli %16, %c64_i32_12 : i32 %59 = triton.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32> - %60 = "triton.broadcast"(%58) : (i32) -> tensor<64xi32> + %60 = triton.broadcast %58 : (i32) -> tensor<64xi32> %61 = arith.addi %60, %59 : tensor<64xi32> - %62 = "triton.reshape"(%57) : (tensor<64xi32>) -> tensor<64x1xi32> - %63 = "triton.broadcast"(%arg8) : (i32) -> tensor<64x1xi32> + %62 = triton.reshape %57 : (tensor<64xi32>) -> tensor<64x1xi32> + %63 = triton.broadcast %arg8 : (i32) -> tensor<64x1xi32> %64 = arith.muli %63, %62 : tensor<64x1xi32> - %65 = "triton.broadcast"(%arg2) : (!triton.ptr) -> tensor<64x1x!triton.ptr> - %66 = "triton.getelementptr"(%65, %64) : (tensor<64x1x!triton.ptr>, tensor<64x1xi32>) -> tensor<64x1x!triton.ptr> - %67 = "triton.reshape"(%61) : (tensor<64xi32>) -> tensor<1x64xi32> + %65 = triton.broadcast %arg2 : (!triton.ptr) -> tensor<64x1x!triton.ptr> + %66 = triton.getelementptr %65, %64, : tensor<64x1x!triton.ptr> + %67 = triton.reshape %61 : (tensor<64xi32>) -> tensor<1x64xi32> %c1_i32_13 = arith.constant 1 : i32 - %68 = "triton.broadcast"(%c1_i32_13) : (i32) -> tensor<1x64xi32> + %68 = triton.broadcast %c1_i32_13 : (i32) -> tensor<1x64xi32> %69 = arith.muli %67, %68 : tensor<1x64xi32> - %70 = "triton.broadcast"(%66) : (tensor<64x1x!triton.ptr>) -> tensor<64x64x!triton.ptr> - %71 = "triton.broadcast"(%69) : (tensor<1x64xi32>) -> tensor<64x64xi32> - %72 = "triton.getelementptr"(%70, %71) : (tensor<64x64x!triton.ptr>, tensor<64x64xi32>) -> tensor<64x64x!triton.ptr> - %73 = "triton.reshape"(%57) : (tensor<64xi32>) -> tensor<64x1xi32> - %74 = "triton.broadcast"(%arg3) : (i32) -> tensor<64x1xi32> + %70 = triton.broadcast %66 : (tensor<64x1x!triton.ptr>) -> tensor<64x64x!triton.ptr> + %71 = triton.broadcast %69 : (tensor<1x64xi32>) -> tensor<64x64xi32> + %72 = triton.getelementptr %70, %71, : tensor<64x64x!triton.ptr> + %73 = triton.reshape %57 : (tensor<64xi32>) -> tensor<64x1xi32> + %74 = triton.broadcast %arg3 : (i32) -> tensor<64x1xi32> %75 = arith.cmpi slt, %73, %74 : tensor<64x1xi32> - %76 = "triton.reshape"(%61) : (tensor<64xi32>) -> tensor<1x64xi32> - %77 = "triton.broadcast"(%arg4) : (i32) -> tensor<1x64xi32> + %76 = triton.reshape %61 : (tensor<64xi32>) -> tensor<1x64xi32> + %77 = triton.broadcast %arg4 : (i32) -> tensor<1x64xi32> %78 = arith.cmpi slt, %76, %77 : tensor<1x64xi32> - %79 = "triton.broadcast"(%75) : (tensor<64x1xi1>) -> tensor<64x64xi1> - %80 = "triton.broadcast"(%78) : (tensor<1x64xi1>) -> tensor<64x64xi1> + %79 = triton.broadcast %75 : (tensor<64x1xi1>) -> tensor<64x64xi1> + %80 = triton.broadcast %78 : (tensor<1x64xi1>) -> tensor<64x64xi1> %81 = arith.andi %79, %80 : tensor<64x64xi1> - "triton.store"(%72, %53, %81) : (tensor<64x64x!triton.ptr>, tensor<64x64xf16>, tensor<64x64xi1>) -> () + triton.store %72, %53, %81, : tensor<64x64xf16> return } }