[PYTHON][KERNEL] Added benchmarking functionalities for kernels

This commit is contained in:
Philippe Tillet
2019-10-27 15:32:34 -04:00
parent e11557855f
commit 0ec213547c
9 changed files with 207 additions and 112 deletions

View File

@@ -56,14 +56,12 @@ void disassociate::run(ir::module &mod) {
bld.set_insert_point(y);
bld.insert(cloned);
clone_map[y] = cloned;
// replace in above level
if(depth > 1){
// replace operands of parents
if(depth > 1)
for(ir::user* ux: x.second.at(depth - 1))
clone_map.at((ir::instruction*)ux)->replace_uses_of_with(y, cloned);
}
else{
else
x.first->replace_uses_of_with(y, cloned);
}
}
depth += 1;
}

View File

@@ -2,11 +2,11 @@ import numpy as np
import triton
def run_tf():
M, N, K = 128, 128, 128
M, N, K = 2048, 2048, 2048
a = tf.placeholder(tf.float32, shape=[M, K])
b = tf.placeholder(tf.float32, shape=[N, K])
tr_c = triton.ops.dot(a, b, transpose_a = False, transpose_b = True)
tr_d = triton.ops.dot(tr_c, b, transpose_a = True, transpose_b = False)
tr_c = triton.ops.dot(a, b, transpose_a = False, transpose_b = True, bench=10)
tr_d = triton.ops.dot(tr_c, b, transpose_a = True, transpose_b = False, bench=10)
tf_c = tf.matmul(a, b, transpose_a = False, transpose_b = True)
tf_d = tf.matmul(tf_c, b, transpose_a = True, transpose_b = False)
# Gradient
@@ -20,15 +20,20 @@ def run_tf():
sess.run(tf.global_variables_initializer())
result = sess.run([tr_da, tf_da], feed_dict = {a: ha,
b: hb})
# Benchmark
nanosec = triton.bench_registry[tr_d]
print('NANOSEC: ', nanosec)
print('TFLOPS:', 2. * M * N * K / nanosec * 1e-3)
# Test
print(result[0][0])
print(result[1][0])
dif = np.abs(result[0][0] - result[1][0])
print("dif: %f" % np.max(dif))
def run_torch():
torch.manual_seed(0)
M, N, K = 128, 128, 128
M, N, K = 2048, 2048, 2048
a = torch.randn(M, K).cuda()
b = torch.randn(K, N).cuda()
a.requires_grad_(True)
@@ -37,9 +42,8 @@ def run_torch():
torch_d = torch.matmul(torch.t(torch_c), b)
torch_y = torch.mean(torch_d)
triton_c = triton.ops.dot(a, b, False, True)
triton_d = triton.ops.dot(triton_c, b, True, False)
triton_d = triton.ops.dot(triton_c, b, True, False, 1)
triton_y = torch.mean(triton_d)
# torch gradient
torch_y.backward()
torch_da = a.grad.clone()
@@ -51,6 +55,9 @@ def run_torch():
triton_da = a.grad.clone()
triton_db = b.grad.clone()
nanosec = triton.bench_registry[triton_d]
print(nanosec)
print('TFLOPS:', 2. * M * N * K / nanosec * 1e-3)
print('Diff DA:', (torch_da - triton_da).max())
print('Diff DB:', (torch_db - triton_db).max())

View File

@@ -12,7 +12,8 @@ from tensorflow.python.ops import gradient_checker
one = 0
out = 0
bench = 0
bench = 10
class ProdKeyTest(tf.test.TestCase):
def testEinsum(self):
@@ -36,9 +37,9 @@ class ProdKeyTest(tf.test.TestCase):
# key_dim = 16
for a_shape, b_shape, c_shape, einsum in [
[ [ 4, 8, 8 ], [ 8, 8 ], [ 4, 8, 8 ], "btc,ck->btk" ],
[ [ 4, 1024, 1024 ], [ 1024, 512 ], [ 4, 1024, 512 ], "btc,ck->btk" ],
[ (batch_dim, ctx_dim, head_dim, 2, key_dim//2),(head_dim, 2, n_keys, key_dim//2), (batch_dim, ctx_dim, head_dim, 2, n_keys), "bchak,hank->bchan" ],
#[ [ 4, 8, 8 ], [ 8, 8 ], [ 4, 8, 8 ], "btc,ck->btk" ],
[ [4, 2048, 2048 ], [ 2048, 2048 ], [4, 2048, 2048 ], "btc,ck->btk" ],
#[ (batch_dim, ctx_dim, head_dim, 2, key_dim//2),(head_dim, 2, n_keys, key_dim//2), (batch_dim, ctx_dim, head_dim, 2, n_keys), "bchak,hank->bchan" ],
]:
if one:
@@ -57,7 +58,7 @@ class ProdKeyTest(tf.test.TestCase):
e = tf.placeholder(tf.float32, c_shape, name="e")
feed_dict = { a:A, b:B, e:E }
cc = triton.ops.einsum(einsum, a, b)
cc = triton.ops.einsum(einsum, a, b, bench=bench)
# error = gradient_checker.compute_gradient_error(a, a_shape, c, c_shape, delta=1e-1, extra_feed_dict={ b:B }) #
# print(error)
@@ -71,8 +72,12 @@ class ProdKeyTest(tf.test.TestCase):
# c, = sess.run( [ c, ], feed_dict )
c, da, db = sess.run( [ cc, da, db ], feed_dict )
if bench == 0:
if bench > 0:
nanosec = triton.bench_registry[cc]
print(A.shape, B.shape)
print(nanosec)
else:
C = np.einsum(einsum, A, B)
id = cc.op.get_attr('id')
ctx = triton.ops._einsum.contexts[id]

View File

@@ -20,13 +20,13 @@ using namespace triton;
namespace rt = triton::runtime;
/* TF triton op properties */
std::map<size_t, std::shared_ptr<rt::function::grid_fn_ty>> id_grid_map;
std::map<size_t, std::shared_ptr<rt::function>> id_fn_map;
std::map<size_t, double> fp64scalar_map;
std::map<size_t, int64_t> i64scalar_map;
/* Grid map */
void register_grid(size_t id,
const rt::function::grid_fn_ty& grid_fn) {
id_grid_map[id].reset(new rt::function::grid_fn_ty(grid_fn));
@@ -36,6 +36,8 @@ void delete_grid(size_t id) {
id_grid_map.erase(id);
}
/* Function map */
void register_fn(size_t id,
const std::string& src,
const rt::function::options_space_t& opt) {
@@ -56,8 +58,11 @@ size_t make_op_id() {
return id_fn_map.size();
}
/* TF scalar wrapper */
size_t make_scalar_id() {
return i64scalar_map.size();
size_t ret = i64scalar_map.size();
i64scalar_map[ret] = int64_t();
return ret;
}
bool has_scalar(size_t id) {
@@ -135,8 +140,9 @@ void gen_make_handles(std::ostream &os, const std::vector<ir::argument*>& args)
}
}
void gen_make_launch_function(std::ostream &os, const std::vector<ir::argument*>& args) {
os << " (*id_fn_map.at(id_))({";
void gen_make_launch_function(std::ostream &os, int num_outputs, const std::vector<ir::argument*>& args) {
os << " std::function<void()> run = [&](){\n ";
os << " (*id_fn_map.at(id_))({";
for(unsigned i = 0; i < args.size() ; i++){
ir::argument *arg = args[i];
std::string name = arg->get_name();
@@ -146,7 +152,11 @@ void gen_make_launch_function(std::ostream &os, const std::vector<ir::argument*>
os << ", ";
os << name;
}
os << "}, *id_grid_map.at(id_), stream); \n";
os << "}, *id_grid_map.at(id_), stream);\n";
os << " };\n ";
os << " run();";
os << " if(bench_ > 0)\n ";
os << " i64scalar_map[id_] = triton::tools::bench(run, stream);\n ";
}
void gen_tf_register_kernel_builder(std::ostream &os, const std::string &name,
@@ -186,7 +196,9 @@ void gen_tf_register_op(std::ostream &os, const std::string &name,
throw std::runtime_error("unknown output");
os << " .Output(\"out" << i << ": T" << idx << "\")\n";
}
os << " .Attr(\"id: int\")" << std::endl;
os << " .Attr(\"id: int\")\n";
os << " .Attr(\"bench: int\")\n";
os << " .Attr(\"bench_id: int\")\n";
os << ";\n";
}
@@ -247,6 +259,7 @@ std::tuple<std::string,
#include "triton/driver/backend.h"
#include "triton/driver/stream.h"
#include "triton/runtime/function.h"
#include "triton/tools/bench.hpp"
#define EIGEN_USE_GPU
#include "tensorflow/core/framework/op.h"
@@ -260,13 +273,15 @@ namespace drv = triton::driver;
extern std::map<size_t, std::shared_ptr<rt::function::grid_fn_ty>> id_grid_map;
extern std::map<size_t, std::shared_ptr<rt::function>> id_fn_map;
extern std::map<size_t, int64_t> i64scalar_map;
class )" << opname << R"(: public OpKernel {
public:
explicit )" << opname << R"((OpKernelConstruction* context)
: OpKernel(context) {
OP_REQUIRES_OK(context, context->GetAttr("id", &id_));
OP_REQUIRES_OK(context, context->GetAttr("bench", &bench_));
OP_REQUIRES_OK(context, context->GetAttr("bench_id", &bench_id_));
}
void Compute(OpKernelContext* context){
@@ -291,12 +306,14 @@ oss << R"(
oss << R"(
// launch function
)";
gen_make_launch_function(oss, fn->args());
gen_make_launch_function(oss, outputs.size(), fn->args());
oss << R"(
}
private:
int id_;
int bench_;
int bench_id_;
};
// register kernel builder
@@ -379,6 +396,7 @@ void gen_torch_signature(std::ostringstream& oss,
oss << ret_ty << " " << name << "(";
oss << "int64_t id, ";
oss << "int64_t bench, ";
for(size_t i = 0; i < args.size(); i++) {
ir::argument* arg = args[i];
if(i > 0)
@@ -420,7 +438,8 @@ void gen_torch_make_handles(std::ostream &os,
}
void gen_torch_make_launch_function(std::ostream &os, const std::vector<ir::argument*>& args) {
os << " (*id_fn_map.at(id))({";
os << " std::function<void()> run = [&](){\n ";
os << " (*id_fn_map.at(id))({";
for(unsigned i = 0; i < args.size() ; i++){
ir::argument *arg = args[i];
std::string name = "arg_" + arg->get_name();
@@ -431,7 +450,11 @@ void gen_torch_make_launch_function(std::ostream &os, const std::vector<ir::argu
os << name;
}
os << "}, *id_grid_map.at(id), &stream);\n";
}
os << " };\n ";
os << " run();";
os << " if(bench > 0)\n ";
os << " i64scalar_map[id] = triton::tools::bench(run, stream);\n ";
}
void gen_torch_ret(std::ostream &os, const std::vector<std::string>& outputs) {
if(outputs.size() == 1){
@@ -465,6 +488,7 @@ std::tuple<std::string,
#include "triton/driver/backend.h"
#include "triton/driver/stream.h"
#include "triton/runtime/function.h"
#include "triton/tools/bench.hpp"
#include "torch/extension.h"
#include "torch/script.h"
#include "ATen/cuda/CUDAContext.h"
@@ -479,6 +503,7 @@ namespace drv = triton::driver;
extern std::map<size_t, std::shared_ptr<rt::function::grid_fn_ty>> id_grid_map;
extern std::map<size_t, std::shared_ptr<rt::function>> id_fn_map;
extern std::map<size_t, int64_t> i64scalar_map;
)";

View File

@@ -5,6 +5,7 @@ import shutil
import hashlib
import sysconfig
import sys
import weakref
# import for just-in-time compilation
import distutils
import setuptools.command.build_ext
@@ -176,6 +177,38 @@ def _make_grid(args) :
return grid
class bench_dict:
# Lazy entry for e.g., tensorflow, when value of benchmark is
# not known at graph compile time
class lazy_entry:
def __init__(self, id):
self.id = id
def get(self):
return libtriton.retrieve_scalar(self.id)
def __init__(self):
self.data = dict()
def __delitem__(self, key):
del self.data[id(key)]
def __getitem__(self, key):
ret = self.data[id(key)]
if isinstance(ret, bench_dict.lazy_entry):
return ret.get()
return ret
def __len__(self):
return len(self.data)
def __setitem__(self, key, value):
self.data[id(key)] = value
bench_registry = bench_dict()
class kernel:
def __init__(self, src, outputs):
@@ -200,7 +233,7 @@ class kernel:
defines.append((k, values))
opt = libtriton.options_space()
opt.defines = defines
opt.num_warps = [4]
opt.num_warps = [2, 4, 8]
# create unique id for this op
op_id = libtriton.make_op_id()
self.fw_id[key] = op_id
@@ -209,6 +242,10 @@ class kernel:
if self.fw_op is None:
self.fw_op = _make_framework_op(self.src, self.outputs, opt)
# benchmarking info
bench = 0
if 'bench' in kwargs:
bench = kwargs['bench']
# retrieve framework op
op_id = self.fw_id[key]
# register grid
@@ -217,9 +254,16 @@ class kernel:
op_args = [x.handle if isinstance(x, triton.utils.scalar) else x for x in args[:-1]]
# call framework function
if fw.has_tensorflow():
return self.fw_op(*op_args, id=op_id)
bench_id = libtriton.make_scalar_id() if bench > 0 else 0
ret = self.fw_op(*op_args, id=op_id, bench=bench, bench_id=bench_id)
if bench > 0:
bench_registry[ret] = bench_dict.lazy_entry(bench_id)
elif fw.has_torch():
args = [x.contiguous() if isinstance(x, fw.torch.Tensor) else x for x in op_args]
return self.fw_op(op_id, *args)
ret = self.fw_op(op_id, bench, *args)
if bench > 0:
bench_registry[ret] = libtriton.retrieve_scalar(op_id)
else:
assert False
assert False
return ret

View File

@@ -11,38 +11,36 @@ void dot(TYPE * A, TYPE * B, TYPE * C,
// prologue
int ridx = get_program_id(0);
int ridy = get_program_id(1);
int rxa[TM] = ridx * TM + 0 ... TM;
int ryb[TN] = ridy * TN + 0 ... TN;
int rka[TK] = 0 ... TK;
int rkb[TK] = 0 ... TK;
int rm[TM] = ridx * TM + 0 ... TM;
int rn[TN] = ridy * TN + 0 ... TN;
int rk[TK] = 0 ... TK;
float c[TM, TN] = 0;
// pointers to operands
TYPE* pa[SHAPE_A] = A + rka[BROADCAST_AK] * STRIDE_AK + rxa[BROADCAST_AM] * STRIDE_AM;
TYPE* pb[SHAPE_B] = B + rkb[BROADCAST_BK] * STRIDE_BK + ryb[BROADCAST_BN] * STRIDE_BN;
TYPE* pa[SHAPE_A] = A + rk[BROADCAST_AK] * STRIDE_AK + rm[BROADCAST_AM] * STRIDE_AM;
TYPE* pb[SHAPE_B] = B + rk[BROADCAST_BK] * STRIDE_BK + rn[BROADCAST_BN] * STRIDE_BN;
// prefetches operands
TYPE a[SHAPE_A] = (*pa);
TYPE b[SHAPE_B] = (*pb);
TYPE a[SHAPE_A] = *pa;
TYPE b[SHAPE_B] = *pb;
// reduction loop
for(int k = K; k > 0; k-= TK){
c += USE_A @ USE_B;
pa = pa + TK * STRIDE_AK;
pb = pb + TK * STRIDE_BK;
a = *pa;
b = *pb;
bool checka[SHAPE_A] = k > TK;
bool checkb[SHAPE_B] = k > TK;
a = checka ? *pa : 0;
b = checkb ? *pb : 0;
}
// epilogue
int rxc[TM] = ridx * TM + 0 ... TM;
int ryc[TN] = ridy * TN + 0 ... TN;
TYPE* pc[TM, TN] = C + ryc[newaxis, :] + rxc[:, newaxis] * ldc;
bool checkc[TM, TN] = (rxc < M)[:, newaxis] && (ryc < N)[newaxis, :];
*?(checkc) pc = c;
TYPE* pc[TM, TN] = C + rm[:, newaxis] * ldc + rn[newaxis, :];
*pc = c;
}
"""
kernel = triton.kernel(src, ['C'])
@staticmethod
def _call(a, b, transpose_a, transpose_b):
def _call(a, b, transpose_a, transpose_b, bench = 0):
# extract shapes
shape_a = triton.shape(a)
shape_b = triton.shape(b)
@@ -78,16 +76,17 @@ void dot(TYPE * A, TYPE * B, TYPE * C,
'BROADCAST_BK': 'newaxis, :' if transpose_b else ':, newaxis',
'BROADCAST_BN': ':, newaxis' if transpose_b else 'newaxis, :',
'SHAPE_B' : 'TN, TK' if transpose_b else 'TK, TN'}
return _dot.kernel(a, b, c, M, N, Ka, lda, ldb, ldc, grid,
return _dot.kernel(a, b, c, M, N, Ka, lda, ldb, ldc,
grid, bench=bench,
AT = transpose_a, BT = transpose_b, TYPE = dtype,
TM = [128], TN = [128], TK = [8], **macros)
TM = [64, 128], TN = [64, 128], TK = [8], **macros)
@staticmethod
def forward(ctx, a, b, transpose_a = False, transpose_b = False):
def forward(ctx, a, b, transpose_a = False, transpose_b = False, bench = 0):
ctx.save_for_backward(a, b)
ctx.t_a = transpose_a
ctx.t_b = transpose_b
return _dot._call(a, b, transpose_a, transpose_b)
return _dot._call(a, b, transpose_a, transpose_b, bench)
@staticmethod
def backward(ctx, dy):
@@ -108,5 +107,5 @@ void dot(TYPE * A, TYPE * B, TYPE * C,
else:
assert False
return da, db, None, None, None, None, None, None, None
dot = _dot.apply

View File

@@ -2,52 +2,58 @@
import triton
import math
class _einsum(triton.function):
src = """
void einsum_(TYPE * A, TYPE * B, TYPE * C,
int dim_M, int dim_N, int dim_K,
int std_A0, int std_B0, int std_C0,
int std_A1, int std_B1, int std_C1) {
// program id
int pgm = get_program_id(0);
int pgn = get_program_id(1);
int pgb = get_program_id(2);
// range
int rm[TM] = pgm * TM + 0 ... TM;
int rn[TN] = pgn * TN + 0 ... TN;
int rb[TB] = pgb * TB + 0 ... TB;
int rk[TK] = 0 ... TK;
// accumulator
TYPE c[TM, TN, TB] = 0;
// pointers to a
TYPE *pa[SHAPE_A] = A + rk[BROADCAST_AK] * STRIDE_AK
+ rm[BROADCAST_AM] * STRIDE_AM
+ rb[newaxis, newaxis, :] * std_A0;
// pointers to b
TYPE *pb[SHAPE_B] = B + rk[BROADCAST_BK] * STRIDE_BK
+ rn[BROADCAST_BN] * STRIDE_BN
+ rb[newaxis, newaxis, :] * std_B0;
// accumulation
for(int k = dim_K; k > 0; k -= TK) {
TYPE a[SHAPE_A] = *pa;
TYPE b[SHAPE_B] = *pb;
c += USE_A @ USE_B;
pa += TK * STRIDE_AK;
pb += TK * STRIDE_BK;
}
// write-back
TYPE *pc[TM, TN, TB] = C + rm[:, newaxis, newaxis] * std_C1
+ rn[newaxis, :, newaxis] * 1
+ rb[newaxis, newaxis, :] * std_C0;
bool checkm[TM] = rm < dim_M;
bool checkn[TN] = rn < dim_N;
bool checkc[TM, TN, TB] = checkm[:, newaxis, newaxis] &&
checkn[newaxis, :, newaxis];
*?(checkc)pc = c;
void einsum_(TYPE * A, TYPE * B, TYPE * C,
int dim_M, int dim_N, int dim_K,
int std_A0, int std_B0, int std_C0,
int std_A1, int std_B1, int std_C1) {
// program id
int pgm = get_program_id(0);
int pgn = get_program_id(1);
int pgb = get_program_id(2);
// range
int rm[TM] = pgm * TM + 0 ... TM;
int rn[TN] = pgn * TN + 0 ... TN;
int rb[TB] = pgb * TB + 0 ... TB;
int rk[TK] = 0 ... TK;
// accumulator
TYPE c[TM, TN, TB] = 0;
// pointers to a
TYPE *pa[SHAPE_A] = A + rk[BROADCAST_AK] * STRIDE_AK
+ rm[BROADCAST_AM] * STRIDE_AM
+ rb[newaxis, newaxis, :] * std_A0;
// pointers to b
TYPE *pb[SHAPE_B] = B + rk[BROADCAST_BK] * STRIDE_BK
+ rn[BROADCAST_BN] * STRIDE_BN
+ rb[newaxis, newaxis, :] * std_B0;
// prefetch
TYPE a[SHAPE_A] = *pa;
TYPE b[SHAPE_B] = *pb;
// accumulation
for(int k = dim_K; k > 0; k -= TK) {
c += USE_A @ USE_B;
pa += TK * STRIDE_AK;
pb += TK * STRIDE_BK;
bool checka[SHAPE_A] = k > TK;
bool checkb[SHAPE_B] = k > TK;
a = checka ? *pa : 0;
b = checkb ? *pb : 0;
}
"""
// write-back
TYPE *pc[TM, TN, TB] = C + rm[:, newaxis, newaxis] * std_C1
+ rn[newaxis, :, newaxis] * 1
+ rb[newaxis, newaxis, :] * std_C0;
bool checkm[TM] = rm < dim_M;
bool checkn[TN] = rn < dim_N;
bool checkc[TM, TN, TB] = checkm[:, newaxis, newaxis] &&
checkn[newaxis, :, newaxis];
*?(checkc)pc = c;
}
"""
kernel = triton.kernel(src, ['C'])
@@ -134,7 +140,8 @@ class _einsum(triton.function):
@staticmethod
def call(a, b, trans_a, trans_b, shape_c, bmnk,
std0, std1, einsum_a, einsum_b, einsum_c):
std0, std1, einsum_a, einsum_b, einsum_c,
bench):
dtype = a.dtype
c = triton.empty(shape_c, dtype)
grid = lambda opt: [triton.cdiv(bmnk[1], opt.d('TM')),
@@ -154,16 +161,22 @@ class _einsum(triton.function):
'BROADCAST_BK': ':, newaxis, newaxis' if not trans_b else 'newaxis, :, newaxis',
'BROADCAST_BN': 'newaxis, :, newaxis' if not trans_b else ':, newaxis, newaxis',
'SHAPE_B' : 'TK, TN, TB' if not trans_b else 'TN, TK, TB'}
return _einsum.kernel(a, b, c,
TM = [2**i for i in range(5, max(6, min(8, int(math.log2(bmnk[1]) + 1 ))))]
TN = [2**i for i in range(5, max(6, min(8, int(math.log2(bmnk[2]) + 1 ))))]
TB = [2**i for i in range(0, max(1, min(3, int(math.log2(bmnk[0]) + 1 ))))]
print(TM)
print(TN)
return _einsum.kernel(a, b, c,
bmnk[1], bmnk[2], bmnk[3],
std0[0], std0[1], std0[2],
std1[0], std1[1], std1[2],
grid, **macros,
TYPE='float', TM=32, TN=32, TK=8, TB=1)
grid, bench=bench,
**macros,
TYPE='float', TM=TM, TN=TN, TK=8, TB=TB)
@staticmethod
def forward(ctx, subscripts, a, b):
def forward(ctx, subscripts, a, b, **kwargs):
ctx.save_for_backward(a, b)
if type(subscripts) is str:
einsum_a, einsum_bc = subscripts.split(",")
@@ -173,14 +186,16 @@ class _einsum(triton.function):
shape_c, bmnk, std0, std1, ta, tb = _einsum._parse_einsum(
einsum_a, einsum_b, einsum_c,
a.shape.as_list(), b.shape.as_list()
triton.shape(a), triton.shape(b)
)
bench = kwargs['bench'] if 'bench' in kwargs else 0
ctx.trans_a = ta
ctx.trans_b = tb
ctx.einsum_a = einsum_a
ctx.einsum_b = einsum_b
ctx.einsum_c = einsum_c
return _einsum.call(a, b, ta, tb, shape_c, bmnk, std0, std1, einsum_a, einsum_b, einsum_c)
ctx.bench = bench
return _einsum.call(a, b, ta, tb, shape_c, bmnk, std0, std1, einsum_a, einsum_b, einsum_c, bench)
@staticmethod
@@ -191,22 +206,23 @@ class _einsum(triton.function):
einsum_a = ctx.einsum_a
einsum_b = ctx.einsum_b
einsum_c = ctx.einsum_c
bench = ctx.bench
if not trans_a and not trans_b: # NN
da = einsum((einsum_c, einsum_b, einsum_a), dc, b)
db = einsum((einsum_a, einsum_c, einsum_b), a, dc)
da = einsum((einsum_c, einsum_b, einsum_a), dc, b, bench=bench)
db = einsum((einsum_a, einsum_c, einsum_b), a, dc, bench=bench)
elif not trans_a and trans_b: # NT
da = einsum((einsum_c, einsum_b, einsum_a), dc, b)
db = einsum((einsum_c, einsum_a, einsum_b), dc, a)
da = einsum((einsum_c, einsum_b, einsum_a), dc, b, bench=bench)
db = einsum((einsum_c, einsum_a, einsum_b), dc, a, bench=bench)
elif trans_a and not trans_b: # TN
da = einsum((einsum_b, einsum_c, einsum_a), b, dc)
db = einsum((einsum_a, einsum_c, einsum_b), a, dc)
da = einsum((einsum_b, einsum_c, einsum_a), b, dc, bench=bench)
db = einsum((einsum_a, einsum_c, einsum_b), a, dc, bench=bench)
elif trans_a and trans_b: # TT (not used)
da = einsum((einsum_b, einsum_c, einsum_a), b, dc)
db = einsum((einsum_c, einsum_a, einsum_b), dc, a)
da = einsum((einsum_b, einsum_c, einsum_a), b, dc, bench=bench)
db = einsum((einsum_c, einsum_a, einsum_b), dc, a, bench=bench)
return da, db, None, None, None, None, None, None, None, None, None, None

View File

@@ -22,7 +22,8 @@ class lazy_shape:
def shape(A) :
if fw.has_tensorflow():
return lazy_shape(fw.tensorflow.shape(A))
return A.shape.as_list()
#return lazy_shape(fw.tensorflow.shape(A))
elif fw.has_torch():
return A.shape
else:

View File

@@ -34,7 +34,7 @@ int main() {
for(const auto& c: configs){
std::tie(ord, AT, BT, M, N, K) = c;
std::cout << "// " << c << std::flush;
for(auto perf: bench_dot(stream, HALF, AT, BT, M, N, K, ord, ord))
for(auto perf: bench_dot(stream, FLOAT, AT, BT, M, N, K, ord, ord))
std::cout << ", " << perf << std::flush;
std::cout << std::endl;
}