[FRONTEND] add an attr for masked load without explicit other (#55)
This commit is contained in:
@@ -80,7 +80,7 @@ def TT_LoadOp : TT_Op<"load",
|
||||
|
||||
let arguments = (ins TT_PtrTensor:$ptr, BoolLike:$mask, TT_Type:$other,
|
||||
TT_CacheModifierAttr:$cache, TT_EvictionPolicyAttr:$evict,
|
||||
BoolAttr:$isVolatile);
|
||||
BoolAttr:$isVolatile, BoolAttr:$isOtherUnspecified);
|
||||
|
||||
let results = (outs TT_Type:$result);
|
||||
|
||||
@@ -89,7 +89,8 @@ def TT_LoadOp : TT_Op<"load",
|
||||
OpBuilder<(ins "Value":$ptr, "triton::CacheModifier":$cache,
|
||||
"triton::EvictionPolicy":$evict, "bool":$isVolatile)>,
|
||||
OpBuilder<(ins "Value":$ptr, "Value":$mask, "triton::CacheModifier":$cache,
|
||||
"triton::EvictionPolicy":$evict, "bool":$isVolatile)>
|
||||
"triton::EvictionPolicy":$evict, "bool":$isVolatile,
|
||||
"bool":$isOtherUnspecified)>
|
||||
];
|
||||
|
||||
let assemblyFormat = "$ptr`,` $mask`,` $other attr-dict `:` type($result)";
|
||||
|
@@ -48,7 +48,8 @@ def TTG_CopyAsyncOp : TTG_Op<"copy_async",
|
||||
|
||||
let arguments = (ins TT_PtrTensor:$ptr, BoolLike:$mask, TT_Type:$other,
|
||||
TT_CacheModifierAttr:$cache, TT_EvictionPolicyAttr:$evict,
|
||||
BoolAttr:$isVolatile);
|
||||
BoolAttr:$isVolatile,
|
||||
BoolAttr:$isOtherUnspecified);
|
||||
|
||||
let results = (outs TT_Type:$result);
|
||||
|
||||
|
@@ -230,7 +230,8 @@ struct TritonLoadPattern : public OpConversionPattern<triton::LoadOp> {
|
||||
Type retType = getTypeConverter()->convertType(op.getType());
|
||||
rewriter.replaceOpWithNewOp<triton::LoadOp>(
|
||||
op, retType, adaptor.ptr(), adaptor.mask(), adaptor.other(),
|
||||
adaptor.cache(), adaptor.evict(), adaptor.isVolatile());
|
||||
adaptor.cache(), adaptor.evict(), adaptor.isVolatile(),
|
||||
adaptor.isOtherUnspecified());
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
@@ -92,6 +92,8 @@ void LoadOp::build(::mlir::OpBuilder &builder, ::mlir::OperationState &state,
|
||||
::mlir::triton::EvictionPolicyAttr::get(builder.getContext(), evict));
|
||||
state.addAttribute(isVolatileAttrName(state.name),
|
||||
builder.getBoolAttr(isVolatile));
|
||||
state.addAttribute(isOtherUnspecifiedAttrName(state.name),
|
||||
builder.getBoolAttr(false));
|
||||
state.addTypes({resultType});
|
||||
}
|
||||
|
||||
|
@@ -40,8 +40,8 @@ def CombineGEPPattern : Pat<
|
||||
// select(cond, load(ptrs, broadcast(cond), ???), other)
|
||||
// => load(ptrs, broadcast(cond), other)
|
||||
def CombineSelectMaskedLoadPattern : Pat<
|
||||
(SelectOp $cond, (TT_LoadOp $ptrs, (TT_BroadcastOp:$bcast_res $cond), $other, $cache, $evict, $isVolatile), $falseValue),
|
||||
(TT_LoadOp $ptrs, $bcast_res, $falseValue, $cache, $evict, $isVolatile)>;
|
||||
(SelectOp $cond, (TT_LoadOp $ptrs, (TT_BroadcastOp:$bcast_res $cond), $other, $cache, $evict, $isVolatile, $isOtherUnspecified), $falseValue),
|
||||
(TT_LoadOp $ptrs, $bcast_res, $falseValue, $cache, $evict, $isVolatile, $isOtherUnspecified)>;
|
||||
|
||||
// broadcast(cst) => cst
|
||||
def getConstantValue : NativeCodeCall<"getConstantValue($_builder, $0, $1)">;
|
||||
|
@@ -7,8 +7,8 @@ include "triton/Dialect/Triton/IR/TritonOps.td"
|
||||
// convert_layout(load(...), #L) => copy_async(...); barrier
|
||||
// if #L is smem_layout
|
||||
def CopyAsyncOptPattern : Pat<
|
||||
(TTG_ConvertLayoutOp:$res (TT_LoadOp $ptr, $mask, $other, $cache, $evict, $isVolatile)),
|
||||
(TTG_CopyAsyncOp $ptr, $mask, $other, $cache, $evict, $isVolatile),
|
||||
(TTG_ConvertLayoutOp:$res (TT_LoadOp $ptr, $mask, $other, $cache, $evict, $isVolatile, $isOtherUnspecified)),
|
||||
(TTG_CopyAsyncOp $ptr, $mask, $other, $cache, $evict, $isVolatile, $isOtherUnspecified),
|
||||
[(Constraint<CPred<"isSharedLayout($0)">> $res)]>;
|
||||
|
||||
// ConvertLayout(ConvertLayout(x, #L0), #L1) => ConvertLayout(x, #L1)
|
||||
|
@@ -226,7 +226,7 @@ void LoopPipeliner::emitPrologue() {
|
||||
newOp = builder.create<triton::gpu::CopyAsyncOp>(
|
||||
op->getLoc(), loadsMapping[loadOp].getType(), loadOp.ptr(),
|
||||
loadOp.mask(), loadOp.other(), loadOp.cache(), loadOp.evict(),
|
||||
loadOp.isVolatile());
|
||||
loadOp.isVolatile(), loadOp.isOtherUnspecified());
|
||||
} else
|
||||
llvm_unreachable("This should be LoadOp");
|
||||
} else
|
||||
@@ -380,7 +380,7 @@ scf::ForOp LoopPipeliner::createNewForOp() {
|
||||
nextMapping.lookupOrDefault(loadOp.ptr()),
|
||||
nextMapping.lookupOrDefault(loadOp.mask()),
|
||||
nextMapping.lookupOrDefault(loadOp.other()), loadOp.cache(),
|
||||
loadOp.evict(), loadOp.isVolatile());
|
||||
loadOp.evict(), loadOp.isVolatile(), loadOp.isOtherUnspecified());
|
||||
} else
|
||||
nextOp = builder.clone(*op, nextMapping);
|
||||
// llvm::errs() << "epilogue cloning...: " << *op << "\n";
|
||||
|
@@ -1306,7 +1306,8 @@ void init_triton_ir(py::module &&m) {
|
||||
})
|
||||
.def("create_masked_load",
|
||||
[](mlir::OpBuilder &self, mlir::Value &ptrs, mlir::Value &mask,
|
||||
mlir::Value &other, mlir::triton::CacheModifier cacheModifier,
|
||||
std::optional<mlir::Value> &other,
|
||||
mlir::triton::CacheModifier cacheModifier,
|
||||
mlir::triton::EvictionPolicy evictionPolicy,
|
||||
bool isVolatile) -> mlir::Value {
|
||||
auto loc = self.getUnknownLoc();
|
||||
@@ -1315,9 +1316,22 @@ void init_triton_ir(py::module &&m) {
|
||||
mlir::Type elementType = ptrType.getElementType()
|
||||
.dyn_cast<mlir::triton::PointerType>()
|
||||
.getPointeeType();
|
||||
return self.create<mlir::triton::LoadOp>(
|
||||
loc, mlir::RankedTensorType::get(shape, elementType), ptrs,
|
||||
mask, other, cacheModifier, evictionPolicy, isVolatile);
|
||||
mlir::Type resultType =
|
||||
mlir::RankedTensorType::get(shape, elementType);
|
||||
if (other.has_value()) {
|
||||
return self.create<mlir::triton::LoadOp>(
|
||||
loc, resultType, ptrs, mask, other.value(), cacheModifier,
|
||||
evictionPolicy, isVolatile, false);
|
||||
} else {
|
||||
mlir::Value dummy_other = self.create<mlir::arith::ConstantOp>(
|
||||
loc, resultType,
|
||||
mlir::DenseElementsAttr::get(resultType,
|
||||
self.getZeroAttr(elementType)));
|
||||
return self.create<mlir::triton::LoadOp>(
|
||||
loc, mlir::RankedTensorType::get(shape, elementType), ptrs,
|
||||
mask, dummy_other, cacheModifier, evictionPolicy, isVolatile,
|
||||
true);
|
||||
}
|
||||
})
|
||||
.def("create_masked_store",
|
||||
[](mlir::OpBuilder &self, mlir::Value &ptrs, mlir::Value &val,
|
||||
|
@@ -706,23 +706,17 @@ def load(ptr: tl.tensor,
|
||||
else:
|
||||
dst_ty = elt_ty
|
||||
|
||||
if not mask and not other:
|
||||
if not mask:
|
||||
if other:
|
||||
raise ValueError("`other` cannot be provided without `mask`")
|
||||
return tl.tensor(builder.create_load(ptr.handle, cache, eviction, is_volatile),
|
||||
dst_ty)
|
||||
if not mask:
|
||||
raise ValueError("`other` cannot be provided without `mask`")
|
||||
|
||||
if not other:
|
||||
other_ir = ir.undef.get(elt_ty.to_ir(builder))
|
||||
if ptr.type.is_block():
|
||||
other_ir = builder.create_splat(other_ir, ptr.type.get_block_shapes())
|
||||
other = tl.tensor(other_ir, dst_ty)
|
||||
|
||||
return tl.tensor(builder.create_masked_load(ptr.handle,
|
||||
mask.handle,
|
||||
other.handle,
|
||||
cache, eviction, is_volatile),
|
||||
dst_ty)
|
||||
else:
|
||||
return tl.tensor(builder.create_masked_load(ptr.handle,
|
||||
mask.handle,
|
||||
other.handle if other else None,
|
||||
cache, eviction, is_volatile),
|
||||
dst_ty)
|
||||
|
||||
|
||||
def store(ptr: tl.tensor,
|
||||
|
@@ -46,7 +46,7 @@ func @permute_2d(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg1: i32 {t
|
||||
// CHECK-NEXT: Contiguity: [128, 1] ; Divisibility: [16, 1] ; Constancy: [1, 1]
|
||||
%19 = tt.getelementptr %17, %18 : tensor<128x128x!tt.ptr<f32>>
|
||||
// CHECK-NEXT: Contiguity: [1, 1] ; Divisibility: [1, 1] ; Constancy: [1, 1]
|
||||
%20 = tt.load %10, %cst, %cst_0 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<128x128xf32>
|
||||
%20 = tt.load %10, %cst, %cst_0 {cache = 1 : i32, evict = 1 : i32, isOtherUnspecified = false, isVolatile = false} : tensor<128x128xf32>
|
||||
tt.store %19, %20, %cst, : tensor<128x128xf32>
|
||||
return
|
||||
}
|
@@ -20,10 +20,10 @@ func @matmul_loop(%lb : index, %ub : index, %step : index, %A : !tt.ptr<f16>, %B
|
||||
%b_off = arith.constant dense<4> : tensor<32x128xi32, #BL>
|
||||
|
||||
scf.for %iv = %lb to %ub step %step iter_args(%a_ptr = %a_ptr_init, %b_ptr = %b_ptr_init, %prev_c = %c_init) -> (tensor<128x32x!tt.ptr<f16>, #AL>, tensor<32x128x!tt.ptr<f16>, #BL>, tensor<128x128xf32, #C>) {
|
||||
%a_ = tt.load %a_ptr, %a_mask, %a_other {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<128x32xf16, #AL>
|
||||
%a_ = tt.load %a_ptr, %a_mask, %a_other {cache = 1 : i32, evict = 1 : i32, isOtherUnspecified = false, isVolatile = false} : tensor<128x32xf16, #AL>
|
||||
// CHECK: offset = 0, size = 8192
|
||||
%a = triton_gpu.convert_layout %a_ : (tensor<128x32xf16, #AL>) -> tensor<128x32xf16, #A>
|
||||
%b_ = tt.load %b_ptr, %b_mask, %b_other {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<32x128xf16, #BL>
|
||||
%b_ = tt.load %b_ptr, %b_mask, %b_other {cache = 1 : i32, evict = 1 : i32, isOtherUnspecified = false, isVolatile = false} : tensor<32x128xf16, #BL>
|
||||
// CHECK: offset = 8192, size = 8192
|
||||
%b = triton_gpu.convert_layout %b_ : (tensor<32x128xf16, #BL>) -> tensor<32x128xf16, #B>
|
||||
|
||||
@@ -47,17 +47,17 @@ func @synthesized_reusable(%A : !tt.ptr<f16>) {
|
||||
|
||||
%a_ptr = tt.broadcast %A : (!tt.ptr<f16>) -> tensor<128x32x!tt.ptr<f16>, #AL>
|
||||
%b_ptr = tt.broadcast %A : (!tt.ptr<f16>) -> tensor<32x128x!tt.ptr<f16>, #AL>
|
||||
%a1_ = tt.load %a_ptr, %cst1, %cst2 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<128x32xf16, #AL>
|
||||
%a1_ = tt.load %a_ptr, %cst1, %cst2 {cache = 1 : i32, evict = 1 : i32, isOtherUnspecified = false, isVolatile = false} : tensor<128x32xf16, #AL>
|
||||
// CHECK: offset = 0, size = 8192
|
||||
%a1 = triton_gpu.convert_layout %a1_ : (tensor<128x32xf16, #AL>) -> tensor<128x32xf16, #A>
|
||||
%a2_ = tt.load %b_ptr, %cst3, %cst4 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<32x128xf16, #AL>
|
||||
%a2_ = tt.load %b_ptr, %cst3, %cst4 {cache = 1 : i32, evict = 1 : i32, isOtherUnspecified = false, isVolatile = false} : tensor<32x128xf16, #AL>
|
||||
// CHECK: offset = 8192, size = 8192
|
||||
%a2 = triton_gpu.convert_layout %a2_ : (tensor<32x128xf16, #AL>) -> tensor<32x128xf16, #A>
|
||||
%a3_ = tt.load %a_ptr, %cst1, %cst2 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<128x32xf16, #AL>
|
||||
%a3_ = tt.load %a_ptr, %cst1, %cst2 {cache = 1 : i32, evict = 1 : i32, isOtherUnspecified = false, isVolatile = false} : tensor<128x32xf16, #AL>
|
||||
// CHECK: offset = 16384, size = 8192
|
||||
%a3 = triton_gpu.convert_layout %a3_ : (tensor<128x32xf16, #AL>) -> tensor<128x32xf16, #A>
|
||||
%c = tt.dot %a1, %a2, %c_init {allowTF32 = true} : tensor<128x32xf16, #A> * tensor<32x128xf16, #B> -> tensor<128x128xf32, #C>
|
||||
%a4_ = tt.load %b_ptr, %cst3, %cst4 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<32x128xf16, #AL>
|
||||
%a4_ = tt.load %b_ptr, %cst3, %cst4 {cache = 1 : i32, evict = 1 : i32, isOtherUnspecified = false, isVolatile = false} : tensor<32x128xf16, #AL>
|
||||
// CHECK: offset = 0, size = 8192
|
||||
%a4 = triton_gpu.convert_layout %a4_ : (tensor<32x128xf16, #AL>) -> tensor<32x128xf16, #A>
|
||||
%c1 = tt.dot %a3, %a4, %c {allowTF32 = true} : tensor<128x32xf16, #A> * tensor<32x128xf16, #B> -> tensor<128x128xf32, #C>
|
||||
|
@@ -21,7 +21,7 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} {
|
||||
func @basic_load(%a_ptr_init : tensor<256x!tt.ptr<f32>, #blocked0>, %cst : tensor<256xi1, #blocked0>, %cst_0 : tensor<256xf32, #blocked0>) {
|
||||
// CHECK: llvm.inline_asm
|
||||
// CHECK: llvm.inline_asm
|
||||
%1 = tt.load %a_ptr_init, %cst, %cst_0 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<256xf32, #blocked0>
|
||||
%1 = tt.load %a_ptr_init, %cst, %cst_0 {cache = 1 : i32, evict = 1 : i32, isOtherUnspecified = false, isVolatile = false} : tensor<256xf32, #blocked0>
|
||||
return
|
||||
}
|
||||
}
|
||||
@@ -36,7 +36,7 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} {
|
||||
// CHECK-SAME: ld.global.v4.b32
|
||||
// CHECK: llvm.inline_asm
|
||||
// CHECK-SAME: ld.global.v4.b32
|
||||
%1 = tt.load %a_ptr_init, %cst, %cst_0 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<256xf32, #blocked0>
|
||||
%1 = tt.load %a_ptr_init, %cst, %cst_0 {cache = 1 : i32, evict = 1 : i32, isOtherUnspecified = false, isVolatile = false} : tensor<256xf32, #blocked0>
|
||||
return
|
||||
}
|
||||
}
|
||||
@@ -51,7 +51,7 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} {
|
||||
// CHECK-SAME: ld.global.v2.b32
|
||||
// CHECK: llvm.inline_asm
|
||||
// CHECK-SAME: ld.global.v2.b32
|
||||
%1 = tt.load %a_ptr_init, %cst, %cst_0 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<256xf16, #blocked0>
|
||||
%1 = tt.load %a_ptr_init, %cst, %cst_0 {cache = 1 : i32, evict = 1 : i32, isOtherUnspecified = false, isVolatile = false} : tensor<256xf16, #blocked0>
|
||||
return
|
||||
}
|
||||
}
|
||||
@@ -64,7 +64,7 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} {
|
||||
// CHECK-LABEL: masked_load_const_other
|
||||
func @masked_load_const_other(%a_ptr_init : tensor<256x!tt.ptr<f32>, #blocked0>, %cst : tensor<256xi1, #blocked0>) {
|
||||
%cst_0 = arith.constant dense<0.000000e+00> : tensor<256xf32, #blocked0>
|
||||
%1 = tt.load %a_ptr_init, %cst, %cst_0 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<256xf32, #blocked0>
|
||||
%1 = tt.load %a_ptr_init, %cst, %cst_0 {cache = 1 : i32, evict = 1 : i32, isOtherUnspecified = false, isVolatile = false} : tensor<256xf32, #blocked0>
|
||||
return
|
||||
}
|
||||
}
|
||||
|
@@ -45,16 +45,20 @@ func @test_combine_gep_pattern(%base: !tt.ptr<f32>) -> tensor<8x!tt.ptr<f32>> {
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @test_combine_select_masked_load_pattern
|
||||
func @test_combine_select_masked_load_pattern(%ptr: tensor<8x!tt.ptr<f32>>, %cond: i1) -> tensor<8xf32> {
|
||||
func @test_combine_select_masked_load_pattern(%ptr: tensor<8x!tt.ptr<f32>>, %cond: i1) -> (tensor<8xf32>, tensor<8xf32>) {
|
||||
%mask = tt.broadcast %cond : (i1) -> tensor<8xi1>
|
||||
%false_val = arith.constant dense<0.0> : tensor<8xf32>
|
||||
|
||||
// CHECK: %[[res:.*]] = tt.load %{{.*}}, %{{.*}}, %{{.*}} {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<8xf32>
|
||||
%x = tt.load %ptr, %mask, %false_val {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<8xf32>
|
||||
// CHECK: %[[res1:.*]] = tt.load %{{.*}}, %{{.*}}, %{{.*}} {cache = 1 : i32, evict = 1 : i32, isOtherUnspecified = false, isVolatile = false} : tensor<8xf32>
|
||||
%x = tt.load %ptr, %mask, %false_val {cache = 1 : i32, evict = 1 : i32, isOtherUnspecified = false, isVolatile = false} : tensor<8xf32>
|
||||
%0 = select %cond, %x, %false_val : tensor<8xf32>
|
||||
|
||||
// CHECK: return %[[res]] : tensor<8xf32>
|
||||
return %0 : tensor<8xf32>
|
||||
// CHECK: %[[res2:.*]] = tt.load %{{.*}}, %{{.*}}, %{{.*}} {cache = 1 : i32, evict = 1 : i32, isOtherUnspecified = true, isVolatile = false} : tensor<8xf32>
|
||||
%y = tt.load %ptr, %mask, %false_val {cache = 1 : i32, evict = 1 : i32, isOtherUnspecified = true, isVolatile = false} : tensor<8xf32>
|
||||
%1 = select %cond, %y, %false_val : tensor<8xf32>
|
||||
|
||||
// CHECK: return %[[res1]], %[[res2]] : tensor<8xf32>, tensor<8xf32>
|
||||
return %0, %1 : tensor<8xf32>, tensor<8xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @test_combine_broadcast_constant_pattern
|
||||
|
@@ -24,10 +24,10 @@ module {
|
||||
%15:3 = scf.for %arg6 = %12 to %13 step %14 iter_args(%arg7 = %11, %arg8 = %8, %arg9 = %10) -> (tensor<256xf32>, tensor<256x!tt.ptr<f32>>, tensor<256x!tt.ptr<f32>>) {
|
||||
%cst_0 = arith.constant 0.000000e+00 : f32
|
||||
%18 = tt.broadcast %cst_0 : (f32) -> tensor<256xf32>
|
||||
%19 = tt.load %arg8, %6, %18 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<256xf32>
|
||||
%19 = tt.load %arg8, %6, %18 {cache = 1 : i32, evict = 1 : i32, isOtherUnspecified = false, isVolatile = false} : tensor<256xf32>
|
||||
%cst_1 = arith.constant 0.000000e+00 : f32
|
||||
%20 = tt.broadcast %cst_1 : (f32) -> tensor<256xf32>
|
||||
%21 = tt.load %arg9, %6, %20 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<256xf32>
|
||||
%21 = tt.load %arg9, %6, %20 {cache = 1 : i32, evict = 1 : i32, isOtherUnspecified = false, isVolatile = false} : tensor<256xf32>
|
||||
%22 = arith.addf %19, %21 : tensor<256xf32>
|
||||
%23 = arith.addf %arg7, %22 : tensor<256xf32>
|
||||
%24 = tt.broadcast %arg5 : (i32) -> tensor<256xi32>
|
||||
|
@@ -40,7 +40,7 @@ func @transpose(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32},
|
||||
%16 = tt.broadcast %14 : (tensor<1x64xi32, #blocked2>) -> tensor<64x64xi32, #blocked2>
|
||||
%17 = triton_gpu.convert_layout %16 : (tensor<64x64xi32, #blocked2>) -> tensor<64x64xi32, #blocked1>
|
||||
%18 = tt.getelementptr %15, %17 : tensor<64x64x!tt.ptr<f32>, #blocked1>
|
||||
%19 = tt.load %10, %cst, %cst_0 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<64x64xf32, #blocked1>
|
||||
%19 = tt.load %10, %cst, %cst_0 {cache = 1 : i32, evict = 1 : i32, isOtherUnspecified = false, isVolatile = false} : tensor<64x64xf32, #blocked1>
|
||||
tt.store %18, %19, %cst, : tensor<64x64xf32, #blocked1>
|
||||
return
|
||||
}
|
@@ -32,9 +32,9 @@ func @matmul_loop(%lb : index, %ub : index, %step : index, %A : !tt.ptr<f16>, %B
|
||||
%b_off = arith.constant dense<4> : tensor<32x128xi32, #BL>
|
||||
|
||||
scf.for %iv = %lb to %ub step %step iter_args(%a_ptr = %a_ptr_init, %b_ptr = %b_ptr_init, %prev_c = %c_init) -> (tensor<128x32x!tt.ptr<f16>, #AL>, tensor<32x128x!tt.ptr<f16>, #BL>, tensor<128x128xf32, #C>) {
|
||||
%a_ = tt.load %a_ptr, %a_mask, %a_other {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<128x32xf16, #AL>
|
||||
%a_ = tt.load %a_ptr, %a_mask, %a_other {cache = 1 : i32, evict = 1 : i32, isOtherUnspecified = false, isVolatile = false} : tensor<128x32xf16, #AL>
|
||||
%a = triton_gpu.convert_layout %a_ : (tensor<128x32xf16, #AL>) -> tensor<128x32xf16, #A>
|
||||
%b_ = tt.load %b_ptr, %b_mask, %b_other {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<32x128xf16, #BL>
|
||||
%b_ = tt.load %b_ptr, %b_mask, %b_other {cache = 1 : i32, evict = 1 : i32, isOtherUnspecified = false, isVolatile = false} : tensor<32x128xf16, #BL>
|
||||
%b = triton_gpu.convert_layout %b_ : (tensor<32x128xf16, #BL>) -> tensor<32x128xf16, #B>
|
||||
|
||||
%c = tt.dot %a, %b, %prev_c {allowTF32 = true} : tensor<128x32xf16, #A> * tensor<32x128xf16, #B> -> tensor<128x128xf32, #C>
|
||||
@@ -73,9 +73,9 @@ func @matmul_loop_nested(%lb : index, %ub : index, %step : index, %A : !tt.ptr<f
|
||||
%b_off = arith.constant dense<4> : tensor<32x128xi32, #BL>
|
||||
|
||||
scf.for %iv = %lb to %ub step %step iter_args(%a_ptr = %a_ptr_init, %b_ptr = %b_ptr_init, %prev_c = %c_init) -> (tensor<128x32x!tt.ptr<f16>, #AL>, tensor<32x128x!tt.ptr<f16>, #BL>, tensor<128x128xf32, #C>) {
|
||||
%a_ = tt.load %a_ptr, %a_mask, %a_other {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<128x32xf16, #AL>
|
||||
%a_ = tt.load %a_ptr, %a_mask, %a_other {cache = 1 : i32, evict = 1 : i32, isOtherUnspecified = false, isVolatile = false} : tensor<128x32xf16, #AL>
|
||||
%a = triton_gpu.convert_layout %a_ : (tensor<128x32xf16, #AL>) -> tensor<128x32xf16, #A>
|
||||
%b_ = tt.load %b_ptr, %b_mask, %b_other {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<32x128xf16, #BL>
|
||||
%b_ = tt.load %b_ptr, %b_mask, %b_other {cache = 1 : i32, evict = 1 : i32, isOtherUnspecified = false, isVolatile = false} : tensor<32x128xf16, #BL>
|
||||
%b = triton_gpu.convert_layout %b_ : (tensor<32x128xf16, #BL>) -> tensor<32x128xf16, #B>
|
||||
|
||||
%c = tt.dot %a, %b, %prev_c {allowTF32 = true} : tensor<128x32xf16, #A> * tensor<32x128xf16, #B> -> tensor<128x128xf32, #C>
|
||||
@@ -103,7 +103,7 @@ func @matmul_loop_single_pipeline(%lb : index, %ub : index, %step : index, %A :
|
||||
%a_mask = arith.constant dense<true> : tensor<128x32xi1, #AL>
|
||||
%a_other = arith.constant dense<0.00e+00> : tensor<128x32xf16, #AL>
|
||||
|
||||
%a_ = tt.load %a_ptr_init, %a_mask, %a_other {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<128x32xf16, #AL>
|
||||
%a_ = tt.load %a_ptr_init, %a_mask, %a_other {cache = 1 : i32, evict = 1 : i32, isOtherUnspecified = false, isVolatile = false} : tensor<128x32xf16, #AL>
|
||||
%a = triton_gpu.convert_layout %a_ : (tensor<128x32xf16, #AL>) -> tensor<128x32xf16, #A>
|
||||
|
||||
%b_mask = arith.constant dense<true> : tensor<32x128xi1, #BL>
|
||||
@@ -113,7 +113,7 @@ func @matmul_loop_single_pipeline(%lb : index, %ub : index, %step : index, %A :
|
||||
%b_off = arith.constant dense<4> : tensor<32x128xi32, #BL>
|
||||
|
||||
scf.for %iv = %lb to %ub step %step iter_args(%b_ptr = %b_ptr_init, %prev_c = %c_init) -> (tensor<32x128x!tt.ptr<f16>, #BL>, tensor<128x128xf32, #C>) {
|
||||
%b_ = tt.load %b_ptr, %b_mask, %b_other {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<32x128xf16, #BL>
|
||||
%b_ = tt.load %b_ptr, %b_mask, %b_other {cache = 1 : i32, evict = 1 : i32, isOtherUnspecified = false, isVolatile = false} : tensor<32x128xf16, #BL>
|
||||
%b = triton_gpu.convert_layout %b_ : (tensor<32x128xf16, #BL>) -> tensor<32x128xf16, #B>
|
||||
%c = tt.dot %a, %b, %prev_c {allowTF32 = true} : tensor<128x32xf16, #A> * tensor<32x128xf16, #B> -> tensor<128x128xf32, #C>
|
||||
%next_b_ptr = tt.getelementptr %b_ptr, %b_off : tensor<32x128x!tt.ptr<f16>, #BL>
|
||||
|
Reference in New Issue
Block a user