more work
This commit is contained in:
@@ -187,6 +187,7 @@ generator::generator(analysis::axes *a_axes,
|
||||
|
||||
|
||||
void generator::visit_value(ir::value* v) {
|
||||
std::cout << "visiting " << typeid(*v).name() << std::endl;
|
||||
if(!seen_.insert(v).second)
|
||||
return;
|
||||
// create machine tile
|
||||
@@ -559,8 +560,8 @@ void generator::visit_hmma_dot(ir::dot_inst* dot, shared_tile *TA, shared_tile *
|
||||
|
||||
bool is_a_trans = is_trans(dot->get_operand(0));
|
||||
bool is_b_trans = is_trans(dot->get_operand(1));
|
||||
bool is_a_row = is_a_trans ^ (ord_a[ord_a.size() - 2] == 1);
|
||||
bool is_b_row = is_b_trans ^ (ord_b[ord_b.size() - 2] == 1);
|
||||
bool is_a_row = is_a_trans ^ (ord_a[0] == 1);
|
||||
bool is_b_row = is_b_trans ^ (ord_b[0] == 1);
|
||||
|
||||
|
||||
Value *offset_a_i = hmma->offset_a_i_;
|
||||
|
@@ -1,4 +1,5 @@
|
||||
#include <numeric>
|
||||
#include <iostream>
|
||||
#include "llvm/IR/IRBuilder.h"
|
||||
#include "triton/codegen/selection/machine_value.h"
|
||||
|
||||
|
@@ -221,6 +221,7 @@ std::unique_ptr<driver::module> function::make_bin(ir::module &module, driver::c
|
||||
codegen::transform::cts cts;
|
||||
codegen::generator isel(&axes, &layouts, &align, &allocation, target.get(), opt.num_warps);
|
||||
// run passes
|
||||
std::cout << "begin" << std::endl;
|
||||
disassociate.run(module);
|
||||
dce.run(module);
|
||||
peephole.run(module);
|
||||
@@ -244,9 +245,10 @@ std::unique_ptr<driver::module> function::make_bin(ir::module &module, driver::c
|
||||
if(allocation.allocated_size() > context->device()->max_shared_memory())
|
||||
return std::unique_ptr<driver::module>();
|
||||
barriers.run(module);
|
||||
std::cout << "isel" << std::endl;
|
||||
// ir::print(module, std::cout);
|
||||
// exit(EXIT_FAILURE);
|
||||
isel.visit(module, *llvm);
|
||||
std::cout << "done" << std::endl;
|
||||
// return binary
|
||||
std::unique_ptr<driver::module> res(driver::module::create(context, std::move(llvm)));
|
||||
// done
|
||||
|
@@ -1,2 +1,2 @@
|
||||
from .dot import dot
|
||||
from .einsum import einsum
|
||||
from .dot import _dot, dot
|
||||
from .einsum import _einsum, einsum
|
||||
|
@@ -3,7 +3,7 @@ import triton
|
||||
class _einsum(triton.function):
|
||||
|
||||
src = """
|
||||
void einsum(TYPE * A, TYPE * B, TYPE * C,
|
||||
void einsum_(TYPE * A, TYPE * B, TYPE * C,
|
||||
int dim_M, int dim_N, int dim_K,
|
||||
int std_A0, int std_B0, int std_C0,
|
||||
int std_A1, int std_B1, int std_C1) {
|
||||
@@ -30,7 +30,7 @@ class _einsum(triton.function):
|
||||
for(int k = dim_K; k > 0; k -= TK) {
|
||||
TYPE a[SHAPE_A] = *pa;
|
||||
TYPE b[SHAPE_B] = *pb;
|
||||
c += a @ b;
|
||||
c += USE_A @ USE_B;
|
||||
pa += TK;
|
||||
pb += TK;
|
||||
}
|
||||
@@ -157,6 +157,7 @@ class _einsum(triton.function):
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, subscripts, a, b):
|
||||
ctx.save_for_backward(a, b)
|
||||
if type(subscripts) is str:
|
||||
einsum_a, einsum_bc = subscripts.split(",")
|
||||
einsum_b, einsum_c = einsum_bc.split("->")
|
||||
@@ -165,8 +166,41 @@ class _einsum(triton.function):
|
||||
|
||||
shape_c, bmnk, std0, std1, ta, tb = _einsum._parse_einsum(
|
||||
einsum_a, einsum_b, einsum_c,
|
||||
a.shape, b.shape
|
||||
a.shape.as_list(), b.shape.as_list()
|
||||
)
|
||||
ctx.trans_a = ta
|
||||
ctx.trans_b = tb
|
||||
ctx.einsum_a = einsum_a
|
||||
ctx.einsum_b = einsum_b
|
||||
ctx.einsum_c = einsum_c
|
||||
return _einsum.call(a, b, ta, tb, shape_c, bmnk, std0, std1, einsum_a, einsum_b, einsum_c)
|
||||
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, dc):
|
||||
a, b = ctx.saved_tensors
|
||||
trans_a = ctx.trans_a
|
||||
trans_b = ctx.trans_b
|
||||
einsum_a = ctx.einsum_a
|
||||
einsum_b = ctx.einsum_b
|
||||
einsum_c = ctx.einsum_c
|
||||
|
||||
if not trans_a and not trans_b: # NN
|
||||
da = einsum((einsum_c, einsum_b, einsum_a), dc, b)
|
||||
db = einsum((einsum_a, einsum_c, einsum_b), a, dc)
|
||||
|
||||
elif not trans_a and trans_b: # NT
|
||||
da = einsum((einsum_c, einsum_b, einsum_a), dc, b)
|
||||
db = einsum((einsum_c, einsum_a, einsum_b), dc, a)
|
||||
|
||||
elif trans_a and not trans_b: # TN
|
||||
da = einsum((einsum_b, einsum_c, einsum_a), b, dc)
|
||||
db = einsum((einsum_a, einsum_c, einsum_b), a, dc)
|
||||
|
||||
elif trans_a and trans_b: # TT (not used)
|
||||
da = einsum((einsum_b, einsum_c, einsum_a), b, dc)
|
||||
db = einsum((einsum_c, einsum_a, einsum_b), dc, a)
|
||||
|
||||
return da, db, None, None, None, None, None, None, None, None, None, None
|
||||
|
||||
einsum = _einsum.apply
|
Reference in New Issue
Block a user