From 78ebbe24c72a527d8ac393c484fffe355d286cf7 Mon Sep 17 00:00:00 2001 From: Philippe Tillet Date: Thu, 4 Aug 2022 18:41:06 -0700 Subject: [PATCH] [FRONTEND] Added `ExpandDimsOp` primitive (#36) --- include/triton/Dialect/Triton/IR/TritonOps.td | 10 +++++ lib/Analysis/AxisInfo.cpp | 32 ++++--------- .../TritonToTritonGPU/TritonToTritonGPU.cpp | 45 ++++++++++++++++++- python/src/triton.cc | 13 +++++- python/triton/language/core.py | 20 +++++---- python/triton/language/semantic.py | 7 +++ test/Analysis/test-alignment.mlir | 8 ++-- test/TritonGPU/coalesce.mlir | 4 +- 8 files changed, 98 insertions(+), 41 deletions(-) diff --git a/include/triton/Dialect/Triton/IR/TritonOps.td b/include/triton/Dialect/Triton/IR/TritonOps.td index ba65e11fe..06f0d406b 100644 --- a/include/triton/Dialect/Triton/IR/TritonOps.td +++ b/include/triton/Dialect/Triton/IR/TritonOps.td @@ -133,6 +133,16 @@ def TT_GEPOp : TT_Op<"getelementptr", // // Shape Manipulation Ops // +def TT_ExpandDimsOp : TT_Op<"expand_dims", [NoSideEffect, SameOperandsAndResultElementType]> { + let summary = "expand_dims"; + + let arguments = (ins TT_Tensor:$src, I32Attr:$axis); + + let results = (outs TT_Tensor:$result); + + let assemblyFormat = "$src attr-dict `:` functional-type(operands, results)"; +} + def TT_ViewOp : TT_Op<"view", [NoSideEffect, SameOperandsAndResultElementType]> { let summary = "view"; diff --git a/lib/Analysis/AxisInfo.cpp b/lib/Analysis/AxisInfo.cpp index 5fa10769c..6bdcc47fa 100644 --- a/lib/Analysis/AxisInfo.cpp +++ b/lib/Analysis/AxisInfo.cpp @@ -167,8 +167,8 @@ ChangeResult AxisInfoAnalysis::visitOperation( } curr = AxisInfo(contiguity, divisibility, constancy); } - // Reshape - if (llvm::isa(op)) { + // expandDims + if (auto expandDims = llvm::dyn_cast(op)) { Type _retTy = *op->result_type_begin(); Type _opTy = *op->operand_type_begin(); TensorType retTy = _retTy.cast(); @@ -176,28 +176,12 @@ ChangeResult AxisInfoAnalysis::visitOperation( ArrayRef retShape = retTy.getShape(); ArrayRef opShape = opTy.getShape(); AxisInfo opInfo = operands[0]->getValue(); - AxisInfo::DimVectorT contiguity; - AxisInfo::DimVectorT divisibility; - AxisInfo::DimVectorT constancy; - bool is_skewed = false; - size_t current = 0; - for (size_t d = 0; d < retTy.getRank(); d++) { - if (retShape[d] == 1) { - contiguity.push_back(1); - divisibility.push_back(1); - constancy.push_back(1); - } else if (!is_skewed && retShape[d] == opShape[current]) { - contiguity.push_back(opInfo.getContiguity()[current]); - divisibility.push_back(opInfo.getDivisibility()[current]); - constancy.push_back(opInfo.getConstancy()[current]); - current++; - } else { - is_skewed = true; - contiguity.push_back(1); - divisibility.push_back(1); - constancy.push_back(1); - } - } + AxisInfo::DimVectorT contiguity = opInfo.getContiguity(); + AxisInfo::DimVectorT divisibility = opInfo.getDivisibility(); + AxisInfo::DimVectorT constancy = opInfo.getConstancy(); + contiguity.insert(contiguity.begin() + expandDims.axis(), 1); + divisibility.insert(divisibility.begin() + expandDims.axis(), 1); + constancy.insert(constancy.begin() + expandDims.axis(), 1); curr = AxisInfo(contiguity, divisibility, constancy); } // Broadcast diff --git a/lib/Conversion/TritonToTritonGPU/TritonToTritonGPU.cpp b/lib/Conversion/TritonToTritonGPU/TritonToTritonGPU.cpp index 4147a6256..a8a4ba807 100644 --- a/lib/Conversion/TritonToTritonGPU/TritonToTritonGPU.cpp +++ b/lib/Conversion/TritonToTritonGPU/TritonToTritonGPU.cpp @@ -5,6 +5,7 @@ #include "triton/Dialect/Triton/IR/Dialect.h" #include "triton/Dialect/TritonGPU/IR/Dialect.h" #include "triton/Dialect/TritonGPU/Transforms/TritonGPUConversion.h" +#include using namespace mlir; using namespace mlir::triton; @@ -142,6 +143,46 @@ struct TritonMakeRangePattern } }; +struct TritonExpandDimsPattern + : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(triton::ExpandDimsOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + // Type retType = op.getType()); + RankedTensorType argType = adaptor.src().getType().cast(); + Attribute _argEncoding = argType.getEncoding(); + if (!_argEncoding) + return failure(); + auto argEncoding = + _argEncoding.cast(); + // return shape + auto retShape = argType.getShape().vec(); + retShape.insert(retShape.begin() + op.axis(), 1); + // return encoding + auto retSizePerThread = argEncoding.getSizePerThread().vec(); + retSizePerThread.insert(retSizePerThread.begin() + op.axis(), 1); + auto retThreadsPerWarp = argEncoding.getThreadsPerWarp().vec(); + retThreadsPerWarp.insert(retThreadsPerWarp.begin() + op.axis(), 1); + auto retWarpsPerCTA = argEncoding.getWarpsPerCTA().vec(); + retWarpsPerCTA.insert(retWarpsPerCTA.begin() + op.axis(), 1); + SmallVector retOrder(retShape.size()); + std::iota(retOrder.begin(), retOrder.end(), 0); + triton::gpu::TritonGPUBlockedEncodingAttr retEncoding = + triton::gpu::TritonGPUBlockedEncodingAttr::get( + getContext(), retSizePerThread, retThreadsPerWarp, retWarpsPerCTA, + retOrder); + // return type + RankedTensorType retType = + RankedTensorType::get(retShape, argType.getElementType(), retEncoding); + // construct new op + rewriter.replaceOpWithNewOp( + op, retType, adaptor.src(), adaptor.axis()); + return success(); + } +}; + struct TritonDotPattern : public OpConversionPattern { using OpConversionPattern::OpConversionPattern; @@ -260,8 +301,8 @@ void populateTritonPatterns(TritonGPUTypeConverter &typeConverter, TritonGenericPattern, TritonGenericPattern, TritonBroadcastPattern, TritonGenericPattern, TritonReducePattern, - TritonMakeRangePattern, TritonDotPattern, TritonLoadPattern, - TritonStorePattern>(typeConverter, context); + TritonExpandDimsPattern, TritonMakeRangePattern, TritonDotPattern, + TritonLoadPattern, TritonStorePattern>(typeConverter, context); } // diff --git a/python/src/triton.cc b/python/src/triton.cc index 91c5ef739..16c8e0b66 100644 --- a/python/src/triton.cc +++ b/python/src/triton.cc @@ -1471,7 +1471,6 @@ void init_triton_ir(py::module &&m) { auto loc = self.getUnknownLoc(); self.create(loc, ptrs, val, mask); }) - // Block instruction .def("create_view", [](mlir::OpBuilder &self, mlir::Value &arg, std::vector &shape) -> mlir::Value { @@ -1482,6 +1481,18 @@ void init_triton_ir(py::module &&m) { return self.create( loc, mlir::RankedTensorType::get(shape, argType), arg); }) + .def( + "create_expand_dims", + [](mlir::OpBuilder &self, mlir::Value &arg, int axis) -> mlir::Value { + auto loc = self.getUnknownLoc(); + auto argType = arg.getType().dyn_cast(); + auto argEltType = argType.getElementType(); + std::vector retShape = argType.getShape(); + retShape.insert(retShape.begin() + axis, 1); + return self.create( + loc, mlir::RankedTensorType::get(retShape, argEltType), arg, + axis); + }) .def("create_cat", [](mlir::OpBuilder &self, mlir::Value &lhs, mlir::Value &rhs) -> mlir::Value { diff --git a/python/triton/language/core.py b/python/triton/language/core.py index 477633d2c..df5477db2 100644 --- a/python/triton/language/core.py +++ b/python/triton/language/core.py @@ -556,18 +556,22 @@ class tensor: def __getitem__(self, slices, _builder=None): if isinstance(slices, slice): slices = [slices] - src_shape = self.shape - dst_shape = [] - curr = 0 - for sl in slices: + ret = self + n_inserted = 0 + for dim, sl in enumerate(slices): if isinstance(sl, constexpr) and sl.value is None: - dst_shape.append(1) + ret = semantic.expand_dims(ret, dim + n_inserted, _builder) + n_inserted += 1 elif sl == slice(None, None, None): - dst_shape.append(src_shape[curr].value) - curr += 1 - ret = semantic.view(self, dst_shape, _builder) + pass + else: + assert False, "unsupported" return ret + # x[:, None, :, None] + # x = expand_dims(x, axis=1) + # x = expand_dims(x, axis=2) + @builtin def to(self, dtype, bitcast=False, _builder=None): if isinstance(bitcast, constexpr): diff --git a/python/triton/language/semantic.py b/python/triton/language/semantic.py index a3d0d3385..45ad9e5ea 100644 --- a/python/triton/language/semantic.py +++ b/python/triton/language/semantic.py @@ -463,6 +463,13 @@ def view(input: tl.tensor, return tl.tensor(builder.create_view(input.handle, dst_shape), ret_ty) +def expand_dims(input: tl.tensor, axis: int, builder: ir.builder) -> tl.tensor: + dst_shape = [s for s in input.type.shape] + dst_shape.insert(axis, 1) + ret_ty = tl.block_type(input.type.scalar, dst_shape) + return tl.tensor(builder.create_expand_dims(input.handle, axis), ret_ty) + + def cat(lhs: tl.tensor, rhs: tl.tensor, builder: ir.builder) -> tl.tensor: # TODO: check types return tl.tensor(builder.create_cat(lhs.handle, rhs.handle), lhs.type) diff --git a/test/Analysis/test-alignment.mlir b/test/Analysis/test-alignment.mlir index 041b88db2..f9f922846 100644 --- a/test/Analysis/test-alignment.mlir +++ b/test/Analysis/test-alignment.mlir @@ -10,7 +10,7 @@ func @permute_2d(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: i32 {t // CHECK-NEXT: Contiguity: [128] ; Divisibility: [65536] ; Constancy: [1] %1 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32> // CHECK-NEXT: Contiguity: [128, 1] ; Divisibility: [65536, 1] ; Constancy: [1, 1] - %2 = tt.view %0 : (tensor<128xi32>) -> tensor<128x1xi32> + %2 = tt.expand_dims %0 {axis = 1 : i32} : (tensor<128xi32>) -> tensor<128x1xi32> // CHECK-NEXT: Contiguity: [1, 1] ; Divisibility: [16, 16] ; Constancy: [128, 1] %3 = tt.splat %arg1 : (i32) -> tensor<128x1xi32> // CHECK-NEXT: Contiguity: [1, 1] ; Divisibility: [1048576, 16] ; Constancy: [1, 1] @@ -20,7 +20,7 @@ func @permute_2d(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: i32 {t // CHECK-NEXT: Contiguity: [1, 1] ; Divisibility: [16, 16] ; Constancy: [1, 1] %6 = tt.getelementptr %5, %4 : tensor<128x1x!tt.ptr> // CHECK-NEXT: Contiguity: [1, 128] ; Divisibility: [1, 65536] ; Constancy: [1, 1] - %7 = tt.view %1 : (tensor<128xi32>) -> tensor<1x128xi32> + %7 = tt.expand_dims %1 {axis = 0 : i32}: (tensor<128xi32>) -> tensor<1x128xi32> // CHECK-NEXT: Contiguity: [1, 1] ; Divisibility: [16, 16] ; Constancy: [1, 128] %8 = tt.broadcast %6 : (tensor<128x1x!tt.ptr>) -> tensor<128x128x!tt.ptr> // CHECK-NEXT: Contiguity: [1, 128] ; Divisibility: [1, 65536] ; Constancy: [128, 1] @@ -28,13 +28,13 @@ func @permute_2d(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: i32 {t // CHECK-NEXT: Contiguity: [1, 128] ; Divisibility: [1, 16] ; Constancy: [1, 1] %10 = tt.getelementptr %8, %9 : tensor<128x128x!tt.ptr> // CHECK-NEXT: Contiguity: [128, 1] ; Divisibility: [65536, 1] ; Constancy: [1, 1] - %11 = tt.view %0 : (tensor<128xi32>) -> tensor<128x1xi32> + %11 = tt.expand_dims %0 {axis = 1 : i32}: (tensor<128xi32>) -> tensor<128x1xi32> // CHECK-NEXT: Contiguity: [1, 1] ; Divisibility: [16, 16] ; Constancy: [128, 1] %12 = tt.splat %arg2 : (!tt.ptr) -> tensor<128x1x!tt.ptr> // CHECK-NEXT: Contiguity: [128, 1] ; Divisibility: [16, 1] ; Constancy: [1, 1] %13 = tt.getelementptr %12, %11 : tensor<128x1x!tt.ptr> // CHECK-NEXT: Contiguity: [1, 128] ; Divisibility: [1, 65536] ; Constancy: [1, 1] - %14 = tt.view %1 : (tensor<128xi32>) -> tensor<1x128xi32> + %14 = tt.expand_dims %1 {axis = 0 : i32} : (tensor<128xi32>) -> tensor<1x128xi32> // CHECK-NEXT: Contiguity: [1, 1] ; Divisibility: [16, 16] ; Constancy: [1, 128] %15 = tt.splat %arg3 : (i32) -> tensor<1x128xi32> // CHECK-NEXT: Contiguity: [1, 1] ; Divisibility: [16, 1048576] ; Constancy: [1, 1] diff --git a/test/TritonGPU/coalesce.mlir b/test/TritonGPU/coalesce.mlir index 1bd8e73d8..d55e5ed76 100644 --- a/test/TritonGPU/coalesce.mlir +++ b/test/TritonGPU/coalesce.mlir @@ -22,12 +22,12 @@ func @transpose(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %cst = arith.constant dense : tensor<64x64xi1, #blocked1> %cst_0 = arith.constant dense<0.000000e+00> : tensor<64x64xf32, #blocked1> %0 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #blocked0> - %1 = tt.view %0 : (tensor<64xi32, #blocked0>) -> tensor<64x1xi32, #blocked1> + %1 = tt.expand_dims %0 {axis = 1 : i32} : (tensor<64xi32, #blocked0>) -> tensor<64x1xi32, #blocked1> %2 = tt.splat %arg1 : (i32) -> tensor<64x1xi32, #blocked1> %3 = arith.muli %1, %2 : tensor<64x1xi32, #blocked1> %4 = tt.splat %arg0 : (!tt.ptr) -> tensor<64x1x!tt.ptr, #blocked1> %5 = tt.getelementptr %4, %3 : tensor<64x1x!tt.ptr, #blocked1> - %6 = tt.view %0 : (tensor<64xi32, #blocked0>) -> tensor<1x64xi32, #blocked2> + %6 = tt.expand_dims %0 {axis = 0 : i32} : (tensor<64xi32, #blocked0>) -> tensor<1x64xi32, #blocked2> %7 = tt.broadcast %5 : (tensor<64x1x!tt.ptr, #blocked1>) -> tensor<64x64x!tt.ptr, #blocked1> %8 = tt.broadcast %6 : (tensor<1x64xi32, #blocked2>) -> tensor<64x64xi32, #blocked2> %9 = triton_gpu.convert_layout %8 : (tensor<64x64xi32, #blocked2>) -> tensor<64x64xi32, #blocked1>