[FRONTEND][BACKEND] Clean-up transpositions (#953)
This commit is contained in:
@@ -339,7 +339,7 @@ def TT_DotOp : TT_Op<"dot", [NoSideEffect,
|
||||
$d = matrix_multiply($a, $b) + $c
|
||||
}];
|
||||
|
||||
let arguments = (ins TT_FpIntTensor:$a, TT_FpIntTensor:$b, TT_FpIntTensor:$c, BoolAttr:$allowTF32, BoolAttr:$transA, BoolAttr:$transB);
|
||||
let arguments = (ins TT_FpIntTensor:$a, TT_FpIntTensor:$b, TT_FpIntTensor:$c, BoolAttr:$allowTF32);
|
||||
|
||||
let results = (outs TT_FpIntTensor:$d);
|
||||
|
||||
|
@@ -942,11 +942,6 @@ struct MMA16816ConversionHelper {
|
||||
|
||||
SmallVector<int64_t> shape(aTensorTy.getShape().begin(),
|
||||
aTensorTy.getShape().end());
|
||||
// TODO[Superjomn]: transA cannot be accessed in ConvertLayoutOp.
|
||||
bool transA = false;
|
||||
if (transA) {
|
||||
std::swap(shape[0], shape[1]);
|
||||
}
|
||||
|
||||
ValueTable ha;
|
||||
std::function<void(int, int)> loadFn;
|
||||
@@ -1052,8 +1047,6 @@ struct MMA16816ConversionHelper {
|
||||
|
||||
SmallVector<int64_t> aShape(aTensorTy.getShape().begin(),
|
||||
aTensorTy.getShape().end());
|
||||
if (op.transA())
|
||||
std::swap(aShape[0], aShape[1]);
|
||||
|
||||
auto dShape = dTensorTy.getShape();
|
||||
|
||||
@@ -1462,8 +1455,8 @@ Value DotOpMmaV1ConversionHelper::loadB(
|
||||
SmallVector<unsigned> order(sharedLayout.getOrder().begin(),
|
||||
sharedLayout.getOrder().end());
|
||||
|
||||
Value smem = smemObj.getBaseBeforeSwizzle(order[0], loc, rewriter);
|
||||
|
||||
Value smem = smemObj.getBaseBeforeSwizzle(order[0], loc, rewriter);
|
||||
bool isBRow = order[0] != 0;
|
||||
bool isBVec4 = isBRow && shape[order[0]] <= 16;
|
||||
// TODO[Superjomn]: Support the case when isBVec4=false later
|
||||
|
@@ -3253,7 +3253,7 @@ struct DotOpConversion : public ConvertTritonGPUOpToLLVMPattern<triton::DotOp> {
|
||||
|
||||
// Here we assume the DotOp's operands always comes from shared memory.
|
||||
auto AShape = A.getType().cast<RankedTensorType>().getShape();
|
||||
size_t reduceAxis = op.transA() ? 0 : 1;
|
||||
size_t reduceAxis = 1;
|
||||
unsigned K = AShape[reduceAxis];
|
||||
bool isOuter = K == 1;
|
||||
|
||||
@@ -3492,22 +3492,10 @@ DotOpConversion::convertMMA884(triton::DotOp op, DotOpAdaptor adaptor,
|
||||
auto DShape = DTensorTy.getShape();
|
||||
auto wpt = mmaLayout.getWarpsPerCTA();
|
||||
|
||||
bool transA = op.transA();
|
||||
bool transB = op.transB();
|
||||
|
||||
// TODO[Superjomn]: order cannot accessed in DotOp.
|
||||
SmallVector<unsigned> AOrder({1, 0});
|
||||
SmallVector<unsigned> BOrder({1, 0});
|
||||
|
||||
if (transA) {
|
||||
std::swap(AShape[0], AShape[1]);
|
||||
std::swap(AOrder[0], AOrder[1]);
|
||||
}
|
||||
if (transB) {
|
||||
std::swap(BShape[0], BShape[1]);
|
||||
std::swap(BOrder[0], BOrder[0]);
|
||||
}
|
||||
|
||||
bool isARow = AOrder[0] != 0;
|
||||
bool isBRow = BOrder[0] != 0;
|
||||
bool isAVec4 = !isARow && AShape[isARow] <= 16; // fp16*4 = 16bytes
|
||||
|
@@ -245,9 +245,8 @@ struct TritonDotPattern : public OpConversionPattern<triton::DotOp> {
|
||||
bType.getElementType(), encoding);
|
||||
b = rewriter.create<triton::gpu::ConvertLayoutOp>(b.getLoc(), dstType, b);
|
||||
}
|
||||
rewriter.replaceOpWithNewOp<triton::DotOp>(
|
||||
op, retType, a, b, adaptor.c(), adaptor.allowTF32(), adaptor.transA(),
|
||||
adaptor.transB());
|
||||
rewriter.replaceOpWithNewOp<triton::DotOp>(op, retType, a, b, adaptor.c(),
|
||||
adaptor.allowTF32());
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
@@ -12,21 +12,21 @@ include "triton/Dialect/Triton/IR/TritonOps.td"
|
||||
// AddIOp(d, DotOp(a, b, c)) and c==0 => DotOp(a, b, d)
|
||||
// AddFOp(d, DotOp(a, b, c)) and c==0 => DotOp(a, b, d)
|
||||
def CombineDotAddIPattern : Pat<
|
||||
(Arith_AddIOp $d, (TT_DotOp:$res $a, $b, $c, $allowTF32, $transA, $transB)),
|
||||
(TT_DotOp $a, $b, $d, $allowTF32, $transA, $transB),
|
||||
(Arith_AddIOp $d, (TT_DotOp:$res $a, $b, $c, $allowTF32)),
|
||||
(TT_DotOp $a, $b, $d, $allowTF32),
|
||||
[(Constraint<CPred<"isZero($0)">> $c)]>;
|
||||
def CombineDotAddFPattern : Pat<
|
||||
(Arith_AddFOp $d, (TT_DotOp:$res $a, $b, $c, $allowTF32, $transA, $transB)),
|
||||
(TT_DotOp $a, $b, $d, $allowTF32, $transA, $transB),
|
||||
(Arith_AddFOp $d, (TT_DotOp:$res $a, $b, $c, $allowTF32)),
|
||||
(TT_DotOp $a, $b, $d, $allowTF32),
|
||||
[(Constraint<CPred<"isZero($0)">> $c)]>;
|
||||
|
||||
def CombineDotAddIRevPattern : Pat<
|
||||
(Arith_AddIOp (TT_DotOp:$res $a, $b, $c, $allowTF32, $transA, $transB), $d),
|
||||
(TT_DotOp $a, $b, $d, $allowTF32, $transA, $transB),
|
||||
(Arith_AddIOp (TT_DotOp:$res $a, $b, $c, $allowTF32), $d),
|
||||
(TT_DotOp $a, $b, $d, $allowTF32),
|
||||
[(Constraint<CPred<"isZero($0)">> $c)]>;
|
||||
def CombineDotAddFRevPattern : Pat<
|
||||
(Arith_AddFOp (TT_DotOp:$res $a, $b, $c, $allowTF32, $transA, $transB), $d),
|
||||
(TT_DotOp $a, $b, $d, $allowTF32, $transA, $transB),
|
||||
(Arith_AddFOp (TT_DotOp:$res $a, $b, $c, $allowTF32), $d),
|
||||
(TT_DotOp $a, $b, $d, $allowTF32),
|
||||
[(Constraint<CPred<"isZero($0)">> $c)]>;
|
||||
|
||||
|
||||
|
@@ -781,8 +781,7 @@ public:
|
||||
a = rewriter.create<triton::gpu::ConvertLayoutOp>(a.getLoc(), newAType, a);
|
||||
b = rewriter.create<triton::gpu::ConvertLayoutOp>(b.getLoc(), newBType, b);
|
||||
auto newDot = rewriter.create<triton::DotOp>(
|
||||
dotOp.getLoc(), newRetType, a, b, newAcc, dotOp.allowTF32(),
|
||||
dotOp.transA(), dotOp.transB());
|
||||
dotOp.getLoc(), newRetType, a, b, newAcc, dotOp.allowTF32());
|
||||
|
||||
rewriter.replaceOpWithNewOp<triton::gpu::ConvertLayoutOp>(
|
||||
op, oldRetType, newDot.getResult());
|
||||
|
@@ -1184,11 +1184,10 @@ void init_triton_ir(py::module &&m) {
|
||||
})
|
||||
.def("create_dot",
|
||||
[](mlir::OpBuilder &self, mlir::Value &a, mlir::Value &b,
|
||||
mlir::Value &c, bool allowTF32, bool transA,
|
||||
bool transB) -> mlir::Value {
|
||||
mlir::Value &c, bool allowTF32) -> mlir::Value {
|
||||
auto loc = self.getUnknownLoc();
|
||||
return self.create<mlir::triton::DotOp>(loc, c.getType(), a, b, c,
|
||||
allowTF32, transA, transB);
|
||||
allowTF32);
|
||||
})
|
||||
.def("create_exp",
|
||||
[](mlir::OpBuilder &self, mlir::Value &val) -> mlir::Value {
|
||||
|
@@ -1126,7 +1126,7 @@ def test_dot(epilogue, allow_tf32, dtype, device='cuda'):
|
||||
if CHAIN_DOT:
|
||||
# tl.store(Zs, z)
|
||||
# tl.debug_barrier()
|
||||
z = tl.dot(tl.trans(z.to(tl.float16)), tl.load(Ws))
|
||||
z = tl.dot(z.to(tl.float16), tl.load(Ws))
|
||||
tl.store(Zs, z)
|
||||
# input
|
||||
rs = RandomState(17)
|
||||
@@ -1173,7 +1173,7 @@ def test_dot(epilogue, allow_tf32, dtype, device='cuda'):
|
||||
denom = np.sum(num, axis=-1, keepdims=True)
|
||||
z_ref = num / denom
|
||||
if epilogue == 'chain-dot':
|
||||
z_ref = np.matmul(z_ref.T, w)
|
||||
z_ref = np.matmul(z_ref, w)
|
||||
# compare
|
||||
# print(z_ref[:,0], z_tri[:,0])
|
||||
np.testing.assert_allclose(z_ref, to_numpy(z_tri), rtol=0.01)
|
||||
|
@@ -762,6 +762,9 @@ class CodeGenerator(ast.NodeVisitor):
|
||||
|
||||
def visit_Attribute(self, node):
|
||||
lhs = self.visit(node.value)
|
||||
if isinstance(lhs, triton.language.tensor):
|
||||
if node.attr == "T":
|
||||
return triton.language.semantic.trans(lhs, builder=self.builder)
|
||||
return getattr(lhs, node.attr)
|
||||
|
||||
def visit_Expr(self, node):
|
||||
|
@@ -603,10 +603,9 @@ class tensor:
|
||||
assert False, "unsupported"
|
||||
return ret
|
||||
|
||||
|
||||
# x[:, None, :, None]
|
||||
# x = expand_dims(x, axis=1)
|
||||
# x = expand_dims(x, axis=2)
|
||||
@property
|
||||
def T(self):
|
||||
assert False, "Transposition must be created by the AST Visitor"
|
||||
|
||||
@builtin
|
||||
def to(self, dtype, bitcast=False, _builder=None):
|
||||
@@ -766,7 +765,7 @@ def view(input, shape, _builder=None):
|
||||
|
||||
|
||||
@builtin
|
||||
def dot(input, other, allow_tf32=True, trans_a=False, trans_b=False, _builder=None):
|
||||
def dot(input, other, allow_tf32=True, _builder=None):
|
||||
"""
|
||||
Returns the matrix product of two blocks.
|
||||
|
||||
@@ -778,7 +777,7 @@ def dot(input, other, allow_tf32=True, trans_a=False, trans_b=False, _builder=No
|
||||
:type other: 2D tensor of scalar-type in {:code:`float16`, :code:`bfloat16`, :code:`float32`}
|
||||
"""
|
||||
allow_tf32 = _constexpr_to_value(allow_tf32)
|
||||
return semantic.dot(input, other, allow_tf32, trans_a, trans_b, _builder)
|
||||
return semantic.dot(input, other, allow_tf32, _builder)
|
||||
|
||||
|
||||
# -----------------------
|
||||
|
@@ -976,8 +976,6 @@ def atomic_xchg(ptr: tl.tensor,
|
||||
def dot(lhs: tl.tensor,
|
||||
rhs: tl.tensor,
|
||||
allow_tf32: bool,
|
||||
trans_a: bool,
|
||||
trans_b: bool,
|
||||
builder: ir.builder) -> tl.tensor:
|
||||
assert lhs.type.is_block() and rhs.type.is_block()
|
||||
if lhs.type.scalar.is_int():
|
||||
@@ -986,11 +984,11 @@ def dot(lhs: tl.tensor,
|
||||
else:
|
||||
_0 = builder.get_float32(0)
|
||||
ret_scalar_ty = tl.float32
|
||||
M = lhs.type.shape[1 if trans_a else 0]
|
||||
N = rhs.type.shape[0 if trans_b else 1]
|
||||
M = lhs.type.shape[0]
|
||||
N = rhs.type.shape[1]
|
||||
_0 = builder.create_splat(_0, [M, N])
|
||||
ret_ty = tl.block_type(ret_scalar_ty, [M, N])
|
||||
return tl.tensor(builder.create_dot(lhs.handle, rhs.handle, _0, allow_tf32, trans_a, trans_b),
|
||||
return tl.tensor(builder.create_dot(lhs.handle, rhs.handle, _0, allow_tf32),
|
||||
ret_ty)
|
||||
|
||||
|
||||
|
@@ -13,10 +13,10 @@ func @test_combine_dot_add_pattern() -> (tensor<128x128xf32>, tensor<128x128xf32
|
||||
|
||||
%dot_out = tt.dot %a, %b, %zero {allowTF32 = true, transA = false, transB = false} : tensor<128x128xf32> * tensor<128x128xf32> -> tensor<128x128xf32>
|
||||
|
||||
// CHECK-NEXT: %[[res0:.*]] = tt.dot %[[a]], %[[b]], %[[d]] {allowTF32 = true, transA = false, transB = false} : tensor<128x128xf32> * tensor<128x128xf32> -> tensor<128x128xf32>
|
||||
// CHECK-NEXT: %[[res0:.*]] = tt.dot %[[a]], %[[b]], %[[d]] {allowTF32 = true} : tensor<128x128xf32> * tensor<128x128xf32> -> tensor<128x128xf32>
|
||||
%res0 = arith.addf %dot_out, %d : tensor<128x128xf32>
|
||||
|
||||
// CHECK-NEXT: %[[res1:.*]] = tt.dot %[[a]], %[[b]], %[[d]] {allowTF32 = true, transA = false, transB = false} : tensor<128x128xf32> * tensor<128x128xf32> -> tensor<128x128xf32>
|
||||
// CHECK-NEXT: %[[res1:.*]] = tt.dot %[[a]], %[[b]], %[[d]] {allowTF32 = true} : tensor<128x128xf32> * tensor<128x128xf32> -> tensor<128x128xf32>
|
||||
%res1 = arith.addf %d, %dot_out : tensor<128x128xf32>
|
||||
|
||||
return %res0, %res1 : tensor<128x128xf32>, tensor<128x128xf32>
|
||||
|
Reference in New Issue
Block a user