From 099918b3c028cbde71e17bf5caf546859264b4d6 Mon Sep 17 00:00:00 2001 From: Philippe Tillet Date: Mon, 21 Oct 2019 18:58:02 -0400 Subject: [PATCH] [python] [ops] added skeleton for einsum op --- python/examples/einsum.py | 2 +- python/triton/ops/__init__.py | 1 + python/triton/ops/einsum.py | 126 ++++++++++++++++++++++++++++++++++ 3 files changed, 128 insertions(+), 1 deletion(-) create mode 100644 python/triton/ops/einsum.py diff --git a/python/examples/einsum.py b/python/examples/einsum.py index 6d535b911..708277f2d 100644 --- a/python/examples/einsum.py +++ b/python/examples/einsum.py @@ -107,7 +107,7 @@ qk = np.einsum("bhak,hank->bhan", q, kw) tq = torch.from_numpy(q).contiguous().cuda() tkw = torch.from_numpy(kw).contiguous().cuda() -tqk = dot(tq, tkw) +tqk = triton.ops.einsum("bhak,hank->bhan", tq, tkw) diff = qk - tqk.cpu().numpy() print(np.max(diff)) diff --git a/python/triton/ops/__init__.py b/python/triton/ops/__init__.py index f995b88f1..3edb82406 100644 --- a/python/triton/ops/__init__.py +++ b/python/triton/ops/__init__.py @@ -1 +1,2 @@ from .dot import dot +from .einsum import einsum diff --git a/python/triton/ops/einsum.py b/python/triton/ops/einsum.py new file mode 100644 index 000000000..e33da2451 --- /dev/null +++ b/python/triton/ops/einsum.py @@ -0,0 +1,126 @@ +import triton + +class _einsum(triton.function): + + src = """ + void einsum(TYPE * A, TYPE * B, TYPE * C, + int dim_M, int dim_N, int dim_K, + int std_A0, int std_B0, int std_C0, + int std_A1, int std_B1, int std_C1) { + int pid0 = get_program_id(0); + int pid1 = get_program_id(1); + } + """ + + kernel = triton.kernel(src, ['C']) + + @staticmethod + def _append_dim(dim_data, dim_type, idx, label, dim, stride): + if dim_type in dim_data: + data = dim_data[dim_type] + if idx != data["idx"] + 1: + raise ValueError("aggregate inner, outer and batch dims must be adjacent to each other.") + data["dim"] *= dim + data["lab"] = label + data["lab"] + else: + dim_data[dim_type] = dict(idx=idx, lab=label, dim=dim, std=stride) + return dim_type + + @staticmethod + def _parse_abc(labels_a, labels_b, labels_c, shape_a, is_a=False): + + if len(labels_a) != len(shape_a): + raise ValueError(f"einsum notation dims do not match shape: {labels_a} {shape_a}") + + trans = False + stride = 1 + std1 = None + data = dict() + for idx, (lab, dim) in enumerate(reversed(list(zip(labels_a, shape_a)))): + #print(idx, lab, dim) + if dim is None: + raise ValueError("einsum doens't currently work on shapes with placeholder dims.") + if idx == 0 and dim % 8 != 0: + raise ValueError("contiguous dim must be multiple of 8") + + if lab in labels_c: + # batch dim + if lab in labels_b: + _einsum._append_dim(data, "B", idx, lab, dim, stride) + if idx == 0: + raise ValueError(f"batch dim can not be contiguous dim: {lab} {labels_a} {shape_a}") + # outer dim + else: + std1 = _einsum._append_dim(data, "O", idx, lab, dim, stride) + if idx == 0: + trans = is_a + # inner dim + elif lab in labels_b: + std1 = _einsum._append_dim(data, "I", idx, lab, dim, stride) + if idx == 0: + trans = not is_a + else: + raise ValueError(f"einsum def for output: {lab} ({labels_a}), not present in either other def") + + stride *= dim + + if "B" not in data: + data["B"] = dict(dim=1, std=1) + + # batch, outer, inner, std0, std1, trans + return data["B"]["dim"], data["O"]["dim"], data["I"]["dim"], data["B"]["std"], data[std1]["std"], trans + + @staticmethod + def _parse_einsum(labels_a, labels_b, labels_c, shape_a, shape_b): + + dims_a = dict(zip(labels_a, shape_a)) + dims_b = dict(zip(labels_b, shape_b)) + shape_c = list() + for lab in labels_c: + if lab in dims_a: + shape_c.append(dims_a[lab]) + elif lab in dims_b: + shape_c.append(dims_b[lab]) + else: + raise ValueError(f"einsum def for output: {lab} ({labels_c}), not present in either input def ({labels_a}, {labels_b})") + + BA, M, KA, std_a0, std_a1, ta = _einsum._parse_abc(labels_a, labels_b, labels_c, shape_a, True) + BB, N, KB, std_b0, std_b1, tb = _einsum._parse_abc(labels_b, labels_a, labels_c, shape_b, False) + BC, _, _, std_c0, std_c1, _ = _einsum._parse_abc(labels_c, labels_b, labels_a, shape_c) + + if not (BA == BB == BC): + raise ValueError("mismatched batch dims") + if KA != KB: + raise ValueError("mismatched reduction dims") + + return shape_c, (BA, M, N, KA), (std_a0, std_b0, std_c0), (std_a1, std_b1, std_c1), ta, tb + + @staticmethod + def call(a, b, trans_a, trans_b, shape_c, bmnk, + std0, std1, einsum_a, einsum_b, einsum_c): + dtype = a.dtype + c = triton.empty(shape_c, dtype) + grid = lambda opt: (1, 1, 1) + return _einsum.kernel(a, b, c, + bmnk[1], bmnk[2], bmnk[3], + std0[0], std0[1], std0[2], + std1[0], std1[1], std1[2], + grid, + TYPE=['float']) + + + @staticmethod + def forward(ctx, subscripts, a, b): + if type(subscripts) is str: + einsum_a, einsum_bc = subscripts.split(",") + einsum_b, einsum_c = einsum_bc.split("->") + else: + einsum_a, einsum_b, einsum_c = subscripts + + shape_c, bmnk, std0, std1, ta, tb = _einsum._parse_einsum( + einsum_a, einsum_b, einsum_c, + a.shape, b.shape + ) + return _einsum.call(a, b, ta, tb, shape_c, bmnk, std0, std1, einsum_a, einsum_b, einsum_c) + +einsum = _einsum.apply \ No newline at end of file