diff --git a/lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp b/lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp index 9cce8bab6..1a601bdd0 100644 --- a/lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp +++ b/lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp @@ -72,23 +72,24 @@ void storeDistributedToShared(Value src, Value llSrc, Value staIdx1 = i32_val(0); Value stride0 = dstStrides[outOrd[0]]; Value stride1 = dstStrides[outOrd[1]]; - if(auto addOp = dyn_cast(dynIdx0.getDefiningOp())) - if(auto cstRhs = dyn_cast(addOp.getRhs().getDefiningOp())) { - unsigned rhsVal = cstRhs.getValue().cast().getValue().getSExtValue(); - unsigned key = (rhsVal/outVec) % maxPhase; - llvm::outs() << srcDistributedLayout.dyn_cast() << " " << rhsVal << " " << key << "\n"; - if(cache.find(key) == cache.end()) - cache[key] = dynIdx0; - dynIdx0 = cache[key]; - staIdx0 = i32_val((rhsVal)/(outVec*maxPhase)*(outVec*maxPhase)); - } - if(auto addOp = dyn_cast(dynIdx1.getDefiningOp())) - if(auto cstRhs = dyn_cast(addOp.getRhs().getDefiningOp())) { - dynIdx1 = addOp.getLhs(); - staIdx1 = addOp.getRhs(); - } - - + if (auto addOp = dyn_cast(dynIdx0.getDefiningOp())) + if (auto cstRhs = + dyn_cast(addOp.getRhs().getDefiningOp())) { + unsigned rhsVal = + cstRhs.getValue().cast().getValue().getSExtValue(); + unsigned key = (rhsVal / outVec) % maxPhase; + if (cache.find(key) == cache.end()) + cache[key] = dynIdx0; + dynIdx0 = cache[key]; + staIdx0 = + i32_val((rhsVal) / (outVec * maxPhase) * (outVec * maxPhase)); + } + if (auto addOp = dyn_cast(dynIdx1.getDefiningOp())) + if (auto cstRhs = + dyn_cast(addOp.getRhs().getDefiningOp())) { + dynIdx1 = addOp.getLhs(); + staIdx1 = addOp.getRhs(); + } // offset along non-contiguous dimension Value off1 = mul(dynIdx1, stride1); @@ -100,10 +101,9 @@ void storeDistributedToShared(Value src, Value llSrc, remained = udiv(remained, minVecVal); off0 = add(off0, mul(remained, minVecVal)); Value offset = add(off1, mul(off0, stride0)); - + Value staOffset = add(mul(staIdx1, stride1), mul(staIdx0, stride0)); // add static offset - offset = add(offset, mul(staIdx1, stride1)); - offset = add(offset, mul(staIdx0, stride0)); + offset = add(offset, staOffset); // step 3: store Value smemAddr = gep(elemPtrTy, smemBase, offset); diff --git a/python/bwd.ttgir b/python/bwd.ttgir index d3f9aeccc..1a8fd9d6e 100644 --- a/python/bwd.ttgir +++ b/python/bwd.ttgir @@ -31,22 +31,22 @@ module attributes {"triton_gpu.num-warps" = 8 : i32} { %14 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #blocked0> %15 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>> %16 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #mma0}>> - %17 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked2}>> + %17 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #mma1}>> %18 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #mma0}>> %19 = tt.splat %arg14 : (i32) -> tensor<128x1xi32, #blocked1> - %20 = tt.splat %arg14 : (i32) -> tensor<128x1xi32, #blocked2> + %20 = tt.splat %arg14 : (i32) -> tensor<128x1xi32, #mma1> %21 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>> - %22 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = #blocked2}>> + %22 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = #mma1}>> %23 = tt.expand_dims %21 {axis = 0 : i32} : (tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>>) -> tensor<1x64xi32, #blocked1> %24 = tt.broadcast %23 : (tensor<1x64xi32, #blocked1>) -> tensor<128x64xi32, #blocked1> - %25 = tt.expand_dims %22 {axis = 0 : i32} : (tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = #blocked2}>>) -> tensor<1x64xi32, #blocked2> - %26 = tt.broadcast %25 : (tensor<1x64xi32, #blocked2>) -> tensor<128x64xi32, #blocked2> + %25 = tt.expand_dims %22 {axis = 0 : i32} : (tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = #mma1}>>) -> tensor<1x64xi32, #mma1> + %26 = tt.broadcast %25 : (tensor<1x64xi32, #mma1>) -> tensor<128x64xi32, #mma1> %27 = tt.splat %6 : (!tt.ptr) -> tensor<128x64x!tt.ptr, #blocked1> %28 = tt.splat %arg17 : (i32) -> tensor<128x1xi32, #blocked1> %29 = tt.splat %7 : (!tt.ptr) -> tensor<128x64x!tt.ptr, #blocked1> %30 = tt.splat %8 : (!tt.ptr) -> tensor<128x64x!tt.ptr, #blocked1> %31 = tt.splat %9 : (!tt.ptr) -> tensor<128x64x!tt.ptr, #blocked1> - %32 = tt.splat %10 : (!tt.ptr) -> tensor<128x64x!tt.ptr, #blocked2> + %32 = tt.splat %10 : (!tt.ptr) -> tensor<128x64x!tt.ptr, #mma1> %33 = arith.muli %0, %arg23 : i32 %34 = tt.addptr %arg11, %33 : !tt.ptr, i32 %35 = tt.addptr %arg10, %33 : !tt.ptr, i32 @@ -57,7 +57,7 @@ module attributes {"triton_gpu.num-warps" = 8 : i32} { %40 = tt.splat %34 : (!tt.ptr) -> tensor<128x!tt.ptr, #blocked0> %41 = arith.muli %arg14, %c128_i32 : i32 %42 = tt.splat %41 : (i32) -> tensor<128x64xi32, #blocked1> - %43 = tt.splat %41 : (i32) -> tensor<128x64xi32, #blocked2> + %43 = tt.splat %41 : (i32) -> tensor<128x64xi32, #mma1> %44 = tt.splat %12 : (!tt.ptr) -> tensor<128x64x!tt.ptr, #blocked1> %45 = tt.splat %11 : (!tt.ptr) -> tensor<128x64x!tt.ptr, #blocked1> scf.for %arg25 = %c0 to %13 step %c1 { @@ -65,11 +65,11 @@ module attributes {"triton_gpu.num-warps" = 8 : i32} { %47 = arith.muli %46, %c128_i32 : i32 %48 = tt.splat %47 : (i32) -> tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>> %49 = tt.splat %47 : (i32) -> tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #mma0}>> - %50 = tt.splat %47 : (i32) -> tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked2}>> + %50 = tt.splat %47 : (i32) -> tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #mma1}>> %51 = arith.addi %48, %15 : tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>> - %52 = arith.addi %50, %17 : tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked2}>> + %52 = arith.addi %50, %17 : tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #mma1}>> %53 = tt.expand_dims %51 {axis = 1 : i32} : (tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>>) -> tensor<128x1xi32, #blocked1> - %54 = tt.expand_dims %52 {axis = 1 : i32} : (tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked2}>>) -> tensor<128x1xi32, #blocked2> + %54 = tt.expand_dims %52 {axis = 1 : i32} : (tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #mma1}>>) -> tensor<128x1xi32, #mma1> %55 = arith.muli %53, %28 : tensor<128x1xi32, #blocked1> %56 = tt.broadcast %55 : (tensor<128x1xi32, #blocked1>) -> tensor<128x64xi32, #blocked1> %57 = arith.addi %56, %24 : tensor<128x64xi32, #blocked1> @@ -88,13 +88,13 @@ module attributes {"triton_gpu.num-warps" = 8 : i32} { %70 = tt.broadcast %69 : (tensor<1x128xi32, #mma0>) -> tensor<128x128xi32, #mma0> %71 = triton_gpu.convert_layout %64 : (tensor<128x64xf16, #blocked1>) -> tensor<128x64xf16, #shared1> %72 = tt.trans %71 : (tensor<128x64xf16, #shared1>) -> tensor<64x128xf16, #shared0> - %73 = arith.muli %54, %20 : tensor<128x1xi32, #blocked2> - %74 = tt.broadcast %73 : (tensor<128x1xi32, #blocked2>) -> tensor<128x64xi32, #blocked2> - %75 = arith.addi %74, %26 : tensor<128x64xi32, #blocked2> - %76 = tt.addptr %32, %75 : tensor<128x64x!tt.ptr, #blocked2>, tensor<128x64xi32, #blocked2> + %73 = arith.muli %54, %20 : tensor<128x1xi32, #mma1> + %74 = tt.broadcast %73 : (tensor<128x1xi32, #mma1>) -> tensor<128x64xi32, #mma1> + %75 = arith.addi %74, %26 : tensor<128x64xi32, #mma1> + %76 = tt.addptr %32, %75 : tensor<128x64x!tt.ptr, #mma1>, tensor<128x64xi32, #mma1> %77 = tt.addptr %27, %62 : tensor<128x64x!tt.ptr, #blocked1>, tensor<128x64xi32, #blocked1> %78 = tt.addptr %31, %62 : tensor<128x64x!tt.ptr, #blocked1>, tensor<128x64xi32, #blocked1> - %79:5 = scf.for %arg26 = %65 to %37 step %c128 iter_args(%arg27 = %cst, %arg28 = %cst, %arg29 = %76, %arg30 = %77, %arg31 = %78) -> (tensor<128x64xf32, #mma1>, tensor<128x64xf32, #mma1>, tensor<128x64x!tt.ptr, #blocked2>, tensor<128x64x!tt.ptr, #blocked1>, tensor<128x64x!tt.ptr, #blocked1>) { + %79:5 = scf.for %arg26 = %65 to %37 step %c128 iter_args(%arg27 = %cst, %arg28 = %cst, %arg29 = %76, %arg30 = %77, %arg31 = %78) -> (tensor<128x64xf32, #mma1>, tensor<128x64xf32, #mma1>, tensor<128x64x!tt.ptr, #mma1>, tensor<128x64x!tt.ptr, #blocked1>, tensor<128x64x!tt.ptr, #blocked1>) { %86 = arith.index_cast %arg26 : index to i32 %87 = tt.splat %86 : (i32) -> tensor<128xi32, #blocked0> %88 = tt.splat %86 : (i32) -> tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #mma0}>> @@ -142,17 +142,15 @@ module attributes {"triton_gpu.num-warps" = 8 : i32} { %128 = triton_gpu.convert_layout %127 : (tensor<128x128xf16, #shared0>) -> tensor<128x128xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma1}>> %129 = triton_gpu.convert_layout %90 : (tensor<128x64xf16, #shared1>) -> tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma1}>> %130 = tt.dot %128, %129, %arg28 {allowTF32 = true} : tensor<128x128xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma1}>> * tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma1}>> -> tensor<128x64xf32, #mma1> - %131 = tt.load %arg29 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<128x64xf32, #blocked2> - %132 = triton_gpu.convert_layout %131 : (tensor<128x64xf32, #blocked2>) -> tensor<128x64xf32, #mma1> + %132 = tt.load %arg29 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<128x64xf32, #mma1> %133 = triton_gpu.convert_layout %126 : (tensor<128x128xf16, #shared1>) -> tensor<128x128xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma1}>> %134 = triton_gpu.convert_layout %66 : (tensor<128x64xf16, #shared1>) -> tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma1}>> %135 = tt.dot %133, %134, %132 {allowTF32 = true} : tensor<128x128xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma1}>> * tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma1}>> -> tensor<128x64xf32, #mma1> - %136 = triton_gpu.convert_layout %135 : (tensor<128x64xf32, #mma1>) -> tensor<128x64xf32, #blocked2> - tt.store %arg29, %136 : tensor<128x64xf32, #blocked2> - %137 = tt.addptr %arg29, %43 : tensor<128x64x!tt.ptr, #blocked2>, tensor<128x64xi32, #blocked2> + tt.store %arg29, %135 : tensor<128x64xf32, #mma1> + %137 = tt.addptr %arg29, %43 : tensor<128x64x!tt.ptr, #mma1>, tensor<128x64xi32, #mma1> %138 = tt.addptr %arg30, %42 : tensor<128x64x!tt.ptr, #blocked1>, tensor<128x64xi32, #blocked1> %139 = tt.addptr %arg31, %42 : tensor<128x64x!tt.ptr, #blocked1>, tensor<128x64xi32, #blocked1> - scf.yield %113, %130, %137, %138, %139 : tensor<128x64xf32, #mma1>, tensor<128x64xf32, #mma1>, tensor<128x64x!tt.ptr, #blocked2>, tensor<128x64x!tt.ptr, #blocked1>, tensor<128x64x!tt.ptr, #blocked1> + scf.yield %113, %130, %137, %138, %139 : tensor<128x64xf32, #mma1>, tensor<128x64xf32, #mma1>, tensor<128x64x!tt.ptr, #mma1>, tensor<128x64x!tt.ptr, #blocked1>, tensor<128x64x!tt.ptr, #blocked1> } %82 = tt.addptr %44, %62 : tensor<128x64x!tt.ptr, #blocked1>, tensor<128x64xi32, #blocked1> %81 = arith.truncf %79#0 : tensor<128x64xf32, #mma1> to tensor<128x64xf16, #mma1>