rename sharded_layout => blocked_layout

This commit is contained in:
Yan Da
2022-06-05 16:14:59 +08:00
parent bbf75b492f
commit 7807f64ef3
7 changed files with 98 additions and 96 deletions

View File

@@ -48,8 +48,8 @@ And the associated TritonGPU MLIR
); );
} }
def TritonGPUShardedEncodingAttr : TritonGPU_Attr<"TritonGPUShardedEncoding"> { def TritonGPUBlockedEncodingAttr : TritonGPU_Attr<"TritonGPUBlockedEncoding"> {
let mnemonic = "sharded_layout"; let mnemonic = "blocked_layout";
let description = [{ let description = [{
An encoding where each warp owns a contiguous portion of the target tensor. This is typically the kind of data layout An encoding where each warp owns a contiguous portion of the target tensor. This is typically the kind of data layout
@@ -74,7 +74,7 @@ size } ....
A_{63, 0}[T60] A_{63, 1}[T60] ... A_{63, 6}[T63] A_{63, 7}[T63] A_{63, 8}[T60] A_{63, 9}[T60] ... A_{63, 14}[T63] A_{63, 15}[T63] A_{63, 0}[T60] A_{63, 1}[T60] ... A_{63, 6}[T63] A_{63, 7}[T63] A_{63, 8}[T60] A_{63, 9}[T60] ... A_{63, 14}[T63] A_{63, 15}[T63]
And the associated TritonGPU MLIR And the associated TritonGPU MLIR
#LAYOUT = #triton_gpu.sharded_layout<{ #LAYOUT = #triton_gpu.blocked_layout<{
threadTileSize = {2, 2} threadTileSize = {2, 2}
blockTileSize = {32, 8} blockTileSize = {32, 8}
}> }>

View File

@@ -37,7 +37,7 @@ static LogicalResult parseIntArrayAttr(AsmParser &parser,
#include "triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.cpp.inc" #include "triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.cpp.inc"
Attribute Attribute
TritonGPUShardedEncodingAttr::parse(AsmParser &parser, Type type) { TritonGPUBlockedEncodingAttr::parse(AsmParser &parser, Type type) {
if (parser.parseLess().failed()) if (parser.parseLess().failed())
return {}; return {};
// Parse the data as a dictionary // Parse the data as a dictionary
@@ -94,14 +94,14 @@ TritonGPUShardedEncodingAttr::parse(AsmParser &parser, Type type) {
} }
} }
return parser.getChecked<TritonGPUShardedEncodingAttr>(parser.getContext(), return parser.getChecked<TritonGPUBlockedEncodingAttr>(parser.getContext(),
threadTileSize, threadTileSize,
warpTileSize, warpTileSize,
blockTileSize, blockTileSize,
order); order);
} }
void TritonGPUShardedEncodingAttr::print(mlir::AsmPrinter &printer) const { void TritonGPUBlockedEncodingAttr::print(mlir::AsmPrinter &printer) const {
printer << "<{" printer << "<{"
<< "threadTileSize = [" << getThreadTileSize() << "]" << "threadTileSize = [" << getThreadTileSize() << "]"
<< ", warpTileSize = [" << getWarpTileSize() << "]" << ", warpTileSize = [" << getWarpTileSize() << "]"

View File

@@ -46,7 +46,7 @@ TritonGPUTypeConverter::TritonGPUTypeConverter(MLIRContext *context,
remainingThreads /= blockTileSize[dim]; remainingThreads /= blockTileSize[dim];
// TODO: will we need repetition? // TODO: will we need repetition?
} }
Attribute encoding = triton::gpu::TritonGPUShardedEncodingAttr::get( Attribute encoding = triton::gpu::TritonGPUBlockedEncodingAttr::get(
context, threadTileSize, warpTileSize, blockTileSize, order); context, threadTileSize, warpTileSize, blockTileSize, order);
return RankedTensorType::get(shape, elementType, encoding); return RankedTensorType::get(shape, elementType, encoding);
}); });

View File

@@ -50,7 +50,7 @@ private:
if (!encoding) if (!encoding)
return dotOp.emitError() << name << " should have encoding"; return dotOp.emitError() << name << " should have encoding";
if (!encoding.isa<triton::gpu::TritonGPUMmaEncodingAttr>() && if (!encoding.isa<triton::gpu::TritonGPUMmaEncodingAttr>() &&
!encoding.isa<triton::gpu::TritonGPUShardedEncodingAttr>()) !encoding.isa<triton::gpu::TritonGPUBlockedEncodingAttr>())
return dotOp.emitError() << name << " should be of distributed layout"; return dotOp.emitError() << name << " should be of distributed layout";
if (name == 'c') if (name == 'c')
cLayout = encoding; cLayout = encoding;

View File

@@ -40,89 +40,89 @@ module {
return return
} }
} }
module { // module {
func @add_kernel__Pfp32_Pfp32_Pfp32_i32_i32_i32__(%arg0: !tt.ptr<f32>, %arg1: !tt.ptr<f32>, %arg2: !tt.ptr<f32>, %arg3: i32, %arg4: i32, %arg5: i32) { // func @add_kernel__Pfp32_Pfp32_Pfp32_i32_i32_i32__(%arg0: !tt.ptr<f32>, %arg1: !tt.ptr<f32>, %arg2: !tt.ptr<f32>, %arg3: i32, %arg4: i32, %arg5: i32) {
%c64 = arith.constant 64 : index // %c64 = arith.constant 64 : index
%c32 = arith.constant 32 : index // %c32 = arith.constant 32 : index
%c0 = arith.constant 0 : index // %c0 = arith.constant 0 : index
%cst = arith.constant 0.000000e+00 : f32 // %cst = arith.constant 0.000000e+00 : f32
%c256_i32 = arith.constant 256 : i32 // %c256_i32 = arith.constant 256 : i32
%0 = tt.get_program_id {axis = 0 : i32} : i32 // %0 = tt.get_program_id {axis = 0 : i32} : i32
%1 = arith.muli %0, %c256_i32 : i32 // %1 = arith.muli %0, %c256_i32 : i32
%2 = tt.make_range {end = 256 : i32, start = 0 : i32} : tensor<256xi32, #triton_gpu<"coalesced encoding<threadTileSize = 1, blockTileSize = 32, order = 0>">> // %2 = tt.make_range {end = 256 : i32, start = 0 : i32} : tensor<256xi32, #triton_gpu<"coalesced encoding<threadTileSize = 1, blockTileSize = 32, order = 0>">>
%3 = tt.broadcast %1 : (i32) -> tensor<256xi32, #triton_gpu<"coalesced encoding<threadTileSize = 1, blockTileSize = 32, order = 0>">> // %3 = tt.broadcast %1 : (i32) -> tensor<256xi32, #triton_gpu<"coalesced encoding<threadTileSize = 1, blockTileSize = 32, order = 0>">>
%4 = arith.addi %3, %2 : tensor<256xi32, #triton_gpu<"coalesced encoding<threadTileSize = 1, blockTileSize = 32, order = 0>">> // %4 = arith.addi %3, %2 : tensor<256xi32, #triton_gpu<"coalesced encoding<threadTileSize = 1, blockTileSize = 32, order = 0>">>
%5 = tt.broadcast %arg3 : (i32) -> tensor<256xi32, #triton_gpu<"coalesced encoding<threadTileSize = 1, blockTileSize = 32, order = 0>">> // %5 = tt.broadcast %arg3 : (i32) -> tensor<256xi32, #triton_gpu<"coalesced encoding<threadTileSize = 1, blockTileSize = 32, order = 0>">>
%6 = "triton_gpu.cmpi"(%4, %5) {predicate = 2 : i64} : (tensor<256xi32, #triton_gpu<"coalesced encoding<threadTileSize = 1, blockTileSize = 32, order = 0>">>, tensor<256xi32, #triton_gpu<"coalesced encoding<threadTileSize = 1, blockTileSize = 32, order = 0>">>) -> tensor<256xi1, #triton_gpu<"coalesced encoding<threadTileSize = 1, blockTileSize = 32, order = 0>">> // %6 = "triton_gpu.cmpi"(%4, %5) {predicate = 2 : i64} : (tensor<256xi32, #triton_gpu<"coalesced encoding<threadTileSize = 1, blockTileSize = 32, order = 0>">>, tensor<256xi32, #triton_gpu<"coalesced encoding<threadTileSize = 1, blockTileSize = 32, order = 0>">>) -> tensor<256xi1, #triton_gpu<"coalesced encoding<threadTileSize = 1, blockTileSize = 32, order = 0>">>
%7 = tt.broadcast %arg0 : (!tt.ptr<f32>) -> tensor<256x!tt.ptr<f32>, #triton_gpu<"coalesced encoding<threadTileSize = 1, blockTileSize = 32, order = 0>">> // %7 = tt.broadcast %arg0 : (!tt.ptr<f32>) -> tensor<256x!tt.ptr<f32>, #triton_gpu<"coalesced encoding<threadTileSize = 1, blockTileSize = 32, order = 0>">>
%8 = tt.getelementptr %7, %4, : tensor<256x!tt.ptr<f32>, #triton_gpu<"coalesced encoding<threadTileSize = 1, blockTileSize = 32, order = 0>">> // %8 = tt.getelementptr %7, %4, : tensor<256x!tt.ptr<f32>, #triton_gpu<"coalesced encoding<threadTileSize = 1, blockTileSize = 32, order = 0>">>
%9 = tt.broadcast %arg1 : (!tt.ptr<f32>) -> tensor<256x!tt.ptr<f32>, #triton_gpu<"coalesced encoding<threadTileSize = 1, blockTileSize = 32, order = 0>">> // %9 = tt.broadcast %arg1 : (!tt.ptr<f32>) -> tensor<256x!tt.ptr<f32>, #triton_gpu<"coalesced encoding<threadTileSize = 1, blockTileSize = 32, order = 0>">>
%10 = tt.getelementptr %9, %4, : tensor<256x!tt.ptr<f32>, #triton_gpu<"coalesced encoding<threadTileSize = 1, blockTileSize = 32, order = 0>">> // %10 = tt.getelementptr %9, %4, : tensor<256x!tt.ptr<f32>, #triton_gpu<"coalesced encoding<threadTileSize = 1, blockTileSize = 32, order = 0>">>
%11 = tt.broadcast %cst : (f32) -> tensor<256xf32, #triton_gpu<"coalesced encoding<threadTileSize = 1, blockTileSize = 32, order = 0>">> // %11 = tt.broadcast %cst : (f32) -> tensor<256xf32, #triton_gpu<"coalesced encoding<threadTileSize = 1, blockTileSize = 32, order = 0>">>
%12 = arith.index_cast %arg4 : i32 to index // %12 = arith.index_cast %arg4 : i32 to index
%13 = arith.cmpi slt, %c0, %12 : index // %13 = arith.cmpi slt, %c0, %12 : index
%14 = tt.broadcast %cst : (f32) -> tensor<256xf32, #triton_gpu<"coalesced encoding<threadTileSize = 1, blockTileSize = 32, order = 0>">> // %14 = tt.broadcast %cst : (f32) -> tensor<256xf32, #triton_gpu<"coalesced encoding<threadTileSize = 1, blockTileSize = 32, order = 0>">>
%15 = tt.broadcast %13 : (i1) -> tensor<256xi1, #triton_gpu<"coalesced encoding<threadTileSize = 1, blockTileSize = 32, order = 0>">> // %15 = tt.broadcast %13 : (i1) -> tensor<256xi1, #triton_gpu<"coalesced encoding<threadTileSize = 1, blockTileSize = 32, order = 0>">>
%16 = arith.andi %6, %15 : tensor<256xi1, #triton_gpu<"coalesced encoding<threadTileSize = 1, blockTileSize = 32, order = 0>">> // %16 = arith.andi %6, %15 : tensor<256xi1, #triton_gpu<"coalesced encoding<threadTileSize = 1, blockTileSize = 32, order = 0>">>
%17 = triton_gpu.copy_async %8, %16, %14 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<256x!tt.ptr<f32>, #triton_gpu<"coalesced encoding<threadTileSize = 1, blockTileSize = 32, order = 0>">> -> tensor<256xf32, #triton_gpu<"coalesced encoding<threadTileSize = 1, blockTileSize = 32, order = 0>">> // %17 = triton_gpu.copy_async %8, %16, %14 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<256x!tt.ptr<f32>, #triton_gpu<"coalesced encoding<threadTileSize = 1, blockTileSize = 32, order = 0>">> -> tensor<256xf32, #triton_gpu<"coalesced encoding<threadTileSize = 1, blockTileSize = 32, order = 0>">>
%18 = tt.broadcast %cst : (f32) -> tensor<256xf32, #triton_gpu<"coalesced encoding<threadTileSize = 1, blockTileSize = 32, order = 0>">> // %18 = tt.broadcast %cst : (f32) -> tensor<256xf32, #triton_gpu<"coalesced encoding<threadTileSize = 1, blockTileSize = 32, order = 0>">>
%19 = tt.broadcast %13 : (i1) -> tensor<256xi1, #triton_gpu<"coalesced encoding<threadTileSize = 1, blockTileSize = 32, order = 0>">> // %19 = tt.broadcast %13 : (i1) -> tensor<256xi1, #triton_gpu<"coalesced encoding<threadTileSize = 1, blockTileSize = 32, order = 0>">>
%20 = arith.andi %6, %19 : tensor<256xi1, #triton_gpu<"coalesced encoding<threadTileSize = 1, blockTileSize = 32, order = 0>">> // %20 = arith.andi %6, %19 : tensor<256xi1, #triton_gpu<"coalesced encoding<threadTileSize = 1, blockTileSize = 32, order = 0>">>
%21 = triton_gpu.copy_async %10, %20, %18 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<256x!tt.ptr<f32>, #triton_gpu<"coalesced encoding<threadTileSize = 1, blockTileSize = 32, order = 0>">> -> tensor<256xf32, #triton_gpu<"coalesced encoding<threadTileSize = 1, blockTileSize = 32, order = 0>">> // %21 = triton_gpu.copy_async %10, %20, %18 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<256x!tt.ptr<f32>, #triton_gpu<"coalesced encoding<threadTileSize = 1, blockTileSize = 32, order = 0>">> -> tensor<256xf32, #triton_gpu<"coalesced encoding<threadTileSize = 1, blockTileSize = 32, order = 0>">>
%22 = tt.broadcast %arg5 : (i32) -> tensor<256xi32, #triton_gpu<"coalesced encoding<threadTileSize = 1, blockTileSize = 32, order = 0>">> // %22 = tt.broadcast %arg5 : (i32) -> tensor<256xi32, #triton_gpu<"coalesced encoding<threadTileSize = 1, blockTileSize = 32, order = 0>">>
%23 = tt.getelementptr %8, %22, : tensor<256x!tt.ptr<f32>, #triton_gpu<"coalesced encoding<threadTileSize = 1, blockTileSize = 32, order = 0>">> // %23 = tt.getelementptr %8, %22, : tensor<256x!tt.ptr<f32>, #triton_gpu<"coalesced encoding<threadTileSize = 1, blockTileSize = 32, order = 0>">>
%24 = tt.broadcast %arg5 : (i32) -> tensor<256xi32, #triton_gpu<"coalesced encoding<threadTileSize = 1, blockTileSize = 32, order = 0>">> // %24 = tt.broadcast %arg5 : (i32) -> tensor<256xi32, #triton_gpu<"coalesced encoding<threadTileSize = 1, blockTileSize = 32, order = 0>">>
%25 = tt.getelementptr %10, %24, : tensor<256x!tt.ptr<f32>, #triton_gpu<"coalesced encoding<threadTileSize = 1, blockTileSize = 32, order = 0>">> // %25 = tt.getelementptr %10, %24, : tensor<256x!tt.ptr<f32>, #triton_gpu<"coalesced encoding<threadTileSize = 1, blockTileSize = 32, order = 0>">>
%26 = arith.cmpi slt, %c32, %12 : index // %26 = arith.cmpi slt, %c32, %12 : index
%27 = tt.broadcast %cst : (f32) -> tensor<256xf32, #triton_gpu<"coalesced encoding<threadTileSize = 1, blockTileSize = 32, order = 0>">> // %27 = tt.broadcast %cst : (f32) -> tensor<256xf32, #triton_gpu<"coalesced encoding<threadTileSize = 1, blockTileSize = 32, order = 0>">>
%28 = tt.broadcast %26 : (i1) -> tensor<256xi1, #triton_gpu<"coalesced encoding<threadTileSize = 1, blockTileSize = 32, order = 0>">> // %28 = tt.broadcast %26 : (i1) -> tensor<256xi1, #triton_gpu<"coalesced encoding<threadTileSize = 1, blockTileSize = 32, order = 0>">>
%29 = arith.andi %6, %28 : tensor<256xi1, #triton_gpu<"coalesced encoding<threadTileSize = 1, blockTileSize = 32, order = 0>">> // %29 = arith.andi %6, %28 : tensor<256xi1, #triton_gpu<"coalesced encoding<threadTileSize = 1, blockTileSize = 32, order = 0>">>
%30 = triton_gpu.copy_async %23, %29, %27 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<256x!tt.ptr<f32>, #triton_gpu<"coalesced encoding<threadTileSize = 1, blockTileSize = 32, order = 0>">> -> tensor<256xf32, #triton_gpu<"coalesced encoding<threadTileSize = 1, blockTileSize = 32, order = 0>">> // %30 = triton_gpu.copy_async %23, %29, %27 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<256x!tt.ptr<f32>, #triton_gpu<"coalesced encoding<threadTileSize = 1, blockTileSize = 32, order = 0>">> -> tensor<256xf32, #triton_gpu<"coalesced encoding<threadTileSize = 1, blockTileSize = 32, order = 0>">>
%31 = tt.broadcast %cst : (f32) -> tensor<256xf32, #triton_gpu<"coalesced encoding<threadTileSize = 1, blockTileSize = 32, order = 0>">> // %31 = tt.broadcast %cst : (f32) -> tensor<256xf32, #triton_gpu<"coalesced encoding<threadTileSize = 1, blockTileSize = 32, order = 0>">>
%32 = tt.broadcast %26 : (i1) -> tensor<256xi1, #triton_gpu<"coalesced encoding<threadTileSize = 1, blockTileSize = 32, order = 0>">> // %32 = tt.broadcast %26 : (i1) -> tensor<256xi1, #triton_gpu<"coalesced encoding<threadTileSize = 1, blockTileSize = 32, order = 0>">>
%33 = arith.andi %6, %32 : tensor<256xi1, #triton_gpu<"coalesced encoding<threadTileSize = 1, blockTileSize = 32, order = 0>">> // %33 = arith.andi %6, %32 : tensor<256xi1, #triton_gpu<"coalesced encoding<threadTileSize = 1, blockTileSize = 32, order = 0>">>
%34 = triton_gpu.copy_async %25, %33, %31 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<256x!tt.ptr<f32>, #triton_gpu<"coalesced encoding<threadTileSize = 1, blockTileSize = 32, order = 0>">> -> tensor<256xf32, #triton_gpu<"coalesced encoding<threadTileSize = 1, blockTileSize = 32, order = 0>">> // %34 = triton_gpu.copy_async %25, %33, %31 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<256x!tt.ptr<f32>, #triton_gpu<"coalesced encoding<threadTileSize = 1, blockTileSize = 32, order = 0>">> -> tensor<256xf32, #triton_gpu<"coalesced encoding<threadTileSize = 1, blockTileSize = 32, order = 0>">>
%35 = tt.broadcast %arg5 : (i32) -> tensor<256xi32, #triton_gpu<"coalesced encoding<threadTileSize = 1, blockTileSize = 32, order = 0>">> // %35 = tt.broadcast %arg5 : (i32) -> tensor<256xi32, #triton_gpu<"coalesced encoding<threadTileSize = 1, blockTileSize = 32, order = 0>">>
%36 = tt.getelementptr %23, %35, : tensor<256x!tt.ptr<f32>, #triton_gpu<"coalesced encoding<threadTileSize = 1, blockTileSize = 32, order = 0>">> // %36 = tt.getelementptr %23, %35, : tensor<256x!tt.ptr<f32>, #triton_gpu<"coalesced encoding<threadTileSize = 1, blockTileSize = 32, order = 0>">>
%37 = tt.broadcast %arg5 : (i32) -> tensor<256xi32, #triton_gpu<"coalesced encoding<threadTileSize = 1, blockTileSize = 32, order = 0>">> // %37 = tt.broadcast %arg5 : (i32) -> tensor<256xi32, #triton_gpu<"coalesced encoding<threadTileSize = 1, blockTileSize = 32, order = 0>">>
%38 = tt.getelementptr %25, %37, : tensor<256x!tt.ptr<f32>, #triton_gpu<"coalesced encoding<threadTileSize = 1, blockTileSize = 32, order = 0>">> // %38 = tt.getelementptr %25, %37, : tensor<256x!tt.ptr<f32>, #triton_gpu<"coalesced encoding<threadTileSize = 1, blockTileSize = 32, order = 0>">>
%39 = arith.cmpi slt, %c64, %12 : index // %39 = arith.cmpi slt, %c64, %12 : index
%40 = tt.broadcast %cst : (f32) -> tensor<256xf32, #triton_gpu<"coalesced encoding<threadTileSize = 1, blockTileSize = 32, order = 0>">> // %40 = tt.broadcast %cst : (f32) -> tensor<256xf32, #triton_gpu<"coalesced encoding<threadTileSize = 1, blockTileSize = 32, order = 0>">>
%41 = tt.broadcast %39 : (i1) -> tensor<256xi1, #triton_gpu<"coalesced encoding<threadTileSize = 1, blockTileSize = 32, order = 0>">> // %41 = tt.broadcast %39 : (i1) -> tensor<256xi1, #triton_gpu<"coalesced encoding<threadTileSize = 1, blockTileSize = 32, order = 0>">>
%42 = arith.andi %6, %41 : tensor<256xi1, #triton_gpu<"coalesced encoding<threadTileSize = 1, blockTileSize = 32, order = 0>">> // %42 = arith.andi %6, %41 : tensor<256xi1, #triton_gpu<"coalesced encoding<threadTileSize = 1, blockTileSize = 32, order = 0>">>
%43 = triton_gpu.copy_async %36, %42, %40 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<256x!tt.ptr<f32>, #triton_gpu<"coalesced encoding<threadTileSize = 1, blockTileSize = 32, order = 0>">> -> tensor<256xf32, #triton_gpu<"coalesced encoding<threadTileSize = 1, blockTileSize = 32, order = 0>">> // %43 = triton_gpu.copy_async %36, %42, %40 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<256x!tt.ptr<f32>, #triton_gpu<"coalesced encoding<threadTileSize = 1, blockTileSize = 32, order = 0>">> -> tensor<256xf32, #triton_gpu<"coalesced encoding<threadTileSize = 1, blockTileSize = 32, order = 0>">>
%44 = tt.broadcast %cst : (f32) -> tensor<256xf32, #triton_gpu<"coalesced encoding<threadTileSize = 1, blockTileSize = 32, order = 0>">> // %44 = tt.broadcast %cst : (f32) -> tensor<256xf32, #triton_gpu<"coalesced encoding<threadTileSize = 1, blockTileSize = 32, order = 0>">>
%45 = tt.broadcast %39 : (i1) -> tensor<256xi1, #triton_gpu<"coalesced encoding<threadTileSize = 1, blockTileSize = 32, order = 0>">> // %45 = tt.broadcast %39 : (i1) -> tensor<256xi1, #triton_gpu<"coalesced encoding<threadTileSize = 1, blockTileSize = 32, order = 0>">>
%46 = arith.andi %6, %45 : tensor<256xi1, #triton_gpu<"coalesced encoding<threadTileSize = 1, blockTileSize = 32, order = 0>">> // %46 = arith.andi %6, %45 : tensor<256xi1, #triton_gpu<"coalesced encoding<threadTileSize = 1, blockTileSize = 32, order = 0>">>
%47 = triton_gpu.copy_async %38, %46, %44 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<256x!tt.ptr<f32>, #triton_gpu<"coalesced encoding<threadTileSize = 1, blockTileSize = 32, order = 0>">> -> tensor<256xf32, #triton_gpu<"coalesced encoding<threadTileSize = 1, blockTileSize = 32, order = 0>">> // %47 = triton_gpu.copy_async %38, %46, %44 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<256x!tt.ptr<f32>, #triton_gpu<"coalesced encoding<threadTileSize = 1, blockTileSize = 32, order = 0>">> -> tensor<256xf32, #triton_gpu<"coalesced encoding<threadTileSize = 1, blockTileSize = 32, order = 0>">>
%48 = tt.broadcast %arg5 : (i32) -> tensor<256xi32, #triton_gpu<"coalesced encoding<threadTileSize = 1, blockTileSize = 32, order = 0>">> // %48 = tt.broadcast %arg5 : (i32) -> tensor<256xi32, #triton_gpu<"coalesced encoding<threadTileSize = 1, blockTileSize = 32, order = 0>">>
%49 = tt.getelementptr %36, %48, : tensor<256x!tt.ptr<f32>, #triton_gpu<"coalesced encoding<threadTileSize = 1, blockTileSize = 32, order = 0>">> // %49 = tt.getelementptr %36, %48, : tensor<256x!tt.ptr<f32>, #triton_gpu<"coalesced encoding<threadTileSize = 1, blockTileSize = 32, order = 0>">>
%50 = tt.broadcast %arg5 : (i32) -> tensor<256xi32, #triton_gpu<"coalesced encoding<threadTileSize = 1, blockTileSize = 32, order = 0>">> // %50 = tt.broadcast %arg5 : (i32) -> tensor<256xi32, #triton_gpu<"coalesced encoding<threadTileSize = 1, blockTileSize = 32, order = 0>">>
%51 = tt.getelementptr %38, %50, : tensor<256x!tt.ptr<f32>, #triton_gpu<"coalesced encoding<threadTileSize = 1, blockTileSize = 32, order = 0>">> // %51 = tt.getelementptr %38, %50, : tensor<256x!tt.ptr<f32>, #triton_gpu<"coalesced encoding<threadTileSize = 1, blockTileSize = 32, order = 0>">>
%52:12 = scf.for %arg6 = %c0 to %12 step %c32 iter_args(%arg7 = %11, %arg8 = %8, %arg9 = %10, %arg10 = %17, %arg11 = %30, %arg12 = %43, %arg13 = %21, %arg14 = %34, %arg15 = %47, %arg16 = %51, %arg17 = %49, %arg18 = %c64) -> (tensor<256xf32, #triton_gpu<"coalesced encoding<threadTileSize = 1, blockTileSize = 32, order = 0>">>, tensor<256x!tt.ptr<f32>, #triton_gpu<"coalesced encoding<threadTileSize = 1, blockTileSize = 32, order = 0>">>, tensor<256x!tt.ptr<f32>, #triton_gpu<"coalesced encoding<threadTileSize = 1, blockTileSize = 32, order = 0>">>, tensor<256xf32, #triton_gpu<"coalesced encoding<threadTileSize = 1, blockTileSize = 32, order = 0>">>, tensor<256xf32, #triton_gpu<"coalesced encoding<threadTileSize = 1, blockTileSize = 32, order = 0>">>, tensor<256xf32, #triton_gpu<"coalesced encoding<threadTileSize = 1, blockTileSize = 32, order = 0>">>, tensor<256xf32, #triton_gpu<"coalesced encoding<threadTileSize = 1, blockTileSize = 32, order = 0>">>, tensor<256xf32, #triton_gpu<"coalesced encoding<threadTileSize = 1, blockTileSize = 32, order = 0>">>, tensor<256xf32, #triton_gpu<"coalesced encoding<threadTileSize = 1, blockTileSize = 32, order = 0>">>, tensor<256x!tt.ptr<f32>, #triton_gpu<"coalesced encoding<threadTileSize = 1, blockTileSize = 32, order = 0>">>, tensor<256x!tt.ptr<f32>, #triton_gpu<"coalesced encoding<threadTileSize = 1, blockTileSize = 32, order = 0>">>, index) { // %52:12 = scf.for %arg6 = %c0 to %12 step %c32 iter_args(%arg7 = %11, %arg8 = %8, %arg9 = %10, %arg10 = %17, %arg11 = %30, %arg12 = %43, %arg13 = %21, %arg14 = %34, %arg15 = %47, %arg16 = %51, %arg17 = %49, %arg18 = %c64) -> (tensor<256xf32, #triton_gpu<"coalesced encoding<threadTileSize = 1, blockTileSize = 32, order = 0>">>, tensor<256x!tt.ptr<f32>, #triton_gpu<"coalesced encoding<threadTileSize = 1, blockTileSize = 32, order = 0>">>, tensor<256x!tt.ptr<f32>, #triton_gpu<"coalesced encoding<threadTileSize = 1, blockTileSize = 32, order = 0>">>, tensor<256xf32, #triton_gpu<"coalesced encoding<threadTileSize = 1, blockTileSize = 32, order = 0>">>, tensor<256xf32, #triton_gpu<"coalesced encoding<threadTileSize = 1, blockTileSize = 32, order = 0>">>, tensor<256xf32, #triton_gpu<"coalesced encoding<threadTileSize = 1, blockTileSize = 32, order = 0>">>, tensor<256xf32, #triton_gpu<"coalesced encoding<threadTileSize = 1, blockTileSize = 32, order = 0>">>, tensor<256xf32, #triton_gpu<"coalesced encoding<threadTileSize = 1, blockTileSize = 32, order = 0>">>, tensor<256xf32, #triton_gpu<"coalesced encoding<threadTileSize = 1, blockTileSize = 32, order = 0>">>, tensor<256x!tt.ptr<f32>, #triton_gpu<"coalesced encoding<threadTileSize = 1, blockTileSize = 32, order = 0>">>, tensor<256x!tt.ptr<f32>, #triton_gpu<"coalesced encoding<threadTileSize = 1, blockTileSize = 32, order = 0>">>, index) {
%55 = arith.addf %arg10, %arg13 : tensor<256xf32, #triton_gpu<"coalesced encoding<threadTileSize = 1, blockTileSize = 32, order = 0>">> // %55 = arith.addf %arg10, %arg13 : tensor<256xf32, #triton_gpu<"coalesced encoding<threadTileSize = 1, blockTileSize = 32, order = 0>">>
%56 = arith.addf %arg7, %55 : tensor<256xf32, #triton_gpu<"coalesced encoding<threadTileSize = 1, blockTileSize = 32, order = 0>">> // %56 = arith.addf %arg7, %55 : tensor<256xf32, #triton_gpu<"coalesced encoding<threadTileSize = 1, blockTileSize = 32, order = 0>">>
%57 = tt.broadcast %arg5 : (i32) -> tensor<256xi32, #triton_gpu<"coalesced encoding<threadTileSize = 1, blockTileSize = 32, order = 0>">> // %57 = tt.broadcast %arg5 : (i32) -> tensor<256xi32, #triton_gpu<"coalesced encoding<threadTileSize = 1, blockTileSize = 32, order = 0>">>
%58 = tt.getelementptr %arg8, %57, : tensor<256x!tt.ptr<f32>, #triton_gpu<"coalesced encoding<threadTileSize = 1, blockTileSize = 32, order = 0>">> // %58 = tt.getelementptr %arg8, %57, : tensor<256x!tt.ptr<f32>, #triton_gpu<"coalesced encoding<threadTileSize = 1, blockTileSize = 32, order = 0>">>
%59 = tt.broadcast %arg5 : (i32) -> tensor<256xi32, #triton_gpu<"coalesced encoding<threadTileSize = 1, blockTileSize = 32, order = 0>">> // %59 = tt.broadcast %arg5 : (i32) -> tensor<256xi32, #triton_gpu<"coalesced encoding<threadTileSize = 1, blockTileSize = 32, order = 0>">>
%60 = tt.getelementptr %arg9, %59, : tensor<256x!tt.ptr<f32>, #triton_gpu<"coalesced encoding<threadTileSize = 1, blockTileSize = 32, order = 0>">> // %60 = tt.getelementptr %arg9, %59, : tensor<256x!tt.ptr<f32>, #triton_gpu<"coalesced encoding<threadTileSize = 1, blockTileSize = 32, order = 0>">>
%61 = arith.addi %arg18, %c32 : index // %61 = arith.addi %arg18, %c32 : index
%62 = arith.cmpi slt, %61, %12 : index // %62 = arith.cmpi slt, %61, %12 : index
%63 = tt.broadcast %cst : (f32) -> tensor<256xf32, #triton_gpu<"coalesced encoding<threadTileSize = 1, blockTileSize = 32, order = 0>">> // %63 = tt.broadcast %cst : (f32) -> tensor<256xf32, #triton_gpu<"coalesced encoding<threadTileSize = 1, blockTileSize = 32, order = 0>">>
%64 = tt.broadcast %62 : (i1) -> tensor<256xi1, #triton_gpu<"coalesced encoding<threadTileSize = 1, blockTileSize = 32, order = 0>">> // %64 = tt.broadcast %62 : (i1) -> tensor<256xi1, #triton_gpu<"coalesced encoding<threadTileSize = 1, blockTileSize = 32, order = 0>">>
%65 = arith.andi %64, %6 : tensor<256xi1, #triton_gpu<"coalesced encoding<threadTileSize = 1, blockTileSize = 32, order = 0>">> // %65 = arith.andi %64, %6 : tensor<256xi1, #triton_gpu<"coalesced encoding<threadTileSize = 1, blockTileSize = 32, order = 0>">>
%66 = triton_gpu.copy_async %arg17, %65, %63 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<256x!tt.ptr<f32>, #triton_gpu<"coalesced encoding<threadTileSize = 1, blockTileSize = 32, order = 0>">> -> tensor<256xf32, #triton_gpu<"coalesced encoding<threadTileSize = 1, blockTileSize = 32, order = 0>">> // %66 = triton_gpu.copy_async %arg17, %65, %63 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<256x!tt.ptr<f32>, #triton_gpu<"coalesced encoding<threadTileSize = 1, blockTileSize = 32, order = 0>">> -> tensor<256xf32, #triton_gpu<"coalesced encoding<threadTileSize = 1, blockTileSize = 32, order = 0>">>
%67 = tt.broadcast %cst : (f32) -> tensor<256xf32, #triton_gpu<"coalesced encoding<threadTileSize = 1, blockTileSize = 32, order = 0>">> // %67 = tt.broadcast %cst : (f32) -> tensor<256xf32, #triton_gpu<"coalesced encoding<threadTileSize = 1, blockTileSize = 32, order = 0>">>
%68 = triton_gpu.copy_async %arg16, %65, %67 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<256x!tt.ptr<f32>, #triton_gpu<"coalesced encoding<threadTileSize = 1, blockTileSize = 32, order = 0>">> -> tensor<256xf32, #triton_gpu<"coalesced encoding<threadTileSize = 1, blockTileSize = 32, order = 0>">> // %68 = triton_gpu.copy_async %arg16, %65, %67 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<256x!tt.ptr<f32>, #triton_gpu<"coalesced encoding<threadTileSize = 1, blockTileSize = 32, order = 0>">> -> tensor<256xf32, #triton_gpu<"coalesced encoding<threadTileSize = 1, blockTileSize = 32, order = 0>">>
%69 = tt.broadcast %arg5 : (i32) -> tensor<256xi32, #triton_gpu<"coalesced encoding<threadTileSize = 1, blockTileSize = 32, order = 0>">> // %69 = tt.broadcast %arg5 : (i32) -> tensor<256xi32, #triton_gpu<"coalesced encoding<threadTileSize = 1, blockTileSize = 32, order = 0>">>
%70 = tt.getelementptr %arg17, %69, : tensor<256x!tt.ptr<f32>, #triton_gpu<"coalesced encoding<threadTileSize = 1, blockTileSize = 32, order = 0>">> // %70 = tt.getelementptr %arg17, %69, : tensor<256x!tt.ptr<f32>, #triton_gpu<"coalesced encoding<threadTileSize = 1, blockTileSize = 32, order = 0>">>
%71 = tt.broadcast %arg5 : (i32) -> tensor<256xi32, #triton_gpu<"coalesced encoding<threadTileSize = 1, blockTileSize = 32, order = 0>">> // %71 = tt.broadcast %arg5 : (i32) -> tensor<256xi32, #triton_gpu<"coalesced encoding<threadTileSize = 1, blockTileSize = 32, order = 0>">>
%72 = tt.getelementptr %arg16, %71, : tensor<256x!tt.ptr<f32>, #triton_gpu<"coalesced encoding<threadTileSize = 1, blockTileSize = 32, order = 0>">> // %72 = tt.getelementptr %arg16, %71, : tensor<256x!tt.ptr<f32>, #triton_gpu<"coalesced encoding<threadTileSize = 1, blockTileSize = 32, order = 0>">>
scf.yield %56, %58, %60, %arg11, %arg12, %66, %arg14, %arg15, %68, %72, %70, %61 : tensor<256xf32, #triton_gpu<"coalesced encoding<threadTileSize = 1, blockTileSize = 32, order = 0>">>, tensor<256x!tt.ptr<f32>, #triton_gpu<"coalesced encoding<threadTileSize = 1, blockTileSize = 32, order = 0>">>, tensor<256x!tt.ptr<f32>, #triton_gpu<"coalesced encoding<threadTileSize = 1, blockTileSize = 32, order = 0>">>, tensor<256xf32, #triton_gpu<"coalesced encoding<threadTileSize = 1, blockTileSize = 32, order = 0>">>, tensor<256xf32, #triton_gpu<"coalesced encoding<threadTileSize = 1, blockTileSize = 32, order = 0>">>, tensor<256xf32, #triton_gpu<"coalesced encoding<threadTileSize = 1, blockTileSize = 32, order = 0>">>, tensor<256xf32, #triton_gpu<"coalesced encoding<threadTileSize = 1, blockTileSize = 32, order = 0>">>, tensor<256xf32, #triton_gpu<"coalesced encoding<threadTileSize = 1, blockTileSize = 32, order = 0>">>, tensor<256xf32, #triton_gpu<"coalesced encoding<threadTileSize = 1, blockTileSize = 32, order = 0>">>, tensor<256x!tt.ptr<f32>, #triton_gpu<"coalesced encoding<threadTileSize = 1, blockTileSize = 32, order = 0>">>, tensor<256x!tt.ptr<f32>, #triton_gpu<"coalesced encoding<threadTileSize = 1, blockTileSize = 32, order = 0>">>, index // scf.yield %56, %58, %60, %arg11, %arg12, %66, %arg14, %arg15, %68, %72, %70, %61 : tensor<256xf32, #triton_gpu<"coalesced encoding<threadTileSize = 1, blockTileSize = 32, order = 0>">>, tensor<256x!tt.ptr<f32>, #triton_gpu<"coalesced encoding<threadTileSize = 1, blockTileSize = 32, order = 0>">>, tensor<256x!tt.ptr<f32>, #triton_gpu<"coalesced encoding<threadTileSize = 1, blockTileSize = 32, order = 0>">>, tensor<256xf32, #triton_gpu<"coalesced encoding<threadTileSize = 1, blockTileSize = 32, order = 0>">>, tensor<256xf32, #triton_gpu<"coalesced encoding<threadTileSize = 1, blockTileSize = 32, order = 0>">>, tensor<256xf32, #triton_gpu<"coalesced encoding<threadTileSize = 1, blockTileSize = 32, order = 0>">>, tensor<256xf32, #triton_gpu<"coalesced encoding<threadTileSize = 1, blockTileSize = 32, order = 0>">>, tensor<256xf32, #triton_gpu<"coalesced encoding<threadTileSize = 1, blockTileSize = 32, order = 0>">>, tensor<256xf32, #triton_gpu<"coalesced encoding<threadTileSize = 1, blockTileSize = 32, order = 0>">>, tensor<256x!tt.ptr<f32>, #triton_gpu<"coalesced encoding<threadTileSize = 1, blockTileSize = 32, order = 0>">>, tensor<256x!tt.ptr<f32>, #triton_gpu<"coalesced encoding<threadTileSize = 1, blockTileSize = 32, order = 0>">>, index
} // }
%53 = tt.broadcast %arg2 : (!tt.ptr<f32>) -> tensor<256x!tt.ptr<f32>, #triton_gpu<"coalesced encoding<threadTileSize = 1, blockTileSize = 32, order = 0>">> // %53 = tt.broadcast %arg2 : (!tt.ptr<f32>) -> tensor<256x!tt.ptr<f32>, #triton_gpu<"coalesced encoding<threadTileSize = 1, blockTileSize = 32, order = 0>">>
%54 = tt.getelementptr %53, %4, : tensor<256x!tt.ptr<f32>, #triton_gpu<"coalesced encoding<threadTileSize = 1, blockTileSize = 32, order = 0>">> // %54 = tt.getelementptr %53, %4, : tensor<256x!tt.ptr<f32>, #triton_gpu<"coalesced encoding<threadTileSize = 1, blockTileSize = 32, order = 0>">>
tt.store %54, %52#0, %6, : tensor<256xf32, #triton_gpu<"coalesced encoding<threadTileSize = 1, blockTileSize = 32, order = 0>">> // tt.store %54, %52#0, %6, : tensor<256xf32, #triton_gpu<"coalesced encoding<threadTileSize = 1, blockTileSize = 32, order = 0>">>
return // return
} // }
} // }

View File

@@ -1,3 +1,5 @@
// RUN: triton-opt %s -tritongpu-verifier -verify-diagnostics
module { module {
func @add_kernel__Pfp32_Pfp32_Pfp32_i32_i32_i32__(%arg0: !tt.ptr<f32>, %arg1: !tt.ptr<f32>, %arg2: !tt.ptr<f32>, %arg3: i32, %arg4: i32, %arg5: i32) { func @add_kernel__Pfp32_Pfp32_Pfp32_i32_i32_i32__(%arg0: !tt.ptr<f32>, %arg1: !tt.ptr<f32>, %arg2: !tt.ptr<f32>, %arg3: i32, %arg4: i32, %arg5: i32) {
%0 = tt.get_program_id {axis = 0 : i32} : i32 %0 = tt.get_program_id {axis = 0 : i32} : i32

View File

@@ -1,13 +1,13 @@
// RUN: triton-opt %s -split-input-file -verify-diagnostics // RUN: triton-opt %s -split-input-file -verify-diagnostics
#reg = #triton_gpu.sharded_layout<{ #reg = #triton_gpu.blocked_layout<{
threadTileSize = [1, 1], threadTileSize = [1, 1],
warpTileSize = [32, 1], warpTileSize = [32, 1],
blockTileSize = [64, 1], blockTileSize = [64, 1],
order = [0] order = [0]
}> }>
#reg2 = #triton_gpu.sharded_layout<{ #reg2 = #triton_gpu.blocked_layout<{
threadTileSize = [2, 1], threadTileSize = [2, 1],
warpTileSize = [64, 1], warpTileSize = [64, 1],
blockTileSize = [128, 1], blockTileSize = [128, 1],