[GENERAL] Merged einsum feature branch. Various feature, performance
improvements and bugfixes: * Added preliminary support for extended Einstein summation in PyTriton * Significant performance improvement on FP32 kernels containing matrix multiplication * Added re-coalescing pass for FP16 kernels containing matrix multiplication * Various bugfixes
This commit is contained in:
@@ -1,157 +0,0 @@
|
||||
import tensorflow as tf
|
||||
import triton
|
||||
import numpy as np
|
||||
|
||||
src = '''
|
||||
#if AT == 1
|
||||
#define USE_A ^a
|
||||
#define STRIDE_AK lda
|
||||
#define STRIDE_AM 1
|
||||
#define BROADCAST_AK :, newaxis
|
||||
#define BROADCAST_AM newaxis, :
|
||||
#define SHAPE_A TK, TM
|
||||
#else
|
||||
#define USE_A a
|
||||
#define STRIDE_AK 1
|
||||
#define STRIDE_AM lda
|
||||
#define BROADCAST_AK newaxis, :
|
||||
#define BROADCAST_AM :, newaxis
|
||||
#define SHAPE_A TM, TK
|
||||
#endif
|
||||
|
||||
#if BT == 1
|
||||
#define USE_B ^b
|
||||
#define STRIDE_BK 1
|
||||
#define STRIDE_BM ldb
|
||||
#define BROADCAST_BN newaxis, :
|
||||
#define BROADCAST_BK :, newaxis
|
||||
#define SHAPE_B TN, TK
|
||||
#else
|
||||
#define USE_B b
|
||||
#define STRIDE_BK ldb
|
||||
#define STRIDE_BM 1
|
||||
#define BROADCAST_BN :, newaxis
|
||||
#define BROADCAST_BK newaxis, :
|
||||
#define SHAPE_B TK, TN
|
||||
#endif
|
||||
|
||||
void dot (TYPE* A __readonly __noalias __align(16),
|
||||
TYPE* B __readonly __noalias __align(16),
|
||||
TYPE* C __writeonly __noalias __align(16),
|
||||
int lda, int ldb, int ldc,
|
||||
int N, int* lut,
|
||||
int* locks, int nlocks) {
|
||||
int ridx = get_program_id(0);
|
||||
float c[TM, TN] = 0;
|
||||
int rka[TK] = 0 ... TK;
|
||||
int rkb[TK] = 0 ... TK;
|
||||
// load LUT header
|
||||
int *header = lut + get_program_id(1) * 4;
|
||||
int offset = *(header + 0);
|
||||
int K = *(header + 1);
|
||||
int column = *(header + 2);
|
||||
int lockid = *(header + 3);
|
||||
int *plut = lut + offset * 2;
|
||||
int offx = ridx;
|
||||
int offy = 0;
|
||||
// compute x, y offsets
|
||||
int rxa[TM] = offx * TM + (0 ... TM);
|
||||
int ryb[TN] = offy * TN + (0 ... TN);
|
||||
// bounds checking
|
||||
bool checka[SHAPE_A] = (rxa < N)[:, newaxis];
|
||||
bool checkb[SHAPE_B] = 1;
|
||||
// base offset
|
||||
int offa[SHAPE_A] = rxa[BROADCAST_AM] * STRIDE_AM + rka[BROADCAST_AK] * STRIDE_AK;
|
||||
int offb[SHAPE_B] = ryb[BROADCAST_BN] * STRIDE_BN + rkb[BROADCAST_BK] * STRIDE_BK;
|
||||
for(int k = K; k > 0; k -= 1) {
|
||||
// fetch block indices
|
||||
int ak = *(plut + 0);
|
||||
int bk = *(plut + 1);
|
||||
lut += 2;
|
||||
// compute pointers to blocks
|
||||
TYPE* pa[SHAPE_A] = A + offa + ak * TK * lda;
|
||||
TYPE* pb[SHAPE_B] = B + offb + bk * TK * TN;
|
||||
// load blocks
|
||||
TYPE a[SHAPE_A] = checka ? *pa : 0;
|
||||
TYPE b[SHAPE_B] = *pb;
|
||||
// multiply blocks
|
||||
c += USE_A @ USE_B;
|
||||
}
|
||||
int rxc[TM] = ridx * TM + (0 ... TM);
|
||||
int ryc[TN] = column * TN + (0 ... TN);
|
||||
TYPE* pc[TM, TN] = C + rxc[:, newaxis] + ryc[newaxis, :]*ldc;
|
||||
bool checkc[TM, TN] = (rxc < N)[:, newaxis];
|
||||
if(lockid == 0) {
|
||||
*?(checkc) pc = c;
|
||||
}
|
||||
else {
|
||||
int *plock = locks + ridx*nlocks + lockid - 1;
|
||||
int *pcount = plock + get_num_program(0)*nlocks;
|
||||
while(atomic_cas(plock, 0, 1));
|
||||
int count = *pcount;
|
||||
if(count == 0)
|
||||
*?(checkc) pc = c;
|
||||
else
|
||||
*?(checkc) pc = c + *pc;
|
||||
atomic_exch(pcount, 1);
|
||||
atomic_exch(plock, 0);
|
||||
}
|
||||
}
|
||||
'''
|
||||
|
||||
# std::string dot::triton_c_src_dw() const {
|
||||
# bool AT = (op_ == WGRAD);
|
||||
# bool BT = (op_ == FPROP);
|
||||
# std::string usea = AT ? "trans(a)" : "a";
|
||||
# std::string useb = BT ? "trans(b)" : "b";
|
||||
# std::string sizea = AT ? "TK, TM" : "TM, TK";
|
||||
# std::string sizeb = BT ? "TN, TK" : "TK, TN";
|
||||
# std::string bca0 = AT ? "newaxis, :" : ":, newaxis";
|
||||
# std::string bca1 = AT ? ":, newaxis" : "newaxis, :";
|
||||
# std::string bcb0 = BT ? ":, newaxis" : "newaxis, :";
|
||||
# std::string bcb1 = BT ? "newaxis, :" : ":, newaxis";
|
||||
# std::string lda0 = AT ? "*lda" : "";
|
||||
# std::string lda1 = AT ? "" : "*lda";
|
||||
# std::string ldb0 = BT ? "" : "*ldb";
|
||||
# std::string ldb1 = BT ? "*ldb" : "" ;
|
||||
# std::string result =
|
||||
# R"(
|
||||
# const tunable int TM = {)" + std::to_string(BS_) + R"(};
|
||||
# const tunable int TN = {)" + std::to_string(BS_) + R"(};
|
||||
# const tunable int TK = {32};
|
||||
# void bsdot(restrict read_only align(16) )" + ab_ty_ + R"( *A,
|
||||
# restrict read_only align(16) )" + ab_ty_ + R"( *B,
|
||||
# )" + c_ty_ + R"(* C,
|
||||
# int lda, int ldb, int ldc,
|
||||
# int N, int* lut,
|
||||
# int* locks, int nlocks) {
|
||||
# int ridx = get_range_id(0);
|
||||
# float acc[TM, TN] = 0;
|
||||
# int rka[TK] = 0 ... TK;
|
||||
# int rkb[TK] = 0 ... TK;
|
||||
# int *header = lut + ridx * 2;
|
||||
# int offx = *(header + 0);
|
||||
# int offy = *(header + 1);
|
||||
# int rxa[TM] = offx*TM + (0 ... TM);
|
||||
# int ryb[TN] = offy*TN + (0 ... TN);
|
||||
# bool checka[TK, TM] = (rka < N)[:, newaxis];
|
||||
# bool checkb[TK, TN] = (rkb < N)[:, newaxis];
|
||||
# int offa[)" + sizea + "] = rxa[" + bca0 + "]" + lda0 + " + rka[" + bca1 + "]" + lda1 + R"(;
|
||||
# int offb[)" + sizeb + "] = ryb[" + bcb0 + "]" + ldb0 + " + rkb[" + bcb1 + "]" + ldb1 + R"(;
|
||||
# )" + ab_ty_ + " * pa[" + sizea + R"(] = A + offa;
|
||||
# )" + ab_ty_ + " * pb[" + sizeb + R"(] = B + offb;
|
||||
# )" + ab_ty_ + " a[" + sizea + R"(] = checka ? *pa : 0;
|
||||
# )" + ab_ty_ + " b[" + sizeb + R"(] = checkb ? *pb : 0;
|
||||
# for(int k = N; k > 0; k = k - TK) {
|
||||
# acc = dot()" + usea + ", " + useb + R"(, acc);
|
||||
# pa = pa + TK)" + lda1 + R"(;
|
||||
# pb = pb + TK)" + ldb1 + R"(;
|
||||
# a = checka ? *pa : 0;
|
||||
# b = checkb ? *pb : 0;
|
||||
# }
|
||||
# int rxc[TM] = (0 ... TM);
|
||||
# int ryc[TN] = (0 ... TN);
|
||||
# )" + c_ty_ + R"( c[TM, TN] = acc;
|
||||
# )" + c_ty_ + R"(* pc[TM, TN] = C + rxc[:, newaxis]*TM + ryc[newaxis, :] + ridx*TM*TN;
|
||||
# *pc = c;
|
||||
# })";
|
@@ -1,15 +0,0 @@
|
||||
import torch
|
||||
import triton
|
||||
|
||||
N, C, K = 32, 8, 32
|
||||
H, W = 16, 16
|
||||
R, S = 3, 3
|
||||
torch.manual_seed(0)
|
||||
a = torch.randn(N, C, H, W).cuda()
|
||||
b = torch.ones(C, R, S, K).cuda()
|
||||
|
||||
rc = torch.nn.functional.conv2d(a, b.permute(3, 0, 1, 2))
|
||||
tc = triton.ops.conv(a, b)
|
||||
print((rc - tc).abs().max())
|
||||
#print((rc[:30,:30,:,:] - tc[:30, :30, :, :]).abs().max())
|
||||
#print(tc[31, 31,:,:])
|
@@ -1,71 +0,0 @@
|
||||
import numpy as np
|
||||
import triton
|
||||
|
||||
def run_tf():
|
||||
M, N, K = 2048, 2048, 2048
|
||||
a = tf.placeholder(tf.float32, shape=[M, K])
|
||||
b = tf.placeholder(tf.float32, shape=[N, K])
|
||||
triton_c = triton.ops.dot(a, b, False, True, 1)
|
||||
triton_d = triton.ops.dot(triton_c, b, True, False, 1)
|
||||
triton_y = tf.math.reduce_mean(triton_d)
|
||||
fw_c = tf.matmul(a, b, False, True)
|
||||
fw_d = tf.matmul(fw_c, b, True, False)
|
||||
fw_y = tf.math.reduce_mean(fw_d)
|
||||
# Gradient
|
||||
triton_da, triton_db = tf.gradients(triton_y, [a, b])
|
||||
fw_da, fw_db = tf.gradients(fw_y, [a, b])
|
||||
# Reference
|
||||
feed_dict = {a: np.random.rand(M, K).astype(np.float32),
|
||||
b: np.random.rand(K, N).astype(np.float32)}
|
||||
sess = tf.InteractiveSession()
|
||||
sess.run(tf.global_variables_initializer())
|
||||
result = sess.run([triton_da, fw_da, triton_db, fw_db, fw_y, triton_y], feed_dict = feed_dict)
|
||||
triton_da, fw_da = result[0][0], result[1][0]
|
||||
triton_db, fw_db = result[2][0], result[3][0]
|
||||
# Benchmark
|
||||
nanosec = triton.bench_registry[triton_d]
|
||||
print('TFLOPS:', 2. * M * N * K / nanosec * 1e-3)
|
||||
print('Diff DA:', (triton_da - fw_da).max())
|
||||
print('Diff DB:', (triton_db - fw_db).max())
|
||||
|
||||
|
||||
def run_torch():
|
||||
torch.manual_seed(0)
|
||||
M, N, K = 2048, 2048, 2048
|
||||
a = torch.randn(M, K).cuda()
|
||||
b = torch.randn(K, N).cuda()
|
||||
a.requires_grad_(True)
|
||||
b.requires_grad_(True)
|
||||
torch_c = torch.matmul(a, torch.t(b))
|
||||
torch_d = torch.matmul(torch.t(torch_c), b)
|
||||
torch_y = torch.mean(torch_d)
|
||||
triton_c = triton.ops.dot(a, b, False, True, 1)
|
||||
triton_d = triton.ops.dot(triton_c, b, True, False, 1)
|
||||
triton_y = torch.mean(triton_d)
|
||||
# torch gradient
|
||||
torch_y.backward()
|
||||
torch_da = a.grad.clone()
|
||||
torch_db = b.grad.clone()
|
||||
# triton gradient
|
||||
a.grad.zero_()
|
||||
b.grad.zero_()
|
||||
triton_y.backward()
|
||||
triton_da = a.grad.clone()
|
||||
triton_db = b.grad.clone()
|
||||
|
||||
#nanosec = triton.bench_registry[triton_d]
|
||||
#print('TFLOPS:', 2. * M * N * K / nanosec * 1e-3)
|
||||
print('Diff DA:', (torch_da - triton_da).max())
|
||||
print('Diff DB:', (torch_db - triton_db).max())
|
||||
|
||||
try:
|
||||
import tensorflow as tf
|
||||
run_tf()
|
||||
except ModuleNotFoundError:
|
||||
pass
|
||||
|
||||
try:
|
||||
import torch
|
||||
run_torch()
|
||||
except ModuleNotFoundError:
|
||||
pass
|
@@ -1,92 +1,194 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
import numpy as np
|
||||
from enum import Enum
|
||||
import triton
|
||||
import torch
|
||||
from torch.utils.cpp_extension import load
|
||||
import numpy as np
|
||||
#import utils
|
||||
from time import time
|
||||
|
||||
class MODE(Enum):
|
||||
TF = 1
|
||||
TORCH = 2
|
||||
#torch.backends.cudnn.benchmark = True
|
||||
|
||||
try:
|
||||
import tensorflow as tf
|
||||
mode = MODE.TF
|
||||
except ModuleNotFoundError:
|
||||
pass
|
||||
configs = []
|
||||
|
||||
try:
|
||||
import torch
|
||||
mode = MODE.TORCH
|
||||
except ModuleNotFoundError:
|
||||
pass
|
||||
# Matrix multiplication
|
||||
MNK = [
|
||||
(512, 512 ,512),
|
||||
(2048, 2048, 2048),
|
||||
(8192, 8192, 8192),
|
||||
|
||||
(64, 64, 64000),
|
||||
(64, 64, 128000),
|
||||
(256, 256, 64000),
|
||||
(256, 256, 128000),
|
||||
|
||||
cases = []
|
||||
# Matmul
|
||||
cases += [[[4, 1024, 1024], [1024, 1024], [4, 1024, 1024], "btc,ck->btk"]]
|
||||
# Attention
|
||||
# cases += [[[4, 256, 8, 2, 64], [8, 2, 512, 64], [4, 256, 8, 2, 512], "bchak,hank->bchan"]]
|
||||
(1536, 16, 1536),
|
||||
(1536, 32, 1536),
|
||||
(1536, 64, 1536),
|
||||
(1536, 128, 1536),
|
||||
(4096, 16, 4096),
|
||||
(4096, 32, 4096),
|
||||
(4096, 64, 4096),
|
||||
(4096, 128, 4096),
|
||||
|
||||
#(127008, 768, 576)
|
||||
]
|
||||
for M, N, K in MNK:
|
||||
matmul = lambda a, b: torch.matmul(a, b)
|
||||
configs += [([M, K], [K, N], [M, N], matmul, 'mk,kn->mn', dict())]
|
||||
for M, N, K in MNK:
|
||||
matmul = lambda a, b: torch.matmul(a.t(), b)
|
||||
configs += [([M, K], [M, N], [K, N], None, 'mk,mn->kn', dict())]
|
||||
for M, N, K in MNK:
|
||||
matmul = lambda a, b: torch.matmul(a, b.t())
|
||||
configs += [([M, N], [K, N], [M, K], None, 'mn,kn->mk', dict())]
|
||||
|
||||
if mode == MODE.TF:
|
||||
sess = tf.InteractiveSession()
|
||||
# Relative attention
|
||||
NTHSE = [
|
||||
#(16, 512, 1, 64, 64),
|
||||
# (16, 512, 1, 128, 128),
|
||||
# (16, 512, 1, 256, 256),
|
||||
# (16, 512, 1, 256, 512),
|
||||
#(16, 512, 8, 64, 64),
|
||||
# (16, 512, 8, 128, 128),
|
||||
# (16, 512, 8, 256, 256),
|
||||
# (16, 512, 8, 256, 512),
|
||||
|
||||
for a_shape, b_shape, c_shape, einsum in cases:
|
||||
# (64, 1024, 1, 64, 64),
|
||||
#(64, 1024, 1, 128, 128),
|
||||
# (64, 1024, 1, 256, 256),
|
||||
# (64, 1024, 1, 256, 512),
|
||||
# (64, 1024, 8, 64, 64),
|
||||
#(64, 1024, 8, 128, 128),
|
||||
# (64, 1024, 8, 256, 256),
|
||||
# (64, 1024, 8, 256, 512),
|
||||
|
||||
A = np.random.uniform(-1.0, 1.0, a_shape).astype(np.float16).astype(np.float32)
|
||||
B = np.random.uniform(-1.0, 1.0, b_shape).astype(np.float16).astype(np.float32)
|
||||
E = np.random.uniform(-1.0, 1.0, c_shape).astype(np.float16).astype(np.float32)
|
||||
# (128, 1024, 1, 64, 64),
|
||||
# (128, 1024, 1, 128, 128),
|
||||
# (128, 1024, 1, 256, 256),
|
||||
#(128, 1024, 1, 256, 512),
|
||||
# (128, 1024, 8, 64, 64),
|
||||
# (128, 1024, 8, 128, 128),
|
||||
# (128, 1024, 8, 256, 256),
|
||||
#(128, 1024, 8, 256, 512)
|
||||
]
|
||||
for N, T, H, S, E in NTHSE:
|
||||
configs += [([N, T, H, S], [H, E, S], [N, H, T, E], None, 'nths,hes->nhte', dict())]
|
||||
for N, T, H, S, E in NTHSE:
|
||||
configs += [([N, H, T, E], [N, T, H, S], [H, E, S], None, 'nhte,nths->hes', dict())]
|
||||
for N, T, H, S, E in NTHSE:
|
||||
configs += [([N, H, T, E], [H, E, S], [N, T, H, S], None, 'nhte,hes->nths', dict())]
|
||||
|
||||
# Execute (tensorflow)
|
||||
if mode == MODE.TF:
|
||||
a = tf.placeholder(tf.float32, a_shape, name="a")
|
||||
b = tf.placeholder(tf.float32, b_shape, name="b")
|
||||
e = tf.placeholder(tf.float32, c_shape, name="e")
|
||||
c = triton.ops.einsum(einsum, a, b, 1)
|
||||
da, db = tf.gradients(c, [a, b], e)
|
||||
feed_dict = { a: A.astype(np.float32),
|
||||
b: B.astype(np.float32),
|
||||
e: E }
|
||||
sess.run(tf.global_variables_initializer())
|
||||
result = sess.run([c, da, db], feed_dict = feed_dict)
|
||||
# Execute (torch)
|
||||
if mode == MODE.TORCH:
|
||||
a = torch.from_numpy(A).cuda()
|
||||
b = torch.from_numpy(B).cuda()
|
||||
e = torch.from_numpy(E).cuda()
|
||||
a.requires_grad_(True)
|
||||
b.requires_grad_(True)
|
||||
c = triton.ops.einsum(einsum, a, b, 1)
|
||||
torch.autograd.backward(c, e)
|
||||
da = a.grad
|
||||
db = b.grad
|
||||
result = [c.cpu().detach().numpy(), da.cpu().detach().numpy(), db.cpu().detach().numpy()]
|
||||
|
||||
# benchmark
|
||||
nanosec = triton.bench_registry[c]
|
||||
ctx = triton.ctx_registry[c]
|
||||
b, m, n, k = tuple((ctx.bmnk[i] for i in range(0, 4)))
|
||||
ops = 2.*b*m*n*k
|
||||
print('C TFLOPS:', ops / triton.bench_registry[c] * 1e-3)
|
||||
#print('DA TFLOPS:', ops / triton.bench_registry[da] * 1e-3)
|
||||
#print('DB TFLOPS:', ops / triton.bench_registry[db] * 1e-3)
|
||||
# 1D Dense convolution
|
||||
NCHKR = [
|
||||
# (1, 1152, 12602, 512, 3)
|
||||
]
|
||||
for N, C, H, K, R in NCHKR:
|
||||
torch_fn = lambda a, b: torch.nn.functional.conv1d(a, b.permute(2, 0, 1))
|
||||
configs += [([N, C, H],
|
||||
[C, R, K],
|
||||
[N, K, H - R + 1],
|
||||
torch_fn,
|
||||
'nc(h+r),crk->nkh',
|
||||
dict())]
|
||||
|
||||
# test
|
||||
ctx = triton.ctx_registry[c]
|
||||
t_a = ctx.trans_a
|
||||
t_b = ctx.trans_b
|
||||
e_a = ctx.einsum_a
|
||||
e_b = ctx.einsum_b
|
||||
e_c = ctx.einsum_c
|
||||
C = np.einsum(einsum, A, B)
|
||||
if not t_a and not t_b: # NN
|
||||
DA = np.einsum(f"{e_c},{e_b}->{e_a}", E, B)
|
||||
DB = np.einsum(f"{e_a},{e_c}->{e_b}", A, E)
|
||||
elif not t_a and t_b: # NT
|
||||
DA = np.einsum(f"{e_c},{e_b}->{e_a}", E, B)
|
||||
DB = np.einsum(f"{e_c},{e_a}->{e_b}", E, A)
|
||||
elif t_a and not t_b: # TN
|
||||
DA = np.einsum(f"{e_b},{e_c}->{e_a}", B, E)
|
||||
DB = np.einsum(f"{e_a},{e_c}->{e_b}", A, E)
|
||||
c, da, db = result[0], result[1], result[2]
|
||||
print('C diff:', np.abs((C - c)).max())
|
||||
print('DA diff:', np.abs((DA - da)).max())
|
||||
print('DB diff:', np.abs((DB - db)).max())
|
||||
# 2D Dense convolution
|
||||
NCHWKRS = [
|
||||
#(8, 64, 128, 128, 768, 3, 3),
|
||||
#(8, 128, 64, 64, 256, 3, 3),
|
||||
#(8, 256, 32, 32, 512, 3, 3),
|
||||
#(8, 512, 32, 32, 1024, 3, 3)
|
||||
]
|
||||
for N, C, H, W, K, R, S in NCHWKRS:
|
||||
torch_fn = lambda a, b: torch.nn.functional.conv2d(a, b.permute(3, 0, 1, 2))
|
||||
configs += [([N, C, H, W],
|
||||
[C, R, S, K],
|
||||
[N, K, H - R + 1, W - R + 1],
|
||||
torch_fn,
|
||||
'nc(h+r)(w+s),crsk->nkhw',
|
||||
dict())]
|
||||
|
||||
# 3D Dense Convolution
|
||||
NCDHWKTRS = [
|
||||
#(8, 32, 27, 100, 100, 64, 3, 3, 3),
|
||||
#(8, 64, 23, 48, 48, 256, 3, 3, 3),
|
||||
#(8, 256, 19, 22, 22, 640, 3, 3, 3),
|
||||
#(8, 640, 15, 36, 36, 384, 3, 3, 3)
|
||||
]
|
||||
for N, C, D, H, W, K, T, R, S in NCDHWKTRS:
|
||||
torch_fn = lambda a, b: torch.nn.functional.conv3d(a, b.permute(4, 0, 1, 2, 3))
|
||||
configs += [([N, C, D, H, W],
|
||||
[C, T, R, S, K],
|
||||
[N, K, D - T + 1, H - R + 1, W - R + 1],
|
||||
torch_fn,
|
||||
'nc(d+t)(h+r)(w+s),ctrsk->nkdhw',
|
||||
dict())]
|
||||
|
||||
|
||||
# Shift convolution
|
||||
shift_cuda = torch.utils.cpp_extension.load(
|
||||
'shift_cuda', ['kernels/shift_cuda.cpp',
|
||||
'kernels/shift_cuda_kernel.cu'],
|
||||
extra_cflags=['-O3'])
|
||||
class shift(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def forward(ctx, x, shift):
|
||||
ctx.save_for_backward(shift)
|
||||
return shift_cuda.forward(x, shift)
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output):
|
||||
shift, = ctx.saved_tensors
|
||||
grad_output = shift_cuda.backward(grad_output, shift)
|
||||
|
||||
return grad_output, None
|
||||
|
||||
NCHWKRS = [
|
||||
#(8, 64, 128, 128, 128, 3, 3),
|
||||
#(8, 128, 64, 64, 256, 3, 3),
|
||||
#(8, 256, 32, 32, 512, 3, 3),
|
||||
#(8, 512, 32, 32, 1024, 3, 3)
|
||||
]
|
||||
for N, C, H, W, K, R, S in NCHWKRS:
|
||||
shift_h = np.random.randint(R, size=C, dtype=np.int32) - R//2
|
||||
shift_w = np.random.randint(S, size=C, dtype=np.int32) - S//2
|
||||
def shift_conv(a, b, **kwargs):
|
||||
shift_h, shift_w = kwargs['sh'], kwargs['sw']
|
||||
shift_torch = np.column_stack((shift_w*-1, shift_h*-1))
|
||||
shift_torch = torch.from_numpy(shift_torch).cuda()
|
||||
a = shift.apply(a, shift_torch)
|
||||
b = b.permute(1, 0)
|
||||
b = b.reshape(b.shape[0], b.shape[1], 1, 1)
|
||||
return torch.nn.functional.conv2d(a, b)
|
||||
configs += [([N, C, H, W],
|
||||
[C, K],
|
||||
[N, K, H, W],
|
||||
shift_conv,
|
||||
'nc(h + sh[c])(w + sw[c]),ck->nkhw',
|
||||
{'sh': shift_h, 'sw': shift_w})]
|
||||
|
||||
# Benchmark
|
||||
torch.set_num_threads(1)
|
||||
for a_shape, b_shape, c_shape, torch_fn, expr, arrays in configs:
|
||||
dtype = torch.cuda.HalfTensor
|
||||
# initialize input tensors
|
||||
a = torch.rand(*a_shape).type(dtype).cuda()
|
||||
b = torch.rand(*b_shape).type(dtype).cuda()
|
||||
# triton output
|
||||
#ta = triton.ops._einsum.pad(a, [4,4,4,4])
|
||||
tc = triton.ops.einsum(expr, a, b, c_shape, arrays = arrays, bench = True)
|
||||
# reference output
|
||||
if torch_fn:
|
||||
rc = torch_fn(a, b, **arrays)
|
||||
else:
|
||||
rc = torch.einsum(expr, a, b)
|
||||
# performance relative to equivalent matrix multiplication
|
||||
ctx = triton.ctx_registry[tc]
|
||||
B, M, N, K = ctx.matmul_B, ctx.matmul_M, ctx.matmul_N, ctx.matmul_K
|
||||
# a = torch.rand(B, M, K).type(dtype).cuda()
|
||||
# b = torch.rand(B, K, N).type(dtype).cuda()
|
||||
# tmmc = triton.ops.einsum('bmk,bkn->bmn', a, b, [B, M, N], bench = True)
|
||||
# ratio = triton.bench_registry[tmmc] / triton.bench_registry[tc]
|
||||
ratio = 0
|
||||
# test and benchmark
|
||||
bench = 2. * B * M * N * K / triton.bench_registry[tc] * 1e-3
|
||||
diff = (tc - rc).abs().max() / rc.abs().max()
|
||||
print(f'{expr:>15}; {str(a_shape):>20}; {str(b_shape):>20}; {bench:4.2f} ({ratio:4.2f}); {diff:4.2f}')
|
||||
|
42
python/examples/kernels/shift_cuda.cpp
Normal file
42
python/examples/kernels/shift_cuda.cpp
Normal file
@@ -0,0 +1,42 @@
|
||||
#include <torch/torch.h>
|
||||
|
||||
#include <vector>
|
||||
|
||||
// CUDA forward declarations
|
||||
|
||||
at::Tensor shift_cuda_forward(
|
||||
const at::Tensor input,
|
||||
const at::Tensor shift);
|
||||
|
||||
at::Tensor shift_cuda_backward(
|
||||
const at::Tensor grad_input,
|
||||
const at::Tensor shift);
|
||||
|
||||
// C++ interface
|
||||
|
||||
// NOTE: AT_ASSERT has become AT_CHECK on master after 0.4.
|
||||
#define CHECK_CUDA(x) AT_ASSERTM(x.type().is_cuda(), #x " must be a CUDA tensor")
|
||||
#define CHECK_CONTIGUOUS(x) AT_ASSERTM(x.is_contiguous(), #x " must be contiguous")
|
||||
#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)
|
||||
|
||||
at::Tensor shift_forward(
|
||||
const at::Tensor input,
|
||||
const at::Tensor shift) {
|
||||
CHECK_INPUT(input);
|
||||
CHECK_INPUT(shift);
|
||||
|
||||
return shift_cuda_forward(input, shift);
|
||||
}
|
||||
|
||||
at::Tensor shift_backward(
|
||||
const at::Tensor grad_input,
|
||||
const at::Tensor shift) {
|
||||
CHECK_INPUT(grad_input);
|
||||
CHECK_INPUT(shift);
|
||||
return shift_cuda_backward(grad_input, shift);
|
||||
}
|
||||
|
||||
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
||||
m.def("forward", &shift_forward, "Shift forward (CUDA)");
|
||||
m.def("backward", &shift_backward, "Shift backward (CUDA)");
|
||||
}
|
111
python/examples/kernels/shift_cuda_kernel.cu
Normal file
111
python/examples/kernels/shift_cuda_kernel.cu
Normal file
@@ -0,0 +1,111 @@
|
||||
#include <ATen/ATen.h>
|
||||
|
||||
#include <cuda.h>
|
||||
#include <cuda_runtime.h>
|
||||
|
||||
#include <vector>
|
||||
|
||||
namespace {
|
||||
template <typename scalar_t>
|
||||
__global__ void shift_cuda_forward_kernel(
|
||||
const scalar_t* __restrict__ input,
|
||||
const int32_t* __restrict__ shift,
|
||||
scalar_t* __restrict__ output,
|
||||
const int32_t B,
|
||||
const int32_t C,
|
||||
const int32_t H,
|
||||
const int32_t W) {
|
||||
const int32_t idx = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
const int32_t size = B*C*H*W;
|
||||
|
||||
const int32_t CHW = C*H*W;
|
||||
const int32_t HW = H*W;
|
||||
const int32_t b = idx / CHW;
|
||||
const int32_t c = (idx - b*CHW) / HW;
|
||||
const int32_t h = (idx - b*CHW - c*HW) / W;
|
||||
const int32_t w = idx - b*CHW - c*HW - h*W;
|
||||
const int32_t target_w = w + shift[2*c];
|
||||
const int32_t target_h = h + shift[2*c + 1];
|
||||
const int32_t target_idx = b*CHW + c*HW + target_h*W + target_w;
|
||||
if (idx < size && target_w >= 0 && target_w < W && target_h >= 0 && target_h < H) {
|
||||
output[target_idx] = input[idx];
|
||||
}
|
||||
}
|
||||
|
||||
template <typename scalar_t>
|
||||
__global__ void shift_cuda_backward_kernel(
|
||||
const scalar_t* __restrict__ grad_input,
|
||||
scalar_t* __restrict__ grad_output,
|
||||
const int32_t* __restrict__ shift,
|
||||
const int32_t B,
|
||||
const int32_t C,
|
||||
const int32_t W,
|
||||
const int32_t H) {
|
||||
const int32_t idx = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
const int32_t size = B*C*W*H;
|
||||
const int32_t CWH = C*W*H;
|
||||
const int32_t WH = W*H;
|
||||
const int32_t b = idx / CWH;
|
||||
const int32_t c = (idx - b*CWH) / WH;
|
||||
const int32_t w = (idx - b*CWH - c*WH) / W;
|
||||
const int32_t h = idx - b*CWH - c*WH - w*H;
|
||||
const int32_t target_w = w - shift[2*c];
|
||||
const int32_t target_h = h - shift[2*c + 1];
|
||||
const int32_t target_idx = b*CWH + c*WH + target_w*W + target_h;
|
||||
if (idx < size && target_w >= 0 && target_w < W && target_h >= 0 && target_h < H) {
|
||||
grad_output[target_idx] = grad_input[idx];
|
||||
}
|
||||
}
|
||||
} // namespace
|
||||
|
||||
at::Tensor shift_cuda_forward(
|
||||
const at::Tensor input,
|
||||
const at::Tensor shift) {
|
||||
const auto B = input.size(0);
|
||||
const auto C = input.size(1);
|
||||
const auto H = input.size(2);
|
||||
const auto W = input.size(3);
|
||||
const auto size = B*C*W*H;
|
||||
const int threads = 1024;
|
||||
const int blocks = (size + threads - 1) / threads;
|
||||
auto output = at::zeros_like(input);
|
||||
|
||||
AT_DISPATCH_FLOATING_TYPES_AND_HALF(input.type(), "shift_forward_cuda", ([&] {
|
||||
shift_cuda_forward_kernel<scalar_t><<<blocks, threads>>>(
|
||||
input.data<scalar_t>(),
|
||||
shift.data<int32_t>(),
|
||||
output.data<scalar_t>(),
|
||||
B,
|
||||
C,
|
||||
H,
|
||||
W);
|
||||
}));
|
||||
|
||||
return output;
|
||||
}
|
||||
|
||||
at::Tensor shift_cuda_backward(
|
||||
const at::Tensor grad_input,
|
||||
const at::Tensor shift) {
|
||||
const auto B = grad_input.size(0);
|
||||
const auto C = grad_input.size(1);
|
||||
const auto H = grad_input.size(2);
|
||||
const auto W = grad_input.size(3);
|
||||
const auto size = B*C*W*H;
|
||||
const int threads = 1024;
|
||||
const int blocks = (size + threads - 1) / threads;
|
||||
auto grad_output = at::zeros_like(grad_input);
|
||||
|
||||
AT_DISPATCH_FLOATING_TYPES_AND_HALF(grad_input.type(), "shift_backward_cuda", ([&] {
|
||||
shift_cuda_backward_kernel<scalar_t><<<blocks, threads>>>(
|
||||
grad_input.data<scalar_t>(),
|
||||
grad_output.data<scalar_t>(),
|
||||
shift.data<int32_t>(),
|
||||
B,
|
||||
C,
|
||||
H,
|
||||
W);
|
||||
}));
|
||||
|
||||
return grad_output;
|
||||
}
|
Reference in New Issue
Block a user