dq now mma
This commit is contained in:
@@ -72,24 +72,25 @@ void storeDistributedToShared(Value src, Value llSrc,
|
|||||||
Value staIdx1 = i32_val(0);
|
Value staIdx1 = i32_val(0);
|
||||||
Value stride0 = dstStrides[outOrd[0]];
|
Value stride0 = dstStrides[outOrd[0]];
|
||||||
Value stride1 = dstStrides[outOrd[1]];
|
Value stride1 = dstStrides[outOrd[1]];
|
||||||
if(auto addOp = dyn_cast<LLVM::AddOp>(dynIdx0.getDefiningOp()))
|
if (auto addOp = dyn_cast<LLVM::AddOp>(dynIdx0.getDefiningOp()))
|
||||||
if(auto cstRhs = dyn_cast<LLVM::ConstantOp>(addOp.getRhs().getDefiningOp())) {
|
if (auto cstRhs =
|
||||||
unsigned rhsVal = cstRhs.getValue().cast<IntegerAttr>().getValue().getSExtValue();
|
dyn_cast<LLVM::ConstantOp>(addOp.getRhs().getDefiningOp())) {
|
||||||
unsigned key = (rhsVal/outVec) % maxPhase;
|
unsigned rhsVal =
|
||||||
llvm::outs() << srcDistributedLayout.dyn_cast<MmaEncodingAttr>() << " " << rhsVal << " " << key << "\n";
|
cstRhs.getValue().cast<IntegerAttr>().getValue().getSExtValue();
|
||||||
if(cache.find(key) == cache.end())
|
unsigned key = (rhsVal / outVec) % maxPhase;
|
||||||
|
if (cache.find(key) == cache.end())
|
||||||
cache[key] = dynIdx0;
|
cache[key] = dynIdx0;
|
||||||
dynIdx0 = cache[key];
|
dynIdx0 = cache[key];
|
||||||
staIdx0 = i32_val((rhsVal)/(outVec*maxPhase)*(outVec*maxPhase));
|
staIdx0 =
|
||||||
|
i32_val((rhsVal) / (outVec * maxPhase) * (outVec * maxPhase));
|
||||||
}
|
}
|
||||||
if(auto addOp = dyn_cast<LLVM::AddOp>(dynIdx1.getDefiningOp()))
|
if (auto addOp = dyn_cast<LLVM::AddOp>(dynIdx1.getDefiningOp()))
|
||||||
if(auto cstRhs = dyn_cast<LLVM::ConstantOp>(addOp.getRhs().getDefiningOp())) {
|
if (auto cstRhs =
|
||||||
|
dyn_cast<LLVM::ConstantOp>(addOp.getRhs().getDefiningOp())) {
|
||||||
dynIdx1 = addOp.getLhs();
|
dynIdx1 = addOp.getLhs();
|
||||||
staIdx1 = addOp.getRhs();
|
staIdx1 = addOp.getRhs();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
// offset along non-contiguous dimension
|
// offset along non-contiguous dimension
|
||||||
Value off1 = mul(dynIdx1, stride1);
|
Value off1 = mul(dynIdx1, stride1);
|
||||||
// swizzled offset along contiguous dimension
|
// swizzled offset along contiguous dimension
|
||||||
@@ -100,10 +101,9 @@ void storeDistributedToShared(Value src, Value llSrc,
|
|||||||
remained = udiv(remained, minVecVal);
|
remained = udiv(remained, minVecVal);
|
||||||
off0 = add(off0, mul(remained, minVecVal));
|
off0 = add(off0, mul(remained, minVecVal));
|
||||||
Value offset = add(off1, mul(off0, stride0));
|
Value offset = add(off1, mul(off0, stride0));
|
||||||
|
Value staOffset = add(mul(staIdx1, stride1), mul(staIdx0, stride0));
|
||||||
// add static offset
|
// add static offset
|
||||||
offset = add(offset, mul(staIdx1, stride1));
|
offset = add(offset, staOffset);
|
||||||
offset = add(offset, mul(staIdx0, stride0));
|
|
||||||
|
|
||||||
// step 3: store
|
// step 3: store
|
||||||
Value smemAddr = gep(elemPtrTy, smemBase, offset);
|
Value smemAddr = gep(elemPtrTy, smemBase, offset);
|
||||||
|
@@ -31,22 +31,22 @@ module attributes {"triton_gpu.num-warps" = 8 : i32} {
|
|||||||
%14 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #blocked0>
|
%14 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #blocked0>
|
||||||
%15 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>>
|
%15 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>>
|
||||||
%16 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #mma0}>>
|
%16 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #mma0}>>
|
||||||
%17 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked2}>>
|
%17 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #mma1}>>
|
||||||
%18 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #mma0}>>
|
%18 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #mma0}>>
|
||||||
%19 = tt.splat %arg14 : (i32) -> tensor<128x1xi32, #blocked1>
|
%19 = tt.splat %arg14 : (i32) -> tensor<128x1xi32, #blocked1>
|
||||||
%20 = tt.splat %arg14 : (i32) -> tensor<128x1xi32, #blocked2>
|
%20 = tt.splat %arg14 : (i32) -> tensor<128x1xi32, #mma1>
|
||||||
%21 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>>
|
%21 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>>
|
||||||
%22 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = #blocked2}>>
|
%22 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = #mma1}>>
|
||||||
%23 = tt.expand_dims %21 {axis = 0 : i32} : (tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>>) -> tensor<1x64xi32, #blocked1>
|
%23 = tt.expand_dims %21 {axis = 0 : i32} : (tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>>) -> tensor<1x64xi32, #blocked1>
|
||||||
%24 = tt.broadcast %23 : (tensor<1x64xi32, #blocked1>) -> tensor<128x64xi32, #blocked1>
|
%24 = tt.broadcast %23 : (tensor<1x64xi32, #blocked1>) -> tensor<128x64xi32, #blocked1>
|
||||||
%25 = tt.expand_dims %22 {axis = 0 : i32} : (tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = #blocked2}>>) -> tensor<1x64xi32, #blocked2>
|
%25 = tt.expand_dims %22 {axis = 0 : i32} : (tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = #mma1}>>) -> tensor<1x64xi32, #mma1>
|
||||||
%26 = tt.broadcast %25 : (tensor<1x64xi32, #blocked2>) -> tensor<128x64xi32, #blocked2>
|
%26 = tt.broadcast %25 : (tensor<1x64xi32, #mma1>) -> tensor<128x64xi32, #mma1>
|
||||||
%27 = tt.splat %6 : (!tt.ptr<f16>) -> tensor<128x64x!tt.ptr<f16>, #blocked1>
|
%27 = tt.splat %6 : (!tt.ptr<f16>) -> tensor<128x64x!tt.ptr<f16>, #blocked1>
|
||||||
%28 = tt.splat %arg17 : (i32) -> tensor<128x1xi32, #blocked1>
|
%28 = tt.splat %arg17 : (i32) -> tensor<128x1xi32, #blocked1>
|
||||||
%29 = tt.splat %7 : (!tt.ptr<f16>) -> tensor<128x64x!tt.ptr<f16>, #blocked1>
|
%29 = tt.splat %7 : (!tt.ptr<f16>) -> tensor<128x64x!tt.ptr<f16>, #blocked1>
|
||||||
%30 = tt.splat %8 : (!tt.ptr<f16>) -> tensor<128x64x!tt.ptr<f16>, #blocked1>
|
%30 = tt.splat %8 : (!tt.ptr<f16>) -> tensor<128x64x!tt.ptr<f16>, #blocked1>
|
||||||
%31 = tt.splat %9 : (!tt.ptr<f16>) -> tensor<128x64x!tt.ptr<f16>, #blocked1>
|
%31 = tt.splat %9 : (!tt.ptr<f16>) -> tensor<128x64x!tt.ptr<f16>, #blocked1>
|
||||||
%32 = tt.splat %10 : (!tt.ptr<f32>) -> tensor<128x64x!tt.ptr<f32>, #blocked2>
|
%32 = tt.splat %10 : (!tt.ptr<f32>) -> tensor<128x64x!tt.ptr<f32>, #mma1>
|
||||||
%33 = arith.muli %0, %arg23 : i32
|
%33 = arith.muli %0, %arg23 : i32
|
||||||
%34 = tt.addptr %arg11, %33 : !tt.ptr<f32>, i32
|
%34 = tt.addptr %arg11, %33 : !tt.ptr<f32>, i32
|
||||||
%35 = tt.addptr %arg10, %33 : !tt.ptr<f32>, i32
|
%35 = tt.addptr %arg10, %33 : !tt.ptr<f32>, i32
|
||||||
@@ -57,7 +57,7 @@ module attributes {"triton_gpu.num-warps" = 8 : i32} {
|
|||||||
%40 = tt.splat %34 : (!tt.ptr<f32>) -> tensor<128x!tt.ptr<f32>, #blocked0>
|
%40 = tt.splat %34 : (!tt.ptr<f32>) -> tensor<128x!tt.ptr<f32>, #blocked0>
|
||||||
%41 = arith.muli %arg14, %c128_i32 : i32
|
%41 = arith.muli %arg14, %c128_i32 : i32
|
||||||
%42 = tt.splat %41 : (i32) -> tensor<128x64xi32, #blocked1>
|
%42 = tt.splat %41 : (i32) -> tensor<128x64xi32, #blocked1>
|
||||||
%43 = tt.splat %41 : (i32) -> tensor<128x64xi32, #blocked2>
|
%43 = tt.splat %41 : (i32) -> tensor<128x64xi32, #mma1>
|
||||||
%44 = tt.splat %12 : (!tt.ptr<f16>) -> tensor<128x64x!tt.ptr<f16>, #blocked1>
|
%44 = tt.splat %12 : (!tt.ptr<f16>) -> tensor<128x64x!tt.ptr<f16>, #blocked1>
|
||||||
%45 = tt.splat %11 : (!tt.ptr<f16>) -> tensor<128x64x!tt.ptr<f16>, #blocked1>
|
%45 = tt.splat %11 : (!tt.ptr<f16>) -> tensor<128x64x!tt.ptr<f16>, #blocked1>
|
||||||
scf.for %arg25 = %c0 to %13 step %c1 {
|
scf.for %arg25 = %c0 to %13 step %c1 {
|
||||||
@@ -65,11 +65,11 @@ module attributes {"triton_gpu.num-warps" = 8 : i32} {
|
|||||||
%47 = arith.muli %46, %c128_i32 : i32
|
%47 = arith.muli %46, %c128_i32 : i32
|
||||||
%48 = tt.splat %47 : (i32) -> tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>>
|
%48 = tt.splat %47 : (i32) -> tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>>
|
||||||
%49 = tt.splat %47 : (i32) -> tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #mma0}>>
|
%49 = tt.splat %47 : (i32) -> tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #mma0}>>
|
||||||
%50 = tt.splat %47 : (i32) -> tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked2}>>
|
%50 = tt.splat %47 : (i32) -> tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #mma1}>>
|
||||||
%51 = arith.addi %48, %15 : tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>>
|
%51 = arith.addi %48, %15 : tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>>
|
||||||
%52 = arith.addi %50, %17 : tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked2}>>
|
%52 = arith.addi %50, %17 : tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #mma1}>>
|
||||||
%53 = tt.expand_dims %51 {axis = 1 : i32} : (tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>>) -> tensor<128x1xi32, #blocked1>
|
%53 = tt.expand_dims %51 {axis = 1 : i32} : (tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>>) -> tensor<128x1xi32, #blocked1>
|
||||||
%54 = tt.expand_dims %52 {axis = 1 : i32} : (tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked2}>>) -> tensor<128x1xi32, #blocked2>
|
%54 = tt.expand_dims %52 {axis = 1 : i32} : (tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #mma1}>>) -> tensor<128x1xi32, #mma1>
|
||||||
%55 = arith.muli %53, %28 : tensor<128x1xi32, #blocked1>
|
%55 = arith.muli %53, %28 : tensor<128x1xi32, #blocked1>
|
||||||
%56 = tt.broadcast %55 : (tensor<128x1xi32, #blocked1>) -> tensor<128x64xi32, #blocked1>
|
%56 = tt.broadcast %55 : (tensor<128x1xi32, #blocked1>) -> tensor<128x64xi32, #blocked1>
|
||||||
%57 = arith.addi %56, %24 : tensor<128x64xi32, #blocked1>
|
%57 = arith.addi %56, %24 : tensor<128x64xi32, #blocked1>
|
||||||
@@ -88,13 +88,13 @@ module attributes {"triton_gpu.num-warps" = 8 : i32} {
|
|||||||
%70 = tt.broadcast %69 : (tensor<1x128xi32, #mma0>) -> tensor<128x128xi32, #mma0>
|
%70 = tt.broadcast %69 : (tensor<1x128xi32, #mma0>) -> tensor<128x128xi32, #mma0>
|
||||||
%71 = triton_gpu.convert_layout %64 : (tensor<128x64xf16, #blocked1>) -> tensor<128x64xf16, #shared1>
|
%71 = triton_gpu.convert_layout %64 : (tensor<128x64xf16, #blocked1>) -> tensor<128x64xf16, #shared1>
|
||||||
%72 = tt.trans %71 : (tensor<128x64xf16, #shared1>) -> tensor<64x128xf16, #shared0>
|
%72 = tt.trans %71 : (tensor<128x64xf16, #shared1>) -> tensor<64x128xf16, #shared0>
|
||||||
%73 = arith.muli %54, %20 : tensor<128x1xi32, #blocked2>
|
%73 = arith.muli %54, %20 : tensor<128x1xi32, #mma1>
|
||||||
%74 = tt.broadcast %73 : (tensor<128x1xi32, #blocked2>) -> tensor<128x64xi32, #blocked2>
|
%74 = tt.broadcast %73 : (tensor<128x1xi32, #mma1>) -> tensor<128x64xi32, #mma1>
|
||||||
%75 = arith.addi %74, %26 : tensor<128x64xi32, #blocked2>
|
%75 = arith.addi %74, %26 : tensor<128x64xi32, #mma1>
|
||||||
%76 = tt.addptr %32, %75 : tensor<128x64x!tt.ptr<f32>, #blocked2>, tensor<128x64xi32, #blocked2>
|
%76 = tt.addptr %32, %75 : tensor<128x64x!tt.ptr<f32>, #mma1>, tensor<128x64xi32, #mma1>
|
||||||
%77 = tt.addptr %27, %62 : tensor<128x64x!tt.ptr<f16>, #blocked1>, tensor<128x64xi32, #blocked1>
|
%77 = tt.addptr %27, %62 : tensor<128x64x!tt.ptr<f16>, #blocked1>, tensor<128x64xi32, #blocked1>
|
||||||
%78 = tt.addptr %31, %62 : tensor<128x64x!tt.ptr<f16>, #blocked1>, tensor<128x64xi32, #blocked1>
|
%78 = tt.addptr %31, %62 : tensor<128x64x!tt.ptr<f16>, #blocked1>, tensor<128x64xi32, #blocked1>
|
||||||
%79:5 = scf.for %arg26 = %65 to %37 step %c128 iter_args(%arg27 = %cst, %arg28 = %cst, %arg29 = %76, %arg30 = %77, %arg31 = %78) -> (tensor<128x64xf32, #mma1>, tensor<128x64xf32, #mma1>, tensor<128x64x!tt.ptr<f32>, #blocked2>, tensor<128x64x!tt.ptr<f16>, #blocked1>, tensor<128x64x!tt.ptr<f16>, #blocked1>) {
|
%79:5 = scf.for %arg26 = %65 to %37 step %c128 iter_args(%arg27 = %cst, %arg28 = %cst, %arg29 = %76, %arg30 = %77, %arg31 = %78) -> (tensor<128x64xf32, #mma1>, tensor<128x64xf32, #mma1>, tensor<128x64x!tt.ptr<f32>, #mma1>, tensor<128x64x!tt.ptr<f16>, #blocked1>, tensor<128x64x!tt.ptr<f16>, #blocked1>) {
|
||||||
%86 = arith.index_cast %arg26 : index to i32
|
%86 = arith.index_cast %arg26 : index to i32
|
||||||
%87 = tt.splat %86 : (i32) -> tensor<128xi32, #blocked0>
|
%87 = tt.splat %86 : (i32) -> tensor<128xi32, #blocked0>
|
||||||
%88 = tt.splat %86 : (i32) -> tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #mma0}>>
|
%88 = tt.splat %86 : (i32) -> tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #mma0}>>
|
||||||
@@ -142,17 +142,15 @@ module attributes {"triton_gpu.num-warps" = 8 : i32} {
|
|||||||
%128 = triton_gpu.convert_layout %127 : (tensor<128x128xf16, #shared0>) -> tensor<128x128xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma1}>>
|
%128 = triton_gpu.convert_layout %127 : (tensor<128x128xf16, #shared0>) -> tensor<128x128xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma1}>>
|
||||||
%129 = triton_gpu.convert_layout %90 : (tensor<128x64xf16, #shared1>) -> tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma1}>>
|
%129 = triton_gpu.convert_layout %90 : (tensor<128x64xf16, #shared1>) -> tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma1}>>
|
||||||
%130 = tt.dot %128, %129, %arg28 {allowTF32 = true} : tensor<128x128xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma1}>> * tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma1}>> -> tensor<128x64xf32, #mma1>
|
%130 = tt.dot %128, %129, %arg28 {allowTF32 = true} : tensor<128x128xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma1}>> * tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma1}>> -> tensor<128x64xf32, #mma1>
|
||||||
%131 = tt.load %arg29 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<128x64xf32, #blocked2>
|
%132 = tt.load %arg29 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<128x64xf32, #mma1>
|
||||||
%132 = triton_gpu.convert_layout %131 : (tensor<128x64xf32, #blocked2>) -> tensor<128x64xf32, #mma1>
|
|
||||||
%133 = triton_gpu.convert_layout %126 : (tensor<128x128xf16, #shared1>) -> tensor<128x128xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma1}>>
|
%133 = triton_gpu.convert_layout %126 : (tensor<128x128xf16, #shared1>) -> tensor<128x128xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma1}>>
|
||||||
%134 = triton_gpu.convert_layout %66 : (tensor<128x64xf16, #shared1>) -> tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma1}>>
|
%134 = triton_gpu.convert_layout %66 : (tensor<128x64xf16, #shared1>) -> tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma1}>>
|
||||||
%135 = tt.dot %133, %134, %132 {allowTF32 = true} : tensor<128x128xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma1}>> * tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma1}>> -> tensor<128x64xf32, #mma1>
|
%135 = tt.dot %133, %134, %132 {allowTF32 = true} : tensor<128x128xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma1}>> * tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma1}>> -> tensor<128x64xf32, #mma1>
|
||||||
%136 = triton_gpu.convert_layout %135 : (tensor<128x64xf32, #mma1>) -> tensor<128x64xf32, #blocked2>
|
tt.store %arg29, %135 : tensor<128x64xf32, #mma1>
|
||||||
tt.store %arg29, %136 : tensor<128x64xf32, #blocked2>
|
%137 = tt.addptr %arg29, %43 : tensor<128x64x!tt.ptr<f32>, #mma1>, tensor<128x64xi32, #mma1>
|
||||||
%137 = tt.addptr %arg29, %43 : tensor<128x64x!tt.ptr<f32>, #blocked2>, tensor<128x64xi32, #blocked2>
|
|
||||||
%138 = tt.addptr %arg30, %42 : tensor<128x64x!tt.ptr<f16>, #blocked1>, tensor<128x64xi32, #blocked1>
|
%138 = tt.addptr %arg30, %42 : tensor<128x64x!tt.ptr<f16>, #blocked1>, tensor<128x64xi32, #blocked1>
|
||||||
%139 = tt.addptr %arg31, %42 : tensor<128x64x!tt.ptr<f16>, #blocked1>, tensor<128x64xi32, #blocked1>
|
%139 = tt.addptr %arg31, %42 : tensor<128x64x!tt.ptr<f16>, #blocked1>, tensor<128x64xi32, #blocked1>
|
||||||
scf.yield %113, %130, %137, %138, %139 : tensor<128x64xf32, #mma1>, tensor<128x64xf32, #mma1>, tensor<128x64x!tt.ptr<f32>, #blocked2>, tensor<128x64x!tt.ptr<f16>, #blocked1>, tensor<128x64x!tt.ptr<f16>, #blocked1>
|
scf.yield %113, %130, %137, %138, %139 : tensor<128x64xf32, #mma1>, tensor<128x64xf32, #mma1>, tensor<128x64x!tt.ptr<f32>, #mma1>, tensor<128x64x!tt.ptr<f16>, #blocked1>, tensor<128x64x!tt.ptr<f16>, #blocked1>
|
||||||
}
|
}
|
||||||
%82 = tt.addptr %44, %62 : tensor<128x64x!tt.ptr<f16>, #blocked1>, tensor<128x64xi32, #blocked1>
|
%82 = tt.addptr %44, %62 : tensor<128x64x!tt.ptr<f16>, #blocked1>, tensor<128x64xi32, #blocked1>
|
||||||
%81 = arith.truncf %79#0 : tensor<128x64xf32, #mma1> to tensor<128x64xf16, #mma1>
|
%81 = arith.truncf %79#0 : tensor<128x64xf32, #mma1> to tensor<128x64xf16, #mma1>
|
||||||
|
Reference in New Issue
Block a user