more stuff

This commit is contained in:
Philippe Tillet
2019-06-30 16:55:02 -07:00
parent 9a86bc51e1
commit c172bd518b
9 changed files with 124 additions and 90 deletions

View File

@@ -16,6 +16,7 @@ int main() {
auto context = triton::driver::backend::contexts::get_default();
// initialize just-in-time compiler
triton::jit jit(context);
// initialization
int32_t R = 3, S = 3;
int32_t BS = 32, F = 1024;
@@ -30,7 +31,7 @@ int main() {
shift_w[c] = rand() % S - S/2;
}
// configuration
triton::dnn::shift shift(BS, C, 1, H, W, 1, R, S, F, shift_h, shift_w, numeric_t_str, numeric_t_str);
triton::dnn::shift shift(BS, C, 1, H, W, 1, R, S, F, shift_h, shift_w, numeric_t_str, numeric_t_str, triton::dnn::shift::FPROP);
// host buffers
std::vector<float> hc(shift.c_size());
std::vector<float> rc(shift.c_size());
@@ -58,7 +59,7 @@ int main() {
auto benchmark = [&](triton::driver::kernel* kernel,
triton::jit::launch_information info) {
shift.init(stream, (triton::driver::cu_module*)kernel->module());
// launch info
// launch infoRR
unsigned TM = info.global_range_size[0];
unsigned TN = info.global_range_size[1];
unsigned nthreads = info.num_threads;
@@ -78,7 +79,7 @@ int main() {
std::ostringstream oss;
shift.src(oss);
std::string src = oss.str();
// jit.autotune("shift", src.c_str(), benchmark);
jit.autotune("shift", src.c_str(), benchmark);
jit.add_module("shift", src.c_str(), params);
triton::driver::kernel* kernel = jit.get_function("shift");
triton::jit::launch_information info = jit.get_launch_info("shift");

View File

@@ -38,7 +38,9 @@ class shift {
public:
enum type {
FPROP
FPROP,
BPROP,
WGRAD
};
private:
@@ -85,11 +87,11 @@ public:
OUT_DTYPE acc;
for(int32_t p = 0; p < AH_; ++p)
for(int32_t q = 0; q < AW_; ++q)
for(int32_t bs = 0; bs < NB_; ++bs)
for(int32_t k = 0; k < NF_; ++k)
for(int32_t bs = 0; bs < B_; ++bs)
for(int32_t k = 0; k < F_; ++k)
{
acc = 0;
for(int32_t c = 0; c < NC_; ++c){
for(int32_t c = 0; c < C_; ++c){
int32_t h = p;
int32_t w = q;
if(h >= BH_/2 && h < AH_ - BH_/2
@@ -97,11 +99,11 @@ public:
h += shift_h_[c];
w += shift_w_[c];
}
IN_DTYPE a = I[bs + w*NB_ + h*NB_*AW_ + c*NB_*AH_*AW_];
IN_DTYPE b = F[k + c*NF_];
IN_DTYPE a = I[bs + w*B_ + h*B_*AW_ + c*B_*AH_*AW_];
IN_DTYPE b = F[k + c*F_];
acc = std::fma(a, b, acc);
}
O[bs + q*NB_ + p*NB_*AW_ + k*NB_*AH_*AW_] = acc;
O[bs + q*B_ + p*B_*AW_ + k*B_*AH_*AW_] = acc;
}
}
@@ -109,8 +111,8 @@ private:
int32_t MAX_C_;
int32_t TK_;
// image size
int32_t NB_;
int32_t NC_;
int32_t B_;
int32_t C_;
int32_t AD_;
int32_t AH_;
int32_t AW_;
@@ -118,7 +120,7 @@ private:
int32_t BD_;
int32_t BH_;
int32_t BW_;
int32_t NF_;
int32_t F_;
// activation size
int32_t CD_;
int32_t CH_;
@@ -149,6 +151,9 @@ private:
// convolution type
type ty_;
bool bias_;
// transpose
bool AT_;
bool BT_;
};
}

View File

@@ -160,13 +160,13 @@ private:
class indexing_expression: public postfix_expression{
public:
indexing_expression(node *id, node *slices)
: id_((const identifier*)id), slices_((const list<slice*>*)slices) {}
indexing_expression(node *lhs, node *slices)
: lhs_((const expression*)lhs), slices_((const list<slice*>*)slices) {}
ir::value* codegen(ir::module *) const;
private:
const identifier* id_;
const expression* lhs_;
const list<slice*>* slices_;
};

View File

@@ -157,7 +157,7 @@ slice_list
postfix_expression
: primary_expression { $$ = $1;}
| identifier '[' slice_list ']' { $$ = new indexing_expression($1, $3);}
| primary_expression '[' slice_list ']' { $$ = new indexing_expression($1, $3);}
;
/* Unary */

View File

@@ -65,9 +65,9 @@ public:
target_(target) { }
void target_independent(ir::module &module) {
optimize_dot.run(module);
optimize_trans.run(module);
// ir::print(module, std::cout);
ir::print(module, std::cout);
optimize_dot.run(module);
optimize_trans.run(module);
}
void target_dependent(ir::module &module) {

View File

@@ -59,7 +59,7 @@ void tune::init_c_graph(ir::instruction *v) {
else if(auto *downcast = dynamic_cast<ir::downcast_inst*>(v))
return;
else{
// std::cout << v->get_name() << std::endl;
std::cout << v->get_name() << std::endl;
shapes = v->get_type()->get_tile_shapes();
}
// Reshape

View File

@@ -8,42 +8,63 @@ void shift::set_ld(const std::vector<int32_t>& shapes,
std::vector<int32_t>& ld) {
size_t size = shapes.size();
ld.resize(size);
ld[3] = 1;
ld[2] = shapes[3]*ld[3];
ld[1] = shapes[2]*ld[2];
ld[0] = shapes[1]*ld[1];
ld[size - 1] = 1;
for(int i = size - 1; i >= 1; i--)
ld[i - 1] = shapes[i] * ld[i];
}
shift::shift(int B, int NC,
shift::shift(int B, int C,
int D, int H, int W,
int T, int R, int S,
int NF,
int F,
const std::vector<int32_t>& shift_h, const std::vector<int32_t>& shift_w,
std::string a_ty, std::string b_ty,
type ty, bool bias)
: NB_(B), NC_(NC),
: B_(B), C_(C),
AD_(D), AH_(H), AW_(W),
BD_(T), BH_(R), BW_(S),
NF_(NF),
F_(F),
shift_h_(shift_h), shift_w_(shift_w),
a_ty_(a_ty), b_ty_(b_ty),
ty_(ty), bias_(bias) {
// max number of channels
TK_ = 16;
MAX_C_ = 8192 + TK_;
// transpose
AT_ = false;
BT_ = true;
// equivalent matmul
M_ = NB_*AH_*AW_;
N_ = NF_;
K_ = NC_;
M_ = B_*AH_*AW_;
N_ = F_;
K_ = C_;
// shapes
// input layout: C, H, W, BS
// filter layout: C, K
// output layout: K, H, W, BS
shapes_a_ = {NC, H, W, B};
shapes_b_ = {NC, NF};
shapes_c_ = {NF, H, W, B};
// input layout: C, H, W, B
// filter layout: C, F
// output layout: F, H, W, B
shapes_a_ = {C, H, W, B};
shapes_b_ = {C, F};
shapes_c_ = {F, H, W, B};
if(ty_ == WGRAD){
shapes_b_.swap(shapes_c_);
shapes_a_.swap(shapes_b_);
AT_ = true;
BT_ = false;
M_ = K_;
N_ = C_;
K_ = B_*AH_*AW_;
}
if(ty_ == BPROP){
shapes_a_.swap(shapes_c_);
AT_ = false;
BT_ = false;
K_ = F_;
M_ = B_*AH_*AW_;
N_ = C_;
}
// memory strides
set_ld(shapes_a_, ld_a_);
set_ld(shapes_b_, ld_b_);
set_ld(shapes_c_, ld_c_);
// build LUTs
build_deltas();
}
@@ -57,7 +78,7 @@ void shift::build_deltas() {
// populate look-up table
for(unsigned c = 0; c < TK_; c++)
h_deltas_[c] = offset(c);
for(unsigned c = 0; c < NC_; c++)
for(unsigned c = 0; c < C_; c++)
h_deltas_[TK_ + c] = offset(c + TK_) - offset(c);
}
@@ -99,18 +120,36 @@ void shift::enqueue(driver::stream *stream, driver::kernel *kernel,
kernel->setArg(3, M_);
kernel->setArg(4, N_);
kernel->setArg(5, K_);
kernel->setArg(6, NB_*AH_*AW_);
kernel->setArg(7, NB_);
kernel->setArg(6, B_*AH_*AW_);
kernel->setArg(7, B_);
kernel->setArg(8, AH_);
kernel->setArg(9, AW_);
kernel->setArg(10, BH_);
kernel->setArg(11, BW_);
// dry run
std::array<size_t, 3> grid = {(M_ + TM - 1)/TM, (N_ + TN - 1)/TN, 1};
stream->enqueue(kernel, grid, {nthreads, 1, 1});
}
void shift::src(std::ostream &os) {
std::string AS0 = "TM", AS1 = "TK";
std::string BS0 = "TK", BS1 = "TN";
std::string bca0 = "[newaxis, :]", bca1 = "[:, newaxis]";
std::string bcb0 = "[:, newaxis]", bcb1 = "[newaxis, :]";
std::string lda0 = "*lda", lda1 = "";
std::string ldb0 = "", ldb1 = "*ldb";
std::string usea = AT_ ? "trans(a)" : "a";
std::string useb = BT_ ? "trans(b)" : "b";
if(AT_){
std::swap(AS0, AS1);
std::swap(bca0, bca1);
std::swap(lda0, lda1);
}
if(BT_){
std::swap(BS0, BS1);
std::swap(bcb0, bcb1);
std::swap(ldb0, ldb1);
}
os <<
R"(
const tunable int32 TM = {16, 32, 64, 128};
@@ -136,26 +175,27 @@ void shift(restrict read_only align(16) )" << a_ty_ << R"( *a,
int32 raw[TM] = rawhc % AW;
int32 rahc[TM] = rawhc / AW;
int32 rah[TM] = rahc % AH;
__constant__ int32* pd[TK] = delta + rka;
multiple_of(4) int32 d[TK] = *pd;
int1 maskh[TM] = (rah >= pad_h) && (rah < (AH - pad_h));
int1 maskw[TM] = (raw >= pad_w) && (raw < (AW - pad_w));
int1 mask[TM, TK] = maskh[:, newaxis] && maskw[:, newaxis];
__constant__ int32* pd[TK] = delta + rka;
multiple_of(4) int32 d[TK];
d = *pd;
int32 offa1[TK] = rka*lda;
int32 inc[TM, TK] = mask ? d[newaxis, :] : offa1[newaxis, :];
)" << a_ty_ << R"(* pa[TM, TK] = a + rxa[:, newaxis] + inc;
)" << b_ty_ << R"(* pb[TN, TK] = b + rkb[newaxis, :]*N + ryb[:, newaxis];
)" << a_ty_ << R"( a[TM, TK] = *pa;
)" << b_ty_ << R"( b[TN, TK] = *pb;
int1 mask[)" << AS0 << ", " << AS1 << "] = maskh" << bca1 << " && maskw" << bca1 << R"(;
int32 inc_true[)" << AS0 << ", " << AS1 << "] = d" << bca0 << R"(;
int32 inc_false[)" << AS0 << ", " << AS1 << "] = rka" << bca0 << R"( * lda;
)" << a_ty_ << "* pa[" << AS0 << ", " << AS1 << R"(] = a + rxa)" << bca1 << R"( + (mask ? inc_true : inc_false);
)" << b_ty_ << "* pb[" << BS0 << ", " << BS1 << "] = b + ryb" << bcb1 << " + rkb" << bcb0 << R"(*N;
)" << a_ty_ << " a[" << AS0 << ", " << AS1 << R"(] = *pa;
)" << b_ty_ << " b[" << BS0 << ", " << BS1 << R"(] = *pb;
for(int32 k = K; k > 0; k = k - TK){
C = dot(a, trans(b), C);
C = dot()" << usea << "," << useb << R"(, C);
pb = pb + TK*N;
pd = pd + TK;
d = *pd;
pa = pa + (mask ? d[newaxis, :] : TK*lda);
int1 checka[TM, TK] = k > TK;
int1 checkb[TN, TK] = k > TK;
inc_true = d)" << bca0 << R"(;
inc_false = TK * lda;
pa = pa + (mask ? inc_true : inc_false);
int1 checka[)" << AS0 << ", " << AS1 << R"(] = k > TK;
int1 checkb[)" << BS0 << ", " << BS1 << R"(] = k > TK;
@checka a = *pa;
@checkb b = *pb;
}

View File

@@ -255,7 +255,7 @@ std::string cu_module::compile_llvm_module(llvm::Module* module) {
cu_module::cu_module(driver::context * context, llvm::Module* ll_module): cu_module(context, compile_llvm_module(ll_module)) { }
cu_module::cu_module(driver::context * context, std::string const & source) : module(context, CUmodule(), true), source_(source){
// std::cout << source << std::endl;
// std::cout << source << sd::endl;
cu_context::context_switcher ctx_switch(*context);
// JIT compile source-code
CUjit_option opt[] = {CU_JIT_ERROR_LOG_BUFFER_SIZE_BYTES, CU_JIT_ERROR_LOG_BUFFER};

View File

@@ -175,7 +175,7 @@ ir::value* trans_expression::codegen(ir::module *mod) const {
/* Postfix expression */
ir::value* indexing_expression::codegen(ir::module *mod) const{
ir::value *in = mod->get_value(id_->name());
ir::value *in = lhs_->codegen(mod);
const std::vector<slice*> &slices = slices_->values();
auto in_shapes = in->get_type()->get_tile_shapes();
ir::type::tile_shapes_t::value_type one = ir::tile_type::make_one(mod->get_context());
@@ -234,44 +234,32 @@ ir::value* cast_expression::codegen(ir::module *mod) const{
/* Conditional expression */
ir::value *conditional_expression::codegen(ir::module *mod) const{
ir::builder &builder = mod->get_builder();
ir::basic_block::inst_list_t &instructions = builder.get_insert_block()->get_inst_list();
ir::value *pred = cond_->codegen(mod);
ir::instruction *mask = (ir::instruction*)builder.create_mask(pred);
/* true value */
ir::value *true_mask = mask->get_result(0);
ir::value *false_mask = mask->get_result(1);
auto it_true_begin = instructions.end();
it_true_begin--;
ir::value *true_value = true_value_->codegen(mod);
ir::value *false_value = false_value_->codegen(mod);
if(auto *itn = dynamic_cast<ir::instruction*>(true_value))
itn->set_mask_pred(true_mask);
if(auto *itn = dynamic_cast<ir::instruction*>(false_value))
itn->set_mask_pred(false_mask);
bool is_float, is_ptr, is_int, is_signed;
ir::value *uncasted_true_value = true_value;
ir::value *uncasted_false_value = false_value;
implicit_cast(builder, true_value, false_value, is_float, is_ptr, is_int, is_signed);
implicit_broadcast(mod, pred, true_value);
it_true_begin++;
auto it_true_end = instructions.end();
for(auto it = it_true_begin; it != it_true_end; it++)
(*it)->set_mask_pred(true_mask);
/* false value */
ir::value *false_mask = mask->get_result(1);
auto it_false_begin = instructions.end();
it_false_begin--;
ir::value *false_value = false_value_->codegen(mod);
it_false_begin++;
implicit_broadcast(mod, pred, false_value);
{
ir::value *current = true_value;
while(current != uncasted_true_value) {
if(auto *itn = dynamic_cast<ir::instruction*>(current)){
itn->set_mask_pred(true_mask);
current = itn->get_operand(0);
}
else
break;
}
}
{
ir::value *current = false_value;
while(current != uncasted_false_value) {
if(auto *itn = dynamic_cast<ir::instruction*>(current)){
itn->set_mask_pred(false_mask);
current = itn->get_operand(0);
}
else
break;
}
}
auto it_false_end = instructions.end();
for(auto it = it_false_begin; it != it_false_end; it++)
(*it)->set_mask_pred(false_mask);
/* cast */
bool is_float, is_ptr, is_int, is_signed;
implicit_cast(builder, true_value, false_value, is_float, is_ptr, is_int, is_signed);
ir::value *result = builder.create_merge(true_mask, true_value, false_mask, false_value);
return result;
}