History prior to this date belonged to the now deprecated ISAAC project, and was deleted to save space

This commit is contained in:
Philippe Tillet
2021-07-27 12:38:38 -07:00
commit 6d7cf35123
202 changed files with 94034 additions and 0 deletions

View File

@@ -0,0 +1,48 @@
import torch
import numpy as np
import reference
import optimized
from time import time
use_half = True
def cast(x):
if use_half:
return x.half()
else:
return x
# GPU device
device = torch.device("cuda:0")
# shapes
batch, nhead = 8, 28
dm, dk, dv = 1024, 1024, 1024
lq, lk, lv = 1024, 1024, 1024
# initialize tensors
torch.manual_seed(0)
np.random.seed(0)
query = cast(torch.randn(batch, lq, dm)).cuda()
key = cast(torch.randn(batch, lk, dm)).cuda()
value = cast(torch.randn(batch, lv, dm)).cuda()
# initialize layers
torch.manual_seed(0)
np.random.seed(0)
rattn = cast(reference.MultiHeadAttention(nhead, dm, dk, dv).to(device))
torch.manual_seed(0)
np.random.seed(0)
tattn = cast(optimized.MultiHeadAttention(nhead, dm, dk, dv).to(device))
# test
routput, _ = rattn(query, key, value)
toutput, _ = tattn(query, key, value)
diff = torch.max(torch.abs(routput - toutput))
assert diff < 1e-2
# benchmark
start = time()
routput, _ = rattn(query, key, value)
end = time()
rtime = end - start
start = time()
toutput, _ = tattn(query, key, value)
end = time()
ttime = end - start
print(f'Torch: {rtime} s')
print(f'Triton: {ttime} s')

View File

@@ -0,0 +1,50 @@
import numpy as np
import torch
import torch.nn as nn
import triton
class MultiHeadAttention(nn.Module):
''' Multi-Head Attention module '''
def __init__(self, n_head, d_model, d_k, d_v):
super().__init__()
self.n_head = n_head
self.d_k = d_k
self.d_v = d_v
# linear layers
self.w_qs = nn.Linear(d_model, n_head * d_k)
self.w_ks = nn.Linear(d_model, n_head * d_k)
self.w_vs = nn.Linear(d_model, n_head * d_v)
self.fc = nn.Linear(n_head * d_v, d_model)
# initialize weights
nn.init.normal_(self.w_qs.weight, mean=0, std=np.sqrt(2.0 / (d_model + d_k)))
nn.init.normal_(self.w_ks.weight, mean=0, std=np.sqrt(2.0 / (d_model + d_k)))
nn.init.normal_(self.w_vs.weight, mean=0, std=np.sqrt(2.0 / (d_model + d_v)))
nn.init.xavier_normal_(self.fc.weight)
# layer normalization
self.layer_norm = nn.LayerNorm(d_model)
def forward(self, q, k, v, mask=None):
# dimensions
d_k, d_v, n_head = self.d_k, self.d_v, self.n_head
sz_b, len_q, _ = q.size()
sz_b, len_k, _ = k.size()
sz_b, len_v, _ = v.size()
# linear transformations
residual = q
q = self.w_qs(q).view(sz_b, len_q, n_head, d_k)
k = self.w_ks(k).view(sz_b, len_k, n_head, d_k)
v = self.w_vs(v).view(sz_b, len_v, n_head, d_v)
# scaled dot-product attention
attn = triton.ops.einsum('blhk,bthk->hblt', q, k, [n_head, sz_b, len_q, len_k])
attn = attn / np.sqrt(d_k)
if mask is not None:
attn = attn.masked_fill(mask[None], -np.inf)
attn = torch.softmax(attn, dim=3)
output = triton.ops.einsum('hblt,bthv->blhv', attn, v, [sz_b, len_q, n_head, d_v])
output = output.view(sz_b, len_q, -1)
output = self.fc(output)
# epilogue
output = self.layer_norm(output + residual)
return output, attn

View File

@@ -0,0 +1,72 @@
import numpy as np
import torch
import torch.nn as nn
class ScaledDotProductAttention(nn.Module):
''' Scaled Dot-Product Attention '''
def __init__(self, temperature, attn_dropout=0.1):
super().__init__()
self.temperature = temperature
self.softmax = nn.Softmax(dim=2)
def forward(self, q, k, v, mask=None):
attn = torch.bmm(q, k.transpose(1, 2))
attn = attn / self.temperature
if mask is not None:
attn = attn.masked_fill(mask, -np.inf)
attn = self.softmax(attn)
output = torch.bmm(attn, v)
return output, attn
class MultiHeadAttention(nn.Module):
''' Multi-Head Attention module '''
def __init__(self, n_head, d_model, d_k, d_v):
super().__init__()
self.n_head = n_head
self.d_k = d_k
self.d_v = d_v
# linear layers
self.w_qs = nn.Linear(d_model, n_head * d_k)
self.w_ks = nn.Linear(d_model, n_head * d_k)
self.w_vs = nn.Linear(d_model, n_head * d_v)
self.fc = nn.Linear(n_head * d_v, d_model)
# initialize weights
nn.init.normal_(self.w_qs.weight, mean=0, std=np.sqrt(2.0 / (d_model + d_k)))
nn.init.normal_(self.w_ks.weight, mean=0, std=np.sqrt(2.0 / (d_model + d_k)))
nn.init.normal_(self.w_vs.weight, mean=0, std=np.sqrt(2.0 / (d_model + d_v)))
nn.init.xavier_normal_(self.fc.weight)
# normalization
self.layer_norm = nn.LayerNorm(d_model)
# scaled dot-product
self.attention = ScaledDotProductAttention(temperature=np.power(d_k, 0.5))
def forward(self, q, k, v, mask=None):
# dimensions
d_k, d_v, n_head = self.d_k, self.d_v, self.n_head
sz_b, len_q, _ = q.size()
sz_b, len_k, _ = k.size()
sz_b, len_v, _ = v.size()
# linear transformations
residual = q
q = self.w_qs(q).view(sz_b, len_q, n_head, d_k)
k = self.w_ks(k).view(sz_b, len_k, n_head, d_k)
v = self.w_vs(v).view(sz_b, len_v, n_head, d_v)
# scaled dot-product attention
q = q.permute(2, 0, 1, 3).contiguous().view(-1, len_q, d_k) # (n*b) x lq x dk
k = k.permute(2, 0, 1, 3).contiguous().view(-1, len_k, d_k) # (n*b) x lk x dk
v = v.permute(2, 0, 1, 3).contiguous().view(-1, len_v, d_v) # (n*b) x lv x dv
if mask:
mask = mask.repeat(n_head, 1, 1) # (n*b) x .. x ..
output, attn = self.attention(q, k, v, mask=mask)
# linear transformation
output = output.view(n_head, sz_b, len_q, d_v)
output = output.permute(1, 2, 0, 3).contiguous().view(sz_b, len_q, -1) # b x lq x (n*dv)
output = self.fc(output)
# normalization
output = self.layer_norm(output + residual)
return output, attn

View File

@@ -0,0 +1,56 @@
import triton
import numpy as np
from enum import Enum
class MODE(Enum):
TF = 1
TORCH = 2
try:
import tensorflow as tf
mode = MODE.TF
except ModuleNotFoundError:
pass
try:
import torch
mode = MODE.TORCH
except ModuleNotFoundError:
pass
C, H, W, B = 32, 1, 1, 128
x = np.random.uniform(-1, 1, (C, H, W, B)).astype(np.float32)
gamma = np.random.uniform(-1, 1, C).astype(np.float32)
beta = np.random.uniform(-1, 1, C).astype(np.float32)
dy = np.random.uniform(-1, 1, (C, H, W, B)).astype(np.float32)
if mode == MODE.TORCH:
fw_x = torch.from_numpy(x).cuda()
fw_gamma = torch.from_numpy(gamma).cuda()
fw_beta = torch.from_numpy(beta).cuda()
fw_dy = torch.from_numpy(dy).cuda()
# register gradients
fw_x.requires_grad_(True)
fw_gamma.requires_grad_(True)
fw_beta.requires_grad_(True)
# execute
fw_y = triton.ops.batchnorm(fw_x, fw_gamma, fw_beta, 1e-4)
fw_y.backward(fw_dy)
if mode == MODE.TF:
fw_x = tf.placeholder(shape=x.shape, dtype=x.dtype)
fw_gamma = tf.placeholder(shape=gamma.shape, dtype=gamma.dtype)
fw_beta = tf.placeholder(shape=beta.shape, dtype=beta.dtype)
fw_dy = tf.placeholder(shape=dy.shape, dtype=dy.dtype)
# execute
fw_y = triton.ops.batchnorm(fw_x, fw_gamma, fw_beta, 1e-4)
fw_mean, fw_var = tf.nn.moments(fw_x, [1, 2, 3])
fw_dx, fw_dgamma, fw_dbeta = tf.gradients(fw_y, [fw_x, fw_gamma, fw_beta], fw_dy)
# run
sess = tf.InteractiveSession()
feed_dict = {fw_x: x, fw_gamma: gamma, fw_beta: beta, fw_dy: dy}
sess.run(tf.global_variables_initializer())
result = sess.run([fw_dx, fw_dgamma, fw_dbeta], feed_dict=feed_dict)
print(result)

197
python/examples/einsum.py Normal file
View File

@@ -0,0 +1,197 @@
import triton
import torch
from torch.utils.cpp_extension import load
import numpy as np
#import utils
from time import time
#torch.backends.cudnn.benchmark = True
configs = []
# Matrix multiplication
MNK = [
(512, 512 ,512),
(2048, 2048, 2048),
(8192, 8192, 8192),
# (64, 64, 64000),
# (64, 64, 128000),
# (256, 256, 64000),
# (256, 256, 128000),
# (1536, 16, 1536),
# (1536, 32, 1536),
# (1536, 64, 1536),
# (1536, 128, 1536),
# (4096, 16, 4096),
# (4096, 32, 4096),
# (4096, 64, 4096),
# (4096, 128, 4096),
# (127008, 768, 576)
]
for M, N, K in MNK:
matmul = lambda a, b: torch.matmul(a, b)
configs += [([M, K], [K, N], [M, N], matmul, 'mk,kn->mn', dict())]
for M, N, K in MNK:
matmul = lambda a, b: torch.matmul(a.t(), b)
configs += [([M, K], [M, N], [K, N], None, 'mk,mn->kn', dict())]
for M, N, K in MNK:
matmul = lambda a, b: torch.matmul(a, b.t())
configs += [([M, N], [K, N], [M, K], None, 'mn,kn->mk', dict())]
# Relative attention
NTHSE = [
(16, 512, 1, 64, 64),
# (16, 512, 1, 128, 128),
# (16, 512, 1, 256, 256),
# (16, 512, 1, 256, 512),
(16, 512, 8, 64, 64),
# (16, 512, 8, 128, 128),
# (16, 512, 8, 256, 256),
# (16, 512, 8, 256, 512),
# (64, 1024, 1, 64, 64),
(64, 1024, 1, 128, 128),
# (64, 1024, 1, 256, 256),
# (64, 1024, 1, 256, 512),
# (64, 1024, 8, 64, 64),
(64, 1024, 8, 128, 128),
# (64, 1024, 8, 256, 256),
# (64, 1024, 8, 256, 512),
# (128, 1024, 1, 64, 64),
# (128, 1024, 1, 128, 128),
# (128, 1024, 1, 256, 256),
(128, 1024, 1, 256, 512),
# (128, 1024, 8, 64, 64),
# (128, 1024, 8, 128, 128),
# (128, 1024, 8, 256, 256),
(128, 1024, 8, 256, 512)
]
for N, T, H, S, E in NTHSE:
configs += [([N, T, H, S], [H, E, S], [N, H, T, E], None, 'nths,hes->nhte', dict())]
for N, T, H, S, E in NTHSE:
configs += [([N, H, T, E], [N, T, H, S], [H, E, S], None, 'nhte,nths->hes', dict())]
for N, T, H, S, E in NTHSE:
configs += [([N, H, T, E], [H, E, S], [N, T, H, S], None, 'nhte,hes->nths', dict())]
# 1D Dense convolution
NCHKR = [
(1, 1152, 12602, 512, 3)
]
for N, C, H, K, R in NCHKR:
torch_fn = lambda a, b: torch.nn.functional.conv1d(a, b.permute(2, 0, 1))
configs += [([N, C, H],
[C, R, K],
[N, K, H - R + 1],
torch_fn,
'nc(h+r),crk->nkh',
dict())]
# 2D Dense convolution
NCHWKRS = [
(8, 64, 128, 128, 768, 3, 3),
(8, 128, 64, 64, 256, 3, 3),
(8, 256, 32, 32, 512, 3, 3),
(8, 512, 32, 32, 1024, 3, 3)
]
for N, C, H, W, K, R, S in NCHWKRS:
torch_fn = lambda a, b: torch.nn.functional.conv2d(a, b.permute(3, 0, 1, 2))
configs += [([N, C, H, W],
[C, R, S, K],
[N, K, H - R + 1, W - R + 1],
torch_fn,
'nc(h+r)(w+s),crsk->nkhw',
dict())]
# 3D Dense Convolution
NCDHWKTRS = [
(8, 32, 27, 100, 100, 64, 3, 3, 3),
(8, 64, 23, 48, 48, 256, 3, 3, 3),
(8, 256, 19, 22, 22, 640, 3, 3, 3),
(8, 640, 15, 36, 36, 384, 3, 3, 3)
]
for N, C, D, H, W, K, T, R, S in NCDHWKTRS:
torch_fn = lambda a, b: torch.nn.functional.conv3d(a, b.permute(4, 0, 1, 2, 3))
configs += [([N, C, D, H, W],
[C, T, R, S, K],
[N, K, D - T + 1, H - R + 1, W - R + 1],
torch_fn,
'nc(d+t)(h+r)(w+s),ctrsk->nkdhw',
dict())]
# Shift convolution
shift_cuda = torch.utils.cpp_extension.load(
'shift_cuda', ['kernels/shift_cuda.cpp',
'kernels/shift_cuda_kernel.cu'],
extra_cflags=['-O3'])
class shift(torch.autograd.Function):
@staticmethod
def forward(ctx, x, shift):
ctx.save_for_backward(shift)
return shift_cuda.forward(x, shift)
@staticmethod
def backward(ctx, grad_output):
shift, = ctx.saved_tensors
grad_output = shift_cuda.backward(grad_output, shift)
return grad_output, None
NCHWKRS = [
#(8, 64, 128, 128, 128, 3, 3),
#(8, 128, 64, 64, 256, 3, 3),
#(8, 256, 32, 32, 512, 3, 3),
#(8, 512, 32, 32, 1024, 3, 3)
]
for N, C, H, W, K, R, S in NCHWKRS:
shift_h = np.random.randint(R, size=C, dtype=np.int32) - R//2
shift_w = np.random.randint(S, size=C, dtype=np.int32) - S//2
def shift_conv(a, b, **kwargs):
shift_h, shift_w = kwargs['sh'], kwargs['sw']
shift_torch = np.column_stack((shift_w*-1, shift_h*-1))
shift_torch = torch.from_numpy(shift_torch).cuda()
a = shift.apply(a, shift_torch)
b = b.permute(1, 0)
b = b.reshape(b.shape[0], b.shape[1], 1, 1)
return torch.nn.functional.conv2d(a, b)
configs += [([N, C, H, W],
[C, K],
[N, K, H, W],
shift_conv,
'nc(h + sh[c])(w + sw[c]),ck->nkhw',
{'sh': shift_h, 'sw': shift_w})]
# Benchmark
torch.set_num_threads(1)
for a_shape, b_shape, c_shape, torch_fn, expr, arrays in configs:
dtype = torch.cuda.FloatTensor
# initialize input tensors
a = torch.rand(*a_shape).type(dtype).cuda()
b = torch.rand(*b_shape).type(dtype).cuda()
# triton output
tc = triton.ops.einsum(expr, a, b, c_shape, arrays = arrays, bench = True)
# reference output
if torch_fn:
rc = torch_fn(a, b, **arrays)
else:
rc = torch.einsum(expr, a, b)
# performance relative to equivalent matrix multiplication
ctx = triton.ctx_registry[tc]
B, M, N, K = ctx.matmul_B, ctx.matmul_M, ctx.matmul_N, ctx.matmul_K
cmp_eqbmm = True
if cmp_eqbmm:
a = torch.rand(B, M, K).type(dtype).cuda()
b = torch.rand(B, K, N).type(dtype).cuda()
tmmc = triton.ops.einsum('bmk,bkn->bmn', a, b, [B, M, N], bench = True)
ratio = triton.bench_registry[tmmc] / triton.bench_registry[tc]
cmp_str = f'({ratio:4.2f})'
else:
cmp_str = ''
# test and benchmark
bench = 2. * B * M * N * K / triton.bench_registry[tc] * 1e-3
diff = (tc - rc).abs().max() / rc.abs().max()
print(f'{expr:>15}; {str(a_shape):>20}; {str(b_shape):>20}; {bench:4.2f} {cmp_str}; {diff:4.2f}')

View File

@@ -0,0 +1,42 @@
#include <torch/torch.h>
#include <vector>
// CUDA forward declarations
at::Tensor shift_cuda_forward(
const at::Tensor input,
const at::Tensor shift);
at::Tensor shift_cuda_backward(
const at::Tensor grad_input,
const at::Tensor shift);
// C++ interface
// NOTE: AT_ASSERT has become AT_CHECK on master after 0.4.
#define CHECK_CUDA(x) AT_ASSERTM(x.type().is_cuda(), #x " must be a CUDA tensor")
#define CHECK_CONTIGUOUS(x) AT_ASSERTM(x.is_contiguous(), #x " must be contiguous")
#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)
at::Tensor shift_forward(
const at::Tensor input,
const at::Tensor shift) {
CHECK_INPUT(input);
CHECK_INPUT(shift);
return shift_cuda_forward(input, shift);
}
at::Tensor shift_backward(
const at::Tensor grad_input,
const at::Tensor shift) {
CHECK_INPUT(grad_input);
CHECK_INPUT(shift);
return shift_cuda_backward(grad_input, shift);
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("forward", &shift_forward, "Shift forward (CUDA)");
m.def("backward", &shift_backward, "Shift backward (CUDA)");
}

View File

@@ -0,0 +1,111 @@
#include <ATen/ATen.h>
#include <cuda.h>
#include <cuda_runtime.h>
#include <vector>
namespace {
template <typename scalar_t>
__global__ void shift_cuda_forward_kernel(
const scalar_t* __restrict__ input,
const int32_t* __restrict__ shift,
scalar_t* __restrict__ output,
const int32_t B,
const int32_t C,
const int32_t H,
const int32_t W) {
const int32_t idx = blockIdx.x * blockDim.x + threadIdx.x;
const int32_t size = B*C*H*W;
const int32_t CHW = C*H*W;
const int32_t HW = H*W;
const int32_t b = idx / CHW;
const int32_t c = (idx - b*CHW) / HW;
const int32_t h = (idx - b*CHW - c*HW) / W;
const int32_t w = idx - b*CHW - c*HW - h*W;
const int32_t target_w = w + shift[2*c];
const int32_t target_h = h + shift[2*c + 1];
const int32_t target_idx = b*CHW + c*HW + target_h*W + target_w;
if (idx < size && target_w >= 0 && target_w < W && target_h >= 0 && target_h < H) {
output[target_idx] = input[idx];
}
}
template <typename scalar_t>
__global__ void shift_cuda_backward_kernel(
const scalar_t* __restrict__ grad_input,
scalar_t* __restrict__ grad_output,
const int32_t* __restrict__ shift,
const int32_t B,
const int32_t C,
const int32_t W,
const int32_t H) {
const int32_t idx = blockIdx.x * blockDim.x + threadIdx.x;
const int32_t size = B*C*W*H;
const int32_t CWH = C*W*H;
const int32_t WH = W*H;
const int32_t b = idx / CWH;
const int32_t c = (idx - b*CWH) / WH;
const int32_t w = (idx - b*CWH - c*WH) / W;
const int32_t h = idx - b*CWH - c*WH - w*H;
const int32_t target_w = w - shift[2*c];
const int32_t target_h = h - shift[2*c + 1];
const int32_t target_idx = b*CWH + c*WH + target_w*W + target_h;
if (idx < size && target_w >= 0 && target_w < W && target_h >= 0 && target_h < H) {
grad_output[target_idx] = grad_input[idx];
}
}
} // namespace
at::Tensor shift_cuda_forward(
const at::Tensor input,
const at::Tensor shift) {
const auto B = input.size(0);
const auto C = input.size(1);
const auto H = input.size(2);
const auto W = input.size(3);
const auto size = B*C*W*H;
const int threads = 1024;
const int blocks = (size + threads - 1) / threads;
auto output = at::zeros_like(input);
AT_DISPATCH_FLOATING_TYPES_AND_HALF(input.type(), "shift_forward_cuda", ([&] {
shift_cuda_forward_kernel<scalar_t><<<blocks, threads>>>(
input.data<scalar_t>(),
shift.data<int32_t>(),
output.data<scalar_t>(),
B,
C,
H,
W);
}));
return output;
}
at::Tensor shift_cuda_backward(
const at::Tensor grad_input,
const at::Tensor shift) {
const auto B = grad_input.size(0);
const auto C = grad_input.size(1);
const auto H = grad_input.size(2);
const auto W = grad_input.size(3);
const auto size = B*C*W*H;
const int threads = 1024;
const int blocks = (size + threads - 1) / threads;
auto grad_output = at::zeros_like(grad_input);
AT_DISPATCH_FLOATING_TYPES_AND_HALF(grad_input.type(), "shift_backward_cuda", ([&] {
shift_cuda_backward_kernel<scalar_t><<<blocks, threads>>>(
grad_input.data<scalar_t>(),
grad_output.data<scalar_t>(),
shift.data<int32_t>(),
B,
C,
H,
W);
}));
return grad_output;
}