more work

This commit is contained in:
Philippe Tillet
2019-10-26 15:10:19 -04:00
parent 76adcb755a
commit 655f43fb5b
5 changed files with 46 additions and 8 deletions

View File

@@ -187,6 +187,7 @@ generator::generator(analysis::axes *a_axes,
void generator::visit_value(ir::value* v) {
std::cout << "visiting " << typeid(*v).name() << std::endl;
if(!seen_.insert(v).second)
return;
// create machine tile
@@ -559,8 +560,8 @@ void generator::visit_hmma_dot(ir::dot_inst* dot, shared_tile *TA, shared_tile *
bool is_a_trans = is_trans(dot->get_operand(0));
bool is_b_trans = is_trans(dot->get_operand(1));
bool is_a_row = is_a_trans ^ (ord_a[ord_a.size() - 2] == 1);
bool is_b_row = is_b_trans ^ (ord_b[ord_b.size() - 2] == 1);
bool is_a_row = is_a_trans ^ (ord_a[0] == 1);
bool is_b_row = is_b_trans ^ (ord_b[0] == 1);
Value *offset_a_i = hmma->offset_a_i_;

View File

@@ -1,4 +1,5 @@
#include <numeric>
#include <iostream>
#include "llvm/IR/IRBuilder.h"
#include "triton/codegen/selection/machine_value.h"

View File

@@ -221,6 +221,7 @@ std::unique_ptr<driver::module> function::make_bin(ir::module &module, driver::c
codegen::transform::cts cts;
codegen::generator isel(&axes, &layouts, &align, &allocation, target.get(), opt.num_warps);
// run passes
std::cout << "begin" << std::endl;
disassociate.run(module);
dce.run(module);
peephole.run(module);
@@ -244,9 +245,10 @@ std::unique_ptr<driver::module> function::make_bin(ir::module &module, driver::c
if(allocation.allocated_size() > context->device()->max_shared_memory())
return std::unique_ptr<driver::module>();
barriers.run(module);
std::cout << "isel" << std::endl;
// ir::print(module, std::cout);
// exit(EXIT_FAILURE);
isel.visit(module, *llvm);
std::cout << "done" << std::endl;
// return binary
std::unique_ptr<driver::module> res(driver::module::create(context, std::move(llvm)));
// done

View File

@@ -1,2 +1,2 @@
from .dot import dot
from .einsum import einsum
from .dot import _dot, dot
from .einsum import _einsum, einsum

View File

@@ -3,7 +3,7 @@ import triton
class _einsum(triton.function):
src = """
void einsum(TYPE * A, TYPE * B, TYPE * C,
void einsum_(TYPE * A, TYPE * B, TYPE * C,
int dim_M, int dim_N, int dim_K,
int std_A0, int std_B0, int std_C0,
int std_A1, int std_B1, int std_C1) {
@@ -30,7 +30,7 @@ class _einsum(triton.function):
for(int k = dim_K; k > 0; k -= TK) {
TYPE a[SHAPE_A] = *pa;
TYPE b[SHAPE_B] = *pb;
c += a @ b;
c += USE_A @ USE_B;
pa += TK;
pb += TK;
}
@@ -157,6 +157,7 @@ class _einsum(triton.function):
@staticmethod
def forward(ctx, subscripts, a, b):
ctx.save_for_backward(a, b)
if type(subscripts) is str:
einsum_a, einsum_bc = subscripts.split(",")
einsum_b, einsum_c = einsum_bc.split("->")
@@ -165,8 +166,41 @@ class _einsum(triton.function):
shape_c, bmnk, std0, std1, ta, tb = _einsum._parse_einsum(
einsum_a, einsum_b, einsum_c,
a.shape, b.shape
a.shape.as_list(), b.shape.as_list()
)
ctx.trans_a = ta
ctx.trans_b = tb
ctx.einsum_a = einsum_a
ctx.einsum_b = einsum_b
ctx.einsum_c = einsum_c
return _einsum.call(a, b, ta, tb, shape_c, bmnk, std0, std1, einsum_a, einsum_b, einsum_c)
@staticmethod
def backward(ctx, dc):
a, b = ctx.saved_tensors
trans_a = ctx.trans_a
trans_b = ctx.trans_b
einsum_a = ctx.einsum_a
einsum_b = ctx.einsum_b
einsum_c = ctx.einsum_c
if not trans_a and not trans_b: # NN
da = einsum((einsum_c, einsum_b, einsum_a), dc, b)
db = einsum((einsum_a, einsum_c, einsum_b), a, dc)
elif not trans_a and trans_b: # NT
da = einsum((einsum_c, einsum_b, einsum_a), dc, b)
db = einsum((einsum_c, einsum_a, einsum_b), dc, a)
elif trans_a and not trans_b: # TN
da = einsum((einsum_b, einsum_c, einsum_a), b, dc)
db = einsum((einsum_a, einsum_c, einsum_b), a, dc)
elif trans_a and trans_b: # TT (not used)
da = einsum((einsum_b, einsum_c, einsum_a), b, dc)
db = einsum((einsum_c, einsum_a, einsum_b), dc, a)
return da, db, None, None, None, None, None, None, None, None, None, None
einsum = _einsum.apply