[GENERAL] Various improvements:

* Sparse einsum in triton.ops.einsum
* Hacky support for fixed-tile-size atomic-add
* Various bugfixes in parser
This commit is contained in:
Philippe Tillet
2020-10-25 11:55:58 -07:00
parent 444907589d
commit 049ab989b5
16 changed files with 574 additions and 331 deletions

View File

@@ -136,7 +136,7 @@ public:
value *create_get_num_program(unsigned axis, const std::string &name = "");
value *create_atomic_cas(value *ptr, value *cmp, value *val, const std::string &name = "");
value *create_atomic_exch(value *ptr, value *val, const std::string &name = "");
value *create_atomic_add(value *ptr, value *val, const std::string &name = "");
value *create_atomic_add(value *ptr, value *val, value *msk, const std::string &name = "");
value *create_exp(value* arg, const std::string &name = "");
value *create_dot(value *A, value *B, value *C, const std::string &name = "");
value *create_trans(value *A, const std::vector<int> &perm = {}, const std::string &name = "");

View File

@@ -603,13 +603,13 @@ public:
class atomic_add_inst: public builtin_inst {
private:
atomic_add_inst(value *ptr, value *val, const std::string &name = "", instruction *next = nullptr);
atomic_add_inst(value *ptr, value *val, value *msk, const std::string &name = "", instruction *next = nullptr);
std::string repr_impl() const { return "atomic_add"; }
_TRITON_DEFINE_CLONE(atomic_add_inst)
_TRITON_DEFINE_ACCEPT(atomic_add_inst)
public:
static instruction* create(value *ptr, value *val, const std::string &name = "", instruction *next = nullptr);
static instruction* create(value *ptr, value *val, value *msk, const std::string &name = "", instruction *next = nullptr);
};
class exp_inst: public builtin_inst {

View File

@@ -364,7 +364,14 @@ void generator::visit_unmasked_load_inst(ir::unmasked_load_inst* x) {
void generator::visit_masked_load_inst(ir::masked_load_inst* x) {
// find vector size
ir::value *ptr = x->get_pointer_operand();
size_t ld = layouts_->get(ptr)->get_order(0);
auto order = layouts_->get(ptr)->get_order();
size_t ld;
for(size_t i = 0; i < order.size(); i++){
ld = order[i];
if(ld < x->get_type()->get_tile_rank())
break;
}
//size_t ld = layouts_->get(ptr)->get_order(0);
unsigned alignment = alignment_->get(ptr, ld);
distributed_tile *pointers = (distributed_tile*)tmap_.at(ptr);
distributed_tile *masks = (distributed_tile*)tmap_.at(x->get_mask_operand());
@@ -652,6 +659,31 @@ void generator::visit_atomic_exch_inst(ir::atomic_exch_inst* xchg) {
}
void generator::visit_atomic_add_inst(ir::atomic_add_inst* add) {
if(add->get_type()->is_tile_ty()){
ir::value* ptr = add->get_operand(0);
ir::value* val = add->get_operand(1);
ir::value* msk = add->get_operand(2);
distributed_tile* ptrs = (distributed_tile*)tmap_.at(ptr);
distributed_tile* vals = (distributed_tile*)tmap_.at(val);
distributed_tile* msks = (distributed_tile*)tmap_.at(msk);
for_each(ptr, [&](indices_t idx){
Value *rmw_ptr = ptrs->get_value(idx);
Value *rmw_val = vals->get_value(idx);
Value *rmw_msk = msks->get_value(idx);
BasicBlock *current_bb = builder_->GetInsertBlock();
Function *parent = builder_->GetInsertBlock()->getParent();
BasicBlock *mask_then_bb = BasicBlock::Create(*ctx_, "mask_then", parent);
BasicBlock *mask_done_bb = BasicBlock::Create(*ctx_, "mask_done", parent);
builder_->CreateCondBr(rmw_msk, mask_then_bb, mask_done_bb);
builder_->SetInsertPoint(mask_then_bb);
builder_->CreateAtomicRMW(AtomicRMWInst::FAdd, rmw_ptr, rmw_val,
AtomicOrdering::Monotonic,
SyncScope::System);
builder_->CreateBr(mask_done_bb);
builder_->SetInsertPoint(mask_done_bb);
});
}
else{
BasicBlock *current = builder_->GetInsertBlock();
Module *module = current->getModule();
Value *rmw_ptr = vmap_.at(add->get_operand(0));
@@ -670,6 +702,7 @@ void generator::visit_atomic_add_inst(ir::atomic_add_inst* add) {
builder_->CreateBr(tid_0_done_bb);
builder_->SetInsertPoint(tid_0_done_bb);
tgt_->add_memfence(module, *builder_);
}
}
void generator::visit_hmma_dot(ir::dot_inst* dot, shared_tile *TA, shared_tile *TB, distributed_tile *TD, unsigned NK) {
@@ -1362,8 +1395,10 @@ void generator::visit_layout_shared(analysis::shared_layout* layout) {
void generator::visit_basic_block(ir::basic_block * block) {
BasicBlock *parent = (BasicBlock*)vmap_[block];
builder_->SetInsertPoint(parent);
for(ir::instruction *i: block->get_inst_list())
for(ir::instruction *i: block->get_inst_list()){
// std::cout << typeid(*i).name() << std::endl;
visit_value(i);
}
vmap_[block] = builder_->GetInsertBlock();
}

View File

@@ -253,7 +253,7 @@ cu_module::cu_module(driver::context * context, std::unique_ptr<llvm::Module> ll
cu_module::cu_module(driver::context * context, std::string const & source) : module(context, CUmodule(), true), source_(source){
cu_context::context_switcher ctx(*context);
// std::cout << source << std::endl;
// std::cout << source << std::endl;
// JIT compile source-code
CUjit_option opt[] = {CU_JIT_ERROR_LOG_BUFFER_SIZE_BYTES, CU_JIT_ERROR_LOG_BUFFER};
unsigned int errbufsize = 8096;

View File

@@ -307,8 +307,8 @@ value *builder::create_atomic_exch(value *ptr, value *val, const std::string &na
return insert(atomic_exch_inst::create(ptr, val, name));
}
value *builder::create_atomic_add(value *ptr, value *val, const std::string &name){
return insert(atomic_add_inst::create(ptr, val, name));
value *builder::create_atomic_add(value *ptr, value *val, value *msk, const std::string &name){
return insert(atomic_add_inst::create(ptr, val, msk, name));
}
value *builder::create_exp(value *arg, const std::string &name){

View File

@@ -736,14 +736,15 @@ instruction* atomic_exch_inst::create(value *ptr, value *val, const std::string
// atomic add
atomic_add_inst::atomic_add_inst(value *ptr, value *val, const std::string &name, instruction *next)
: builtin_inst(ptr->get_type()->get_pointer_element_ty(), INST_ATOMIC_ADD, 2, name, next) {
atomic_add_inst::atomic_add_inst(value *ptr, value *val, value *msk, const std::string &name, instruction *next)
: builtin_inst(ptr->get_type()->get_pointer_element_ty(), INST_ATOMIC_ADD, 3, name, next) {
set_operand(0, ptr);
set_operand(1, val);
set_operand(2, msk);
}
instruction* atomic_add_inst::create(value *ptr, value *val, const std::string &name, instruction *next) {
return new atomic_add_inst(ptr, val, name, next);
instruction* atomic_add_inst::create(value *ptr, value *val, value *msk, const std::string &name, instruction *next) {
return new atomic_add_inst(ptr, val, msk, name, next);
}
// exp

View File

@@ -523,7 +523,7 @@ void BinaryOp::RelationalOpTypeChecking() {
}
Convert();
}
type_ = ArithmType::New(T_INT);
type_ = ArithmType::New(T_BOOL);
Broadcast(this, lhs_, rhs_, type_);
}
@@ -538,7 +538,7 @@ void BinaryOp::EqualityOpTypeChecking() {
Error(this, "invalid operands to binary %s", tok_->str_.c_str());
Convert();
}
type_ = ArithmType::New(T_INT);
type_ = ArithmType::New(T_BOOL);
Broadcast(this, lhs_, rhs_, type_);
}
@@ -558,7 +558,7 @@ void BinaryOp::LogicalOpTypeChecking() {
::Type* rhsScalType = TryExtractScalarType(this, rhs_);
if (!lhsScalType->IsScalar() || !rhsScalType->IsScalar())
Error(this, "the operand should be arithmetic type or pointer");
type_ = ArithmType::New(T_INT);
type_ = ArithmType::New(T_BOOL);
Broadcast(this, lhs_, rhs_, type_);
}

View File

@@ -277,12 +277,14 @@ void Generator::VisitFuncCall(FuncCall* funcCall) {
ir::value* val = ret_;
return set_ret(bld_->create_atomic_exch(ptr, val));
}
if(name == "f32_atomic_add"){
if(name == "f32_atomic_add" || name == "atomic_add_64x64"){
VisitExpr(funcCall->Args()->at(0));
ir::value* ptr = ret_;
VisitExpr(funcCall->Args()->at(1));
ir::value* val = ret_;
return set_ret(bld_->create_atomic_add(ptr, val));
VisitExpr(funcCall->Args()->at(2));
ir::value* msk = ret_;
return set_ret(bld_->create_atomic_add(ptr, val, msk));
}
if(name == "sqrtf"){
VisitExpr(funcCall->Args()->at(0));
@@ -338,6 +340,7 @@ void Generator::VisitTempVar(TempVar* tempVar) {
}
// Statement
// TODO: int x = x; crashes
void Generator::VisitDeclaration(Declaration* decl) {
auto obj = decl->obj_;
// initialize to undef

View File

@@ -650,7 +650,7 @@ Expr* Parser::ParseDerefOp(const Token* tok) {
Expr* pred = nullptr;
if(ts_.Try('?')){
ts_.Expect('(');
pred = ParseCastExpr();
pred = ParseExpr();
ts_.Expect(')');
}
Expr* addr = ParseCastExpr();

View File

@@ -239,6 +239,7 @@ std::unique_ptr<driver::module> function::make_bin(ir::module &module,
throw std::runtime_error("using too much shared memory");
barriers.run(module);
isel.visit(module, *llvm);
// ir::print(module, std::cout);
std::unique_ptr<driver::module> res(driver::module::create(context, std::move(llvm)));
return res;
}
@@ -351,6 +352,8 @@ std::string function::preheader() {
extern int atomic_cas(int*, int, int);
extern int atomic_xchg(int*, int);
extern float f32_atomic_add(float*, float);
extern void atomic_add_128x128(float*[128, 128], float[128, 128], bool[128, 128]);
extern void atomic_add_64x64(float*[64, 64], float[64, 64], bool[64, 64]);
extern int get_program_id(int);
extern int get_num_programs(int);
extern float sqrtf(float);

109
python/examples/test.py Normal file
View File

@@ -0,0 +1,109 @@
import triton
import numpy
import torch
import itertools
torch.manual_seed(0)
numpy.random.seed(0)
def to_sparse(expr, data, layout, shape, block):
# shape of result
sparse = None
shape_ret = []
for i, d in enumerate(expr):
if d.isupper() and sparse is None:
sparse = i
shape_ret.append(int(layout.sum()))
if d.isupper():
shape_ret.append(block[d])
else:
shape_ret.append(shape[i])
# iterator
steps = [block[d] if d.isupper() else 1 for d in expr]
it = [range(0, shape[i], steps[i]) for i in range(len(expr))]
# create result
ret = torch.empty(*shape_ret, dtype=data.dtype, device=data.device)
blockid = 0
nzblockid = 0
for curr in itertools.product(*it):
if all([curr[i] == it[i][0] for i in range(len(curr)) if expr[i].isupper()]):
blockid = 0
nzblockid = 0
data_slice = [slice(curr[i], curr[i] + steps[i], 1) for i in range(len(curr))]
ret_slice = [slice(0, block[expr[i]], 1) if expr[i].isupper() else slice(curr[i], curr[i] + 1) for i in range(len(curr))]
ret_slice.insert(sparse, nzblockid)
if int(layout.view(-1)[blockid]):
ret[ret_slice] = data[data_slice]
nzblockid += 1
blockid += 1
return ret
def to_dense(expr, data, layout, shape, block):
sparse = None
for i, d in enumerate(expr):
if d.isupper() and sparse is None:
sparse = i
ret = torch.zeros(*shape, dtype=data.dtype, device=data.device)
steps = [block[d] if d.isupper() else 1 for d in expr]
it = [range(0, shape[i], steps[i]) for i in range(len(expr))]
blockid = 0
nzblockid = 0
for curr in itertools.product(*it):
if all([curr[i] == it[i][0] for i in range(len(curr)) if expr[i].isupper()]):
blockid = 0
nzblockid = 0
ret_slice = [slice(curr[i], curr[i] + steps[i], 1) for i in range(len(curr))]
data_slice = [slice(0, block[expr[i]], 1) if expr[i].isupper() else slice(curr[i], curr[i] + 1) for i in range(len(curr))]
data_slice.insert(sparse, nzblockid)
if int(layout.view(-1)[blockid]):
ret[ret_slice] = data[data_slice]
nzblockid += 1
blockid += 1
return ret
def test_expr(expr, shape, blocks):
# decompose expr
expr_a, expr_bc = expr.split(",")
expr_b, expr_c = expr_bc.split("->")
# check with argument is sparse
sparse_a = any(x.isupper() for x in expr_a)
sparse_b = any(x.isupper() for x in expr_b)
sparse_c = any(x.isupper() for x in expr_c)
# allocate data
shape_a = [shape[d.lower()] for d in expr_a]
shape_b = [shape[d.lower()] for d in expr_b]
shape_c = [shape[d.lower()] for d in expr_c]
ref_a = torch.rand(*shape_a, device='cuda')
ref_b = torch.rand(*shape_b, device='cuda')
ref_c = torch.zeros(*shape_c, device='cuda')
# layouts
layout_a = [shape[d.lower()]//blocks[d] for d in expr_a if d.isupper()]
layout_b = [shape[d.lower()]//blocks[d] for d in expr_b if d.isupper()]
layout_c = [shape[d.lower()]//blocks[d] for d in expr_c if d.isupper()]
layout_a = torch.randint(0, 2, layout_a, device='cuda')
layout_b = torch.randint(0, 2, layout_b, device='cuda')
layout_c = torch.randint(0, 2, layout_c, device='cuda')
# triton computation
triton_a = to_sparse(expr_a, ref_a, layout_a, shape_a, blocks) if sparse_a else ref_a
triton_b = to_sparse(expr_b, ref_b, layout_b, shape_b, blocks) if sparse_b else ref_b
layouts = {expr_a: layout_a, expr_b: layout_b, expr_c: layout_c}
triton_c = triton.ops.einsum(expr, triton_a, triton_b, layouts, blocks)
torch.cuda.synchronize()
# reference computation
ref_a = to_dense(expr_a, triton_a, layout_a, shape_a, blocks) if sparse_a else ref_a
ref_b = to_dense(expr_b, triton_b, layout_b, shape_b, blocks) if sparse_b else ref_b
ref_c = torch.einsum(expr.lower(), ref_a, ref_b)
if sparse_c:
ref_c = to_sparse(expr_c, ref_c, layout_c, shape_c, blocks)
torch.cuda.synchronize()
print((ref_c - triton_c).abs().max())
# shape characteristics
test_expr('bHMK,bhkn->bhmn', {'b': 2, 'h': 2, 'm': 256, 'k': 256, 'n': 256}, {'H': 1, 'M': 32, 'K': 32})
test_expr('bhmk,bHKN->bhmn', {'b': 2, 'h': 2, 'm': 256, 'k': 256, 'n': 256}, {'H': 1, 'K': 32, 'N': 32})
test_expr('bhmk,bhkn->bHMN', {'b': 2, 'h': 2, 'm': 256, 'k': 256, 'n': 256}, {'H': 1, 'M': 32, 'N': 32})

View File

@@ -56,8 +56,8 @@ class _dot(torch.autograd.Function):
TYPE c[TM, TN] = acc;
// epilogue
int rxm[TM] = get_program_id(0) * TM + 0 ... TM;
int rxn[TN] = get_program_id(1) * TN + 0 ... TN;
int rxm[TM] = ridx * TM + 0 ... TM;
int rxn[TN] = ridy * TN + 0 ... TN;
int offc[TM, TN] = rxm[:, newaxis] * ldc + rxn[newaxis, :];
TYPE* pc[TM, TN] = C + offc;
bool checkc[TM, TN] = (rxm[:, newaxis] < M) && (rxn[newaxis, :] < N);
@@ -95,7 +95,7 @@ class _dot(torch.autograd.Function):
if dtype not in _dot.kernel:
defines = {
'TYPE' : dtype,
'STRIDE_AM': '1', 'STRIDE_AK': 'lda',
'STRIDE_AM': 'lda', 'STRIDE_AK': '1',
'STRIDE_BN': '1', 'STRIDE_BK': 'ldb',
'TM' : [64, 128],
'TN' : [64, 128],
@@ -107,14 +107,12 @@ class _dot(torch.autograd.Function):
# allocate output
M, K = a.shape
K, N = b.shape
c = triton.empty([M,N], dtype=dtype)
c = torch.empty([M,N], dtype=dtype, device=a.device)
# enqueue
grid = lambda opt: [triton.cdiv(M, opt.d('TM')),
triton.cdiv(N, opt.d('TN'))]
time = kernel(a, b, c, 1., M, N, K,
a.stride(0), b.stride(0), c.stride(0),
grid=grid, bench=100)
print(2*M*N*K/(time*1e-6)*1e-9)
a.stride(0), b.stride(0), c.stride(0), grid=grid)
return c
@@ -126,8 +124,10 @@ M, N, K = 2048, 2048, 2048
a = torch.rand((M, K)).cuda()
b = torch.rand((K, N)).cuda()
#a[:] = 1
#b[:] = 1
#zc = torch.matmul(a,b)
zc = torch.matmul(a,b)
zc_ = dot(a,b)
#print(torch.allclose(zc, zc_))
print(torch.allclose(zc, zc_))

View File

@@ -110,7 +110,7 @@ setup(
author_email='ptillet@g.harvard.edu',
description='A language and compiler for custom Deep Learning operations',
long_description='',
packages=['triton', 'triton/_C', 'triton/ops', 'triton/nn'],
packages=['triton', 'triton/_C', 'triton/ops'],
install_requires=['numpy', 'torch', 'sympy'],
package_data={'': data},
ext_modules=[CMakeExtension('triton', 'triton/_C/')],

View File

@@ -1,5 +1,5 @@
from .kernel import *
#import triton.ops
import triton.ops
#import triton.nn

View File

@@ -82,6 +82,7 @@ class kernel:
grid = kwargs['grid']
libtriton.register_grid((self.op_id, device), grid)
# launch
#print(self.tys)
params = pack(self.tys, *[x.data_ptr() if isinstance(x, torch.Tensor) else x for x in args])
torch.cuda.synchronize()
torch.ops.triton.launch_kernel(self.op_id, device, params)

View File

@@ -5,8 +5,8 @@ from operator import mul
from collections import OrderedDict
from collections import namedtuple
import re
import string
import triton
# torch
import torch
# numpy -- ideally removed in a future release
import numpy as np
@@ -22,33 +22,14 @@ class _einsum(torch.autograd.Function):
#############################
## Triton-C code generation
#############################
def print_cc(expr, axes_0, axes_1, axes_2):
class TritonCodePrinter(C89CodePrinter):
def __init__(self, axes_0, axes_1, axes_2):
super(TritonCodePrinter, self).__init__()
self.axes_0 = axes_0
self.axes_1 = axes_1
self.axes_2 = axes_2
def _print_Symbol(self, expr):
name = super(C89CodePrinter, self)._print_Symbol(expr)
if expr in self.axes_0:
return f'r{name}[:, newaxis, newaxis]'
if expr in self.axes_1:
return f'r{name}[newaxis, :, newaxis]'
if expr in self.axes_2:
return f'r{name}[newaxis, newaxis, :]'
return name
def _print_Indexed(self, expr):
assert len(expr.indices) == 1
return "*(%s + %s)" % (self._print(expr.base.label),
self._print(expr.indices[0]))
return TritonCodePrinter(axes_0, axes_1, axes_2).doprint(expr)
def print_cc(expr, axes_0, axes_1, axes_2, prefix):
if expr in axes_0:
return f'{prefix}r{expr}[:, newaxis, newaxis]'
if expr in axes_1:
return f'{prefix}r{expr}[newaxis, :, newaxis]'
if expr in axes_2:
return f'{prefix}r{expr}[newaxis, newaxis, :]'
return expr
def unpack_cc(tile, axes, prefix, remat):
ret = ''
@@ -59,7 +40,7 @@ class _einsum(torch.autograd.Function):
currs = ''.join(axes[: len(axes) - i])
nexts = ''.join(axes[: len(axes) - (i + 1)])
ty = '' if remat else 'int '
sz = '' if remat else f'[{tile}]'
sz = '' if remat or tile is None else f'[{tile}]'
ret += f' {ty}{prefix}{nexts}{sz} = r{currs} / dim_{d};\n'
ret += f' {ty}{prefix}{d}{sz} = r{currs} % dim_{d};\n'
return ret
@@ -69,18 +50,26 @@ class _einsum(torch.autograd.Function):
ret = dict(zip(expr, ret))
return ret
def make_kernel(name, dtype, mask,
def make_kernel(name, dtype,
expr_a, expr_b, expr_c,
sparse_a, sparse_b, sparse_c,
axes_m, axes_n, axes_k, axes_b,
multipleof_a, multipleof_b, multipleof_c,
stride_a_last, stride_b_last, stride_c_last,
lut_mode_a, lut_mode_b,
delta_a, delta_b,
subscripted, varnames):
blocks):
use_lut_a = True
use_lut_b = True
outer_sparse_a = [x for x in expr_a if x in sparse_a and x not in axes_k]
outer_dense_a = [x for x in expr_a if x not in sparse_a and x not in axes_k]
outer_sparse_b = [x for x in expr_b if x in sparse_b and x not in axes_k]
outer_dense_b = [x for x in expr_b if x not in sparse_b and x not in axes_k]
outer_dense_c = [x for x in expr_c if x not in sparse_c and x not in axes_k]
src = ""
if use_lut_a and lut_mode_a == _einsum.LUT_MODE.CONSTANT:
@@ -105,41 +94,91 @@ __global__ void {name}(
for d in dim:
src += f", int dim_{d}"
src += "\n "
for dim, name, mult in zip([expr_a, expr_b, expr_c],
for dim, name, mult, sparse in zip([expr_a, expr_b, expr_c],
['a', 'b', 'c'],
[multipleof_a, multipleof_b, multipleof_c]):
[multipleof_a, multipleof_b, multipleof_c],
[sparse_a, sparse_b, sparse_c]):
for d in range(len(dim) - 1):
attr = f'__multipleof({mult})'
src += f", int stride_{name}_{d} {attr}"
if sparse and dim[d] == sparse[0]:
src += f', int stride_{name}_block __multipleof({mult})'
src += f", int stride_{name}_{d} __multipleof({mult})"
src += "\n "
if lut_mode_a == _einsum.LUT_MODE.SCALAR:
src += f", int stride_a_inner __multipleof({multipleof_a})"
src += f", int rem_delta_a __multipleof({multipleof_a})"
elif lut_mode_a == _einsum.LUT_MODE.DRAM:
elif sparse_a or lut_mode_a == _einsum.LUT_MODE.DRAM:
src += ", int* AD __noalias __readonly __aligned(16)"
src += "\n "
if lut_mode_b == _einsum.LUT_MODE.SCALAR:
src += f", int stride_b_inner __multipleof({multipleof_b})"
src += f", int rem_delta_b __multipleof({multipleof_b})"
elif lut_mode_b == _einsum.LUT_MODE.DRAM:
elif sparse_b or lut_mode_b == _einsum.LUT_MODE.DRAM:
src += ", int* BD"
src += "\n"
for ptr in subscripted:
src += f", int* {ptr}"
for name in varnames:
src += f", int {name}"
src += "\n "
if sparse_c:
src += ", int* CD"
if sparse_a or sparse_b:
src += ", int width"
src += """) {
// program identifiers
int pid_0 = get_program_id(0);
int pid_1 = get_program_id(1);
"""
if sparse_a:
src += f"""
int off_n = pid_0 / width;
int off_header = pid_0 % width;
int* header = AD + off_header * {2 + len(outer_sparse_a)};
int* pdelta = AD + *(header + 0);
matmul_k = *(header + 1);"""
for i, d in enumerate(outer_sparse_a):
src += f"""
int off_{d} = *(header + {2 + i});"""
src += f"""
int inca = *(pdelta + 0);
int incb = *(pdelta + 1);
int off_{''.join(map(str, outer_dense_a))} = pid_1;
"""
_einsum.unpack_cc(None, outer_dense_a, "off_", False)
elif sparse_b:
src += f"""
int off_m = pid_0 / width;
int off_header = pid_0 % width;
int* header = BD + off_header * {2 + len(outer_sparse_b)};
int* pdelta = BD + *(header + 0);
matmul_k = *(header + 1);"""
for i, d in enumerate(outer_sparse_b):
src += f"""
int off_{d} = *(header + {2 + i});"""
src += f"""
int incb = *(pdelta + 0);
int inca = *(pdelta + 1);
int off_{''.join(map(str, outer_dense_b))} = pid_1;
"""
_einsum.unpack_cc(None, outer_dense_b, "off_", False)
elif sparse_c:
src += f"""
// load LUT header
int *header = CD + pid_0 * {len(sparse_c)};"""
for i, d in enumerate(sparse_c):
src += f"""
int off_{d} = *(header + {i});"""
src += f"""
int off_{''.join(map(str, outer_dense_c))} = pid_1;"""
else:
src += """
// re-order outer program ids
int grid_m = (matmul_m + TM - 1) / TM;
int grid_n = (matmul_n + TN - 1) / TN;
int pid_mn = get_program_id(0) / div_m;
int pid_n = pid_mn % grid_n;
int pid_m = (pid_mn / grid_n)*div_m + (get_program_id(0) % div_m);
// get batch program id
int pid_b = get_program_id(1);
int off_mn = pid_0 / div_m;
int off_n = off_mn % grid_n;
int off_m = (off_mn / grid_n)*div_m + (pid_0 % div_m);
int off_b = get_program_id(1);"""
src += """
#if TZ == 1
int off_k = 0;
#else
@@ -155,55 +194,71 @@ __global__ void {name}(
// create ranges
"""
rk = 'r{}'.format(''.join(map(str,axes_k)))
for axes, tile, off in zip([axes_m, axes_n, axes_b, axes_k],
['TM', 'TN', 'TB', 'TK'],
['pid_m*TM', 'pid_n*TN', 'pid_b*TB', 'off_k']):
currs = ''.join(map(str,axes))
if axes:
src += f" int r{currs}[{tile}] = {off} + 0 ... {tile};\n"
src += _einsum.unpack_cc(tile, axes, 'r', False)
src += """
sparse = sparse_a + sparse_b + sparse_c
for axes, tile, off, prefixes in zip([axes_m, axes_n, axes_b, axes_k],
['TM', 'TN', 'TB', 'TK'],
['off_m*TM', 'off_n*TN', 'off_b*TB', 'off_k'],
[['a', 'c'], ['b', 'c'], ['a', 'b', 'c'], ['a', 'b']]):
if not axes:
continue
currs = ''.join(map(str,axes))
has_sparse_component = set(axes) & set(sparse)
if has_sparse_component:
src += f" int r{currs}[{tile}] = 0 ... {tile};\n"
src += _einsum.unpack_cc(tile, axes, f'r', False)
else:
src += f" int r{currs}[{tile}] = {off} + 0 ... {tile};\n"
src += _einsum.unpack_cc(tile, axes, f'r', False)
for pfx in prefixes:
for d in axes:
is_dense_dim = d not in sparse
is_dense_storage = (pfx == 'a' and not sparse_a) or\
(pfx == 'b' and not sparse_b) or\
(pfx == 'c' and not sparse_c)
if not is_dense_dim and is_dense_storage:
src += f" int {pfx}r{d}[{tile}] = off_{d} * BLOCK{d.upper()} + r{d};\n"
elif is_dense_dim and has_sparse_component:
src += f" int {pfx}r{d}[{tile}] = off_{d};\n"
else:
src += f" int {pfx}r{d}[{tile}] = r{d};\n"
src += f"""
// initialize pointers to A
int offa[TM, TK, TB] = """
int offa[TM, TK, TB] = {'inca' if sparse_a or sparse_b else '0'} """
for i, sym in enumerate(expr_a):
ccode = _einsum.print_cc(sym, axes_m, axes_k, axes_b)
ccode = _einsum.print_cc(sym, axes_m, axes_k, axes_b, 'a')
stride = f'stride_a_{i}' if i < len(expr_a) - 1 else f'{stride_a_last}'
if i > 0:
src += ' + '
src += f"({ccode}) * {stride}\n "
src += f" + ({ccode}) * {stride}\n "
src += ';'
src += """
TYPE *pa[TM, TK, TB] = A + offa;"""
if use_lut_a and not lut_mode_a == _einsum.LUT_MODE.SCALAR:
if not sparse_a and not sparse_b and use_lut_a and not lut_mode_a == _einsum.LUT_MODE.SCALAR:
spec = '__constant__' if lut_mode_a == _einsum.LUT_MODE.CONSTANT else ''
cast = '(int __constant__*)' if lut_mode_a == _einsum.LUT_MODE.CONSTANT else ''
src += f"""
// initialize pointers to A look-up table
int offadelta[TK] = off_k + 0 ... TK;
int {spec} *padelta[TK] = {cast}AD + offadelta;
int incda[TM, TK, TB] = (*padelta)[newaxis, :, newaxis];"""
src += """
src += f"""
// initialize pointers to B
int offb[TK, TN, TB] = """
int offb[TK, TN, TB] = {'incb' if sparse_a or sparse_b else '0'}"""
for i, sym in enumerate(expr_b):
ccode = _einsum.print_cc(sym, axes_k, axes_n, axes_b)
ccode = _einsum.print_cc(sym, axes_k, axes_n, axes_b, 'b')
stride = f'stride_b_{i}' if i < len(expr_b) - 1 else f'{stride_b_last}'
if i > 0:
src += ' + '
src += f"({ccode}) * {stride}\n "
src += f" + ({ccode}) * {stride}\n "
src += ';'
src += """
TYPE *pb[TK, TN, TB] = B + offb;"""
if use_lut_b and not lut_mode_b == _einsum.LUT_MODE.SCALAR:
if not sparse_a and not sparse_b and use_lut_b and not lut_mode_b == _einsum.LUT_MODE.SCALAR:
spec = '__constant__' if lut_mode_b == _einsum.LUT_MODE.CONSTANT else ''
cast = '(int __constant__*)' if lut_mode_b == _einsum.LUT_MODE.CONSTANT else ''
src += f"""
@@ -211,66 +266,99 @@ __global__ void {name}(
int offbdelta[TK] = off_k + 0 ... TK;
int *pbdelta[TK] = BD + offbdelta;"""
rk = 'r{}'.format(''.join(map(str,axes_k)))
src += f"""
// prefetch
int prefetch_k = select(rem_k > 0, rem_k, TK);
bool checkm[TM] = r""" + ''.join(map(str,axes_m)) + f""" < matmul_m;
bool checkn[TN] = r""" + ''.join(map(str,axes_n)) + f""" < matmul_n;
bool checkk[TK] = {rk} < prefetch_k;
bool checka[TM, TK, TB] = checkm[:, newaxis, newaxis] && checkk[newaxis, :, newaxis];
bool checkb[TK, TN, TB] = checkk[:, newaxis, newaxis] && checkn[newaxis, :, newaxis];
bool checkam[TM] = ar""" + ''.join(map(str,axes_m)) + f""" < matmul_m;
bool checkbn[TN] = br""" + ''.join(map(str,axes_n)) + f""" < matmul_n;
bool checkk[TK] = r{''.join(map(str, axes_k))} < prefetch_k;
bool checka[TM, TK, TB] = checkam[:, newaxis, newaxis] && checkk[newaxis, :, newaxis];
bool checkb[TK, TN, TB] = checkk[:, newaxis, newaxis] && checkbn[newaxis, :, newaxis];
TYPE a[TM, TK, TB] = checka ? *pa : 0;
TYPE b[TK, TN, TB] = checkb ? *pb : 0;"""
if lut_mode_a == _einsum.LUT_MODE.SCALAR:
if sparse_a:
src += f"""
// update pointers to look-up tables
pdelta += 2;
int incda = *(pdelta + 0);
int incdb = *(pdelta + 1);
pa += incda;
pb += incdb;"""
if sparse_b:
src += f"""
// update pointers to look-up tables
pdelta += 2;
int incdb = *(pdelta + 0);
int incda = *(pdelta + 1);
pa += incda;
pb += incdb;"""
if not sparse_a and not sparse_b and lut_mode_a == _einsum.LUT_MODE.SCALAR:
src += """
pa += rem_delta_a;"""
else:
elif not sparse_a and not sparse_b:
src += """
pa += incda;
padelta += TK;
incda = (*padelta)[newaxis, :, newaxis];"""
if lut_mode_b == _einsum.LUT_MODE.SCALAR:
if not sparse_a and not sparse_b and lut_mode_b == _einsum.LUT_MODE.SCALAR:
src += """
pb += rem_delta_b;"""
else:
elif not sparse_a and not sparse_b:
src += """
pb += (*pbdelta)[:, newaxis, newaxis];
pbdelta += TK;"""
src += f"""
// accumulate
float acc[TM, TN, TB] = 0;
for(int k = matmul_k; k > 0; k -= TK) {{
acc += a @ b;
#ifdef MASK
uint32 bits[TM, TN, TB] = bitcast<uint32[TM,TN,TB]>(acc);
acc = bitcast<float[TM, TN, TB]>(bits & MASK);
#endif
// load inputs
checkk = k > TK;
checka = checkm[:, newaxis, newaxis] && checkk[newaxis, :, newaxis];
checkb = checkk[:, newaxis, newaxis] && checkn[newaxis, :, newaxis];
checka = checkam[:, newaxis, newaxis] && checkk[newaxis, :, newaxis];
checkb = checkk[:, newaxis, newaxis] && checkbn[newaxis, :, newaxis];
a = *?(checka)pa;
b = *?(checkb)pb;"""
b = *?(checkb)pb;
if lut_mode_a == _einsum.LUT_MODE.SCALAR:
// update pointers"""
if sparse_a:
src += """
pa += stride_a_inner;"""
pdelta += 2;
incda = *(pdelta + 0);
incdb = *(pdelta + 1);
pa += incda;
pb += incdb;
"""
elif sparse_b:
src += """
pdelta += 2;
incdb = *(pdelta + 0);
incda = *(pdelta + 1);
pa += incda;
pb += incdb;
"""
else:
src += """
if lut_mode_a == _einsum.LUT_MODE.SCALAR:
src += """
pa += stride_a_inner;"""
else:
src += """
pa += incda;
padelta += TK;
incda = (*padelta)[newaxis, :, newaxis];"""
if lut_mode_b == _einsum.LUT_MODE.SCALAR:
src += """
if lut_mode_b == _einsum.LUT_MODE.SCALAR:
src += """
pb += stride_b_inner;"""
else:
src += """
else:
src += """
pb += (*pbdelta)[:, newaxis, newaxis];
pbdelta += TK;"""
@@ -278,38 +366,22 @@ __global__ void {name}(
}}
TYPE c[TM, TN, TB] = acc;
// re-materialize ranges
pid_mn = get_program_id(0) / div_m;
pid_n = pid_mn % grid_n;
pid_m = (pid_mn / grid_n)*div_m + (get_program_id(0) % div_m);
"""
for axes, tile, off in zip([axes_m, axes_n, axes_b],
['TM', 'TN', 'TB'],
['pid_m*TM', 'pid_n*TN', 'pid_b*TB']):
currs = ''.join(map(str,axes))
if axes:
src += f" r{currs} = {off} + 0 ... {tile};\n"
src += _einsum.unpack_cc(tile, axes, 'r', True)
src += """
// initialize pointers to C
int offc[TM, TN, TB] = """
int offc[TM, TN, TB] = {'pid_0*TN*TN' if sparse_c else 0}"""
for i, sym in enumerate(expr_c):
stride = f'stride_c_{i}' if i < len(expr_c) - 1 else f'{stride_c_last}'
ccode = _einsum.print_cc(sym, axes_m, axes_n, axes_b)
if i > 0:
src += ' + '
src += f"({ccode}) * {stride}\n "
ccode = _einsum.print_cc(sym, axes_m, axes_n, axes_b, 'c')
src += f"\n + ({ccode}) * {stride}"
src += ';'
src += """
TYPE *pc[TM, TN, TB] = C + offc;
// bounds-checking
checkm = r""" + ''.join(map(str,axes_m)) + """ < matmul_m;
checkn = r""" + ''.join(map(str,axes_n)) + """ < matmul_n;
bool checkc[TM, TN, TB] = checkm[:, newaxis, newaxis] &&
checkn[newaxis, :, newaxis];
bool checkcm[TM] = cr""" + ''.join(map(str,axes_m)) + """ < matmul_m;
bool checkcn[TN] = cr""" + ''.join(map(str,axes_n)) + """ < matmul_n;
bool checkc[TM, TN, TB] = checkcm[:, newaxis, newaxis] &&
checkcn[newaxis, :, newaxis];
// write back
#if TZ == 1
@@ -330,11 +402,11 @@ __global__ void {name}(
}
"""
# compilation options
TM, TN, TB, TZ = [16, 32, 64, 128], [16, 32, 64, 128], 1, [1, 4, 16]
TM, TN, TB, TZ = [32], [32], 1, [1]
TK = 16 if dtype==torch.float16 else 8
defines = {'TM': TM, 'TN': TN, 'TB': TB, 'TK': TK, 'TZ': TZ, 'TYPE': dtype}
if mask is not None:
defines['MASK'] = '{0:#0{1}x}'.format(mask, 10)
for d, B in blocks.items():
defines[f'BLOCK{d}'] = B
# create kernel
ret = triton.kernel(src, defines=defines)
# set constant
@@ -376,27 +448,81 @@ __global__ void {name}(
k = k // dims[d]
return ret
def make_delta(axes, step, stride, dims, symbols, arrays):
def make_dsd_delta(axes, step, stride, dims, symbols, sparse, layout, blocks):
# depth of reductions
depth = layout.sum(*[i for i, d in enumerate(sparse) if d in axes])
# outer dimension indices
outer = torch.nonzero(depth, as_tuple=False)
outer = [outer[:,i] for i in range(outer.shape[1])]
# find offset of outer dimensions
depth = depth.view(-1)
offsets = torch.zeros_like(depth)
offsets[1:] = torch.cumsum(depth[:-1], 0)
# compute delta for b
# TODO: support multiple sparse red indices
col = next((i for i, d in enumerate(sparse) if d in axes), None)
block = blocks[sparse[-1].upper()]
div = block // step
delta_b = torch.nonzero(layout.transpose(-1, col), as_tuple=False)[:, -1].reshape(-1).contiguous()
delta_b *= block
delta_b = [delta_b + step*i for i in range(div)]
delta_b = torch.stack(delta_b, dim=1)
delta_b = delta_b.view(-1)
# compute delta for a
bstride = 1
for d in sparse[::-1]:
if d in axes:
break
bstride *= blocks[d.upper()]
order = [d for d in sparse if d not in axes] +\
[d for d in sparse if d in axes]
idx = [sparse.index(d) for d in order]
layout[layout > 0] = 1 + torch.arange(layout.sum(), device=layout.device)
layout = layout.permute(*idx)
delta_a = layout[layout > 0] - 1
delta_a *= np.prod(list(blocks.values()))
saved = delta_a[offsets]
delta_a[1:] = delta_a[1:] - delta_a[:-1]
delta_a = delta_a.view(-1, 1).repeat(1, div)
delta_a[:, 1:] = step*bstride
delta_a[:, 0] -= (div - 1)*step*bstride
delta_a[offsets, 0] = saved
delta_a = delta_a.view(-1)
delta = torch.stack((delta_a, delta_b), dim=1).view(-1).contiguous()
# form look-up table
depth *= blocks[symbols[-1].upper()]
offsets *= div
header = torch.stack((offsets, depth, *outer), dim=1).view(-1).contiguous()
nouter = 2 + len(outer)
header[::nouter] = header[::nouter]*2 + header.shape[0]
lut = torch.cat((header, delta)).int().int().cpu().numpy()
return lut, nouter, _einsum.LUT_MODE.DRAM
def make_delta(axes, step, stride, dims, symbols, sparse, layout, lut = None, nouter = None):
# symbolic pointer increments
symbols = [sp.symbols(x) for x in symbols]
delta = _einsum.symbolic_delta(symbols, axes)
args = [f'stride{d}' for d in range(len(stride))]
args += [f'{sk}' for sk in axes]
args += [f'next{sk}' for sk in axes]
args += [f'{sk}' for sk, _ in arrays]
fn = sp.lambdify(args, delta, 'numpy')
# inner axes values
inner = [dims[d] for d in axes]
inner = np.prod(inner)
rem = inner % step
rem = rem if rem > 0 else step
# k = [0, 1, ..., step,
# rem, rem + 1, ... rem + inner]
k = np.concatenate((np.arange(step),
np.arange(rem, inner))).astype(np.int32)
# nextk = [rem, 1 + rem, ..., step + rem,
# rem + step, rem + 1 + step, ..., inner + step]
nextk = np.concatenate((k[:step] + rem,
k[step:] + step))
if lut is None:
# inner axes values
inner = [dims[d] for d in axes]
inner = np.prod(inner)
rem = inner % step
rem = rem if rem > 0 else step
# k = [0, 1, ..., step, rem, rem + 1, ... rem + inner]
# nextk = [rem, 1 + rem, ..., step + rem, rem + step, rem + 1 + step, ..., inner + step]
k = np.concatenate((np.arange(step), np.arange(rem, inner))).astype(np.int32)
nextk = np.concatenate((k[:step] + rem, k[step:] + step))
else:
idx = (lut[:lut[0]:nouter] - lut[0])//2
k = lut[lut[0]+1::2]
k = np.insert(k, idx, 0)
nextk = k[1:]
k = k[:-1]
# offsets
off = _einsum.unpack_offset(k, axes, dims)
nextoff = _einsum.unpack_offset(nextk, axes, dims)
@@ -404,71 +530,20 @@ __global__ void {name}(
args = [s for s in stride]
args += [off[sk] for sk in axes]
args += [nextoff[sk] for sk in axes]
args += [x for _, x in arrays]
delta = fn(*args)
delta = np.maximum(delta, 0)
if lut is not None:
idx = idx[1:] + np.arange(idx.shape[0] - 1)
delta = np.delete(delta, idx)
lut[lut[0]+1::2] = delta
return None, None
return delta, _einsum.lut_mode(delta[step:-step])
############################
## Einsum parsing
############################
def uniq(seq):
seen = set()
seen_add = seen.add
return [x for x in seq if not (x in seen or seen_add(x))]
def parse_axes(expr_a, expr_b, expr_c, subscripted):
is_index = lambda x: type(x) == sp.indexed.Indexed or str(x) in subscripted
sym_a = [x for s in expr_a for x in s.free_symbols if not is_index(x)]
sym_b = [x for s in expr_b for x in s.free_symbols if not is_index(x)]
sym_c = [x for s in expr_c for x in s.free_symbols]
batch = [d for d in sym_a if d in sym_b and d in sym_c]
outer = [d for d in sym_a if d not in sym_b and d in sym_c]
inner = [d for d in sym_a if d in sym_b and d not in sym_c]
variables = [d for d in sym_a if d not in sym_b and d not in sym_c]
return _einsum.uniq(batch), _einsum.uniq(outer), _einsum.uniq(inner), variables
def replace_subscript(expr, arrays):
# replace array indexing by Indexed()
indexed = re.findall('([_a-zA-Z][_a-zA-Z0-9]*)\[([_a-z]*)\]', expr)
for x in indexed:
arrays.append(x[0])
expr = expr.replace(f'{x[0]}[{x[1]}]', f'Indexed({x[0]},{x[1]})')
return expr
def parse_expr(expr, arrays):
# extract symbols
sym = []
i = 0
while i < len(expr):
d = expr[i]
if d == '(':
size = expr[i:].find(')')
d = expr[i : i + size + 1]
d = _einsum.replace_subscript(d, arrays)
sym.append(parse_expr(d))
i += size + 1
else:
sym.append(parse_expr(d))
i += 1
return sym
############################
## Preprocessing
############################
@staticmethod
def pad(tensor, pad):
pad = pad + [0] * (2*len(tensor.shape) - len(pad))
begin = [ x if x > 0 else None for x in pad[-1::-2]]
end = [-x if x > 0 else None for x in pad[-2::-2]]
slices = [slice(b, e) for b, e in zip(begin, end)]
tensor = torch.nn.functional.pad(tensor, pad, 'constant', 0)
tensor = tensor[slices]
return tensor
def make_sdd_lut(layout_c, sparse_c, blocks):
nnz = torch.nonzero(layout_c, as_tuple=False)
lut = nnz.reshape(-1).int().cuda()
return lut
############################
## Compilation
@@ -479,68 +554,78 @@ __global__ void {name}(
locks = None
kernel_cache = dict()
@staticmethod
def _tile(M, N, B, TMs, TNs, TBs, TZs, TK):
smp = 15
# occupancy estimation
grid = lambda TM, TN, TB, TZ: \
triton.cdiv(M, TM)* \
triton.cdiv(N, TN)* \
triton.cdiv(B, TB)* \
TZ
occupancy = lambda TM, TN, TB, TZ: \
min(grid(TM, TN, TB, TZ), 4*smp)
# arithmetic intensity estimation
intensity = lambda TM, TN: \
TM * TN * TK / (TM*TK + TK*TN)
# occupancy/intensity for all configurations
estimates = {(TM, TN, TB, TZ): (occupancy(TM, TN, TB, TZ), intensity(TM, TN)) \
for TM in TMs \
for TN in TNs \
for TB in TBs \
for TZ in TZs }
# returns configuration that maximizes occupancy subject to maximizing intensity
estimates = sorted(estimates.items(),
key=lambda item: item[1],
reverse=True)
return estimates[0][0]
def __init__(self, einsum, dtype, stride_a, stride_b, stride_c, shape_a, shape_b, arrays, mask, shape_c, varnames):
def __init__(self, einsum, dtype, stride_a, stride_b, shape_a, shape_b, layouts, blocks):
# parse symbols
expr_a, expr_bc = einsum.split(",")
expr_b, expr_c = expr_bc.split("->")
subscripted = []
sym_a = _einsum.parse_expr(expr_a, subscripted)
sym_b = _einsum.parse_expr(expr_b, subscripted)
sym_c = _einsum.parse_expr(expr_c, subscripted)
sym_a = expr_a.lower()
sym_b = expr_b.lower()
sym_c = expr_c.lower()
sparse_a = [x.lower() for x in expr_a if x.isupper()]
sparse_b = [x.lower() for x in expr_b if x.isupper()]
sparse_c = [x.lower() for x in expr_c if x.isupper()]
layout_a = layouts.get(expr_a)
layout_b = layouts.get(expr_b)
layout_c = layouts.get(expr_c)
# parse axes
axes_b, axes_m, axes_k, var = _einsum.parse_axes(sym_a, sym_b, sym_c, subscripted)
_, axes_n, _, _ = _einsum.parse_axes(sym_b, sym_a, sym_c, subscripted)
axes_b = [d for d in sym_a if d in sym_b and d in sym_c]
axes_m = [d for d in sym_a if d not in sym_b and d in sym_c]
axes_k = [d for d in sym_a if d in sym_b and d not in sym_c]
axes_n = [d for d in sym_b if d not in sym_a and d in sym_c]
axes = axes_b + axes_m + axes_n + axes_k
# unresolved symbols
unresolved = [x for x in map(str, var) if x not in varnames]
if unresolved:
raise ValueError(f'unresolved symbols: {unresolved}')
# check block sizes
for d in sparse_a + sparse_b + sparse_c:
if d.upper() not in blocks:
raise ValueError(f'unspecified block size for dimension: {d.upper()}')
# check layout is present
if sparse_a and layout_a is None:
raise ValueError('A is sparse but not layout provided')
if sparse_b and layout_b is None:
raise ValueError('B is sparse but not layout provided')
if sparse_c and layout_c is None:
raise ValueError('C is sparse but not layout provided')
# check dimensions
dims_a = dict(zip(sym_a, shape_a))
dims_b = dict(zip(sym_b, shape_b))
dims_c = dict(zip(sym_c, shape_c))
for axes in [axes_b, axes_k]:
for d in axes:
dim_a = dims_a[d] if d in sym_a else None
dim_b = dims_b[d] if d in sym_b else None
if dim_a and dim_b and dim_a != dim_b:
raise ValueError(f'incompatible dimension {d}'
f' (a: {dim_a}; b: {dim_b})')
dims_a = dict([(x, y) for x,y in zip(sym_a, shape_a) if x not in sparse_a])
dims_b = dict([(x, y) for x,y in zip(sym_b, shape_b) if x not in sparse_b])
dims_La = None if layout_a is None else dict(zip([x for x in expr_a if x.isupper()], layout_a.shape))
dims_Lb = None if layout_b is None else dict(zip([x for x in expr_b if x.isupper()], layout_b.shape))
# TODO: could be cleaner
read_shape = lambda d, dimsT, dimsL, sparse: dimsL[d.upper()] * blocks[d.upper()] if d in sparse else dimsT[d]
for d in axes_b + axes_m + axes_n + axes_k:
dim_a = read_shape(d, dims_a, dims_La, sparse_a) if d in sym_a else None
dim_b = read_shape(d, dims_b, dims_Lb, sparse_b) if d in sym_b else None
if d in axes_b and dim_a and dim_b and dim_a != dim_b:
raise ValueError(f'incomparible batch dimension {d} (A: {dim_a}, B: {dim_b})')
if d in axes_k and dim_a and dim_b and dim_a != dim_b:
raise ValueError(f'incompatible inner dimension {d} (A: {dim_a}, B: {dim_b})')
dims = dict()
dims.update(dims_a)
dims.update(dims_b)
dims.update(dims_c)
for i, d in enumerate(sparse_a):
dims[d] = layout_a.shape[i] * blocks[d.upper()]
for i, d in enumerate(sparse_b):
dims[d] = layout_b.shape[i] * blocks[d.upper()]
# allocate output
shape_c = [dims[d] if d.islower() else blocks[d] for d in expr_c]
if sparse_c:
shape_c.insert(expr_c.index(sparse_c[0].upper()), int(layout_c.sum()))
stride_c = [None] * len(shape_c)
stride_c[-1] = 1
for i in reversed(range(len(shape_c) - 1)):
stride_c[i] = stride_c[i+1] * shape_c[i+1]
# look-up tables
TK = 16 if dtype == torch.float16 else 8
arrays = [(x, arrays[x]) for x in subscripted]
delta_a, lut_mode_a = _einsum.make_delta(axes_k, TK, stride_a, dims, sym_a, arrays)
delta_b, lut_mode_b = _einsum.make_delta(axes_k, TK, stride_b, dims, sym_b, arrays)
if sparse_a and not sparse_b:
delta_a, nouter, lut_mode_a = _einsum.make_dsd_delta(axes_k, TK, stride_a, dims, sym_a, sparse_a, layout_a, blocks)
delta_b, lut_mode_b = _einsum.make_delta(axes_k, TK, stride_b, dims, sym_b, sparse_b, layout_b, delta_a, nouter)
if sparse_b and not sparse_a:
delta_b, nouter, lut_mode_b = _einsum.make_dsd_delta(axes_k, TK, stride_b, dims, sym_b, sparse_b, layout_b, blocks)
delta_a, lut_mode_a = _einsum.make_delta(axes_k, TK, stride_a, dims, sym_a, sparse_a, layout_a, delta_b, nouter)
if not sparse_a and not sparse_b:
delta_a, lut_mode_a = _einsum.make_delta(axes_k, TK, stride_a, dims, sym_a, sparse_a, layout_a)
delta_b, lut_mode_b = _einsum.make_delta(axes_k, TK, stride_b, dims, sym_b, sparse_b, layout_b)
if sparse_c:
delta_c = _einsum.make_sdd_lut(layout_c, sparse_c, blocks)
# hash for recompilation
stride_a_multiple = max([x for x in [1, 2, 4, 8] if shape_a[-1] % x == 0])
stride_b_multiple = max([x for x in [1, 2, 4, 8] if shape_b[-1] % x == 0])
@@ -548,7 +633,7 @@ __global__ void {name}(
stride_a_last = stride_a[-1]
stride_b_last = stride_b[-1]
stride_c_last = stride_c[-1]
name = f'{dtype}_{mask}_{expr_a}_{expr_b}_{expr_c}_{lut_mode_a}_{lut_mode_b}'\
name = f'{dtype}_{expr_a}_{expr_b}_{expr_c}_{lut_mode_a}_{lut_mode_b}'\
f'_{stride_a_multiple}_{stride_b_multiple}_{stride_c_multiple}'\
f'_{stride_a_last}_{stride_b_last}_{stride_c_last}'
# recompile if necessary
@@ -556,14 +641,15 @@ __global__ void {name}(
if name not in cache:
cachesize = len(cache)
cache[name] = _einsum.make_kernel(f'__einsum{cachesize}',
dtype, mask,
dtype,
sym_a, sym_b, sym_c,
sparse_a, sparse_b, sparse_c,
axes_m, axes_n, axes_k, axes_b,
stride_a_multiple, stride_b_multiple, stride_c_multiple,
stride_a_last, stride_b_last, stride_c_last,
lut_mode_a, lut_mode_b,
delta_a, delta_b,
subscripted, varnames)
blocks)
self.kernel = cache[name]
# Initialize locks
if _einsum.instance.locks is None:
@@ -576,42 +662,55 @@ __global__ void {name}(
M = reduce(mul, dim_m, 1)
N = reduce(mul, dim_n, 1)
K = reduce(mul, dim_k, 1)
B = reduce(mul, dim_b, 1)
B = reduce(mul, [dims[d] for d in axes_b if d.upper() not in einsum], 1)
stride_a = list(stride_a[:-1])
stride_b = list(stride_b[:-1])
stride_c = list(stride_c[:-1])
arrays = [torch.from_numpy(x).cuda() for _, x in arrays]
alpha = 1.
div_m = 1
self.args = [None, None, None,
_einsum.instance.locks,
alpha, M, N, K, div_m] +\
dim_m + dim_n + dim_k + dim_b +\
stride_a + stride_b + stride_c
self.args = [None, None, None]
self.args += [_einsum.instance.locks]
self.args += [alpha, M, N, K, div_m]
self.args += dim_m
self.args += dim_n
self.args += dim_k
self.args += dim_b
self.args += stride_a
self.args += stride_b
self.args += stride_c
# LUT for A
if lut_mode_a == _einsum.LUT_MODE.SCALAR:
self.args += [delta_a[TK], delta_a[0]]
if lut_mode_a == _einsum.LUT_MODE.DRAM:
elif sparse_a or lut_mode_a == _einsum.LUT_MODE.DRAM:
self.args += [torch.from_numpy(delta_a).cuda()]
# LUT for B
if lut_mode_b == _einsum.LUT_MODE.SCALAR:
self.args += [delta_b[TK], delta_b[0]]
if lut_mode_b == _einsum.LUT_MODE.DRAM:
elif sparse_b or lut_mode_b == _einsum.LUT_MODE.DRAM:
self.args += [torch.from_numpy(delta_b).cuda()]
# Einsum dependents
self.args += arrays
self.grid = lambda opt: [triton.cdiv(M, opt.d('TM')) *
triton.cdiv(N, opt.d('TN')),
triton.cdiv(B, opt.d('TB')),
opt.d('TZ')]
# LUT for C
if sparse_c:
self.args += [delta_c]
if sparse_a or sparse_b:
width = delta_a[0] // nouter if sparse_a else delta_b[0] // nouter
self.args += [width]
# Grid
if sparse_a:
self.grid = lambda opt: [width*triton.cdiv(N, opt.d('TN')), B, opt.d('TZ')]
elif sparse_b:
self.grid = lambda opt: [width*triton.cdiv(M, opt.d('TM')), B, opt.d('TZ')]
elif sparse_c:
width = int(layout_c.sum())
self.grid = lambda opt: [width, B, opt.d('TZ')]
else:
self.grid = lambda opt: [triton.cdiv(M, opt.d('TM')) *
triton.cdiv(N, opt.d('TN')),
triton.cdiv(B, opt.d('TB')),
opt.d('TZ')]
# position of dynamic arguments
self.pos_a = 0
self.pos_b = 1
self.pos_c = 2
# user-provided variables
self.pos_vars = len(self.args)
self.varnames = varnames
self.args += [None] * len(varnames)
# save information on the operation
self.expr_a = expr_a
self.expr_b = expr_b
@@ -620,16 +719,17 @@ __global__ void {name}(
self.matmul_M = M
self.matmul_N = N
self.matmul_K = K
self.is_extended = any([not x.is_symbol for x in sym_a + sym_b])
# output shape
self.shape_c = shape_c
def run(self, a, b, c, values, bench):
def run(self, a, b):
c = torch.empty(*self.shape_c, dtype=a.dtype, device=a.device)
self.args[self.pos_a] = a
self.args[self.pos_b] = b
self.args[self.pos_c] = c
for i, name in enumerate(self.varnames):
self.args[self.pos_vars + i] = values[name]
return self.kernel(*self.args, grid=self.grid, bench=bench)
self.kernel(*self.args, grid=self.grid)
return c
@@ -639,25 +739,23 @@ __global__ void {name}(
############################
instance_cache = dict()
registry = triton.utils.id_dict()
registry = dict()
@staticmethod
def forward(ctx, expr, a, b, output, mask, arrays, bench, values):
def forward(ctx, expr, a, b, layouts, blocks):
# compile einsum instance
cache = _einsum.instance_cache
key = (expr, a.dtype,
a.stride(), b.stride(), output.stride(),
a.shape, b.shape, output.shape, mask)
a.stride(), b.stride(),
a.shape , b.shape)
if key not in cache:
cache[key] = _einsum.instance(expr, a.dtype,
a.stride(), b.stride(), output.stride(),
a.shape, b.shape, arrays,
mask, output.shape, values.keys())
a.stride(), b.stride(),
a.shape , b.shape ,
layouts, blocks)
instance = cache[key]
# run and mark as dirty output modified in-place
perf = instance.run(a, b, output, values, bench)
ctx.mark_dirty(output)
# run and mark as dirty c modified in-place
c = instance.run(a, b)
# save information in context
ctx.is_extended = instance.is_extended
ctx.expr_a = instance.expr_a
ctx.expr_b = instance.expr_b
ctx.expr_c = instance.expr_c
@@ -665,10 +763,8 @@ __global__ void {name}(
ctx.matmul_M = instance.matmul_M
ctx.matmul_N = instance.matmul_N
ctx.matmul_K = instance.matmul_K
ctx.forward_ms = perf
ctx.save_for_backward(a, b)
_einsum.registry[output] = ctx
return output
return c
############################
@@ -677,9 +773,6 @@ __global__ void {name}(
@staticmethod
def backward(ctx, dy):
if ctx.is_extended:
raise NotImplementedError('Automatic differentiation for extended einsum not yet implemented;'
' print write your own autograd function')
a, b = ctx.saved_tensors
expr_a = ctx.expr_a
expr_b = ctx.expr_b
@@ -694,10 +787,8 @@ __global__ void {name}(
if ctx.needs_input_grad[2]:
db = torch.empty_like(b)
einsum(f'{expr_a},{expr_c}->{expr_b}', a, dy, db)
return None, da, db, None, None, None, None, None
return None, da, db, None, None, None, None, None, None, None
def einsum(expr, a, b, output,
mask=None, arrays=dict(),
bench=False, values=dict()):
return _einsum.apply(expr, a, b, output, mask, arrays, bench, values)
def einsum(expr, a, b, layouts = None, blocks = dict()):
return _einsum.apply(expr, a, b, layouts, blocks)