[FRONTEND][BACKEND] Clean-up transpositions (#953)
This commit is contained in:
@@ -339,7 +339,7 @@ def TT_DotOp : TT_Op<"dot", [NoSideEffect,
|
|||||||
$d = matrix_multiply($a, $b) + $c
|
$d = matrix_multiply($a, $b) + $c
|
||||||
}];
|
}];
|
||||||
|
|
||||||
let arguments = (ins TT_FpIntTensor:$a, TT_FpIntTensor:$b, TT_FpIntTensor:$c, BoolAttr:$allowTF32, BoolAttr:$transA, BoolAttr:$transB);
|
let arguments = (ins TT_FpIntTensor:$a, TT_FpIntTensor:$b, TT_FpIntTensor:$c, BoolAttr:$allowTF32);
|
||||||
|
|
||||||
let results = (outs TT_FpIntTensor:$d);
|
let results = (outs TT_FpIntTensor:$d);
|
||||||
|
|
||||||
|
@@ -942,11 +942,6 @@ struct MMA16816ConversionHelper {
|
|||||||
|
|
||||||
SmallVector<int64_t> shape(aTensorTy.getShape().begin(),
|
SmallVector<int64_t> shape(aTensorTy.getShape().begin(),
|
||||||
aTensorTy.getShape().end());
|
aTensorTy.getShape().end());
|
||||||
// TODO[Superjomn]: transA cannot be accessed in ConvertLayoutOp.
|
|
||||||
bool transA = false;
|
|
||||||
if (transA) {
|
|
||||||
std::swap(shape[0], shape[1]);
|
|
||||||
}
|
|
||||||
|
|
||||||
ValueTable ha;
|
ValueTable ha;
|
||||||
std::function<void(int, int)> loadFn;
|
std::function<void(int, int)> loadFn;
|
||||||
@@ -1052,8 +1047,6 @@ struct MMA16816ConversionHelper {
|
|||||||
|
|
||||||
SmallVector<int64_t> aShape(aTensorTy.getShape().begin(),
|
SmallVector<int64_t> aShape(aTensorTy.getShape().begin(),
|
||||||
aTensorTy.getShape().end());
|
aTensorTy.getShape().end());
|
||||||
if (op.transA())
|
|
||||||
std::swap(aShape[0], aShape[1]);
|
|
||||||
|
|
||||||
auto dShape = dTensorTy.getShape();
|
auto dShape = dTensorTy.getShape();
|
||||||
|
|
||||||
@@ -1462,8 +1455,8 @@ Value DotOpMmaV1ConversionHelper::loadB(
|
|||||||
SmallVector<unsigned> order(sharedLayout.getOrder().begin(),
|
SmallVector<unsigned> order(sharedLayout.getOrder().begin(),
|
||||||
sharedLayout.getOrder().end());
|
sharedLayout.getOrder().end());
|
||||||
|
|
||||||
Value smem = smemObj.getBaseBeforeSwizzle(order[0], loc, rewriter);
|
|
||||||
|
|
||||||
|
Value smem = smemObj.getBaseBeforeSwizzle(order[0], loc, rewriter);
|
||||||
bool isBRow = order[0] != 0;
|
bool isBRow = order[0] != 0;
|
||||||
bool isBVec4 = isBRow && shape[order[0]] <= 16;
|
bool isBVec4 = isBRow && shape[order[0]] <= 16;
|
||||||
// TODO[Superjomn]: Support the case when isBVec4=false later
|
// TODO[Superjomn]: Support the case when isBVec4=false later
|
||||||
|
@@ -3253,7 +3253,7 @@ struct DotOpConversion : public ConvertTritonGPUOpToLLVMPattern<triton::DotOp> {
|
|||||||
|
|
||||||
// Here we assume the DotOp's operands always comes from shared memory.
|
// Here we assume the DotOp's operands always comes from shared memory.
|
||||||
auto AShape = A.getType().cast<RankedTensorType>().getShape();
|
auto AShape = A.getType().cast<RankedTensorType>().getShape();
|
||||||
size_t reduceAxis = op.transA() ? 0 : 1;
|
size_t reduceAxis = 1;
|
||||||
unsigned K = AShape[reduceAxis];
|
unsigned K = AShape[reduceAxis];
|
||||||
bool isOuter = K == 1;
|
bool isOuter = K == 1;
|
||||||
|
|
||||||
@@ -3492,22 +3492,10 @@ DotOpConversion::convertMMA884(triton::DotOp op, DotOpAdaptor adaptor,
|
|||||||
auto DShape = DTensorTy.getShape();
|
auto DShape = DTensorTy.getShape();
|
||||||
auto wpt = mmaLayout.getWarpsPerCTA();
|
auto wpt = mmaLayout.getWarpsPerCTA();
|
||||||
|
|
||||||
bool transA = op.transA();
|
|
||||||
bool transB = op.transB();
|
|
||||||
|
|
||||||
// TODO[Superjomn]: order cannot accessed in DotOp.
|
// TODO[Superjomn]: order cannot accessed in DotOp.
|
||||||
SmallVector<unsigned> AOrder({1, 0});
|
SmallVector<unsigned> AOrder({1, 0});
|
||||||
SmallVector<unsigned> BOrder({1, 0});
|
SmallVector<unsigned> BOrder({1, 0});
|
||||||
|
|
||||||
if (transA) {
|
|
||||||
std::swap(AShape[0], AShape[1]);
|
|
||||||
std::swap(AOrder[0], AOrder[1]);
|
|
||||||
}
|
|
||||||
if (transB) {
|
|
||||||
std::swap(BShape[0], BShape[1]);
|
|
||||||
std::swap(BOrder[0], BOrder[0]);
|
|
||||||
}
|
|
||||||
|
|
||||||
bool isARow = AOrder[0] != 0;
|
bool isARow = AOrder[0] != 0;
|
||||||
bool isBRow = BOrder[0] != 0;
|
bool isBRow = BOrder[0] != 0;
|
||||||
bool isAVec4 = !isARow && AShape[isARow] <= 16; // fp16*4 = 16bytes
|
bool isAVec4 = !isARow && AShape[isARow] <= 16; // fp16*4 = 16bytes
|
||||||
|
@@ -245,9 +245,8 @@ struct TritonDotPattern : public OpConversionPattern<triton::DotOp> {
|
|||||||
bType.getElementType(), encoding);
|
bType.getElementType(), encoding);
|
||||||
b = rewriter.create<triton::gpu::ConvertLayoutOp>(b.getLoc(), dstType, b);
|
b = rewriter.create<triton::gpu::ConvertLayoutOp>(b.getLoc(), dstType, b);
|
||||||
}
|
}
|
||||||
rewriter.replaceOpWithNewOp<triton::DotOp>(
|
rewriter.replaceOpWithNewOp<triton::DotOp>(op, retType, a, b, adaptor.c(),
|
||||||
op, retType, a, b, adaptor.c(), adaptor.allowTF32(), adaptor.transA(),
|
adaptor.allowTF32());
|
||||||
adaptor.transB());
|
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
@@ -12,21 +12,21 @@ include "triton/Dialect/Triton/IR/TritonOps.td"
|
|||||||
// AddIOp(d, DotOp(a, b, c)) and c==0 => DotOp(a, b, d)
|
// AddIOp(d, DotOp(a, b, c)) and c==0 => DotOp(a, b, d)
|
||||||
// AddFOp(d, DotOp(a, b, c)) and c==0 => DotOp(a, b, d)
|
// AddFOp(d, DotOp(a, b, c)) and c==0 => DotOp(a, b, d)
|
||||||
def CombineDotAddIPattern : Pat<
|
def CombineDotAddIPattern : Pat<
|
||||||
(Arith_AddIOp $d, (TT_DotOp:$res $a, $b, $c, $allowTF32, $transA, $transB)),
|
(Arith_AddIOp $d, (TT_DotOp:$res $a, $b, $c, $allowTF32)),
|
||||||
(TT_DotOp $a, $b, $d, $allowTF32, $transA, $transB),
|
(TT_DotOp $a, $b, $d, $allowTF32),
|
||||||
[(Constraint<CPred<"isZero($0)">> $c)]>;
|
[(Constraint<CPred<"isZero($0)">> $c)]>;
|
||||||
def CombineDotAddFPattern : Pat<
|
def CombineDotAddFPattern : Pat<
|
||||||
(Arith_AddFOp $d, (TT_DotOp:$res $a, $b, $c, $allowTF32, $transA, $transB)),
|
(Arith_AddFOp $d, (TT_DotOp:$res $a, $b, $c, $allowTF32)),
|
||||||
(TT_DotOp $a, $b, $d, $allowTF32, $transA, $transB),
|
(TT_DotOp $a, $b, $d, $allowTF32),
|
||||||
[(Constraint<CPred<"isZero($0)">> $c)]>;
|
[(Constraint<CPred<"isZero($0)">> $c)]>;
|
||||||
|
|
||||||
def CombineDotAddIRevPattern : Pat<
|
def CombineDotAddIRevPattern : Pat<
|
||||||
(Arith_AddIOp (TT_DotOp:$res $a, $b, $c, $allowTF32, $transA, $transB), $d),
|
(Arith_AddIOp (TT_DotOp:$res $a, $b, $c, $allowTF32), $d),
|
||||||
(TT_DotOp $a, $b, $d, $allowTF32, $transA, $transB),
|
(TT_DotOp $a, $b, $d, $allowTF32),
|
||||||
[(Constraint<CPred<"isZero($0)">> $c)]>;
|
[(Constraint<CPred<"isZero($0)">> $c)]>;
|
||||||
def CombineDotAddFRevPattern : Pat<
|
def CombineDotAddFRevPattern : Pat<
|
||||||
(Arith_AddFOp (TT_DotOp:$res $a, $b, $c, $allowTF32, $transA, $transB), $d),
|
(Arith_AddFOp (TT_DotOp:$res $a, $b, $c, $allowTF32), $d),
|
||||||
(TT_DotOp $a, $b, $d, $allowTF32, $transA, $transB),
|
(TT_DotOp $a, $b, $d, $allowTF32),
|
||||||
[(Constraint<CPred<"isZero($0)">> $c)]>;
|
[(Constraint<CPred<"isZero($0)">> $c)]>;
|
||||||
|
|
||||||
|
|
||||||
|
@@ -781,8 +781,7 @@ public:
|
|||||||
a = rewriter.create<triton::gpu::ConvertLayoutOp>(a.getLoc(), newAType, a);
|
a = rewriter.create<triton::gpu::ConvertLayoutOp>(a.getLoc(), newAType, a);
|
||||||
b = rewriter.create<triton::gpu::ConvertLayoutOp>(b.getLoc(), newBType, b);
|
b = rewriter.create<triton::gpu::ConvertLayoutOp>(b.getLoc(), newBType, b);
|
||||||
auto newDot = rewriter.create<triton::DotOp>(
|
auto newDot = rewriter.create<triton::DotOp>(
|
||||||
dotOp.getLoc(), newRetType, a, b, newAcc, dotOp.allowTF32(),
|
dotOp.getLoc(), newRetType, a, b, newAcc, dotOp.allowTF32());
|
||||||
dotOp.transA(), dotOp.transB());
|
|
||||||
|
|
||||||
rewriter.replaceOpWithNewOp<triton::gpu::ConvertLayoutOp>(
|
rewriter.replaceOpWithNewOp<triton::gpu::ConvertLayoutOp>(
|
||||||
op, oldRetType, newDot.getResult());
|
op, oldRetType, newDot.getResult());
|
||||||
|
@@ -1184,11 +1184,10 @@ void init_triton_ir(py::module &&m) {
|
|||||||
})
|
})
|
||||||
.def("create_dot",
|
.def("create_dot",
|
||||||
[](mlir::OpBuilder &self, mlir::Value &a, mlir::Value &b,
|
[](mlir::OpBuilder &self, mlir::Value &a, mlir::Value &b,
|
||||||
mlir::Value &c, bool allowTF32, bool transA,
|
mlir::Value &c, bool allowTF32) -> mlir::Value {
|
||||||
bool transB) -> mlir::Value {
|
|
||||||
auto loc = self.getUnknownLoc();
|
auto loc = self.getUnknownLoc();
|
||||||
return self.create<mlir::triton::DotOp>(loc, c.getType(), a, b, c,
|
return self.create<mlir::triton::DotOp>(loc, c.getType(), a, b, c,
|
||||||
allowTF32, transA, transB);
|
allowTF32);
|
||||||
})
|
})
|
||||||
.def("create_exp",
|
.def("create_exp",
|
||||||
[](mlir::OpBuilder &self, mlir::Value &val) -> mlir::Value {
|
[](mlir::OpBuilder &self, mlir::Value &val) -> mlir::Value {
|
||||||
|
@@ -1126,7 +1126,7 @@ def test_dot(epilogue, allow_tf32, dtype, device='cuda'):
|
|||||||
if CHAIN_DOT:
|
if CHAIN_DOT:
|
||||||
# tl.store(Zs, z)
|
# tl.store(Zs, z)
|
||||||
# tl.debug_barrier()
|
# tl.debug_barrier()
|
||||||
z = tl.dot(tl.trans(z.to(tl.float16)), tl.load(Ws))
|
z = tl.dot(z.to(tl.float16), tl.load(Ws))
|
||||||
tl.store(Zs, z)
|
tl.store(Zs, z)
|
||||||
# input
|
# input
|
||||||
rs = RandomState(17)
|
rs = RandomState(17)
|
||||||
@@ -1173,7 +1173,7 @@ def test_dot(epilogue, allow_tf32, dtype, device='cuda'):
|
|||||||
denom = np.sum(num, axis=-1, keepdims=True)
|
denom = np.sum(num, axis=-1, keepdims=True)
|
||||||
z_ref = num / denom
|
z_ref = num / denom
|
||||||
if epilogue == 'chain-dot':
|
if epilogue == 'chain-dot':
|
||||||
z_ref = np.matmul(z_ref.T, w)
|
z_ref = np.matmul(z_ref, w)
|
||||||
# compare
|
# compare
|
||||||
# print(z_ref[:,0], z_tri[:,0])
|
# print(z_ref[:,0], z_tri[:,0])
|
||||||
np.testing.assert_allclose(z_ref, to_numpy(z_tri), rtol=0.01)
|
np.testing.assert_allclose(z_ref, to_numpy(z_tri), rtol=0.01)
|
||||||
|
@@ -762,6 +762,9 @@ class CodeGenerator(ast.NodeVisitor):
|
|||||||
|
|
||||||
def visit_Attribute(self, node):
|
def visit_Attribute(self, node):
|
||||||
lhs = self.visit(node.value)
|
lhs = self.visit(node.value)
|
||||||
|
if isinstance(lhs, triton.language.tensor):
|
||||||
|
if node.attr == "T":
|
||||||
|
return triton.language.semantic.trans(lhs, builder=self.builder)
|
||||||
return getattr(lhs, node.attr)
|
return getattr(lhs, node.attr)
|
||||||
|
|
||||||
def visit_Expr(self, node):
|
def visit_Expr(self, node):
|
||||||
|
@@ -603,10 +603,9 @@ class tensor:
|
|||||||
assert False, "unsupported"
|
assert False, "unsupported"
|
||||||
return ret
|
return ret
|
||||||
|
|
||||||
|
@property
|
||||||
# x[:, None, :, None]
|
def T(self):
|
||||||
# x = expand_dims(x, axis=1)
|
assert False, "Transposition must be created by the AST Visitor"
|
||||||
# x = expand_dims(x, axis=2)
|
|
||||||
|
|
||||||
@builtin
|
@builtin
|
||||||
def to(self, dtype, bitcast=False, _builder=None):
|
def to(self, dtype, bitcast=False, _builder=None):
|
||||||
@@ -766,7 +765,7 @@ def view(input, shape, _builder=None):
|
|||||||
|
|
||||||
|
|
||||||
@builtin
|
@builtin
|
||||||
def dot(input, other, allow_tf32=True, trans_a=False, trans_b=False, _builder=None):
|
def dot(input, other, allow_tf32=True, _builder=None):
|
||||||
"""
|
"""
|
||||||
Returns the matrix product of two blocks.
|
Returns the matrix product of two blocks.
|
||||||
|
|
||||||
@@ -778,7 +777,7 @@ def dot(input, other, allow_tf32=True, trans_a=False, trans_b=False, _builder=No
|
|||||||
:type other: 2D tensor of scalar-type in {:code:`float16`, :code:`bfloat16`, :code:`float32`}
|
:type other: 2D tensor of scalar-type in {:code:`float16`, :code:`bfloat16`, :code:`float32`}
|
||||||
"""
|
"""
|
||||||
allow_tf32 = _constexpr_to_value(allow_tf32)
|
allow_tf32 = _constexpr_to_value(allow_tf32)
|
||||||
return semantic.dot(input, other, allow_tf32, trans_a, trans_b, _builder)
|
return semantic.dot(input, other, allow_tf32, _builder)
|
||||||
|
|
||||||
|
|
||||||
# -----------------------
|
# -----------------------
|
||||||
|
@@ -976,8 +976,6 @@ def atomic_xchg(ptr: tl.tensor,
|
|||||||
def dot(lhs: tl.tensor,
|
def dot(lhs: tl.tensor,
|
||||||
rhs: tl.tensor,
|
rhs: tl.tensor,
|
||||||
allow_tf32: bool,
|
allow_tf32: bool,
|
||||||
trans_a: bool,
|
|
||||||
trans_b: bool,
|
|
||||||
builder: ir.builder) -> tl.tensor:
|
builder: ir.builder) -> tl.tensor:
|
||||||
assert lhs.type.is_block() and rhs.type.is_block()
|
assert lhs.type.is_block() and rhs.type.is_block()
|
||||||
if lhs.type.scalar.is_int():
|
if lhs.type.scalar.is_int():
|
||||||
@@ -986,11 +984,11 @@ def dot(lhs: tl.tensor,
|
|||||||
else:
|
else:
|
||||||
_0 = builder.get_float32(0)
|
_0 = builder.get_float32(0)
|
||||||
ret_scalar_ty = tl.float32
|
ret_scalar_ty = tl.float32
|
||||||
M = lhs.type.shape[1 if trans_a else 0]
|
M = lhs.type.shape[0]
|
||||||
N = rhs.type.shape[0 if trans_b else 1]
|
N = rhs.type.shape[1]
|
||||||
_0 = builder.create_splat(_0, [M, N])
|
_0 = builder.create_splat(_0, [M, N])
|
||||||
ret_ty = tl.block_type(ret_scalar_ty, [M, N])
|
ret_ty = tl.block_type(ret_scalar_ty, [M, N])
|
||||||
return tl.tensor(builder.create_dot(lhs.handle, rhs.handle, _0, allow_tf32, trans_a, trans_b),
|
return tl.tensor(builder.create_dot(lhs.handle, rhs.handle, _0, allow_tf32),
|
||||||
ret_ty)
|
ret_ty)
|
||||||
|
|
||||||
|
|
||||||
|
@@ -13,10 +13,10 @@ func @test_combine_dot_add_pattern() -> (tensor<128x128xf32>, tensor<128x128xf32
|
|||||||
|
|
||||||
%dot_out = tt.dot %a, %b, %zero {allowTF32 = true, transA = false, transB = false} : tensor<128x128xf32> * tensor<128x128xf32> -> tensor<128x128xf32>
|
%dot_out = tt.dot %a, %b, %zero {allowTF32 = true, transA = false, transB = false} : tensor<128x128xf32> * tensor<128x128xf32> -> tensor<128x128xf32>
|
||||||
|
|
||||||
// CHECK-NEXT: %[[res0:.*]] = tt.dot %[[a]], %[[b]], %[[d]] {allowTF32 = true, transA = false, transB = false} : tensor<128x128xf32> * tensor<128x128xf32> -> tensor<128x128xf32>
|
// CHECK-NEXT: %[[res0:.*]] = tt.dot %[[a]], %[[b]], %[[d]] {allowTF32 = true} : tensor<128x128xf32> * tensor<128x128xf32> -> tensor<128x128xf32>
|
||||||
%res0 = arith.addf %dot_out, %d : tensor<128x128xf32>
|
%res0 = arith.addf %dot_out, %d : tensor<128x128xf32>
|
||||||
|
|
||||||
// CHECK-NEXT: %[[res1:.*]] = tt.dot %[[a]], %[[b]], %[[d]] {allowTF32 = true, transA = false, transB = false} : tensor<128x128xf32> * tensor<128x128xf32> -> tensor<128x128xf32>
|
// CHECK-NEXT: %[[res1:.*]] = tt.dot %[[a]], %[[b]], %[[d]] {allowTF32 = true} : tensor<128x128xf32> * tensor<128x128xf32> -> tensor<128x128xf32>
|
||||||
%res1 = arith.addf %d, %dot_out : tensor<128x128xf32>
|
%res1 = arith.addf %d, %dot_out : tensor<128x128xf32>
|
||||||
|
|
||||||
return %res0, %res1 : tensor<128x128xf32>, tensor<128x128xf32>
|
return %res0, %res1 : tensor<128x128xf32>, tensor<128x128xf32>
|
||||||
|
Reference in New Issue
Block a user