[GENERAL] Merged v1.0alpha into master. Added features are:
- A100 support via mma.16816 - Thread swizzling for conflict-free shared memory accesses without padding - Complete overhaul of the LLVM code generation in codegen/selection/generator.cc to remove overengineering - Added debugging capabilities in the Python binding - Compilation error for kernels that spill
This commit is contained in:
@@ -1,56 +0,0 @@
|
||||
import triton
|
||||
import numpy as np
|
||||
from enum import Enum
|
||||
|
||||
class MODE(Enum):
|
||||
TF = 1
|
||||
TORCH = 2
|
||||
|
||||
try:
|
||||
import tensorflow as tf
|
||||
mode = MODE.TF
|
||||
except ModuleNotFoundError:
|
||||
pass
|
||||
|
||||
try:
|
||||
import torch
|
||||
mode = MODE.TORCH
|
||||
except ModuleNotFoundError:
|
||||
pass
|
||||
|
||||
|
||||
C, H, W, B = 32, 1, 1, 128
|
||||
|
||||
x = np.random.uniform(-1, 1, (C, H, W, B)).astype(np.float32)
|
||||
gamma = np.random.uniform(-1, 1, C).astype(np.float32)
|
||||
beta = np.random.uniform(-1, 1, C).astype(np.float32)
|
||||
dy = np.random.uniform(-1, 1, (C, H, W, B)).astype(np.float32)
|
||||
|
||||
if mode == MODE.TORCH:
|
||||
fw_x = torch.from_numpy(x).cuda()
|
||||
fw_gamma = torch.from_numpy(gamma).cuda()
|
||||
fw_beta = torch.from_numpy(beta).cuda()
|
||||
fw_dy = torch.from_numpy(dy).cuda()
|
||||
# register gradients
|
||||
fw_x.requires_grad_(True)
|
||||
fw_gamma.requires_grad_(True)
|
||||
fw_beta.requires_grad_(True)
|
||||
# execute
|
||||
fw_y = triton.ops.batchnorm(fw_x, fw_gamma, fw_beta, 1e-4)
|
||||
fw_y.backward(fw_dy)
|
||||
|
||||
if mode == MODE.TF:
|
||||
fw_x = tf.placeholder(shape=x.shape, dtype=x.dtype)
|
||||
fw_gamma = tf.placeholder(shape=gamma.shape, dtype=gamma.dtype)
|
||||
fw_beta = tf.placeholder(shape=beta.shape, dtype=beta.dtype)
|
||||
fw_dy = tf.placeholder(shape=dy.shape, dtype=dy.dtype)
|
||||
# execute
|
||||
fw_y = triton.ops.batchnorm(fw_x, fw_gamma, fw_beta, 1e-4)
|
||||
fw_mean, fw_var = tf.nn.moments(fw_x, [1, 2, 3])
|
||||
fw_dx, fw_dgamma, fw_dbeta = tf.gradients(fw_y, [fw_x, fw_gamma, fw_beta], fw_dy)
|
||||
# run
|
||||
sess = tf.InteractiveSession()
|
||||
feed_dict = {fw_x: x, fw_gamma: gamma, fw_beta: beta, fw_dy: dy}
|
||||
sess.run(tf.global_variables_initializer())
|
||||
result = sess.run([fw_dx, fw_dgamma, fw_dbeta], feed_dict=feed_dict)
|
||||
print(result)
|
@@ -1,213 +0,0 @@
|
||||
import triton
|
||||
import torch
|
||||
from torch.utils.cpp_extension import load
|
||||
import numpy as np
|
||||
#import utils
|
||||
from time import time
|
||||
|
||||
torch.manual_seed(0)
|
||||
|
||||
#torch.backends.cudnn.benchmark = True
|
||||
|
||||
configs = []
|
||||
|
||||
# Matrix multiplication
|
||||
MNK = [
|
||||
(512, 512 ,512),
|
||||
(2048, 2048, 2048),
|
||||
#(8192, 8192, 8192),
|
||||
|
||||
(64, 64, 64000),
|
||||
(64, 64, 128000),
|
||||
(256, 256, 64000),
|
||||
(256, 256, 128000),
|
||||
|
||||
(1536, 16, 1536),
|
||||
(1536, 32, 1536),
|
||||
(1536, 64, 1536),
|
||||
# (1536, 128, 1536),
|
||||
# (4096, 16, 4096),
|
||||
# (4096, 32, 4096),
|
||||
# (4096, 64, 4096),
|
||||
# (4096, 128, 4096),
|
||||
|
||||
# (127008, 768, 576)
|
||||
]
|
||||
for M, N, K in MNK:
|
||||
matmul = lambda a, b: torch.matmul(a, b)
|
||||
configs += [([M, K], [K, N], [M, N], matmul, 'mk,kn->mn', dict(), None, None, None)]
|
||||
#for M, N, K in MNK:
|
||||
# matmul = lambda a, b: torch.matmul(a.t(), b)
|
||||
# configs += [([M, K], [M, N], [K, N], None, 'mk,mn->kn', dict(), None, None, None)]
|
||||
#for M, N, K in MNK:
|
||||
# matmul = lambda a, b: torch.matmul(a, b.t())
|
||||
# configs += [([M, N], [K, N], [M, K], None, 'mn,kn->mk', dict(), None, None, None)]
|
||||
|
||||
# Relative attention
|
||||
NTHSE = [
|
||||
(16, 512, 1, 64, 64),
|
||||
# (16, 512, 1, 128, 128),
|
||||
# (16, 512, 1, 256, 256),
|
||||
# (16, 512, 1, 256, 512),
|
||||
(16, 512, 8, 64, 64),
|
||||
# (16, 512, 8, 128, 128),
|
||||
# (16, 512, 8, 256, 256),
|
||||
# (16, 512, 8, 256, 512),
|
||||
|
||||
# (64, 1024, 1, 64, 64),
|
||||
(64, 1024, 1, 128, 128),
|
||||
# (64, 1024, 1, 256, 256),
|
||||
# (64, 1024, 1, 256, 512),
|
||||
# (64, 1024, 8, 64, 64),
|
||||
(64, 1024, 8, 128, 128),
|
||||
# (64, 1024, 8, 256, 256),
|
||||
# (64, 1024, 8, 256, 512),
|
||||
|
||||
# (128, 1024, 1, 64, 64),
|
||||
# (128, 1024, 1, 128, 128),
|
||||
# (128, 1024, 1, 256, 256),
|
||||
(128, 1024, 1, 256, 512),
|
||||
# (128, 1024, 8, 64, 64),
|
||||
# (128, 1024, 8, 128, 128),
|
||||
# (128, 1024, 8, 256, 256),
|
||||
#(128, 1024, 8, 256, 512)
|
||||
]
|
||||
#for N, T, H, S, E in NTHSE:
|
||||
# configs += [([N, T, H, S], [H, E, S], [N, H, T, E], None, 'nths,hes->nhte', dict(), None, None, None)]
|
||||
#for N, T, H, S, E in NTHSE:
|
||||
# configs += [([N, H, T, E], [N, T, H, S], [H, E, S], None, 'nhte,nths->hes', dict(), None, None, None)]
|
||||
#for N, T, H, S, E in NTHSE:
|
||||
# configs += [([N, H, T, E], [H, E, S], [N, T, H, S], None, 'nhte,hes->nths', dict(), None, None, None)]
|
||||
|
||||
# 1D Dense convolution
|
||||
NCHKR = [
|
||||
#(1, 1152, 12602, 512, 3)
|
||||
]
|
||||
for N, C, H, K, R in NCHKR:
|
||||
torch_fn = lambda a, b: torch.nn.functional.conv1d(a, b.permute(2, 0, 1))
|
||||
configs += [([N, C, H],
|
||||
[C, R, K],
|
||||
[N, K, H - R + 1],
|
||||
torch_fn,
|
||||
'nc(h+r),crk->nkh',
|
||||
dict(), None, None, None)]
|
||||
|
||||
# 2D Dense convolution
|
||||
NCHWKRS = [
|
||||
#(8, 64, 128, 128, 768, 3, 3),
|
||||
#(128, 3, 32, 32, 64, 3, 3),
|
||||
#(1, 1024, 32, 112, 112, 1024, 3, 3),
|
||||
#(8, 512, 32, 32, 1024, 3, 3)
|
||||
]
|
||||
for N, C, G, H, W, K, R, S in NCHWKRS:
|
||||
stride = 2
|
||||
torch_fn = lambda a, b: torch.nn.functional.conv2d(a, b.permute(3, 0, 1, 2), stride=stride, groups=G)
|
||||
P = (H - R + 1) // stride
|
||||
Q = (W - S + 1) // stride
|
||||
transform_a = lambda a: a.view(N, G, C // G, H, W)
|
||||
transform_b = lambda b: b.view(C // G, R, S, G, K // G)
|
||||
transform_c = lambda c: c.view(N, K, P, Q)
|
||||
configs += [([N, C, H, W],
|
||||
[C // G, R, S, K],
|
||||
[N, G, K // G, P, Q],
|
||||
torch_fn,
|
||||
'ngc(h*2+r)(w*2+s),crsgk->ngkhw',
|
||||
dict(), transform_a, transform_b, transform_c)]
|
||||
|
||||
|
||||
# 3D Dense Convolution
|
||||
NCDHWKTRS = [
|
||||
#(8, 32, 27, 100, 100, 64, 3, 3, 3),
|
||||
#(8, 64, 23, 48, 48, 256, 3, 3, 3),
|
||||
#(8, 256, 19, 22, 22, 640, 3, 3, 3),
|
||||
#(8, 640, 15, 36, 36, 384, 3, 3, 3)
|
||||
]
|
||||
for N, C, D, H, W, K, T, R, S in NCDHWKTRS:
|
||||
torch_fn = lambda a, b: torch.nn.functional.conv3d(a, b.permute(4, 0, 1, 2, 3))
|
||||
configs += [([N, C, D, H, W],
|
||||
[C, T, R, S, K],
|
||||
[N, K, D - T + 1, H - R + 1, W - R + 1],
|
||||
torch_fn,
|
||||
'nc(d+t)(h+r)(w+s),ctrsk->nkdhw',
|
||||
dict(), None, None, None)]
|
||||
|
||||
|
||||
# Shift convolution
|
||||
shift_cuda = torch.utils.cpp_extension.load(
|
||||
'shift_cuda', ['kernels/shift_cuda.cpp',
|
||||
'kernels/shift_cuda_kernel.cu'],
|
||||
extra_cflags=['-O3'])
|
||||
class shift(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def forward(ctx, x, shift):
|
||||
ctx.save_for_backward(shift)
|
||||
return shift_cuda.forward(x, shift)
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output):
|
||||
shift, = ctx.saved_tensors
|
||||
grad_output = shift_cuda.backward(grad_output, shift)
|
||||
|
||||
return grad_output, None
|
||||
|
||||
NCHWKRS = [
|
||||
#(8, 64, 128, 128, 128, 3, 3),
|
||||
#(8, 128, 64, 64, 256, 3, 3),
|
||||
#(8, 256, 32, 32, 512, 3, 3),
|
||||
#(8, 512, 32, 32, 1024, 3, 3)
|
||||
]
|
||||
for N, C, H, W, K, R, S in NCHWKRS:
|
||||
shift_h = np.random.randint(R, size=C, dtype=np.int32) - R//2
|
||||
shift_w = np.random.randint(S, size=C, dtype=np.int32) - S//2
|
||||
def shift_conv(a, b, **kwargs):
|
||||
shift_h, shift_w = kwargs['sh'], kwargs['sw']
|
||||
shift_torch = np.column_stack((shift_w*-1, shift_h*-1))
|
||||
shift_torch = torch.from_numpy(shift_torch).cuda()
|
||||
a = shift.apply(a, shift_torch)
|
||||
b = b.permute(1, 0)
|
||||
b = b.reshape(b.shape[0], b.shape[1], 1, 1)
|
||||
return torch.nn.functional.conv2d(a, b)
|
||||
configs += [([N, C, H, W],
|
||||
[C, K],
|
||||
[N, K, H, W],
|
||||
shift_conv,
|
||||
'nc(h + sh[c])(w + sw[c]),ck->nkhw',
|
||||
{'sh': shift_h, 'sw': shift_w},
|
||||
None, None, None)]
|
||||
|
||||
# Benchmark
|
||||
torch.set_num_threads(1)
|
||||
for a_shape, b_shape, c_shape, torch_fn, expr, arrays, \
|
||||
transform_a, transform_b, transform_c in configs:
|
||||
dtype = torch.cuda.FloatTensor
|
||||
# initialize input tensors
|
||||
a = torch.rand(*a_shape).type(dtype).cuda()
|
||||
b = torch.rand(*b_shape).type(dtype).cuda()
|
||||
# reference output
|
||||
if torch_fn:
|
||||
rc = torch_fn(a, b, **arrays)
|
||||
else:
|
||||
rc = torch.einsum(expr, a, b)
|
||||
# triton output
|
||||
ta = a if transform_a is None else transform_a(a)
|
||||
tb = b if transform_b is None else transform_b(b)
|
||||
tc = torch.empty(c_shape, device=a.device)
|
||||
triton.ops.einsum(expr, ta, tb, tc, arrays = arrays, bench = True)
|
||||
ctx = triton.ops._einsum.registry[tc]
|
||||
tc = tc if transform_c is None else transform_c(tc)
|
||||
# performance relative to equivalent matrix multiplication
|
||||
B, M, N, K = ctx.matmul_B, ctx.matmul_M, ctx.matmul_N, ctx.matmul_K
|
||||
cmp_eqbmm = True
|
||||
if cmp_eqbmm:
|
||||
a = torch.rand(B, M, K).type(dtype).cuda()
|
||||
b = torch.rand(B, K, N).type(dtype).cuda()
|
||||
c = torch.empty((B, M, N), device=a.device).cuda()
|
||||
tmmc = triton.ops.einsum('bmk,bkn->bmn', a, b, c, bench = True)
|
||||
ratio = triton.ops._einsum.registry[tmmc].forward_ms / ctx.forward_ms
|
||||
cmp_str = f'({ratio:4.2f})'
|
||||
else:
|
||||
cmp_str = ''
|
||||
# test and benchmark
|
||||
bench = 2. * B * M * N * K / ctx.forward_ms * 1e-3
|
||||
diff = (tc - rc).abs().max() / rc.abs().max()
|
||||
print(f'{expr:>15}; {str(a_shape):>20}; {str(b_shape):>20}; {bench:4.2f} {cmp_str}; {diff:4.2f}')
|
@@ -1,42 +0,0 @@
|
||||
#include <torch/torch.h>
|
||||
|
||||
#include <vector>
|
||||
|
||||
// CUDA forward declarations
|
||||
|
||||
at::Tensor shift_cuda_forward(
|
||||
const at::Tensor input,
|
||||
const at::Tensor shift);
|
||||
|
||||
at::Tensor shift_cuda_backward(
|
||||
const at::Tensor grad_input,
|
||||
const at::Tensor shift);
|
||||
|
||||
// C++ interface
|
||||
|
||||
// NOTE: AT_ASSERT has become AT_CHECK on master after 0.4.
|
||||
#define CHECK_CUDA(x) AT_ASSERTM(x.type().is_cuda(), #x " must be a CUDA tensor")
|
||||
#define CHECK_CONTIGUOUS(x) AT_ASSERTM(x.is_contiguous(), #x " must be contiguous")
|
||||
#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)
|
||||
|
||||
at::Tensor shift_forward(
|
||||
const at::Tensor input,
|
||||
const at::Tensor shift) {
|
||||
CHECK_INPUT(input);
|
||||
CHECK_INPUT(shift);
|
||||
|
||||
return shift_cuda_forward(input, shift);
|
||||
}
|
||||
|
||||
at::Tensor shift_backward(
|
||||
const at::Tensor grad_input,
|
||||
const at::Tensor shift) {
|
||||
CHECK_INPUT(grad_input);
|
||||
CHECK_INPUT(shift);
|
||||
return shift_cuda_backward(grad_input, shift);
|
||||
}
|
||||
|
||||
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
||||
m.def("forward", &shift_forward, "Shift forward (CUDA)");
|
||||
m.def("backward", &shift_backward, "Shift backward (CUDA)");
|
||||
}
|
@@ -1,111 +0,0 @@
|
||||
#include <ATen/ATen.h>
|
||||
|
||||
#include <cuda.h>
|
||||
#include <cuda_runtime.h>
|
||||
|
||||
#include <vector>
|
||||
|
||||
namespace {
|
||||
template <typename scalar_t>
|
||||
__global__ void shift_cuda_forward_kernel(
|
||||
const scalar_t* __restrict__ input,
|
||||
const int32_t* __restrict__ shift,
|
||||
scalar_t* __restrict__ output,
|
||||
const int32_t B,
|
||||
const int32_t C,
|
||||
const int32_t H,
|
||||
const int32_t W) {
|
||||
const int32_t idx = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
const int32_t size = B*C*H*W;
|
||||
|
||||
const int32_t CHW = C*H*W;
|
||||
const int32_t HW = H*W;
|
||||
const int32_t b = idx / CHW;
|
||||
const int32_t c = (idx - b*CHW) / HW;
|
||||
const int32_t h = (idx - b*CHW - c*HW) / W;
|
||||
const int32_t w = idx - b*CHW - c*HW - h*W;
|
||||
const int32_t target_w = w + shift[2*c];
|
||||
const int32_t target_h = h + shift[2*c + 1];
|
||||
const int32_t target_idx = b*CHW + c*HW + target_h*W + target_w;
|
||||
if (idx < size && target_w >= 0 && target_w < W && target_h >= 0 && target_h < H) {
|
||||
output[target_idx] = input[idx];
|
||||
}
|
||||
}
|
||||
|
||||
template <typename scalar_t>
|
||||
__global__ void shift_cuda_backward_kernel(
|
||||
const scalar_t* __restrict__ grad_input,
|
||||
scalar_t* __restrict__ grad_output,
|
||||
const int32_t* __restrict__ shift,
|
||||
const int32_t B,
|
||||
const int32_t C,
|
||||
const int32_t W,
|
||||
const int32_t H) {
|
||||
const int32_t idx = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
const int32_t size = B*C*W*H;
|
||||
const int32_t CWH = C*W*H;
|
||||
const int32_t WH = W*H;
|
||||
const int32_t b = idx / CWH;
|
||||
const int32_t c = (idx - b*CWH) / WH;
|
||||
const int32_t w = (idx - b*CWH - c*WH) / W;
|
||||
const int32_t h = idx - b*CWH - c*WH - w*H;
|
||||
const int32_t target_w = w - shift[2*c];
|
||||
const int32_t target_h = h - shift[2*c + 1];
|
||||
const int32_t target_idx = b*CWH + c*WH + target_w*W + target_h;
|
||||
if (idx < size && target_w >= 0 && target_w < W && target_h >= 0 && target_h < H) {
|
||||
grad_output[target_idx] = grad_input[idx];
|
||||
}
|
||||
}
|
||||
} // namespace
|
||||
|
||||
at::Tensor shift_cuda_forward(
|
||||
const at::Tensor input,
|
||||
const at::Tensor shift) {
|
||||
const auto B = input.size(0);
|
||||
const auto C = input.size(1);
|
||||
const auto H = input.size(2);
|
||||
const auto W = input.size(3);
|
||||
const auto size = B*C*W*H;
|
||||
const int threads = 1024;
|
||||
const int blocks = (size + threads - 1) / threads;
|
||||
auto output = at::zeros_like(input);
|
||||
|
||||
AT_DISPATCH_FLOATING_TYPES_AND_HALF(input.type(), "shift_forward_cuda", ([&] {
|
||||
shift_cuda_forward_kernel<scalar_t><<<blocks, threads>>>(
|
||||
input.data<scalar_t>(),
|
||||
shift.data<int32_t>(),
|
||||
output.data<scalar_t>(),
|
||||
B,
|
||||
C,
|
||||
H,
|
||||
W);
|
||||
}));
|
||||
|
||||
return output;
|
||||
}
|
||||
|
||||
at::Tensor shift_cuda_backward(
|
||||
const at::Tensor grad_input,
|
||||
const at::Tensor shift) {
|
||||
const auto B = grad_input.size(0);
|
||||
const auto C = grad_input.size(1);
|
||||
const auto H = grad_input.size(2);
|
||||
const auto W = grad_input.size(3);
|
||||
const auto size = B*C*W*H;
|
||||
const int threads = 1024;
|
||||
const int blocks = (size + threads - 1) / threads;
|
||||
auto grad_output = at::zeros_like(grad_input);
|
||||
|
||||
AT_DISPATCH_FLOATING_TYPES_AND_HALF(grad_input.type(), "shift_backward_cuda", ([&] {
|
||||
shift_cuda_backward_kernel<scalar_t><<<blocks, threads>>>(
|
||||
grad_input.data<scalar_t>(),
|
||||
grad_output.data<scalar_t>(),
|
||||
shift.data<int32_t>(),
|
||||
B,
|
||||
C,
|
||||
H,
|
||||
W);
|
||||
}));
|
||||
|
||||
return grad_output;
|
||||
}
|
@@ -1,109 +0,0 @@
|
||||
import triton
|
||||
import numpy
|
||||
import torch
|
||||
import itertools
|
||||
|
||||
torch.manual_seed(0)
|
||||
numpy.random.seed(0)
|
||||
|
||||
def to_sparse(expr, data, layout, shape, block):
|
||||
# shape of result
|
||||
sparse = None
|
||||
shape_ret = []
|
||||
for i, d in enumerate(expr):
|
||||
if d.isupper() and sparse is None:
|
||||
sparse = i
|
||||
shape_ret.append(int(layout.sum()))
|
||||
if d.isupper():
|
||||
shape_ret.append(block[d])
|
||||
else:
|
||||
shape_ret.append(shape[i])
|
||||
# iterator
|
||||
steps = [block[d] if d.isupper() else 1 for d in expr]
|
||||
it = [range(0, shape[i], steps[i]) for i in range(len(expr))]
|
||||
# create result
|
||||
ret = torch.empty(*shape_ret, dtype=data.dtype, device=data.device)
|
||||
blockid = 0
|
||||
nzblockid = 0
|
||||
for curr in itertools.product(*it):
|
||||
if all([curr[i] == it[i][0] for i in range(len(curr)) if expr[i].isupper()]):
|
||||
blockid = 0
|
||||
nzblockid = 0
|
||||
data_slice = [slice(curr[i], curr[i] + steps[i], 1) for i in range(len(curr))]
|
||||
ret_slice = [slice(0, block[expr[i]], 1) if expr[i].isupper() else slice(curr[i], curr[i] + 1) for i in range(len(curr))]
|
||||
ret_slice.insert(sparse, nzblockid)
|
||||
if int(layout.view(-1)[blockid]):
|
||||
ret[ret_slice] = data[data_slice]
|
||||
nzblockid += 1
|
||||
blockid += 1
|
||||
return ret
|
||||
|
||||
def to_dense(expr, data, layout, shape, block):
|
||||
sparse = None
|
||||
for i, d in enumerate(expr):
|
||||
if d.isupper() and sparse is None:
|
||||
sparse = i
|
||||
|
||||
ret = torch.zeros(*shape, dtype=data.dtype, device=data.device)
|
||||
steps = [block[d] if d.isupper() else 1 for d in expr]
|
||||
it = [range(0, shape[i], steps[i]) for i in range(len(expr))]
|
||||
blockid = 0
|
||||
nzblockid = 0
|
||||
for curr in itertools.product(*it):
|
||||
if all([curr[i] == it[i][0] for i in range(len(curr)) if expr[i].isupper()]):
|
||||
blockid = 0
|
||||
nzblockid = 0
|
||||
ret_slice = [slice(curr[i], curr[i] + steps[i], 1) for i in range(len(curr))]
|
||||
data_slice = [slice(0, block[expr[i]], 1) if expr[i].isupper() else slice(curr[i], curr[i] + 1) for i in range(len(curr))]
|
||||
data_slice.insert(sparse, nzblockid)
|
||||
if int(layout.view(-1)[blockid]):
|
||||
ret[ret_slice] = data[data_slice]
|
||||
nzblockid += 1
|
||||
blockid += 1
|
||||
return ret
|
||||
|
||||
def test_expr(expr, shape, blocks):
|
||||
# decompose expr
|
||||
expr_a, expr_bc = expr.split(",")
|
||||
expr_b, expr_c = expr_bc.split("->")
|
||||
# check with argument is sparse
|
||||
sparse_a = any(x.isupper() for x in expr_a)
|
||||
sparse_b = any(x.isupper() for x in expr_b)
|
||||
sparse_c = any(x.isupper() for x in expr_c)
|
||||
# allocate data
|
||||
shape_a = [shape[d.lower()] for d in expr_a]
|
||||
shape_b = [shape[d.lower()] for d in expr_b]
|
||||
shape_c = [shape[d.lower()] for d in expr_c]
|
||||
ref_a = torch.rand(*shape_a, device='cuda')
|
||||
ref_b = torch.rand(*shape_b, device='cuda')
|
||||
ref_c = torch.zeros(*shape_c, device='cuda')
|
||||
# layouts
|
||||
layout_a = [shape[d.lower()]//blocks[d] for d in expr_a if d.isupper()]
|
||||
layout_b = [shape[d.lower()]//blocks[d] for d in expr_b if d.isupper()]
|
||||
layout_c = [shape[d.lower()]//blocks[d] for d in expr_c if d.isupper()]
|
||||
layout_a = torch.randint(0, 2, layout_a, device='cuda')
|
||||
layout_b = torch.randint(0, 2, layout_b, device='cuda')
|
||||
layout_c = torch.randint(0, 2, layout_c, device='cuda')
|
||||
# triton computation
|
||||
triton_a = to_sparse(expr_a, ref_a, layout_a, shape_a, blocks) if sparse_a else ref_a
|
||||
triton_b = to_sparse(expr_b, ref_b, layout_b, shape_b, blocks) if sparse_b else ref_b
|
||||
layouts = {expr_a: layout_a, expr_b: layout_b, expr_c: layout_c}
|
||||
triton_c = triton.ops.einsum(expr, triton_a, triton_b, layouts, blocks)
|
||||
torch.cuda.synchronize()
|
||||
# reference computation
|
||||
ref_a = to_dense(expr_a, triton_a, layout_a, shape_a, blocks) if sparse_a else ref_a
|
||||
ref_b = to_dense(expr_b, triton_b, layout_b, shape_b, blocks) if sparse_b else ref_b
|
||||
ref_c = torch.einsum(expr.lower(), ref_a, ref_b)
|
||||
if sparse_c:
|
||||
ref_c = to_sparse(expr_c, ref_c, layout_c, shape_c, blocks)
|
||||
torch.cuda.synchronize()
|
||||
print((ref_c - triton_c).abs().max())
|
||||
|
||||
|
||||
|
||||
|
||||
# shape characteristics
|
||||
test_expr('bHMK,bhkn->bhmn', {'b': 2, 'h': 2, 'm': 256, 'k': 256, 'n': 256}, {'H': 1, 'M': 32, 'K': 32})
|
||||
test_expr('bhmk,bHKN->bhmn', {'b': 2, 'h': 2, 'm': 256, 'k': 256, 'n': 256}, {'H': 1, 'K': 32, 'N': 32})
|
||||
test_expr('bhmk,bhkn->bHMN', {'b': 2, 'h': 2, 'm': 256, 'k': 256, 'n': 256}, {'H': 1, 'M': 32, 'N': 32})
|
||||
|
@@ -171,7 +171,7 @@ class _conv(torch.autograd.Function):
|
||||
_conv.kernel[dtype] = (delta, triton.kernel(_conv.src, num_warps=[2, 4], defines=defines))
|
||||
delta, kernel = _conv.kernel[dtype]
|
||||
# allocate output
|
||||
c = triton.empty([Z, CO, P, Q], dtype=dtype)
|
||||
c = torch.empty([Z, CO, P, Q], dtype=dtype)
|
||||
# enqueue
|
||||
grid = lambda opt: [triton.cdiv(Z*P*Q, opt.d('TM')),
|
||||
triton.cdiv(CO, opt.d('TN'))]
|
||||
|
@@ -3,6 +3,9 @@ import triton
|
||||
|
||||
class _dot(torch.autograd.Function):
|
||||
src = """
|
||||
#define STM 4
|
||||
#define STN 4
|
||||
|
||||
__global__ void dot(TYPE * A __noalias __readonly __aligned(16),
|
||||
TYPE * B __noalias __readonly __aligned(16),
|
||||
TYPE * C __noalias __aligned(16),
|
||||
@@ -14,20 +17,26 @@ __global__ void dot(TYPE * A __noalias __readonly __aligned(16),
|
||||
int ldb __multipleof(8),
|
||||
int ldc __multipleof(8)) {
|
||||
// prologue
|
||||
int ridx = get_program_id(0);
|
||||
int ridy = get_program_id(1);
|
||||
int ridz = get_program_id(2);
|
||||
int gridx = M / TM;
|
||||
int gridy = N / TN;
|
||||
int rid = ridx + ridy * gridx;
|
||||
ridx = rid / gridy;
|
||||
ridy = rid % gridy;
|
||||
int rm[TM] = ridx * TM + 0 ... TM;
|
||||
int rn[TN] = ridy * TN + 0 ... TN;
|
||||
int pid = get_program_id(0);
|
||||
int pidz = get_program_id(2);
|
||||
int gridm = M / TM;
|
||||
int gridn = N / TN;
|
||||
int stgridm = (gridm + STM - 1) / STM;
|
||||
int stgridn = (gridn + STN - 1) / STN;
|
||||
int stid = pid / (STM * STN);
|
||||
int laneid = pid % (STM * STN);
|
||||
int stm = stid / stgridn;
|
||||
int stn = stid % stgridn;
|
||||
int lanem = laneid / STN;
|
||||
int lanen = laneid % STN;
|
||||
int pidm = stm*STM + lanem;
|
||||
int pidn = stn*STN + lanen;
|
||||
int rm[TM] = pidm * TM + 0 ... TM;
|
||||
int rn[TN] = pidn * TN + 0 ... TN;
|
||||
|
||||
// reduction splitting
|
||||
K = K / TZ;
|
||||
int rk[TK] = ridz * K + 0 ... TK;
|
||||
int rk[TK] = pidz * K + 0 ... TK;
|
||||
|
||||
// pointers to operands
|
||||
int offa[TM, TK] = rk[newaxis, :] * STRIDE_AK + rm[:, newaxis] * STRIDE_AM;
|
||||
@@ -44,11 +53,11 @@ __global__ void dot(TYPE * A __noalias __readonly __aligned(16),
|
||||
// reduction loop
|
||||
float acc[TM, TN] = 0;
|
||||
for(int k = K; k > 0; k -= TK){
|
||||
acc += a @ b;
|
||||
bool checka[TM, TK] = k > TK;
|
||||
bool checkb[TK, TN] = k > TK;
|
||||
pa += TK * STRIDE_AK;
|
||||
pb += TK * STRIDE_BK;
|
||||
acc += a @ b;
|
||||
a = *?(checka)pa;
|
||||
b = *?(checkb)pb;
|
||||
}
|
||||
@@ -56,8 +65,8 @@ __global__ void dot(TYPE * A __noalias __readonly __aligned(16),
|
||||
TYPE c[TM, TN] = acc;
|
||||
|
||||
// epilogue
|
||||
int rxm[TM] = ridx * TM + 0 ... TM;
|
||||
int rxn[TN] = ridy * TN + 0 ... TN;
|
||||
int rxm[TM] = pidm * TM + 0 ... TM;
|
||||
int rxn[TN] = pidn * TN + 0 ... TN;
|
||||
int offc[TM, TN] = rxm[:, newaxis] * ldc + rxn[newaxis, :];
|
||||
TYPE* pc[TM, TN] = C + offc;
|
||||
bool checkc[TM, TN] = (rxm[:, newaxis] < M) && (rxn[newaxis, :] < N);
|
||||
@@ -66,7 +75,7 @@ __global__ void dot(TYPE * A __noalias __readonly __aligned(16),
|
||||
*?(checkc) pc = c;
|
||||
#else
|
||||
// accumulate partial result using spin-locks
|
||||
int *plock = locks + rid;
|
||||
int *plock = locks + pid;
|
||||
int *pcount = plock + get_num_programs(0) * get_num_programs(1);
|
||||
for(int repeat = 1; repeat == 1; repeat = atomic_cas(plock, 0, 1));
|
||||
int count = *pcount;
|
||||
@@ -100,7 +109,7 @@ __global__ void dot(TYPE * A __noalias __readonly __aligned(16),
|
||||
'STRIDE_BN': '1', 'STRIDE_BK': 'ldb',
|
||||
'TM' : [128],
|
||||
'TN' : [128],
|
||||
'TK' : [16],
|
||||
'TK' : [32],
|
||||
'TZ' : [1]
|
||||
}
|
||||
_dot.kernel[dtype] = triton.kernel(_dot.src, num_warps=[4], defines=defines)
|
||||
@@ -109,9 +118,10 @@ __global__ void dot(TYPE * A __noalias __readonly __aligned(16),
|
||||
M, K = a.shape
|
||||
K, N = b.shape
|
||||
c = torch.empty([M,N], dtype=dtype, device=a.device)
|
||||
print(kernel.asm('sass', c.device))
|
||||
print(kernel.asm('ptx', c.device))
|
||||
# enqueue
|
||||
grid = lambda opt: [triton.cdiv(M, opt.d('TM')),
|
||||
triton.cdiv(N, opt.d('TN'))]
|
||||
grid = lambda opt: [triton.cdiv(M, opt.d('TM'))*triton.cdiv(N, opt.d('TN'))]
|
||||
time = kernel(a, b, c, 1., M, N, K,
|
||||
a.stride(0), b.stride(0), c.stride(0), grid=grid)
|
||||
return c
|
||||
@@ -130,6 +140,4 @@ b = torch.rand((K, N)).cuda().half()
|
||||
|
||||
zc = torch.matmul(a,b)
|
||||
zc_ = dot(a,b)
|
||||
|
||||
|
||||
print(torch.allclose(zc, zc_))
|
@@ -111,7 +111,7 @@ setup(
|
||||
author_email='ptillet@g.harvard.edu',
|
||||
description='A language and compiler for custom Deep Learning operations',
|
||||
long_description='',
|
||||
packages=['triton', 'triton/_C', 'triton/ops'],
|
||||
packages=['triton', 'triton/_C'],
|
||||
install_requires=['numpy', 'torch', 'sympy'],
|
||||
package_data={'': data},
|
||||
ext_modules=[CMakeExtension('triton', 'triton/_C/')],
|
||||
|
@@ -38,7 +38,7 @@ void delete_grid(const map_key_t& key) {
|
||||
|
||||
void register_fn(const map_key_t& key,
|
||||
const std::string& src,
|
||||
const rt::function::options_space_t& opt) {
|
||||
const rt::options_space_t& opt) {
|
||||
if(id_fn_map.find(key) == id_fn_map.end())
|
||||
id_fn_map[key].reset(new rt::function(src, opt, ""));
|
||||
}
|
||||
@@ -47,9 +47,9 @@ void delete_fn(const map_key_t& key) {
|
||||
id_fn_map.erase(key);
|
||||
}
|
||||
|
||||
std::string get_fn_ptx(const map_key_t& key, const rt::function::options_t& opt) {
|
||||
triton::driver::cu_device device(torch_get_cuda_device(key.second), false);
|
||||
return id_fn_map[key]->ptx(&device, opt);
|
||||
std::string get_fn_asm(const map_key_t& key, rt::asm_mode_t mode, const rt::options_t& opt) {
|
||||
triton::driver::cu_device device(key.second, false);
|
||||
return id_fn_map[key]->get_asm(mode, &device, opt);
|
||||
}
|
||||
|
||||
void cleanup() {
|
||||
@@ -63,7 +63,7 @@ size_t make_op_id() {
|
||||
|
||||
/* Function signature */
|
||||
void make_module(const std::string& src, ir::module* ir,
|
||||
const runtime::function::options_space_t& opt) {
|
||||
const runtime::options_space_t& opt) {
|
||||
std::string copy = triton::runtime::function::preheader() + src;
|
||||
// pre-process
|
||||
TokenSequence tokens;
|
||||
@@ -80,7 +80,7 @@ void make_module(const std::string& src, ir::module* ir,
|
||||
}
|
||||
|
||||
std::vector<rt::arg_type> get_fn_signature(const std::string& src,
|
||||
const runtime::function::options_space_t& opt) {
|
||||
const runtime::options_space_t& opt) {
|
||||
// triton-ir code-gen
|
||||
ir::context ctx;
|
||||
auto ir = std::shared_ptr<ir::module>(new ir::module("", ctx));
|
||||
@@ -95,8 +95,8 @@ std::vector<rt::arg_type> get_fn_signature(const std::string& src,
|
||||
return ret;
|
||||
}
|
||||
|
||||
typedef triton::runtime::function::options_t options_t;
|
||||
typedef triton::runtime::function::options_space_t options_space_t;
|
||||
typedef triton::runtime::options_t options_t;
|
||||
typedef triton::runtime::options_space_t options_space_t;
|
||||
|
||||
PYBIND11_MODULE(libtriton, m) {
|
||||
m.doc() = "Python bindings to the C++ Triton API";
|
||||
@@ -112,6 +112,10 @@ PYBIND11_MODULE(libtriton, m) {
|
||||
.value("float", rt::FLOAT_T)
|
||||
.value("double", rt::DOUBLE_T)
|
||||
.value("buffer", rt::BUFFER_T);
|
||||
|
||||
pybind11::enum_<rt::asm_mode_t>(m, "asm_mode")
|
||||
.value("ptx", rt::ASM_NV_PTX)
|
||||
.value("sass", rt::ASM_NV_SASS);
|
||||
|
||||
pybind11::class_<options_t>(m, "options")
|
||||
.def(pybind11::init<>())
|
||||
@@ -126,7 +130,7 @@ PYBIND11_MODULE(libtriton, m) {
|
||||
|
||||
// hooks into triton constructs since frameworks may not use pybind11
|
||||
m.def("get_fn_signature", &get_fn_signature);
|
||||
m.def("get_fn_ptx", &get_fn_ptx);
|
||||
m.def("get_fn_asm", &get_fn_asm);
|
||||
m.def("register_grid", ®ister_grid);
|
||||
m.def("delete_grid", &delete_grid);
|
||||
m.def("register_fn", ®ister_fn);
|
||||
|
@@ -59,16 +59,12 @@ void synchronize(int64_t dev_id) {
|
||||
}
|
||||
}
|
||||
|
||||
torch::Tensor raw_like(torch::Tensor x){
|
||||
torch::Tensor cuda_empty_like(torch::Tensor x){
|
||||
if(x.nbytes() == 0)
|
||||
return torch::empty_like(x);
|
||||
C10_CUDA_CHECK(cudaSetDevice(x.device().index()));
|
||||
auto shape = x.sizes();
|
||||
CUdeviceptr data;
|
||||
triton::driver::dispatch::cuMemAlloc(&data, x.nbytes());
|
||||
auto deleter = [data](void* ptr) { triton::driver::dispatch::cuMemFree_v2(data); };
|
||||
auto ret = torch::from_blob((void*)data, shape, deleter, x.options());
|
||||
ret.copy_(x);
|
||||
void* data;
|
||||
cudaMalloc(&data, x.nbytes());
|
||||
auto ret = torch::from_blob((void*)data, x.sizes(), x.strides(), [data](void* ptr) { cudaFree(data); }, x.options());
|
||||
return ret;
|
||||
}
|
||||
|
||||
@@ -94,6 +90,6 @@ void launch_kernel(int64_t op_id, int64_t dev_id, const std::string& args,
|
||||
|
||||
static auto registry = torch::RegisterOperators()
|
||||
.op("triton::launch_kernel", &launch_kernel)
|
||||
.op("triton::raw_like", &raw_like)
|
||||
.op("triton::cuda_empty_like", &cuda_empty_like)
|
||||
.op("triton::cdiv_sum", &cdiv_sum)
|
||||
.op("triton::synchronize", &synchronize);
|
||||
|
@@ -1,7 +1,4 @@
|
||||
from .kernel import *
|
||||
import triton.ops
|
||||
#import triton.nn
|
||||
|
||||
|
||||
# clean-up libtriton resources
|
||||
import atexit
|
||||
|
@@ -68,8 +68,17 @@ class kernel:
|
||||
size = sum([sizes[x] for x in arg_types])
|
||||
self.tys = ''.join([codes[x] for x in arg_types])
|
||||
|
||||
def ptx(self, device, **kwargs):
|
||||
def asm(self, mode, device, **kwargs):
|
||||
dev_id = device.index
|
||||
# assembly mode
|
||||
supported = {
|
||||
'ptx': libtriton.asm_mode.ptx,
|
||||
'sass': libtriton.asm_mode.sass,
|
||||
}
|
||||
if mode not in supported:
|
||||
raise('ASM mode must be in ', supported.keys())
|
||||
mode = supported[mode]
|
||||
# disambiguates #defines
|
||||
libtriton.register_fn((self.op_id, dev_id), self.src, self.opt)
|
||||
def _single_value_or_err(x, key):
|
||||
if isinstance(x, list) and len(x) == 1:
|
||||
@@ -86,15 +95,18 @@ class kernel:
|
||||
opt = libtriton.options()
|
||||
opt.num_warps = _single_value_or_err(self.opt.num_warps, 'num_warps')
|
||||
opt.defines = defines
|
||||
return libtriton.get_fn_ptx((self.op_id, dev_id), opt)
|
||||
# run
|
||||
return libtriton.get_fn_asm((self.op_id, dev_id), mode, opt)
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
if 'TRITON_DEBUG_MODE' in os.environ:
|
||||
_args = args
|
||||
args = [x for x in args]
|
||||
args = [x.clone() if isinstance(x, torch.Tensor) else x for x in _args]
|
||||
for i in range(len(args)):
|
||||
if isinstance(args[i], torch.Tensor):
|
||||
args[i] = torch.ops.triton.raw_like(args[i])
|
||||
args[i] = torch.ops.triton.cuda_empty_like(args[i])
|
||||
args[i].copy_(_args[i])
|
||||
torch.cuda.synchronize()
|
||||
for x in args:
|
||||
if isinstance(x, torch.Tensor):
|
||||
device = x.device.index
|
||||
@@ -116,6 +128,8 @@ class kernel:
|
||||
constants = list(kwargs['constants'].values()) if 'constants' in kwargs else []
|
||||
torch.ops.triton.launch_kernel(self.op_id, device, params, names, constants)
|
||||
if 'TRITON_DEBUG_MODE' in os.environ:
|
||||
torch.cuda.synchronize()
|
||||
for i in range(len(args)):
|
||||
if isinstance(args[i], torch.Tensor):
|
||||
_args[i].copy_(args[i])
|
||||
_args[i].copy_(args[i].clone())
|
||||
args = _args
|
@@ -1,2 +0,0 @@
|
||||
from .conv import replace_conv2d
|
||||
from .attention import replace_mah
|
@@ -1,312 +0,0 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import triton
|
||||
|
||||
def bmm(x, w, mask = None):
|
||||
b, m, k = x.size()
|
||||
b, k, n = w.size()
|
||||
out = torch.empty([b, m, n], device=x.device)
|
||||
triton.ops.einsum('bmk,bkn->bmn', x, w, out, mask=mask, bench=False)
|
||||
return out
|
||||
|
||||
def multi_head_attention_forward(query, # type: Tensor
|
||||
key, # type: Tensor
|
||||
value, # type: Tensor
|
||||
embed_dim_to_check, # type: int
|
||||
num_heads, # type: int
|
||||
in_proj_weight, # type: Tensor
|
||||
in_proj_bias, # type: Tensor
|
||||
bias_k, # type: Optional[Tensor]
|
||||
bias_v, # type: Optional[Tensor]
|
||||
add_zero_attn, # type: bool
|
||||
dropout_p, # type: float
|
||||
out_proj_weight, # type: Tensor
|
||||
out_proj_bias, # type: Tensor
|
||||
training=True, # type: bool
|
||||
key_padding_mask=None, # type: Optional[Tensor]
|
||||
need_weights=True, # type: bool
|
||||
attn_mask=None, # type: Optional[Tensor]
|
||||
use_separate_proj_weight=False, # type: bool
|
||||
q_proj_weight=None, # type: Optional[Tensor]
|
||||
k_proj_weight=None, # type: Optional[Tensor]
|
||||
v_proj_weight=None, # type: Optional[Tensor]
|
||||
static_k=None, # type: Optional[Tensor]
|
||||
static_v=None, # type: Optional[Tensor]
|
||||
acc_bitmask=None
|
||||
):
|
||||
# type: (...) -> Tuple[Tensor, Optional[Tensor]]
|
||||
r"""
|
||||
Args:
|
||||
query, key, value: map a query and a set of key-value pairs to an output.
|
||||
See "Attention Is All You Need" for more details.
|
||||
embed_dim_to_check: total dimension of the model.
|
||||
num_heads: parallel attention heads.
|
||||
in_proj_weight, in_proj_bias: input projection weight and bias.
|
||||
bias_k, bias_v: bias of the key and value sequences to be added at dim=0.
|
||||
add_zero_attn: add a new batch of zeros to the key and
|
||||
value sequences at dim=1.
|
||||
dropout_p: probability of an element to be zeroed.
|
||||
out_proj_weight, out_proj_bias: the output projection weight and bias.
|
||||
training: apply dropout if is ``True``.
|
||||
key_padding_mask: if provided, specified padding elements in the key will
|
||||
be ignored by the attention. This is an binary mask. When the value is True,
|
||||
the corresponding value on the attention layer will be filled with -inf.
|
||||
need_weights: output attn_output_weights.
|
||||
attn_mask: mask that prevents attention to certain positions. This is an additive mask
|
||||
(i.e. the values will be added to the attention layer).
|
||||
use_separate_proj_weight: the function accept the proj. weights for query, key,
|
||||
and value in differnt forms. If false, in_proj_weight will be used, which is
|
||||
a combination of q_proj_weight, k_proj_weight, v_proj_weight.
|
||||
q_proj_weight, k_proj_weight, v_proj_weight, in_proj_bias: input projection weight and bias.
|
||||
static_k, static_v: static key and value used for attention operators.
|
||||
|
||||
|
||||
Shape:
|
||||
Inputs:
|
||||
- query: :math:`(L, N, E)` where L is the target sequence length, N is the batch size, E is
|
||||
the embedding dimension.
|
||||
- key: :math:`(S, N, E)`, where S is the source sequence length, N is the batch size, E is
|
||||
the embedding dimension.
|
||||
- value: :math:`(S, N, E)` where S is the source sequence length, N is the batch size, E is
|
||||
the embedding dimension.
|
||||
- key_padding_mask: :math:`(N, S)`, ByteTensor, where N is the batch size, S is the source sequence length.
|
||||
- attn_mask: :math:`(L, S)` where L is the target sequence length, S is the source sequence length.
|
||||
- static_k: :math:`(N*num_heads, S, E/num_heads)`, where S is the source sequence length,
|
||||
N is the batch size, E is the embedding dimension. E/num_heads is the head dimension.
|
||||
- static_v: :math:`(N*num_heads, S, E/num_heads)`, where S is the source sequence length,
|
||||
N is the batch size, E is the embedding dimension. E/num_heads is the head dimension.
|
||||
|
||||
Outputs:
|
||||
- attn_output: :math:`(L, N, E)` where L is the target sequence length, N is the batch size,
|
||||
E is the embedding dimension.
|
||||
- attn_output_weights: :math:`(N, L, S)` where N is the batch size,
|
||||
L is the target sequence length, S is the source sequence length.
|
||||
"""
|
||||
|
||||
tgt_len, bsz, embed_dim = query.size()
|
||||
assert embed_dim == embed_dim_to_check
|
||||
assert key.size() == value.size()
|
||||
|
||||
head_dim = embed_dim // num_heads
|
||||
assert head_dim * num_heads == embed_dim, "embed_dim must be divisible by num_heads"
|
||||
scaling = float(head_dim) ** -0.5
|
||||
|
||||
|
||||
if not use_separate_proj_weight:
|
||||
if torch.equal(query, key) and torch.equal(key, value):
|
||||
# self-attention
|
||||
q, k, v = F.linear(query, in_proj_weight, in_proj_bias).chunk(3, dim=-1)
|
||||
|
||||
elif torch.equal(key, value):
|
||||
# encoder-decoder attention
|
||||
# This is inline in_proj function with in_proj_weight and in_proj_bias
|
||||
_b = in_proj_bias
|
||||
_start = 0
|
||||
_end = embed_dim
|
||||
_w = in_proj_weight[_start:_end, :]
|
||||
if _b is not None:
|
||||
_b = _b[_start:_end]
|
||||
q = F.linear(query, _w, _b)
|
||||
|
||||
if key is None:
|
||||
assert value is None
|
||||
k = None
|
||||
v = None
|
||||
else:
|
||||
|
||||
# This is inline in_proj function with in_proj_weight and in_proj_bias
|
||||
_b = in_proj_bias
|
||||
_start = embed_dim
|
||||
_end = None
|
||||
_w = in_proj_weight[_start:, :]
|
||||
if _b is not None:
|
||||
_b = _b[_start:]
|
||||
k, v = F.linear(key, _w, _b).chunk(2, dim=-1)
|
||||
|
||||
else:
|
||||
# This is inline in_proj function with in_proj_weight and in_proj_bias
|
||||
_b = in_proj_bias
|
||||
_start = 0
|
||||
_end = embed_dim
|
||||
_w = in_proj_weight[_start:_end, :]
|
||||
if _b is not None:
|
||||
_b = _b[_start:_end]
|
||||
q = F.linear(query, _w, _b)
|
||||
|
||||
# This is inline in_proj function with in_proj_weight and in_proj_bias
|
||||
_b = in_proj_bias
|
||||
_start = embed_dim
|
||||
_end = embed_dim * 2
|
||||
_w = in_proj_weight[_start:_end, :]
|
||||
if _b is not None:
|
||||
_b = _b[_start:_end]
|
||||
k = F.linear(key, _w, _b)
|
||||
|
||||
# This is inline in_proj function with in_proj_weight and in_proj_bias
|
||||
_b = in_proj_bias
|
||||
_start = embed_dim * 2
|
||||
_end = None
|
||||
_w = in_proj_weight[_start:, :]
|
||||
if _b is not None:
|
||||
_b = _b[_start:]
|
||||
v = F.linear(value, _w, _b)
|
||||
else:
|
||||
q_proj_weight_non_opt = torch.jit._unwrap_optional(q_proj_weight)
|
||||
len1, len2 = q_proj_weight_non_opt.size()
|
||||
assert len1 == embed_dim and len2 == query.size(-1)
|
||||
|
||||
k_proj_weight_non_opt = torch.jit._unwrap_optional(k_proj_weight)
|
||||
len1, len2 = k_proj_weight_non_opt.size()
|
||||
assert len1 == embed_dim and len2 == key.size(-1)
|
||||
|
||||
v_proj_weight_non_opt = torch.jit._unwrap_optional(v_proj_weight)
|
||||
len1, len2 = v_proj_weight_non_opt.size()
|
||||
assert len1 == embed_dim and len2 == value.size(-1)
|
||||
|
||||
if in_proj_bias is not None:
|
||||
q = F.linear(query, q_proj_weight_non_opt, in_proj_bias[0:embed_dim])
|
||||
k = F.linear(key, k_proj_weight_non_opt, in_proj_bias[embed_dim:(embed_dim * 2)])
|
||||
v = F.linear(value, v_proj_weight_non_opt, in_proj_bias[(embed_dim * 2):])
|
||||
else:
|
||||
q = F.linear(query, q_proj_weight_non_opt, in_proj_bias)
|
||||
k = F.linear(key, k_proj_weight_non_opt, in_proj_bias)
|
||||
v = F.linear(value, v_proj_weight_non_opt, in_proj_bias)
|
||||
q = q * scaling
|
||||
|
||||
if bias_k is not None and bias_v is not None:
|
||||
if static_k is None and static_v is None:
|
||||
k = torch.cat([k, bias_k.repeat(1, bsz, 1)])
|
||||
v = torch.cat([v, bias_v.repeat(1, bsz, 1)])
|
||||
if attn_mask is not None:
|
||||
attn_mask = torch.cat([attn_mask,
|
||||
torch.zeros((attn_mask.size(0), 1),
|
||||
dtype=attn_mask.dtype,
|
||||
device=attn_mask.device)], dim=1)
|
||||
if key_padding_mask is not None:
|
||||
key_padding_mask = torch.cat(
|
||||
[key_padding_mask, torch.zeros((key_padding_mask.size(0), 1),
|
||||
dtype=key_padding_mask.dtype,
|
||||
device=key_padding_mask.device)], dim=1)
|
||||
else:
|
||||
assert static_k is None, "bias cannot be added to static key."
|
||||
assert static_v is None, "bias cannot be added to static value."
|
||||
else:
|
||||
assert bias_k is None
|
||||
assert bias_v is None
|
||||
|
||||
q = q.contiguous().view(tgt_len, bsz * num_heads, head_dim).transpose(0, 1)
|
||||
if k is not None:
|
||||
k = k.contiguous().view(-1, bsz * num_heads, head_dim).transpose(0, 1)
|
||||
if v is not None:
|
||||
v = v.contiguous().view(-1, bsz * num_heads, head_dim).transpose(0, 1)
|
||||
|
||||
if static_k is not None:
|
||||
assert static_k.size(0) == bsz * num_heads
|
||||
assert static_k.size(2) == head_dim
|
||||
k = static_k
|
||||
|
||||
if static_v is not None:
|
||||
assert static_v.size(0) == bsz * num_heads
|
||||
assert static_v.size(2) == head_dim
|
||||
v = static_v
|
||||
|
||||
src_len = k.size(1)
|
||||
|
||||
if key_padding_mask is not None:
|
||||
assert key_padding_mask.size(0) == bsz
|
||||
assert key_padding_mask.size(1) == src_len
|
||||
|
||||
if add_zero_attn:
|
||||
src_len += 1
|
||||
k = torch.cat([k, torch.zeros((k.size(0), 1) + k.size()[2:], dtype=k.dtype, device=k.device)], dim=1)
|
||||
v = torch.cat([v, torch.zeros((v.size(0), 1) + v.size()[2:], dtype=v.dtype, device=v.device)], dim=1)
|
||||
if attn_mask is not None:
|
||||
attn_mask = torch.cat([attn_mask, torch.zeros((attn_mask.size(0), 1),
|
||||
dtype=attn_mask.dtype,
|
||||
device=attn_mask.device)], dim=1)
|
||||
if key_padding_mask is not None:
|
||||
key_padding_mask = torch.cat(
|
||||
[key_padding_mask, torch.zeros((key_padding_mask.size(0), 1),
|
||||
dtype=key_padding_mask.dtype,
|
||||
device=key_padding_mask.device)], dim=1)
|
||||
|
||||
|
||||
attn_output_weights = bmm(q, k.transpose(1, 2), mask=acc_bitmask)
|
||||
assert list(attn_output_weights.size()) == [bsz * num_heads, tgt_len, src_len]
|
||||
|
||||
if attn_mask is not None:
|
||||
attn_mask = attn_mask.unsqueeze(0)
|
||||
attn_output_weights += attn_mask
|
||||
|
||||
if key_padding_mask is not None:
|
||||
attn_output_weights = attn_output_weights.view(bsz, num_heads, tgt_len, src_len)
|
||||
attn_output_weights = attn_output_weights.masked_fill(
|
||||
key_padding_mask.unsqueeze(1).unsqueeze(2),
|
||||
float('-inf'),
|
||||
)
|
||||
attn_output_weights = attn_output_weights.view(bsz * num_heads, tgt_len, src_len)
|
||||
|
||||
attn_output_weights = F.softmax(
|
||||
attn_output_weights, dim=-1)
|
||||
attn_output_weights = F.dropout(attn_output_weights, p=dropout_p, training=training)
|
||||
|
||||
attn_output = bmm(attn_output_weights, v, mask=acc_bitmask)
|
||||
assert list(attn_output.size()) == [bsz * num_heads, tgt_len, head_dim]
|
||||
attn_output = attn_output.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim)
|
||||
attn_output = F.linear(attn_output, out_proj_weight, out_proj_bias)
|
||||
|
||||
if need_weights:
|
||||
# average attention weights over heads
|
||||
attn_output_weights = attn_output_weights.view(bsz, num_heads, tgt_len, src_len)
|
||||
return attn_output, attn_output_weights.sum(dim=1) / num_heads
|
||||
else:
|
||||
return attn_output, None
|
||||
|
||||
class MultiheadAttention(nn.modules.activation.MultiheadAttention):
|
||||
|
||||
def __init__(self, embed_dim, num_heads, dropout=0., bias=True, add_bias_kv=False, add_zero_attn=False, kdim=None, vdim=None, acc_bitmask=None):
|
||||
super(MultiheadAttention, self).__init__(embed_dim, num_heads, dropout, bias, add_bias_kv, add_zero_attn, kdim, vdim)
|
||||
self.acc_bitmask = acc_bitmask
|
||||
|
||||
|
||||
def forward(self, query, key, value, key_padding_mask=None,
|
||||
need_weights=True, attn_mask=None):
|
||||
if not self._qkv_same_embed_dim:
|
||||
return multi_head_attention_forward(
|
||||
query, key, value, self.embed_dim, self.num_heads,
|
||||
self.in_proj_weight, self.in_proj_bias,
|
||||
self.bias_k, self.bias_v, self.add_zero_attn,
|
||||
self.dropout, self.out_proj.weight, self.out_proj.bias,
|
||||
training=self.training,
|
||||
key_padding_mask=key_padding_mask, need_weights=need_weights,
|
||||
attn_mask=attn_mask, use_separate_proj_weight=True,
|
||||
q_proj_weight=self.q_proj_weight, k_proj_weight=self.k_proj_weight,
|
||||
v_proj_weight=self.v_proj_weight,
|
||||
acc_bitmask=self.acc_bitmask)
|
||||
else:
|
||||
return multi_head_attention_forward(
|
||||
query, key, value, self.embed_dim, self.num_heads,
|
||||
self.in_proj_weight, self.in_proj_bias,
|
||||
self.bias_k, self.bias_v, self.add_zero_attn,
|
||||
self.dropout, self.out_proj.weight, self.out_proj.bias,
|
||||
training=self.training,
|
||||
key_padding_mask=key_padding_mask, need_weights=need_weights,
|
||||
attn_mask=attn_mask,
|
||||
acc_bitmask=self.acc_bitmask)
|
||||
|
||||
|
||||
def replace_mah(model, mask = None):
|
||||
for child_name, child in model.named_children():
|
||||
if isinstance(child, nn.modules.activation.MultiheadAttention):
|
||||
add_bias_kv = child.bias_k is not None
|
||||
device = child.in_proj_weight.device
|
||||
mah = MultiheadAttention(child.embed_dim, child.num_heads,
|
||||
dropout=child.dropout, add_bias_kv=add_bias_kv,
|
||||
add_zero_attn=child.add_zero_attn, kdim=child.kdim,
|
||||
vdim=child.vdim, acc_bitmask=mask).to(device)
|
||||
for yparam, xparam in zip(mah.parameters(), child.parameters()):
|
||||
yparam.data.copy_(xparam.data)
|
||||
setattr(model, child_name, mah)
|
||||
else:
|
||||
replace_mah(child, mask)
|
@@ -1,166 +0,0 @@
|
||||
import triton
|
||||
import torch.nn as nn
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
class _conv2d(torch.autograd.Function):
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, x, weight, bias,
|
||||
stride, padding, dilation, groups,
|
||||
acc_bitmask):
|
||||
assert dilation == (1, 1)
|
||||
assert groups == 1
|
||||
assert bias == None
|
||||
pad_h, pad_w = padding
|
||||
stride_h, stride_w = stride
|
||||
n, c, h, w = x.size()
|
||||
k, c, r, s = weight.size()
|
||||
# allocate output
|
||||
p = (h + 2*padding[0] - r)//stride[0] + 1
|
||||
q = (w + 2*padding[1] - s)//stride[1] + 1
|
||||
output = torch.empty((n, k, p, q), dtype=x.dtype, device=x.device)
|
||||
# padding
|
||||
if pad_h or pad_w:
|
||||
x = triton.ops._einsum.pad(x, [pad_w, pad_w, pad_h, pad_h])
|
||||
# convolution
|
||||
triton.ops.einsum(f'nc(h*stride_h + r - pad_h)(w*stride_w + s - pad_w),kcrs->nkhw',
|
||||
x, weight, mask=acc_bitmask,
|
||||
output=output,
|
||||
values = {'pad_h': pad_h,
|
||||
'stride_h': stride_h,
|
||||
'pad_w': pad_w,
|
||||
'stride_w': stride_w})
|
||||
# prepare backprop
|
||||
ctx.save_for_backward(x, weight)
|
||||
ctx.stride = stride
|
||||
ctx.padding = padding
|
||||
ctx.acc_bitmask = acc_bitmask
|
||||
# return
|
||||
return output
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, dy):
|
||||
# retrieve contextual information
|
||||
x, weight = ctx.saved_tensors
|
||||
stride = ctx.stride
|
||||
padding = ctx.padding
|
||||
acc_bitmask = ctx.acc_bitmask
|
||||
# gradient of the input
|
||||
dx = None
|
||||
if ctx.needs_input_grad[0]:
|
||||
# dy must be padded
|
||||
n, k, p, q = dy.size()
|
||||
n, c, h, w = x.size()
|
||||
k, c, r, s = weight.size()
|
||||
dypad = triton.ops._einsum.pad(dy, [4, 4, 4, 4])
|
||||
# have to be careful here
|
||||
# the gradient of strided conv is a conv over a sparse image
|
||||
# which can be decomposed as a set of smaller convs
|
||||
dx = torch.empty_like(x)
|
||||
for offh in range(stride[0]):
|
||||
for offw in range(stride[1]):
|
||||
poffh = (offh + padding[0]) % stride[0]
|
||||
poffw = (offw + padding[1]) % stride[1]
|
||||
pad_h = int((padding[0] + (stride[0] - 1)*offh) / stride[0])
|
||||
pad_w = int((padding[1] + (stride[1] - 1)*offw) / stride[1])
|
||||
if poffh >= r or poffw >= s:
|
||||
dx[:, :, offh::stride[0], offw::stride[1]] = 0
|
||||
else:
|
||||
triton.ops.einsum(f'nk(h - r + pad_h)(w - s + pad_w),kcrs->nchw',
|
||||
dypad[:, :, :, :],
|
||||
weight[:, :, poffh::stride[0], poffw::stride[1]],
|
||||
output = dx[:, :, offh::stride[0], offw::stride[1]],
|
||||
mask = acc_bitmask,
|
||||
values = {'pad_h': pad_h,
|
||||
'pad_w': pad_w})
|
||||
|
||||
# gradient for the weight
|
||||
dw = None
|
||||
if ctx.needs_input_grad[1]:
|
||||
dw = torch.empty_like(weight)
|
||||
triton.ops.einsum(f'nc(p*{stride[0]}+r-{padding[0]})(q*{stride[1]}+s-{padding[1]}),nkpq->kcrs',
|
||||
x, dy, output = dw, mask = acc_bitmask)
|
||||
#print('dw: ', dw.view(-1)[0])
|
||||
return dx, dw, None, None, None, None, None, None
|
||||
conv2d = _conv2d.apply
|
||||
|
||||
class Conv2d(nn.Conv2d):
|
||||
|
||||
def __init__(self, in_channels, out_channels, kernel_size, stride=1,
|
||||
padding=0, dilation=1, groups=1,
|
||||
bias=True, padding_mode='zeros',
|
||||
acc_bitmask = None):
|
||||
super(Conv2d, self).__init__(
|
||||
in_channels, out_channels, kernel_size, stride, padding, dilation,
|
||||
groups, bias, padding_mode)
|
||||
self.acc_bitmask = acc_bitmask
|
||||
|
||||
def forward(self, input):
|
||||
#if self.kernel_size[0] == 3 and self.stride[0] != 1:
|
||||
#print(self.padding, self.stride, input.size(), self.weight.size())
|
||||
# return F.conv2d(input, self.weight, self.bias, self.stride, self.padding, self.dilation, self.groups)
|
||||
return conv2d(input, self.weight, self.bias, self.stride,
|
||||
self.padding, self.dilation, self.groups,
|
||||
self.acc_bitmask)
|
||||
|
||||
|
||||
def replace_conv2d(model, acc_bitmask = None):
|
||||
for child_name, child in model.named_children():
|
||||
if isinstance(child, nn.Conv2d):
|
||||
conv2d = Conv2d(child.in_channels, child.out_channels, child.kernel_size,
|
||||
child.stride, child.padding, child.dilation, child.groups,
|
||||
child.bias, child.padding_mode,
|
||||
acc_bitmask=acc_bitmask)
|
||||
for yparam, xparam in zip(conv2d.parameters(), child.parameters()):
|
||||
yparam.data.copy_(xparam.data)
|
||||
setattr(model, child_name, conv2d)
|
||||
else:
|
||||
replace_conv2d(child, acc_bitmask)
|
||||
|
||||
# initialize input
|
||||
#N, C, H, W, K, RS = 16, 32, 24, 24, 64, 3
|
||||
#torch.Size([128, 64, 30, 30]) torch.Size([128, 64, 3, 3])
|
||||
#torch.Size([128, 128, 15, 15]) torch.Size([256, 128, 3, 3])
|
||||
#torch.Size([128, 256, 8, 8]) torch.Size([512, 256, 3, 3])
|
||||
|
||||
if __name__ == '__main__':
|
||||
N, C, H, W, K, RS = 128, 64, 30, 30, 128, 3
|
||||
#N, C, H, W, K, RS = 128, 128, 15, 15, 256, 3
|
||||
#N, C, H, W, K, RS = 128, 256, 8, 8, 512, 3
|
||||
pad, stride = 0, 1
|
||||
torch.manual_seed(0)
|
||||
x = torch.randn((N, C, H, W)).cuda()
|
||||
x.requires_grad_(True)
|
||||
#x.data[:] = 1
|
||||
# initialize layers
|
||||
torch.manual_seed(0)
|
||||
rconv2d = nn.Conv2d(C, K, RS, stride, pad, bias=False).cuda()
|
||||
torch.manual_seed(0)
|
||||
tconv2d = Conv2d(C, K, RS, stride, pad, bias=False).cuda()
|
||||
#rconv2d.weight.data[:] = 1
|
||||
#tconv2d.weight.data[:] = 1
|
||||
ry = rconv2d(x)
|
||||
ty = tconv2d(x)
|
||||
# reference
|
||||
dy = torch.randn(ry.size()).cuda()
|
||||
#dy.data[:] = 1
|
||||
ry.backward(dy)
|
||||
rdx = x.grad.clone()
|
||||
rdw = rconv2d.weight.grad.clone()
|
||||
x.grad.zero_()
|
||||
# triton
|
||||
ty.backward(dy)
|
||||
tdx = x.grad.clone()
|
||||
tdw = tconv2d.weight.grad.clone()
|
||||
x.grad.zero_()
|
||||
# print error
|
||||
diff = lambda x, y: (x - y).abs().max()
|
||||
print(diff(ry, ty))
|
||||
print(diff(rdx, tdx))
|
||||
print(diff(rdw, tdw))
|
||||
#print((rdx - tdx).abs())
|
||||
|
||||
#print((rdx[0,0,:,:] - tdx[0,0,:,:]))
|
||||
#print(rdx[0,0,:,:])
|
||||
#print(tdx[0,0,:,:])
|
@@ -1,13 +0,0 @@
|
||||
import torch
|
||||
import triton
|
||||
|
||||
|
||||
def linear(x, w, bias = None):
|
||||
print(x.size(), w.size())
|
||||
m, k = x.size()
|
||||
k, n = w.size()
|
||||
out = torch.empty([m, n], device=x.device)
|
||||
triton.ops.einsum('mk,nk->mn', x, w, bias)
|
||||
if bias is not None:
|
||||
out += bias
|
||||
return out
|
@@ -1,2 +0,0 @@
|
||||
from .einsum import _einsum, einsum
|
||||
from .batchnorm import _batchnorm, batchnorm
|
@@ -1,136 +0,0 @@
|
||||
import triton
|
||||
import torch
|
||||
import math
|
||||
|
||||
class _batchnorm(torch.autograd.Function):
|
||||
|
||||
fwd_src = """
|
||||
void fwdbatchnorm(float *Y, float *M, float *V,
|
||||
float *X, float *G, float *B,
|
||||
int N, float eps) {
|
||||
// pointers
|
||||
int c = get_program_id(1);
|
||||
int rm[TM] = 0 ... TM;
|
||||
float *px[TM] = X + rm + c*N;
|
||||
float* py[TM] = Y + rm + c*N;
|
||||
|
||||
// compute mean
|
||||
float accm[TM] = 0;
|
||||
for(int i = 0; i < N; i = i + TM)
|
||||
accm = accm + *(px + i);
|
||||
float mean = (float)accm[+] / N;
|
||||
*(M + c) = mean;
|
||||
|
||||
// compute variance
|
||||
float accv[TM] = 0;
|
||||
for(int i = 0; i < N; i = i + TM){
|
||||
float x[TM] = *(px + i);
|
||||
x = x - mean;
|
||||
accv = accv + x*x;
|
||||
}
|
||||
float var = (float)accv[+] / N;
|
||||
*(V + c) = var;
|
||||
|
||||
// Normalize batch
|
||||
float gamma = *(G + c);
|
||||
float beta = *(B + c);
|
||||
float rstdg = 1 / sqrtf(var + eps) * gamma;
|
||||
for(int i = 0; i < N; i = i + TM){
|
||||
float x[TM] = *(px + i);
|
||||
float y[TM] = (x - mean)*rstdg + beta;
|
||||
*(py + i) = y;
|
||||
}
|
||||
}
|
||||
"""
|
||||
|
||||
bwd_src = """
|
||||
void bwdbatchnorm(float *DX, float *DG, float *DB,
|
||||
float *DY, float *X, float *G,
|
||||
float *M, float *V,
|
||||
int N, float epsilon) {
|
||||
|
||||
// pointers
|
||||
int c = get_program_id(1);
|
||||
int rx[TM] = 0 ... TM;
|
||||
int offset = c*N;
|
||||
float* px[TM] = X + rx + offset;
|
||||
float* pdy[TM] = DY + rx + offset;
|
||||
float* pdx[TM] = DX + rx + offset;
|
||||
|
||||
// fetch statistics
|
||||
float gamma = *(G + c);
|
||||
float mean = *(M + c);
|
||||
float var = *(V + c);
|
||||
float rstd = 1 / sqrtf(var + epsilon);
|
||||
|
||||
// compute dgamma and dbeta
|
||||
float acc_dg[TM] = 0;
|
||||
float acc_db[TM] = 0;
|
||||
for(int i = 0; i < N; i = i + TM){
|
||||
float x[TM] = *(px + i);
|
||||
float dy[TM] = *(pdy + i);
|
||||
acc_dg += dy*(x - mean)*rstd;
|
||||
acc_db += dy;
|
||||
}
|
||||
float dg = acc_dg[+];
|
||||
float db = acc_db[+];
|
||||
*(DG + c) = dg;
|
||||
*(DB + c) = db;
|
||||
|
||||
// compute dx
|
||||
for(int i = 0; i < N; i = i + TM){
|
||||
float x[TM] = *(px + i);
|
||||
float dy[TM] = *(pdy + i);
|
||||
float xhat[TM] = (x - mean) * rstd;
|
||||
float xtmp[TM] = (xhat * dg + db) / N;
|
||||
float dx[TM] = (dy - xtmp) * rstd * gamma;
|
||||
*(pdx + i) = dx;
|
||||
}
|
||||
}
|
||||
"""
|
||||
|
||||
fwd_kernel = None
|
||||
bwd_kernel = None
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, x, gamma, beta, eps):
|
||||
# lazy compilation of kernel
|
||||
if _batchnorm.fwd_kernel is None:
|
||||
_batchnorm.fwd_kernel = triton.kernel(fwd_src, defines = {'TM': 128})
|
||||
# shapes
|
||||
shape = triton.shape(x)
|
||||
dtype = x.dtype
|
||||
# allocate outputs
|
||||
C, H, W, B = shape[0], shape[1], shape[2], shape[3]
|
||||
y = triton.empty(shape, dtype=dtype)
|
||||
mean = triton.empty([C], dtype=dtype)
|
||||
var = triton.empty([C], dtype=dtype)
|
||||
# execute kernels
|
||||
_batchnorm.fwd_kernel(y, mean, var, x, gamma, beta, H*W*B, eps,
|
||||
grid = lambda opt: [1, C])
|
||||
# save
|
||||
ctx.save_for_backward(x, gamma, beta, mean, var)
|
||||
ctx.eps = eps
|
||||
return y
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, dy):
|
||||
# lazy compilation of kernel
|
||||
if _batchnorm.bwd_kernel is None:
|
||||
_batchnorm.bwd_kernel = triton.kernel(bwd_src, defines = {'TN': 128})
|
||||
# retrieve info
|
||||
x, gamma, beta, mean, var = ctx.saved_tensors
|
||||
eps = ctx.eps
|
||||
# allocate result
|
||||
dx = triton.empty(triton.shape(x), dtype=x.dtype)
|
||||
dgamma = triton.empty(triton.shape(gamma), dtype=gamma.dtype)
|
||||
dbeta = triton.empty(triton.shape(beta), dtype=beta.dtype)
|
||||
# execute
|
||||
C, H, W, B = triton.shape(x)
|
||||
_batchnorm.bwd_kernel(dx, dgamma, dbeta, dy,
|
||||
x, gamma, mean, var,
|
||||
H*W*B, eps,
|
||||
grid = lambda opt: [1, C])
|
||||
return dx, dgamma, dbeta, None
|
||||
|
||||
batchnorm = _batchnorm.apply
|
@@ -1,794 +0,0 @@
|
||||
from math import ceil, log2
|
||||
from enum import IntEnum
|
||||
from functools import reduce
|
||||
from operator import mul
|
||||
from collections import OrderedDict
|
||||
from collections import namedtuple
|
||||
import re
|
||||
import string
|
||||
import triton
|
||||
import torch
|
||||
# numpy -- ideally removed in a future release
|
||||
import numpy as np
|
||||
# sympy -- ideally removed in a future release
|
||||
import sympy as sp
|
||||
from sympy.parsing.sympy_parser import parse_expr
|
||||
from sympy.printing.ccode import C89CodePrinter
|
||||
|
||||
|
||||
class _einsum(torch.autograd.Function):
|
||||
|
||||
|
||||
#############################
|
||||
## Triton-C code generation
|
||||
#############################
|
||||
def print_cc(expr, axes_0, axes_1, axes_2, prefix):
|
||||
if expr in axes_0:
|
||||
return f'{prefix}r{expr}[:, newaxis, newaxis]'
|
||||
if expr in axes_1:
|
||||
return f'{prefix}r{expr}[newaxis, :, newaxis]'
|
||||
if expr in axes_2:
|
||||
return f'{prefix}r{expr}[newaxis, newaxis, :]'
|
||||
return expr
|
||||
|
||||
def unpack_cc(tile, axes, prefix, remat):
|
||||
ret = ''
|
||||
axes = list(map(str, axes))
|
||||
for i, d in enumerate(reversed(axes)):
|
||||
if i == len(axes) - 1:
|
||||
break
|
||||
currs = ''.join(axes[: len(axes) - i])
|
||||
nexts = ''.join(axes[: len(axes) - (i + 1)])
|
||||
ty = '' if remat else 'int '
|
||||
sz = '' if remat or tile is None else f'[{tile}]'
|
||||
ret += f' {ty}{prefix}{nexts}{sz} = r{currs} / dim_{d};\n'
|
||||
ret += f' {ty}{prefix}{d}{sz} = r{currs} % dim_{d};\n'
|
||||
return ret
|
||||
|
||||
def strides_cc(name, expr):
|
||||
ret = [f'stride_{name}_{d}' for d in expr[:-1]] + ['1']
|
||||
ret = dict(zip(expr, ret))
|
||||
return ret
|
||||
|
||||
def make_kernel(name, dtype,
|
||||
expr_a, expr_b, expr_c,
|
||||
sparse_a, sparse_b, sparse_c,
|
||||
axes_m, axes_n, axes_k, axes_b,
|
||||
multipleof_a, multipleof_b, multipleof_c,
|
||||
stride_a_last, stride_b_last, stride_c_last,
|
||||
lut_mode_a, lut_mode_b,
|
||||
delta_a, delta_b,
|
||||
blocks):
|
||||
|
||||
use_lut_a = True
|
||||
use_lut_b = True
|
||||
|
||||
outer_sparse_a = [x for x in expr_a if x in sparse_a and x not in axes_k]
|
||||
outer_dense_a = [x for x in expr_a if x not in sparse_a and x not in axes_k]
|
||||
outer_sparse_b = [x for x in expr_b if x in sparse_b and x not in axes_k]
|
||||
outer_dense_b = [x for x in expr_b if x not in sparse_b and x not in axes_k]
|
||||
outer_dense_c = [x for x in expr_c if x not in sparse_c and x not in axes_k]
|
||||
|
||||
|
||||
src = ""
|
||||
|
||||
if use_lut_a and lut_mode_a == _einsum.LUT_MODE.CONSTANT:
|
||||
src += f"""
|
||||
char __constant__* AD = calloc({4*len(delta_a)});"""
|
||||
if use_lut_b and lut_mode_b == _einsum.LUT_MODE.CONSTANT:
|
||||
src += f"""
|
||||
char __constant__* BD = calloc({4*len(delta_b)});"""
|
||||
|
||||
|
||||
src += f"""
|
||||
__global__ void {name}(
|
||||
TYPE * A __noalias __readonly __aligned(16)
|
||||
, TYPE * B __noalias __readonly __aligned(16)
|
||||
, TYPE * C
|
||||
, int * locks
|
||||
, float alpha
|
||||
, int matmul_m, int matmul_n, int matmul_k __multipleof(16)
|
||||
, int div_m
|
||||
"""
|
||||
for dim in [axes_m, axes_n, axes_k, axes_b]:
|
||||
for d in dim:
|
||||
src += f", int dim_{d}"
|
||||
src += "\n "
|
||||
for dim, name, mult, sparse in zip([expr_a, expr_b, expr_c],
|
||||
['a', 'b', 'c'],
|
||||
[multipleof_a, multipleof_b, multipleof_c],
|
||||
[sparse_a, sparse_b, sparse_c]):
|
||||
for d in range(len(dim) - 1):
|
||||
if sparse and dim[d] == sparse[0]:
|
||||
src += f', int stride_{name}_block __multipleof({mult})'
|
||||
src += f", int stride_{name}_{d} __multipleof({mult})"
|
||||
src += "\n "
|
||||
if lut_mode_a == _einsum.LUT_MODE.SCALAR:
|
||||
src += f", int stride_a_inner __multipleof({multipleof_a})"
|
||||
src += f", int rem_delta_a __multipleof({multipleof_a})"
|
||||
elif sparse_a or lut_mode_a == _einsum.LUT_MODE.DRAM:
|
||||
src += ", int* AD __noalias __readonly __aligned(16)"
|
||||
src += "\n "
|
||||
if lut_mode_b == _einsum.LUT_MODE.SCALAR:
|
||||
src += f", int stride_b_inner __multipleof({multipleof_b})"
|
||||
src += f", int rem_delta_b __multipleof({multipleof_b})"
|
||||
elif sparse_b or lut_mode_b == _einsum.LUT_MODE.DRAM:
|
||||
src += ", int* BD"
|
||||
src += "\n "
|
||||
if sparse_c:
|
||||
src += ", int* CD"
|
||||
if sparse_a or sparse_b:
|
||||
src += ", int width"
|
||||
src += """) {
|
||||
|
||||
|
||||
// program identifiers
|
||||
int pid_0 = get_program_id(0);
|
||||
int pid_1 = get_program_id(1);
|
||||
|
||||
"""
|
||||
if sparse_a:
|
||||
src += f"""
|
||||
int off_n = pid_0 / width;
|
||||
int off_header = pid_0 % width;
|
||||
int* header = AD + off_header * {2 + len(outer_sparse_a)};
|
||||
int* pdelta = AD + *(header + 0);
|
||||
matmul_k = *(header + 1);"""
|
||||
for i, d in enumerate(outer_sparse_a):
|
||||
src += f"""
|
||||
int off_{d} = *(header + {2 + i});"""
|
||||
src += f"""
|
||||
int inca = *(pdelta + 0);
|
||||
int incb = *(pdelta + 1);
|
||||
int off_{''.join(map(str, outer_dense_a))} = pid_1;
|
||||
"""
|
||||
_einsum.unpack_cc(None, outer_dense_a, "off_", False)
|
||||
elif sparse_b:
|
||||
src += f"""
|
||||
int off_m = pid_0 / width;
|
||||
int off_header = pid_0 % width;
|
||||
int* header = BD + off_header * {2 + len(outer_sparse_b)};
|
||||
int* pdelta = BD + *(header + 0);
|
||||
matmul_k = *(header + 1);"""
|
||||
for i, d in enumerate(outer_sparse_b):
|
||||
src += f"""
|
||||
int off_{d} = *(header + {2 + i});"""
|
||||
src += f"""
|
||||
int incb = *(pdelta + 0);
|
||||
int inca = *(pdelta + 1);
|
||||
int off_{''.join(map(str, outer_dense_b))} = pid_1;
|
||||
"""
|
||||
_einsum.unpack_cc(None, outer_dense_b, "off_", False)
|
||||
elif sparse_c:
|
||||
src += f"""
|
||||
// load LUT header
|
||||
int *header = CD + pid_0 * {len(sparse_c)};"""
|
||||
for i, d in enumerate(sparse_c):
|
||||
src += f"""
|
||||
int off_{d} = *(header + {i});"""
|
||||
src += f"""
|
||||
int off_{''.join(map(str, outer_dense_c))} = pid_1;"""
|
||||
else:
|
||||
src += """
|
||||
// re-order outer program ids
|
||||
int grid_m = (matmul_m + TM - 1) / TM;
|
||||
int grid_n = (matmul_n + TN - 1) / TN;
|
||||
int off_mn = pid_0 / div_m;
|
||||
int off_n = off_mn % grid_n;
|
||||
int off_m = (off_mn / grid_n)*div_m + (pid_0 % div_m);
|
||||
int off_b = get_program_id(1);"""
|
||||
|
||||
src += """
|
||||
#if TZ == 1
|
||||
int off_k = 0;
|
||||
#else
|
||||
// get reduction sub-group program id
|
||||
int pid_z = get_program_id(2);
|
||||
int grid_z = get_num_programs(2);
|
||||
int div_z = matmul_k / TZ;
|
||||
int rem_z = matmul_k % TZ;
|
||||
int off_k = pid_z * div_z;
|
||||
matmul_k = select(pid_z < rem_z, div_z, div_z + rem_z);
|
||||
#endif
|
||||
int rem_k = matmul_k % TK;
|
||||
|
||||
// create ranges
|
||||
"""
|
||||
|
||||
sparse = sparse_a + sparse_b + sparse_c
|
||||
for axes, tile, off, prefixes in zip([axes_m, axes_n, axes_b, axes_k],
|
||||
['TM', 'TN', 'TB', 'TK'],
|
||||
['off_m*TM', 'off_n*TN', 'off_b*TB', 'off_k'],
|
||||
[['a', 'c'], ['b', 'c'], ['a', 'b', 'c'], ['a', 'b']]):
|
||||
if not axes:
|
||||
continue
|
||||
currs = ''.join(map(str,axes))
|
||||
has_sparse_component = set(axes) & set(sparse)
|
||||
if has_sparse_component:
|
||||
src += f" int r{currs}[{tile}] = 0 ... {tile};\n"
|
||||
src += _einsum.unpack_cc(tile, axes, f'r', False)
|
||||
else:
|
||||
src += f" int r{currs}[{tile}] = {off} + 0 ... {tile};\n"
|
||||
src += _einsum.unpack_cc(tile, axes, f'r', False)
|
||||
for pfx in prefixes:
|
||||
for d in axes:
|
||||
is_dense_dim = d not in sparse
|
||||
is_dense_storage = (pfx == 'a' and not sparse_a) or\
|
||||
(pfx == 'b' and not sparse_b) or\
|
||||
(pfx == 'c' and not sparse_c)
|
||||
if not is_dense_dim and is_dense_storage:
|
||||
src += f" int {pfx}r{d}[{tile}] = off_{d} * BLOCK{d.upper()} + r{d};\n"
|
||||
elif is_dense_dim and has_sparse_component:
|
||||
src += f" int {pfx}r{d}[{tile}] = off_{d};\n"
|
||||
else:
|
||||
src += f" int {pfx}r{d}[{tile}] = r{d};\n"
|
||||
|
||||
src += f"""
|
||||
// initialize pointers to A
|
||||
int offa[TM, TK, TB] = {'inca' if sparse_a or sparse_b else '0'} """
|
||||
for i, sym in enumerate(expr_a):
|
||||
ccode = _einsum.print_cc(sym, axes_m, axes_k, axes_b, 'a')
|
||||
stride = f'stride_a_{i}' if i < len(expr_a) - 1 else f'{stride_a_last}'
|
||||
src += f" + ({ccode}) * {stride}\n "
|
||||
src += ';'
|
||||
|
||||
src += """
|
||||
TYPE *pa[TM, TK, TB] = A + offa;"""
|
||||
|
||||
|
||||
if not sparse_a and not sparse_b and use_lut_a and not lut_mode_a == _einsum.LUT_MODE.SCALAR:
|
||||
spec = '__constant__' if lut_mode_a == _einsum.LUT_MODE.CONSTANT else ''
|
||||
cast = '(int __constant__*)' if lut_mode_a == _einsum.LUT_MODE.CONSTANT else ''
|
||||
src += f"""
|
||||
int offadelta[TK] = off_k + 0 ... TK;
|
||||
int {spec} *padelta[TK] = {cast}AD + offadelta;
|
||||
int incda[TM, TK, TB] = (*padelta)[newaxis, :, newaxis];"""
|
||||
|
||||
src += f"""
|
||||
|
||||
// initialize pointers to B
|
||||
int offb[TK, TN, TB] = {'incb' if sparse_a or sparse_b else '0'}"""
|
||||
for i, sym in enumerate(expr_b):
|
||||
ccode = _einsum.print_cc(sym, axes_k, axes_n, axes_b, 'b')
|
||||
stride = f'stride_b_{i}' if i < len(expr_b) - 1 else f'{stride_b_last}'
|
||||
src += f" + ({ccode}) * {stride}\n "
|
||||
src += ';'
|
||||
|
||||
src += """
|
||||
TYPE *pb[TK, TN, TB] = B + offb;"""
|
||||
|
||||
|
||||
if not sparse_a and not sparse_b and use_lut_b and not lut_mode_b == _einsum.LUT_MODE.SCALAR:
|
||||
spec = '__constant__' if lut_mode_b == _einsum.LUT_MODE.CONSTANT else ''
|
||||
cast = '(int __constant__*)' if lut_mode_b == _einsum.LUT_MODE.CONSTANT else ''
|
||||
src += f"""
|
||||
// initialize pointers to B look-up table
|
||||
int offbdelta[TK] = off_k + 0 ... TK;
|
||||
int *pbdelta[TK] = BD + offbdelta;"""
|
||||
|
||||
|
||||
rk = 'r{}'.format(''.join(map(str,axes_k)))
|
||||
src += f"""
|
||||
|
||||
// prefetch
|
||||
int prefetch_k = select(rem_k > 0, rem_k, TK);
|
||||
bool checkam[TM] = ar""" + ''.join(map(str,axes_m)) + f""" < matmul_m;
|
||||
bool checkbn[TN] = br""" + ''.join(map(str,axes_n)) + f""" < matmul_n;
|
||||
bool checkk[TK] = r{''.join(map(str, axes_k))} < prefetch_k;
|
||||
bool checka[TM, TK, TB] = checkam[:, newaxis, newaxis] && checkk[newaxis, :, newaxis];
|
||||
bool checkb[TK, TN, TB] = checkk[:, newaxis, newaxis] && checkbn[newaxis, :, newaxis];
|
||||
TYPE a[TM, TK, TB] = checka ? *pa : 0;
|
||||
TYPE b[TK, TN, TB] = checkb ? *pb : 0;"""
|
||||
|
||||
if sparse_a:
|
||||
src += f"""
|
||||
// update pointers to look-up tables
|
||||
pdelta += 2;
|
||||
int incda = *(pdelta + 0);
|
||||
int incdb = *(pdelta + 1);
|
||||
pa += incda;
|
||||
pb += incdb;"""
|
||||
if sparse_b:
|
||||
src += f"""
|
||||
// update pointers to look-up tables
|
||||
pdelta += 2;
|
||||
int incdb = *(pdelta + 0);
|
||||
int incda = *(pdelta + 1);
|
||||
pa += incda;
|
||||
pb += incdb;"""
|
||||
|
||||
if not sparse_a and not sparse_b and lut_mode_a == _einsum.LUT_MODE.SCALAR:
|
||||
src += """
|
||||
pa += rem_delta_a;"""
|
||||
elif not sparse_a and not sparse_b:
|
||||
src += """
|
||||
pa += incda;
|
||||
padelta += TK;
|
||||
incda = (*padelta)[newaxis, :, newaxis];"""
|
||||
|
||||
if not sparse_a and not sparse_b and lut_mode_b == _einsum.LUT_MODE.SCALAR:
|
||||
src += """
|
||||
pb += rem_delta_b;"""
|
||||
elif not sparse_a and not sparse_b:
|
||||
src += """
|
||||
pb += (*pbdelta)[:, newaxis, newaxis];
|
||||
pbdelta += TK;"""
|
||||
|
||||
src += f"""
|
||||
|
||||
// accumulate
|
||||
float acc[TM, TN, TB] = 0;
|
||||
for(int k = matmul_k; k > 0; k -= TK) {{
|
||||
acc += a @ b;
|
||||
|
||||
// load inputs
|
||||
checkk = k > TK;
|
||||
checka = checkam[:, newaxis, newaxis] && checkk[newaxis, :, newaxis];
|
||||
checkb = checkk[:, newaxis, newaxis] && checkbn[newaxis, :, newaxis];
|
||||
a = *?(checka)pa;
|
||||
b = *?(checkb)pb;
|
||||
|
||||
// update pointers"""
|
||||
if sparse_a:
|
||||
src += """
|
||||
pdelta += 2;
|
||||
incda = *(pdelta + 0);
|
||||
incdb = *(pdelta + 1);
|
||||
pa += incda;
|
||||
pb += incdb;
|
||||
"""
|
||||
elif sparse_b:
|
||||
src += """
|
||||
pdelta += 2;
|
||||
incdb = *(pdelta + 0);
|
||||
incda = *(pdelta + 1);
|
||||
pa += incda;
|
||||
pb += incdb;
|
||||
"""
|
||||
else:
|
||||
if lut_mode_a == _einsum.LUT_MODE.SCALAR:
|
||||
src += """
|
||||
pa += stride_a_inner;"""
|
||||
else:
|
||||
src += """
|
||||
pa += incda;
|
||||
padelta += TK;
|
||||
incda = (*padelta)[newaxis, :, newaxis];"""
|
||||
if lut_mode_b == _einsum.LUT_MODE.SCALAR:
|
||||
src += """
|
||||
pb += stride_b_inner;"""
|
||||
else:
|
||||
src += """
|
||||
pb += (*pbdelta)[:, newaxis, newaxis];
|
||||
pbdelta += TK;"""
|
||||
|
||||
src += f"""
|
||||
}}
|
||||
TYPE c[TM, TN, TB] = acc;
|
||||
|
||||
// initialize pointers to C
|
||||
int offc[TM, TN, TB] = {'pid_0*TN*TN' if sparse_c else 0}"""
|
||||
for i, sym in enumerate(expr_c):
|
||||
stride = f'stride_c_{i}' if i < len(expr_c) - 1 else f'{stride_c_last}'
|
||||
ccode = _einsum.print_cc(sym, axes_m, axes_n, axes_b, 'c')
|
||||
src += f"\n + ({ccode}) * {stride}"
|
||||
src += ';'
|
||||
|
||||
src += """
|
||||
TYPE *pc[TM, TN, TB] = C + offc;
|
||||
|
||||
// bounds-checking
|
||||
bool checkcm[TM] = cr""" + ''.join(map(str,axes_m)) + """ < matmul_m;
|
||||
bool checkcn[TN] = cr""" + ''.join(map(str,axes_n)) + """ < matmul_n;
|
||||
bool checkc[TM, TN, TB] = checkcm[:, newaxis, newaxis] &&
|
||||
checkcn[newaxis, :, newaxis];
|
||||
|
||||
// write back
|
||||
#if TZ == 1
|
||||
*?(checkc)pc = c;
|
||||
#else
|
||||
int *plock = locks + pid_mn + pid_b * get_num_programs(0);
|
||||
int *pcount = plock + 1024*1024;
|
||||
// spin
|
||||
for(int repeat = 1; repeat == 1; repeat = atomic_cas(plock, 0, 1));
|
||||
int count = *pcount;
|
||||
if(count == 0)
|
||||
*?(checkc)pc = c;
|
||||
else
|
||||
*?(checkc)pc = c + *?(checkc)pc;
|
||||
atomic_xchg(pcount, (count + 1) % (grid_z));
|
||||
atomic_xchg(plock, 0);
|
||||
#endif
|
||||
}
|
||||
"""
|
||||
# compilation options
|
||||
TM, TN, TB, TZ = [32], [32], 1, [1]
|
||||
TK = 16 if dtype==torch.float16 else 8
|
||||
defines = {'TM': TM, 'TN': TN, 'TB': TB, 'TK': TK, 'TZ': TZ, 'TYPE': dtype}
|
||||
for d, B in blocks.items():
|
||||
defines[f'BLOCK{d}'] = B
|
||||
# create kernel
|
||||
ret = triton.kernel(src, defines=defines)
|
||||
# set constant
|
||||
if use_lut_a and lut_mode_a == _einsum.LUT_MODE.CONSTANT:
|
||||
ret.set_constant('AD', delta_a)
|
||||
if use_lut_b and lut_mode_b == _einsum.LUT_MODE.CONSTANT:
|
||||
ret.set_constant('BD', delta_b)
|
||||
return ret
|
||||
|
||||
############################
|
||||
## Look-up Table
|
||||
############################
|
||||
|
||||
class LUT_MODE(IntEnum):
|
||||
SCALAR = 1
|
||||
CONSTANT = 2
|
||||
DRAM = 3
|
||||
|
||||
def lut_mode(delta):
|
||||
if delta.size == 0 or np.min(delta) == np.max(delta):
|
||||
return _einsum.LUT_MODE.SCALAR
|
||||
#if delta.size < 4096:
|
||||
# return _einsum.LUT_MODE.CONSTANT
|
||||
return _einsum.LUT_MODE.DRAM
|
||||
|
||||
def symbolic_delta(symbols, axes):
|
||||
rank = len(symbols)
|
||||
strides = [sp.symbols(f'stride{d}') for d in range(rank)]
|
||||
nexts = {s: sp.symbols(f'next{s}') for s in axes}
|
||||
delta = 0
|
||||
for i in range(rank):
|
||||
delta += strides[i] * (symbols[i].subs(nexts) - symbols[i])
|
||||
return delta
|
||||
|
||||
def unpack_offset(k, axes, dims):
|
||||
ret = dict()
|
||||
for d in reversed(axes):
|
||||
ret[d] = k % dims[d]
|
||||
k = k // dims[d]
|
||||
return ret
|
||||
|
||||
|
||||
def make_dsd_delta(axes, step, stride, dims, symbols, sparse, layout, blocks):
|
||||
# depth of reductions
|
||||
depth = layout.sum(*[i for i, d in enumerate(sparse) if d in axes])
|
||||
# outer dimension indices
|
||||
outer = torch.nonzero(depth, as_tuple=False)
|
||||
outer = [outer[:,i] for i in range(outer.shape[1])]
|
||||
# find offset of outer dimensions
|
||||
depth = depth.view(-1)
|
||||
offsets = torch.zeros_like(depth)
|
||||
offsets[1:] = torch.cumsum(depth[:-1], 0)
|
||||
# compute delta for b
|
||||
# TODO: support multiple sparse red indices
|
||||
col = next((i for i, d in enumerate(sparse) if d in axes), None)
|
||||
block = blocks[sparse[-1].upper()]
|
||||
div = block // step
|
||||
delta_b = torch.nonzero(layout.transpose(-1, col), as_tuple=False)[:, -1].reshape(-1).contiguous()
|
||||
delta_b *= block
|
||||
delta_b = [delta_b + step*i for i in range(div)]
|
||||
delta_b = torch.stack(delta_b, dim=1)
|
||||
delta_b = delta_b.view(-1)
|
||||
# compute delta for a
|
||||
bstride = 1
|
||||
for d in sparse[::-1]:
|
||||
if d in axes:
|
||||
break
|
||||
bstride *= blocks[d.upper()]
|
||||
order = [d for d in sparse if d not in axes] +\
|
||||
[d for d in sparse if d in axes]
|
||||
idx = [sparse.index(d) for d in order]
|
||||
layout[layout > 0] = 1 + torch.arange(layout.sum(), device=layout.device)
|
||||
layout = layout.permute(*idx)
|
||||
delta_a = layout[layout > 0] - 1
|
||||
delta_a *= np.prod(list(blocks.values()))
|
||||
saved = delta_a[offsets]
|
||||
delta_a[1:] = delta_a[1:] - delta_a[:-1]
|
||||
delta_a = delta_a.view(-1, 1).repeat(1, div)
|
||||
delta_a[:, 1:] = step*bstride
|
||||
delta_a[:, 0] -= (div - 1)*step*bstride
|
||||
delta_a[offsets, 0] = saved
|
||||
delta_a = delta_a.view(-1)
|
||||
delta = torch.stack((delta_a, delta_b), dim=1).view(-1).contiguous()
|
||||
# form look-up table
|
||||
depth *= blocks[symbols[-1].upper()]
|
||||
offsets *= div
|
||||
header = torch.stack((offsets, depth, *outer), dim=1).view(-1).contiguous()
|
||||
nouter = 2 + len(outer)
|
||||
header[::nouter] = header[::nouter]*2 + header.shape[0]
|
||||
lut = torch.cat((header, delta)).int().int().cpu().numpy()
|
||||
return lut, nouter, _einsum.LUT_MODE.DRAM
|
||||
|
||||
def make_delta(axes, step, stride, dims, symbols, sparse, layout, lut = None, nouter = None):
|
||||
# symbolic pointer increments
|
||||
symbols = [sp.symbols(x) for x in symbols]
|
||||
delta = _einsum.symbolic_delta(symbols, axes)
|
||||
args = [f'stride{d}' for d in range(len(stride))]
|
||||
args += [f'{sk}' for sk in axes]
|
||||
args += [f'next{sk}' for sk in axes]
|
||||
fn = sp.lambdify(args, delta, 'numpy')
|
||||
if lut is None:
|
||||
# inner axes values
|
||||
inner = [dims[d] for d in axes]
|
||||
inner = np.prod(inner)
|
||||
rem = inner % step
|
||||
rem = rem if rem > 0 else step
|
||||
# k = [0, 1, ..., step, rem, rem + 1, ... rem + inner]
|
||||
# nextk = [rem, 1 + rem, ..., step + rem, rem + step, rem + 1 + step, ..., inner + step]
|
||||
k = np.concatenate((np.arange(step), np.arange(rem, inner))).astype(np.int32)
|
||||
nextk = np.concatenate((k[:step] + rem, k[step:] + step))
|
||||
else:
|
||||
idx = (lut[:lut[0]:nouter] - lut[0])//2
|
||||
k = lut[lut[0]+1::2]
|
||||
k = np.insert(k, idx, 0)
|
||||
nextk = k[1:]
|
||||
k = k[:-1]
|
||||
# offsets
|
||||
off = _einsum.unpack_offset(k, axes, dims)
|
||||
nextoff = _einsum.unpack_offset(nextk, axes, dims)
|
||||
# evaluate deltas
|
||||
args = [s for s in stride]
|
||||
args += [off[sk] for sk in axes]
|
||||
args += [nextoff[sk] for sk in axes]
|
||||
delta = fn(*args)
|
||||
delta = np.maximum(delta, 0)
|
||||
if lut is not None:
|
||||
idx = idx[1:] + np.arange(idx.shape[0] - 1)
|
||||
delta = np.delete(delta, idx)
|
||||
lut[lut[0]+1::2] = delta
|
||||
return None, None
|
||||
return delta, _einsum.lut_mode(delta[step:-step])
|
||||
|
||||
@staticmethod
|
||||
def make_sdd_lut(layout_c, sparse_c, blocks):
|
||||
nnz = torch.nonzero(layout_c, as_tuple=False)
|
||||
lut = nnz.reshape(-1).int().cuda()
|
||||
return lut
|
||||
|
||||
############################
|
||||
## Compilation
|
||||
############################
|
||||
|
||||
class instance:
|
||||
|
||||
locks = None
|
||||
kernel_cache = dict()
|
||||
|
||||
def __init__(self, einsum, dtype, stride_a, stride_b, shape_a, shape_b, layouts, blocks):
|
||||
# parse symbols
|
||||
expr_a, expr_bc = einsum.split(",")
|
||||
expr_b, expr_c = expr_bc.split("->")
|
||||
sym_a = expr_a.lower()
|
||||
sym_b = expr_b.lower()
|
||||
sym_c = expr_c.lower()
|
||||
sparse_a = [x.lower() for x in expr_a if x.isupper()]
|
||||
sparse_b = [x.lower() for x in expr_b if x.isupper()]
|
||||
sparse_c = [x.lower() for x in expr_c if x.isupper()]
|
||||
layout_a = layouts.get(expr_a)
|
||||
layout_b = layouts.get(expr_b)
|
||||
layout_c = layouts.get(expr_c)
|
||||
# parse axes
|
||||
axes_b = [d for d in sym_a if d in sym_b and d in sym_c]
|
||||
axes_m = [d for d in sym_a if d not in sym_b and d in sym_c]
|
||||
axes_k = [d for d in sym_a if d in sym_b and d not in sym_c]
|
||||
axes_n = [d for d in sym_b if d not in sym_a and d in sym_c]
|
||||
axes = axes_b + axes_m + axes_n + axes_k
|
||||
# check block sizes
|
||||
for d in sparse_a + sparse_b + sparse_c:
|
||||
if d.upper() not in blocks:
|
||||
raise ValueError(f'unspecified block size for dimension: {d.upper()}')
|
||||
# check layout is present
|
||||
if sparse_a and layout_a is None:
|
||||
raise ValueError('A is sparse but not layout provided')
|
||||
if sparse_b and layout_b is None:
|
||||
raise ValueError('B is sparse but not layout provided')
|
||||
if sparse_c and layout_c is None:
|
||||
raise ValueError('C is sparse but not layout provided')
|
||||
# check dimensions
|
||||
dims_a = dict([(x, y) for x,y in zip(sym_a, shape_a) if x not in sparse_a])
|
||||
dims_b = dict([(x, y) for x,y in zip(sym_b, shape_b) if x not in sparse_b])
|
||||
dims_La = None if layout_a is None else dict(zip([x for x in expr_a if x.isupper()], layout_a.shape))
|
||||
dims_Lb = None if layout_b is None else dict(zip([x for x in expr_b if x.isupper()], layout_b.shape))
|
||||
# TODO: could be cleaner
|
||||
read_shape = lambda d, dimsT, dimsL, sparse: dimsL[d.upper()] * blocks[d.upper()] if d in sparse else dimsT[d]
|
||||
for d in axes_b + axes_m + axes_n + axes_k:
|
||||
dim_a = read_shape(d, dims_a, dims_La, sparse_a) if d in sym_a else None
|
||||
dim_b = read_shape(d, dims_b, dims_Lb, sparse_b) if d in sym_b else None
|
||||
if d in axes_b and dim_a and dim_b and dim_a != dim_b:
|
||||
raise ValueError(f'incomparible batch dimension {d} (A: {dim_a}, B: {dim_b})')
|
||||
if d in axes_k and dim_a and dim_b and dim_a != dim_b:
|
||||
raise ValueError(f'incompatible inner dimension {d} (A: {dim_a}, B: {dim_b})')
|
||||
dims = dict()
|
||||
dims.update(dims_a)
|
||||
dims.update(dims_b)
|
||||
for i, d in enumerate(sparse_a):
|
||||
dims[d] = layout_a.shape[i] * blocks[d.upper()]
|
||||
for i, d in enumerate(sparse_b):
|
||||
dims[d] = layout_b.shape[i] * blocks[d.upper()]
|
||||
# allocate output
|
||||
shape_c = [dims[d] if d.islower() else blocks[d] for d in expr_c]
|
||||
if sparse_c:
|
||||
shape_c.insert(expr_c.index(sparse_c[0].upper()), int(layout_c.sum()))
|
||||
stride_c = [None] * len(shape_c)
|
||||
stride_c[-1] = 1
|
||||
for i in reversed(range(len(shape_c) - 1)):
|
||||
stride_c[i] = stride_c[i+1] * shape_c[i+1]
|
||||
# look-up tables
|
||||
TK = 16 if dtype == torch.float16 else 8
|
||||
if sparse_a and not sparse_b:
|
||||
delta_a, nouter, lut_mode_a = _einsum.make_dsd_delta(axes_k, TK, stride_a, dims, sym_a, sparse_a, layout_a, blocks)
|
||||
delta_b, lut_mode_b = _einsum.make_delta(axes_k, TK, stride_b, dims, sym_b, sparse_b, layout_b, delta_a, nouter)
|
||||
if sparse_b and not sparse_a:
|
||||
delta_b, nouter, lut_mode_b = _einsum.make_dsd_delta(axes_k, TK, stride_b, dims, sym_b, sparse_b, layout_b, blocks)
|
||||
delta_a, lut_mode_a = _einsum.make_delta(axes_k, TK, stride_a, dims, sym_a, sparse_a, layout_a, delta_b, nouter)
|
||||
if not sparse_a and not sparse_b:
|
||||
delta_a, lut_mode_a = _einsum.make_delta(axes_k, TK, stride_a, dims, sym_a, sparse_a, layout_a)
|
||||
delta_b, lut_mode_b = _einsum.make_delta(axes_k, TK, stride_b, dims, sym_b, sparse_b, layout_b)
|
||||
if sparse_c:
|
||||
delta_c = _einsum.make_sdd_lut(layout_c, sparse_c, blocks)
|
||||
# hash for recompilation
|
||||
stride_a_multiple = max([x for x in [1, 2, 4, 8] if shape_a[-1] % x == 0])
|
||||
stride_b_multiple = max([x for x in [1, 2, 4, 8] if shape_b[-1] % x == 0])
|
||||
stride_c_multiple = max([x for x in [1, 2, 4, 8] if shape_c[-1] % x == 0])
|
||||
stride_a_last = stride_a[-1]
|
||||
stride_b_last = stride_b[-1]
|
||||
stride_c_last = stride_c[-1]
|
||||
name = f'{dtype}_{expr_a}_{expr_b}_{expr_c}_{lut_mode_a}_{lut_mode_b}'\
|
||||
f'_{stride_a_multiple}_{stride_b_multiple}_{stride_c_multiple}'\
|
||||
f'_{stride_a_last}_{stride_b_last}_{stride_c_last}'
|
||||
# recompile if necessary
|
||||
cache = _einsum.instance.kernel_cache
|
||||
if name not in cache:
|
||||
cachesize = len(cache)
|
||||
cache[name] = _einsum.make_kernel(f'__einsum{cachesize}',
|
||||
dtype,
|
||||
sym_a, sym_b, sym_c,
|
||||
sparse_a, sparse_b, sparse_c,
|
||||
axes_m, axes_n, axes_k, axes_b,
|
||||
stride_a_multiple, stride_b_multiple, stride_c_multiple,
|
||||
stride_a_last, stride_b_last, stride_c_last,
|
||||
lut_mode_a, lut_mode_b,
|
||||
delta_a, delta_b,
|
||||
blocks)
|
||||
self.kernel = cache[name]
|
||||
# Initialize locks
|
||||
if _einsum.instance.locks is None:
|
||||
_einsum.instance.locks = torch.zeros(2*1024*1024, dtype=torch.int32).cuda()
|
||||
# Kernel arguments
|
||||
dim_m = [dims[d] for d in axes_m]
|
||||
dim_n = [dims[d] for d in axes_n]
|
||||
dim_k = [dims[d] for d in axes_k]
|
||||
dim_b = [dims[d] for d in axes_b]
|
||||
M = reduce(mul, dim_m, 1)
|
||||
N = reduce(mul, dim_n, 1)
|
||||
K = reduce(mul, dim_k, 1)
|
||||
B = reduce(mul, [dims[d] for d in axes_b if d.upper() not in einsum], 1)
|
||||
stride_a = list(stride_a[:-1])
|
||||
stride_b = list(stride_b[:-1])
|
||||
stride_c = list(stride_c[:-1])
|
||||
alpha = 1.
|
||||
div_m = 1
|
||||
self.args = [None, None, None]
|
||||
self.args += [_einsum.instance.locks]
|
||||
self.args += [alpha, M, N, K, div_m]
|
||||
self.args += dim_m
|
||||
self.args += dim_n
|
||||
self.args += dim_k
|
||||
self.args += dim_b
|
||||
self.args += stride_a
|
||||
self.args += stride_b
|
||||
self.args += stride_c
|
||||
# LUT for A
|
||||
if lut_mode_a == _einsum.LUT_MODE.SCALAR:
|
||||
self.args += [delta_a[TK], delta_a[0]]
|
||||
elif sparse_a or lut_mode_a == _einsum.LUT_MODE.DRAM:
|
||||
self.args += [torch.from_numpy(delta_a).cuda()]
|
||||
# LUT for B
|
||||
if lut_mode_b == _einsum.LUT_MODE.SCALAR:
|
||||
self.args += [delta_b[TK], delta_b[0]]
|
||||
elif sparse_b or lut_mode_b == _einsum.LUT_MODE.DRAM:
|
||||
self.args += [torch.from_numpy(delta_b).cuda()]
|
||||
# LUT for C
|
||||
if sparse_c:
|
||||
self.args += [delta_c]
|
||||
if sparse_a or sparse_b:
|
||||
width = delta_a[0] // nouter if sparse_a else delta_b[0] // nouter
|
||||
self.args += [width]
|
||||
# Grid
|
||||
if sparse_a:
|
||||
self.grid = lambda opt: [width*triton.cdiv(N, opt.d('TN')), B, opt.d('TZ')]
|
||||
elif sparse_b:
|
||||
self.grid = lambda opt: [width*triton.cdiv(M, opt.d('TM')), B, opt.d('TZ')]
|
||||
elif sparse_c:
|
||||
width = int(layout_c.sum())
|
||||
self.grid = lambda opt: [width, B, opt.d('TZ')]
|
||||
else:
|
||||
self.grid = lambda opt: [triton.cdiv(M, opt.d('TM')) *
|
||||
triton.cdiv(N, opt.d('TN')),
|
||||
triton.cdiv(B, opt.d('TB')),
|
||||
opt.d('TZ')]
|
||||
# position of dynamic arguments
|
||||
self.pos_a = 0
|
||||
self.pos_b = 1
|
||||
self.pos_c = 2
|
||||
# save information on the operation
|
||||
self.expr_a = expr_a
|
||||
self.expr_b = expr_b
|
||||
self.expr_c = expr_c
|
||||
self.matmul_B = B
|
||||
self.matmul_M = M
|
||||
self.matmul_N = N
|
||||
self.matmul_K = K
|
||||
# output shape
|
||||
self.shape_c = shape_c
|
||||
|
||||
|
||||
def run(self, a, b):
|
||||
c = torch.empty(*self.shape_c, dtype=a.dtype, device=a.device)
|
||||
self.args[self.pos_a] = a
|
||||
self.args[self.pos_b] = b
|
||||
self.args[self.pos_c] = c
|
||||
self.kernel(*self.args, grid=self.grid)
|
||||
return c
|
||||
|
||||
|
||||
|
||||
|
||||
############################
|
||||
## Forward
|
||||
############################
|
||||
|
||||
instance_cache = dict()
|
||||
registry = dict()
|
||||
@staticmethod
|
||||
def forward(ctx, expr, a, b, layouts, blocks):
|
||||
# compile einsum instance
|
||||
cache = _einsum.instance_cache
|
||||
key = (expr, a.dtype,
|
||||
a.stride(), b.stride(),
|
||||
a.shape , b.shape)
|
||||
if key not in cache:
|
||||
cache[key] = _einsum.instance(expr, a.dtype,
|
||||
a.stride(), b.stride(),
|
||||
a.shape , b.shape ,
|
||||
layouts, blocks)
|
||||
instance = cache[key]
|
||||
# run and mark as dirty c modified in-place
|
||||
c = instance.run(a, b)
|
||||
# save information in context
|
||||
ctx.expr_a = instance.expr_a
|
||||
ctx.expr_b = instance.expr_b
|
||||
ctx.expr_c = instance.expr_c
|
||||
ctx.matmul_B = instance.matmul_B
|
||||
ctx.matmul_M = instance.matmul_M
|
||||
ctx.matmul_N = instance.matmul_N
|
||||
ctx.matmul_K = instance.matmul_K
|
||||
ctx.save_for_backward(a, b)
|
||||
return c
|
||||
|
||||
|
||||
############################
|
||||
## Backward
|
||||
############################
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, dy):
|
||||
a, b = ctx.saved_tensors
|
||||
expr_a = ctx.expr_a
|
||||
expr_b = ctx.expr_b
|
||||
expr_c = ctx.expr_c
|
||||
# gradient of first argument
|
||||
da = None
|
||||
if ctx.needs_input_grad[1]:
|
||||
da = torch.empty_like(a)
|
||||
einsum(f'{expr_c},{expr_b}->{expr_a}', dy, b, da)
|
||||
# gradient of second argument
|
||||
db = None
|
||||
if ctx.needs_input_grad[2]:
|
||||
db = torch.empty_like(b)
|
||||
einsum(f'{expr_a},{expr_c}->{expr_b}', a, dy, db)
|
||||
return None, da, db, None, None, None, None, None, None, None
|
||||
|
||||
|
||||
def einsum(expr, a, b, layouts = None, blocks = dict()):
|
||||
return _einsum.apply(expr, a, b, layouts, blocks)
|
Reference in New Issue
Block a user