[GENERAL] Various improvements:
* Sparse einsum in triton.ops.einsum * Hacky support for fixed-tile-size atomic-add * Various bugfixes in parser
This commit is contained in:
@@ -136,7 +136,7 @@ public:
|
||||
value *create_get_num_program(unsigned axis, const std::string &name = "");
|
||||
value *create_atomic_cas(value *ptr, value *cmp, value *val, const std::string &name = "");
|
||||
value *create_atomic_exch(value *ptr, value *val, const std::string &name = "");
|
||||
value *create_atomic_add(value *ptr, value *val, const std::string &name = "");
|
||||
value *create_atomic_add(value *ptr, value *val, value *msk, const std::string &name = "");
|
||||
value *create_exp(value* arg, const std::string &name = "");
|
||||
value *create_dot(value *A, value *B, value *C, const std::string &name = "");
|
||||
value *create_trans(value *A, const std::vector<int> &perm = {}, const std::string &name = "");
|
||||
|
@@ -603,13 +603,13 @@ public:
|
||||
|
||||
class atomic_add_inst: public builtin_inst {
|
||||
private:
|
||||
atomic_add_inst(value *ptr, value *val, const std::string &name = "", instruction *next = nullptr);
|
||||
atomic_add_inst(value *ptr, value *val, value *msk, const std::string &name = "", instruction *next = nullptr);
|
||||
std::string repr_impl() const { return "atomic_add"; }
|
||||
_TRITON_DEFINE_CLONE(atomic_add_inst)
|
||||
_TRITON_DEFINE_ACCEPT(atomic_add_inst)
|
||||
|
||||
public:
|
||||
static instruction* create(value *ptr, value *val, const std::string &name = "", instruction *next = nullptr);
|
||||
static instruction* create(value *ptr, value *val, value *msk, const std::string &name = "", instruction *next = nullptr);
|
||||
};
|
||||
|
||||
class exp_inst: public builtin_inst {
|
||||
|
@@ -364,7 +364,14 @@ void generator::visit_unmasked_load_inst(ir::unmasked_load_inst* x) {
|
||||
void generator::visit_masked_load_inst(ir::masked_load_inst* x) {
|
||||
// find vector size
|
||||
ir::value *ptr = x->get_pointer_operand();
|
||||
size_t ld = layouts_->get(ptr)->get_order(0);
|
||||
auto order = layouts_->get(ptr)->get_order();
|
||||
size_t ld;
|
||||
for(size_t i = 0; i < order.size(); i++){
|
||||
ld = order[i];
|
||||
if(ld < x->get_type()->get_tile_rank())
|
||||
break;
|
||||
}
|
||||
//size_t ld = layouts_->get(ptr)->get_order(0);
|
||||
unsigned alignment = alignment_->get(ptr, ld);
|
||||
distributed_tile *pointers = (distributed_tile*)tmap_.at(ptr);
|
||||
distributed_tile *masks = (distributed_tile*)tmap_.at(x->get_mask_operand());
|
||||
@@ -652,6 +659,31 @@ void generator::visit_atomic_exch_inst(ir::atomic_exch_inst* xchg) {
|
||||
}
|
||||
|
||||
void generator::visit_atomic_add_inst(ir::atomic_add_inst* add) {
|
||||
if(add->get_type()->is_tile_ty()){
|
||||
ir::value* ptr = add->get_operand(0);
|
||||
ir::value* val = add->get_operand(1);
|
||||
ir::value* msk = add->get_operand(2);
|
||||
distributed_tile* ptrs = (distributed_tile*)tmap_.at(ptr);
|
||||
distributed_tile* vals = (distributed_tile*)tmap_.at(val);
|
||||
distributed_tile* msks = (distributed_tile*)tmap_.at(msk);
|
||||
for_each(ptr, [&](indices_t idx){
|
||||
Value *rmw_ptr = ptrs->get_value(idx);
|
||||
Value *rmw_val = vals->get_value(idx);
|
||||
Value *rmw_msk = msks->get_value(idx);
|
||||
BasicBlock *current_bb = builder_->GetInsertBlock();
|
||||
Function *parent = builder_->GetInsertBlock()->getParent();
|
||||
BasicBlock *mask_then_bb = BasicBlock::Create(*ctx_, "mask_then", parent);
|
||||
BasicBlock *mask_done_bb = BasicBlock::Create(*ctx_, "mask_done", parent);
|
||||
builder_->CreateCondBr(rmw_msk, mask_then_bb, mask_done_bb);
|
||||
builder_->SetInsertPoint(mask_then_bb);
|
||||
builder_->CreateAtomicRMW(AtomicRMWInst::FAdd, rmw_ptr, rmw_val,
|
||||
AtomicOrdering::Monotonic,
|
||||
SyncScope::System);
|
||||
builder_->CreateBr(mask_done_bb);
|
||||
builder_->SetInsertPoint(mask_done_bb);
|
||||
});
|
||||
}
|
||||
else{
|
||||
BasicBlock *current = builder_->GetInsertBlock();
|
||||
Module *module = current->getModule();
|
||||
Value *rmw_ptr = vmap_.at(add->get_operand(0));
|
||||
@@ -670,6 +702,7 @@ void generator::visit_atomic_add_inst(ir::atomic_add_inst* add) {
|
||||
builder_->CreateBr(tid_0_done_bb);
|
||||
builder_->SetInsertPoint(tid_0_done_bb);
|
||||
tgt_->add_memfence(module, *builder_);
|
||||
}
|
||||
}
|
||||
|
||||
void generator::visit_hmma_dot(ir::dot_inst* dot, shared_tile *TA, shared_tile *TB, distributed_tile *TD, unsigned NK) {
|
||||
@@ -1362,8 +1395,10 @@ void generator::visit_layout_shared(analysis::shared_layout* layout) {
|
||||
void generator::visit_basic_block(ir::basic_block * block) {
|
||||
BasicBlock *parent = (BasicBlock*)vmap_[block];
|
||||
builder_->SetInsertPoint(parent);
|
||||
for(ir::instruction *i: block->get_inst_list())
|
||||
for(ir::instruction *i: block->get_inst_list()){
|
||||
// std::cout << typeid(*i).name() << std::endl;
|
||||
visit_value(i);
|
||||
}
|
||||
vmap_[block] = builder_->GetInsertBlock();
|
||||
}
|
||||
|
||||
|
@@ -253,7 +253,7 @@ cu_module::cu_module(driver::context * context, std::unique_ptr<llvm::Module> ll
|
||||
|
||||
cu_module::cu_module(driver::context * context, std::string const & source) : module(context, CUmodule(), true), source_(source){
|
||||
cu_context::context_switcher ctx(*context);
|
||||
// std::cout << source << std::endl;
|
||||
// std::cout << source << std::endl;
|
||||
// JIT compile source-code
|
||||
CUjit_option opt[] = {CU_JIT_ERROR_LOG_BUFFER_SIZE_BYTES, CU_JIT_ERROR_LOG_BUFFER};
|
||||
unsigned int errbufsize = 8096;
|
||||
|
@@ -307,8 +307,8 @@ value *builder::create_atomic_exch(value *ptr, value *val, const std::string &na
|
||||
return insert(atomic_exch_inst::create(ptr, val, name));
|
||||
}
|
||||
|
||||
value *builder::create_atomic_add(value *ptr, value *val, const std::string &name){
|
||||
return insert(atomic_add_inst::create(ptr, val, name));
|
||||
value *builder::create_atomic_add(value *ptr, value *val, value *msk, const std::string &name){
|
||||
return insert(atomic_add_inst::create(ptr, val, msk, name));
|
||||
}
|
||||
|
||||
value *builder::create_exp(value *arg, const std::string &name){
|
||||
|
@@ -736,14 +736,15 @@ instruction* atomic_exch_inst::create(value *ptr, value *val, const std::string
|
||||
|
||||
// atomic add
|
||||
|
||||
atomic_add_inst::atomic_add_inst(value *ptr, value *val, const std::string &name, instruction *next)
|
||||
: builtin_inst(ptr->get_type()->get_pointer_element_ty(), INST_ATOMIC_ADD, 2, name, next) {
|
||||
atomic_add_inst::atomic_add_inst(value *ptr, value *val, value *msk, const std::string &name, instruction *next)
|
||||
: builtin_inst(ptr->get_type()->get_pointer_element_ty(), INST_ATOMIC_ADD, 3, name, next) {
|
||||
set_operand(0, ptr);
|
||||
set_operand(1, val);
|
||||
set_operand(2, msk);
|
||||
}
|
||||
|
||||
instruction* atomic_add_inst::create(value *ptr, value *val, const std::string &name, instruction *next) {
|
||||
return new atomic_add_inst(ptr, val, name, next);
|
||||
instruction* atomic_add_inst::create(value *ptr, value *val, value *msk, const std::string &name, instruction *next) {
|
||||
return new atomic_add_inst(ptr, val, msk, name, next);
|
||||
}
|
||||
|
||||
// exp
|
||||
|
@@ -523,7 +523,7 @@ void BinaryOp::RelationalOpTypeChecking() {
|
||||
}
|
||||
Convert();
|
||||
}
|
||||
type_ = ArithmType::New(T_INT);
|
||||
type_ = ArithmType::New(T_BOOL);
|
||||
Broadcast(this, lhs_, rhs_, type_);
|
||||
}
|
||||
|
||||
@@ -538,7 +538,7 @@ void BinaryOp::EqualityOpTypeChecking() {
|
||||
Error(this, "invalid operands to binary %s", tok_->str_.c_str());
|
||||
Convert();
|
||||
}
|
||||
type_ = ArithmType::New(T_INT);
|
||||
type_ = ArithmType::New(T_BOOL);
|
||||
Broadcast(this, lhs_, rhs_, type_);
|
||||
}
|
||||
|
||||
@@ -558,7 +558,7 @@ void BinaryOp::LogicalOpTypeChecking() {
|
||||
::Type* rhsScalType = TryExtractScalarType(this, rhs_);
|
||||
if (!lhsScalType->IsScalar() || !rhsScalType->IsScalar())
|
||||
Error(this, "the operand should be arithmetic type or pointer");
|
||||
type_ = ArithmType::New(T_INT);
|
||||
type_ = ArithmType::New(T_BOOL);
|
||||
Broadcast(this, lhs_, rhs_, type_);
|
||||
}
|
||||
|
||||
|
@@ -277,12 +277,14 @@ void Generator::VisitFuncCall(FuncCall* funcCall) {
|
||||
ir::value* val = ret_;
|
||||
return set_ret(bld_->create_atomic_exch(ptr, val));
|
||||
}
|
||||
if(name == "f32_atomic_add"){
|
||||
if(name == "f32_atomic_add" || name == "atomic_add_64x64"){
|
||||
VisitExpr(funcCall->Args()->at(0));
|
||||
ir::value* ptr = ret_;
|
||||
VisitExpr(funcCall->Args()->at(1));
|
||||
ir::value* val = ret_;
|
||||
return set_ret(bld_->create_atomic_add(ptr, val));
|
||||
VisitExpr(funcCall->Args()->at(2));
|
||||
ir::value* msk = ret_;
|
||||
return set_ret(bld_->create_atomic_add(ptr, val, msk));
|
||||
}
|
||||
if(name == "sqrtf"){
|
||||
VisitExpr(funcCall->Args()->at(0));
|
||||
@@ -338,6 +340,7 @@ void Generator::VisitTempVar(TempVar* tempVar) {
|
||||
}
|
||||
|
||||
// Statement
|
||||
// TODO: int x = x; crashes
|
||||
void Generator::VisitDeclaration(Declaration* decl) {
|
||||
auto obj = decl->obj_;
|
||||
// initialize to undef
|
||||
|
@@ -650,7 +650,7 @@ Expr* Parser::ParseDerefOp(const Token* tok) {
|
||||
Expr* pred = nullptr;
|
||||
if(ts_.Try('?')){
|
||||
ts_.Expect('(');
|
||||
pred = ParseCastExpr();
|
||||
pred = ParseExpr();
|
||||
ts_.Expect(')');
|
||||
}
|
||||
Expr* addr = ParseCastExpr();
|
||||
|
@@ -239,6 +239,7 @@ std::unique_ptr<driver::module> function::make_bin(ir::module &module,
|
||||
throw std::runtime_error("using too much shared memory");
|
||||
barriers.run(module);
|
||||
isel.visit(module, *llvm);
|
||||
// ir::print(module, std::cout);
|
||||
std::unique_ptr<driver::module> res(driver::module::create(context, std::move(llvm)));
|
||||
return res;
|
||||
}
|
||||
@@ -351,6 +352,8 @@ std::string function::preheader() {
|
||||
extern int atomic_cas(int*, int, int);
|
||||
extern int atomic_xchg(int*, int);
|
||||
extern float f32_atomic_add(float*, float);
|
||||
extern void atomic_add_128x128(float*[128, 128], float[128, 128], bool[128, 128]);
|
||||
extern void atomic_add_64x64(float*[64, 64], float[64, 64], bool[64, 64]);
|
||||
extern int get_program_id(int);
|
||||
extern int get_num_programs(int);
|
||||
extern float sqrtf(float);
|
||||
|
109
python/examples/test.py
Normal file
109
python/examples/test.py
Normal file
@@ -0,0 +1,109 @@
|
||||
import triton
|
||||
import numpy
|
||||
import torch
|
||||
import itertools
|
||||
|
||||
torch.manual_seed(0)
|
||||
numpy.random.seed(0)
|
||||
|
||||
def to_sparse(expr, data, layout, shape, block):
|
||||
# shape of result
|
||||
sparse = None
|
||||
shape_ret = []
|
||||
for i, d in enumerate(expr):
|
||||
if d.isupper() and sparse is None:
|
||||
sparse = i
|
||||
shape_ret.append(int(layout.sum()))
|
||||
if d.isupper():
|
||||
shape_ret.append(block[d])
|
||||
else:
|
||||
shape_ret.append(shape[i])
|
||||
# iterator
|
||||
steps = [block[d] if d.isupper() else 1 for d in expr]
|
||||
it = [range(0, shape[i], steps[i]) for i in range(len(expr))]
|
||||
# create result
|
||||
ret = torch.empty(*shape_ret, dtype=data.dtype, device=data.device)
|
||||
blockid = 0
|
||||
nzblockid = 0
|
||||
for curr in itertools.product(*it):
|
||||
if all([curr[i] == it[i][0] for i in range(len(curr)) if expr[i].isupper()]):
|
||||
blockid = 0
|
||||
nzblockid = 0
|
||||
data_slice = [slice(curr[i], curr[i] + steps[i], 1) for i in range(len(curr))]
|
||||
ret_slice = [slice(0, block[expr[i]], 1) if expr[i].isupper() else slice(curr[i], curr[i] + 1) for i in range(len(curr))]
|
||||
ret_slice.insert(sparse, nzblockid)
|
||||
if int(layout.view(-1)[blockid]):
|
||||
ret[ret_slice] = data[data_slice]
|
||||
nzblockid += 1
|
||||
blockid += 1
|
||||
return ret
|
||||
|
||||
def to_dense(expr, data, layout, shape, block):
|
||||
sparse = None
|
||||
for i, d in enumerate(expr):
|
||||
if d.isupper() and sparse is None:
|
||||
sparse = i
|
||||
|
||||
ret = torch.zeros(*shape, dtype=data.dtype, device=data.device)
|
||||
steps = [block[d] if d.isupper() else 1 for d in expr]
|
||||
it = [range(0, shape[i], steps[i]) for i in range(len(expr))]
|
||||
blockid = 0
|
||||
nzblockid = 0
|
||||
for curr in itertools.product(*it):
|
||||
if all([curr[i] == it[i][0] for i in range(len(curr)) if expr[i].isupper()]):
|
||||
blockid = 0
|
||||
nzblockid = 0
|
||||
ret_slice = [slice(curr[i], curr[i] + steps[i], 1) for i in range(len(curr))]
|
||||
data_slice = [slice(0, block[expr[i]], 1) if expr[i].isupper() else slice(curr[i], curr[i] + 1) for i in range(len(curr))]
|
||||
data_slice.insert(sparse, nzblockid)
|
||||
if int(layout.view(-1)[blockid]):
|
||||
ret[ret_slice] = data[data_slice]
|
||||
nzblockid += 1
|
||||
blockid += 1
|
||||
return ret
|
||||
|
||||
def test_expr(expr, shape, blocks):
|
||||
# decompose expr
|
||||
expr_a, expr_bc = expr.split(",")
|
||||
expr_b, expr_c = expr_bc.split("->")
|
||||
# check with argument is sparse
|
||||
sparse_a = any(x.isupper() for x in expr_a)
|
||||
sparse_b = any(x.isupper() for x in expr_b)
|
||||
sparse_c = any(x.isupper() for x in expr_c)
|
||||
# allocate data
|
||||
shape_a = [shape[d.lower()] for d in expr_a]
|
||||
shape_b = [shape[d.lower()] for d in expr_b]
|
||||
shape_c = [shape[d.lower()] for d in expr_c]
|
||||
ref_a = torch.rand(*shape_a, device='cuda')
|
||||
ref_b = torch.rand(*shape_b, device='cuda')
|
||||
ref_c = torch.zeros(*shape_c, device='cuda')
|
||||
# layouts
|
||||
layout_a = [shape[d.lower()]//blocks[d] for d in expr_a if d.isupper()]
|
||||
layout_b = [shape[d.lower()]//blocks[d] for d in expr_b if d.isupper()]
|
||||
layout_c = [shape[d.lower()]//blocks[d] for d in expr_c if d.isupper()]
|
||||
layout_a = torch.randint(0, 2, layout_a, device='cuda')
|
||||
layout_b = torch.randint(0, 2, layout_b, device='cuda')
|
||||
layout_c = torch.randint(0, 2, layout_c, device='cuda')
|
||||
# triton computation
|
||||
triton_a = to_sparse(expr_a, ref_a, layout_a, shape_a, blocks) if sparse_a else ref_a
|
||||
triton_b = to_sparse(expr_b, ref_b, layout_b, shape_b, blocks) if sparse_b else ref_b
|
||||
layouts = {expr_a: layout_a, expr_b: layout_b, expr_c: layout_c}
|
||||
triton_c = triton.ops.einsum(expr, triton_a, triton_b, layouts, blocks)
|
||||
torch.cuda.synchronize()
|
||||
# reference computation
|
||||
ref_a = to_dense(expr_a, triton_a, layout_a, shape_a, blocks) if sparse_a else ref_a
|
||||
ref_b = to_dense(expr_b, triton_b, layout_b, shape_b, blocks) if sparse_b else ref_b
|
||||
ref_c = torch.einsum(expr.lower(), ref_a, ref_b)
|
||||
if sparse_c:
|
||||
ref_c = to_sparse(expr_c, ref_c, layout_c, shape_c, blocks)
|
||||
torch.cuda.synchronize()
|
||||
print((ref_c - triton_c).abs().max())
|
||||
|
||||
|
||||
|
||||
|
||||
# shape characteristics
|
||||
test_expr('bHMK,bhkn->bhmn', {'b': 2, 'h': 2, 'm': 256, 'k': 256, 'n': 256}, {'H': 1, 'M': 32, 'K': 32})
|
||||
test_expr('bhmk,bHKN->bhmn', {'b': 2, 'h': 2, 'm': 256, 'k': 256, 'n': 256}, {'H': 1, 'K': 32, 'N': 32})
|
||||
test_expr('bhmk,bhkn->bHMN', {'b': 2, 'h': 2, 'm': 256, 'k': 256, 'n': 256}, {'H': 1, 'M': 32, 'N': 32})
|
||||
|
@@ -56,8 +56,8 @@ class _dot(torch.autograd.Function):
|
||||
TYPE c[TM, TN] = acc;
|
||||
|
||||
// epilogue
|
||||
int rxm[TM] = get_program_id(0) * TM + 0 ... TM;
|
||||
int rxn[TN] = get_program_id(1) * TN + 0 ... TN;
|
||||
int rxm[TM] = ridx * TM + 0 ... TM;
|
||||
int rxn[TN] = ridy * TN + 0 ... TN;
|
||||
int offc[TM, TN] = rxm[:, newaxis] * ldc + rxn[newaxis, :];
|
||||
TYPE* pc[TM, TN] = C + offc;
|
||||
bool checkc[TM, TN] = (rxm[:, newaxis] < M) && (rxn[newaxis, :] < N);
|
||||
@@ -95,7 +95,7 @@ class _dot(torch.autograd.Function):
|
||||
if dtype not in _dot.kernel:
|
||||
defines = {
|
||||
'TYPE' : dtype,
|
||||
'STRIDE_AM': '1', 'STRIDE_AK': 'lda',
|
||||
'STRIDE_AM': 'lda', 'STRIDE_AK': '1',
|
||||
'STRIDE_BN': '1', 'STRIDE_BK': 'ldb',
|
||||
'TM' : [64, 128],
|
||||
'TN' : [64, 128],
|
||||
@@ -107,14 +107,12 @@ class _dot(torch.autograd.Function):
|
||||
# allocate output
|
||||
M, K = a.shape
|
||||
K, N = b.shape
|
||||
c = triton.empty([M,N], dtype=dtype)
|
||||
c = torch.empty([M,N], dtype=dtype, device=a.device)
|
||||
# enqueue
|
||||
grid = lambda opt: [triton.cdiv(M, opt.d('TM')),
|
||||
triton.cdiv(N, opt.d('TN'))]
|
||||
time = kernel(a, b, c, 1., M, N, K,
|
||||
a.stride(0), b.stride(0), c.stride(0),
|
||||
grid=grid, bench=100)
|
||||
print(2*M*N*K/(time*1e-6)*1e-9)
|
||||
a.stride(0), b.stride(0), c.stride(0), grid=grid)
|
||||
return c
|
||||
|
||||
|
||||
@@ -126,8 +124,10 @@ M, N, K = 2048, 2048, 2048
|
||||
a = torch.rand((M, K)).cuda()
|
||||
b = torch.rand((K, N)).cuda()
|
||||
|
||||
#a[:] = 1
|
||||
#b[:] = 1
|
||||
|
||||
#zc = torch.matmul(a,b)
|
||||
zc = torch.matmul(a,b)
|
||||
zc_ = dot(a,b)
|
||||
|
||||
#print(torch.allclose(zc, zc_))
|
||||
print(torch.allclose(zc, zc_))
|
||||
|
@@ -110,7 +110,7 @@ setup(
|
||||
author_email='ptillet@g.harvard.edu',
|
||||
description='A language and compiler for custom Deep Learning operations',
|
||||
long_description='',
|
||||
packages=['triton', 'triton/_C', 'triton/ops', 'triton/nn'],
|
||||
packages=['triton', 'triton/_C', 'triton/ops'],
|
||||
install_requires=['numpy', 'torch', 'sympy'],
|
||||
package_data={'': data},
|
||||
ext_modules=[CMakeExtension('triton', 'triton/_C/')],
|
||||
|
@@ -1,5 +1,5 @@
|
||||
from .kernel import *
|
||||
#import triton.ops
|
||||
import triton.ops
|
||||
#import triton.nn
|
||||
|
||||
|
||||
|
@@ -82,6 +82,7 @@ class kernel:
|
||||
grid = kwargs['grid']
|
||||
libtriton.register_grid((self.op_id, device), grid)
|
||||
# launch
|
||||
#print(self.tys)
|
||||
params = pack(self.tys, *[x.data_ptr() if isinstance(x, torch.Tensor) else x for x in args])
|
||||
torch.cuda.synchronize()
|
||||
torch.ops.triton.launch_kernel(self.op_id, device, params)
|
@@ -5,8 +5,8 @@ from operator import mul
|
||||
from collections import OrderedDict
|
||||
from collections import namedtuple
|
||||
import re
|
||||
import string
|
||||
import triton
|
||||
# torch
|
||||
import torch
|
||||
# numpy -- ideally removed in a future release
|
||||
import numpy as np
|
||||
@@ -22,33 +22,14 @@ class _einsum(torch.autograd.Function):
|
||||
#############################
|
||||
## Triton-C code generation
|
||||
#############################
|
||||
def print_cc(expr, axes_0, axes_1, axes_2):
|
||||
|
||||
class TritonCodePrinter(C89CodePrinter):
|
||||
|
||||
def __init__(self, axes_0, axes_1, axes_2):
|
||||
super(TritonCodePrinter, self).__init__()
|
||||
self.axes_0 = axes_0
|
||||
self.axes_1 = axes_1
|
||||
self.axes_2 = axes_2
|
||||
|
||||
def _print_Symbol(self, expr):
|
||||
name = super(C89CodePrinter, self)._print_Symbol(expr)
|
||||
if expr in self.axes_0:
|
||||
return f'r{name}[:, newaxis, newaxis]'
|
||||
if expr in self.axes_1:
|
||||
return f'r{name}[newaxis, :, newaxis]'
|
||||
if expr in self.axes_2:
|
||||
return f'r{name}[newaxis, newaxis, :]'
|
||||
return name
|
||||
|
||||
def _print_Indexed(self, expr):
|
||||
assert len(expr.indices) == 1
|
||||
return "*(%s + %s)" % (self._print(expr.base.label),
|
||||
self._print(expr.indices[0]))
|
||||
|
||||
return TritonCodePrinter(axes_0, axes_1, axes_2).doprint(expr)
|
||||
|
||||
def print_cc(expr, axes_0, axes_1, axes_2, prefix):
|
||||
if expr in axes_0:
|
||||
return f'{prefix}r{expr}[:, newaxis, newaxis]'
|
||||
if expr in axes_1:
|
||||
return f'{prefix}r{expr}[newaxis, :, newaxis]'
|
||||
if expr in axes_2:
|
||||
return f'{prefix}r{expr}[newaxis, newaxis, :]'
|
||||
return expr
|
||||
|
||||
def unpack_cc(tile, axes, prefix, remat):
|
||||
ret = ''
|
||||
@@ -59,7 +40,7 @@ class _einsum(torch.autograd.Function):
|
||||
currs = ''.join(axes[: len(axes) - i])
|
||||
nexts = ''.join(axes[: len(axes) - (i + 1)])
|
||||
ty = '' if remat else 'int '
|
||||
sz = '' if remat else f'[{tile}]'
|
||||
sz = '' if remat or tile is None else f'[{tile}]'
|
||||
ret += f' {ty}{prefix}{nexts}{sz} = r{currs} / dim_{d};\n'
|
||||
ret += f' {ty}{prefix}{d}{sz} = r{currs} % dim_{d};\n'
|
||||
return ret
|
||||
@@ -69,18 +50,26 @@ class _einsum(torch.autograd.Function):
|
||||
ret = dict(zip(expr, ret))
|
||||
return ret
|
||||
|
||||
def make_kernel(name, dtype, mask,
|
||||
def make_kernel(name, dtype,
|
||||
expr_a, expr_b, expr_c,
|
||||
sparse_a, sparse_b, sparse_c,
|
||||
axes_m, axes_n, axes_k, axes_b,
|
||||
multipleof_a, multipleof_b, multipleof_c,
|
||||
stride_a_last, stride_b_last, stride_c_last,
|
||||
lut_mode_a, lut_mode_b,
|
||||
delta_a, delta_b,
|
||||
subscripted, varnames):
|
||||
blocks):
|
||||
|
||||
use_lut_a = True
|
||||
use_lut_b = True
|
||||
|
||||
outer_sparse_a = [x for x in expr_a if x in sparse_a and x not in axes_k]
|
||||
outer_dense_a = [x for x in expr_a if x not in sparse_a and x not in axes_k]
|
||||
outer_sparse_b = [x for x in expr_b if x in sparse_b and x not in axes_k]
|
||||
outer_dense_b = [x for x in expr_b if x not in sparse_b and x not in axes_k]
|
||||
outer_dense_c = [x for x in expr_c if x not in sparse_c and x not in axes_k]
|
||||
|
||||
|
||||
src = ""
|
||||
|
||||
if use_lut_a and lut_mode_a == _einsum.LUT_MODE.CONSTANT:
|
||||
@@ -105,41 +94,91 @@ __global__ void {name}(
|
||||
for d in dim:
|
||||
src += f", int dim_{d}"
|
||||
src += "\n "
|
||||
for dim, name, mult in zip([expr_a, expr_b, expr_c],
|
||||
for dim, name, mult, sparse in zip([expr_a, expr_b, expr_c],
|
||||
['a', 'b', 'c'],
|
||||
[multipleof_a, multipleof_b, multipleof_c]):
|
||||
[multipleof_a, multipleof_b, multipleof_c],
|
||||
[sparse_a, sparse_b, sparse_c]):
|
||||
for d in range(len(dim) - 1):
|
||||
attr = f'__multipleof({mult})'
|
||||
src += f", int stride_{name}_{d} {attr}"
|
||||
if sparse and dim[d] == sparse[0]:
|
||||
src += f', int stride_{name}_block __multipleof({mult})'
|
||||
src += f", int stride_{name}_{d} __multipleof({mult})"
|
||||
src += "\n "
|
||||
if lut_mode_a == _einsum.LUT_MODE.SCALAR:
|
||||
src += f", int stride_a_inner __multipleof({multipleof_a})"
|
||||
src += f", int rem_delta_a __multipleof({multipleof_a})"
|
||||
elif lut_mode_a == _einsum.LUT_MODE.DRAM:
|
||||
elif sparse_a or lut_mode_a == _einsum.LUT_MODE.DRAM:
|
||||
src += ", int* AD __noalias __readonly __aligned(16)"
|
||||
src += "\n "
|
||||
if lut_mode_b == _einsum.LUT_MODE.SCALAR:
|
||||
src += f", int stride_b_inner __multipleof({multipleof_b})"
|
||||
src += f", int rem_delta_b __multipleof({multipleof_b})"
|
||||
elif lut_mode_b == _einsum.LUT_MODE.DRAM:
|
||||
elif sparse_b or lut_mode_b == _einsum.LUT_MODE.DRAM:
|
||||
src += ", int* BD"
|
||||
src += "\n"
|
||||
for ptr in subscripted:
|
||||
src += f", int* {ptr}"
|
||||
for name in varnames:
|
||||
src += f", int {name}"
|
||||
src += "\n "
|
||||
if sparse_c:
|
||||
src += ", int* CD"
|
||||
if sparse_a or sparse_b:
|
||||
src += ", int width"
|
||||
src += """) {
|
||||
|
||||
|
||||
// program identifiers
|
||||
int pid_0 = get_program_id(0);
|
||||
int pid_1 = get_program_id(1);
|
||||
|
||||
"""
|
||||
if sparse_a:
|
||||
src += f"""
|
||||
int off_n = pid_0 / width;
|
||||
int off_header = pid_0 % width;
|
||||
int* header = AD + off_header * {2 + len(outer_sparse_a)};
|
||||
int* pdelta = AD + *(header + 0);
|
||||
matmul_k = *(header + 1);"""
|
||||
for i, d in enumerate(outer_sparse_a):
|
||||
src += f"""
|
||||
int off_{d} = *(header + {2 + i});"""
|
||||
src += f"""
|
||||
int inca = *(pdelta + 0);
|
||||
int incb = *(pdelta + 1);
|
||||
int off_{''.join(map(str, outer_dense_a))} = pid_1;
|
||||
"""
|
||||
_einsum.unpack_cc(None, outer_dense_a, "off_", False)
|
||||
elif sparse_b:
|
||||
src += f"""
|
||||
int off_m = pid_0 / width;
|
||||
int off_header = pid_0 % width;
|
||||
int* header = BD + off_header * {2 + len(outer_sparse_b)};
|
||||
int* pdelta = BD + *(header + 0);
|
||||
matmul_k = *(header + 1);"""
|
||||
for i, d in enumerate(outer_sparse_b):
|
||||
src += f"""
|
||||
int off_{d} = *(header + {2 + i});"""
|
||||
src += f"""
|
||||
int incb = *(pdelta + 0);
|
||||
int inca = *(pdelta + 1);
|
||||
int off_{''.join(map(str, outer_dense_b))} = pid_1;
|
||||
"""
|
||||
_einsum.unpack_cc(None, outer_dense_b, "off_", False)
|
||||
elif sparse_c:
|
||||
src += f"""
|
||||
// load LUT header
|
||||
int *header = CD + pid_0 * {len(sparse_c)};"""
|
||||
for i, d in enumerate(sparse_c):
|
||||
src += f"""
|
||||
int off_{d} = *(header + {i});"""
|
||||
src += f"""
|
||||
int off_{''.join(map(str, outer_dense_c))} = pid_1;"""
|
||||
else:
|
||||
src += """
|
||||
// re-order outer program ids
|
||||
int grid_m = (matmul_m + TM - 1) / TM;
|
||||
int grid_n = (matmul_n + TN - 1) / TN;
|
||||
int pid_mn = get_program_id(0) / div_m;
|
||||
int pid_n = pid_mn % grid_n;
|
||||
int pid_m = (pid_mn / grid_n)*div_m + (get_program_id(0) % div_m);
|
||||
|
||||
// get batch program id
|
||||
int pid_b = get_program_id(1);
|
||||
int off_mn = pid_0 / div_m;
|
||||
int off_n = off_mn % grid_n;
|
||||
int off_m = (off_mn / grid_n)*div_m + (pid_0 % div_m);
|
||||
int off_b = get_program_id(1);"""
|
||||
|
||||
src += """
|
||||
#if TZ == 1
|
||||
int off_k = 0;
|
||||
#else
|
||||
@@ -155,55 +194,71 @@ __global__ void {name}(
|
||||
|
||||
// create ranges
|
||||
"""
|
||||
rk = 'r{}'.format(''.join(map(str,axes_k)))
|
||||
for axes, tile, off in zip([axes_m, axes_n, axes_b, axes_k],
|
||||
['TM', 'TN', 'TB', 'TK'],
|
||||
['pid_m*TM', 'pid_n*TN', 'pid_b*TB', 'off_k']):
|
||||
currs = ''.join(map(str,axes))
|
||||
if axes:
|
||||
src += f" int r{currs}[{tile}] = {off} + 0 ... {tile};\n"
|
||||
src += _einsum.unpack_cc(tile, axes, 'r', False)
|
||||
|
||||
src += """
|
||||
sparse = sparse_a + sparse_b + sparse_c
|
||||
for axes, tile, off, prefixes in zip([axes_m, axes_n, axes_b, axes_k],
|
||||
['TM', 'TN', 'TB', 'TK'],
|
||||
['off_m*TM', 'off_n*TN', 'off_b*TB', 'off_k'],
|
||||
[['a', 'c'], ['b', 'c'], ['a', 'b', 'c'], ['a', 'b']]):
|
||||
if not axes:
|
||||
continue
|
||||
currs = ''.join(map(str,axes))
|
||||
has_sparse_component = set(axes) & set(sparse)
|
||||
if has_sparse_component:
|
||||
src += f" int r{currs}[{tile}] = 0 ... {tile};\n"
|
||||
src += _einsum.unpack_cc(tile, axes, f'r', False)
|
||||
else:
|
||||
src += f" int r{currs}[{tile}] = {off} + 0 ... {tile};\n"
|
||||
src += _einsum.unpack_cc(tile, axes, f'r', False)
|
||||
for pfx in prefixes:
|
||||
for d in axes:
|
||||
is_dense_dim = d not in sparse
|
||||
is_dense_storage = (pfx == 'a' and not sparse_a) or\
|
||||
(pfx == 'b' and not sparse_b) or\
|
||||
(pfx == 'c' and not sparse_c)
|
||||
if not is_dense_dim and is_dense_storage:
|
||||
src += f" int {pfx}r{d}[{tile}] = off_{d} * BLOCK{d.upper()} + r{d};\n"
|
||||
elif is_dense_dim and has_sparse_component:
|
||||
src += f" int {pfx}r{d}[{tile}] = off_{d};\n"
|
||||
else:
|
||||
src += f" int {pfx}r{d}[{tile}] = r{d};\n"
|
||||
|
||||
src += f"""
|
||||
// initialize pointers to A
|
||||
int offa[TM, TK, TB] = """
|
||||
int offa[TM, TK, TB] = {'inca' if sparse_a or sparse_b else '0'} """
|
||||
for i, sym in enumerate(expr_a):
|
||||
ccode = _einsum.print_cc(sym, axes_m, axes_k, axes_b)
|
||||
ccode = _einsum.print_cc(sym, axes_m, axes_k, axes_b, 'a')
|
||||
stride = f'stride_a_{i}' if i < len(expr_a) - 1 else f'{stride_a_last}'
|
||||
if i > 0:
|
||||
src += ' + '
|
||||
src += f"({ccode}) * {stride}\n "
|
||||
src += f" + ({ccode}) * {stride}\n "
|
||||
src += ';'
|
||||
|
||||
src += """
|
||||
TYPE *pa[TM, TK, TB] = A + offa;"""
|
||||
|
||||
if use_lut_a and not lut_mode_a == _einsum.LUT_MODE.SCALAR:
|
||||
|
||||
if not sparse_a and not sparse_b and use_lut_a and not lut_mode_a == _einsum.LUT_MODE.SCALAR:
|
||||
spec = '__constant__' if lut_mode_a == _einsum.LUT_MODE.CONSTANT else ''
|
||||
cast = '(int __constant__*)' if lut_mode_a == _einsum.LUT_MODE.CONSTANT else ''
|
||||
src += f"""
|
||||
// initialize pointers to A look-up table
|
||||
int offadelta[TK] = off_k + 0 ... TK;
|
||||
int {spec} *padelta[TK] = {cast}AD + offadelta;
|
||||
int incda[TM, TK, TB] = (*padelta)[newaxis, :, newaxis];"""
|
||||
|
||||
src += """
|
||||
src += f"""
|
||||
|
||||
// initialize pointers to B
|
||||
int offb[TK, TN, TB] = """
|
||||
int offb[TK, TN, TB] = {'incb' if sparse_a or sparse_b else '0'}"""
|
||||
for i, sym in enumerate(expr_b):
|
||||
ccode = _einsum.print_cc(sym, axes_k, axes_n, axes_b)
|
||||
ccode = _einsum.print_cc(sym, axes_k, axes_n, axes_b, 'b')
|
||||
stride = f'stride_b_{i}' if i < len(expr_b) - 1 else f'{stride_b_last}'
|
||||
if i > 0:
|
||||
src += ' + '
|
||||
src += f"({ccode}) * {stride}\n "
|
||||
src += f" + ({ccode}) * {stride}\n "
|
||||
src += ';'
|
||||
|
||||
src += """
|
||||
TYPE *pb[TK, TN, TB] = B + offb;"""
|
||||
|
||||
|
||||
if use_lut_b and not lut_mode_b == _einsum.LUT_MODE.SCALAR:
|
||||
if not sparse_a and not sparse_b and use_lut_b and not lut_mode_b == _einsum.LUT_MODE.SCALAR:
|
||||
spec = '__constant__' if lut_mode_b == _einsum.LUT_MODE.CONSTANT else ''
|
||||
cast = '(int __constant__*)' if lut_mode_b == _einsum.LUT_MODE.CONSTANT else ''
|
||||
src += f"""
|
||||
@@ -211,66 +266,99 @@ __global__ void {name}(
|
||||
int offbdelta[TK] = off_k + 0 ... TK;
|
||||
int *pbdelta[TK] = BD + offbdelta;"""
|
||||
|
||||
|
||||
rk = 'r{}'.format(''.join(map(str,axes_k)))
|
||||
src += f"""
|
||||
|
||||
// prefetch
|
||||
int prefetch_k = select(rem_k > 0, rem_k, TK);
|
||||
bool checkm[TM] = r""" + ''.join(map(str,axes_m)) + f""" < matmul_m;
|
||||
bool checkn[TN] = r""" + ''.join(map(str,axes_n)) + f""" < matmul_n;
|
||||
bool checkk[TK] = {rk} < prefetch_k;
|
||||
bool checka[TM, TK, TB] = checkm[:, newaxis, newaxis] && checkk[newaxis, :, newaxis];
|
||||
bool checkb[TK, TN, TB] = checkk[:, newaxis, newaxis] && checkn[newaxis, :, newaxis];
|
||||
bool checkam[TM] = ar""" + ''.join(map(str,axes_m)) + f""" < matmul_m;
|
||||
bool checkbn[TN] = br""" + ''.join(map(str,axes_n)) + f""" < matmul_n;
|
||||
bool checkk[TK] = r{''.join(map(str, axes_k))} < prefetch_k;
|
||||
bool checka[TM, TK, TB] = checkam[:, newaxis, newaxis] && checkk[newaxis, :, newaxis];
|
||||
bool checkb[TK, TN, TB] = checkk[:, newaxis, newaxis] && checkbn[newaxis, :, newaxis];
|
||||
TYPE a[TM, TK, TB] = checka ? *pa : 0;
|
||||
TYPE b[TK, TN, TB] = checkb ? *pb : 0;"""
|
||||
|
||||
if lut_mode_a == _einsum.LUT_MODE.SCALAR:
|
||||
if sparse_a:
|
||||
src += f"""
|
||||
// update pointers to look-up tables
|
||||
pdelta += 2;
|
||||
int incda = *(pdelta + 0);
|
||||
int incdb = *(pdelta + 1);
|
||||
pa += incda;
|
||||
pb += incdb;"""
|
||||
if sparse_b:
|
||||
src += f"""
|
||||
// update pointers to look-up tables
|
||||
pdelta += 2;
|
||||
int incdb = *(pdelta + 0);
|
||||
int incda = *(pdelta + 1);
|
||||
pa += incda;
|
||||
pb += incdb;"""
|
||||
|
||||
if not sparse_a and not sparse_b and lut_mode_a == _einsum.LUT_MODE.SCALAR:
|
||||
src += """
|
||||
pa += rem_delta_a;"""
|
||||
else:
|
||||
elif not sparse_a and not sparse_b:
|
||||
src += """
|
||||
pa += incda;
|
||||
padelta += TK;
|
||||
incda = (*padelta)[newaxis, :, newaxis];"""
|
||||
|
||||
if lut_mode_b == _einsum.LUT_MODE.SCALAR:
|
||||
if not sparse_a and not sparse_b and lut_mode_b == _einsum.LUT_MODE.SCALAR:
|
||||
src += """
|
||||
pb += rem_delta_b;"""
|
||||
else:
|
||||
elif not sparse_a and not sparse_b:
|
||||
src += """
|
||||
pb += (*pbdelta)[:, newaxis, newaxis];
|
||||
pbdelta += TK;"""
|
||||
|
||||
src += f"""
|
||||
|
||||
// accumulate
|
||||
float acc[TM, TN, TB] = 0;
|
||||
for(int k = matmul_k; k > 0; k -= TK) {{
|
||||
acc += a @ b;
|
||||
#ifdef MASK
|
||||
uint32 bits[TM, TN, TB] = bitcast<uint32[TM,TN,TB]>(acc);
|
||||
acc = bitcast<float[TM, TN, TB]>(bits & MASK);
|
||||
#endif
|
||||
|
||||
// load inputs
|
||||
checkk = k > TK;
|
||||
checka = checkm[:, newaxis, newaxis] && checkk[newaxis, :, newaxis];
|
||||
checkb = checkk[:, newaxis, newaxis] && checkn[newaxis, :, newaxis];
|
||||
checka = checkam[:, newaxis, newaxis] && checkk[newaxis, :, newaxis];
|
||||
checkb = checkk[:, newaxis, newaxis] && checkbn[newaxis, :, newaxis];
|
||||
a = *?(checka)pa;
|
||||
b = *?(checkb)pb;"""
|
||||
b = *?(checkb)pb;
|
||||
|
||||
if lut_mode_a == _einsum.LUT_MODE.SCALAR:
|
||||
// update pointers"""
|
||||
if sparse_a:
|
||||
src += """
|
||||
pa += stride_a_inner;"""
|
||||
pdelta += 2;
|
||||
incda = *(pdelta + 0);
|
||||
incdb = *(pdelta + 1);
|
||||
pa += incda;
|
||||
pb += incdb;
|
||||
"""
|
||||
elif sparse_b:
|
||||
src += """
|
||||
pdelta += 2;
|
||||
incdb = *(pdelta + 0);
|
||||
incda = *(pdelta + 1);
|
||||
pa += incda;
|
||||
pb += incdb;
|
||||
"""
|
||||
else:
|
||||
src += """
|
||||
if lut_mode_a == _einsum.LUT_MODE.SCALAR:
|
||||
src += """
|
||||
pa += stride_a_inner;"""
|
||||
else:
|
||||
src += """
|
||||
pa += incda;
|
||||
padelta += TK;
|
||||
incda = (*padelta)[newaxis, :, newaxis];"""
|
||||
|
||||
|
||||
if lut_mode_b == _einsum.LUT_MODE.SCALAR:
|
||||
src += """
|
||||
if lut_mode_b == _einsum.LUT_MODE.SCALAR:
|
||||
src += """
|
||||
pb += stride_b_inner;"""
|
||||
else:
|
||||
src += """
|
||||
else:
|
||||
src += """
|
||||
pb += (*pbdelta)[:, newaxis, newaxis];
|
||||
pbdelta += TK;"""
|
||||
|
||||
@@ -278,38 +366,22 @@ __global__ void {name}(
|
||||
}}
|
||||
TYPE c[TM, TN, TB] = acc;
|
||||
|
||||
// re-materialize ranges
|
||||
pid_mn = get_program_id(0) / div_m;
|
||||
pid_n = pid_mn % grid_n;
|
||||
pid_m = (pid_mn / grid_n)*div_m + (get_program_id(0) % div_m);
|
||||
"""
|
||||
for axes, tile, off in zip([axes_m, axes_n, axes_b],
|
||||
['TM', 'TN', 'TB'],
|
||||
['pid_m*TM', 'pid_n*TN', 'pid_b*TB']):
|
||||
currs = ''.join(map(str,axes))
|
||||
if axes:
|
||||
src += f" r{currs} = {off} + 0 ... {tile};\n"
|
||||
src += _einsum.unpack_cc(tile, axes, 'r', True)
|
||||
|
||||
src += """
|
||||
// initialize pointers to C
|
||||
int offc[TM, TN, TB] = """
|
||||
int offc[TM, TN, TB] = {'pid_0*TN*TN' if sparse_c else 0}"""
|
||||
for i, sym in enumerate(expr_c):
|
||||
stride = f'stride_c_{i}' if i < len(expr_c) - 1 else f'{stride_c_last}'
|
||||
ccode = _einsum.print_cc(sym, axes_m, axes_n, axes_b)
|
||||
if i > 0:
|
||||
src += ' + '
|
||||
src += f"({ccode}) * {stride}\n "
|
||||
ccode = _einsum.print_cc(sym, axes_m, axes_n, axes_b, 'c')
|
||||
src += f"\n + ({ccode}) * {stride}"
|
||||
src += ';'
|
||||
|
||||
src += """
|
||||
TYPE *pc[TM, TN, TB] = C + offc;
|
||||
|
||||
// bounds-checking
|
||||
checkm = r""" + ''.join(map(str,axes_m)) + """ < matmul_m;
|
||||
checkn = r""" + ''.join(map(str,axes_n)) + """ < matmul_n;
|
||||
bool checkc[TM, TN, TB] = checkm[:, newaxis, newaxis] &&
|
||||
checkn[newaxis, :, newaxis];
|
||||
bool checkcm[TM] = cr""" + ''.join(map(str,axes_m)) + """ < matmul_m;
|
||||
bool checkcn[TN] = cr""" + ''.join(map(str,axes_n)) + """ < matmul_n;
|
||||
bool checkc[TM, TN, TB] = checkcm[:, newaxis, newaxis] &&
|
||||
checkcn[newaxis, :, newaxis];
|
||||
|
||||
// write back
|
||||
#if TZ == 1
|
||||
@@ -330,11 +402,11 @@ __global__ void {name}(
|
||||
}
|
||||
"""
|
||||
# compilation options
|
||||
TM, TN, TB, TZ = [16, 32, 64, 128], [16, 32, 64, 128], 1, [1, 4, 16]
|
||||
TM, TN, TB, TZ = [32], [32], 1, [1]
|
||||
TK = 16 if dtype==torch.float16 else 8
|
||||
defines = {'TM': TM, 'TN': TN, 'TB': TB, 'TK': TK, 'TZ': TZ, 'TYPE': dtype}
|
||||
if mask is not None:
|
||||
defines['MASK'] = '{0:#0{1}x}'.format(mask, 10)
|
||||
for d, B in blocks.items():
|
||||
defines[f'BLOCK{d}'] = B
|
||||
# create kernel
|
||||
ret = triton.kernel(src, defines=defines)
|
||||
# set constant
|
||||
@@ -376,27 +448,81 @@ __global__ void {name}(
|
||||
k = k // dims[d]
|
||||
return ret
|
||||
|
||||
def make_delta(axes, step, stride, dims, symbols, arrays):
|
||||
|
||||
def make_dsd_delta(axes, step, stride, dims, symbols, sparse, layout, blocks):
|
||||
# depth of reductions
|
||||
depth = layout.sum(*[i for i, d in enumerate(sparse) if d in axes])
|
||||
# outer dimension indices
|
||||
outer = torch.nonzero(depth, as_tuple=False)
|
||||
outer = [outer[:,i] for i in range(outer.shape[1])]
|
||||
# find offset of outer dimensions
|
||||
depth = depth.view(-1)
|
||||
offsets = torch.zeros_like(depth)
|
||||
offsets[1:] = torch.cumsum(depth[:-1], 0)
|
||||
# compute delta for b
|
||||
# TODO: support multiple sparse red indices
|
||||
col = next((i for i, d in enumerate(sparse) if d in axes), None)
|
||||
block = blocks[sparse[-1].upper()]
|
||||
div = block // step
|
||||
delta_b = torch.nonzero(layout.transpose(-1, col), as_tuple=False)[:, -1].reshape(-1).contiguous()
|
||||
delta_b *= block
|
||||
delta_b = [delta_b + step*i for i in range(div)]
|
||||
delta_b = torch.stack(delta_b, dim=1)
|
||||
delta_b = delta_b.view(-1)
|
||||
# compute delta for a
|
||||
bstride = 1
|
||||
for d in sparse[::-1]:
|
||||
if d in axes:
|
||||
break
|
||||
bstride *= blocks[d.upper()]
|
||||
order = [d for d in sparse if d not in axes] +\
|
||||
[d for d in sparse if d in axes]
|
||||
idx = [sparse.index(d) for d in order]
|
||||
layout[layout > 0] = 1 + torch.arange(layout.sum(), device=layout.device)
|
||||
layout = layout.permute(*idx)
|
||||
delta_a = layout[layout > 0] - 1
|
||||
delta_a *= np.prod(list(blocks.values()))
|
||||
saved = delta_a[offsets]
|
||||
delta_a[1:] = delta_a[1:] - delta_a[:-1]
|
||||
delta_a = delta_a.view(-1, 1).repeat(1, div)
|
||||
delta_a[:, 1:] = step*bstride
|
||||
delta_a[:, 0] -= (div - 1)*step*bstride
|
||||
delta_a[offsets, 0] = saved
|
||||
delta_a = delta_a.view(-1)
|
||||
delta = torch.stack((delta_a, delta_b), dim=1).view(-1).contiguous()
|
||||
# form look-up table
|
||||
depth *= blocks[symbols[-1].upper()]
|
||||
offsets *= div
|
||||
header = torch.stack((offsets, depth, *outer), dim=1).view(-1).contiguous()
|
||||
nouter = 2 + len(outer)
|
||||
header[::nouter] = header[::nouter]*2 + header.shape[0]
|
||||
lut = torch.cat((header, delta)).int().int().cpu().numpy()
|
||||
return lut, nouter, _einsum.LUT_MODE.DRAM
|
||||
|
||||
def make_delta(axes, step, stride, dims, symbols, sparse, layout, lut = None, nouter = None):
|
||||
# symbolic pointer increments
|
||||
symbols = [sp.symbols(x) for x in symbols]
|
||||
delta = _einsum.symbolic_delta(symbols, axes)
|
||||
args = [f'stride{d}' for d in range(len(stride))]
|
||||
args += [f'{sk}' for sk in axes]
|
||||
args += [f'next{sk}' for sk in axes]
|
||||
args += [f'{sk}' for sk, _ in arrays]
|
||||
fn = sp.lambdify(args, delta, 'numpy')
|
||||
# inner axes values
|
||||
inner = [dims[d] for d in axes]
|
||||
inner = np.prod(inner)
|
||||
rem = inner % step
|
||||
rem = rem if rem > 0 else step
|
||||
# k = [0, 1, ..., step,
|
||||
# rem, rem + 1, ... rem + inner]
|
||||
k = np.concatenate((np.arange(step),
|
||||
np.arange(rem, inner))).astype(np.int32)
|
||||
# nextk = [rem, 1 + rem, ..., step + rem,
|
||||
# rem + step, rem + 1 + step, ..., inner + step]
|
||||
nextk = np.concatenate((k[:step] + rem,
|
||||
k[step:] + step))
|
||||
if lut is None:
|
||||
# inner axes values
|
||||
inner = [dims[d] for d in axes]
|
||||
inner = np.prod(inner)
|
||||
rem = inner % step
|
||||
rem = rem if rem > 0 else step
|
||||
# k = [0, 1, ..., step, rem, rem + 1, ... rem + inner]
|
||||
# nextk = [rem, 1 + rem, ..., step + rem, rem + step, rem + 1 + step, ..., inner + step]
|
||||
k = np.concatenate((np.arange(step), np.arange(rem, inner))).astype(np.int32)
|
||||
nextk = np.concatenate((k[:step] + rem, k[step:] + step))
|
||||
else:
|
||||
idx = (lut[:lut[0]:nouter] - lut[0])//2
|
||||
k = lut[lut[0]+1::2]
|
||||
k = np.insert(k, idx, 0)
|
||||
nextk = k[1:]
|
||||
k = k[:-1]
|
||||
# offsets
|
||||
off = _einsum.unpack_offset(k, axes, dims)
|
||||
nextoff = _einsum.unpack_offset(nextk, axes, dims)
|
||||
@@ -404,71 +530,20 @@ __global__ void {name}(
|
||||
args = [s for s in stride]
|
||||
args += [off[sk] for sk in axes]
|
||||
args += [nextoff[sk] for sk in axes]
|
||||
args += [x for _, x in arrays]
|
||||
delta = fn(*args)
|
||||
delta = np.maximum(delta, 0)
|
||||
if lut is not None:
|
||||
idx = idx[1:] + np.arange(idx.shape[0] - 1)
|
||||
delta = np.delete(delta, idx)
|
||||
lut[lut[0]+1::2] = delta
|
||||
return None, None
|
||||
return delta, _einsum.lut_mode(delta[step:-step])
|
||||
|
||||
############################
|
||||
## Einsum parsing
|
||||
############################
|
||||
|
||||
def uniq(seq):
|
||||
seen = set()
|
||||
seen_add = seen.add
|
||||
return [x for x in seq if not (x in seen or seen_add(x))]
|
||||
|
||||
def parse_axes(expr_a, expr_b, expr_c, subscripted):
|
||||
is_index = lambda x: type(x) == sp.indexed.Indexed or str(x) in subscripted
|
||||
sym_a = [x for s in expr_a for x in s.free_symbols if not is_index(x)]
|
||||
sym_b = [x for s in expr_b for x in s.free_symbols if not is_index(x)]
|
||||
sym_c = [x for s in expr_c for x in s.free_symbols]
|
||||
batch = [d for d in sym_a if d in sym_b and d in sym_c]
|
||||
outer = [d for d in sym_a if d not in sym_b and d in sym_c]
|
||||
inner = [d for d in sym_a if d in sym_b and d not in sym_c]
|
||||
variables = [d for d in sym_a if d not in sym_b and d not in sym_c]
|
||||
return _einsum.uniq(batch), _einsum.uniq(outer), _einsum.uniq(inner), variables
|
||||
|
||||
|
||||
def replace_subscript(expr, arrays):
|
||||
# replace array indexing by Indexed()
|
||||
indexed = re.findall('([_a-zA-Z][_a-zA-Z0-9]*)\[([_a-z]*)\]', expr)
|
||||
for x in indexed:
|
||||
arrays.append(x[0])
|
||||
expr = expr.replace(f'{x[0]}[{x[1]}]', f'Indexed({x[0]},{x[1]})')
|
||||
return expr
|
||||
|
||||
|
||||
def parse_expr(expr, arrays):
|
||||
# extract symbols
|
||||
sym = []
|
||||
i = 0
|
||||
while i < len(expr):
|
||||
d = expr[i]
|
||||
if d == '(':
|
||||
size = expr[i:].find(')')
|
||||
d = expr[i : i + size + 1]
|
||||
d = _einsum.replace_subscript(d, arrays)
|
||||
sym.append(parse_expr(d))
|
||||
i += size + 1
|
||||
else:
|
||||
sym.append(parse_expr(d))
|
||||
i += 1
|
||||
return sym
|
||||
|
||||
############################
|
||||
## Preprocessing
|
||||
############################
|
||||
|
||||
@staticmethod
|
||||
def pad(tensor, pad):
|
||||
pad = pad + [0] * (2*len(tensor.shape) - len(pad))
|
||||
begin = [ x if x > 0 else None for x in pad[-1::-2]]
|
||||
end = [-x if x > 0 else None for x in pad[-2::-2]]
|
||||
slices = [slice(b, e) for b, e in zip(begin, end)]
|
||||
tensor = torch.nn.functional.pad(tensor, pad, 'constant', 0)
|
||||
tensor = tensor[slices]
|
||||
return tensor
|
||||
|
||||
def make_sdd_lut(layout_c, sparse_c, blocks):
|
||||
nnz = torch.nonzero(layout_c, as_tuple=False)
|
||||
lut = nnz.reshape(-1).int().cuda()
|
||||
return lut
|
||||
|
||||
############################
|
||||
## Compilation
|
||||
@@ -479,68 +554,78 @@ __global__ void {name}(
|
||||
locks = None
|
||||
kernel_cache = dict()
|
||||
|
||||
@staticmethod
|
||||
def _tile(M, N, B, TMs, TNs, TBs, TZs, TK):
|
||||
smp = 15
|
||||
# occupancy estimation
|
||||
grid = lambda TM, TN, TB, TZ: \
|
||||
triton.cdiv(M, TM)* \
|
||||
triton.cdiv(N, TN)* \
|
||||
triton.cdiv(B, TB)* \
|
||||
TZ
|
||||
occupancy = lambda TM, TN, TB, TZ: \
|
||||
min(grid(TM, TN, TB, TZ), 4*smp)
|
||||
# arithmetic intensity estimation
|
||||
intensity = lambda TM, TN: \
|
||||
TM * TN * TK / (TM*TK + TK*TN)
|
||||
# occupancy/intensity for all configurations
|
||||
estimates = {(TM, TN, TB, TZ): (occupancy(TM, TN, TB, TZ), intensity(TM, TN)) \
|
||||
for TM in TMs \
|
||||
for TN in TNs \
|
||||
for TB in TBs \
|
||||
for TZ in TZs }
|
||||
# returns configuration that maximizes occupancy subject to maximizing intensity
|
||||
estimates = sorted(estimates.items(),
|
||||
key=lambda item: item[1],
|
||||
reverse=True)
|
||||
return estimates[0][0]
|
||||
|
||||
def __init__(self, einsum, dtype, stride_a, stride_b, stride_c, shape_a, shape_b, arrays, mask, shape_c, varnames):
|
||||
def __init__(self, einsum, dtype, stride_a, stride_b, shape_a, shape_b, layouts, blocks):
|
||||
# parse symbols
|
||||
expr_a, expr_bc = einsum.split(",")
|
||||
expr_b, expr_c = expr_bc.split("->")
|
||||
subscripted = []
|
||||
sym_a = _einsum.parse_expr(expr_a, subscripted)
|
||||
sym_b = _einsum.parse_expr(expr_b, subscripted)
|
||||
sym_c = _einsum.parse_expr(expr_c, subscripted)
|
||||
sym_a = expr_a.lower()
|
||||
sym_b = expr_b.lower()
|
||||
sym_c = expr_c.lower()
|
||||
sparse_a = [x.lower() for x in expr_a if x.isupper()]
|
||||
sparse_b = [x.lower() for x in expr_b if x.isupper()]
|
||||
sparse_c = [x.lower() for x in expr_c if x.isupper()]
|
||||
layout_a = layouts.get(expr_a)
|
||||
layout_b = layouts.get(expr_b)
|
||||
layout_c = layouts.get(expr_c)
|
||||
# parse axes
|
||||
axes_b, axes_m, axes_k, var = _einsum.parse_axes(sym_a, sym_b, sym_c, subscripted)
|
||||
_, axes_n, _, _ = _einsum.parse_axes(sym_b, sym_a, sym_c, subscripted)
|
||||
axes_b = [d for d in sym_a if d in sym_b and d in sym_c]
|
||||
axes_m = [d for d in sym_a if d not in sym_b and d in sym_c]
|
||||
axes_k = [d for d in sym_a if d in sym_b and d not in sym_c]
|
||||
axes_n = [d for d in sym_b if d not in sym_a and d in sym_c]
|
||||
axes = axes_b + axes_m + axes_n + axes_k
|
||||
# unresolved symbols
|
||||
unresolved = [x for x in map(str, var) if x not in varnames]
|
||||
if unresolved:
|
||||
raise ValueError(f'unresolved symbols: {unresolved}')
|
||||
# check block sizes
|
||||
for d in sparse_a + sparse_b + sparse_c:
|
||||
if d.upper() not in blocks:
|
||||
raise ValueError(f'unspecified block size for dimension: {d.upper()}')
|
||||
# check layout is present
|
||||
if sparse_a and layout_a is None:
|
||||
raise ValueError('A is sparse but not layout provided')
|
||||
if sparse_b and layout_b is None:
|
||||
raise ValueError('B is sparse but not layout provided')
|
||||
if sparse_c and layout_c is None:
|
||||
raise ValueError('C is sparse but not layout provided')
|
||||
# check dimensions
|
||||
dims_a = dict(zip(sym_a, shape_a))
|
||||
dims_b = dict(zip(sym_b, shape_b))
|
||||
dims_c = dict(zip(sym_c, shape_c))
|
||||
for axes in [axes_b, axes_k]:
|
||||
for d in axes:
|
||||
dim_a = dims_a[d] if d in sym_a else None
|
||||
dim_b = dims_b[d] if d in sym_b else None
|
||||
if dim_a and dim_b and dim_a != dim_b:
|
||||
raise ValueError(f'incompatible dimension {d}'
|
||||
f' (a: {dim_a}; b: {dim_b})')
|
||||
dims_a = dict([(x, y) for x,y in zip(sym_a, shape_a) if x not in sparse_a])
|
||||
dims_b = dict([(x, y) for x,y in zip(sym_b, shape_b) if x not in sparse_b])
|
||||
dims_La = None if layout_a is None else dict(zip([x for x in expr_a if x.isupper()], layout_a.shape))
|
||||
dims_Lb = None if layout_b is None else dict(zip([x for x in expr_b if x.isupper()], layout_b.shape))
|
||||
# TODO: could be cleaner
|
||||
read_shape = lambda d, dimsT, dimsL, sparse: dimsL[d.upper()] * blocks[d.upper()] if d in sparse else dimsT[d]
|
||||
for d in axes_b + axes_m + axes_n + axes_k:
|
||||
dim_a = read_shape(d, dims_a, dims_La, sparse_a) if d in sym_a else None
|
||||
dim_b = read_shape(d, dims_b, dims_Lb, sparse_b) if d in sym_b else None
|
||||
if d in axes_b and dim_a and dim_b and dim_a != dim_b:
|
||||
raise ValueError(f'incomparible batch dimension {d} (A: {dim_a}, B: {dim_b})')
|
||||
if d in axes_k and dim_a and dim_b and dim_a != dim_b:
|
||||
raise ValueError(f'incompatible inner dimension {d} (A: {dim_a}, B: {dim_b})')
|
||||
dims = dict()
|
||||
dims.update(dims_a)
|
||||
dims.update(dims_b)
|
||||
dims.update(dims_c)
|
||||
for i, d in enumerate(sparse_a):
|
||||
dims[d] = layout_a.shape[i] * blocks[d.upper()]
|
||||
for i, d in enumerate(sparse_b):
|
||||
dims[d] = layout_b.shape[i] * blocks[d.upper()]
|
||||
# allocate output
|
||||
shape_c = [dims[d] if d.islower() else blocks[d] for d in expr_c]
|
||||
if sparse_c:
|
||||
shape_c.insert(expr_c.index(sparse_c[0].upper()), int(layout_c.sum()))
|
||||
stride_c = [None] * len(shape_c)
|
||||
stride_c[-1] = 1
|
||||
for i in reversed(range(len(shape_c) - 1)):
|
||||
stride_c[i] = stride_c[i+1] * shape_c[i+1]
|
||||
# look-up tables
|
||||
TK = 16 if dtype == torch.float16 else 8
|
||||
arrays = [(x, arrays[x]) for x in subscripted]
|
||||
delta_a, lut_mode_a = _einsum.make_delta(axes_k, TK, stride_a, dims, sym_a, arrays)
|
||||
delta_b, lut_mode_b = _einsum.make_delta(axes_k, TK, stride_b, dims, sym_b, arrays)
|
||||
if sparse_a and not sparse_b:
|
||||
delta_a, nouter, lut_mode_a = _einsum.make_dsd_delta(axes_k, TK, stride_a, dims, sym_a, sparse_a, layout_a, blocks)
|
||||
delta_b, lut_mode_b = _einsum.make_delta(axes_k, TK, stride_b, dims, sym_b, sparse_b, layout_b, delta_a, nouter)
|
||||
if sparse_b and not sparse_a:
|
||||
delta_b, nouter, lut_mode_b = _einsum.make_dsd_delta(axes_k, TK, stride_b, dims, sym_b, sparse_b, layout_b, blocks)
|
||||
delta_a, lut_mode_a = _einsum.make_delta(axes_k, TK, stride_a, dims, sym_a, sparse_a, layout_a, delta_b, nouter)
|
||||
if not sparse_a and not sparse_b:
|
||||
delta_a, lut_mode_a = _einsum.make_delta(axes_k, TK, stride_a, dims, sym_a, sparse_a, layout_a)
|
||||
delta_b, lut_mode_b = _einsum.make_delta(axes_k, TK, stride_b, dims, sym_b, sparse_b, layout_b)
|
||||
if sparse_c:
|
||||
delta_c = _einsum.make_sdd_lut(layout_c, sparse_c, blocks)
|
||||
# hash for recompilation
|
||||
stride_a_multiple = max([x for x in [1, 2, 4, 8] if shape_a[-1] % x == 0])
|
||||
stride_b_multiple = max([x for x in [1, 2, 4, 8] if shape_b[-1] % x == 0])
|
||||
@@ -548,7 +633,7 @@ __global__ void {name}(
|
||||
stride_a_last = stride_a[-1]
|
||||
stride_b_last = stride_b[-1]
|
||||
stride_c_last = stride_c[-1]
|
||||
name = f'{dtype}_{mask}_{expr_a}_{expr_b}_{expr_c}_{lut_mode_a}_{lut_mode_b}'\
|
||||
name = f'{dtype}_{expr_a}_{expr_b}_{expr_c}_{lut_mode_a}_{lut_mode_b}'\
|
||||
f'_{stride_a_multiple}_{stride_b_multiple}_{stride_c_multiple}'\
|
||||
f'_{stride_a_last}_{stride_b_last}_{stride_c_last}'
|
||||
# recompile if necessary
|
||||
@@ -556,14 +641,15 @@ __global__ void {name}(
|
||||
if name not in cache:
|
||||
cachesize = len(cache)
|
||||
cache[name] = _einsum.make_kernel(f'__einsum{cachesize}',
|
||||
dtype, mask,
|
||||
dtype,
|
||||
sym_a, sym_b, sym_c,
|
||||
sparse_a, sparse_b, sparse_c,
|
||||
axes_m, axes_n, axes_k, axes_b,
|
||||
stride_a_multiple, stride_b_multiple, stride_c_multiple,
|
||||
stride_a_last, stride_b_last, stride_c_last,
|
||||
lut_mode_a, lut_mode_b,
|
||||
delta_a, delta_b,
|
||||
subscripted, varnames)
|
||||
blocks)
|
||||
self.kernel = cache[name]
|
||||
# Initialize locks
|
||||
if _einsum.instance.locks is None:
|
||||
@@ -576,42 +662,55 @@ __global__ void {name}(
|
||||
M = reduce(mul, dim_m, 1)
|
||||
N = reduce(mul, dim_n, 1)
|
||||
K = reduce(mul, dim_k, 1)
|
||||
B = reduce(mul, dim_b, 1)
|
||||
B = reduce(mul, [dims[d] for d in axes_b if d.upper() not in einsum], 1)
|
||||
stride_a = list(stride_a[:-1])
|
||||
stride_b = list(stride_b[:-1])
|
||||
stride_c = list(stride_c[:-1])
|
||||
arrays = [torch.from_numpy(x).cuda() for _, x in arrays]
|
||||
alpha = 1.
|
||||
div_m = 1
|
||||
self.args = [None, None, None,
|
||||
_einsum.instance.locks,
|
||||
alpha, M, N, K, div_m] +\
|
||||
dim_m + dim_n + dim_k + dim_b +\
|
||||
stride_a + stride_b + stride_c
|
||||
self.args = [None, None, None]
|
||||
self.args += [_einsum.instance.locks]
|
||||
self.args += [alpha, M, N, K, div_m]
|
||||
self.args += dim_m
|
||||
self.args += dim_n
|
||||
self.args += dim_k
|
||||
self.args += dim_b
|
||||
self.args += stride_a
|
||||
self.args += stride_b
|
||||
self.args += stride_c
|
||||
# LUT for A
|
||||
if lut_mode_a == _einsum.LUT_MODE.SCALAR:
|
||||
self.args += [delta_a[TK], delta_a[0]]
|
||||
if lut_mode_a == _einsum.LUT_MODE.DRAM:
|
||||
elif sparse_a or lut_mode_a == _einsum.LUT_MODE.DRAM:
|
||||
self.args += [torch.from_numpy(delta_a).cuda()]
|
||||
# LUT for B
|
||||
if lut_mode_b == _einsum.LUT_MODE.SCALAR:
|
||||
self.args += [delta_b[TK], delta_b[0]]
|
||||
if lut_mode_b == _einsum.LUT_MODE.DRAM:
|
||||
elif sparse_b or lut_mode_b == _einsum.LUT_MODE.DRAM:
|
||||
self.args += [torch.from_numpy(delta_b).cuda()]
|
||||
# Einsum dependents
|
||||
self.args += arrays
|
||||
self.grid = lambda opt: [triton.cdiv(M, opt.d('TM')) *
|
||||
triton.cdiv(N, opt.d('TN')),
|
||||
triton.cdiv(B, opt.d('TB')),
|
||||
opt.d('TZ')]
|
||||
# LUT for C
|
||||
if sparse_c:
|
||||
self.args += [delta_c]
|
||||
if sparse_a or sparse_b:
|
||||
width = delta_a[0] // nouter if sparse_a else delta_b[0] // nouter
|
||||
self.args += [width]
|
||||
# Grid
|
||||
if sparse_a:
|
||||
self.grid = lambda opt: [width*triton.cdiv(N, opt.d('TN')), B, opt.d('TZ')]
|
||||
elif sparse_b:
|
||||
self.grid = lambda opt: [width*triton.cdiv(M, opt.d('TM')), B, opt.d('TZ')]
|
||||
elif sparse_c:
|
||||
width = int(layout_c.sum())
|
||||
self.grid = lambda opt: [width, B, opt.d('TZ')]
|
||||
else:
|
||||
self.grid = lambda opt: [triton.cdiv(M, opt.d('TM')) *
|
||||
triton.cdiv(N, opt.d('TN')),
|
||||
triton.cdiv(B, opt.d('TB')),
|
||||
opt.d('TZ')]
|
||||
# position of dynamic arguments
|
||||
self.pos_a = 0
|
||||
self.pos_b = 1
|
||||
self.pos_c = 2
|
||||
# user-provided variables
|
||||
self.pos_vars = len(self.args)
|
||||
self.varnames = varnames
|
||||
self.args += [None] * len(varnames)
|
||||
# save information on the operation
|
||||
self.expr_a = expr_a
|
||||
self.expr_b = expr_b
|
||||
@@ -620,16 +719,17 @@ __global__ void {name}(
|
||||
self.matmul_M = M
|
||||
self.matmul_N = N
|
||||
self.matmul_K = K
|
||||
self.is_extended = any([not x.is_symbol for x in sym_a + sym_b])
|
||||
# output shape
|
||||
self.shape_c = shape_c
|
||||
|
||||
|
||||
def run(self, a, b, c, values, bench):
|
||||
def run(self, a, b):
|
||||
c = torch.empty(*self.shape_c, dtype=a.dtype, device=a.device)
|
||||
self.args[self.pos_a] = a
|
||||
self.args[self.pos_b] = b
|
||||
self.args[self.pos_c] = c
|
||||
for i, name in enumerate(self.varnames):
|
||||
self.args[self.pos_vars + i] = values[name]
|
||||
return self.kernel(*self.args, grid=self.grid, bench=bench)
|
||||
self.kernel(*self.args, grid=self.grid)
|
||||
return c
|
||||
|
||||
|
||||
|
||||
@@ -639,25 +739,23 @@ __global__ void {name}(
|
||||
############################
|
||||
|
||||
instance_cache = dict()
|
||||
registry = triton.utils.id_dict()
|
||||
registry = dict()
|
||||
@staticmethod
|
||||
def forward(ctx, expr, a, b, output, mask, arrays, bench, values):
|
||||
def forward(ctx, expr, a, b, layouts, blocks):
|
||||
# compile einsum instance
|
||||
cache = _einsum.instance_cache
|
||||
key = (expr, a.dtype,
|
||||
a.stride(), b.stride(), output.stride(),
|
||||
a.shape, b.shape, output.shape, mask)
|
||||
a.stride(), b.stride(),
|
||||
a.shape , b.shape)
|
||||
if key not in cache:
|
||||
cache[key] = _einsum.instance(expr, a.dtype,
|
||||
a.stride(), b.stride(), output.stride(),
|
||||
a.shape, b.shape, arrays,
|
||||
mask, output.shape, values.keys())
|
||||
a.stride(), b.stride(),
|
||||
a.shape , b.shape ,
|
||||
layouts, blocks)
|
||||
instance = cache[key]
|
||||
# run and mark as dirty output modified in-place
|
||||
perf = instance.run(a, b, output, values, bench)
|
||||
ctx.mark_dirty(output)
|
||||
# run and mark as dirty c modified in-place
|
||||
c = instance.run(a, b)
|
||||
# save information in context
|
||||
ctx.is_extended = instance.is_extended
|
||||
ctx.expr_a = instance.expr_a
|
||||
ctx.expr_b = instance.expr_b
|
||||
ctx.expr_c = instance.expr_c
|
||||
@@ -665,10 +763,8 @@ __global__ void {name}(
|
||||
ctx.matmul_M = instance.matmul_M
|
||||
ctx.matmul_N = instance.matmul_N
|
||||
ctx.matmul_K = instance.matmul_K
|
||||
ctx.forward_ms = perf
|
||||
ctx.save_for_backward(a, b)
|
||||
_einsum.registry[output] = ctx
|
||||
return output
|
||||
return c
|
||||
|
||||
|
||||
############################
|
||||
@@ -677,9 +773,6 @@ __global__ void {name}(
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, dy):
|
||||
if ctx.is_extended:
|
||||
raise NotImplementedError('Automatic differentiation for extended einsum not yet implemented;'
|
||||
' print write your own autograd function')
|
||||
a, b = ctx.saved_tensors
|
||||
expr_a = ctx.expr_a
|
||||
expr_b = ctx.expr_b
|
||||
@@ -694,10 +787,8 @@ __global__ void {name}(
|
||||
if ctx.needs_input_grad[2]:
|
||||
db = torch.empty_like(b)
|
||||
einsum(f'{expr_a},{expr_c}->{expr_b}', a, dy, db)
|
||||
return None, da, db, None, None, None, None, None
|
||||
return None, da, db, None, None, None, None, None, None, None
|
||||
|
||||
|
||||
def einsum(expr, a, b, output,
|
||||
mask=None, arrays=dict(),
|
||||
bench=False, values=dict()):
|
||||
return _einsum.apply(expr, a, b, output, mask, arrays, bench, values)
|
||||
def einsum(expr, a, b, layouts = None, blocks = dict()):
|
||||
return _einsum.apply(expr, a, b, layouts, blocks)
|
Reference in New Issue
Block a user