[FRONTEND] Added ExpandDimsOp primitive (#36)

This commit is contained in:
Philippe Tillet
2022-08-04 18:41:06 -07:00
committed by GitHub
parent a7b49b3227
commit 78ebbe24c7
8 changed files with 98 additions and 41 deletions

View File

@@ -133,6 +133,16 @@ def TT_GEPOp : TT_Op<"getelementptr",
//
// Shape Manipulation Ops
//
def TT_ExpandDimsOp : TT_Op<"expand_dims", [NoSideEffect, SameOperandsAndResultElementType]> {
let summary = "expand_dims";
let arguments = (ins TT_Tensor:$src, I32Attr:$axis);
let results = (outs TT_Tensor:$result);
let assemblyFormat = "$src attr-dict `:` functional-type(operands, results)";
}
def TT_ViewOp : TT_Op<"view", [NoSideEffect, SameOperandsAndResultElementType]> {
let summary = "view";

View File

@@ -167,8 +167,8 @@ ChangeResult AxisInfoAnalysis::visitOperation(
}
curr = AxisInfo(contiguity, divisibility, constancy);
}
// Reshape
if (llvm::isa<triton::ViewOp>(op)) {
// expandDims
if (auto expandDims = llvm::dyn_cast<triton::ExpandDimsOp>(op)) {
Type _retTy = *op->result_type_begin();
Type _opTy = *op->operand_type_begin();
TensorType retTy = _retTy.cast<TensorType>();
@@ -176,28 +176,12 @@ ChangeResult AxisInfoAnalysis::visitOperation(
ArrayRef<int64_t> retShape = retTy.getShape();
ArrayRef<int64_t> opShape = opTy.getShape();
AxisInfo opInfo = operands[0]->getValue();
AxisInfo::DimVectorT contiguity;
AxisInfo::DimVectorT divisibility;
AxisInfo::DimVectorT constancy;
bool is_skewed = false;
size_t current = 0;
for (size_t d = 0; d < retTy.getRank(); d++) {
if (retShape[d] == 1) {
contiguity.push_back(1);
divisibility.push_back(1);
constancy.push_back(1);
} else if (!is_skewed && retShape[d] == opShape[current]) {
contiguity.push_back(opInfo.getContiguity()[current]);
divisibility.push_back(opInfo.getDivisibility()[current]);
constancy.push_back(opInfo.getConstancy()[current]);
current++;
} else {
is_skewed = true;
contiguity.push_back(1);
divisibility.push_back(1);
constancy.push_back(1);
}
}
AxisInfo::DimVectorT contiguity = opInfo.getContiguity();
AxisInfo::DimVectorT divisibility = opInfo.getDivisibility();
AxisInfo::DimVectorT constancy = opInfo.getConstancy();
contiguity.insert(contiguity.begin() + expandDims.axis(), 1);
divisibility.insert(divisibility.begin() + expandDims.axis(), 1);
constancy.insert(constancy.begin() + expandDims.axis(), 1);
curr = AxisInfo(contiguity, divisibility, constancy);
}
// Broadcast

View File

@@ -5,6 +5,7 @@
#include "triton/Dialect/Triton/IR/Dialect.h"
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
#include "triton/Dialect/TritonGPU/Transforms/TritonGPUConversion.h"
#include <numeric>
using namespace mlir;
using namespace mlir::triton;
@@ -142,6 +143,46 @@ struct TritonMakeRangePattern
}
};
struct TritonExpandDimsPattern
: public OpConversionPattern<triton::ExpandDimsOp> {
using OpConversionPattern<triton::ExpandDimsOp>::OpConversionPattern;
LogicalResult
matchAndRewrite(triton::ExpandDimsOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
// Type retType = op.getType());
RankedTensorType argType = adaptor.src().getType().cast<RankedTensorType>();
Attribute _argEncoding = argType.getEncoding();
if (!_argEncoding)
return failure();
auto argEncoding =
_argEncoding.cast<triton::gpu::TritonGPUBlockedEncodingAttr>();
// return shape
auto retShape = argType.getShape().vec();
retShape.insert(retShape.begin() + op.axis(), 1);
// return encoding
auto retSizePerThread = argEncoding.getSizePerThread().vec();
retSizePerThread.insert(retSizePerThread.begin() + op.axis(), 1);
auto retThreadsPerWarp = argEncoding.getThreadsPerWarp().vec();
retThreadsPerWarp.insert(retThreadsPerWarp.begin() + op.axis(), 1);
auto retWarpsPerCTA = argEncoding.getWarpsPerCTA().vec();
retWarpsPerCTA.insert(retWarpsPerCTA.begin() + op.axis(), 1);
SmallVector<unsigned, 4> retOrder(retShape.size());
std::iota(retOrder.begin(), retOrder.end(), 0);
triton::gpu::TritonGPUBlockedEncodingAttr retEncoding =
triton::gpu::TritonGPUBlockedEncodingAttr::get(
getContext(), retSizePerThread, retThreadsPerWarp, retWarpsPerCTA,
retOrder);
// return type
RankedTensorType retType =
RankedTensorType::get(retShape, argType.getElementType(), retEncoding);
// construct new op
rewriter.replaceOpWithNewOp<triton::ExpandDimsOp>(
op, retType, adaptor.src(), adaptor.axis());
return success();
}
};
struct TritonDotPattern : public OpConversionPattern<triton::DotOp> {
using OpConversionPattern<triton::DotOp>::OpConversionPattern;
@@ -260,8 +301,8 @@ void populateTritonPatterns(TritonGPUTypeConverter &typeConverter,
TritonGenericPattern<triton::ViewOp>,
TritonGenericPattern<triton::SplatOp>, TritonBroadcastPattern,
TritonGenericPattern<triton::GEPOp>, TritonReducePattern,
TritonMakeRangePattern, TritonDotPattern, TritonLoadPattern,
TritonStorePattern>(typeConverter, context);
TritonExpandDimsPattern, TritonMakeRangePattern, TritonDotPattern,
TritonLoadPattern, TritonStorePattern>(typeConverter, context);
}
//

View File

@@ -1471,7 +1471,6 @@ void init_triton_ir(py::module &&m) {
auto loc = self.getUnknownLoc();
self.create<mlir::triton::StoreOp>(loc, ptrs, val, mask);
})
// Block instruction
.def("create_view",
[](mlir::OpBuilder &self, mlir::Value &arg,
std::vector<int64_t> &shape) -> mlir::Value {
@@ -1482,6 +1481,18 @@ void init_triton_ir(py::module &&m) {
return self.create<mlir::triton::ViewOp>(
loc, mlir::RankedTensorType::get(shape, argType), arg);
})
.def(
"create_expand_dims",
[](mlir::OpBuilder &self, mlir::Value &arg, int axis) -> mlir::Value {
auto loc = self.getUnknownLoc();
auto argType = arg.getType().dyn_cast<mlir::RankedTensorType>();
auto argEltType = argType.getElementType();
std::vector<int64_t> retShape = argType.getShape();
retShape.insert(retShape.begin() + axis, 1);
return self.create<mlir::triton::ExpandDimsOp>(
loc, mlir::RankedTensorType::get(retShape, argEltType), arg,
axis);
})
.def("create_cat",
[](mlir::OpBuilder &self, mlir::Value &lhs,
mlir::Value &rhs) -> mlir::Value {

View File

@@ -556,18 +556,22 @@ class tensor:
def __getitem__(self, slices, _builder=None):
if isinstance(slices, slice):
slices = [slices]
src_shape = self.shape
dst_shape = []
curr = 0
for sl in slices:
ret = self
n_inserted = 0
for dim, sl in enumerate(slices):
if isinstance(sl, constexpr) and sl.value is None:
dst_shape.append(1)
ret = semantic.expand_dims(ret, dim + n_inserted, _builder)
n_inserted += 1
elif sl == slice(None, None, None):
dst_shape.append(src_shape[curr].value)
curr += 1
ret = semantic.view(self, dst_shape, _builder)
pass
else:
assert False, "unsupported"
return ret
# x[:, None, :, None]
# x = expand_dims(x, axis=1)
# x = expand_dims(x, axis=2)
@builtin
def to(self, dtype, bitcast=False, _builder=None):
if isinstance(bitcast, constexpr):

View File

@@ -463,6 +463,13 @@ def view(input: tl.tensor,
return tl.tensor(builder.create_view(input.handle, dst_shape), ret_ty)
def expand_dims(input: tl.tensor, axis: int, builder: ir.builder) -> tl.tensor:
dst_shape = [s for s in input.type.shape]
dst_shape.insert(axis, 1)
ret_ty = tl.block_type(input.type.scalar, dst_shape)
return tl.tensor(builder.create_expand_dims(input.handle, axis), ret_ty)
def cat(lhs: tl.tensor, rhs: tl.tensor, builder: ir.builder) -> tl.tensor:
# TODO: check types
return tl.tensor(builder.create_cat(lhs.handle, rhs.handle), lhs.type)

View File

@@ -10,7 +10,7 @@ func @permute_2d(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg1: i32 {t
// CHECK-NEXT: Contiguity: [128] ; Divisibility: [65536] ; Constancy: [1]
%1 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32>
// CHECK-NEXT: Contiguity: [128, 1] ; Divisibility: [65536, 1] ; Constancy: [1, 1]
%2 = tt.view %0 : (tensor<128xi32>) -> tensor<128x1xi32>
%2 = tt.expand_dims %0 {axis = 1 : i32} : (tensor<128xi32>) -> tensor<128x1xi32>
// CHECK-NEXT: Contiguity: [1, 1] ; Divisibility: [16, 16] ; Constancy: [128, 1]
%3 = tt.splat %arg1 : (i32) -> tensor<128x1xi32>
// CHECK-NEXT: Contiguity: [1, 1] ; Divisibility: [1048576, 16] ; Constancy: [1, 1]
@@ -20,7 +20,7 @@ func @permute_2d(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg1: i32 {t
// CHECK-NEXT: Contiguity: [1, 1] ; Divisibility: [16, 16] ; Constancy: [1, 1]
%6 = tt.getelementptr %5, %4 : tensor<128x1x!tt.ptr<f32>>
// CHECK-NEXT: Contiguity: [1, 128] ; Divisibility: [1, 65536] ; Constancy: [1, 1]
%7 = tt.view %1 : (tensor<128xi32>) -> tensor<1x128xi32>
%7 = tt.expand_dims %1 {axis = 0 : i32}: (tensor<128xi32>) -> tensor<1x128xi32>
// CHECK-NEXT: Contiguity: [1, 1] ; Divisibility: [16, 16] ; Constancy: [1, 128]
%8 = tt.broadcast %6 : (tensor<128x1x!tt.ptr<f32>>) -> tensor<128x128x!tt.ptr<f32>>
// CHECK-NEXT: Contiguity: [1, 128] ; Divisibility: [1, 65536] ; Constancy: [128, 1]
@@ -28,13 +28,13 @@ func @permute_2d(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg1: i32 {t
// CHECK-NEXT: Contiguity: [1, 128] ; Divisibility: [1, 16] ; Constancy: [1, 1]
%10 = tt.getelementptr %8, %9 : tensor<128x128x!tt.ptr<f32>>
// CHECK-NEXT: Contiguity: [128, 1] ; Divisibility: [65536, 1] ; Constancy: [1, 1]
%11 = tt.view %0 : (tensor<128xi32>) -> tensor<128x1xi32>
%11 = tt.expand_dims %0 {axis = 1 : i32}: (tensor<128xi32>) -> tensor<128x1xi32>
// CHECK-NEXT: Contiguity: [1, 1] ; Divisibility: [16, 16] ; Constancy: [128, 1]
%12 = tt.splat %arg2 : (!tt.ptr<f32>) -> tensor<128x1x!tt.ptr<f32>>
// CHECK-NEXT: Contiguity: [128, 1] ; Divisibility: [16, 1] ; Constancy: [1, 1]
%13 = tt.getelementptr %12, %11 : tensor<128x1x!tt.ptr<f32>>
// CHECK-NEXT: Contiguity: [1, 128] ; Divisibility: [1, 65536] ; Constancy: [1, 1]
%14 = tt.view %1 : (tensor<128xi32>) -> tensor<1x128xi32>
%14 = tt.expand_dims %1 {axis = 0 : i32} : (tensor<128xi32>) -> tensor<1x128xi32>
// CHECK-NEXT: Contiguity: [1, 1] ; Divisibility: [16, 16] ; Constancy: [1, 128]
%15 = tt.splat %arg3 : (i32) -> tensor<1x128xi32>
// CHECK-NEXT: Contiguity: [1, 1] ; Divisibility: [16, 1048576] ; Constancy: [1, 1]

View File

@@ -22,12 +22,12 @@ func @transpose(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32},
%cst = arith.constant dense<true> : tensor<64x64xi1, #blocked1>
%cst_0 = arith.constant dense<0.000000e+00> : tensor<64x64xf32, #blocked1>
%0 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #blocked0>
%1 = tt.view %0 : (tensor<64xi32, #blocked0>) -> tensor<64x1xi32, #blocked1>
%1 = tt.expand_dims %0 {axis = 1 : i32} : (tensor<64xi32, #blocked0>) -> tensor<64x1xi32, #blocked1>
%2 = tt.splat %arg1 : (i32) -> tensor<64x1xi32, #blocked1>
%3 = arith.muli %1, %2 : tensor<64x1xi32, #blocked1>
%4 = tt.splat %arg0 : (!tt.ptr<f32>) -> tensor<64x1x!tt.ptr<f32>, #blocked1>
%5 = tt.getelementptr %4, %3 : tensor<64x1x!tt.ptr<f32>, #blocked1>
%6 = tt.view %0 : (tensor<64xi32, #blocked0>) -> tensor<1x64xi32, #blocked2>
%6 = tt.expand_dims %0 {axis = 0 : i32} : (tensor<64xi32, #blocked0>) -> tensor<1x64xi32, #blocked2>
%7 = tt.broadcast %5 : (tensor<64x1x!tt.ptr<f32>, #blocked1>) -> tensor<64x64x!tt.ptr<f32>, #blocked1>
%8 = tt.broadcast %6 : (tensor<1x64xi32, #blocked2>) -> tensor<64x64xi32, #blocked2>
%9 = triton_gpu.convert_layout %8 : (tensor<64x64xi32, #blocked2>) -> tensor<64x64xi32, #blocked1>