From 049ab989b515fb324a6b36c6cf5cb6efd11c7d11 Mon Sep 17 00:00:00 2001 From: Philippe Tillet Date: Sun, 25 Oct 2020 11:55:58 -0700 Subject: [PATCH] [GENERAL] Various improvements: * Sparse einsum in triton.ops.einsum * Hacky support for fixed-tile-size atomic-add * Various bugfixes in parser --- include/triton/ir/builder.h | 2 +- include/triton/ir/instructions.h | 4 +- lib/codegen/selection/generator.cc | 39 +- lib/driver/module.cc | 2 +- lib/ir/builder.cc | 4 +- lib/ir/instructions.cc | 9 +- lib/lang/ast.cc | 6 +- lib/lang/code_gen.cc | 7 +- lib/lang/parser.cc | 2 +- lib/runtime/function.cc | 3 + python/examples/test.py | 109 +++++ python/examples/tutorials/mat_mul.py | 18 +- python/setup.py | 2 +- python/triton/__init__.py | 2 +- python/triton/kernel.py | 1 + python/triton/ops/einsum.py | 695 +++++++++++++++------------ 16 files changed, 574 insertions(+), 331 deletions(-) create mode 100644 python/examples/test.py diff --git a/include/triton/ir/builder.h b/include/triton/ir/builder.h index 1c87997f5..f204ee0b5 100644 --- a/include/triton/ir/builder.h +++ b/include/triton/ir/builder.h @@ -136,7 +136,7 @@ public: value *create_get_num_program(unsigned axis, const std::string &name = ""); value *create_atomic_cas(value *ptr, value *cmp, value *val, const std::string &name = ""); value *create_atomic_exch(value *ptr, value *val, const std::string &name = ""); - value *create_atomic_add(value *ptr, value *val, const std::string &name = ""); + value *create_atomic_add(value *ptr, value *val, value *msk, const std::string &name = ""); value *create_exp(value* arg, const std::string &name = ""); value *create_dot(value *A, value *B, value *C, const std::string &name = ""); value *create_trans(value *A, const std::vector &perm = {}, const std::string &name = ""); diff --git a/include/triton/ir/instructions.h b/include/triton/ir/instructions.h index 54dd2e736..83255d215 100644 --- a/include/triton/ir/instructions.h +++ b/include/triton/ir/instructions.h @@ -603,13 +603,13 @@ public: class atomic_add_inst: public builtin_inst { private: - atomic_add_inst(value *ptr, value *val, const std::string &name = "", instruction *next = nullptr); + atomic_add_inst(value *ptr, value *val, value *msk, const std::string &name = "", instruction *next = nullptr); std::string repr_impl() const { return "atomic_add"; } _TRITON_DEFINE_CLONE(atomic_add_inst) _TRITON_DEFINE_ACCEPT(atomic_add_inst) public: - static instruction* create(value *ptr, value *val, const std::string &name = "", instruction *next = nullptr); + static instruction* create(value *ptr, value *val, value *msk, const std::string &name = "", instruction *next = nullptr); }; class exp_inst: public builtin_inst { diff --git a/lib/codegen/selection/generator.cc b/lib/codegen/selection/generator.cc index 3dd61cfc3..5813c1ff1 100644 --- a/lib/codegen/selection/generator.cc +++ b/lib/codegen/selection/generator.cc @@ -364,7 +364,14 @@ void generator::visit_unmasked_load_inst(ir::unmasked_load_inst* x) { void generator::visit_masked_load_inst(ir::masked_load_inst* x) { // find vector size ir::value *ptr = x->get_pointer_operand(); - size_t ld = layouts_->get(ptr)->get_order(0); + auto order = layouts_->get(ptr)->get_order(); + size_t ld; + for(size_t i = 0; i < order.size(); i++){ + ld = order[i]; + if(ld < x->get_type()->get_tile_rank()) + break; + } + //size_t ld = layouts_->get(ptr)->get_order(0); unsigned alignment = alignment_->get(ptr, ld); distributed_tile *pointers = (distributed_tile*)tmap_.at(ptr); distributed_tile *masks = (distributed_tile*)tmap_.at(x->get_mask_operand()); @@ -652,6 +659,31 @@ void generator::visit_atomic_exch_inst(ir::atomic_exch_inst* xchg) { } void generator::visit_atomic_add_inst(ir::atomic_add_inst* add) { + if(add->get_type()->is_tile_ty()){ + ir::value* ptr = add->get_operand(0); + ir::value* val = add->get_operand(1); + ir::value* msk = add->get_operand(2); + distributed_tile* ptrs = (distributed_tile*)tmap_.at(ptr); + distributed_tile* vals = (distributed_tile*)tmap_.at(val); + distributed_tile* msks = (distributed_tile*)tmap_.at(msk); + for_each(ptr, [&](indices_t idx){ + Value *rmw_ptr = ptrs->get_value(idx); + Value *rmw_val = vals->get_value(idx); + Value *rmw_msk = msks->get_value(idx); + BasicBlock *current_bb = builder_->GetInsertBlock(); + Function *parent = builder_->GetInsertBlock()->getParent(); + BasicBlock *mask_then_bb = BasicBlock::Create(*ctx_, "mask_then", parent); + BasicBlock *mask_done_bb = BasicBlock::Create(*ctx_, "mask_done", parent); + builder_->CreateCondBr(rmw_msk, mask_then_bb, mask_done_bb); + builder_->SetInsertPoint(mask_then_bb); + builder_->CreateAtomicRMW(AtomicRMWInst::FAdd, rmw_ptr, rmw_val, + AtomicOrdering::Monotonic, + SyncScope::System); + builder_->CreateBr(mask_done_bb); + builder_->SetInsertPoint(mask_done_bb); + }); + } + else{ BasicBlock *current = builder_->GetInsertBlock(); Module *module = current->getModule(); Value *rmw_ptr = vmap_.at(add->get_operand(0)); @@ -670,6 +702,7 @@ void generator::visit_atomic_add_inst(ir::atomic_add_inst* add) { builder_->CreateBr(tid_0_done_bb); builder_->SetInsertPoint(tid_0_done_bb); tgt_->add_memfence(module, *builder_); + } } void generator::visit_hmma_dot(ir::dot_inst* dot, shared_tile *TA, shared_tile *TB, distributed_tile *TD, unsigned NK) { @@ -1362,8 +1395,10 @@ void generator::visit_layout_shared(analysis::shared_layout* layout) { void generator::visit_basic_block(ir::basic_block * block) { BasicBlock *parent = (BasicBlock*)vmap_[block]; builder_->SetInsertPoint(parent); - for(ir::instruction *i: block->get_inst_list()) + for(ir::instruction *i: block->get_inst_list()){ + // std::cout << typeid(*i).name() << std::endl; visit_value(i); + } vmap_[block] = builder_->GetInsertBlock(); } diff --git a/lib/driver/module.cc b/lib/driver/module.cc index 78f42d9a5..20586f57f 100755 --- a/lib/driver/module.cc +++ b/lib/driver/module.cc @@ -253,7 +253,7 @@ cu_module::cu_module(driver::context * context, std::unique_ptr ll cu_module::cu_module(driver::context * context, std::string const & source) : module(context, CUmodule(), true), source_(source){ cu_context::context_switcher ctx(*context); -// std::cout << source << std::endl; + // std::cout << source << std::endl; // JIT compile source-code CUjit_option opt[] = {CU_JIT_ERROR_LOG_BUFFER_SIZE_BYTES, CU_JIT_ERROR_LOG_BUFFER}; unsigned int errbufsize = 8096; diff --git a/lib/ir/builder.cc b/lib/ir/builder.cc index 50404df46..c100f461a 100644 --- a/lib/ir/builder.cc +++ b/lib/ir/builder.cc @@ -307,8 +307,8 @@ value *builder::create_atomic_exch(value *ptr, value *val, const std::string &na return insert(atomic_exch_inst::create(ptr, val, name)); } -value *builder::create_atomic_add(value *ptr, value *val, const std::string &name){ - return insert(atomic_add_inst::create(ptr, val, name)); +value *builder::create_atomic_add(value *ptr, value *val, value *msk, const std::string &name){ + return insert(atomic_add_inst::create(ptr, val, msk, name)); } value *builder::create_exp(value *arg, const std::string &name){ diff --git a/lib/ir/instructions.cc b/lib/ir/instructions.cc index 2c14a1e83..6ede70001 100644 --- a/lib/ir/instructions.cc +++ b/lib/ir/instructions.cc @@ -736,14 +736,15 @@ instruction* atomic_exch_inst::create(value *ptr, value *val, const std::string // atomic add -atomic_add_inst::atomic_add_inst(value *ptr, value *val, const std::string &name, instruction *next) - : builtin_inst(ptr->get_type()->get_pointer_element_ty(), INST_ATOMIC_ADD, 2, name, next) { +atomic_add_inst::atomic_add_inst(value *ptr, value *val, value *msk, const std::string &name, instruction *next) + : builtin_inst(ptr->get_type()->get_pointer_element_ty(), INST_ATOMIC_ADD, 3, name, next) { set_operand(0, ptr); set_operand(1, val); + set_operand(2, msk); } -instruction* atomic_add_inst::create(value *ptr, value *val, const std::string &name, instruction *next) { - return new atomic_add_inst(ptr, val, name, next); +instruction* atomic_add_inst::create(value *ptr, value *val, value *msk, const std::string &name, instruction *next) { + return new atomic_add_inst(ptr, val, msk, name, next); } // exp diff --git a/lib/lang/ast.cc b/lib/lang/ast.cc index 18bfbec71..c0887574b 100644 --- a/lib/lang/ast.cc +++ b/lib/lang/ast.cc @@ -523,7 +523,7 @@ void BinaryOp::RelationalOpTypeChecking() { } Convert(); } - type_ = ArithmType::New(T_INT); + type_ = ArithmType::New(T_BOOL); Broadcast(this, lhs_, rhs_, type_); } @@ -538,7 +538,7 @@ void BinaryOp::EqualityOpTypeChecking() { Error(this, "invalid operands to binary %s", tok_->str_.c_str()); Convert(); } - type_ = ArithmType::New(T_INT); + type_ = ArithmType::New(T_BOOL); Broadcast(this, lhs_, rhs_, type_); } @@ -558,7 +558,7 @@ void BinaryOp::LogicalOpTypeChecking() { ::Type* rhsScalType = TryExtractScalarType(this, rhs_); if (!lhsScalType->IsScalar() || !rhsScalType->IsScalar()) Error(this, "the operand should be arithmetic type or pointer"); - type_ = ArithmType::New(T_INT); + type_ = ArithmType::New(T_BOOL); Broadcast(this, lhs_, rhs_, type_); } diff --git a/lib/lang/code_gen.cc b/lib/lang/code_gen.cc index 2cf20d85e..323808ac7 100644 --- a/lib/lang/code_gen.cc +++ b/lib/lang/code_gen.cc @@ -277,12 +277,14 @@ void Generator::VisitFuncCall(FuncCall* funcCall) { ir::value* val = ret_; return set_ret(bld_->create_atomic_exch(ptr, val)); } - if(name == "f32_atomic_add"){ + if(name == "f32_atomic_add" || name == "atomic_add_64x64"){ VisitExpr(funcCall->Args()->at(0)); ir::value* ptr = ret_; VisitExpr(funcCall->Args()->at(1)); ir::value* val = ret_; - return set_ret(bld_->create_atomic_add(ptr, val)); + VisitExpr(funcCall->Args()->at(2)); + ir::value* msk = ret_; + return set_ret(bld_->create_atomic_add(ptr, val, msk)); } if(name == "sqrtf"){ VisitExpr(funcCall->Args()->at(0)); @@ -338,6 +340,7 @@ void Generator::VisitTempVar(TempVar* tempVar) { } // Statement +// TODO: int x = x; crashes void Generator::VisitDeclaration(Declaration* decl) { auto obj = decl->obj_; // initialize to undef diff --git a/lib/lang/parser.cc b/lib/lang/parser.cc index c3f9c5ab7..8025ca563 100644 --- a/lib/lang/parser.cc +++ b/lib/lang/parser.cc @@ -650,7 +650,7 @@ Expr* Parser::ParseDerefOp(const Token* tok) { Expr* pred = nullptr; if(ts_.Try('?')){ ts_.Expect('('); - pred = ParseCastExpr(); + pred = ParseExpr(); ts_.Expect(')'); } Expr* addr = ParseCastExpr(); diff --git a/lib/runtime/function.cc b/lib/runtime/function.cc index 4d8edc523..ef735b670 100644 --- a/lib/runtime/function.cc +++ b/lib/runtime/function.cc @@ -239,6 +239,7 @@ std::unique_ptr function::make_bin(ir::module &module, throw std::runtime_error("using too much shared memory"); barriers.run(module); isel.visit(module, *llvm); + // ir::print(module, std::cout); std::unique_ptr res(driver::module::create(context, std::move(llvm))); return res; } @@ -351,6 +352,8 @@ std::string function::preheader() { extern int atomic_cas(int*, int, int); extern int atomic_xchg(int*, int); extern float f32_atomic_add(float*, float); +extern void atomic_add_128x128(float*[128, 128], float[128, 128], bool[128, 128]); +extern void atomic_add_64x64(float*[64, 64], float[64, 64], bool[64, 64]); extern int get_program_id(int); extern int get_num_programs(int); extern float sqrtf(float); diff --git a/python/examples/test.py b/python/examples/test.py new file mode 100644 index 000000000..c2ff7d473 --- /dev/null +++ b/python/examples/test.py @@ -0,0 +1,109 @@ +import triton +import numpy +import torch +import itertools + +torch.manual_seed(0) +numpy.random.seed(0) + +def to_sparse(expr, data, layout, shape, block): + # shape of result + sparse = None + shape_ret = [] + for i, d in enumerate(expr): + if d.isupper() and sparse is None: + sparse = i + shape_ret.append(int(layout.sum())) + if d.isupper(): + shape_ret.append(block[d]) + else: + shape_ret.append(shape[i]) + # iterator + steps = [block[d] if d.isupper() else 1 for d in expr] + it = [range(0, shape[i], steps[i]) for i in range(len(expr))] + # create result + ret = torch.empty(*shape_ret, dtype=data.dtype, device=data.device) + blockid = 0 + nzblockid = 0 + for curr in itertools.product(*it): + if all([curr[i] == it[i][0] for i in range(len(curr)) if expr[i].isupper()]): + blockid = 0 + nzblockid = 0 + data_slice = [slice(curr[i], curr[i] + steps[i], 1) for i in range(len(curr))] + ret_slice = [slice(0, block[expr[i]], 1) if expr[i].isupper() else slice(curr[i], curr[i] + 1) for i in range(len(curr))] + ret_slice.insert(sparse, nzblockid) + if int(layout.view(-1)[blockid]): + ret[ret_slice] = data[data_slice] + nzblockid += 1 + blockid += 1 + return ret + +def to_dense(expr, data, layout, shape, block): + sparse = None + for i, d in enumerate(expr): + if d.isupper() and sparse is None: + sparse = i + + ret = torch.zeros(*shape, dtype=data.dtype, device=data.device) + steps = [block[d] if d.isupper() else 1 for d in expr] + it = [range(0, shape[i], steps[i]) for i in range(len(expr))] + blockid = 0 + nzblockid = 0 + for curr in itertools.product(*it): + if all([curr[i] == it[i][0] for i in range(len(curr)) if expr[i].isupper()]): + blockid = 0 + nzblockid = 0 + ret_slice = [slice(curr[i], curr[i] + steps[i], 1) for i in range(len(curr))] + data_slice = [slice(0, block[expr[i]], 1) if expr[i].isupper() else slice(curr[i], curr[i] + 1) for i in range(len(curr))] + data_slice.insert(sparse, nzblockid) + if int(layout.view(-1)[blockid]): + ret[ret_slice] = data[data_slice] + nzblockid += 1 + blockid += 1 + return ret + +def test_expr(expr, shape, blocks): + # decompose expr + expr_a, expr_bc = expr.split(",") + expr_b, expr_c = expr_bc.split("->") + # check with argument is sparse + sparse_a = any(x.isupper() for x in expr_a) + sparse_b = any(x.isupper() for x in expr_b) + sparse_c = any(x.isupper() for x in expr_c) + # allocate data + shape_a = [shape[d.lower()] for d in expr_a] + shape_b = [shape[d.lower()] for d in expr_b] + shape_c = [shape[d.lower()] for d in expr_c] + ref_a = torch.rand(*shape_a, device='cuda') + ref_b = torch.rand(*shape_b, device='cuda') + ref_c = torch.zeros(*shape_c, device='cuda') + # layouts + layout_a = [shape[d.lower()]//blocks[d] for d in expr_a if d.isupper()] + layout_b = [shape[d.lower()]//blocks[d] for d in expr_b if d.isupper()] + layout_c = [shape[d.lower()]//blocks[d] for d in expr_c if d.isupper()] + layout_a = torch.randint(0, 2, layout_a, device='cuda') + layout_b = torch.randint(0, 2, layout_b, device='cuda') + layout_c = torch.randint(0, 2, layout_c, device='cuda') + # triton computation + triton_a = to_sparse(expr_a, ref_a, layout_a, shape_a, blocks) if sparse_a else ref_a + triton_b = to_sparse(expr_b, ref_b, layout_b, shape_b, blocks) if sparse_b else ref_b + layouts = {expr_a: layout_a, expr_b: layout_b, expr_c: layout_c} + triton_c = triton.ops.einsum(expr, triton_a, triton_b, layouts, blocks) + torch.cuda.synchronize() + # reference computation + ref_a = to_dense(expr_a, triton_a, layout_a, shape_a, blocks) if sparse_a else ref_a + ref_b = to_dense(expr_b, triton_b, layout_b, shape_b, blocks) if sparse_b else ref_b + ref_c = torch.einsum(expr.lower(), ref_a, ref_b) + if sparse_c: + ref_c = to_sparse(expr_c, ref_c, layout_c, shape_c, blocks) + torch.cuda.synchronize() + print((ref_c - triton_c).abs().max()) + + + + +# shape characteristics +test_expr('bHMK,bhkn->bhmn', {'b': 2, 'h': 2, 'm': 256, 'k': 256, 'n': 256}, {'H': 1, 'M': 32, 'K': 32}) +test_expr('bhmk,bHKN->bhmn', {'b': 2, 'h': 2, 'm': 256, 'k': 256, 'n': 256}, {'H': 1, 'K': 32, 'N': 32}) +test_expr('bhmk,bhkn->bHMN', {'b': 2, 'h': 2, 'm': 256, 'k': 256, 'n': 256}, {'H': 1, 'M': 32, 'N': 32}) + diff --git a/python/examples/tutorials/mat_mul.py b/python/examples/tutorials/mat_mul.py index 78ecb9712..8ec788e45 100644 --- a/python/examples/tutorials/mat_mul.py +++ b/python/examples/tutorials/mat_mul.py @@ -56,8 +56,8 @@ class _dot(torch.autograd.Function): TYPE c[TM, TN] = acc; // epilogue - int rxm[TM] = get_program_id(0) * TM + 0 ... TM; - int rxn[TN] = get_program_id(1) * TN + 0 ... TN; + int rxm[TM] = ridx * TM + 0 ... TM; + int rxn[TN] = ridy * TN + 0 ... TN; int offc[TM, TN] = rxm[:, newaxis] * ldc + rxn[newaxis, :]; TYPE* pc[TM, TN] = C + offc; bool checkc[TM, TN] = (rxm[:, newaxis] < M) && (rxn[newaxis, :] < N); @@ -95,7 +95,7 @@ class _dot(torch.autograd.Function): if dtype not in _dot.kernel: defines = { 'TYPE' : dtype, - 'STRIDE_AM': '1', 'STRIDE_AK': 'lda', + 'STRIDE_AM': 'lda', 'STRIDE_AK': '1', 'STRIDE_BN': '1', 'STRIDE_BK': 'ldb', 'TM' : [64, 128], 'TN' : [64, 128], @@ -107,14 +107,12 @@ class _dot(torch.autograd.Function): # allocate output M, K = a.shape K, N = b.shape - c = triton.empty([M,N], dtype=dtype) + c = torch.empty([M,N], dtype=dtype, device=a.device) # enqueue grid = lambda opt: [triton.cdiv(M, opt.d('TM')), triton.cdiv(N, opt.d('TN'))] time = kernel(a, b, c, 1., M, N, K, - a.stride(0), b.stride(0), c.stride(0), - grid=grid, bench=100) - print(2*M*N*K/(time*1e-6)*1e-9) + a.stride(0), b.stride(0), c.stride(0), grid=grid) return c @@ -126,8 +124,10 @@ M, N, K = 2048, 2048, 2048 a = torch.rand((M, K)).cuda() b = torch.rand((K, N)).cuda() +#a[:] = 1 +#b[:] = 1 -#zc = torch.matmul(a,b) +zc = torch.matmul(a,b) zc_ = dot(a,b) -#print(torch.allclose(zc, zc_)) +print(torch.allclose(zc, zc_)) diff --git a/python/setup.py b/python/setup.py index d4c06b305..c3862aa0d 100644 --- a/python/setup.py +++ b/python/setup.py @@ -110,7 +110,7 @@ setup( author_email='ptillet@g.harvard.edu', description='A language and compiler for custom Deep Learning operations', long_description='', - packages=['triton', 'triton/_C', 'triton/ops', 'triton/nn'], + packages=['triton', 'triton/_C', 'triton/ops'], install_requires=['numpy', 'torch', 'sympy'], package_data={'': data}, ext_modules=[CMakeExtension('triton', 'triton/_C/')], diff --git a/python/triton/__init__.py b/python/triton/__init__.py index afaa70ecd..b1e9e8bc6 100644 --- a/python/triton/__init__.py +++ b/python/triton/__init__.py @@ -1,5 +1,5 @@ from .kernel import * -#import triton.ops +import triton.ops #import triton.nn diff --git a/python/triton/kernel.py b/python/triton/kernel.py index 9fee937a6..1cf354a3b 100644 --- a/python/triton/kernel.py +++ b/python/triton/kernel.py @@ -82,6 +82,7 @@ class kernel: grid = kwargs['grid'] libtriton.register_grid((self.op_id, device), grid) # launch + #print(self.tys) params = pack(self.tys, *[x.data_ptr() if isinstance(x, torch.Tensor) else x for x in args]) torch.cuda.synchronize() torch.ops.triton.launch_kernel(self.op_id, device, params) \ No newline at end of file diff --git a/python/triton/ops/einsum.py b/python/triton/ops/einsum.py index 464d588ca..f1fb719cf 100644 --- a/python/triton/ops/einsum.py +++ b/python/triton/ops/einsum.py @@ -5,8 +5,8 @@ from operator import mul from collections import OrderedDict from collections import namedtuple import re +import string import triton -# torch import torch # numpy -- ideally removed in a future release import numpy as np @@ -22,33 +22,14 @@ class _einsum(torch.autograd.Function): ############################# ## Triton-C code generation ############################# - def print_cc(expr, axes_0, axes_1, axes_2): - - class TritonCodePrinter(C89CodePrinter): - - def __init__(self, axes_0, axes_1, axes_2): - super(TritonCodePrinter, self).__init__() - self.axes_0 = axes_0 - self.axes_1 = axes_1 - self.axes_2 = axes_2 - - def _print_Symbol(self, expr): - name = super(C89CodePrinter, self)._print_Symbol(expr) - if expr in self.axes_0: - return f'r{name}[:, newaxis, newaxis]' - if expr in self.axes_1: - return f'r{name}[newaxis, :, newaxis]' - if expr in self.axes_2: - return f'r{name}[newaxis, newaxis, :]' - return name - - def _print_Indexed(self, expr): - assert len(expr.indices) == 1 - return "*(%s + %s)" % (self._print(expr.base.label), - self._print(expr.indices[0])) - - return TritonCodePrinter(axes_0, axes_1, axes_2).doprint(expr) - + def print_cc(expr, axes_0, axes_1, axes_2, prefix): + if expr in axes_0: + return f'{prefix}r{expr}[:, newaxis, newaxis]' + if expr in axes_1: + return f'{prefix}r{expr}[newaxis, :, newaxis]' + if expr in axes_2: + return f'{prefix}r{expr}[newaxis, newaxis, :]' + return expr def unpack_cc(tile, axes, prefix, remat): ret = '' @@ -59,7 +40,7 @@ class _einsum(torch.autograd.Function): currs = ''.join(axes[: len(axes) - i]) nexts = ''.join(axes[: len(axes) - (i + 1)]) ty = '' if remat else 'int ' - sz = '' if remat else f'[{tile}]' + sz = '' if remat or tile is None else f'[{tile}]' ret += f' {ty}{prefix}{nexts}{sz} = r{currs} / dim_{d};\n' ret += f' {ty}{prefix}{d}{sz} = r{currs} % dim_{d};\n' return ret @@ -69,18 +50,26 @@ class _einsum(torch.autograd.Function): ret = dict(zip(expr, ret)) return ret - def make_kernel(name, dtype, mask, + def make_kernel(name, dtype, expr_a, expr_b, expr_c, + sparse_a, sparse_b, sparse_c, axes_m, axes_n, axes_k, axes_b, multipleof_a, multipleof_b, multipleof_c, stride_a_last, stride_b_last, stride_c_last, lut_mode_a, lut_mode_b, delta_a, delta_b, - subscripted, varnames): + blocks): use_lut_a = True use_lut_b = True + outer_sparse_a = [x for x in expr_a if x in sparse_a and x not in axes_k] + outer_dense_a = [x for x in expr_a if x not in sparse_a and x not in axes_k] + outer_sparse_b = [x for x in expr_b if x in sparse_b and x not in axes_k] + outer_dense_b = [x for x in expr_b if x not in sparse_b and x not in axes_k] + outer_dense_c = [x for x in expr_c if x not in sparse_c and x not in axes_k] + + src = "" if use_lut_a and lut_mode_a == _einsum.LUT_MODE.CONSTANT: @@ -105,41 +94,91 @@ __global__ void {name}( for d in dim: src += f", int dim_{d}" src += "\n " - for dim, name, mult in zip([expr_a, expr_b, expr_c], + for dim, name, mult, sparse in zip([expr_a, expr_b, expr_c], ['a', 'b', 'c'], - [multipleof_a, multipleof_b, multipleof_c]): + [multipleof_a, multipleof_b, multipleof_c], + [sparse_a, sparse_b, sparse_c]): for d in range(len(dim) - 1): - attr = f'__multipleof({mult})' - src += f", int stride_{name}_{d} {attr}" + if sparse and dim[d] == sparse[0]: + src += f', int stride_{name}_block __multipleof({mult})' + src += f", int stride_{name}_{d} __multipleof({mult})" src += "\n " if lut_mode_a == _einsum.LUT_MODE.SCALAR: src += f", int stride_a_inner __multipleof({multipleof_a})" src += f", int rem_delta_a __multipleof({multipleof_a})" - elif lut_mode_a == _einsum.LUT_MODE.DRAM: + elif sparse_a or lut_mode_a == _einsum.LUT_MODE.DRAM: src += ", int* AD __noalias __readonly __aligned(16)" src += "\n " if lut_mode_b == _einsum.LUT_MODE.SCALAR: src += f", int stride_b_inner __multipleof({multipleof_b})" src += f", int rem_delta_b __multipleof({multipleof_b})" - elif lut_mode_b == _einsum.LUT_MODE.DRAM: + elif sparse_b or lut_mode_b == _einsum.LUT_MODE.DRAM: src += ", int* BD" - src += "\n" - for ptr in subscripted: - src += f", int* {ptr}" - for name in varnames: - src += f", int {name}" + src += "\n " + if sparse_c: + src += ", int* CD" + if sparse_a or sparse_b: + src += ", int width" src += """) { + + // program identifiers + int pid_0 = get_program_id(0); + int pid_1 = get_program_id(1); + +""" + if sparse_a: + src += f""" + int off_n = pid_0 / width; + int off_header = pid_0 % width; + int* header = AD + off_header * {2 + len(outer_sparse_a)}; + int* pdelta = AD + *(header + 0); + matmul_k = *(header + 1);""" + for i, d in enumerate(outer_sparse_a): + src += f""" + int off_{d} = *(header + {2 + i});""" + src += f""" + int inca = *(pdelta + 0); + int incb = *(pdelta + 1); + int off_{''.join(map(str, outer_dense_a))} = pid_1; +""" + _einsum.unpack_cc(None, outer_dense_a, "off_", False) + elif sparse_b: + src += f""" + int off_m = pid_0 / width; + int off_header = pid_0 % width; + int* header = BD + off_header * {2 + len(outer_sparse_b)}; + int* pdelta = BD + *(header + 0); + matmul_k = *(header + 1);""" + for i, d in enumerate(outer_sparse_b): + src += f""" + int off_{d} = *(header + {2 + i});""" + src += f""" + int incb = *(pdelta + 0); + int inca = *(pdelta + 1); + int off_{''.join(map(str, outer_dense_b))} = pid_1; +""" + _einsum.unpack_cc(None, outer_dense_b, "off_", False) + elif sparse_c: + src += f""" + // load LUT header + int *header = CD + pid_0 * {len(sparse_c)};""" + for i, d in enumerate(sparse_c): + src += f""" + int off_{d} = *(header + {i});""" + src += f""" + int off_{''.join(map(str, outer_dense_c))} = pid_1;""" + else: + src += """ // re-order outer program ids int grid_m = (matmul_m + TM - 1) / TM; int grid_n = (matmul_n + TN - 1) / TN; - int pid_mn = get_program_id(0) / div_m; - int pid_n = pid_mn % grid_n; - int pid_m = (pid_mn / grid_n)*div_m + (get_program_id(0) % div_m); - - // get batch program id - int pid_b = get_program_id(1); + int off_mn = pid_0 / div_m; + int off_n = off_mn % grid_n; + int off_m = (off_mn / grid_n)*div_m + (pid_0 % div_m); + int off_b = get_program_id(1);""" + src += """ #if TZ == 1 int off_k = 0; #else @@ -155,55 +194,71 @@ __global__ void {name}( // create ranges """ - rk = 'r{}'.format(''.join(map(str,axes_k))) - for axes, tile, off in zip([axes_m, axes_n, axes_b, axes_k], - ['TM', 'TN', 'TB', 'TK'], - ['pid_m*TM', 'pid_n*TN', 'pid_b*TB', 'off_k']): - currs = ''.join(map(str,axes)) - if axes: - src += f" int r{currs}[{tile}] = {off} + 0 ... {tile};\n" - src += _einsum.unpack_cc(tile, axes, 'r', False) - src += """ + sparse = sparse_a + sparse_b + sparse_c + for axes, tile, off, prefixes in zip([axes_m, axes_n, axes_b, axes_k], + ['TM', 'TN', 'TB', 'TK'], + ['off_m*TM', 'off_n*TN', 'off_b*TB', 'off_k'], + [['a', 'c'], ['b', 'c'], ['a', 'b', 'c'], ['a', 'b']]): + if not axes: + continue + currs = ''.join(map(str,axes)) + has_sparse_component = set(axes) & set(sparse) + if has_sparse_component: + src += f" int r{currs}[{tile}] = 0 ... {tile};\n" + src += _einsum.unpack_cc(tile, axes, f'r', False) + else: + src += f" int r{currs}[{tile}] = {off} + 0 ... {tile};\n" + src += _einsum.unpack_cc(tile, axes, f'r', False) + for pfx in prefixes: + for d in axes: + is_dense_dim = d not in sparse + is_dense_storage = (pfx == 'a' and not sparse_a) or\ + (pfx == 'b' and not sparse_b) or\ + (pfx == 'c' and not sparse_c) + if not is_dense_dim and is_dense_storage: + src += f" int {pfx}r{d}[{tile}] = off_{d} * BLOCK{d.upper()} + r{d};\n" + elif is_dense_dim and has_sparse_component: + src += f" int {pfx}r{d}[{tile}] = off_{d};\n" + else: + src += f" int {pfx}r{d}[{tile}] = r{d};\n" + + src += f""" // initialize pointers to A - int offa[TM, TK, TB] = """ + int offa[TM, TK, TB] = {'inca' if sparse_a or sparse_b else '0'} """ for i, sym in enumerate(expr_a): - ccode = _einsum.print_cc(sym, axes_m, axes_k, axes_b) + ccode = _einsum.print_cc(sym, axes_m, axes_k, axes_b, 'a') stride = f'stride_a_{i}' if i < len(expr_a) - 1 else f'{stride_a_last}' - if i > 0: - src += ' + ' - src += f"({ccode}) * {stride}\n " + src += f" + ({ccode}) * {stride}\n " src += ';' src += """ TYPE *pa[TM, TK, TB] = A + offa;""" - if use_lut_a and not lut_mode_a == _einsum.LUT_MODE.SCALAR: + + if not sparse_a and not sparse_b and use_lut_a and not lut_mode_a == _einsum.LUT_MODE.SCALAR: spec = '__constant__' if lut_mode_a == _einsum.LUT_MODE.CONSTANT else '' cast = '(int __constant__*)' if lut_mode_a == _einsum.LUT_MODE.CONSTANT else '' src += f""" - // initialize pointers to A look-up table int offadelta[TK] = off_k + 0 ... TK; int {spec} *padelta[TK] = {cast}AD + offadelta; int incda[TM, TK, TB] = (*padelta)[newaxis, :, newaxis];""" - src += """ + src += f""" // initialize pointers to B - int offb[TK, TN, TB] = """ + int offb[TK, TN, TB] = {'incb' if sparse_a or sparse_b else '0'}""" for i, sym in enumerate(expr_b): - ccode = _einsum.print_cc(sym, axes_k, axes_n, axes_b) + ccode = _einsum.print_cc(sym, axes_k, axes_n, axes_b, 'b') stride = f'stride_b_{i}' if i < len(expr_b) - 1 else f'{stride_b_last}' - if i > 0: - src += ' + ' - src += f"({ccode}) * {stride}\n " + src += f" + ({ccode}) * {stride}\n " src += ';' src += """ TYPE *pb[TK, TN, TB] = B + offb;""" - if use_lut_b and not lut_mode_b == _einsum.LUT_MODE.SCALAR: + if not sparse_a and not sparse_b and use_lut_b and not lut_mode_b == _einsum.LUT_MODE.SCALAR: spec = '__constant__' if lut_mode_b == _einsum.LUT_MODE.CONSTANT else '' cast = '(int __constant__*)' if lut_mode_b == _einsum.LUT_MODE.CONSTANT else '' src += f""" @@ -211,105 +266,122 @@ __global__ void {name}( int offbdelta[TK] = off_k + 0 ... TK; int *pbdelta[TK] = BD + offbdelta;""" + + rk = 'r{}'.format(''.join(map(str,axes_k))) src += f""" // prefetch int prefetch_k = select(rem_k > 0, rem_k, TK); - bool checkm[TM] = r""" + ''.join(map(str,axes_m)) + f""" < matmul_m; - bool checkn[TN] = r""" + ''.join(map(str,axes_n)) + f""" < matmul_n; - bool checkk[TK] = {rk} < prefetch_k; - bool checka[TM, TK, TB] = checkm[:, newaxis, newaxis] && checkk[newaxis, :, newaxis]; - bool checkb[TK, TN, TB] = checkk[:, newaxis, newaxis] && checkn[newaxis, :, newaxis]; + bool checkam[TM] = ar""" + ''.join(map(str,axes_m)) + f""" < matmul_m; + bool checkbn[TN] = br""" + ''.join(map(str,axes_n)) + f""" < matmul_n; + bool checkk[TK] = r{''.join(map(str, axes_k))} < prefetch_k; + bool checka[TM, TK, TB] = checkam[:, newaxis, newaxis] && checkk[newaxis, :, newaxis]; + bool checkb[TK, TN, TB] = checkk[:, newaxis, newaxis] && checkbn[newaxis, :, newaxis]; TYPE a[TM, TK, TB] = checka ? *pa : 0; TYPE b[TK, TN, TB] = checkb ? *pb : 0;""" - if lut_mode_a == _einsum.LUT_MODE.SCALAR: + if sparse_a: + src += f""" + // update pointers to look-up tables + pdelta += 2; + int incda = *(pdelta + 0); + int incdb = *(pdelta + 1); + pa += incda; + pb += incdb;""" + if sparse_b: + src += f""" + // update pointers to look-up tables + pdelta += 2; + int incdb = *(pdelta + 0); + int incda = *(pdelta + 1); + pa += incda; + pb += incdb;""" + + if not sparse_a and not sparse_b and lut_mode_a == _einsum.LUT_MODE.SCALAR: src += """ pa += rem_delta_a;""" - else: + elif not sparse_a and not sparse_b: src += """ pa += incda; padelta += TK; incda = (*padelta)[newaxis, :, newaxis];""" - if lut_mode_b == _einsum.LUT_MODE.SCALAR: + if not sparse_a and not sparse_b and lut_mode_b == _einsum.LUT_MODE.SCALAR: src += """ pb += rem_delta_b;""" - else: + elif not sparse_a and not sparse_b: src += """ pb += (*pbdelta)[:, newaxis, newaxis]; pbdelta += TK;""" src += f""" + // accumulate float acc[TM, TN, TB] = 0; for(int k = matmul_k; k > 0; k -= TK) {{ acc += a @ b; - #ifdef MASK - uint32 bits[TM, TN, TB] = bitcast(acc); - acc = bitcast(bits & MASK); - #endif - + + // load inputs checkk = k > TK; - checka = checkm[:, newaxis, newaxis] && checkk[newaxis, :, newaxis]; - checkb = checkk[:, newaxis, newaxis] && checkn[newaxis, :, newaxis]; + checka = checkam[:, newaxis, newaxis] && checkk[newaxis, :, newaxis]; + checkb = checkk[:, newaxis, newaxis] && checkbn[newaxis, :, newaxis]; a = *?(checka)pa; - b = *?(checkb)pb;""" - - if lut_mode_a == _einsum.LUT_MODE.SCALAR: + b = *?(checkb)pb; + + // update pointers""" + if sparse_a: src += """ - pa += stride_a_inner;""" + pdelta += 2; + incda = *(pdelta + 0); + incdb = *(pdelta + 1); + pa += incda; + pb += incdb; + """ + elif sparse_b: + src += """ + pdelta += 2; + incdb = *(pdelta + 0); + incda = *(pdelta + 1); + pa += incda; + pb += incdb; + """ else: - src += """ + if lut_mode_a == _einsum.LUT_MODE.SCALAR: + src += """ + pa += stride_a_inner;""" + else: + src += """ pa += incda; padelta += TK; incda = (*padelta)[newaxis, :, newaxis];""" - - - if lut_mode_b == _einsum.LUT_MODE.SCALAR: - src += """ + if lut_mode_b == _einsum.LUT_MODE.SCALAR: + src += """ pb += stride_b_inner;""" - else: - src += """ + else: + src += """ pb += (*pbdelta)[:, newaxis, newaxis]; pbdelta += TK;""" src += f""" }} TYPE c[TM, TN, TB] = acc; - - // re-materialize ranges - pid_mn = get_program_id(0) / div_m; - pid_n = pid_mn % grid_n; - pid_m = (pid_mn / grid_n)*div_m + (get_program_id(0) % div_m); -""" - for axes, tile, off in zip([axes_m, axes_n, axes_b], - ['TM', 'TN', 'TB'], - ['pid_m*TM', 'pid_n*TN', 'pid_b*TB']): - currs = ''.join(map(str,axes)) - if axes: - src += f" r{currs} = {off} + 0 ... {tile};\n" - src += _einsum.unpack_cc(tile, axes, 'r', True) - - src += """ + // initialize pointers to C - int offc[TM, TN, TB] = """ + int offc[TM, TN, TB] = {'pid_0*TN*TN' if sparse_c else 0}""" for i, sym in enumerate(expr_c): stride = f'stride_c_{i}' if i < len(expr_c) - 1 else f'{stride_c_last}' - ccode = _einsum.print_cc(sym, axes_m, axes_n, axes_b) - if i > 0: - src += ' + ' - src += f"({ccode}) * {stride}\n " + ccode = _einsum.print_cc(sym, axes_m, axes_n, axes_b, 'c') + src += f"\n + ({ccode}) * {stride}" src += ';' src += """ TYPE *pc[TM, TN, TB] = C + offc; // bounds-checking - checkm = r""" + ''.join(map(str,axes_m)) + """ < matmul_m; - checkn = r""" + ''.join(map(str,axes_n)) + """ < matmul_n; - bool checkc[TM, TN, TB] = checkm[:, newaxis, newaxis] && - checkn[newaxis, :, newaxis]; + bool checkcm[TM] = cr""" + ''.join(map(str,axes_m)) + """ < matmul_m; + bool checkcn[TN] = cr""" + ''.join(map(str,axes_n)) + """ < matmul_n; + bool checkc[TM, TN, TB] = checkcm[:, newaxis, newaxis] && + checkcn[newaxis, :, newaxis]; // write back #if TZ == 1 @@ -330,11 +402,11 @@ __global__ void {name}( } """ # compilation options - TM, TN, TB, TZ = [16, 32, 64, 128], [16, 32, 64, 128], 1, [1, 4, 16] + TM, TN, TB, TZ = [32], [32], 1, [1] TK = 16 if dtype==torch.float16 else 8 defines = {'TM': TM, 'TN': TN, 'TB': TB, 'TK': TK, 'TZ': TZ, 'TYPE': dtype} - if mask is not None: - defines['MASK'] = '{0:#0{1}x}'.format(mask, 10) + for d, B in blocks.items(): + defines[f'BLOCK{d}'] = B # create kernel ret = triton.kernel(src, defines=defines) # set constant @@ -376,27 +448,81 @@ __global__ void {name}( k = k // dims[d] return ret - def make_delta(axes, step, stride, dims, symbols, arrays): + + def make_dsd_delta(axes, step, stride, dims, symbols, sparse, layout, blocks): + # depth of reductions + depth = layout.sum(*[i for i, d in enumerate(sparse) if d in axes]) + # outer dimension indices + outer = torch.nonzero(depth, as_tuple=False) + outer = [outer[:,i] for i in range(outer.shape[1])] + # find offset of outer dimensions + depth = depth.view(-1) + offsets = torch.zeros_like(depth) + offsets[1:] = torch.cumsum(depth[:-1], 0) + # compute delta for b + # TODO: support multiple sparse red indices + col = next((i for i, d in enumerate(sparse) if d in axes), None) + block = blocks[sparse[-1].upper()] + div = block // step + delta_b = torch.nonzero(layout.transpose(-1, col), as_tuple=False)[:, -1].reshape(-1).contiguous() + delta_b *= block + delta_b = [delta_b + step*i for i in range(div)] + delta_b = torch.stack(delta_b, dim=1) + delta_b = delta_b.view(-1) + # compute delta for a + bstride = 1 + for d in sparse[::-1]: + if d in axes: + break + bstride *= blocks[d.upper()] + order = [d for d in sparse if d not in axes] +\ + [d for d in sparse if d in axes] + idx = [sparse.index(d) for d in order] + layout[layout > 0] = 1 + torch.arange(layout.sum(), device=layout.device) + layout = layout.permute(*idx) + delta_a = layout[layout > 0] - 1 + delta_a *= np.prod(list(blocks.values())) + saved = delta_a[offsets] + delta_a[1:] = delta_a[1:] - delta_a[:-1] + delta_a = delta_a.view(-1, 1).repeat(1, div) + delta_a[:, 1:] = step*bstride + delta_a[:, 0] -= (div - 1)*step*bstride + delta_a[offsets, 0] = saved + delta_a = delta_a.view(-1) + delta = torch.stack((delta_a, delta_b), dim=1).view(-1).contiguous() + # form look-up table + depth *= blocks[symbols[-1].upper()] + offsets *= div + header = torch.stack((offsets, depth, *outer), dim=1).view(-1).contiguous() + nouter = 2 + len(outer) + header[::nouter] = header[::nouter]*2 + header.shape[0] + lut = torch.cat((header, delta)).int().int().cpu().numpy() + return lut, nouter, _einsum.LUT_MODE.DRAM + + def make_delta(axes, step, stride, dims, symbols, sparse, layout, lut = None, nouter = None): # symbolic pointer increments + symbols = [sp.symbols(x) for x in symbols] delta = _einsum.symbolic_delta(symbols, axes) args = [f'stride{d}' for d in range(len(stride))] args += [f'{sk}' for sk in axes] args += [f'next{sk}' for sk in axes] - args += [f'{sk}' for sk, _ in arrays] fn = sp.lambdify(args, delta, 'numpy') - # inner axes values - inner = [dims[d] for d in axes] - inner = np.prod(inner) - rem = inner % step - rem = rem if rem > 0 else step - # k = [0, 1, ..., step, - # rem, rem + 1, ... rem + inner] - k = np.concatenate((np.arange(step), - np.arange(rem, inner))).astype(np.int32) - # nextk = [rem, 1 + rem, ..., step + rem, - # rem + step, rem + 1 + step, ..., inner + step] - nextk = np.concatenate((k[:step] + rem, - k[step:] + step)) + if lut is None: + # inner axes values + inner = [dims[d] for d in axes] + inner = np.prod(inner) + rem = inner % step + rem = rem if rem > 0 else step + # k = [0, 1, ..., step, rem, rem + 1, ... rem + inner] + # nextk = [rem, 1 + rem, ..., step + rem, rem + step, rem + 1 + step, ..., inner + step] + k = np.concatenate((np.arange(step), np.arange(rem, inner))).astype(np.int32) + nextk = np.concatenate((k[:step] + rem, k[step:] + step)) + else: + idx = (lut[:lut[0]:nouter] - lut[0])//2 + k = lut[lut[0]+1::2] + k = np.insert(k, idx, 0) + nextk = k[1:] + k = k[:-1] # offsets off = _einsum.unpack_offset(k, axes, dims) nextoff = _einsum.unpack_offset(nextk, axes, dims) @@ -404,71 +530,20 @@ __global__ void {name}( args = [s for s in stride] args += [off[sk] for sk in axes] args += [nextoff[sk] for sk in axes] - args += [x for _, x in arrays] delta = fn(*args) + delta = np.maximum(delta, 0) + if lut is not None: + idx = idx[1:] + np.arange(idx.shape[0] - 1) + delta = np.delete(delta, idx) + lut[lut[0]+1::2] = delta + return None, None return delta, _einsum.lut_mode(delta[step:-step]) - ############################ - ## Einsum parsing - ############################ - - def uniq(seq): - seen = set() - seen_add = seen.add - return [x for x in seq if not (x in seen or seen_add(x))] - - def parse_axes(expr_a, expr_b, expr_c, subscripted): - is_index = lambda x: type(x) == sp.indexed.Indexed or str(x) in subscripted - sym_a = [x for s in expr_a for x in s.free_symbols if not is_index(x)] - sym_b = [x for s in expr_b for x in s.free_symbols if not is_index(x)] - sym_c = [x for s in expr_c for x in s.free_symbols] - batch = [d for d in sym_a if d in sym_b and d in sym_c] - outer = [d for d in sym_a if d not in sym_b and d in sym_c] - inner = [d for d in sym_a if d in sym_b and d not in sym_c] - variables = [d for d in sym_a if d not in sym_b and d not in sym_c] - return _einsum.uniq(batch), _einsum.uniq(outer), _einsum.uniq(inner), variables - - - def replace_subscript(expr, arrays): - # replace array indexing by Indexed() - indexed = re.findall('([_a-zA-Z][_a-zA-Z0-9]*)\[([_a-z]*)\]', expr) - for x in indexed: - arrays.append(x[0]) - expr = expr.replace(f'{x[0]}[{x[1]}]', f'Indexed({x[0]},{x[1]})') - return expr - - - def parse_expr(expr, arrays): - # extract symbols - sym = [] - i = 0 - while i < len(expr): - d = expr[i] - if d == '(': - size = expr[i:].find(')') - d = expr[i : i + size + 1] - d = _einsum.replace_subscript(d, arrays) - sym.append(parse_expr(d)) - i += size + 1 - else: - sym.append(parse_expr(d)) - i += 1 - return sym - - ############################ - ## Preprocessing - ############################ - @staticmethod - def pad(tensor, pad): - pad = pad + [0] * (2*len(tensor.shape) - len(pad)) - begin = [ x if x > 0 else None for x in pad[-1::-2]] - end = [-x if x > 0 else None for x in pad[-2::-2]] - slices = [slice(b, e) for b, e in zip(begin, end)] - tensor = torch.nn.functional.pad(tensor, pad, 'constant', 0) - tensor = tensor[slices] - return tensor - + def make_sdd_lut(layout_c, sparse_c, blocks): + nnz = torch.nonzero(layout_c, as_tuple=False) + lut = nnz.reshape(-1).int().cuda() + return lut ############################ ## Compilation @@ -478,69 +553,79 @@ __global__ void {name}( locks = None kernel_cache = dict() - - @staticmethod - def _tile(M, N, B, TMs, TNs, TBs, TZs, TK): - smp = 15 - # occupancy estimation - grid = lambda TM, TN, TB, TZ: \ - triton.cdiv(M, TM)* \ - triton.cdiv(N, TN)* \ - triton.cdiv(B, TB)* \ - TZ - occupancy = lambda TM, TN, TB, TZ: \ - min(grid(TM, TN, TB, TZ), 4*smp) - # arithmetic intensity estimation - intensity = lambda TM, TN: \ - TM * TN * TK / (TM*TK + TK*TN) - # occupancy/intensity for all configurations - estimates = {(TM, TN, TB, TZ): (occupancy(TM, TN, TB, TZ), intensity(TM, TN)) \ - for TM in TMs \ - for TN in TNs \ - for TB in TBs \ - for TZ in TZs } - # returns configuration that maximizes occupancy subject to maximizing intensity - estimates = sorted(estimates.items(), - key=lambda item: item[1], - reverse=True) - return estimates[0][0] - - def __init__(self, einsum, dtype, stride_a, stride_b, stride_c, shape_a, shape_b, arrays, mask, shape_c, varnames): + + def __init__(self, einsum, dtype, stride_a, stride_b, shape_a, shape_b, layouts, blocks): # parse symbols expr_a, expr_bc = einsum.split(",") expr_b, expr_c = expr_bc.split("->") - subscripted = [] - sym_a = _einsum.parse_expr(expr_a, subscripted) - sym_b = _einsum.parse_expr(expr_b, subscripted) - sym_c = _einsum.parse_expr(expr_c, subscripted) + sym_a = expr_a.lower() + sym_b = expr_b.lower() + sym_c = expr_c.lower() + sparse_a = [x.lower() for x in expr_a if x.isupper()] + sparse_b = [x.lower() for x in expr_b if x.isupper()] + sparse_c = [x.lower() for x in expr_c if x.isupper()] + layout_a = layouts.get(expr_a) + layout_b = layouts.get(expr_b) + layout_c = layouts.get(expr_c) # parse axes - axes_b, axes_m, axes_k, var = _einsum.parse_axes(sym_a, sym_b, sym_c, subscripted) - _, axes_n, _, _ = _einsum.parse_axes(sym_b, sym_a, sym_c, subscripted) + axes_b = [d for d in sym_a if d in sym_b and d in sym_c] + axes_m = [d for d in sym_a if d not in sym_b and d in sym_c] + axes_k = [d for d in sym_a if d in sym_b and d not in sym_c] + axes_n = [d for d in sym_b if d not in sym_a and d in sym_c] axes = axes_b + axes_m + axes_n + axes_k - # unresolved symbols - unresolved = [x for x in map(str, var) if x not in varnames] - if unresolved: - raise ValueError(f'unresolved symbols: {unresolved}') + # check block sizes + for d in sparse_a + sparse_b + sparse_c: + if d.upper() not in blocks: + raise ValueError(f'unspecified block size for dimension: {d.upper()}') + # check layout is present + if sparse_a and layout_a is None: + raise ValueError('A is sparse but not layout provided') + if sparse_b and layout_b is None: + raise ValueError('B is sparse but not layout provided') + if sparse_c and layout_c is None: + raise ValueError('C is sparse but not layout provided') # check dimensions - dims_a = dict(zip(sym_a, shape_a)) - dims_b = dict(zip(sym_b, shape_b)) - dims_c = dict(zip(sym_c, shape_c)) - for axes in [axes_b, axes_k]: - for d in axes: - dim_a = dims_a[d] if d in sym_a else None - dim_b = dims_b[d] if d in sym_b else None - if dim_a and dim_b and dim_a != dim_b: - raise ValueError(f'incompatible dimension {d}' - f' (a: {dim_a}; b: {dim_b})') + dims_a = dict([(x, y) for x,y in zip(sym_a, shape_a) if x not in sparse_a]) + dims_b = dict([(x, y) for x,y in zip(sym_b, shape_b) if x not in sparse_b]) + dims_La = None if layout_a is None else dict(zip([x for x in expr_a if x.isupper()], layout_a.shape)) + dims_Lb = None if layout_b is None else dict(zip([x for x in expr_b if x.isupper()], layout_b.shape)) + # TODO: could be cleaner + read_shape = lambda d, dimsT, dimsL, sparse: dimsL[d.upper()] * blocks[d.upper()] if d in sparse else dimsT[d] + for d in axes_b + axes_m + axes_n + axes_k: + dim_a = read_shape(d, dims_a, dims_La, sparse_a) if d in sym_a else None + dim_b = read_shape(d, dims_b, dims_Lb, sparse_b) if d in sym_b else None + if d in axes_b and dim_a and dim_b and dim_a != dim_b: + raise ValueError(f'incomparible batch dimension {d} (A: {dim_a}, B: {dim_b})') + if d in axes_k and dim_a and dim_b and dim_a != dim_b: + raise ValueError(f'incompatible inner dimension {d} (A: {dim_a}, B: {dim_b})') dims = dict() dims.update(dims_a) dims.update(dims_b) - dims.update(dims_c) + for i, d in enumerate(sparse_a): + dims[d] = layout_a.shape[i] * blocks[d.upper()] + for i, d in enumerate(sparse_b): + dims[d] = layout_b.shape[i] * blocks[d.upper()] + # allocate output + shape_c = [dims[d] if d.islower() else blocks[d] for d in expr_c] + if sparse_c: + shape_c.insert(expr_c.index(sparse_c[0].upper()), int(layout_c.sum())) + stride_c = [None] * len(shape_c) + stride_c[-1] = 1 + for i in reversed(range(len(shape_c) - 1)): + stride_c[i] = stride_c[i+1] * shape_c[i+1] # look-up tables TK = 16 if dtype == torch.float16 else 8 - arrays = [(x, arrays[x]) for x in subscripted] - delta_a, lut_mode_a = _einsum.make_delta(axes_k, TK, stride_a, dims, sym_a, arrays) - delta_b, lut_mode_b = _einsum.make_delta(axes_k, TK, stride_b, dims, sym_b, arrays) + if sparse_a and not sparse_b: + delta_a, nouter, lut_mode_a = _einsum.make_dsd_delta(axes_k, TK, stride_a, dims, sym_a, sparse_a, layout_a, blocks) + delta_b, lut_mode_b = _einsum.make_delta(axes_k, TK, stride_b, dims, sym_b, sparse_b, layout_b, delta_a, nouter) + if sparse_b and not sparse_a: + delta_b, nouter, lut_mode_b = _einsum.make_dsd_delta(axes_k, TK, stride_b, dims, sym_b, sparse_b, layout_b, blocks) + delta_a, lut_mode_a = _einsum.make_delta(axes_k, TK, stride_a, dims, sym_a, sparse_a, layout_a, delta_b, nouter) + if not sparse_a and not sparse_b: + delta_a, lut_mode_a = _einsum.make_delta(axes_k, TK, stride_a, dims, sym_a, sparse_a, layout_a) + delta_b, lut_mode_b = _einsum.make_delta(axes_k, TK, stride_b, dims, sym_b, sparse_b, layout_b) + if sparse_c: + delta_c = _einsum.make_sdd_lut(layout_c, sparse_c, blocks) # hash for recompilation stride_a_multiple = max([x for x in [1, 2, 4, 8] if shape_a[-1] % x == 0]) stride_b_multiple = max([x for x in [1, 2, 4, 8] if shape_b[-1] % x == 0]) @@ -548,7 +633,7 @@ __global__ void {name}( stride_a_last = stride_a[-1] stride_b_last = stride_b[-1] stride_c_last = stride_c[-1] - name = f'{dtype}_{mask}_{expr_a}_{expr_b}_{expr_c}_{lut_mode_a}_{lut_mode_b}'\ + name = f'{dtype}_{expr_a}_{expr_b}_{expr_c}_{lut_mode_a}_{lut_mode_b}'\ f'_{stride_a_multiple}_{stride_b_multiple}_{stride_c_multiple}'\ f'_{stride_a_last}_{stride_b_last}_{stride_c_last}' # recompile if necessary @@ -556,14 +641,15 @@ __global__ void {name}( if name not in cache: cachesize = len(cache) cache[name] = _einsum.make_kernel(f'__einsum{cachesize}', - dtype, mask, + dtype, sym_a, sym_b, sym_c, + sparse_a, sparse_b, sparse_c, axes_m, axes_n, axes_k, axes_b, stride_a_multiple, stride_b_multiple, stride_c_multiple, stride_a_last, stride_b_last, stride_c_last, lut_mode_a, lut_mode_b, delta_a, delta_b, - subscripted, varnames) + blocks) self.kernel = cache[name] # Initialize locks if _einsum.instance.locks is None: @@ -576,42 +662,55 @@ __global__ void {name}( M = reduce(mul, dim_m, 1) N = reduce(mul, dim_n, 1) K = reduce(mul, dim_k, 1) - B = reduce(mul, dim_b, 1) + B = reduce(mul, [dims[d] for d in axes_b if d.upper() not in einsum], 1) stride_a = list(stride_a[:-1]) stride_b = list(stride_b[:-1]) stride_c = list(stride_c[:-1]) - arrays = [torch.from_numpy(x).cuda() for _, x in arrays] alpha = 1. div_m = 1 - self.args = [None, None, None, - _einsum.instance.locks, - alpha, M, N, K, div_m] +\ - dim_m + dim_n + dim_k + dim_b +\ - stride_a + stride_b + stride_c + self.args = [None, None, None] + self.args += [_einsum.instance.locks] + self.args += [alpha, M, N, K, div_m] + self.args += dim_m + self.args += dim_n + self.args += dim_k + self.args += dim_b + self.args += stride_a + self.args += stride_b + self.args += stride_c # LUT for A if lut_mode_a == _einsum.LUT_MODE.SCALAR: self.args += [delta_a[TK], delta_a[0]] - if lut_mode_a == _einsum.LUT_MODE.DRAM: + elif sparse_a or lut_mode_a == _einsum.LUT_MODE.DRAM: self.args += [torch.from_numpy(delta_a).cuda()] # LUT for B if lut_mode_b == _einsum.LUT_MODE.SCALAR: self.args += [delta_b[TK], delta_b[0]] - if lut_mode_b == _einsum.LUT_MODE.DRAM: + elif sparse_b or lut_mode_b == _einsum.LUT_MODE.DRAM: self.args += [torch.from_numpy(delta_b).cuda()] - # Einsum dependents - self.args += arrays - self.grid = lambda opt: [triton.cdiv(M, opt.d('TM')) * - triton.cdiv(N, opt.d('TN')), - triton.cdiv(B, opt.d('TB')), - opt.d('TZ')] + # LUT for C + if sparse_c: + self.args += [delta_c] + if sparse_a or sparse_b: + width = delta_a[0] // nouter if sparse_a else delta_b[0] // nouter + self.args += [width] + # Grid + if sparse_a: + self.grid = lambda opt: [width*triton.cdiv(N, opt.d('TN')), B, opt.d('TZ')] + elif sparse_b: + self.grid = lambda opt: [width*triton.cdiv(M, opt.d('TM')), B, opt.d('TZ')] + elif sparse_c: + width = int(layout_c.sum()) + self.grid = lambda opt: [width, B, opt.d('TZ')] + else: + self.grid = lambda opt: [triton.cdiv(M, opt.d('TM')) * + triton.cdiv(N, opt.d('TN')), + triton.cdiv(B, opt.d('TB')), + opt.d('TZ')] # position of dynamic arguments self.pos_a = 0 self.pos_b = 1 self.pos_c = 2 - # user-provided variables - self.pos_vars = len(self.args) - self.varnames = varnames - self.args += [None] * len(varnames) # save information on the operation self.expr_a = expr_a self.expr_b = expr_b @@ -620,16 +719,17 @@ __global__ void {name}( self.matmul_M = M self.matmul_N = N self.matmul_K = K - self.is_extended = any([not x.is_symbol for x in sym_a + sym_b]) + # output shape + self.shape_c = shape_c - def run(self, a, b, c, values, bench): + def run(self, a, b): + c = torch.empty(*self.shape_c, dtype=a.dtype, device=a.device) self.args[self.pos_a] = a self.args[self.pos_b] = b self.args[self.pos_c] = c - for i, name in enumerate(self.varnames): - self.args[self.pos_vars + i] = values[name] - return self.kernel(*self.args, grid=self.grid, bench=bench) + self.kernel(*self.args, grid=self.grid) + return c @@ -639,25 +739,23 @@ __global__ void {name}( ############################ instance_cache = dict() - registry = triton.utils.id_dict() + registry = dict() @staticmethod - def forward(ctx, expr, a, b, output, mask, arrays, bench, values): + def forward(ctx, expr, a, b, layouts, blocks): # compile einsum instance cache = _einsum.instance_cache key = (expr, a.dtype, - a.stride(), b.stride(), output.stride(), - a.shape, b.shape, output.shape, mask) + a.stride(), b.stride(), + a.shape , b.shape) if key not in cache: cache[key] = _einsum.instance(expr, a.dtype, - a.stride(), b.stride(), output.stride(), - a.shape, b.shape, arrays, - mask, output.shape, values.keys()) + a.stride(), b.stride(), + a.shape , b.shape , + layouts, blocks) instance = cache[key] - # run and mark as dirty output modified in-place - perf = instance.run(a, b, output, values, bench) - ctx.mark_dirty(output) + # run and mark as dirty c modified in-place + c = instance.run(a, b) # save information in context - ctx.is_extended = instance.is_extended ctx.expr_a = instance.expr_a ctx.expr_b = instance.expr_b ctx.expr_c = instance.expr_c @@ -665,10 +763,8 @@ __global__ void {name}( ctx.matmul_M = instance.matmul_M ctx.matmul_N = instance.matmul_N ctx.matmul_K = instance.matmul_K - ctx.forward_ms = perf ctx.save_for_backward(a, b) - _einsum.registry[output] = ctx - return output + return c ############################ @@ -677,9 +773,6 @@ __global__ void {name}( @staticmethod def backward(ctx, dy): - if ctx.is_extended: - raise NotImplementedError('Automatic differentiation for extended einsum not yet implemented;' - ' print write your own autograd function') a, b = ctx.saved_tensors expr_a = ctx.expr_a expr_b = ctx.expr_b @@ -694,10 +787,8 @@ __global__ void {name}( if ctx.needs_input_grad[2]: db = torch.empty_like(b) einsum(f'{expr_a},{expr_c}->{expr_b}', a, dy, db) - return None, da, db, None, None, None, None, None + return None, da, db, None, None, None, None, None, None, None -def einsum(expr, a, b, output, - mask=None, arrays=dict(), - bench=False, values=dict()): - return _einsum.apply(expr, a, b, output, mask, arrays, bench, values) \ No newline at end of file +def einsum(expr, a, b, layouts = None, blocks = dict()): + return _einsum.apply(expr, a, b, layouts, blocks) \ No newline at end of file