[CI] run clang-format (#24)

This commit is contained in:
Philippe Tillet
2022-07-26 17:25:03 -07:00
committed by GitHub
parent 25357083e6
commit 6d62d88d4f
62 changed files with 13673 additions and 11367 deletions

View File

@@ -1,7 +1,7 @@
#include "triton/Dialect/TritonGPU/Transforms/TritonGPUConversion.h"
#include "mlir/IR/BlockAndValueMapping.h"
#include "triton/Dialect/Triton/IR/Dialect.h"
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
#include "mlir/IR/BlockAndValueMapping.h"
#include <algorithm>
using namespace mlir;
@@ -10,7 +10,7 @@ using namespace mlir::triton::gpu;
//
// TypeConverter
//
TritonGPUTypeConverter::TritonGPUTypeConverter(MLIRContext *context,
TritonGPUTypeConverter::TritonGPUTypeConverter(MLIRContext *context,
int numThreads)
: context(context), numThreads(numThreads) {
// TODO: how does MLIR pick the right conversion?
@@ -38,14 +38,14 @@ TritonGPUTypeConverter::TritonGPUTypeConverter(MLIRContext *context,
// or assert no encoding?
// Now we assume:
// contiguous = 1, order = 0, 1, 2, ...,
// contiguous = 1, order = 0, 1, 2, ...,
llvm::SmallVector<unsigned> threadTileSize(rank, 1); // naive layout
llvm::SmallVector<unsigned> warpTileSize(rank, 1);
llvm::SmallVector<unsigned> blockTileSize(rank);
llvm::SmallVector<unsigned> order(rank);
llvm::SmallVector<unsigned> broadcastAxis;
int remainingThreads = numThreads;
int remainingLanes = /*warp size*/32;
int remainingLanes = /*warp size*/ 32;
for (int64_t dim = 0; dim < rank; ++dim) {
blockTileSize[dim] = std::clamp(remainingThreads, 1, int(shape[dim]));
warpTileSize[dim] = std::clamp(remainingLanes, 1, int(shape[dim]));
@@ -56,7 +56,8 @@ TritonGPUTypeConverter::TritonGPUTypeConverter(MLIRContext *context,
// TODO: will we need repetition?
}
Attribute encoding = triton::gpu::TritonGPUBlockedEncodingAttr::get(
context, threadTileSize, warpTileSize, blockTileSize, order, broadcastAxis);
context, threadTileSize, warpTileSize, blockTileSize, order,
broadcastAxis);
return RankedTensorType::get(shape, elementType, encoding);
});
@@ -65,8 +66,9 @@ TritonGPUTypeConverter::TritonGPUTypeConverter(MLIRContext *context,
//
// This will be called when (newArgType != origArgType)
// This will create newArg, and map(origArg, newArg)
addArgumentMaterialization([&](OpBuilder &builder, RankedTensorType tensorType,
ValueRange inputs, Location loc) {
addArgumentMaterialization([&](OpBuilder &builder,
RankedTensorType tensorType, ValueRange inputs,
Location loc) {
llvm_unreachable("Not implemented");
return llvm::None;
});
@@ -74,7 +76,7 @@ TritonGPUTypeConverter::TritonGPUTypeConverter(MLIRContext *context,
// If the origValue still has live user(s), use this to
// convert origValue to newValue
addSourceMaterialization([&](OpBuilder &builder, RankedTensorType tensorType,
ValueRange inputs, Location loc) {
ValueRange inputs, Location loc) {
llvm_unreachable("Not implemented");
return llvm::None;
});
@@ -83,7 +85,7 @@ TritonGPUTypeConverter::TritonGPUTypeConverter(MLIRContext *context,
// where, desiredType = typeConverter->convertType(origType)
// NOTE: only for remapped values.
addTargetMaterialization([&](OpBuilder &builder, RankedTensorType tensorType,
ValueRange inputs, Location loc) {
ValueRange inputs, Location loc) {
llvm_unreachable("Not implemented");
return llvm::None;
});
@@ -93,30 +95,31 @@ TritonGPUTypeConverter::TritonGPUTypeConverter(MLIRContext *context,
// TritonGPUConversion
//
TritonGPUConversionTarget::TritonGPUConversionTarget(
MLIRContext &context, TritonGPUTypeConverter &typeConverter)
MLIRContext &context, TritonGPUTypeConverter &typeConverter)
: ConversionTarget(context), typeConverter(typeConverter) {
// TODO: we should also verify ops of TritonGPUDialect
addLegalDialect<triton::gpu::TritonGPUDialect>();
// Some ops from SCF are illegal
addIllegalOp<scf::ExecuteRegionOp, scf::ParallelOp,
scf::ReduceOp, scf::ReduceReturnOp>();
addDynamicallyLegalDialect<arith::ArithmeticDialect,
triton::TritonDialect,
StandardOpsDialect,
scf::SCFDialect>([&](Operation *op) {
if (typeConverter.isLegal(op))
return true;
return false;
});
addIllegalOp<scf::ExecuteRegionOp, scf::ParallelOp, scf::ReduceOp,
scf::ReduceReturnOp>();
addDynamicallyLegalDialect<arith::ArithmeticDialect, triton::TritonDialect,
StandardOpsDialect, scf::SCFDialect>(
[&](Operation *op) {
if (typeConverter.isLegal(op))
return true;
return false;
});
// We have requirements for the data layouts
addDynamicallyLegalOp<triton::DotOp>([this](triton::DotOp dotOp) -> bool {
Attribute aEncoding = dotOp.a().getType().cast<RankedTensorType>().getEncoding();
Attribute bEncoding = dotOp.b().getType().cast<RankedTensorType>().getEncoding();
if (aEncoding && aEncoding.isa<triton::gpu::TritonGPUSharedEncodingAttr>() &&
Attribute aEncoding =
dotOp.a().getType().cast<RankedTensorType>().getEncoding();
Attribute bEncoding =
dotOp.b().getType().cast<RankedTensorType>().getEncoding();
if (aEncoding &&
aEncoding.isa<triton::gpu::TritonGPUSharedEncodingAttr>() &&
bEncoding && bEncoding.isa<triton::gpu::TritonGPUSharedEncodingAttr>())
return true;
// // TODO: we should delete this
@@ -124,7 +127,6 @@ TritonGPUConversionTarget::TritonGPUConversionTarget(
// return true;
return false;
});
}
// %dst = tt.broadcast %src
@@ -133,12 +135,10 @@ TritonGPUConversionTarget::TritonGPUConversionTarget(
// %bcst = tt.broadcast %newSrc
// %dst = convert_layout %bcst
LogicalResult TritonGPUConversionTarget::refineLayouts(ModuleOp mod,
int numThreads) {
int numThreads) {
// collect broadcasts
SmallVector<triton::BroadcastOp> broadcasts;
mod.walk([&](triton::BroadcastOp op) {
broadcasts.push_back(op);
});
mod.walk([&](triton::BroadcastOp op) { broadcasts.push_back(op); });
BlockAndValueMapping mapping;
for (auto broadcast : broadcasts) {
@@ -161,20 +161,14 @@ LogicalResult TritonGPUConversionTarget::refineLayouts(ModuleOp mod,
broadcastAxis.push_back(ax);
Attribute originSrcEnc = tensorType.getEncoding();
if (auto blockedEnc = originSrcEnc.dyn_cast<TritonGPUBlockedEncodingAttr>()) {
if (auto blockedEnc =
originSrcEnc.dyn_cast<TritonGPUBlockedEncodingAttr>()) {
auto newSrcEnc = TritonGPUBlockedMulticastEncodingAttr::get(
blockedEnc.getContext(),
blockedEnc.getThreadTileSize(),
blockedEnc.getWarpTileSize(),
blockedEnc.getBlockTileSize(),
blockedEnc.getOrder(),
broadcastAxis
);
blockedEnc.getContext(), blockedEnc.getThreadTileSize(),
blockedEnc.getWarpTileSize(), blockedEnc.getBlockTileSize(),
blockedEnc.getOrder(), broadcastAxis);
newSrcType = RankedTensorType::get(
tensorType.getShape(),
tensorType.getElementType(),
newSrcEnc
);
tensorType.getShape(), tensorType.getElementType(), newSrcEnc);
} else
llvm_unreachable("src of broadcast should have blocked encoding");
} else {
@@ -186,34 +180,25 @@ LogicalResult TritonGPUConversionTarget::refineLayouts(ModuleOp mod,
// create new src
if (!isSrcScalar) // we don't need to convert layout for scalar values
src = builder.create<triton::gpu::ConvertLayoutOp>(
src.getLoc(), newSrcType, src
);
src = builder.create<triton::gpu::ConvertLayoutOp>(src.getLoc(),
newSrcType, src);
// create new broadcast
// compute new type (encoding)
auto originDstEnc = originDstTensorType.getEncoding()
.dyn_cast<TritonGPUBlockedEncodingAttr>();
.dyn_cast<TritonGPUBlockedEncodingAttr>();
auto newEnc = TritonGPUBlockedEncodingAttr::get(
originDstEnc.getContext(),
originDstEnc.getThreadTileSize(),
originDstEnc.getWarpTileSize(),
originDstEnc.getBlockTileSize(),
originDstEnc.getOrder(),
broadcastAxis
);
auto newType = RankedTensorType::get(
originDstTensorType.getShape(),
originDstTensorType.getElementType(),
newEnc
);
Value newBroadcast = builder.create<triton::BroadcastOp>(
broadcast.getLoc(), newType, src
);
originDstEnc.getContext(), originDstEnc.getThreadTileSize(),
originDstEnc.getWarpTileSize(), originDstEnc.getBlockTileSize(),
originDstEnc.getOrder(), broadcastAxis);
auto newType =
RankedTensorType::get(originDstTensorType.getShape(),
originDstTensorType.getElementType(), newEnc);
Value newBroadcast =
builder.create<triton::BroadcastOp>(broadcast.getLoc(), newType, src);
// we don't want to change the encoding of the result
Value newDst = builder.create<triton::gpu::ConvertLayoutOp>(
broadcast.getLoc(), originDstType, newBroadcast
);
broadcast.getLoc(), originDstType, newBroadcast);
broadcast.replaceAllUsesWith(newDst);
mapping.map(broadcast, newDst);