[PYTHON][OPS] Bugfix in conv fprop
This commit is contained in:
@@ -297,12 +297,18 @@ void generator::visit_unmasked_load_inst(ir::unmasked_load_inst* x) {
|
|||||||
// find vector size
|
// find vector size
|
||||||
ir::value *ptr = x->get_pointer_operand();
|
ir::value *ptr = x->get_pointer_operand();
|
||||||
size_t ld = layouts_->get(ptr)->order[0];
|
size_t ld = layouts_->get(ptr)->order[0];
|
||||||
unsigned alignment = alignment_->get(ptr, ld);
|
unsigned alignment = std::max<int>(alignment_->get(ptr, ld), 1);
|
||||||
|
|
||||||
|
|
||||||
// vector loads
|
// vector loads
|
||||||
std::map<unsigned, Value*> packets;
|
std::map<unsigned, Value*> packets;
|
||||||
for_each(x, [&](indices_t idx){
|
for_each(x, [&](indices_t idx){
|
||||||
distributed_tile* result = (distributed_tile*)tmap_.at(x);
|
distributed_tile* result = (distributed_tile*)tmap_.at(x);
|
||||||
unsigned vector_size = std::min<unsigned>(result->axis(ld).contiguous, alignment);
|
// vector size
|
||||||
|
unsigned contiguous = 1;
|
||||||
|
if(ld < x->get_type()->get_tile_rank())
|
||||||
|
contiguous = result->axis(ld).contiguous;
|
||||||
|
unsigned vector_size = std::min<unsigned>(contiguous, alignment);
|
||||||
|
|
||||||
unsigned linear = result->get_linear_index(idx);
|
unsigned linear = result->get_linear_index(idx);
|
||||||
unsigned id = linear / vector_size;
|
unsigned id = linear / vector_size;
|
||||||
@@ -314,10 +320,15 @@ void generator::visit_unmasked_load_inst(ir::unmasked_load_inst* x) {
|
|||||||
packets[id] = builder_->CreateLoad(ptr);
|
packets[id] = builder_->CreateLoad(ptr);
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
|
|
||||||
// extract result element
|
// extract result element
|
||||||
for_each(x, [&](indices_t idx){
|
for_each(x, [&](indices_t idx){
|
||||||
distributed_tile* result = (distributed_tile*)tmap_.at(x);
|
distributed_tile* result = (distributed_tile*)tmap_.at(x);
|
||||||
unsigned vector_size = std::min<unsigned>(result->axis(ld).contiguous, alignment);
|
// vector size
|
||||||
|
unsigned contiguous = 1;
|
||||||
|
if(ld < x->get_type()->get_tile_rank())
|
||||||
|
contiguous = result->axis(ld).contiguous;
|
||||||
|
unsigned vector_size = std::min<unsigned>(contiguous, alignment);
|
||||||
unsigned linear = result->get_linear_index(idx);
|
unsigned linear = result->get_linear_index(idx);
|
||||||
unsigned id = linear / vector_size;
|
unsigned id = linear / vector_size;
|
||||||
set_value(x, idx, builder_->CreateExtractElement(packets.at(id), linear % vector_size));
|
set_value(x, idx, builder_->CreateExtractElement(packets.at(id), linear % vector_size));
|
||||||
|
@@ -242,6 +242,7 @@ cu_module::cu_module(driver::context * context, std::unique_ptr<llvm::Module> ll
|
|||||||
|
|
||||||
cu_module::cu_module(driver::context * context, std::string const & source) : module(context, CUmodule(), true), source_(source){
|
cu_module::cu_module(driver::context * context, std::string const & source) : module(context, CUmodule(), true), source_(source){
|
||||||
// exit(EXIT_FAILURE);
|
// exit(EXIT_FAILURE);
|
||||||
|
// std::cout << source << std::endl;
|
||||||
cu_context::context_switcher ctx(*context);
|
cu_context::context_switcher ctx(*context);
|
||||||
// JIT compile source-code
|
// JIT compile source-code
|
||||||
CUjit_option opt[] = {CU_JIT_ERROR_LOG_BUFFER_SIZE_BYTES, CU_JIT_ERROR_LOG_BUFFER};
|
CUjit_option opt[] = {CU_JIT_ERROR_LOG_BUFFER_SIZE_BYTES, CU_JIT_ERROR_LOG_BUFFER};
|
||||||
|
@@ -221,7 +221,6 @@ std::unique_ptr<driver::module> function::make_bin(ir::module &module, driver::c
|
|||||||
codegen::transform::cts cts;
|
codegen::transform::cts cts;
|
||||||
codegen::generator isel(&axes, &layouts, &align, &allocation, target.get(), opt.num_warps);
|
codegen::generator isel(&axes, &layouts, &align, &allocation, target.get(), opt.num_warps);
|
||||||
// run passes
|
// run passes
|
||||||
ir::print(module, std::cout);
|
|
||||||
dce.run(module);
|
dce.run(module);
|
||||||
// ir::print(module, std::cout);
|
// ir::print(module, std::cout);
|
||||||
|
|
||||||
|
@@ -1,11 +1,16 @@
|
|||||||
import torch
|
import torch
|
||||||
import triton
|
import triton
|
||||||
|
|
||||||
N, C, K = 32, 32, 32
|
N, C, K = 32, 8, 32
|
||||||
H, W = 32, 32
|
H, W = 4, 4
|
||||||
R, S = 3, 3
|
R, S = 3, 3
|
||||||
|
torch.manual_seed(0)
|
||||||
a = torch.randn(N, C, H, W).cuda()
|
a = torch.randn(N, C, H, W).cuda()
|
||||||
b = torch.randn(C, R, S, K).cuda()
|
b = torch.ones(C, R, S, K).cuda()
|
||||||
#c = torch.nn.functional.conv2d(a, b)
|
|
||||||
c = triton.ops.conv(a, b)
|
rc = torch.nn.functional.conv2d(a, b.permute(3, 0, 1, 2))
|
||||||
print(c)
|
tc = triton.ops.conv(a, b)
|
||||||
|
print((rc - tc).abs().max())
|
||||||
|
print((tc[:,:,0,0] - rc[:,:,0,0]).abs())
|
||||||
|
#print((rc[:30,:30,:,:] - tc[:30, :30, :, :]).abs().max())
|
||||||
|
#print(tc[31, 31,:,:])
|
@@ -514,7 +514,6 @@ void gen_torch_make_handles(std::ostream &os,
|
|||||||
}
|
}
|
||||||
|
|
||||||
void gen_torch_make_launch_function(std::ostream &os, const std::vector<ir::argument*>& args) {
|
void gen_torch_make_launch_function(std::ostream &os, const std::vector<ir::argument*>& args) {
|
||||||
os << " std::cout << 9 << std::endl;";
|
|
||||||
os << " std::function<void()> run = [&](){\n ";
|
os << " std::function<void()> run = [&](){\n ";
|
||||||
os << " (*id_fn_map.at(id))({";
|
os << " (*id_fn_map.at(id))({";
|
||||||
for(unsigned i = 0; i < args.size() ; i++){
|
for(unsigned i = 0; i < args.size() ; i++){
|
||||||
@@ -529,7 +528,6 @@ void gen_torch_make_launch_function(std::ostream &os, const std::vector<ir::argu
|
|||||||
os << "}, *id_grid_map.at(id), &stream);\n";
|
os << "}, *id_grid_map.at(id), &stream);\n";
|
||||||
os << " };\n ";
|
os << " };\n ";
|
||||||
os << " run();";
|
os << " run();";
|
||||||
os << " std::cout << 10 << std::endl;";
|
|
||||||
os << " if(bench > 0)\n ";
|
os << " if(bench > 0)\n ";
|
||||||
os << " i64scalar_map[bench_id] = triton::tools::bench(run, &stream);\n ";
|
os << " i64scalar_map[bench_id] = triton::tools::bench(run, &stream);\n ";
|
||||||
}
|
}
|
||||||
@@ -588,14 +586,10 @@ extern std::map<size_t, int64_t> i64scalar_map;
|
|||||||
|
|
||||||
gen_torch_signature(oss, fn, outputs, name);
|
gen_torch_signature(oss, fn, outputs, name);
|
||||||
oss << " {" << std::endl;
|
oss << " {" << std::endl;
|
||||||
oss << " std::cout << 1 << std::endl;";
|
|
||||||
gen_torch_init_driver(oss, fn->args());
|
gen_torch_init_driver(oss, fn->args());
|
||||||
gen_torch_make_handles(oss, fn->args());
|
gen_torch_make_handles(oss, fn->args());
|
||||||
oss << " std::cout << 2 << std::endl;";
|
|
||||||
gen_torch_make_launch_function(oss, fn->args());
|
gen_torch_make_launch_function(oss, fn->args());
|
||||||
oss << " std::cout << 3 << std::endl;";
|
|
||||||
gen_torch_ret(oss, outputs);
|
gen_torch_ret(oss, outputs);
|
||||||
oss << " std::cout << \"done\" << std::endl;\n";
|
|
||||||
oss << "}" << std::endl;
|
oss << "}" << std::endl;
|
||||||
|
|
||||||
oss << std::endl;
|
oss << std::endl;
|
||||||
|
@@ -21,7 +21,7 @@ void convnd(A_TYPE *A,
|
|||||||
int off_uh, int off_uw,
|
int off_uh, int off_uw,
|
||||||
int off_uah, int off_uaw,
|
int off_uah, int off_uaw,
|
||||||
int off_uch, int off_ucw,
|
int off_uch, int off_ucw,
|
||||||
int* a_delta){
|
int* a_delta, int* inc_a){
|
||||||
|
|
||||||
// range of indices along the reduction axis
|
// range of indices along the reduction axis
|
||||||
int rka[TK] = 0 ... TK;
|
int rka[TK] = 0 ... TK;
|
||||||
@@ -42,8 +42,6 @@ void convnd(A_TYPE *A,
|
|||||||
int ras[TK] = rka % BW;
|
int ras[TK] = rka % BW;
|
||||||
int rac[TK] = racr / BH;
|
int rac[TK] = racr / BH;
|
||||||
int rar[TK] = racr % BH;
|
int rar[TK] = racr % BH;
|
||||||
rar = FLIPR rar;
|
|
||||||
ras = FLIPS ras;
|
|
||||||
rar = UPAR * rar;
|
rar = UPAR * rar;
|
||||||
ras = UPAS * ras;
|
ras = UPAS * ras;
|
||||||
int ra0[TM] = rab*lda_n + rah*lda_h + raw*lda_w;
|
int ra0[TM] = rab*lda_n + rah*lda_h + raw*lda_w;
|
||||||
@@ -51,56 +49,36 @@ void convnd(A_TYPE *A,
|
|||||||
A_TYPE* pa[TM, TK] = A + ra0[:, newaxis] + ra1[newaxis, :];
|
A_TYPE* pa[TM, TK] = A + ra0[:, newaxis] + ra1[newaxis, :];
|
||||||
|
|
||||||
// pointers for B
|
// pointers for B
|
||||||
int rb0[TN] = get_program_id(1) * TN + 0 ... TN;
|
int rbn[TN] = get_program_id(1) * TN + 0 ... TN;
|
||||||
#ifdef B_LUT
|
B_TYPE* pb[TK, TN] = B + rbn[newaxis, :] * ldb_k + rkb[:, newaxis] * ldb_s;
|
||||||
int rbcr[TK] = rkb / BW;
|
|
||||||
int rbs[TK] = rkb % BW;
|
|
||||||
int rbc[TK] = rbcr / BH;
|
|
||||||
int rbr[TK] = rbcr % BH;
|
|
||||||
rbr = rbr * upsample_h + off_uh;
|
|
||||||
rbs = rbs * upsample_w + off_uw;
|
|
||||||
int rb1[TK] = rbc*ldb_c + rbr*ldb_r + rbs*ldb_s;
|
|
||||||
#else
|
|
||||||
int rb1[TK] = rkb * STRIDE_B0;
|
|
||||||
#endif
|
|
||||||
B_TYPE* pb [B_SHAPE] = B + rb1[BROADCAST_B1] * STRIDE_B1 + rb0[BROADCAST_B0] * STRIDE_B0 * ldb_k;
|
|
||||||
|
|
||||||
// pointers for A look-up table
|
// pointers for A look-up table
|
||||||
int offda[TK] = rka % LUT_SIZE;
|
int offda[TK] = rka % LUT_SIZE;
|
||||||
int* pincd[TK] = a_delta + offda;
|
int* pincd[TK] = inc_a + offda;
|
||||||
int* pda[TK] = a_delta + LUT_SIZE + offda + off_uw * LUT_SIZE + off_uh * LUT_SIZE * upsample_w;
|
int* pda[TK] = a_delta + offda + off_uw * LUT_SIZE + off_uh * LUT_SIZE * upsample_w;
|
||||||
int da[TK] = *pda;
|
int da[TK] = *pda;
|
||||||
int incd[TK] = *pincd;
|
int incd[TK] = *pincd;
|
||||||
|
|
||||||
// pointers for B look-up table
|
|
||||||
int offdb[TK] = rkb % LUT_SIZE;
|
|
||||||
#ifdef B_LUT
|
|
||||||
int* pdb[TK] = b_delta + offdb + off_uw * LUT_SIZE + off_uh * LUT_SIZE * upsample_w;
|
|
||||||
int db[TK] = *pdb;
|
|
||||||
#endif
|
|
||||||
|
|
||||||
// reduction loop
|
// reduction loop
|
||||||
A_TYPE a[TM, TK] = *pa;
|
A_TYPE a[TM, TK] = *pa;
|
||||||
B_TYPE b[B_SHAPE] = *pb;
|
B_TYPE b[TK, TN] = *pb;
|
||||||
for(int k = K; k > 0; k = k - TK){
|
for(int k = K; k > 0; k = k - TK){
|
||||||
c += a @ USE_B;
|
c += a @ b;
|
||||||
pa = pa + da[newaxis, :];
|
pa += da[newaxis, :];
|
||||||
pb = pb + INC_PB;
|
pb += TK * ldb_s;
|
||||||
// increment A look-up table
|
// increment A look-up table
|
||||||
pda = pda + incd;
|
pda = pda + incd;
|
||||||
da = *pda;
|
da = *pda;
|
||||||
pincd = pincd + incd;
|
pincd = pincd + incd;
|
||||||
incd = *pincd;
|
incd = *pincd;
|
||||||
// increment B look-up table
|
|
||||||
#ifdef B_LUT
|
|
||||||
pdb = pdb + INC_PDB;
|
|
||||||
db = *pdb;
|
|
||||||
#endif
|
|
||||||
// pre-fetches
|
// pre-fetches
|
||||||
a = *pa;
|
bool checka[TM, TK] = k > TK;
|
||||||
b = *pb;
|
bool checkb[TK, TN] = k > TK;
|
||||||
|
a = checka ? *pa : 0;
|
||||||
|
b = checkb ? *pb : 0;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
// write back
|
// write back
|
||||||
int rxc[TM] = get_program_id(0) * TM + 0 ... TM;
|
int rxc[TM] = get_program_id(0) * TM + 0 ... TM;
|
||||||
int rc1[TN] = get_program_id(1) * TN + 0 ... TN;
|
int rc1[TN] = get_program_id(1) * TN + 0 ... TN;
|
||||||
@@ -112,28 +90,31 @@ void convnd(A_TYPE *A,
|
|||||||
rcq = rcq * upsample_w + off_ucw;
|
rcq = rcq * upsample_w + off_ucw;
|
||||||
int rc0[TM] = rcn * ldc_n + rcp * ldc_p + rcq * ldc_q;
|
int rc0[TM] = rcn * ldc_n + rcp * ldc_p + rcq * ldc_q;
|
||||||
float* pc[TM, TN] = C + rc1[newaxis, :]*ldc_k + rc0[:, newaxis];
|
float* pc[TM, TN] = C + rc1[newaxis, :]*ldc_k + rc0[:, newaxis];
|
||||||
*pc = c;
|
bool checkc0[TM] = rxc < M;
|
||||||
|
bool checkc1[TN] = rc1 < N;
|
||||||
|
bool checkc[TM, TN] = checkc0[:, newaxis] && checkc1[newaxis, :];
|
||||||
|
*?(checkc)pc = c;
|
||||||
}
|
}
|
||||||
"""
|
"""
|
||||||
kernel = triton.kernel(src, ['C'])
|
kernel = triton.kernel(src, ['C'])
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _unpack(idx, D, H, W):
|
def _unpack(idx, D, H, W):
|
||||||
c = idx // (D*H*W)
|
cdh = idx // W
|
||||||
dhw = idx % (D*H*W)
|
w = idx % W
|
||||||
dh = dhw // W
|
cd = cdh // H
|
||||||
w = dhw % W
|
h = cdh % H
|
||||||
d = dh // H
|
c = cd // D
|
||||||
h = dh % H
|
d = cd % D
|
||||||
return c, d, h, w
|
return c, d, h, w
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _delta_a(upsample_d, upsample_h, upsample_w, depth, TK,
|
def _delta_a(upsample_d, upsample_h, upsample_w, depth, TK,
|
||||||
T, R, S, stride_a):
|
T, R, S, stride_a):
|
||||||
ud = np.arange(upsample_d)[:, np.newaxis, np.newaxis, np.newaxis]
|
ud = np.arange(upsample_d, dtype=np.int32)[:, np.newaxis, np.newaxis, np.newaxis]
|
||||||
uh = np.arange(upsample_h)[np.newaxis, :, np.newaxis, np.newaxis]
|
uh = np.arange(upsample_h, dtype=np.int32)[np.newaxis, :, np.newaxis, np.newaxis]
|
||||||
uw = np.arange(upsample_w)[np.newaxis, np.newaxis, :, np.newaxis]
|
uw = np.arange(upsample_w, dtype=np.int32)[np.newaxis, np.newaxis, :, np.newaxis]
|
||||||
ctrs = np.arange(depth)[np.newaxis, np.newaxis, np.newaxis, :]
|
ctrs = np.arange(depth, dtype=np.int32)[np.newaxis, np.newaxis, np.newaxis, :]
|
||||||
c, t, r, s = _conv._unpack(ctrs, T, R, S)
|
c, t, r, s = _conv._unpack(ctrs, T, R, S)
|
||||||
nextc, nextt, nextr, nexts = _conv._unpack(ctrs + TK, T, R, S)
|
nextc, nextt, nextr, nexts = _conv._unpack(ctrs + TK, T, R, S)
|
||||||
cdiff = nextc - c
|
cdiff = nextc - c
|
||||||
@@ -181,31 +162,18 @@ void convnd(A_TYPE *A,
|
|||||||
delta_a = _conv._delta_a(upsample_d, upsample_h, upsample_w,
|
delta_a = _conv._delta_a(upsample_d, upsample_h, upsample_w,
|
||||||
depth, TK, BD, BH, BW, stride_a)
|
depth, TK, BD, BH, BW, stride_a)
|
||||||
delta_a = triton.fw.torch.from_numpy(delta_a).cuda()
|
delta_a = triton.fw.torch.from_numpy(delta_a).cuda()
|
||||||
|
inc_a = np.arange(depth, dtype=np.int32)
|
||||||
|
inc_a = ((inc_a + TK) % depth) - inc_a
|
||||||
|
inc_a = triton.fw.torch.from_numpy(inc_a).cuda()
|
||||||
|
|
||||||
trans_b = False
|
trans_b = False
|
||||||
is_wgrad = False
|
is_wgrad = False
|
||||||
is_blut = False
|
is_blut = False
|
||||||
macros = {
|
macros = {
|
||||||
'B_SHAPE': 'TN, TK' if trans_b else 'TK, TN',
|
|
||||||
'BROADCAST_B0': ':, newaxis' if trans_b else 'newaxis, :',
|
|
||||||
'BROADCAST_B1': 'newaxis, :' if trans_b else ':, newaxis',
|
|
||||||
'STRIDE_B0': 'ldb_s' if trans_b else '1',
|
|
||||||
'STRIDE_B1': '1' if trans_b else 'ldb_s',
|
|
||||||
'USE_B': '^b' if trans_b else 'b',
|
|
||||||
'FLIPR': '' if trans_b else 'BH - 1 -',
|
|
||||||
'FLIPS': '' if trans_b else 'BW - 1 -',
|
|
||||||
'UPAR': 'stride_h' if is_wgrad else '1',
|
'UPAR': 'stride_h' if is_wgrad else '1',
|
||||||
'UPAS': 'stride_w' if is_wgrad else '1',
|
'UPAS': 'stride_w' if is_wgrad else '1',
|
||||||
'UPAH': '' if is_wgrad else 'stride_h',
|
'UPAH': '' if is_wgrad else 'stride_h',
|
||||||
'UPAW': '' if is_wgrad else 'stride_w',
|
'UPAW': '' if is_wgrad else 'stride_w',
|
||||||
'REDAX0': 'NC' if trans_b else 'BH',
|
|
||||||
'REDAX1': 'BH' if trans_b else 'BW',
|
|
||||||
'REDAX2': 'BW' if trans_b else 'NC',
|
|
||||||
'AX0': 'c' if trans_b else 'r',
|
|
||||||
'AX1': 'r' if trans_b else 's',
|
|
||||||
'AX2': 's' if trans_b else 'c',
|
|
||||||
'INC_PB': 'db[newaxis, :]' if is_blut else 'TK',
|
|
||||||
'INC_PDB': 'incd' if trans_b else 'TK',
|
|
||||||
'LUT_SIZE': depth,
|
'LUT_SIZE': depth,
|
||||||
'TM': [32],
|
'TM': [32],
|
||||||
'TN': [32],
|
'TN': [32],
|
||||||
@@ -215,20 +183,22 @@ void convnd(A_TYPE *A,
|
|||||||
}
|
}
|
||||||
|
|
||||||
shape_c.pop(2)
|
shape_c.pop(2)
|
||||||
print(shape_c)
|
|
||||||
c = triton.empty(shape_c, dtype=a.dtype)
|
c = triton.empty(shape_c, dtype=a.dtype)
|
||||||
_conv.kernel(a, b, c, CD*CH*CW, NF, NC*BD*BH*BW, AH, AW, BH, BW, CH, CW, NC,
|
grid = lambda opt: [triton.cdiv(NB*CD*CH*CW, opt.d('TM')), triton.cdiv(NF, opt.d('TN'))]
|
||||||
|
print(stride_c)
|
||||||
|
print(stride_b)
|
||||||
|
_conv.kernel(a, b, c, NB*CD*CH*CW, NF, NC*BD*BH*BW, AH, AW, BH, BW, CH, CW, NC,
|
||||||
stride_a[0], stride_a[1], stride_a[2], stride_a[3], stride_a[4],
|
stride_a[0], stride_a[1], stride_a[2], stride_a[3], stride_a[4],
|
||||||
stride_b[0], stride_b[1], stride_b[2], stride_b[3], stride_b[4],
|
stride_b[0], stride_b[1], stride_b[2], stride_b[3], stride_b[4],
|
||||||
stride_c[0], stride_c[1], stride_c[2], stride_c[3], stride_c[4],
|
stride_c[0], stride_c[1], stride_c[2], stride_c[3], stride_c[4],
|
||||||
pad_h, pad_w, stride_h, stride_w, upsample_h, upsample_w,
|
pad_h, pad_w, stride_h, stride_w, upsample_h, upsample_w,
|
||||||
0, 0, 0, 0, 0, 0,
|
0, 0, 0, 0, 0, 0,
|
||||||
delta_a,
|
delta_a, inc_a,
|
||||||
lambda opt: (1, 1, 1), **macros)
|
grid, **macros)
|
||||||
return c
|
return c
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def forward(ctx, input, weight):
|
def forward(ctx, input, weight):
|
||||||
_conv._call(input, weight, 1, 1, 1, 0, 0, 0, 1, 1, 1, '')
|
return _conv._call(input, weight, 1, 1, 1, 0, 0, 0, 1, 1, 1, '')
|
||||||
|
|
||||||
conv = _conv.apply
|
conv = _conv.apply
|
Reference in New Issue
Block a user