more progress on the definition of layouts

This commit is contained in:
Da Yan
2022-05-31 11:43:21 +00:00
parent 41d338d848
commit e36a54eb86
4 changed files with 203 additions and 36 deletions

View File

@@ -35,6 +35,7 @@ TritonGPUTypeConverter::TritonGPUTypeConverter(MLIRContext *context,
// Now we assume:
// contiguous = 1, order = 0, 1, 2, ...,
llvm::SmallVector<unsigned> threadTileSize(rank, 1); // naive layout
llvm::SmallVector<unsigned> warpTileSize(rank, 1);
llvm::SmallVector<unsigned> blockTileSize(rank);
llvm::SmallVector<unsigned> order(rank);
int remainingThreads = numThreads;
@@ -45,8 +46,8 @@ TritonGPUTypeConverter::TritonGPUTypeConverter(MLIRContext *context,
remainingThreads /= blockTileSize[dim];
// TODO: will we need repetition?
}
Attribute encoding = triton::gpu::TritonGPUDistributedEncodingAttr::get(
context, threadTileSize, blockTileSize, order);
Attribute encoding = triton::gpu::TritonGPUShardedEncodingAttr::get(
context, threadTileSize, warpTileSize, blockTileSize, order);
return RankedTensorType::get(shape, elementType, encoding);
});

View File

@@ -50,7 +50,7 @@ private:
if (!encoding)
return dotOp.emitError() << name << " should have encoding";
if (!encoding.isa<triton::gpu::TritonGPUMmaEncodingAttr>() &&
!encoding.isa<triton::gpu::TritonGPUDistributedEncodingAttr>())
!encoding.isa<triton::gpu::TritonGPUShardedEncodingAttr>())
return dotOp.emitError() << name << " should be of distributed layout";
if (name == 'c')
cLayout = encoding;