[PYTHON][EXAMPLES] Better einsum example
This commit is contained in:
@@ -8,6 +8,7 @@
|
|||||||
#include <string>
|
#include <string>
|
||||||
#include <memory>
|
#include <memory>
|
||||||
#include <functional>
|
#include <functional>
|
||||||
|
#include <mutex>
|
||||||
// codegen
|
// codegen
|
||||||
#include "triton/ir/context.h"
|
#include "triton/ir/context.h"
|
||||||
#include "triton/codegen/target.h"
|
#include "triton/codegen/target.h"
|
||||||
@@ -110,6 +111,7 @@ private:
|
|||||||
std::string src_;
|
std::string src_;
|
||||||
options_space_t opt_space_;
|
options_space_t opt_space_;
|
||||||
std::map<cache_key_t, caller> cache_;
|
std::map<cache_key_t, caller> cache_;
|
||||||
|
std::mutex src_mutex_;
|
||||||
};
|
};
|
||||||
|
|
||||||
}
|
}
|
||||||
|
@@ -9,8 +9,6 @@
|
|||||||
#include <unordered_map>
|
#include <unordered_map>
|
||||||
|
|
||||||
|
|
||||||
extern std::string filename_in;
|
|
||||||
extern std::string filename_out;
|
|
||||||
|
|
||||||
using DirectiveMap = std::unordered_map<std::string, int>;
|
using DirectiveMap = std::unordered_map<std::string, int>;
|
||||||
|
|
||||||
|
@@ -29,9 +29,9 @@
|
|||||||
#include "triton/ir/print.h"
|
#include "triton/ir/print.h"
|
||||||
#include "triton/tools/bench.hpp"
|
#include "triton/tools/bench.hpp"
|
||||||
#include "llvm/IR/Module.h"
|
#include "llvm/IR/Module.h"
|
||||||
|
#include <mutex>
|
||||||
|
|
||||||
|
std::mutex mut;
|
||||||
|
|
||||||
|
|
||||||
namespace triton{
|
namespace triton{
|
||||||
namespace runtime {
|
namespace runtime {
|
||||||
@@ -168,7 +168,6 @@ function::caller function::autotune(driver::stream* stream, const grid_fn_ty& gr
|
|||||||
for(auto it: opt_space_.defines)
|
for(auto it: opt_space_.defines)
|
||||||
cpp.AddMacro(it.first, &opt.defines.at(it.first));
|
cpp.AddMacro(it.first, &opt.defines.at(it.first));
|
||||||
cpp.Process(tokens);
|
cpp.Process(tokens);
|
||||||
// tokens.Print(stdout);
|
|
||||||
// parse
|
// parse
|
||||||
Parser parser(tokens);
|
Parser parser(tokens);
|
||||||
parser.Parse();
|
parser.Parse();
|
||||||
@@ -309,7 +308,10 @@ void function::operator()(const std::vector<arg>& args, const grid_fn_ty& grid_f
|
|||||||
}
|
}
|
||||||
|
|
||||||
/* re-tune and re-compile */
|
/* re-tune and re-compile */
|
||||||
cache_.insert({key, autotune(stream, grid_fn, args)});
|
{
|
||||||
|
std::lock_guard<std::mutex> lock(mut);
|
||||||
|
cache_.insert({key, autotune(stream, grid_fn, args)});
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void function::operator()(const std::vector<arg>& args, const grid_t& grid, driver::stream *stream) {
|
void function::operator()(const std::vector<arg>& args, const grid_t& grid, driver::stream *stream) {
|
||||||
|
@@ -1,38 +1,92 @@
|
|||||||
import numpy as np
|
#!/usr/bin/env python
|
||||||
import torch
|
|
||||||
|
import numpy as np
|
||||||
|
from enum import Enum
|
||||||
import triton
|
import triton
|
||||||
|
|
||||||
batch_dim = 16
|
class MODE(Enum):
|
||||||
ctx_dim = 32
|
TF = 1
|
||||||
head_dim = 8
|
TORCH = 2
|
||||||
state_dim = 32
|
|
||||||
key_dim = 32
|
|
||||||
n_keys = 32
|
|
||||||
bs = batch_dim * ctx_dim
|
|
||||||
|
|
||||||
# shapes
|
try:
|
||||||
x_shape = (bs, state_dim)
|
import tensorflow as tf
|
||||||
qw_shape = (state_dim, head_dim * key_dim)
|
mode = MODE.TF
|
||||||
kw_shape = (head_dim, 2, n_keys, key_dim // 2)
|
except ModuleNotFoundError:
|
||||||
|
pass
|
||||||
|
|
||||||
np.random.seed(0)
|
try:
|
||||||
x = np.random.uniform(-1.0, 1.0, x_shape).astype(np.float32) # layer input
|
import torch
|
||||||
qw = np.random.uniform(-1.0, 1.0, qw_shape).astype(np.float32) # query weights
|
mode = MODE.TORCH
|
||||||
kw = np.random.uniform(-1.0, 1.0, kw_shape).astype(np.float32) # key weights
|
except ModuleNotFoundError:
|
||||||
# (bs, head_dim * key_dim) = (bs, state_dim) * (state_dim, head_dim * key_dim)
|
pass
|
||||||
# (bs, head_dim, 2, key_dim//2) <== (bs, head_dim * key_dim)
|
|
||||||
q = np.dot(x, qw).reshape(bs, head_dim, 2, key_dim//2) # normal matmul
|
|
||||||
|
|
||||||
# (bs, head_dim, 2, n_keys) = (bs, head_dim, 2, key_dim//2) * (head_dim, 2, n_keys, key_dim//2)
|
cases = []
|
||||||
# outer: bs, n_keys
|
# Matmul
|
||||||
# inner: key_dim//2
|
cases += [[[4, 1024, 1024], [1024, 1024], [4, 1024, 1024], "btc,ck->btk"]]
|
||||||
# batch: head_dim, 2 (key_axis)
|
# Attention
|
||||||
qk = np.einsum("bhak,hank->bhan", q, kw)
|
cases += [[[4, 256, 8, 2, 64], [8, 2, 512, 64], [4, 256, 8, 2, 512], "bchak,hank->bchan"]]
|
||||||
|
|
||||||
tq = torch.from_numpy(q).contiguous().cuda()
|
if mode == MODE.TF:
|
||||||
tkw = torch.from_numpy(kw).contiguous().cuda()
|
sess = tf.InteractiveSession()
|
||||||
tqk = triton.ops.einsum("bhak,hank->bhan", tq, tkw)
|
|
||||||
diff = np.abs(qk - tqk.cpu().numpy())
|
|
||||||
print(np.max(diff))
|
|
||||||
print(np.min(diff))
|
|
||||||
|
|
||||||
|
for a_shape, b_shape, c_shape, einsum in cases:
|
||||||
|
|
||||||
|
A = np.random.uniform(-1.0, 1.0, a_shape).astype(np.float16).astype(np.float32)
|
||||||
|
B = np.random.uniform(-1.0, 1.0, b_shape).astype(np.float16).astype(np.float32)
|
||||||
|
E = np.random.uniform(-1.0, 1.0, c_shape).astype(np.float16).astype(np.float32)
|
||||||
|
|
||||||
|
# Execute (tensorflow)
|
||||||
|
if mode == MODE.TF:
|
||||||
|
a = tf.placeholder(tf.float32, a_shape, name="a")
|
||||||
|
b = tf.placeholder(tf.float32, b_shape, name="b")
|
||||||
|
e = tf.placeholder(tf.float32, c_shape, name="e")
|
||||||
|
c = triton.ops.einsum(einsum, a, b, 1)
|
||||||
|
da, db = tf.gradients(c, [a, b], e)
|
||||||
|
feed_dict = { a: A.astype(np.float32),
|
||||||
|
b: B.astype(np.float32),
|
||||||
|
e: E }
|
||||||
|
sess.run(tf.global_variables_initializer())
|
||||||
|
result = sess.run([c, da, db], feed_dict = feed_dict)
|
||||||
|
# Execute (torch)
|
||||||
|
if mode == MODE.TORCH:
|
||||||
|
a = torch.from_numpy(A).cuda()
|
||||||
|
b = torch.from_numpy(B).cuda()
|
||||||
|
e = torch.from_numpy(E).cuda()
|
||||||
|
a.requires_grad_(True)
|
||||||
|
b.requires_grad_(True)
|
||||||
|
c = triton.ops.einsum(einsum, a, b, 1)
|
||||||
|
torch.autograd.backward(c, e)
|
||||||
|
da = a.grad
|
||||||
|
db = b.grad
|
||||||
|
result = [c.cpu().detach().numpy(), da.cpu().detach().numpy(), db.cpu().detach().numpy()]
|
||||||
|
|
||||||
|
# benchmark
|
||||||
|
nanosec = triton.bench_registry[c]
|
||||||
|
ctx = triton.ctx_registry[c]
|
||||||
|
b, m, n, k = tuple((ctx.bmnk[i] for i in range(0, 4)))
|
||||||
|
ops = 2.*b*m*n*k
|
||||||
|
print('C TFLOPS:', ops / triton.bench_registry[c] * 1e-3)
|
||||||
|
#print('DA TFLOPS:', ops / triton.bench_registry[da] * 1e-3)
|
||||||
|
#print('DB TFLOPS:', ops / triton.bench_registry[db] * 1e-3)
|
||||||
|
|
||||||
|
# test
|
||||||
|
ctx = triton.ctx_registry[c]
|
||||||
|
t_a = ctx.trans_a
|
||||||
|
t_b = ctx.trans_b
|
||||||
|
e_a = ctx.einsum_a
|
||||||
|
e_b = ctx.einsum_b
|
||||||
|
e_c = ctx.einsum_c
|
||||||
|
C = np.einsum(einsum, A, B)
|
||||||
|
if not t_a and not t_b: # NN
|
||||||
|
DA = np.einsum(f"{e_c},{e_b}->{e_a}", E, B)
|
||||||
|
DB = np.einsum(f"{e_a},{e_c}->{e_b}", A, E)
|
||||||
|
elif not t_a and t_b: # NT
|
||||||
|
DA = np.einsum(f"{e_c},{e_b}->{e_a}", E, B)
|
||||||
|
DB = np.einsum(f"{e_c},{e_a}->{e_b}", E, A)
|
||||||
|
elif t_a and not t_b: # TN
|
||||||
|
DA = np.einsum(f"{e_b},{e_c}->{e_a}", B, E)
|
||||||
|
DB = np.einsum(f"{e_a},{e_c}->{e_b}", A, E)
|
||||||
|
c, da, db = result[0], result[1], result[2]
|
||||||
|
print('C diff:', np.abs((C - c)).max())
|
||||||
|
print('DA diff:', np.abs((DA - da)).max())
|
||||||
|
print('DB diff:', np.abs((DB - db)).max())
|
@@ -1,139 +0,0 @@
|
|||||||
#!/usr/bin/env python
|
|
||||||
|
|
||||||
from __future__ import absolute_import
|
|
||||||
from __future__ import division
|
|
||||||
from __future__ import print_function
|
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
import tensorflow as tf
|
|
||||||
import triton
|
|
||||||
import blocksparse as bs
|
|
||||||
from tensorflow.python.ops import gradient_checker
|
|
||||||
|
|
||||||
one = 0
|
|
||||||
out = 0
|
|
||||||
bench = 0
|
|
||||||
|
|
||||||
class ProdKeyTest(tf.test.TestCase):
|
|
||||||
|
|
||||||
def testEinsum(self):
|
|
||||||
# multi-threading screws up benchmarking
|
|
||||||
conf = tf.ConfigProto(
|
|
||||||
intra_op_parallelism_threads=1,
|
|
||||||
inter_op_parallelism_threads=1)
|
|
||||||
|
|
||||||
with self.test_session(config=conf) as sess, tf.device("/gpu:0"):
|
|
||||||
|
|
||||||
batch_dim = 4
|
|
||||||
ctx_dim = 256
|
|
||||||
head_dim = 8
|
|
||||||
n_keys = 512
|
|
||||||
key_dim = 128
|
|
||||||
|
|
||||||
# batch_dim = 2
|
|
||||||
# ctx_dim = 8
|
|
||||||
# head_dim = 2
|
|
||||||
# n_keys = 16
|
|
||||||
# key_dim = 16
|
|
||||||
|
|
||||||
for a_shape, b_shape, c_shape, einsum in [
|
|
||||||
[ [ 4, 8, 8 ], [ 8, 8 ], [ 4, 8, 8 ], "btc,ck->btk" ],
|
|
||||||
[ [4, 1024, 1024], [ 1024, 1024 ], [4, 1024, 1024 ], "btc,ck->btk" ],
|
|
||||||
[ (batch_dim, ctx_dim, head_dim, 2, key_dim//2),(head_dim, 2, n_keys, key_dim//2), (batch_dim, ctx_dim, head_dim, 2, n_keys), "bchak,hank->bchan" ],
|
|
||||||
]:
|
|
||||||
|
|
||||||
if one:
|
|
||||||
A = np.ones(a_shape, dtype=np.float16).astype(np.float32)
|
|
||||||
B = np.ones(b_shape, dtype=np.float16).astype(np.float32)
|
|
||||||
E = np.ones(c_shape, dtype=np.float32)
|
|
||||||
else:
|
|
||||||
# QK = np.random.normal(loc=0.0, scale=1.0, size=qk_shape).astype(np.float16).astype(np.float32)
|
|
||||||
# V = np.random.normal(loc=0.0, scale=1.0, size=vw_shape).astype(np.float16).astype(np.float32)
|
|
||||||
A = np.random.uniform(-1.0, 1.0, a_shape).astype(np.float16).astype(np.float32)
|
|
||||||
B = np.random.uniform(-1.0, 1.0, b_shape).astype(np.float16).astype(np.float32)
|
|
||||||
E = np.random.uniform(-1.0, 1.0, c_shape).astype(np.float16).astype(np.float32)
|
|
||||||
|
|
||||||
a = tf.placeholder(tf.float32, a_shape, name="a")
|
|
||||||
b = tf.placeholder(tf.float32, b_shape, name="b")
|
|
||||||
e = tf.placeholder(tf.float32, c_shape, name="e")
|
|
||||||
feed_dict = { a: A.astype(np.float32),
|
|
||||||
b: B.astype(np.float32),
|
|
||||||
e: E }
|
|
||||||
|
|
||||||
c = triton.ops.einsum(einsum, a, b, bench=bench)
|
|
||||||
|
|
||||||
# error = gradient_checker.compute_gradient_error(a, a_shape, c, c_shape, delta=1e-1, extra_feed_dict={ b:B }) #
|
|
||||||
# print(error)
|
|
||||||
# error = gradient_checker.compute_gradient_error(b, b_shape, c, c_shape, delta=1e-1, extra_feed_dict={ a:A }) #
|
|
||||||
# print(error)
|
|
||||||
# return
|
|
||||||
|
|
||||||
with tf.control_dependencies([c.op]):
|
|
||||||
da, db = tf.gradients(c, [a, b], e)
|
|
||||||
|
|
||||||
# c, = sess.run( [ c, ], feed_dict )
|
|
||||||
rc, rda, rdb = sess.run( [ c, da, db ], feed_dict )
|
|
||||||
|
|
||||||
if bench > 0:
|
|
||||||
nanosec = triton.bench_registry[c]
|
|
||||||
ctx = triton.ctx_registry[c]
|
|
||||||
b, m, n, k = tuple((ctx.bmnk[i] for i in range(0, 4)))
|
|
||||||
ops = 2. * b * m * n * k
|
|
||||||
print('C TFLOPS:', ops / triton.bench_registry[c] * 1e-3)
|
|
||||||
print('DA TFLOPS:', ops / triton.bench_registry[da] * 1e-3)
|
|
||||||
print('DB TFLOPS:', ops / triton.bench_registry[db] * 1e-3)
|
|
||||||
|
|
||||||
else:
|
|
||||||
C = np.einsum(einsum, A, B)
|
|
||||||
ctx = triton.ctx_registry[c]
|
|
||||||
t_a = ctx.trans_a
|
|
||||||
t_b = ctx.trans_b
|
|
||||||
e_a = ctx.einsum_a
|
|
||||||
e_b = ctx.einsum_b
|
|
||||||
e_c = ctx.einsum_c
|
|
||||||
|
|
||||||
if not t_a and not t_b: # NN
|
|
||||||
DA = np.einsum(f"{e_c},{e_b}->{e_a}", E, B)
|
|
||||||
DB = np.einsum(f"{e_a},{e_c}->{e_b}", A, E)
|
|
||||||
elif not t_a and t_b: # NT
|
|
||||||
DA = np.einsum(f"{e_c},{e_b}->{e_a}", E, B)
|
|
||||||
DB = np.einsum(f"{e_c},{e_a}->{e_b}", E, A)
|
|
||||||
elif t_a and not t_b: # TN
|
|
||||||
DA = np.einsum(f"{e_b},{e_c}->{e_a}", B, E)
|
|
||||||
DB = np.einsum(f"{e_a},{e_c}->{e_b}", A, E)
|
|
||||||
|
|
||||||
print("testProdKey", einsum)
|
|
||||||
if not bench:
|
|
||||||
for op, dev, cpu in [
|
|
||||||
[ "C", rc, C ],
|
|
||||||
[ "DA", rda, DA ],
|
|
||||||
[ "DB", rdb, DB ],
|
|
||||||
]:
|
|
||||||
self.compare_results(op, dev, cpu)
|
|
||||||
|
|
||||||
def compare_results(self, op, dev, cpu):
|
|
||||||
dev = dev.astype(np.float64)
|
|
||||||
cpu = cpu.astype(np.float64)
|
|
||||||
|
|
||||||
# print(dev.reshape(-1)[0:4])
|
|
||||||
# print(cpu.reshape(-1)[0:4])
|
|
||||||
|
|
||||||
dif = np.abs(cpu - dev)
|
|
||||||
maxval = np.max(abs(cpu))
|
|
||||||
avgval = np.average(abs(cpu))
|
|
||||||
maxdif = dif.max()
|
|
||||||
max_err = maxdif if avgval == 0 else maxdif / avgval
|
|
||||||
l2_err = 0.0 if avgval == 0 else np.sqrt(np.square(dif).sum()) / np.sqrt(np.square(cpu).sum())
|
|
||||||
|
|
||||||
print("op:%3s, max:%18.12f, avg:%18.12f, dif:%18.12f, err:%18.12f, l2_err:%18.12f shape:%15s" % (op, maxval, avgval, maxdif, max_err, l2_err, str(cpu.shape)))
|
|
||||||
|
|
||||||
if out:
|
|
||||||
dim = cpu.shape[-1]
|
|
||||||
np.savetxt("%s_dif.txt" % op, dif.reshape((-1,dim)), fmt='%4.1f') #7.5 5.3
|
|
||||||
np.savetxt("%s_cpu.txt" % op, cpu.reshape((-1,dim)), fmt='%4.1f') #7.5 5.3
|
|
||||||
np.savetxt("%s_dev.txt" % op, dev.reshape((-1,dim)), fmt='%4.1f') #7.5 5.3
|
|
||||||
exit()
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
tf.test.main()
|
|
||||||
|
|
@@ -179,7 +179,7 @@ void einsumk(TYPE * A, TYPE * B, TYPE * C,
|
|||||||
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def forward(ctx, subscripts, a, b, **kwargs):
|
def forward(ctx, subscripts, a, b, bench = 0):
|
||||||
ctx.save_for_backward(a, b)
|
ctx.save_for_backward(a, b)
|
||||||
if type(subscripts) is str:
|
if type(subscripts) is str:
|
||||||
einsum_a, einsum_bc = subscripts.split(",")
|
einsum_a, einsum_bc = subscripts.split(",")
|
||||||
@@ -189,9 +189,7 @@ void einsumk(TYPE * A, TYPE * B, TYPE * C,
|
|||||||
|
|
||||||
shape_c, bmnk, std0, std1, ta, tb = _einsum._parse_einsum(
|
shape_c, bmnk, std0, std1, ta, tb = _einsum._parse_einsum(
|
||||||
einsum_a, einsum_b, einsum_c,
|
einsum_a, einsum_b, einsum_c,
|
||||||
triton.shape(a), triton.shape(b)
|
triton.shape(a), triton.shape(b))
|
||||||
)
|
|
||||||
bench = kwargs['bench'] if 'bench' in kwargs else 0
|
|
||||||
ctx.trans_a = ta
|
ctx.trans_a = ta
|
||||||
ctx.trans_b = tb
|
ctx.trans_b = tb
|
||||||
ctx.einsum_a = einsum_a
|
ctx.einsum_a = einsum_a
|
||||||
@@ -213,20 +211,20 @@ void einsumk(TYPE * A, TYPE * B, TYPE * C,
|
|||||||
bench = ctx.bench
|
bench = ctx.bench
|
||||||
|
|
||||||
if not trans_a and not trans_b: # NN
|
if not trans_a and not trans_b: # NN
|
||||||
da = einsum((einsum_c, einsum_b, einsum_a), dc, b, bench=bench)
|
da = einsum((einsum_c, einsum_b, einsum_a), dc, b, bench)
|
||||||
db = einsum((einsum_a, einsum_c, einsum_b), a, dc, bench=bench)
|
db = einsum((einsum_a, einsum_c, einsum_b), a, dc, bench)
|
||||||
|
|
||||||
elif not trans_a and trans_b: # NT
|
elif not trans_a and trans_b: # NT
|
||||||
da = einsum((einsum_c, einsum_b, einsum_a), dc, b, bench=bench)
|
da = einsum((einsum_c, einsum_b, einsum_a), dc, b, bench)
|
||||||
db = einsum((einsum_c, einsum_a, einsum_b), dc, a, bench=bench)
|
db = einsum((einsum_c, einsum_a, einsum_b), dc, a, bench)
|
||||||
|
|
||||||
elif trans_a and not trans_b: # TN
|
elif trans_a and not trans_b: # TN
|
||||||
da = einsum((einsum_b, einsum_c, einsum_a), b, dc, bench=bench)
|
da = einsum((einsum_b, einsum_c, einsum_a), b, dc, bench)
|
||||||
db = einsum((einsum_a, einsum_c, einsum_b), a, dc, bench=bench)
|
db = einsum((einsum_a, einsum_c, einsum_b), a, dc, bench)
|
||||||
|
|
||||||
elif trans_a and trans_b: # TT (not used)
|
elif trans_a and trans_b: # TT (not used)
|
||||||
da = einsum((einsum_b, einsum_c, einsum_a), b, dc, bench=bench)
|
da = einsum((einsum_b, einsum_c, einsum_a), b, dc, bench)
|
||||||
db = einsum((einsum_c, einsum_a, einsum_b), dc, a, bench=bench)
|
db = einsum((einsum_c, einsum_a, einsum_b), dc, a, bench)
|
||||||
|
|
||||||
return da, db, None, None, None, None, None, None, None, None, None, None
|
return da, db, None, None, None, None, None, None, None, None, None, None
|
||||||
|
|
||||||
|
Reference in New Issue
Block a user