History prior to this date belonged to the now deprecated ISAAC project, and was deleted to save space
This commit is contained in:
48
python/examples/attention/bench.py
Normal file
48
python/examples/attention/bench.py
Normal file
@@ -0,0 +1,48 @@
|
||||
import torch
|
||||
import numpy as np
|
||||
import reference
|
||||
import optimized
|
||||
from time import time
|
||||
|
||||
use_half = True
|
||||
def cast(x):
|
||||
if use_half:
|
||||
return x.half()
|
||||
else:
|
||||
return x
|
||||
|
||||
# GPU device
|
||||
device = torch.device("cuda:0")
|
||||
# shapes
|
||||
batch, nhead = 8, 28
|
||||
dm, dk, dv = 1024, 1024, 1024
|
||||
lq, lk, lv = 1024, 1024, 1024
|
||||
# initialize tensors
|
||||
torch.manual_seed(0)
|
||||
np.random.seed(0)
|
||||
query = cast(torch.randn(batch, lq, dm)).cuda()
|
||||
key = cast(torch.randn(batch, lk, dm)).cuda()
|
||||
value = cast(torch.randn(batch, lv, dm)).cuda()
|
||||
# initialize layers
|
||||
torch.manual_seed(0)
|
||||
np.random.seed(0)
|
||||
rattn = cast(reference.MultiHeadAttention(nhead, dm, dk, dv).to(device))
|
||||
torch.manual_seed(0)
|
||||
np.random.seed(0)
|
||||
tattn = cast(optimized.MultiHeadAttention(nhead, dm, dk, dv).to(device))
|
||||
# test
|
||||
routput, _ = rattn(query, key, value)
|
||||
toutput, _ = tattn(query, key, value)
|
||||
diff = torch.max(torch.abs(routput - toutput))
|
||||
assert diff < 1e-2
|
||||
# benchmark
|
||||
start = time()
|
||||
routput, _ = rattn(query, key, value)
|
||||
end = time()
|
||||
rtime = end - start
|
||||
start = time()
|
||||
toutput, _ = tattn(query, key, value)
|
||||
end = time()
|
||||
ttime = end - start
|
||||
print(f'Torch: {rtime} s')
|
||||
print(f'Triton: {ttime} s')
|
50
python/examples/attention/optimized.py
Normal file
50
python/examples/attention/optimized.py
Normal file
@@ -0,0 +1,50 @@
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import triton
|
||||
|
||||
class MultiHeadAttention(nn.Module):
|
||||
''' Multi-Head Attention module '''
|
||||
|
||||
def __init__(self, n_head, d_model, d_k, d_v):
|
||||
super().__init__()
|
||||
self.n_head = n_head
|
||||
self.d_k = d_k
|
||||
self.d_v = d_v
|
||||
# linear layers
|
||||
self.w_qs = nn.Linear(d_model, n_head * d_k)
|
||||
self.w_ks = nn.Linear(d_model, n_head * d_k)
|
||||
self.w_vs = nn.Linear(d_model, n_head * d_v)
|
||||
self.fc = nn.Linear(n_head * d_v, d_model)
|
||||
# initialize weights
|
||||
nn.init.normal_(self.w_qs.weight, mean=0, std=np.sqrt(2.0 / (d_model + d_k)))
|
||||
nn.init.normal_(self.w_ks.weight, mean=0, std=np.sqrt(2.0 / (d_model + d_k)))
|
||||
nn.init.normal_(self.w_vs.weight, mean=0, std=np.sqrt(2.0 / (d_model + d_v)))
|
||||
nn.init.xavier_normal_(self.fc.weight)
|
||||
# layer normalization
|
||||
self.layer_norm = nn.LayerNorm(d_model)
|
||||
|
||||
|
||||
def forward(self, q, k, v, mask=None):
|
||||
# dimensions
|
||||
d_k, d_v, n_head = self.d_k, self.d_v, self.n_head
|
||||
sz_b, len_q, _ = q.size()
|
||||
sz_b, len_k, _ = k.size()
|
||||
sz_b, len_v, _ = v.size()
|
||||
# linear transformations
|
||||
residual = q
|
||||
q = self.w_qs(q).view(sz_b, len_q, n_head, d_k)
|
||||
k = self.w_ks(k).view(sz_b, len_k, n_head, d_k)
|
||||
v = self.w_vs(v).view(sz_b, len_v, n_head, d_v)
|
||||
# scaled dot-product attention
|
||||
attn = triton.ops.einsum('blhk,bthk->hblt', q, k, [n_head, sz_b, len_q, len_k])
|
||||
attn = attn / np.sqrt(d_k)
|
||||
if mask is not None:
|
||||
attn = attn.masked_fill(mask[None], -np.inf)
|
||||
attn = torch.softmax(attn, dim=3)
|
||||
output = triton.ops.einsum('hblt,bthv->blhv', attn, v, [sz_b, len_q, n_head, d_v])
|
||||
output = output.view(sz_b, len_q, -1)
|
||||
output = self.fc(output)
|
||||
# epilogue
|
||||
output = self.layer_norm(output + residual)
|
||||
return output, attn
|
72
python/examples/attention/reference.py
Normal file
72
python/examples/attention/reference.py
Normal file
@@ -0,0 +1,72 @@
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
class ScaledDotProductAttention(nn.Module):
|
||||
''' Scaled Dot-Product Attention '''
|
||||
|
||||
def __init__(self, temperature, attn_dropout=0.1):
|
||||
super().__init__()
|
||||
self.temperature = temperature
|
||||
self.softmax = nn.Softmax(dim=2)
|
||||
|
||||
def forward(self, q, k, v, mask=None):
|
||||
attn = torch.bmm(q, k.transpose(1, 2))
|
||||
attn = attn / self.temperature
|
||||
if mask is not None:
|
||||
attn = attn.masked_fill(mask, -np.inf)
|
||||
attn = self.softmax(attn)
|
||||
output = torch.bmm(attn, v)
|
||||
return output, attn
|
||||
|
||||
|
||||
|
||||
class MultiHeadAttention(nn.Module):
|
||||
''' Multi-Head Attention module '''
|
||||
|
||||
def __init__(self, n_head, d_model, d_k, d_v):
|
||||
super().__init__()
|
||||
self.n_head = n_head
|
||||
self.d_k = d_k
|
||||
self.d_v = d_v
|
||||
# linear layers
|
||||
self.w_qs = nn.Linear(d_model, n_head * d_k)
|
||||
self.w_ks = nn.Linear(d_model, n_head * d_k)
|
||||
self.w_vs = nn.Linear(d_model, n_head * d_v)
|
||||
self.fc = nn.Linear(n_head * d_v, d_model)
|
||||
# initialize weights
|
||||
nn.init.normal_(self.w_qs.weight, mean=0, std=np.sqrt(2.0 / (d_model + d_k)))
|
||||
nn.init.normal_(self.w_ks.weight, mean=0, std=np.sqrt(2.0 / (d_model + d_k)))
|
||||
nn.init.normal_(self.w_vs.weight, mean=0, std=np.sqrt(2.0 / (d_model + d_v)))
|
||||
nn.init.xavier_normal_(self.fc.weight)
|
||||
# normalization
|
||||
self.layer_norm = nn.LayerNorm(d_model)
|
||||
# scaled dot-product
|
||||
self.attention = ScaledDotProductAttention(temperature=np.power(d_k, 0.5))
|
||||
|
||||
|
||||
def forward(self, q, k, v, mask=None):
|
||||
# dimensions
|
||||
d_k, d_v, n_head = self.d_k, self.d_v, self.n_head
|
||||
sz_b, len_q, _ = q.size()
|
||||
sz_b, len_k, _ = k.size()
|
||||
sz_b, len_v, _ = v.size()
|
||||
# linear transformations
|
||||
residual = q
|
||||
q = self.w_qs(q).view(sz_b, len_q, n_head, d_k)
|
||||
k = self.w_ks(k).view(sz_b, len_k, n_head, d_k)
|
||||
v = self.w_vs(v).view(sz_b, len_v, n_head, d_v)
|
||||
# scaled dot-product attention
|
||||
q = q.permute(2, 0, 1, 3).contiguous().view(-1, len_q, d_k) # (n*b) x lq x dk
|
||||
k = k.permute(2, 0, 1, 3).contiguous().view(-1, len_k, d_k) # (n*b) x lk x dk
|
||||
v = v.permute(2, 0, 1, 3).contiguous().view(-1, len_v, d_v) # (n*b) x lv x dv
|
||||
if mask:
|
||||
mask = mask.repeat(n_head, 1, 1) # (n*b) x .. x ..
|
||||
output, attn = self.attention(q, k, v, mask=mask)
|
||||
# linear transformation
|
||||
output = output.view(n_head, sz_b, len_q, d_v)
|
||||
output = output.permute(1, 2, 0, 3).contiguous().view(sz_b, len_q, -1) # b x lq x (n*dv)
|
||||
output = self.fc(output)
|
||||
# normalization
|
||||
output = self.layer_norm(output + residual)
|
||||
return output, attn
|
56
python/examples/batchnorm.py
Normal file
56
python/examples/batchnorm.py
Normal file
@@ -0,0 +1,56 @@
|
||||
import triton
|
||||
import numpy as np
|
||||
from enum import Enum
|
||||
|
||||
class MODE(Enum):
|
||||
TF = 1
|
||||
TORCH = 2
|
||||
|
||||
try:
|
||||
import tensorflow as tf
|
||||
mode = MODE.TF
|
||||
except ModuleNotFoundError:
|
||||
pass
|
||||
|
||||
try:
|
||||
import torch
|
||||
mode = MODE.TORCH
|
||||
except ModuleNotFoundError:
|
||||
pass
|
||||
|
||||
|
||||
C, H, W, B = 32, 1, 1, 128
|
||||
|
||||
x = np.random.uniform(-1, 1, (C, H, W, B)).astype(np.float32)
|
||||
gamma = np.random.uniform(-1, 1, C).astype(np.float32)
|
||||
beta = np.random.uniform(-1, 1, C).astype(np.float32)
|
||||
dy = np.random.uniform(-1, 1, (C, H, W, B)).astype(np.float32)
|
||||
|
||||
if mode == MODE.TORCH:
|
||||
fw_x = torch.from_numpy(x).cuda()
|
||||
fw_gamma = torch.from_numpy(gamma).cuda()
|
||||
fw_beta = torch.from_numpy(beta).cuda()
|
||||
fw_dy = torch.from_numpy(dy).cuda()
|
||||
# register gradients
|
||||
fw_x.requires_grad_(True)
|
||||
fw_gamma.requires_grad_(True)
|
||||
fw_beta.requires_grad_(True)
|
||||
# execute
|
||||
fw_y = triton.ops.batchnorm(fw_x, fw_gamma, fw_beta, 1e-4)
|
||||
fw_y.backward(fw_dy)
|
||||
|
||||
if mode == MODE.TF:
|
||||
fw_x = tf.placeholder(shape=x.shape, dtype=x.dtype)
|
||||
fw_gamma = tf.placeholder(shape=gamma.shape, dtype=gamma.dtype)
|
||||
fw_beta = tf.placeholder(shape=beta.shape, dtype=beta.dtype)
|
||||
fw_dy = tf.placeholder(shape=dy.shape, dtype=dy.dtype)
|
||||
# execute
|
||||
fw_y = triton.ops.batchnorm(fw_x, fw_gamma, fw_beta, 1e-4)
|
||||
fw_mean, fw_var = tf.nn.moments(fw_x, [1, 2, 3])
|
||||
fw_dx, fw_dgamma, fw_dbeta = tf.gradients(fw_y, [fw_x, fw_gamma, fw_beta], fw_dy)
|
||||
# run
|
||||
sess = tf.InteractiveSession()
|
||||
feed_dict = {fw_x: x, fw_gamma: gamma, fw_beta: beta, fw_dy: dy}
|
||||
sess.run(tf.global_variables_initializer())
|
||||
result = sess.run([fw_dx, fw_dgamma, fw_dbeta], feed_dict=feed_dict)
|
||||
print(result)
|
197
python/examples/einsum.py
Normal file
197
python/examples/einsum.py
Normal file
@@ -0,0 +1,197 @@
|
||||
import triton
|
||||
import torch
|
||||
from torch.utils.cpp_extension import load
|
||||
import numpy as np
|
||||
#import utils
|
||||
from time import time
|
||||
|
||||
#torch.backends.cudnn.benchmark = True
|
||||
|
||||
configs = []
|
||||
|
||||
# Matrix multiplication
|
||||
MNK = [
|
||||
(512, 512 ,512),
|
||||
(2048, 2048, 2048),
|
||||
(8192, 8192, 8192),
|
||||
|
||||
# (64, 64, 64000),
|
||||
# (64, 64, 128000),
|
||||
# (256, 256, 64000),
|
||||
# (256, 256, 128000),
|
||||
|
||||
# (1536, 16, 1536),
|
||||
# (1536, 32, 1536),
|
||||
# (1536, 64, 1536),
|
||||
# (1536, 128, 1536),
|
||||
# (4096, 16, 4096),
|
||||
# (4096, 32, 4096),
|
||||
# (4096, 64, 4096),
|
||||
# (4096, 128, 4096),
|
||||
|
||||
# (127008, 768, 576)
|
||||
]
|
||||
for M, N, K in MNK:
|
||||
matmul = lambda a, b: torch.matmul(a, b)
|
||||
configs += [([M, K], [K, N], [M, N], matmul, 'mk,kn->mn', dict())]
|
||||
for M, N, K in MNK:
|
||||
matmul = lambda a, b: torch.matmul(a.t(), b)
|
||||
configs += [([M, K], [M, N], [K, N], None, 'mk,mn->kn', dict())]
|
||||
for M, N, K in MNK:
|
||||
matmul = lambda a, b: torch.matmul(a, b.t())
|
||||
configs += [([M, N], [K, N], [M, K], None, 'mn,kn->mk', dict())]
|
||||
|
||||
# Relative attention
|
||||
NTHSE = [
|
||||
(16, 512, 1, 64, 64),
|
||||
# (16, 512, 1, 128, 128),
|
||||
# (16, 512, 1, 256, 256),
|
||||
# (16, 512, 1, 256, 512),
|
||||
(16, 512, 8, 64, 64),
|
||||
# (16, 512, 8, 128, 128),
|
||||
# (16, 512, 8, 256, 256),
|
||||
# (16, 512, 8, 256, 512),
|
||||
|
||||
# (64, 1024, 1, 64, 64),
|
||||
(64, 1024, 1, 128, 128),
|
||||
# (64, 1024, 1, 256, 256),
|
||||
# (64, 1024, 1, 256, 512),
|
||||
# (64, 1024, 8, 64, 64),
|
||||
(64, 1024, 8, 128, 128),
|
||||
# (64, 1024, 8, 256, 256),
|
||||
# (64, 1024, 8, 256, 512),
|
||||
|
||||
# (128, 1024, 1, 64, 64),
|
||||
# (128, 1024, 1, 128, 128),
|
||||
# (128, 1024, 1, 256, 256),
|
||||
(128, 1024, 1, 256, 512),
|
||||
# (128, 1024, 8, 64, 64),
|
||||
# (128, 1024, 8, 128, 128),
|
||||
# (128, 1024, 8, 256, 256),
|
||||
(128, 1024, 8, 256, 512)
|
||||
]
|
||||
for N, T, H, S, E in NTHSE:
|
||||
configs += [([N, T, H, S], [H, E, S], [N, H, T, E], None, 'nths,hes->nhte', dict())]
|
||||
for N, T, H, S, E in NTHSE:
|
||||
configs += [([N, H, T, E], [N, T, H, S], [H, E, S], None, 'nhte,nths->hes', dict())]
|
||||
for N, T, H, S, E in NTHSE:
|
||||
configs += [([N, H, T, E], [H, E, S], [N, T, H, S], None, 'nhte,hes->nths', dict())]
|
||||
|
||||
# 1D Dense convolution
|
||||
NCHKR = [
|
||||
(1, 1152, 12602, 512, 3)
|
||||
]
|
||||
for N, C, H, K, R in NCHKR:
|
||||
torch_fn = lambda a, b: torch.nn.functional.conv1d(a, b.permute(2, 0, 1))
|
||||
configs += [([N, C, H],
|
||||
[C, R, K],
|
||||
[N, K, H - R + 1],
|
||||
torch_fn,
|
||||
'nc(h+r),crk->nkh',
|
||||
dict())]
|
||||
|
||||
# 2D Dense convolution
|
||||
NCHWKRS = [
|
||||
(8, 64, 128, 128, 768, 3, 3),
|
||||
(8, 128, 64, 64, 256, 3, 3),
|
||||
(8, 256, 32, 32, 512, 3, 3),
|
||||
(8, 512, 32, 32, 1024, 3, 3)
|
||||
]
|
||||
for N, C, H, W, K, R, S in NCHWKRS:
|
||||
torch_fn = lambda a, b: torch.nn.functional.conv2d(a, b.permute(3, 0, 1, 2))
|
||||
configs += [([N, C, H, W],
|
||||
[C, R, S, K],
|
||||
[N, K, H - R + 1, W - R + 1],
|
||||
torch_fn,
|
||||
'nc(h+r)(w+s),crsk->nkhw',
|
||||
dict())]
|
||||
|
||||
# 3D Dense Convolution
|
||||
NCDHWKTRS = [
|
||||
(8, 32, 27, 100, 100, 64, 3, 3, 3),
|
||||
(8, 64, 23, 48, 48, 256, 3, 3, 3),
|
||||
(8, 256, 19, 22, 22, 640, 3, 3, 3),
|
||||
(8, 640, 15, 36, 36, 384, 3, 3, 3)
|
||||
]
|
||||
for N, C, D, H, W, K, T, R, S in NCDHWKTRS:
|
||||
torch_fn = lambda a, b: torch.nn.functional.conv3d(a, b.permute(4, 0, 1, 2, 3))
|
||||
configs += [([N, C, D, H, W],
|
||||
[C, T, R, S, K],
|
||||
[N, K, D - T + 1, H - R + 1, W - R + 1],
|
||||
torch_fn,
|
||||
'nc(d+t)(h+r)(w+s),ctrsk->nkdhw',
|
||||
dict())]
|
||||
|
||||
|
||||
# Shift convolution
|
||||
shift_cuda = torch.utils.cpp_extension.load(
|
||||
'shift_cuda', ['kernels/shift_cuda.cpp',
|
||||
'kernels/shift_cuda_kernel.cu'],
|
||||
extra_cflags=['-O3'])
|
||||
class shift(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def forward(ctx, x, shift):
|
||||
ctx.save_for_backward(shift)
|
||||
return shift_cuda.forward(x, shift)
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output):
|
||||
shift, = ctx.saved_tensors
|
||||
grad_output = shift_cuda.backward(grad_output, shift)
|
||||
|
||||
return grad_output, None
|
||||
|
||||
NCHWKRS = [
|
||||
#(8, 64, 128, 128, 128, 3, 3),
|
||||
#(8, 128, 64, 64, 256, 3, 3),
|
||||
#(8, 256, 32, 32, 512, 3, 3),
|
||||
#(8, 512, 32, 32, 1024, 3, 3)
|
||||
]
|
||||
for N, C, H, W, K, R, S in NCHWKRS:
|
||||
shift_h = np.random.randint(R, size=C, dtype=np.int32) - R//2
|
||||
shift_w = np.random.randint(S, size=C, dtype=np.int32) - S//2
|
||||
def shift_conv(a, b, **kwargs):
|
||||
shift_h, shift_w = kwargs['sh'], kwargs['sw']
|
||||
shift_torch = np.column_stack((shift_w*-1, shift_h*-1))
|
||||
shift_torch = torch.from_numpy(shift_torch).cuda()
|
||||
a = shift.apply(a, shift_torch)
|
||||
b = b.permute(1, 0)
|
||||
b = b.reshape(b.shape[0], b.shape[1], 1, 1)
|
||||
return torch.nn.functional.conv2d(a, b)
|
||||
configs += [([N, C, H, W],
|
||||
[C, K],
|
||||
[N, K, H, W],
|
||||
shift_conv,
|
||||
'nc(h + sh[c])(w + sw[c]),ck->nkhw',
|
||||
{'sh': shift_h, 'sw': shift_w})]
|
||||
|
||||
# Benchmark
|
||||
torch.set_num_threads(1)
|
||||
for a_shape, b_shape, c_shape, torch_fn, expr, arrays in configs:
|
||||
dtype = torch.cuda.FloatTensor
|
||||
# initialize input tensors
|
||||
a = torch.rand(*a_shape).type(dtype).cuda()
|
||||
b = torch.rand(*b_shape).type(dtype).cuda()
|
||||
# triton output
|
||||
tc = triton.ops.einsum(expr, a, b, c_shape, arrays = arrays, bench = True)
|
||||
# reference output
|
||||
if torch_fn:
|
||||
rc = torch_fn(a, b, **arrays)
|
||||
else:
|
||||
rc = torch.einsum(expr, a, b)
|
||||
# performance relative to equivalent matrix multiplication
|
||||
ctx = triton.ctx_registry[tc]
|
||||
B, M, N, K = ctx.matmul_B, ctx.matmul_M, ctx.matmul_N, ctx.matmul_K
|
||||
cmp_eqbmm = True
|
||||
if cmp_eqbmm:
|
||||
a = torch.rand(B, M, K).type(dtype).cuda()
|
||||
b = torch.rand(B, K, N).type(dtype).cuda()
|
||||
tmmc = triton.ops.einsum('bmk,bkn->bmn', a, b, [B, M, N], bench = True)
|
||||
ratio = triton.bench_registry[tmmc] / triton.bench_registry[tc]
|
||||
cmp_str = f'({ratio:4.2f})'
|
||||
else:
|
||||
cmp_str = ''
|
||||
# test and benchmark
|
||||
bench = 2. * B * M * N * K / triton.bench_registry[tc] * 1e-3
|
||||
diff = (tc - rc).abs().max() / rc.abs().max()
|
||||
print(f'{expr:>15}; {str(a_shape):>20}; {str(b_shape):>20}; {bench:4.2f} {cmp_str}; {diff:4.2f}')
|
42
python/examples/kernels/shift_cuda.cpp
Normal file
42
python/examples/kernels/shift_cuda.cpp
Normal file
@@ -0,0 +1,42 @@
|
||||
#include <torch/torch.h>
|
||||
|
||||
#include <vector>
|
||||
|
||||
// CUDA forward declarations
|
||||
|
||||
at::Tensor shift_cuda_forward(
|
||||
const at::Tensor input,
|
||||
const at::Tensor shift);
|
||||
|
||||
at::Tensor shift_cuda_backward(
|
||||
const at::Tensor grad_input,
|
||||
const at::Tensor shift);
|
||||
|
||||
// C++ interface
|
||||
|
||||
// NOTE: AT_ASSERT has become AT_CHECK on master after 0.4.
|
||||
#define CHECK_CUDA(x) AT_ASSERTM(x.type().is_cuda(), #x " must be a CUDA tensor")
|
||||
#define CHECK_CONTIGUOUS(x) AT_ASSERTM(x.is_contiguous(), #x " must be contiguous")
|
||||
#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)
|
||||
|
||||
at::Tensor shift_forward(
|
||||
const at::Tensor input,
|
||||
const at::Tensor shift) {
|
||||
CHECK_INPUT(input);
|
||||
CHECK_INPUT(shift);
|
||||
|
||||
return shift_cuda_forward(input, shift);
|
||||
}
|
||||
|
||||
at::Tensor shift_backward(
|
||||
const at::Tensor grad_input,
|
||||
const at::Tensor shift) {
|
||||
CHECK_INPUT(grad_input);
|
||||
CHECK_INPUT(shift);
|
||||
return shift_cuda_backward(grad_input, shift);
|
||||
}
|
||||
|
||||
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
||||
m.def("forward", &shift_forward, "Shift forward (CUDA)");
|
||||
m.def("backward", &shift_backward, "Shift backward (CUDA)");
|
||||
}
|
111
python/examples/kernels/shift_cuda_kernel.cu
Normal file
111
python/examples/kernels/shift_cuda_kernel.cu
Normal file
@@ -0,0 +1,111 @@
|
||||
#include <ATen/ATen.h>
|
||||
|
||||
#include <cuda.h>
|
||||
#include <cuda_runtime.h>
|
||||
|
||||
#include <vector>
|
||||
|
||||
namespace {
|
||||
template <typename scalar_t>
|
||||
__global__ void shift_cuda_forward_kernel(
|
||||
const scalar_t* __restrict__ input,
|
||||
const int32_t* __restrict__ shift,
|
||||
scalar_t* __restrict__ output,
|
||||
const int32_t B,
|
||||
const int32_t C,
|
||||
const int32_t H,
|
||||
const int32_t W) {
|
||||
const int32_t idx = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
const int32_t size = B*C*H*W;
|
||||
|
||||
const int32_t CHW = C*H*W;
|
||||
const int32_t HW = H*W;
|
||||
const int32_t b = idx / CHW;
|
||||
const int32_t c = (idx - b*CHW) / HW;
|
||||
const int32_t h = (idx - b*CHW - c*HW) / W;
|
||||
const int32_t w = idx - b*CHW - c*HW - h*W;
|
||||
const int32_t target_w = w + shift[2*c];
|
||||
const int32_t target_h = h + shift[2*c + 1];
|
||||
const int32_t target_idx = b*CHW + c*HW + target_h*W + target_w;
|
||||
if (idx < size && target_w >= 0 && target_w < W && target_h >= 0 && target_h < H) {
|
||||
output[target_idx] = input[idx];
|
||||
}
|
||||
}
|
||||
|
||||
template <typename scalar_t>
|
||||
__global__ void shift_cuda_backward_kernel(
|
||||
const scalar_t* __restrict__ grad_input,
|
||||
scalar_t* __restrict__ grad_output,
|
||||
const int32_t* __restrict__ shift,
|
||||
const int32_t B,
|
||||
const int32_t C,
|
||||
const int32_t W,
|
||||
const int32_t H) {
|
||||
const int32_t idx = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
const int32_t size = B*C*W*H;
|
||||
const int32_t CWH = C*W*H;
|
||||
const int32_t WH = W*H;
|
||||
const int32_t b = idx / CWH;
|
||||
const int32_t c = (idx - b*CWH) / WH;
|
||||
const int32_t w = (idx - b*CWH - c*WH) / W;
|
||||
const int32_t h = idx - b*CWH - c*WH - w*H;
|
||||
const int32_t target_w = w - shift[2*c];
|
||||
const int32_t target_h = h - shift[2*c + 1];
|
||||
const int32_t target_idx = b*CWH + c*WH + target_w*W + target_h;
|
||||
if (idx < size && target_w >= 0 && target_w < W && target_h >= 0 && target_h < H) {
|
||||
grad_output[target_idx] = grad_input[idx];
|
||||
}
|
||||
}
|
||||
} // namespace
|
||||
|
||||
at::Tensor shift_cuda_forward(
|
||||
const at::Tensor input,
|
||||
const at::Tensor shift) {
|
||||
const auto B = input.size(0);
|
||||
const auto C = input.size(1);
|
||||
const auto H = input.size(2);
|
||||
const auto W = input.size(3);
|
||||
const auto size = B*C*W*H;
|
||||
const int threads = 1024;
|
||||
const int blocks = (size + threads - 1) / threads;
|
||||
auto output = at::zeros_like(input);
|
||||
|
||||
AT_DISPATCH_FLOATING_TYPES_AND_HALF(input.type(), "shift_forward_cuda", ([&] {
|
||||
shift_cuda_forward_kernel<scalar_t><<<blocks, threads>>>(
|
||||
input.data<scalar_t>(),
|
||||
shift.data<int32_t>(),
|
||||
output.data<scalar_t>(),
|
||||
B,
|
||||
C,
|
||||
H,
|
||||
W);
|
||||
}));
|
||||
|
||||
return output;
|
||||
}
|
||||
|
||||
at::Tensor shift_cuda_backward(
|
||||
const at::Tensor grad_input,
|
||||
const at::Tensor shift) {
|
||||
const auto B = grad_input.size(0);
|
||||
const auto C = grad_input.size(1);
|
||||
const auto H = grad_input.size(2);
|
||||
const auto W = grad_input.size(3);
|
||||
const auto size = B*C*W*H;
|
||||
const int threads = 1024;
|
||||
const int blocks = (size + threads - 1) / threads;
|
||||
auto grad_output = at::zeros_like(grad_input);
|
||||
|
||||
AT_DISPATCH_FLOATING_TYPES_AND_HALF(grad_input.type(), "shift_backward_cuda", ([&] {
|
||||
shift_cuda_backward_kernel<scalar_t><<<blocks, threads>>>(
|
||||
grad_input.data<scalar_t>(),
|
||||
grad_output.data<scalar_t>(),
|
||||
shift.data<int32_t>(),
|
||||
B,
|
||||
C,
|
||||
H,
|
||||
W);
|
||||
}));
|
||||
|
||||
return grad_output;
|
||||
}
|
Reference in New Issue
Block a user