[GENERAL] Merged v1.0alpha into master. Added features are:

- A100 support via mma.16816
- Thread swizzling for conflict-free shared memory accesses without
padding
- Complete overhaul of the LLVM code generation in
codegen/selection/generator.cc to remove overengineering
- Added debugging capabilities in the Python binding
- Compilation error for kernels that spill
This commit is contained in:
Philippe Tillet
2021-01-11 19:20:34 -05:00
parent c0bc7ed8b0
commit 083bbd1e8d
75 changed files with 2688 additions and 4512 deletions

View File

@@ -1,56 +0,0 @@
import triton
import numpy as np
from enum import Enum
class MODE(Enum):
TF = 1
TORCH = 2
try:
import tensorflow as tf
mode = MODE.TF
except ModuleNotFoundError:
pass
try:
import torch
mode = MODE.TORCH
except ModuleNotFoundError:
pass
C, H, W, B = 32, 1, 1, 128
x = np.random.uniform(-1, 1, (C, H, W, B)).astype(np.float32)
gamma = np.random.uniform(-1, 1, C).astype(np.float32)
beta = np.random.uniform(-1, 1, C).astype(np.float32)
dy = np.random.uniform(-1, 1, (C, H, W, B)).astype(np.float32)
if mode == MODE.TORCH:
fw_x = torch.from_numpy(x).cuda()
fw_gamma = torch.from_numpy(gamma).cuda()
fw_beta = torch.from_numpy(beta).cuda()
fw_dy = torch.from_numpy(dy).cuda()
# register gradients
fw_x.requires_grad_(True)
fw_gamma.requires_grad_(True)
fw_beta.requires_grad_(True)
# execute
fw_y = triton.ops.batchnorm(fw_x, fw_gamma, fw_beta, 1e-4)
fw_y.backward(fw_dy)
if mode == MODE.TF:
fw_x = tf.placeholder(shape=x.shape, dtype=x.dtype)
fw_gamma = tf.placeholder(shape=gamma.shape, dtype=gamma.dtype)
fw_beta = tf.placeholder(shape=beta.shape, dtype=beta.dtype)
fw_dy = tf.placeholder(shape=dy.shape, dtype=dy.dtype)
# execute
fw_y = triton.ops.batchnorm(fw_x, fw_gamma, fw_beta, 1e-4)
fw_mean, fw_var = tf.nn.moments(fw_x, [1, 2, 3])
fw_dx, fw_dgamma, fw_dbeta = tf.gradients(fw_y, [fw_x, fw_gamma, fw_beta], fw_dy)
# run
sess = tf.InteractiveSession()
feed_dict = {fw_x: x, fw_gamma: gamma, fw_beta: beta, fw_dy: dy}
sess.run(tf.global_variables_initializer())
result = sess.run([fw_dx, fw_dgamma, fw_dbeta], feed_dict=feed_dict)
print(result)

View File

@@ -1,213 +0,0 @@
import triton
import torch
from torch.utils.cpp_extension import load
import numpy as np
#import utils
from time import time
torch.manual_seed(0)
#torch.backends.cudnn.benchmark = True
configs = []
# Matrix multiplication
MNK = [
(512, 512 ,512),
(2048, 2048, 2048),
#(8192, 8192, 8192),
(64, 64, 64000),
(64, 64, 128000),
(256, 256, 64000),
(256, 256, 128000),
(1536, 16, 1536),
(1536, 32, 1536),
(1536, 64, 1536),
# (1536, 128, 1536),
# (4096, 16, 4096),
# (4096, 32, 4096),
# (4096, 64, 4096),
# (4096, 128, 4096),
# (127008, 768, 576)
]
for M, N, K in MNK:
matmul = lambda a, b: torch.matmul(a, b)
configs += [([M, K], [K, N], [M, N], matmul, 'mk,kn->mn', dict(), None, None, None)]
#for M, N, K in MNK:
# matmul = lambda a, b: torch.matmul(a.t(), b)
# configs += [([M, K], [M, N], [K, N], None, 'mk,mn->kn', dict(), None, None, None)]
#for M, N, K in MNK:
# matmul = lambda a, b: torch.matmul(a, b.t())
# configs += [([M, N], [K, N], [M, K], None, 'mn,kn->mk', dict(), None, None, None)]
# Relative attention
NTHSE = [
(16, 512, 1, 64, 64),
# (16, 512, 1, 128, 128),
# (16, 512, 1, 256, 256),
# (16, 512, 1, 256, 512),
(16, 512, 8, 64, 64),
# (16, 512, 8, 128, 128),
# (16, 512, 8, 256, 256),
# (16, 512, 8, 256, 512),
# (64, 1024, 1, 64, 64),
(64, 1024, 1, 128, 128),
# (64, 1024, 1, 256, 256),
# (64, 1024, 1, 256, 512),
# (64, 1024, 8, 64, 64),
(64, 1024, 8, 128, 128),
# (64, 1024, 8, 256, 256),
# (64, 1024, 8, 256, 512),
# (128, 1024, 1, 64, 64),
# (128, 1024, 1, 128, 128),
# (128, 1024, 1, 256, 256),
(128, 1024, 1, 256, 512),
# (128, 1024, 8, 64, 64),
# (128, 1024, 8, 128, 128),
# (128, 1024, 8, 256, 256),
#(128, 1024, 8, 256, 512)
]
#for N, T, H, S, E in NTHSE:
# configs += [([N, T, H, S], [H, E, S], [N, H, T, E], None, 'nths,hes->nhte', dict(), None, None, None)]
#for N, T, H, S, E in NTHSE:
# configs += [([N, H, T, E], [N, T, H, S], [H, E, S], None, 'nhte,nths->hes', dict(), None, None, None)]
#for N, T, H, S, E in NTHSE:
# configs += [([N, H, T, E], [H, E, S], [N, T, H, S], None, 'nhte,hes->nths', dict(), None, None, None)]
# 1D Dense convolution
NCHKR = [
#(1, 1152, 12602, 512, 3)
]
for N, C, H, K, R in NCHKR:
torch_fn = lambda a, b: torch.nn.functional.conv1d(a, b.permute(2, 0, 1))
configs += [([N, C, H],
[C, R, K],
[N, K, H - R + 1],
torch_fn,
'nc(h+r),crk->nkh',
dict(), None, None, None)]
# 2D Dense convolution
NCHWKRS = [
#(8, 64, 128, 128, 768, 3, 3),
#(128, 3, 32, 32, 64, 3, 3),
#(1, 1024, 32, 112, 112, 1024, 3, 3),
#(8, 512, 32, 32, 1024, 3, 3)
]
for N, C, G, H, W, K, R, S in NCHWKRS:
stride = 2
torch_fn = lambda a, b: torch.nn.functional.conv2d(a, b.permute(3, 0, 1, 2), stride=stride, groups=G)
P = (H - R + 1) // stride
Q = (W - S + 1) // stride
transform_a = lambda a: a.view(N, G, C // G, H, W)
transform_b = lambda b: b.view(C // G, R, S, G, K // G)
transform_c = lambda c: c.view(N, K, P, Q)
configs += [([N, C, H, W],
[C // G, R, S, K],
[N, G, K // G, P, Q],
torch_fn,
'ngc(h*2+r)(w*2+s),crsgk->ngkhw',
dict(), transform_a, transform_b, transform_c)]
# 3D Dense Convolution
NCDHWKTRS = [
#(8, 32, 27, 100, 100, 64, 3, 3, 3),
#(8, 64, 23, 48, 48, 256, 3, 3, 3),
#(8, 256, 19, 22, 22, 640, 3, 3, 3),
#(8, 640, 15, 36, 36, 384, 3, 3, 3)
]
for N, C, D, H, W, K, T, R, S in NCDHWKTRS:
torch_fn = lambda a, b: torch.nn.functional.conv3d(a, b.permute(4, 0, 1, 2, 3))
configs += [([N, C, D, H, W],
[C, T, R, S, K],
[N, K, D - T + 1, H - R + 1, W - R + 1],
torch_fn,
'nc(d+t)(h+r)(w+s),ctrsk->nkdhw',
dict(), None, None, None)]
# Shift convolution
shift_cuda = torch.utils.cpp_extension.load(
'shift_cuda', ['kernels/shift_cuda.cpp',
'kernels/shift_cuda_kernel.cu'],
extra_cflags=['-O3'])
class shift(torch.autograd.Function):
@staticmethod
def forward(ctx, x, shift):
ctx.save_for_backward(shift)
return shift_cuda.forward(x, shift)
@staticmethod
def backward(ctx, grad_output):
shift, = ctx.saved_tensors
grad_output = shift_cuda.backward(grad_output, shift)
return grad_output, None
NCHWKRS = [
#(8, 64, 128, 128, 128, 3, 3),
#(8, 128, 64, 64, 256, 3, 3),
#(8, 256, 32, 32, 512, 3, 3),
#(8, 512, 32, 32, 1024, 3, 3)
]
for N, C, H, W, K, R, S in NCHWKRS:
shift_h = np.random.randint(R, size=C, dtype=np.int32) - R//2
shift_w = np.random.randint(S, size=C, dtype=np.int32) - S//2
def shift_conv(a, b, **kwargs):
shift_h, shift_w = kwargs['sh'], kwargs['sw']
shift_torch = np.column_stack((shift_w*-1, shift_h*-1))
shift_torch = torch.from_numpy(shift_torch).cuda()
a = shift.apply(a, shift_torch)
b = b.permute(1, 0)
b = b.reshape(b.shape[0], b.shape[1], 1, 1)
return torch.nn.functional.conv2d(a, b)
configs += [([N, C, H, W],
[C, K],
[N, K, H, W],
shift_conv,
'nc(h + sh[c])(w + sw[c]),ck->nkhw',
{'sh': shift_h, 'sw': shift_w},
None, None, None)]
# Benchmark
torch.set_num_threads(1)
for a_shape, b_shape, c_shape, torch_fn, expr, arrays, \
transform_a, transform_b, transform_c in configs:
dtype = torch.cuda.FloatTensor
# initialize input tensors
a = torch.rand(*a_shape).type(dtype).cuda()
b = torch.rand(*b_shape).type(dtype).cuda()
# reference output
if torch_fn:
rc = torch_fn(a, b, **arrays)
else:
rc = torch.einsum(expr, a, b)
# triton output
ta = a if transform_a is None else transform_a(a)
tb = b if transform_b is None else transform_b(b)
tc = torch.empty(c_shape, device=a.device)
triton.ops.einsum(expr, ta, tb, tc, arrays = arrays, bench = True)
ctx = triton.ops._einsum.registry[tc]
tc = tc if transform_c is None else transform_c(tc)
# performance relative to equivalent matrix multiplication
B, M, N, K = ctx.matmul_B, ctx.matmul_M, ctx.matmul_N, ctx.matmul_K
cmp_eqbmm = True
if cmp_eqbmm:
a = torch.rand(B, M, K).type(dtype).cuda()
b = torch.rand(B, K, N).type(dtype).cuda()
c = torch.empty((B, M, N), device=a.device).cuda()
tmmc = triton.ops.einsum('bmk,bkn->bmn', a, b, c, bench = True)
ratio = triton.ops._einsum.registry[tmmc].forward_ms / ctx.forward_ms
cmp_str = f'({ratio:4.2f})'
else:
cmp_str = ''
# test and benchmark
bench = 2. * B * M * N * K / ctx.forward_ms * 1e-3
diff = (tc - rc).abs().max() / rc.abs().max()
print(f'{expr:>15}; {str(a_shape):>20}; {str(b_shape):>20}; {bench:4.2f} {cmp_str}; {diff:4.2f}')

View File

@@ -1,42 +0,0 @@
#include <torch/torch.h>
#include <vector>
// CUDA forward declarations
at::Tensor shift_cuda_forward(
const at::Tensor input,
const at::Tensor shift);
at::Tensor shift_cuda_backward(
const at::Tensor grad_input,
const at::Tensor shift);
// C++ interface
// NOTE: AT_ASSERT has become AT_CHECK on master after 0.4.
#define CHECK_CUDA(x) AT_ASSERTM(x.type().is_cuda(), #x " must be a CUDA tensor")
#define CHECK_CONTIGUOUS(x) AT_ASSERTM(x.is_contiguous(), #x " must be contiguous")
#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)
at::Tensor shift_forward(
const at::Tensor input,
const at::Tensor shift) {
CHECK_INPUT(input);
CHECK_INPUT(shift);
return shift_cuda_forward(input, shift);
}
at::Tensor shift_backward(
const at::Tensor grad_input,
const at::Tensor shift) {
CHECK_INPUT(grad_input);
CHECK_INPUT(shift);
return shift_cuda_backward(grad_input, shift);
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("forward", &shift_forward, "Shift forward (CUDA)");
m.def("backward", &shift_backward, "Shift backward (CUDA)");
}

View File

@@ -1,111 +0,0 @@
#include <ATen/ATen.h>
#include <cuda.h>
#include <cuda_runtime.h>
#include <vector>
namespace {
template <typename scalar_t>
__global__ void shift_cuda_forward_kernel(
const scalar_t* __restrict__ input,
const int32_t* __restrict__ shift,
scalar_t* __restrict__ output,
const int32_t B,
const int32_t C,
const int32_t H,
const int32_t W) {
const int32_t idx = blockIdx.x * blockDim.x + threadIdx.x;
const int32_t size = B*C*H*W;
const int32_t CHW = C*H*W;
const int32_t HW = H*W;
const int32_t b = idx / CHW;
const int32_t c = (idx - b*CHW) / HW;
const int32_t h = (idx - b*CHW - c*HW) / W;
const int32_t w = idx - b*CHW - c*HW - h*W;
const int32_t target_w = w + shift[2*c];
const int32_t target_h = h + shift[2*c + 1];
const int32_t target_idx = b*CHW + c*HW + target_h*W + target_w;
if (idx < size && target_w >= 0 && target_w < W && target_h >= 0 && target_h < H) {
output[target_idx] = input[idx];
}
}
template <typename scalar_t>
__global__ void shift_cuda_backward_kernel(
const scalar_t* __restrict__ grad_input,
scalar_t* __restrict__ grad_output,
const int32_t* __restrict__ shift,
const int32_t B,
const int32_t C,
const int32_t W,
const int32_t H) {
const int32_t idx = blockIdx.x * blockDim.x + threadIdx.x;
const int32_t size = B*C*W*H;
const int32_t CWH = C*W*H;
const int32_t WH = W*H;
const int32_t b = idx / CWH;
const int32_t c = (idx - b*CWH) / WH;
const int32_t w = (idx - b*CWH - c*WH) / W;
const int32_t h = idx - b*CWH - c*WH - w*H;
const int32_t target_w = w - shift[2*c];
const int32_t target_h = h - shift[2*c + 1];
const int32_t target_idx = b*CWH + c*WH + target_w*W + target_h;
if (idx < size && target_w >= 0 && target_w < W && target_h >= 0 && target_h < H) {
grad_output[target_idx] = grad_input[idx];
}
}
} // namespace
at::Tensor shift_cuda_forward(
const at::Tensor input,
const at::Tensor shift) {
const auto B = input.size(0);
const auto C = input.size(1);
const auto H = input.size(2);
const auto W = input.size(3);
const auto size = B*C*W*H;
const int threads = 1024;
const int blocks = (size + threads - 1) / threads;
auto output = at::zeros_like(input);
AT_DISPATCH_FLOATING_TYPES_AND_HALF(input.type(), "shift_forward_cuda", ([&] {
shift_cuda_forward_kernel<scalar_t><<<blocks, threads>>>(
input.data<scalar_t>(),
shift.data<int32_t>(),
output.data<scalar_t>(),
B,
C,
H,
W);
}));
return output;
}
at::Tensor shift_cuda_backward(
const at::Tensor grad_input,
const at::Tensor shift) {
const auto B = grad_input.size(0);
const auto C = grad_input.size(1);
const auto H = grad_input.size(2);
const auto W = grad_input.size(3);
const auto size = B*C*W*H;
const int threads = 1024;
const int blocks = (size + threads - 1) / threads;
auto grad_output = at::zeros_like(grad_input);
AT_DISPATCH_FLOATING_TYPES_AND_HALF(grad_input.type(), "shift_backward_cuda", ([&] {
shift_cuda_backward_kernel<scalar_t><<<blocks, threads>>>(
grad_input.data<scalar_t>(),
grad_output.data<scalar_t>(),
shift.data<int32_t>(),
B,
C,
H,
W);
}));
return grad_output;
}

View File

@@ -1,109 +0,0 @@
import triton
import numpy
import torch
import itertools
torch.manual_seed(0)
numpy.random.seed(0)
def to_sparse(expr, data, layout, shape, block):
# shape of result
sparse = None
shape_ret = []
for i, d in enumerate(expr):
if d.isupper() and sparse is None:
sparse = i
shape_ret.append(int(layout.sum()))
if d.isupper():
shape_ret.append(block[d])
else:
shape_ret.append(shape[i])
# iterator
steps = [block[d] if d.isupper() else 1 for d in expr]
it = [range(0, shape[i], steps[i]) for i in range(len(expr))]
# create result
ret = torch.empty(*shape_ret, dtype=data.dtype, device=data.device)
blockid = 0
nzblockid = 0
for curr in itertools.product(*it):
if all([curr[i] == it[i][0] for i in range(len(curr)) if expr[i].isupper()]):
blockid = 0
nzblockid = 0
data_slice = [slice(curr[i], curr[i] + steps[i], 1) for i in range(len(curr))]
ret_slice = [slice(0, block[expr[i]], 1) if expr[i].isupper() else slice(curr[i], curr[i] + 1) for i in range(len(curr))]
ret_slice.insert(sparse, nzblockid)
if int(layout.view(-1)[blockid]):
ret[ret_slice] = data[data_slice]
nzblockid += 1
blockid += 1
return ret
def to_dense(expr, data, layout, shape, block):
sparse = None
for i, d in enumerate(expr):
if d.isupper() and sparse is None:
sparse = i
ret = torch.zeros(*shape, dtype=data.dtype, device=data.device)
steps = [block[d] if d.isupper() else 1 for d in expr]
it = [range(0, shape[i], steps[i]) for i in range(len(expr))]
blockid = 0
nzblockid = 0
for curr in itertools.product(*it):
if all([curr[i] == it[i][0] for i in range(len(curr)) if expr[i].isupper()]):
blockid = 0
nzblockid = 0
ret_slice = [slice(curr[i], curr[i] + steps[i], 1) for i in range(len(curr))]
data_slice = [slice(0, block[expr[i]], 1) if expr[i].isupper() else slice(curr[i], curr[i] + 1) for i in range(len(curr))]
data_slice.insert(sparse, nzblockid)
if int(layout.view(-1)[blockid]):
ret[ret_slice] = data[data_slice]
nzblockid += 1
blockid += 1
return ret
def test_expr(expr, shape, blocks):
# decompose expr
expr_a, expr_bc = expr.split(",")
expr_b, expr_c = expr_bc.split("->")
# check with argument is sparse
sparse_a = any(x.isupper() for x in expr_a)
sparse_b = any(x.isupper() for x in expr_b)
sparse_c = any(x.isupper() for x in expr_c)
# allocate data
shape_a = [shape[d.lower()] for d in expr_a]
shape_b = [shape[d.lower()] for d in expr_b]
shape_c = [shape[d.lower()] for d in expr_c]
ref_a = torch.rand(*shape_a, device='cuda')
ref_b = torch.rand(*shape_b, device='cuda')
ref_c = torch.zeros(*shape_c, device='cuda')
# layouts
layout_a = [shape[d.lower()]//blocks[d] for d in expr_a if d.isupper()]
layout_b = [shape[d.lower()]//blocks[d] for d in expr_b if d.isupper()]
layout_c = [shape[d.lower()]//blocks[d] for d in expr_c if d.isupper()]
layout_a = torch.randint(0, 2, layout_a, device='cuda')
layout_b = torch.randint(0, 2, layout_b, device='cuda')
layout_c = torch.randint(0, 2, layout_c, device='cuda')
# triton computation
triton_a = to_sparse(expr_a, ref_a, layout_a, shape_a, blocks) if sparse_a else ref_a
triton_b = to_sparse(expr_b, ref_b, layout_b, shape_b, blocks) if sparse_b else ref_b
layouts = {expr_a: layout_a, expr_b: layout_b, expr_c: layout_c}
triton_c = triton.ops.einsum(expr, triton_a, triton_b, layouts, blocks)
torch.cuda.synchronize()
# reference computation
ref_a = to_dense(expr_a, triton_a, layout_a, shape_a, blocks) if sparse_a else ref_a
ref_b = to_dense(expr_b, triton_b, layout_b, shape_b, blocks) if sparse_b else ref_b
ref_c = torch.einsum(expr.lower(), ref_a, ref_b)
if sparse_c:
ref_c = to_sparse(expr_c, ref_c, layout_c, shape_c, blocks)
torch.cuda.synchronize()
print((ref_c - triton_c).abs().max())
# shape characteristics
test_expr('bHMK,bhkn->bhmn', {'b': 2, 'h': 2, 'm': 256, 'k': 256, 'n': 256}, {'H': 1, 'M': 32, 'K': 32})
test_expr('bhmk,bHKN->bhmn', {'b': 2, 'h': 2, 'm': 256, 'k': 256, 'n': 256}, {'H': 1, 'K': 32, 'N': 32})
test_expr('bhmk,bhkn->bHMN', {'b': 2, 'h': 2, 'm': 256, 'k': 256, 'n': 256}, {'H': 1, 'M': 32, 'N': 32})

View File

@@ -171,7 +171,7 @@ class _conv(torch.autograd.Function):
_conv.kernel[dtype] = (delta, triton.kernel(_conv.src, num_warps=[2, 4], defines=defines))
delta, kernel = _conv.kernel[dtype]
# allocate output
c = triton.empty([Z, CO, P, Q], dtype=dtype)
c = torch.empty([Z, CO, P, Q], dtype=dtype)
# enqueue
grid = lambda opt: [triton.cdiv(Z*P*Q, opt.d('TM')),
triton.cdiv(CO, opt.d('TN'))]

View File

@@ -3,6 +3,9 @@ import triton
class _dot(torch.autograd.Function):
src = """
#define STM 4
#define STN 4
__global__ void dot(TYPE * A __noalias __readonly __aligned(16),
TYPE * B __noalias __readonly __aligned(16),
TYPE * C __noalias __aligned(16),
@@ -14,20 +17,26 @@ __global__ void dot(TYPE * A __noalias __readonly __aligned(16),
int ldb __multipleof(8),
int ldc __multipleof(8)) {
// prologue
int ridx = get_program_id(0);
int ridy = get_program_id(1);
int ridz = get_program_id(2);
int gridx = M / TM;
int gridy = N / TN;
int rid = ridx + ridy * gridx;
ridx = rid / gridy;
ridy = rid % gridy;
int rm[TM] = ridx * TM + 0 ... TM;
int rn[TN] = ridy * TN + 0 ... TN;
int pid = get_program_id(0);
int pidz = get_program_id(2);
int gridm = M / TM;
int gridn = N / TN;
int stgridm = (gridm + STM - 1) / STM;
int stgridn = (gridn + STN - 1) / STN;
int stid = pid / (STM * STN);
int laneid = pid % (STM * STN);
int stm = stid / stgridn;
int stn = stid % stgridn;
int lanem = laneid / STN;
int lanen = laneid % STN;
int pidm = stm*STM + lanem;
int pidn = stn*STN + lanen;
int rm[TM] = pidm * TM + 0 ... TM;
int rn[TN] = pidn * TN + 0 ... TN;
// reduction splitting
K = K / TZ;
int rk[TK] = ridz * K + 0 ... TK;
int rk[TK] = pidz * K + 0 ... TK;
// pointers to operands
int offa[TM, TK] = rk[newaxis, :] * STRIDE_AK + rm[:, newaxis] * STRIDE_AM;
@@ -44,11 +53,11 @@ __global__ void dot(TYPE * A __noalias __readonly __aligned(16),
// reduction loop
float acc[TM, TN] = 0;
for(int k = K; k > 0; k -= TK){
acc += a @ b;
bool checka[TM, TK] = k > TK;
bool checkb[TK, TN] = k > TK;
pa += TK * STRIDE_AK;
pb += TK * STRIDE_BK;
acc += a @ b;
a = *?(checka)pa;
b = *?(checkb)pb;
}
@@ -56,8 +65,8 @@ __global__ void dot(TYPE * A __noalias __readonly __aligned(16),
TYPE c[TM, TN] = acc;
// epilogue
int rxm[TM] = ridx * TM + 0 ... TM;
int rxn[TN] = ridy * TN + 0 ... TN;
int rxm[TM] = pidm * TM + 0 ... TM;
int rxn[TN] = pidn * TN + 0 ... TN;
int offc[TM, TN] = rxm[:, newaxis] * ldc + rxn[newaxis, :];
TYPE* pc[TM, TN] = C + offc;
bool checkc[TM, TN] = (rxm[:, newaxis] < M) && (rxn[newaxis, :] < N);
@@ -66,7 +75,7 @@ __global__ void dot(TYPE * A __noalias __readonly __aligned(16),
*?(checkc) pc = c;
#else
// accumulate partial result using spin-locks
int *plock = locks + rid;
int *plock = locks + pid;
int *pcount = plock + get_num_programs(0) * get_num_programs(1);
for(int repeat = 1; repeat == 1; repeat = atomic_cas(plock, 0, 1));
int count = *pcount;
@@ -100,7 +109,7 @@ __global__ void dot(TYPE * A __noalias __readonly __aligned(16),
'STRIDE_BN': '1', 'STRIDE_BK': 'ldb',
'TM' : [128],
'TN' : [128],
'TK' : [16],
'TK' : [32],
'TZ' : [1]
}
_dot.kernel[dtype] = triton.kernel(_dot.src, num_warps=[4], defines=defines)
@@ -109,9 +118,10 @@ __global__ void dot(TYPE * A __noalias __readonly __aligned(16),
M, K = a.shape
K, N = b.shape
c = torch.empty([M,N], dtype=dtype, device=a.device)
print(kernel.asm('sass', c.device))
print(kernel.asm('ptx', c.device))
# enqueue
grid = lambda opt: [triton.cdiv(M, opt.d('TM')),
triton.cdiv(N, opt.d('TN'))]
grid = lambda opt: [triton.cdiv(M, opt.d('TM'))*triton.cdiv(N, opt.d('TN'))]
time = kernel(a, b, c, 1., M, N, K,
a.stride(0), b.stride(0), c.stride(0), grid=grid)
return c
@@ -130,6 +140,4 @@ b = torch.rand((K, N)).cuda().half()
zc = torch.matmul(a,b)
zc_ = dot(a,b)
print(torch.allclose(zc, zc_))

View File

@@ -111,7 +111,7 @@ setup(
author_email='ptillet@g.harvard.edu',
description='A language and compiler for custom Deep Learning operations',
long_description='',
packages=['triton', 'triton/_C', 'triton/ops'],
packages=['triton', 'triton/_C'],
install_requires=['numpy', 'torch', 'sympy'],
package_data={'': data},
ext_modules=[CMakeExtension('triton', 'triton/_C/')],

View File

@@ -38,7 +38,7 @@ void delete_grid(const map_key_t& key) {
void register_fn(const map_key_t& key,
const std::string& src,
const rt::function::options_space_t& opt) {
const rt::options_space_t& opt) {
if(id_fn_map.find(key) == id_fn_map.end())
id_fn_map[key].reset(new rt::function(src, opt, ""));
}
@@ -47,9 +47,9 @@ void delete_fn(const map_key_t& key) {
id_fn_map.erase(key);
}
std::string get_fn_ptx(const map_key_t& key, const rt::function::options_t& opt) {
triton::driver::cu_device device(torch_get_cuda_device(key.second), false);
return id_fn_map[key]->ptx(&device, opt);
std::string get_fn_asm(const map_key_t& key, rt::asm_mode_t mode, const rt::options_t& opt) {
triton::driver::cu_device device(key.second, false);
return id_fn_map[key]->get_asm(mode, &device, opt);
}
void cleanup() {
@@ -63,7 +63,7 @@ size_t make_op_id() {
/* Function signature */
void make_module(const std::string& src, ir::module* ir,
const runtime::function::options_space_t& opt) {
const runtime::options_space_t& opt) {
std::string copy = triton::runtime::function::preheader() + src;
// pre-process
TokenSequence tokens;
@@ -80,7 +80,7 @@ void make_module(const std::string& src, ir::module* ir,
}
std::vector<rt::arg_type> get_fn_signature(const std::string& src,
const runtime::function::options_space_t& opt) {
const runtime::options_space_t& opt) {
// triton-ir code-gen
ir::context ctx;
auto ir = std::shared_ptr<ir::module>(new ir::module("", ctx));
@@ -95,8 +95,8 @@ std::vector<rt::arg_type> get_fn_signature(const std::string& src,
return ret;
}
typedef triton::runtime::function::options_t options_t;
typedef triton::runtime::function::options_space_t options_space_t;
typedef triton::runtime::options_t options_t;
typedef triton::runtime::options_space_t options_space_t;
PYBIND11_MODULE(libtriton, m) {
m.doc() = "Python bindings to the C++ Triton API";
@@ -112,6 +112,10 @@ PYBIND11_MODULE(libtriton, m) {
.value("float", rt::FLOAT_T)
.value("double", rt::DOUBLE_T)
.value("buffer", rt::BUFFER_T);
pybind11::enum_<rt::asm_mode_t>(m, "asm_mode")
.value("ptx", rt::ASM_NV_PTX)
.value("sass", rt::ASM_NV_SASS);
pybind11::class_<options_t>(m, "options")
.def(pybind11::init<>())
@@ -126,7 +130,7 @@ PYBIND11_MODULE(libtriton, m) {
// hooks into triton constructs since frameworks may not use pybind11
m.def("get_fn_signature", &get_fn_signature);
m.def("get_fn_ptx", &get_fn_ptx);
m.def("get_fn_asm", &get_fn_asm);
m.def("register_grid", &register_grid);
m.def("delete_grid", &delete_grid);
m.def("register_fn", &register_fn);

View File

@@ -59,16 +59,12 @@ void synchronize(int64_t dev_id) {
}
}
torch::Tensor raw_like(torch::Tensor x){
torch::Tensor cuda_empty_like(torch::Tensor x){
if(x.nbytes() == 0)
return torch::empty_like(x);
C10_CUDA_CHECK(cudaSetDevice(x.device().index()));
auto shape = x.sizes();
CUdeviceptr data;
triton::driver::dispatch::cuMemAlloc(&data, x.nbytes());
auto deleter = [data](void* ptr) { triton::driver::dispatch::cuMemFree_v2(data); };
auto ret = torch::from_blob((void*)data, shape, deleter, x.options());
ret.copy_(x);
void* data;
cudaMalloc(&data, x.nbytes());
auto ret = torch::from_blob((void*)data, x.sizes(), x.strides(), [data](void* ptr) { cudaFree(data); }, x.options());
return ret;
}
@@ -94,6 +90,6 @@ void launch_kernel(int64_t op_id, int64_t dev_id, const std::string& args,
static auto registry = torch::RegisterOperators()
.op("triton::launch_kernel", &launch_kernel)
.op("triton::raw_like", &raw_like)
.op("triton::cuda_empty_like", &cuda_empty_like)
.op("triton::cdiv_sum", &cdiv_sum)
.op("triton::synchronize", &synchronize);

View File

@@ -1,7 +1,4 @@
from .kernel import *
import triton.ops
#import triton.nn
# clean-up libtriton resources
import atexit

View File

@@ -68,8 +68,17 @@ class kernel:
size = sum([sizes[x] for x in arg_types])
self.tys = ''.join([codes[x] for x in arg_types])
def ptx(self, device, **kwargs):
def asm(self, mode, device, **kwargs):
dev_id = device.index
# assembly mode
supported = {
'ptx': libtriton.asm_mode.ptx,
'sass': libtriton.asm_mode.sass,
}
if mode not in supported:
raise('ASM mode must be in ', supported.keys())
mode = supported[mode]
# disambiguates #defines
libtriton.register_fn((self.op_id, dev_id), self.src, self.opt)
def _single_value_or_err(x, key):
if isinstance(x, list) and len(x) == 1:
@@ -86,15 +95,18 @@ class kernel:
opt = libtriton.options()
opt.num_warps = _single_value_or_err(self.opt.num_warps, 'num_warps')
opt.defines = defines
return libtriton.get_fn_ptx((self.op_id, dev_id), opt)
# run
return libtriton.get_fn_asm((self.op_id, dev_id), mode, opt)
def __call__(self, *args, **kwargs):
if 'TRITON_DEBUG_MODE' in os.environ:
_args = args
args = [x for x in args]
args = [x.clone() if isinstance(x, torch.Tensor) else x for x in _args]
for i in range(len(args)):
if isinstance(args[i], torch.Tensor):
args[i] = torch.ops.triton.raw_like(args[i])
args[i] = torch.ops.triton.cuda_empty_like(args[i])
args[i].copy_(_args[i])
torch.cuda.synchronize()
for x in args:
if isinstance(x, torch.Tensor):
device = x.device.index
@@ -116,6 +128,8 @@ class kernel:
constants = list(kwargs['constants'].values()) if 'constants' in kwargs else []
torch.ops.triton.launch_kernel(self.op_id, device, params, names, constants)
if 'TRITON_DEBUG_MODE' in os.environ:
torch.cuda.synchronize()
for i in range(len(args)):
if isinstance(args[i], torch.Tensor):
_args[i].copy_(args[i])
_args[i].copy_(args[i].clone())
args = _args

View File

@@ -1,2 +0,0 @@
from .conv import replace_conv2d
from .attention import replace_mah

View File

@@ -1,312 +0,0 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
import triton
def bmm(x, w, mask = None):
b, m, k = x.size()
b, k, n = w.size()
out = torch.empty([b, m, n], device=x.device)
triton.ops.einsum('bmk,bkn->bmn', x, w, out, mask=mask, bench=False)
return out
def multi_head_attention_forward(query, # type: Tensor
key, # type: Tensor
value, # type: Tensor
embed_dim_to_check, # type: int
num_heads, # type: int
in_proj_weight, # type: Tensor
in_proj_bias, # type: Tensor
bias_k, # type: Optional[Tensor]
bias_v, # type: Optional[Tensor]
add_zero_attn, # type: bool
dropout_p, # type: float
out_proj_weight, # type: Tensor
out_proj_bias, # type: Tensor
training=True, # type: bool
key_padding_mask=None, # type: Optional[Tensor]
need_weights=True, # type: bool
attn_mask=None, # type: Optional[Tensor]
use_separate_proj_weight=False, # type: bool
q_proj_weight=None, # type: Optional[Tensor]
k_proj_weight=None, # type: Optional[Tensor]
v_proj_weight=None, # type: Optional[Tensor]
static_k=None, # type: Optional[Tensor]
static_v=None, # type: Optional[Tensor]
acc_bitmask=None
):
# type: (...) -> Tuple[Tensor, Optional[Tensor]]
r"""
Args:
query, key, value: map a query and a set of key-value pairs to an output.
See "Attention Is All You Need" for more details.
embed_dim_to_check: total dimension of the model.
num_heads: parallel attention heads.
in_proj_weight, in_proj_bias: input projection weight and bias.
bias_k, bias_v: bias of the key and value sequences to be added at dim=0.
add_zero_attn: add a new batch of zeros to the key and
value sequences at dim=1.
dropout_p: probability of an element to be zeroed.
out_proj_weight, out_proj_bias: the output projection weight and bias.
training: apply dropout if is ``True``.
key_padding_mask: if provided, specified padding elements in the key will
be ignored by the attention. This is an binary mask. When the value is True,
the corresponding value on the attention layer will be filled with -inf.
need_weights: output attn_output_weights.
attn_mask: mask that prevents attention to certain positions. This is an additive mask
(i.e. the values will be added to the attention layer).
use_separate_proj_weight: the function accept the proj. weights for query, key,
and value in differnt forms. If false, in_proj_weight will be used, which is
a combination of q_proj_weight, k_proj_weight, v_proj_weight.
q_proj_weight, k_proj_weight, v_proj_weight, in_proj_bias: input projection weight and bias.
static_k, static_v: static key and value used for attention operators.
Shape:
Inputs:
- query: :math:`(L, N, E)` where L is the target sequence length, N is the batch size, E is
the embedding dimension.
- key: :math:`(S, N, E)`, where S is the source sequence length, N is the batch size, E is
the embedding dimension.
- value: :math:`(S, N, E)` where S is the source sequence length, N is the batch size, E is
the embedding dimension.
- key_padding_mask: :math:`(N, S)`, ByteTensor, where N is the batch size, S is the source sequence length.
- attn_mask: :math:`(L, S)` where L is the target sequence length, S is the source sequence length.
- static_k: :math:`(N*num_heads, S, E/num_heads)`, where S is the source sequence length,
N is the batch size, E is the embedding dimension. E/num_heads is the head dimension.
- static_v: :math:`(N*num_heads, S, E/num_heads)`, where S is the source sequence length,
N is the batch size, E is the embedding dimension. E/num_heads is the head dimension.
Outputs:
- attn_output: :math:`(L, N, E)` where L is the target sequence length, N is the batch size,
E is the embedding dimension.
- attn_output_weights: :math:`(N, L, S)` where N is the batch size,
L is the target sequence length, S is the source sequence length.
"""
tgt_len, bsz, embed_dim = query.size()
assert embed_dim == embed_dim_to_check
assert key.size() == value.size()
head_dim = embed_dim // num_heads
assert head_dim * num_heads == embed_dim, "embed_dim must be divisible by num_heads"
scaling = float(head_dim) ** -0.5
if not use_separate_proj_weight:
if torch.equal(query, key) and torch.equal(key, value):
# self-attention
q, k, v = F.linear(query, in_proj_weight, in_proj_bias).chunk(3, dim=-1)
elif torch.equal(key, value):
# encoder-decoder attention
# This is inline in_proj function with in_proj_weight and in_proj_bias
_b = in_proj_bias
_start = 0
_end = embed_dim
_w = in_proj_weight[_start:_end, :]
if _b is not None:
_b = _b[_start:_end]
q = F.linear(query, _w, _b)
if key is None:
assert value is None
k = None
v = None
else:
# This is inline in_proj function with in_proj_weight and in_proj_bias
_b = in_proj_bias
_start = embed_dim
_end = None
_w = in_proj_weight[_start:, :]
if _b is not None:
_b = _b[_start:]
k, v = F.linear(key, _w, _b).chunk(2, dim=-1)
else:
# This is inline in_proj function with in_proj_weight and in_proj_bias
_b = in_proj_bias
_start = 0
_end = embed_dim
_w = in_proj_weight[_start:_end, :]
if _b is not None:
_b = _b[_start:_end]
q = F.linear(query, _w, _b)
# This is inline in_proj function with in_proj_weight and in_proj_bias
_b = in_proj_bias
_start = embed_dim
_end = embed_dim * 2
_w = in_proj_weight[_start:_end, :]
if _b is not None:
_b = _b[_start:_end]
k = F.linear(key, _w, _b)
# This is inline in_proj function with in_proj_weight and in_proj_bias
_b = in_proj_bias
_start = embed_dim * 2
_end = None
_w = in_proj_weight[_start:, :]
if _b is not None:
_b = _b[_start:]
v = F.linear(value, _w, _b)
else:
q_proj_weight_non_opt = torch.jit._unwrap_optional(q_proj_weight)
len1, len2 = q_proj_weight_non_opt.size()
assert len1 == embed_dim and len2 == query.size(-1)
k_proj_weight_non_opt = torch.jit._unwrap_optional(k_proj_weight)
len1, len2 = k_proj_weight_non_opt.size()
assert len1 == embed_dim and len2 == key.size(-1)
v_proj_weight_non_opt = torch.jit._unwrap_optional(v_proj_weight)
len1, len2 = v_proj_weight_non_opt.size()
assert len1 == embed_dim and len2 == value.size(-1)
if in_proj_bias is not None:
q = F.linear(query, q_proj_weight_non_opt, in_proj_bias[0:embed_dim])
k = F.linear(key, k_proj_weight_non_opt, in_proj_bias[embed_dim:(embed_dim * 2)])
v = F.linear(value, v_proj_weight_non_opt, in_proj_bias[(embed_dim * 2):])
else:
q = F.linear(query, q_proj_weight_non_opt, in_proj_bias)
k = F.linear(key, k_proj_weight_non_opt, in_proj_bias)
v = F.linear(value, v_proj_weight_non_opt, in_proj_bias)
q = q * scaling
if bias_k is not None and bias_v is not None:
if static_k is None and static_v is None:
k = torch.cat([k, bias_k.repeat(1, bsz, 1)])
v = torch.cat([v, bias_v.repeat(1, bsz, 1)])
if attn_mask is not None:
attn_mask = torch.cat([attn_mask,
torch.zeros((attn_mask.size(0), 1),
dtype=attn_mask.dtype,
device=attn_mask.device)], dim=1)
if key_padding_mask is not None:
key_padding_mask = torch.cat(
[key_padding_mask, torch.zeros((key_padding_mask.size(0), 1),
dtype=key_padding_mask.dtype,
device=key_padding_mask.device)], dim=1)
else:
assert static_k is None, "bias cannot be added to static key."
assert static_v is None, "bias cannot be added to static value."
else:
assert bias_k is None
assert bias_v is None
q = q.contiguous().view(tgt_len, bsz * num_heads, head_dim).transpose(0, 1)
if k is not None:
k = k.contiguous().view(-1, bsz * num_heads, head_dim).transpose(0, 1)
if v is not None:
v = v.contiguous().view(-1, bsz * num_heads, head_dim).transpose(0, 1)
if static_k is not None:
assert static_k.size(0) == bsz * num_heads
assert static_k.size(2) == head_dim
k = static_k
if static_v is not None:
assert static_v.size(0) == bsz * num_heads
assert static_v.size(2) == head_dim
v = static_v
src_len = k.size(1)
if key_padding_mask is not None:
assert key_padding_mask.size(0) == bsz
assert key_padding_mask.size(1) == src_len
if add_zero_attn:
src_len += 1
k = torch.cat([k, torch.zeros((k.size(0), 1) + k.size()[2:], dtype=k.dtype, device=k.device)], dim=1)
v = torch.cat([v, torch.zeros((v.size(0), 1) + v.size()[2:], dtype=v.dtype, device=v.device)], dim=1)
if attn_mask is not None:
attn_mask = torch.cat([attn_mask, torch.zeros((attn_mask.size(0), 1),
dtype=attn_mask.dtype,
device=attn_mask.device)], dim=1)
if key_padding_mask is not None:
key_padding_mask = torch.cat(
[key_padding_mask, torch.zeros((key_padding_mask.size(0), 1),
dtype=key_padding_mask.dtype,
device=key_padding_mask.device)], dim=1)
attn_output_weights = bmm(q, k.transpose(1, 2), mask=acc_bitmask)
assert list(attn_output_weights.size()) == [bsz * num_heads, tgt_len, src_len]
if attn_mask is not None:
attn_mask = attn_mask.unsqueeze(0)
attn_output_weights += attn_mask
if key_padding_mask is not None:
attn_output_weights = attn_output_weights.view(bsz, num_heads, tgt_len, src_len)
attn_output_weights = attn_output_weights.masked_fill(
key_padding_mask.unsqueeze(1).unsqueeze(2),
float('-inf'),
)
attn_output_weights = attn_output_weights.view(bsz * num_heads, tgt_len, src_len)
attn_output_weights = F.softmax(
attn_output_weights, dim=-1)
attn_output_weights = F.dropout(attn_output_weights, p=dropout_p, training=training)
attn_output = bmm(attn_output_weights, v, mask=acc_bitmask)
assert list(attn_output.size()) == [bsz * num_heads, tgt_len, head_dim]
attn_output = attn_output.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim)
attn_output = F.linear(attn_output, out_proj_weight, out_proj_bias)
if need_weights:
# average attention weights over heads
attn_output_weights = attn_output_weights.view(bsz, num_heads, tgt_len, src_len)
return attn_output, attn_output_weights.sum(dim=1) / num_heads
else:
return attn_output, None
class MultiheadAttention(nn.modules.activation.MultiheadAttention):
def __init__(self, embed_dim, num_heads, dropout=0., bias=True, add_bias_kv=False, add_zero_attn=False, kdim=None, vdim=None, acc_bitmask=None):
super(MultiheadAttention, self).__init__(embed_dim, num_heads, dropout, bias, add_bias_kv, add_zero_attn, kdim, vdim)
self.acc_bitmask = acc_bitmask
def forward(self, query, key, value, key_padding_mask=None,
need_weights=True, attn_mask=None):
if not self._qkv_same_embed_dim:
return multi_head_attention_forward(
query, key, value, self.embed_dim, self.num_heads,
self.in_proj_weight, self.in_proj_bias,
self.bias_k, self.bias_v, self.add_zero_attn,
self.dropout, self.out_proj.weight, self.out_proj.bias,
training=self.training,
key_padding_mask=key_padding_mask, need_weights=need_weights,
attn_mask=attn_mask, use_separate_proj_weight=True,
q_proj_weight=self.q_proj_weight, k_proj_weight=self.k_proj_weight,
v_proj_weight=self.v_proj_weight,
acc_bitmask=self.acc_bitmask)
else:
return multi_head_attention_forward(
query, key, value, self.embed_dim, self.num_heads,
self.in_proj_weight, self.in_proj_bias,
self.bias_k, self.bias_v, self.add_zero_attn,
self.dropout, self.out_proj.weight, self.out_proj.bias,
training=self.training,
key_padding_mask=key_padding_mask, need_weights=need_weights,
attn_mask=attn_mask,
acc_bitmask=self.acc_bitmask)
def replace_mah(model, mask = None):
for child_name, child in model.named_children():
if isinstance(child, nn.modules.activation.MultiheadAttention):
add_bias_kv = child.bias_k is not None
device = child.in_proj_weight.device
mah = MultiheadAttention(child.embed_dim, child.num_heads,
dropout=child.dropout, add_bias_kv=add_bias_kv,
add_zero_attn=child.add_zero_attn, kdim=child.kdim,
vdim=child.vdim, acc_bitmask=mask).to(device)
for yparam, xparam in zip(mah.parameters(), child.parameters()):
yparam.data.copy_(xparam.data)
setattr(model, child_name, mah)
else:
replace_mah(child, mask)

View File

@@ -1,166 +0,0 @@
import triton
import torch.nn as nn
import torch
import torch.nn.functional as F
class _conv2d(torch.autograd.Function):
@staticmethod
def forward(ctx, x, weight, bias,
stride, padding, dilation, groups,
acc_bitmask):
assert dilation == (1, 1)
assert groups == 1
assert bias == None
pad_h, pad_w = padding
stride_h, stride_w = stride
n, c, h, w = x.size()
k, c, r, s = weight.size()
# allocate output
p = (h + 2*padding[0] - r)//stride[0] + 1
q = (w + 2*padding[1] - s)//stride[1] + 1
output = torch.empty((n, k, p, q), dtype=x.dtype, device=x.device)
# padding
if pad_h or pad_w:
x = triton.ops._einsum.pad(x, [pad_w, pad_w, pad_h, pad_h])
# convolution
triton.ops.einsum(f'nc(h*stride_h + r - pad_h)(w*stride_w + s - pad_w),kcrs->nkhw',
x, weight, mask=acc_bitmask,
output=output,
values = {'pad_h': pad_h,
'stride_h': stride_h,
'pad_w': pad_w,
'stride_w': stride_w})
# prepare backprop
ctx.save_for_backward(x, weight)
ctx.stride = stride
ctx.padding = padding
ctx.acc_bitmask = acc_bitmask
# return
return output
@staticmethod
def backward(ctx, dy):
# retrieve contextual information
x, weight = ctx.saved_tensors
stride = ctx.stride
padding = ctx.padding
acc_bitmask = ctx.acc_bitmask
# gradient of the input
dx = None
if ctx.needs_input_grad[0]:
# dy must be padded
n, k, p, q = dy.size()
n, c, h, w = x.size()
k, c, r, s = weight.size()
dypad = triton.ops._einsum.pad(dy, [4, 4, 4, 4])
# have to be careful here
# the gradient of strided conv is a conv over a sparse image
# which can be decomposed as a set of smaller convs
dx = torch.empty_like(x)
for offh in range(stride[0]):
for offw in range(stride[1]):
poffh = (offh + padding[0]) % stride[0]
poffw = (offw + padding[1]) % stride[1]
pad_h = int((padding[0] + (stride[0] - 1)*offh) / stride[0])
pad_w = int((padding[1] + (stride[1] - 1)*offw) / stride[1])
if poffh >= r or poffw >= s:
dx[:, :, offh::stride[0], offw::stride[1]] = 0
else:
triton.ops.einsum(f'nk(h - r + pad_h)(w - s + pad_w),kcrs->nchw',
dypad[:, :, :, :],
weight[:, :, poffh::stride[0], poffw::stride[1]],
output = dx[:, :, offh::stride[0], offw::stride[1]],
mask = acc_bitmask,
values = {'pad_h': pad_h,
'pad_w': pad_w})
# gradient for the weight
dw = None
if ctx.needs_input_grad[1]:
dw = torch.empty_like(weight)
triton.ops.einsum(f'nc(p*{stride[0]}+r-{padding[0]})(q*{stride[1]}+s-{padding[1]}),nkpq->kcrs',
x, dy, output = dw, mask = acc_bitmask)
#print('dw: ', dw.view(-1)[0])
return dx, dw, None, None, None, None, None, None
conv2d = _conv2d.apply
class Conv2d(nn.Conv2d):
def __init__(self, in_channels, out_channels, kernel_size, stride=1,
padding=0, dilation=1, groups=1,
bias=True, padding_mode='zeros',
acc_bitmask = None):
super(Conv2d, self).__init__(
in_channels, out_channels, kernel_size, stride, padding, dilation,
groups, bias, padding_mode)
self.acc_bitmask = acc_bitmask
def forward(self, input):
#if self.kernel_size[0] == 3 and self.stride[0] != 1:
#print(self.padding, self.stride, input.size(), self.weight.size())
# return F.conv2d(input, self.weight, self.bias, self.stride, self.padding, self.dilation, self.groups)
return conv2d(input, self.weight, self.bias, self.stride,
self.padding, self.dilation, self.groups,
self.acc_bitmask)
def replace_conv2d(model, acc_bitmask = None):
for child_name, child in model.named_children():
if isinstance(child, nn.Conv2d):
conv2d = Conv2d(child.in_channels, child.out_channels, child.kernel_size,
child.stride, child.padding, child.dilation, child.groups,
child.bias, child.padding_mode,
acc_bitmask=acc_bitmask)
for yparam, xparam in zip(conv2d.parameters(), child.parameters()):
yparam.data.copy_(xparam.data)
setattr(model, child_name, conv2d)
else:
replace_conv2d(child, acc_bitmask)
# initialize input
#N, C, H, W, K, RS = 16, 32, 24, 24, 64, 3
#torch.Size([128, 64, 30, 30]) torch.Size([128, 64, 3, 3])
#torch.Size([128, 128, 15, 15]) torch.Size([256, 128, 3, 3])
#torch.Size([128, 256, 8, 8]) torch.Size([512, 256, 3, 3])
if __name__ == '__main__':
N, C, H, W, K, RS = 128, 64, 30, 30, 128, 3
#N, C, H, W, K, RS = 128, 128, 15, 15, 256, 3
#N, C, H, W, K, RS = 128, 256, 8, 8, 512, 3
pad, stride = 0, 1
torch.manual_seed(0)
x = torch.randn((N, C, H, W)).cuda()
x.requires_grad_(True)
#x.data[:] = 1
# initialize layers
torch.manual_seed(0)
rconv2d = nn.Conv2d(C, K, RS, stride, pad, bias=False).cuda()
torch.manual_seed(0)
tconv2d = Conv2d(C, K, RS, stride, pad, bias=False).cuda()
#rconv2d.weight.data[:] = 1
#tconv2d.weight.data[:] = 1
ry = rconv2d(x)
ty = tconv2d(x)
# reference
dy = torch.randn(ry.size()).cuda()
#dy.data[:] = 1
ry.backward(dy)
rdx = x.grad.clone()
rdw = rconv2d.weight.grad.clone()
x.grad.zero_()
# triton
ty.backward(dy)
tdx = x.grad.clone()
tdw = tconv2d.weight.grad.clone()
x.grad.zero_()
# print error
diff = lambda x, y: (x - y).abs().max()
print(diff(ry, ty))
print(diff(rdx, tdx))
print(diff(rdw, tdw))
#print((rdx - tdx).abs())
#print((rdx[0,0,:,:] - tdx[0,0,:,:]))
#print(rdx[0,0,:,:])
#print(tdx[0,0,:,:])

View File

@@ -1,13 +0,0 @@
import torch
import triton
def linear(x, w, bias = None):
print(x.size(), w.size())
m, k = x.size()
k, n = w.size()
out = torch.empty([m, n], device=x.device)
triton.ops.einsum('mk,nk->mn', x, w, bias)
if bias is not None:
out += bias
return out

View File

@@ -1,2 +0,0 @@
from .einsum import _einsum, einsum
from .batchnorm import _batchnorm, batchnorm

View File

@@ -1,136 +0,0 @@
import triton
import torch
import math
class _batchnorm(torch.autograd.Function):
fwd_src = """
void fwdbatchnorm(float *Y, float *M, float *V,
float *X, float *G, float *B,
int N, float eps) {
// pointers
int c = get_program_id(1);
int rm[TM] = 0 ... TM;
float *px[TM] = X + rm + c*N;
float* py[TM] = Y + rm + c*N;
// compute mean
float accm[TM] = 0;
for(int i = 0; i < N; i = i + TM)
accm = accm + *(px + i);
float mean = (float)accm[+] / N;
*(M + c) = mean;
// compute variance
float accv[TM] = 0;
for(int i = 0; i < N; i = i + TM){
float x[TM] = *(px + i);
x = x - mean;
accv = accv + x*x;
}
float var = (float)accv[+] / N;
*(V + c) = var;
// Normalize batch
float gamma = *(G + c);
float beta = *(B + c);
float rstdg = 1 / sqrtf(var + eps) * gamma;
for(int i = 0; i < N; i = i + TM){
float x[TM] = *(px + i);
float y[TM] = (x - mean)*rstdg + beta;
*(py + i) = y;
}
}
"""
bwd_src = """
void bwdbatchnorm(float *DX, float *DG, float *DB,
float *DY, float *X, float *G,
float *M, float *V,
int N, float epsilon) {
// pointers
int c = get_program_id(1);
int rx[TM] = 0 ... TM;
int offset = c*N;
float* px[TM] = X + rx + offset;
float* pdy[TM] = DY + rx + offset;
float* pdx[TM] = DX + rx + offset;
// fetch statistics
float gamma = *(G + c);
float mean = *(M + c);
float var = *(V + c);
float rstd = 1 / sqrtf(var + epsilon);
// compute dgamma and dbeta
float acc_dg[TM] = 0;
float acc_db[TM] = 0;
for(int i = 0; i < N; i = i + TM){
float x[TM] = *(px + i);
float dy[TM] = *(pdy + i);
acc_dg += dy*(x - mean)*rstd;
acc_db += dy;
}
float dg = acc_dg[+];
float db = acc_db[+];
*(DG + c) = dg;
*(DB + c) = db;
// compute dx
for(int i = 0; i < N; i = i + TM){
float x[TM] = *(px + i);
float dy[TM] = *(pdy + i);
float xhat[TM] = (x - mean) * rstd;
float xtmp[TM] = (xhat * dg + db) / N;
float dx[TM] = (dy - xtmp) * rstd * gamma;
*(pdx + i) = dx;
}
}
"""
fwd_kernel = None
bwd_kernel = None
@staticmethod
def forward(ctx, x, gamma, beta, eps):
# lazy compilation of kernel
if _batchnorm.fwd_kernel is None:
_batchnorm.fwd_kernel = triton.kernel(fwd_src, defines = {'TM': 128})
# shapes
shape = triton.shape(x)
dtype = x.dtype
# allocate outputs
C, H, W, B = shape[0], shape[1], shape[2], shape[3]
y = triton.empty(shape, dtype=dtype)
mean = triton.empty([C], dtype=dtype)
var = triton.empty([C], dtype=dtype)
# execute kernels
_batchnorm.fwd_kernel(y, mean, var, x, gamma, beta, H*W*B, eps,
grid = lambda opt: [1, C])
# save
ctx.save_for_backward(x, gamma, beta, mean, var)
ctx.eps = eps
return y
@staticmethod
def backward(ctx, dy):
# lazy compilation of kernel
if _batchnorm.bwd_kernel is None:
_batchnorm.bwd_kernel = triton.kernel(bwd_src, defines = {'TN': 128})
# retrieve info
x, gamma, beta, mean, var = ctx.saved_tensors
eps = ctx.eps
# allocate result
dx = triton.empty(triton.shape(x), dtype=x.dtype)
dgamma = triton.empty(triton.shape(gamma), dtype=gamma.dtype)
dbeta = triton.empty(triton.shape(beta), dtype=beta.dtype)
# execute
C, H, W, B = triton.shape(x)
_batchnorm.bwd_kernel(dx, dgamma, dbeta, dy,
x, gamma, mean, var,
H*W*B, eps,
grid = lambda opt: [1, C])
return dx, dgamma, dbeta, None
batchnorm = _batchnorm.apply

View File

@@ -1,794 +0,0 @@
from math import ceil, log2
from enum import IntEnum
from functools import reduce
from operator import mul
from collections import OrderedDict
from collections import namedtuple
import re
import string
import triton
import torch
# numpy -- ideally removed in a future release
import numpy as np
# sympy -- ideally removed in a future release
import sympy as sp
from sympy.parsing.sympy_parser import parse_expr
from sympy.printing.ccode import C89CodePrinter
class _einsum(torch.autograd.Function):
#############################
## Triton-C code generation
#############################
def print_cc(expr, axes_0, axes_1, axes_2, prefix):
if expr in axes_0:
return f'{prefix}r{expr}[:, newaxis, newaxis]'
if expr in axes_1:
return f'{prefix}r{expr}[newaxis, :, newaxis]'
if expr in axes_2:
return f'{prefix}r{expr}[newaxis, newaxis, :]'
return expr
def unpack_cc(tile, axes, prefix, remat):
ret = ''
axes = list(map(str, axes))
for i, d in enumerate(reversed(axes)):
if i == len(axes) - 1:
break
currs = ''.join(axes[: len(axes) - i])
nexts = ''.join(axes[: len(axes) - (i + 1)])
ty = '' if remat else 'int '
sz = '' if remat or tile is None else f'[{tile}]'
ret += f' {ty}{prefix}{nexts}{sz} = r{currs} / dim_{d};\n'
ret += f' {ty}{prefix}{d}{sz} = r{currs} % dim_{d};\n'
return ret
def strides_cc(name, expr):
ret = [f'stride_{name}_{d}' for d in expr[:-1]] + ['1']
ret = dict(zip(expr, ret))
return ret
def make_kernel(name, dtype,
expr_a, expr_b, expr_c,
sparse_a, sparse_b, sparse_c,
axes_m, axes_n, axes_k, axes_b,
multipleof_a, multipleof_b, multipleof_c,
stride_a_last, stride_b_last, stride_c_last,
lut_mode_a, lut_mode_b,
delta_a, delta_b,
blocks):
use_lut_a = True
use_lut_b = True
outer_sparse_a = [x for x in expr_a if x in sparse_a and x not in axes_k]
outer_dense_a = [x for x in expr_a if x not in sparse_a and x not in axes_k]
outer_sparse_b = [x for x in expr_b if x in sparse_b and x not in axes_k]
outer_dense_b = [x for x in expr_b if x not in sparse_b and x not in axes_k]
outer_dense_c = [x for x in expr_c if x not in sparse_c and x not in axes_k]
src = ""
if use_lut_a and lut_mode_a == _einsum.LUT_MODE.CONSTANT:
src += f"""
char __constant__* AD = calloc({4*len(delta_a)});"""
if use_lut_b and lut_mode_b == _einsum.LUT_MODE.CONSTANT:
src += f"""
char __constant__* BD = calloc({4*len(delta_b)});"""
src += f"""
__global__ void {name}(
TYPE * A __noalias __readonly __aligned(16)
, TYPE * B __noalias __readonly __aligned(16)
, TYPE * C
, int * locks
, float alpha
, int matmul_m, int matmul_n, int matmul_k __multipleof(16)
, int div_m
"""
for dim in [axes_m, axes_n, axes_k, axes_b]:
for d in dim:
src += f", int dim_{d}"
src += "\n "
for dim, name, mult, sparse in zip([expr_a, expr_b, expr_c],
['a', 'b', 'c'],
[multipleof_a, multipleof_b, multipleof_c],
[sparse_a, sparse_b, sparse_c]):
for d in range(len(dim) - 1):
if sparse and dim[d] == sparse[0]:
src += f', int stride_{name}_block __multipleof({mult})'
src += f", int stride_{name}_{d} __multipleof({mult})"
src += "\n "
if lut_mode_a == _einsum.LUT_MODE.SCALAR:
src += f", int stride_a_inner __multipleof({multipleof_a})"
src += f", int rem_delta_a __multipleof({multipleof_a})"
elif sparse_a or lut_mode_a == _einsum.LUT_MODE.DRAM:
src += ", int* AD __noalias __readonly __aligned(16)"
src += "\n "
if lut_mode_b == _einsum.LUT_MODE.SCALAR:
src += f", int stride_b_inner __multipleof({multipleof_b})"
src += f", int rem_delta_b __multipleof({multipleof_b})"
elif sparse_b or lut_mode_b == _einsum.LUT_MODE.DRAM:
src += ", int* BD"
src += "\n "
if sparse_c:
src += ", int* CD"
if sparse_a or sparse_b:
src += ", int width"
src += """) {
// program identifiers
int pid_0 = get_program_id(0);
int pid_1 = get_program_id(1);
"""
if sparse_a:
src += f"""
int off_n = pid_0 / width;
int off_header = pid_0 % width;
int* header = AD + off_header * {2 + len(outer_sparse_a)};
int* pdelta = AD + *(header + 0);
matmul_k = *(header + 1);"""
for i, d in enumerate(outer_sparse_a):
src += f"""
int off_{d} = *(header + {2 + i});"""
src += f"""
int inca = *(pdelta + 0);
int incb = *(pdelta + 1);
int off_{''.join(map(str, outer_dense_a))} = pid_1;
"""
_einsum.unpack_cc(None, outer_dense_a, "off_", False)
elif sparse_b:
src += f"""
int off_m = pid_0 / width;
int off_header = pid_0 % width;
int* header = BD + off_header * {2 + len(outer_sparse_b)};
int* pdelta = BD + *(header + 0);
matmul_k = *(header + 1);"""
for i, d in enumerate(outer_sparse_b):
src += f"""
int off_{d} = *(header + {2 + i});"""
src += f"""
int incb = *(pdelta + 0);
int inca = *(pdelta + 1);
int off_{''.join(map(str, outer_dense_b))} = pid_1;
"""
_einsum.unpack_cc(None, outer_dense_b, "off_", False)
elif sparse_c:
src += f"""
// load LUT header
int *header = CD + pid_0 * {len(sparse_c)};"""
for i, d in enumerate(sparse_c):
src += f"""
int off_{d} = *(header + {i});"""
src += f"""
int off_{''.join(map(str, outer_dense_c))} = pid_1;"""
else:
src += """
// re-order outer program ids
int grid_m = (matmul_m + TM - 1) / TM;
int grid_n = (matmul_n + TN - 1) / TN;
int off_mn = pid_0 / div_m;
int off_n = off_mn % grid_n;
int off_m = (off_mn / grid_n)*div_m + (pid_0 % div_m);
int off_b = get_program_id(1);"""
src += """
#if TZ == 1
int off_k = 0;
#else
// get reduction sub-group program id
int pid_z = get_program_id(2);
int grid_z = get_num_programs(2);
int div_z = matmul_k / TZ;
int rem_z = matmul_k % TZ;
int off_k = pid_z * div_z;
matmul_k = select(pid_z < rem_z, div_z, div_z + rem_z);
#endif
int rem_k = matmul_k % TK;
// create ranges
"""
sparse = sparse_a + sparse_b + sparse_c
for axes, tile, off, prefixes in zip([axes_m, axes_n, axes_b, axes_k],
['TM', 'TN', 'TB', 'TK'],
['off_m*TM', 'off_n*TN', 'off_b*TB', 'off_k'],
[['a', 'c'], ['b', 'c'], ['a', 'b', 'c'], ['a', 'b']]):
if not axes:
continue
currs = ''.join(map(str,axes))
has_sparse_component = set(axes) & set(sparse)
if has_sparse_component:
src += f" int r{currs}[{tile}] = 0 ... {tile};\n"
src += _einsum.unpack_cc(tile, axes, f'r', False)
else:
src += f" int r{currs}[{tile}] = {off} + 0 ... {tile};\n"
src += _einsum.unpack_cc(tile, axes, f'r', False)
for pfx in prefixes:
for d in axes:
is_dense_dim = d not in sparse
is_dense_storage = (pfx == 'a' and not sparse_a) or\
(pfx == 'b' and not sparse_b) or\
(pfx == 'c' and not sparse_c)
if not is_dense_dim and is_dense_storage:
src += f" int {pfx}r{d}[{tile}] = off_{d} * BLOCK{d.upper()} + r{d};\n"
elif is_dense_dim and has_sparse_component:
src += f" int {pfx}r{d}[{tile}] = off_{d};\n"
else:
src += f" int {pfx}r{d}[{tile}] = r{d};\n"
src += f"""
// initialize pointers to A
int offa[TM, TK, TB] = {'inca' if sparse_a or sparse_b else '0'} """
for i, sym in enumerate(expr_a):
ccode = _einsum.print_cc(sym, axes_m, axes_k, axes_b, 'a')
stride = f'stride_a_{i}' if i < len(expr_a) - 1 else f'{stride_a_last}'
src += f" + ({ccode}) * {stride}\n "
src += ';'
src += """
TYPE *pa[TM, TK, TB] = A + offa;"""
if not sparse_a and not sparse_b and use_lut_a and not lut_mode_a == _einsum.LUT_MODE.SCALAR:
spec = '__constant__' if lut_mode_a == _einsum.LUT_MODE.CONSTANT else ''
cast = '(int __constant__*)' if lut_mode_a == _einsum.LUT_MODE.CONSTANT else ''
src += f"""
int offadelta[TK] = off_k + 0 ... TK;
int {spec} *padelta[TK] = {cast}AD + offadelta;
int incda[TM, TK, TB] = (*padelta)[newaxis, :, newaxis];"""
src += f"""
// initialize pointers to B
int offb[TK, TN, TB] = {'incb' if sparse_a or sparse_b else '0'}"""
for i, sym in enumerate(expr_b):
ccode = _einsum.print_cc(sym, axes_k, axes_n, axes_b, 'b')
stride = f'stride_b_{i}' if i < len(expr_b) - 1 else f'{stride_b_last}'
src += f" + ({ccode}) * {stride}\n "
src += ';'
src += """
TYPE *pb[TK, TN, TB] = B + offb;"""
if not sparse_a and not sparse_b and use_lut_b and not lut_mode_b == _einsum.LUT_MODE.SCALAR:
spec = '__constant__' if lut_mode_b == _einsum.LUT_MODE.CONSTANT else ''
cast = '(int __constant__*)' if lut_mode_b == _einsum.LUT_MODE.CONSTANT else ''
src += f"""
// initialize pointers to B look-up table
int offbdelta[TK] = off_k + 0 ... TK;
int *pbdelta[TK] = BD + offbdelta;"""
rk = 'r{}'.format(''.join(map(str,axes_k)))
src += f"""
// prefetch
int prefetch_k = select(rem_k > 0, rem_k, TK);
bool checkam[TM] = ar""" + ''.join(map(str,axes_m)) + f""" < matmul_m;
bool checkbn[TN] = br""" + ''.join(map(str,axes_n)) + f""" < matmul_n;
bool checkk[TK] = r{''.join(map(str, axes_k))} < prefetch_k;
bool checka[TM, TK, TB] = checkam[:, newaxis, newaxis] && checkk[newaxis, :, newaxis];
bool checkb[TK, TN, TB] = checkk[:, newaxis, newaxis] && checkbn[newaxis, :, newaxis];
TYPE a[TM, TK, TB] = checka ? *pa : 0;
TYPE b[TK, TN, TB] = checkb ? *pb : 0;"""
if sparse_a:
src += f"""
// update pointers to look-up tables
pdelta += 2;
int incda = *(pdelta + 0);
int incdb = *(pdelta + 1);
pa += incda;
pb += incdb;"""
if sparse_b:
src += f"""
// update pointers to look-up tables
pdelta += 2;
int incdb = *(pdelta + 0);
int incda = *(pdelta + 1);
pa += incda;
pb += incdb;"""
if not sparse_a and not sparse_b and lut_mode_a == _einsum.LUT_MODE.SCALAR:
src += """
pa += rem_delta_a;"""
elif not sparse_a and not sparse_b:
src += """
pa += incda;
padelta += TK;
incda = (*padelta)[newaxis, :, newaxis];"""
if not sparse_a and not sparse_b and lut_mode_b == _einsum.LUT_MODE.SCALAR:
src += """
pb += rem_delta_b;"""
elif not sparse_a and not sparse_b:
src += """
pb += (*pbdelta)[:, newaxis, newaxis];
pbdelta += TK;"""
src += f"""
// accumulate
float acc[TM, TN, TB] = 0;
for(int k = matmul_k; k > 0; k -= TK) {{
acc += a @ b;
// load inputs
checkk = k > TK;
checka = checkam[:, newaxis, newaxis] && checkk[newaxis, :, newaxis];
checkb = checkk[:, newaxis, newaxis] && checkbn[newaxis, :, newaxis];
a = *?(checka)pa;
b = *?(checkb)pb;
// update pointers"""
if sparse_a:
src += """
pdelta += 2;
incda = *(pdelta + 0);
incdb = *(pdelta + 1);
pa += incda;
pb += incdb;
"""
elif sparse_b:
src += """
pdelta += 2;
incdb = *(pdelta + 0);
incda = *(pdelta + 1);
pa += incda;
pb += incdb;
"""
else:
if lut_mode_a == _einsum.LUT_MODE.SCALAR:
src += """
pa += stride_a_inner;"""
else:
src += """
pa += incda;
padelta += TK;
incda = (*padelta)[newaxis, :, newaxis];"""
if lut_mode_b == _einsum.LUT_MODE.SCALAR:
src += """
pb += stride_b_inner;"""
else:
src += """
pb += (*pbdelta)[:, newaxis, newaxis];
pbdelta += TK;"""
src += f"""
}}
TYPE c[TM, TN, TB] = acc;
// initialize pointers to C
int offc[TM, TN, TB] = {'pid_0*TN*TN' if sparse_c else 0}"""
for i, sym in enumerate(expr_c):
stride = f'stride_c_{i}' if i < len(expr_c) - 1 else f'{stride_c_last}'
ccode = _einsum.print_cc(sym, axes_m, axes_n, axes_b, 'c')
src += f"\n + ({ccode}) * {stride}"
src += ';'
src += """
TYPE *pc[TM, TN, TB] = C + offc;
// bounds-checking
bool checkcm[TM] = cr""" + ''.join(map(str,axes_m)) + """ < matmul_m;
bool checkcn[TN] = cr""" + ''.join(map(str,axes_n)) + """ < matmul_n;
bool checkc[TM, TN, TB] = checkcm[:, newaxis, newaxis] &&
checkcn[newaxis, :, newaxis];
// write back
#if TZ == 1
*?(checkc)pc = c;
#else
int *plock = locks + pid_mn + pid_b * get_num_programs(0);
int *pcount = plock + 1024*1024;
// spin
for(int repeat = 1; repeat == 1; repeat = atomic_cas(plock, 0, 1));
int count = *pcount;
if(count == 0)
*?(checkc)pc = c;
else
*?(checkc)pc = c + *?(checkc)pc;
atomic_xchg(pcount, (count + 1) % (grid_z));
atomic_xchg(plock, 0);
#endif
}
"""
# compilation options
TM, TN, TB, TZ = [32], [32], 1, [1]
TK = 16 if dtype==torch.float16 else 8
defines = {'TM': TM, 'TN': TN, 'TB': TB, 'TK': TK, 'TZ': TZ, 'TYPE': dtype}
for d, B in blocks.items():
defines[f'BLOCK{d}'] = B
# create kernel
ret = triton.kernel(src, defines=defines)
# set constant
if use_lut_a and lut_mode_a == _einsum.LUT_MODE.CONSTANT:
ret.set_constant('AD', delta_a)
if use_lut_b and lut_mode_b == _einsum.LUT_MODE.CONSTANT:
ret.set_constant('BD', delta_b)
return ret
############################
## Look-up Table
############################
class LUT_MODE(IntEnum):
SCALAR = 1
CONSTANT = 2
DRAM = 3
def lut_mode(delta):
if delta.size == 0 or np.min(delta) == np.max(delta):
return _einsum.LUT_MODE.SCALAR
#if delta.size < 4096:
# return _einsum.LUT_MODE.CONSTANT
return _einsum.LUT_MODE.DRAM
def symbolic_delta(symbols, axes):
rank = len(symbols)
strides = [sp.symbols(f'stride{d}') for d in range(rank)]
nexts = {s: sp.symbols(f'next{s}') for s in axes}
delta = 0
for i in range(rank):
delta += strides[i] * (symbols[i].subs(nexts) - symbols[i])
return delta
def unpack_offset(k, axes, dims):
ret = dict()
for d in reversed(axes):
ret[d] = k % dims[d]
k = k // dims[d]
return ret
def make_dsd_delta(axes, step, stride, dims, symbols, sparse, layout, blocks):
# depth of reductions
depth = layout.sum(*[i for i, d in enumerate(sparse) if d in axes])
# outer dimension indices
outer = torch.nonzero(depth, as_tuple=False)
outer = [outer[:,i] for i in range(outer.shape[1])]
# find offset of outer dimensions
depth = depth.view(-1)
offsets = torch.zeros_like(depth)
offsets[1:] = torch.cumsum(depth[:-1], 0)
# compute delta for b
# TODO: support multiple sparse red indices
col = next((i for i, d in enumerate(sparse) if d in axes), None)
block = blocks[sparse[-1].upper()]
div = block // step
delta_b = torch.nonzero(layout.transpose(-1, col), as_tuple=False)[:, -1].reshape(-1).contiguous()
delta_b *= block
delta_b = [delta_b + step*i for i in range(div)]
delta_b = torch.stack(delta_b, dim=1)
delta_b = delta_b.view(-1)
# compute delta for a
bstride = 1
for d in sparse[::-1]:
if d in axes:
break
bstride *= blocks[d.upper()]
order = [d for d in sparse if d not in axes] +\
[d for d in sparse if d in axes]
idx = [sparse.index(d) for d in order]
layout[layout > 0] = 1 + torch.arange(layout.sum(), device=layout.device)
layout = layout.permute(*idx)
delta_a = layout[layout > 0] - 1
delta_a *= np.prod(list(blocks.values()))
saved = delta_a[offsets]
delta_a[1:] = delta_a[1:] - delta_a[:-1]
delta_a = delta_a.view(-1, 1).repeat(1, div)
delta_a[:, 1:] = step*bstride
delta_a[:, 0] -= (div - 1)*step*bstride
delta_a[offsets, 0] = saved
delta_a = delta_a.view(-1)
delta = torch.stack((delta_a, delta_b), dim=1).view(-1).contiguous()
# form look-up table
depth *= blocks[symbols[-1].upper()]
offsets *= div
header = torch.stack((offsets, depth, *outer), dim=1).view(-1).contiguous()
nouter = 2 + len(outer)
header[::nouter] = header[::nouter]*2 + header.shape[0]
lut = torch.cat((header, delta)).int().int().cpu().numpy()
return lut, nouter, _einsum.LUT_MODE.DRAM
def make_delta(axes, step, stride, dims, symbols, sparse, layout, lut = None, nouter = None):
# symbolic pointer increments
symbols = [sp.symbols(x) for x in symbols]
delta = _einsum.symbolic_delta(symbols, axes)
args = [f'stride{d}' for d in range(len(stride))]
args += [f'{sk}' for sk in axes]
args += [f'next{sk}' for sk in axes]
fn = sp.lambdify(args, delta, 'numpy')
if lut is None:
# inner axes values
inner = [dims[d] for d in axes]
inner = np.prod(inner)
rem = inner % step
rem = rem if rem > 0 else step
# k = [0, 1, ..., step, rem, rem + 1, ... rem + inner]
# nextk = [rem, 1 + rem, ..., step + rem, rem + step, rem + 1 + step, ..., inner + step]
k = np.concatenate((np.arange(step), np.arange(rem, inner))).astype(np.int32)
nextk = np.concatenate((k[:step] + rem, k[step:] + step))
else:
idx = (lut[:lut[0]:nouter] - lut[0])//2
k = lut[lut[0]+1::2]
k = np.insert(k, idx, 0)
nextk = k[1:]
k = k[:-1]
# offsets
off = _einsum.unpack_offset(k, axes, dims)
nextoff = _einsum.unpack_offset(nextk, axes, dims)
# evaluate deltas
args = [s for s in stride]
args += [off[sk] for sk in axes]
args += [nextoff[sk] for sk in axes]
delta = fn(*args)
delta = np.maximum(delta, 0)
if lut is not None:
idx = idx[1:] + np.arange(idx.shape[0] - 1)
delta = np.delete(delta, idx)
lut[lut[0]+1::2] = delta
return None, None
return delta, _einsum.lut_mode(delta[step:-step])
@staticmethod
def make_sdd_lut(layout_c, sparse_c, blocks):
nnz = torch.nonzero(layout_c, as_tuple=False)
lut = nnz.reshape(-1).int().cuda()
return lut
############################
## Compilation
############################
class instance:
locks = None
kernel_cache = dict()
def __init__(self, einsum, dtype, stride_a, stride_b, shape_a, shape_b, layouts, blocks):
# parse symbols
expr_a, expr_bc = einsum.split(",")
expr_b, expr_c = expr_bc.split("->")
sym_a = expr_a.lower()
sym_b = expr_b.lower()
sym_c = expr_c.lower()
sparse_a = [x.lower() for x in expr_a if x.isupper()]
sparse_b = [x.lower() for x in expr_b if x.isupper()]
sparse_c = [x.lower() for x in expr_c if x.isupper()]
layout_a = layouts.get(expr_a)
layout_b = layouts.get(expr_b)
layout_c = layouts.get(expr_c)
# parse axes
axes_b = [d for d in sym_a if d in sym_b and d in sym_c]
axes_m = [d for d in sym_a if d not in sym_b and d in sym_c]
axes_k = [d for d in sym_a if d in sym_b and d not in sym_c]
axes_n = [d for d in sym_b if d not in sym_a and d in sym_c]
axes = axes_b + axes_m + axes_n + axes_k
# check block sizes
for d in sparse_a + sparse_b + sparse_c:
if d.upper() not in blocks:
raise ValueError(f'unspecified block size for dimension: {d.upper()}')
# check layout is present
if sparse_a and layout_a is None:
raise ValueError('A is sparse but not layout provided')
if sparse_b and layout_b is None:
raise ValueError('B is sparse but not layout provided')
if sparse_c and layout_c is None:
raise ValueError('C is sparse but not layout provided')
# check dimensions
dims_a = dict([(x, y) for x,y in zip(sym_a, shape_a) if x not in sparse_a])
dims_b = dict([(x, y) for x,y in zip(sym_b, shape_b) if x not in sparse_b])
dims_La = None if layout_a is None else dict(zip([x for x in expr_a if x.isupper()], layout_a.shape))
dims_Lb = None if layout_b is None else dict(zip([x for x in expr_b if x.isupper()], layout_b.shape))
# TODO: could be cleaner
read_shape = lambda d, dimsT, dimsL, sparse: dimsL[d.upper()] * blocks[d.upper()] if d in sparse else dimsT[d]
for d in axes_b + axes_m + axes_n + axes_k:
dim_a = read_shape(d, dims_a, dims_La, sparse_a) if d in sym_a else None
dim_b = read_shape(d, dims_b, dims_Lb, sparse_b) if d in sym_b else None
if d in axes_b and dim_a and dim_b and dim_a != dim_b:
raise ValueError(f'incomparible batch dimension {d} (A: {dim_a}, B: {dim_b})')
if d in axes_k and dim_a and dim_b and dim_a != dim_b:
raise ValueError(f'incompatible inner dimension {d} (A: {dim_a}, B: {dim_b})')
dims = dict()
dims.update(dims_a)
dims.update(dims_b)
for i, d in enumerate(sparse_a):
dims[d] = layout_a.shape[i] * blocks[d.upper()]
for i, d in enumerate(sparse_b):
dims[d] = layout_b.shape[i] * blocks[d.upper()]
# allocate output
shape_c = [dims[d] if d.islower() else blocks[d] for d in expr_c]
if sparse_c:
shape_c.insert(expr_c.index(sparse_c[0].upper()), int(layout_c.sum()))
stride_c = [None] * len(shape_c)
stride_c[-1] = 1
for i in reversed(range(len(shape_c) - 1)):
stride_c[i] = stride_c[i+1] * shape_c[i+1]
# look-up tables
TK = 16 if dtype == torch.float16 else 8
if sparse_a and not sparse_b:
delta_a, nouter, lut_mode_a = _einsum.make_dsd_delta(axes_k, TK, stride_a, dims, sym_a, sparse_a, layout_a, blocks)
delta_b, lut_mode_b = _einsum.make_delta(axes_k, TK, stride_b, dims, sym_b, sparse_b, layout_b, delta_a, nouter)
if sparse_b and not sparse_a:
delta_b, nouter, lut_mode_b = _einsum.make_dsd_delta(axes_k, TK, stride_b, dims, sym_b, sparse_b, layout_b, blocks)
delta_a, lut_mode_a = _einsum.make_delta(axes_k, TK, stride_a, dims, sym_a, sparse_a, layout_a, delta_b, nouter)
if not sparse_a and not sparse_b:
delta_a, lut_mode_a = _einsum.make_delta(axes_k, TK, stride_a, dims, sym_a, sparse_a, layout_a)
delta_b, lut_mode_b = _einsum.make_delta(axes_k, TK, stride_b, dims, sym_b, sparse_b, layout_b)
if sparse_c:
delta_c = _einsum.make_sdd_lut(layout_c, sparse_c, blocks)
# hash for recompilation
stride_a_multiple = max([x for x in [1, 2, 4, 8] if shape_a[-1] % x == 0])
stride_b_multiple = max([x for x in [1, 2, 4, 8] if shape_b[-1] % x == 0])
stride_c_multiple = max([x for x in [1, 2, 4, 8] if shape_c[-1] % x == 0])
stride_a_last = stride_a[-1]
stride_b_last = stride_b[-1]
stride_c_last = stride_c[-1]
name = f'{dtype}_{expr_a}_{expr_b}_{expr_c}_{lut_mode_a}_{lut_mode_b}'\
f'_{stride_a_multiple}_{stride_b_multiple}_{stride_c_multiple}'\
f'_{stride_a_last}_{stride_b_last}_{stride_c_last}'
# recompile if necessary
cache = _einsum.instance.kernel_cache
if name not in cache:
cachesize = len(cache)
cache[name] = _einsum.make_kernel(f'__einsum{cachesize}',
dtype,
sym_a, sym_b, sym_c,
sparse_a, sparse_b, sparse_c,
axes_m, axes_n, axes_k, axes_b,
stride_a_multiple, stride_b_multiple, stride_c_multiple,
stride_a_last, stride_b_last, stride_c_last,
lut_mode_a, lut_mode_b,
delta_a, delta_b,
blocks)
self.kernel = cache[name]
# Initialize locks
if _einsum.instance.locks is None:
_einsum.instance.locks = torch.zeros(2*1024*1024, dtype=torch.int32).cuda()
# Kernel arguments
dim_m = [dims[d] for d in axes_m]
dim_n = [dims[d] for d in axes_n]
dim_k = [dims[d] for d in axes_k]
dim_b = [dims[d] for d in axes_b]
M = reduce(mul, dim_m, 1)
N = reduce(mul, dim_n, 1)
K = reduce(mul, dim_k, 1)
B = reduce(mul, [dims[d] for d in axes_b if d.upper() not in einsum], 1)
stride_a = list(stride_a[:-1])
stride_b = list(stride_b[:-1])
stride_c = list(stride_c[:-1])
alpha = 1.
div_m = 1
self.args = [None, None, None]
self.args += [_einsum.instance.locks]
self.args += [alpha, M, N, K, div_m]
self.args += dim_m
self.args += dim_n
self.args += dim_k
self.args += dim_b
self.args += stride_a
self.args += stride_b
self.args += stride_c
# LUT for A
if lut_mode_a == _einsum.LUT_MODE.SCALAR:
self.args += [delta_a[TK], delta_a[0]]
elif sparse_a or lut_mode_a == _einsum.LUT_MODE.DRAM:
self.args += [torch.from_numpy(delta_a).cuda()]
# LUT for B
if lut_mode_b == _einsum.LUT_MODE.SCALAR:
self.args += [delta_b[TK], delta_b[0]]
elif sparse_b or lut_mode_b == _einsum.LUT_MODE.DRAM:
self.args += [torch.from_numpy(delta_b).cuda()]
# LUT for C
if sparse_c:
self.args += [delta_c]
if sparse_a or sparse_b:
width = delta_a[0] // nouter if sparse_a else delta_b[0] // nouter
self.args += [width]
# Grid
if sparse_a:
self.grid = lambda opt: [width*triton.cdiv(N, opt.d('TN')), B, opt.d('TZ')]
elif sparse_b:
self.grid = lambda opt: [width*triton.cdiv(M, opt.d('TM')), B, opt.d('TZ')]
elif sparse_c:
width = int(layout_c.sum())
self.grid = lambda opt: [width, B, opt.d('TZ')]
else:
self.grid = lambda opt: [triton.cdiv(M, opt.d('TM')) *
triton.cdiv(N, opt.d('TN')),
triton.cdiv(B, opt.d('TB')),
opt.d('TZ')]
# position of dynamic arguments
self.pos_a = 0
self.pos_b = 1
self.pos_c = 2
# save information on the operation
self.expr_a = expr_a
self.expr_b = expr_b
self.expr_c = expr_c
self.matmul_B = B
self.matmul_M = M
self.matmul_N = N
self.matmul_K = K
# output shape
self.shape_c = shape_c
def run(self, a, b):
c = torch.empty(*self.shape_c, dtype=a.dtype, device=a.device)
self.args[self.pos_a] = a
self.args[self.pos_b] = b
self.args[self.pos_c] = c
self.kernel(*self.args, grid=self.grid)
return c
############################
## Forward
############################
instance_cache = dict()
registry = dict()
@staticmethod
def forward(ctx, expr, a, b, layouts, blocks):
# compile einsum instance
cache = _einsum.instance_cache
key = (expr, a.dtype,
a.stride(), b.stride(),
a.shape , b.shape)
if key not in cache:
cache[key] = _einsum.instance(expr, a.dtype,
a.stride(), b.stride(),
a.shape , b.shape ,
layouts, blocks)
instance = cache[key]
# run and mark as dirty c modified in-place
c = instance.run(a, b)
# save information in context
ctx.expr_a = instance.expr_a
ctx.expr_b = instance.expr_b
ctx.expr_c = instance.expr_c
ctx.matmul_B = instance.matmul_B
ctx.matmul_M = instance.matmul_M
ctx.matmul_N = instance.matmul_N
ctx.matmul_K = instance.matmul_K
ctx.save_for_backward(a, b)
return c
############################
## Backward
############################
@staticmethod
def backward(ctx, dy):
a, b = ctx.saved_tensors
expr_a = ctx.expr_a
expr_b = ctx.expr_b
expr_c = ctx.expr_c
# gradient of first argument
da = None
if ctx.needs_input_grad[1]:
da = torch.empty_like(a)
einsum(f'{expr_c},{expr_b}->{expr_a}', dy, b, da)
# gradient of second argument
db = None
if ctx.needs_input_grad[2]:
db = torch.empty_like(b)
einsum(f'{expr_a},{expr_c}->{expr_b}', a, dy, db)
return None, da, db, None, None, None, None, None, None, None
def einsum(expr, a, b, layouts = None, blocks = dict()):
return _einsum.apply(expr, a, b, layouts, blocks)