diff --git a/python/bwd.ttgir b/python/bwd.ttgir index eee5b7435..c79d41631 100644 --- a/python/bwd.ttgir +++ b/python/bwd.ttgir @@ -140,13 +140,13 @@ module attributes {"triton_gpu.num-warps" = 8 : i32} { %128 = triton_gpu.convert_layout %127 : (tensor<128x128xf16, #shared0>) -> tensor<128x128xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma1}>> %129 = triton_gpu.convert_layout %90 : (tensor<128x64xf16, #blocked1>) -> tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma1}>> %130 = tt.dot %128, %129, %arg28 {allowTF32 = true} : tensor<128x128xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma1}>> * tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma1}>> -> tensor<128x64xf32, #mma1> - //%131 = tt.load %arg29 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<128x64xf32, #blocked2> - //%132 = triton_gpu.convert_layout %131 : (tensor<128x64xf32, #blocked2>) -> tensor<128x64xf32, #mma1> - //%133 = triton_gpu.convert_layout %125 : (tensor<128x128xf16, #mma0>) -> tensor<128x128xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma1}>> - //%134 = triton_gpu.convert_layout %59 : (tensor<128x64xf16, #blocked1>) -> tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma1}>> - //%135 = tt.dot %133, %134, %132 {allowTF32 = true} : tensor<128x128xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma1}>> * tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma1}>> -> tensor<128x64xf32, #mma1> - //%136 = triton_gpu.convert_layout %135 : (tensor<128x64xf32, #mma1>) -> tensor<128x64xf32, #blocked2> - //tt.store %arg29, %136 : tensor<128x64xf32, #blocked2> + %131 = tt.load %arg29 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<128x64xf32, #blocked2> + %132 = triton_gpu.convert_layout %131 : (tensor<128x64xf32, #blocked2>) -> tensor<128x64xf32, #mma1> + %133 = triton_gpu.convert_layout %125 : (tensor<128x128xf16, #mma0>) -> tensor<128x128xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma1}>> + %134 = triton_gpu.convert_layout %59 : (tensor<128x64xf16, #blocked1>) -> tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma1}>> + %135 = tt.dot %133, %134, %132 {allowTF32 = true} : tensor<128x128xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma1}>> * tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma1}>> -> tensor<128x64xf32, #mma1> + %136 = triton_gpu.convert_layout %135 : (tensor<128x64xf32, #mma1>) -> tensor<128x64xf32, #blocked2> + tt.store %arg29, %136 : tensor<128x64xf32, #blocked2> %137 = tt.addptr %arg29, %43 : tensor<128x64x!tt.ptr, #blocked2>, tensor<128x64xi32, #blocked2> %138 = tt.addptr %arg30, %42 : tensor<128x64x!tt.ptr, #blocked1>, tensor<128x64xi32, #blocked1> %139 = tt.addptr %arg31, %42 : tensor<128x64x!tt.ptr, #blocked1>, tensor<128x64xi32, #blocked1> diff --git a/python/tutorials/06-fused-attention.py b/python/tutorials/06-fused-attention.py index 6eea40916..67a36fedc 100644 --- a/python/tutorials/06-fused-attention.py +++ b/python/tutorials/06-fused-attention.py @@ -326,8 +326,6 @@ def test_op(Z, H, N_CTX, D_HEAD, dtype=torch.float16): triton.testing.assert_almost_equal(ref_dv, tri_dv) triton.testing.assert_almost_equal(ref_dk, tri_dk) triton.testing.assert_almost_equal(ref_dq, tri_dq) - print(ref_dk, tri_dk) - print(ref_dq, tri_dq) BATCH, N_HEADS, N_CTX, D_HEAD = 4, 48, 4096, 64 # vary seq length for fixed head and batch=4