[FRONTEND][BACKEND] Clean-up transpositions (#953)

This commit is contained in:
Philippe Tillet
2022-12-06 09:32:13 -08:00
committed by GitHub
parent 16e973edf2
commit 532e10cf87
12 changed files with 31 additions and 53 deletions

View File

@@ -339,7 +339,7 @@ def TT_DotOp : TT_Op<"dot", [NoSideEffect,
$d = matrix_multiply($a, $b) + $c
}];
let arguments = (ins TT_FpIntTensor:$a, TT_FpIntTensor:$b, TT_FpIntTensor:$c, BoolAttr:$allowTF32, BoolAttr:$transA, BoolAttr:$transB);
let arguments = (ins TT_FpIntTensor:$a, TT_FpIntTensor:$b, TT_FpIntTensor:$c, BoolAttr:$allowTF32);
let results = (outs TT_FpIntTensor:$d);

View File

@@ -942,11 +942,6 @@ struct MMA16816ConversionHelper {
SmallVector<int64_t> shape(aTensorTy.getShape().begin(),
aTensorTy.getShape().end());
// TODO[Superjomn]: transA cannot be accessed in ConvertLayoutOp.
bool transA = false;
if (transA) {
std::swap(shape[0], shape[1]);
}
ValueTable ha;
std::function<void(int, int)> loadFn;
@@ -1052,8 +1047,6 @@ struct MMA16816ConversionHelper {
SmallVector<int64_t> aShape(aTensorTy.getShape().begin(),
aTensorTy.getShape().end());
if (op.transA())
std::swap(aShape[0], aShape[1]);
auto dShape = dTensorTy.getShape();
@@ -1462,8 +1455,8 @@ Value DotOpMmaV1ConversionHelper::loadB(
SmallVector<unsigned> order(sharedLayout.getOrder().begin(),
sharedLayout.getOrder().end());
Value smem = smemObj.getBaseBeforeSwizzle(order[0], loc, rewriter);
Value smem = smemObj.getBaseBeforeSwizzle(order[0], loc, rewriter);
bool isBRow = order[0] != 0;
bool isBVec4 = isBRow && shape[order[0]] <= 16;
// TODO[Superjomn]: Support the case when isBVec4=false later

View File

@@ -3253,7 +3253,7 @@ struct DotOpConversion : public ConvertTritonGPUOpToLLVMPattern<triton::DotOp> {
// Here we assume the DotOp's operands always comes from shared memory.
auto AShape = A.getType().cast<RankedTensorType>().getShape();
size_t reduceAxis = op.transA() ? 0 : 1;
size_t reduceAxis = 1;
unsigned K = AShape[reduceAxis];
bool isOuter = K == 1;
@@ -3492,22 +3492,10 @@ DotOpConversion::convertMMA884(triton::DotOp op, DotOpAdaptor adaptor,
auto DShape = DTensorTy.getShape();
auto wpt = mmaLayout.getWarpsPerCTA();
bool transA = op.transA();
bool transB = op.transB();
// TODO[Superjomn]: order cannot accessed in DotOp.
SmallVector<unsigned> AOrder({1, 0});
SmallVector<unsigned> BOrder({1, 0});
if (transA) {
std::swap(AShape[0], AShape[1]);
std::swap(AOrder[0], AOrder[1]);
}
if (transB) {
std::swap(BShape[0], BShape[1]);
std::swap(BOrder[0], BOrder[0]);
}
bool isARow = AOrder[0] != 0;
bool isBRow = BOrder[0] != 0;
bool isAVec4 = !isARow && AShape[isARow] <= 16; // fp16*4 = 16bytes

View File

@@ -245,9 +245,8 @@ struct TritonDotPattern : public OpConversionPattern<triton::DotOp> {
bType.getElementType(), encoding);
b = rewriter.create<triton::gpu::ConvertLayoutOp>(b.getLoc(), dstType, b);
}
rewriter.replaceOpWithNewOp<triton::DotOp>(
op, retType, a, b, adaptor.c(), adaptor.allowTF32(), adaptor.transA(),
adaptor.transB());
rewriter.replaceOpWithNewOp<triton::DotOp>(op, retType, a, b, adaptor.c(),
adaptor.allowTF32());
return success();
}
};

View File

@@ -12,21 +12,21 @@ include "triton/Dialect/Triton/IR/TritonOps.td"
// AddIOp(d, DotOp(a, b, c)) and c==0 => DotOp(a, b, d)
// AddFOp(d, DotOp(a, b, c)) and c==0 => DotOp(a, b, d)
def CombineDotAddIPattern : Pat<
(Arith_AddIOp $d, (TT_DotOp:$res $a, $b, $c, $allowTF32, $transA, $transB)),
(TT_DotOp $a, $b, $d, $allowTF32, $transA, $transB),
(Arith_AddIOp $d, (TT_DotOp:$res $a, $b, $c, $allowTF32)),
(TT_DotOp $a, $b, $d, $allowTF32),
[(Constraint<CPred<"isZero($0)">> $c)]>;
def CombineDotAddFPattern : Pat<
(Arith_AddFOp $d, (TT_DotOp:$res $a, $b, $c, $allowTF32, $transA, $transB)),
(TT_DotOp $a, $b, $d, $allowTF32, $transA, $transB),
(Arith_AddFOp $d, (TT_DotOp:$res $a, $b, $c, $allowTF32)),
(TT_DotOp $a, $b, $d, $allowTF32),
[(Constraint<CPred<"isZero($0)">> $c)]>;
def CombineDotAddIRevPattern : Pat<
(Arith_AddIOp (TT_DotOp:$res $a, $b, $c, $allowTF32, $transA, $transB), $d),
(TT_DotOp $a, $b, $d, $allowTF32, $transA, $transB),
(Arith_AddIOp (TT_DotOp:$res $a, $b, $c, $allowTF32), $d),
(TT_DotOp $a, $b, $d, $allowTF32),
[(Constraint<CPred<"isZero($0)">> $c)]>;
def CombineDotAddFRevPattern : Pat<
(Arith_AddFOp (TT_DotOp:$res $a, $b, $c, $allowTF32, $transA, $transB), $d),
(TT_DotOp $a, $b, $d, $allowTF32, $transA, $transB),
(Arith_AddFOp (TT_DotOp:$res $a, $b, $c, $allowTF32), $d),
(TT_DotOp $a, $b, $d, $allowTF32),
[(Constraint<CPred<"isZero($0)">> $c)]>;

View File

@@ -781,8 +781,7 @@ public:
a = rewriter.create<triton::gpu::ConvertLayoutOp>(a.getLoc(), newAType, a);
b = rewriter.create<triton::gpu::ConvertLayoutOp>(b.getLoc(), newBType, b);
auto newDot = rewriter.create<triton::DotOp>(
dotOp.getLoc(), newRetType, a, b, newAcc, dotOp.allowTF32(),
dotOp.transA(), dotOp.transB());
dotOp.getLoc(), newRetType, a, b, newAcc, dotOp.allowTF32());
rewriter.replaceOpWithNewOp<triton::gpu::ConvertLayoutOp>(
op, oldRetType, newDot.getResult());

View File

@@ -1184,11 +1184,10 @@ void init_triton_ir(py::module &&m) {
})
.def("create_dot",
[](mlir::OpBuilder &self, mlir::Value &a, mlir::Value &b,
mlir::Value &c, bool allowTF32, bool transA,
bool transB) -> mlir::Value {
mlir::Value &c, bool allowTF32) -> mlir::Value {
auto loc = self.getUnknownLoc();
return self.create<mlir::triton::DotOp>(loc, c.getType(), a, b, c,
allowTF32, transA, transB);
allowTF32);
})
.def("create_exp",
[](mlir::OpBuilder &self, mlir::Value &val) -> mlir::Value {

View File

@@ -1126,7 +1126,7 @@ def test_dot(epilogue, allow_tf32, dtype, device='cuda'):
if CHAIN_DOT:
# tl.store(Zs, z)
# tl.debug_barrier()
z = tl.dot(tl.trans(z.to(tl.float16)), tl.load(Ws))
z = tl.dot(z.to(tl.float16), tl.load(Ws))
tl.store(Zs, z)
# input
rs = RandomState(17)
@@ -1173,7 +1173,7 @@ def test_dot(epilogue, allow_tf32, dtype, device='cuda'):
denom = np.sum(num, axis=-1, keepdims=True)
z_ref = num / denom
if epilogue == 'chain-dot':
z_ref = np.matmul(z_ref.T, w)
z_ref = np.matmul(z_ref, w)
# compare
# print(z_ref[:,0], z_tri[:,0])
np.testing.assert_allclose(z_ref, to_numpy(z_tri), rtol=0.01)

View File

@@ -762,6 +762,9 @@ class CodeGenerator(ast.NodeVisitor):
def visit_Attribute(self, node):
lhs = self.visit(node.value)
if isinstance(lhs, triton.language.tensor):
if node.attr == "T":
return triton.language.semantic.trans(lhs, builder=self.builder)
return getattr(lhs, node.attr)
def visit_Expr(self, node):

View File

@@ -603,10 +603,9 @@ class tensor:
assert False, "unsupported"
return ret
# x[:, None, :, None]
# x = expand_dims(x, axis=1)
# x = expand_dims(x, axis=2)
@property
def T(self):
assert False, "Transposition must be created by the AST Visitor"
@builtin
def to(self, dtype, bitcast=False, _builder=None):
@@ -766,7 +765,7 @@ def view(input, shape, _builder=None):
@builtin
def dot(input, other, allow_tf32=True, trans_a=False, trans_b=False, _builder=None):
def dot(input, other, allow_tf32=True, _builder=None):
"""
Returns the matrix product of two blocks.
@@ -778,7 +777,7 @@ def dot(input, other, allow_tf32=True, trans_a=False, trans_b=False, _builder=No
:type other: 2D tensor of scalar-type in {:code:`float16`, :code:`bfloat16`, :code:`float32`}
"""
allow_tf32 = _constexpr_to_value(allow_tf32)
return semantic.dot(input, other, allow_tf32, trans_a, trans_b, _builder)
return semantic.dot(input, other, allow_tf32, _builder)
# -----------------------

View File

@@ -976,8 +976,6 @@ def atomic_xchg(ptr: tl.tensor,
def dot(lhs: tl.tensor,
rhs: tl.tensor,
allow_tf32: bool,
trans_a: bool,
trans_b: bool,
builder: ir.builder) -> tl.tensor:
assert lhs.type.is_block() and rhs.type.is_block()
if lhs.type.scalar.is_int():
@@ -986,11 +984,11 @@ def dot(lhs: tl.tensor,
else:
_0 = builder.get_float32(0)
ret_scalar_ty = tl.float32
M = lhs.type.shape[1 if trans_a else 0]
N = rhs.type.shape[0 if trans_b else 1]
M = lhs.type.shape[0]
N = rhs.type.shape[1]
_0 = builder.create_splat(_0, [M, N])
ret_ty = tl.block_type(ret_scalar_ty, [M, N])
return tl.tensor(builder.create_dot(lhs.handle, rhs.handle, _0, allow_tf32, trans_a, trans_b),
return tl.tensor(builder.create_dot(lhs.handle, rhs.handle, _0, allow_tf32),
ret_ty)

View File

@@ -13,10 +13,10 @@ func @test_combine_dot_add_pattern() -> (tensor<128x128xf32>, tensor<128x128xf32
%dot_out = tt.dot %a, %b, %zero {allowTF32 = true, transA = false, transB = false} : tensor<128x128xf32> * tensor<128x128xf32> -> tensor<128x128xf32>
// CHECK-NEXT: %[[res0:.*]] = tt.dot %[[a]], %[[b]], %[[d]] {allowTF32 = true, transA = false, transB = false} : tensor<128x128xf32> * tensor<128x128xf32> -> tensor<128x128xf32>
// CHECK-NEXT: %[[res0:.*]] = tt.dot %[[a]], %[[b]], %[[d]] {allowTF32 = true} : tensor<128x128xf32> * tensor<128x128xf32> -> tensor<128x128xf32>
%res0 = arith.addf %dot_out, %d : tensor<128x128xf32>
// CHECK-NEXT: %[[res1:.*]] = tt.dot %[[a]], %[[b]], %[[d]] {allowTF32 = true, transA = false, transB = false} : tensor<128x128xf32> * tensor<128x128xf32> -> tensor<128x128xf32>
// CHECK-NEXT: %[[res1:.*]] = tt.dot %[[a]], %[[b]], %[[d]] {allowTF32 = true} : tensor<128x128xf32> * tensor<128x128xf32> -> tensor<128x128xf32>
%res1 = arith.addf %d, %dot_out : tensor<128x128xf32>
return %res0, %res1 : tensor<128x128xf32>, tensor<128x128xf32>