more stuff

This commit is contained in:
Philippe Tillet
2019-06-30 16:55:02 -07:00
parent 9a86bc51e1
commit c172bd518b
9 changed files with 124 additions and 90 deletions

View File

@@ -16,6 +16,7 @@ int main() {
auto context = triton::driver::backend::contexts::get_default(); auto context = triton::driver::backend::contexts::get_default();
// initialize just-in-time compiler // initialize just-in-time compiler
triton::jit jit(context); triton::jit jit(context);
// initialization // initialization
int32_t R = 3, S = 3; int32_t R = 3, S = 3;
int32_t BS = 32, F = 1024; int32_t BS = 32, F = 1024;
@@ -30,7 +31,7 @@ int main() {
shift_w[c] = rand() % S - S/2; shift_w[c] = rand() % S - S/2;
} }
// configuration // configuration
triton::dnn::shift shift(BS, C, 1, H, W, 1, R, S, F, shift_h, shift_w, numeric_t_str, numeric_t_str); triton::dnn::shift shift(BS, C, 1, H, W, 1, R, S, F, shift_h, shift_w, numeric_t_str, numeric_t_str, triton::dnn::shift::FPROP);
// host buffers // host buffers
std::vector<float> hc(shift.c_size()); std::vector<float> hc(shift.c_size());
std::vector<float> rc(shift.c_size()); std::vector<float> rc(shift.c_size());
@@ -58,7 +59,7 @@ int main() {
auto benchmark = [&](triton::driver::kernel* kernel, auto benchmark = [&](triton::driver::kernel* kernel,
triton::jit::launch_information info) { triton::jit::launch_information info) {
shift.init(stream, (triton::driver::cu_module*)kernel->module()); shift.init(stream, (triton::driver::cu_module*)kernel->module());
// launch info // launch infoRR
unsigned TM = info.global_range_size[0]; unsigned TM = info.global_range_size[0];
unsigned TN = info.global_range_size[1]; unsigned TN = info.global_range_size[1];
unsigned nthreads = info.num_threads; unsigned nthreads = info.num_threads;
@@ -78,7 +79,7 @@ int main() {
std::ostringstream oss; std::ostringstream oss;
shift.src(oss); shift.src(oss);
std::string src = oss.str(); std::string src = oss.str();
// jit.autotune("shift", src.c_str(), benchmark); jit.autotune("shift", src.c_str(), benchmark);
jit.add_module("shift", src.c_str(), params); jit.add_module("shift", src.c_str(), params);
triton::driver::kernel* kernel = jit.get_function("shift"); triton::driver::kernel* kernel = jit.get_function("shift");
triton::jit::launch_information info = jit.get_launch_info("shift"); triton::jit::launch_information info = jit.get_launch_info("shift");

View File

@@ -38,7 +38,9 @@ class shift {
public: public:
enum type { enum type {
FPROP FPROP,
BPROP,
WGRAD
}; };
private: private:
@@ -85,11 +87,11 @@ public:
OUT_DTYPE acc; OUT_DTYPE acc;
for(int32_t p = 0; p < AH_; ++p) for(int32_t p = 0; p < AH_; ++p)
for(int32_t q = 0; q < AW_; ++q) for(int32_t q = 0; q < AW_; ++q)
for(int32_t bs = 0; bs < NB_; ++bs) for(int32_t bs = 0; bs < B_; ++bs)
for(int32_t k = 0; k < NF_; ++k) for(int32_t k = 0; k < F_; ++k)
{ {
acc = 0; acc = 0;
for(int32_t c = 0; c < NC_; ++c){ for(int32_t c = 0; c < C_; ++c){
int32_t h = p; int32_t h = p;
int32_t w = q; int32_t w = q;
if(h >= BH_/2 && h < AH_ - BH_/2 if(h >= BH_/2 && h < AH_ - BH_/2
@@ -97,11 +99,11 @@ public:
h += shift_h_[c]; h += shift_h_[c];
w += shift_w_[c]; w += shift_w_[c];
} }
IN_DTYPE a = I[bs + w*NB_ + h*NB_*AW_ + c*NB_*AH_*AW_]; IN_DTYPE a = I[bs + w*B_ + h*B_*AW_ + c*B_*AH_*AW_];
IN_DTYPE b = F[k + c*NF_]; IN_DTYPE b = F[k + c*F_];
acc = std::fma(a, b, acc); acc = std::fma(a, b, acc);
} }
O[bs + q*NB_ + p*NB_*AW_ + k*NB_*AH_*AW_] = acc; O[bs + q*B_ + p*B_*AW_ + k*B_*AH_*AW_] = acc;
} }
} }
@@ -109,8 +111,8 @@ private:
int32_t MAX_C_; int32_t MAX_C_;
int32_t TK_; int32_t TK_;
// image size // image size
int32_t NB_; int32_t B_;
int32_t NC_; int32_t C_;
int32_t AD_; int32_t AD_;
int32_t AH_; int32_t AH_;
int32_t AW_; int32_t AW_;
@@ -118,7 +120,7 @@ private:
int32_t BD_; int32_t BD_;
int32_t BH_; int32_t BH_;
int32_t BW_; int32_t BW_;
int32_t NF_; int32_t F_;
// activation size // activation size
int32_t CD_; int32_t CD_;
int32_t CH_; int32_t CH_;
@@ -149,6 +151,9 @@ private:
// convolution type // convolution type
type ty_; type ty_;
bool bias_; bool bias_;
// transpose
bool AT_;
bool BT_;
}; };
} }

View File

@@ -160,13 +160,13 @@ private:
class indexing_expression: public postfix_expression{ class indexing_expression: public postfix_expression{
public: public:
indexing_expression(node *id, node *slices) indexing_expression(node *lhs, node *slices)
: id_((const identifier*)id), slices_((const list<slice*>*)slices) {} : lhs_((const expression*)lhs), slices_((const list<slice*>*)slices) {}
ir::value* codegen(ir::module *) const; ir::value* codegen(ir::module *) const;
private: private:
const identifier* id_; const expression* lhs_;
const list<slice*>* slices_; const list<slice*>* slices_;
}; };

View File

@@ -157,7 +157,7 @@ slice_list
postfix_expression postfix_expression
: primary_expression { $$ = $1;} : primary_expression { $$ = $1;}
| identifier '[' slice_list ']' { $$ = new indexing_expression($1, $3);} | primary_expression '[' slice_list ']' { $$ = new indexing_expression($1, $3);}
; ;
/* Unary */ /* Unary */

View File

@@ -65,9 +65,9 @@ public:
target_(target) { } target_(target) { }
void target_independent(ir::module &module) { void target_independent(ir::module &module) {
optimize_dot.run(module); ir::print(module, std::cout);
optimize_trans.run(module); optimize_dot.run(module);
// ir::print(module, std::cout); optimize_trans.run(module);
} }
void target_dependent(ir::module &module) { void target_dependent(ir::module &module) {

View File

@@ -59,7 +59,7 @@ void tune::init_c_graph(ir::instruction *v) {
else if(auto *downcast = dynamic_cast<ir::downcast_inst*>(v)) else if(auto *downcast = dynamic_cast<ir::downcast_inst*>(v))
return; return;
else{ else{
// std::cout << v->get_name() << std::endl; std::cout << v->get_name() << std::endl;
shapes = v->get_type()->get_tile_shapes(); shapes = v->get_type()->get_tile_shapes();
} }
// Reshape // Reshape

View File

@@ -8,42 +8,63 @@ void shift::set_ld(const std::vector<int32_t>& shapes,
std::vector<int32_t>& ld) { std::vector<int32_t>& ld) {
size_t size = shapes.size(); size_t size = shapes.size();
ld.resize(size); ld.resize(size);
ld[3] = 1; ld[size - 1] = 1;
ld[2] = shapes[3]*ld[3]; for(int i = size - 1; i >= 1; i--)
ld[1] = shapes[2]*ld[2]; ld[i - 1] = shapes[i] * ld[i];
ld[0] = shapes[1]*ld[1];
} }
shift::shift(int B, int NC, shift::shift(int B, int C,
int D, int H, int W, int D, int H, int W,
int T, int R, int S, int T, int R, int S,
int NF, int F,
const std::vector<int32_t>& shift_h, const std::vector<int32_t>& shift_w, const std::vector<int32_t>& shift_h, const std::vector<int32_t>& shift_w,
std::string a_ty, std::string b_ty, std::string a_ty, std::string b_ty,
type ty, bool bias) type ty, bool bias)
: NB_(B), NC_(NC), : B_(B), C_(C),
AD_(D), AH_(H), AW_(W), AD_(D), AH_(H), AW_(W),
BD_(T), BH_(R), BW_(S), BD_(T), BH_(R), BW_(S),
NF_(NF), F_(F),
shift_h_(shift_h), shift_w_(shift_w), shift_h_(shift_h), shift_w_(shift_w),
a_ty_(a_ty), b_ty_(b_ty), a_ty_(a_ty), b_ty_(b_ty),
ty_(ty), bias_(bias) { ty_(ty), bias_(bias) {
// max number of channels // max number of channels
TK_ = 16; TK_ = 16;
MAX_C_ = 8192 + TK_; MAX_C_ = 8192 + TK_;
// transpose
AT_ = false;
BT_ = true;
// equivalent matmul // equivalent matmul
M_ = NB_*AH_*AW_; M_ = B_*AH_*AW_;
N_ = NF_; N_ = F_;
K_ = NC_; K_ = C_;
// shapes // shapes
// input layout: C, H, W, BS // input layout: C, H, W, B
// filter layout: C, K // filter layout: C, F
// output layout: K, H, W, BS // output layout: F, H, W, B
shapes_a_ = {NC, H, W, B}; shapes_a_ = {C, H, W, B};
shapes_b_ = {NC, NF}; shapes_b_ = {C, F};
shapes_c_ = {NF, H, W, B}; shapes_c_ = {F, H, W, B};
if(ty_ == WGRAD){
shapes_b_.swap(shapes_c_);
shapes_a_.swap(shapes_b_);
AT_ = true;
BT_ = false;
M_ = K_;
N_ = C_;
K_ = B_*AH_*AW_;
}
if(ty_ == BPROP){
shapes_a_.swap(shapes_c_);
AT_ = false;
BT_ = false;
K_ = F_;
M_ = B_*AH_*AW_;
N_ = C_;
}
// memory strides // memory strides
set_ld(shapes_a_, ld_a_); set_ld(shapes_a_, ld_a_);
set_ld(shapes_b_, ld_b_);
set_ld(shapes_c_, ld_c_);
// build LUTs // build LUTs
build_deltas(); build_deltas();
} }
@@ -57,7 +78,7 @@ void shift::build_deltas() {
// populate look-up table // populate look-up table
for(unsigned c = 0; c < TK_; c++) for(unsigned c = 0; c < TK_; c++)
h_deltas_[c] = offset(c); h_deltas_[c] = offset(c);
for(unsigned c = 0; c < NC_; c++) for(unsigned c = 0; c < C_; c++)
h_deltas_[TK_ + c] = offset(c + TK_) - offset(c); h_deltas_[TK_ + c] = offset(c + TK_) - offset(c);
} }
@@ -99,18 +120,36 @@ void shift::enqueue(driver::stream *stream, driver::kernel *kernel,
kernel->setArg(3, M_); kernel->setArg(3, M_);
kernel->setArg(4, N_); kernel->setArg(4, N_);
kernel->setArg(5, K_); kernel->setArg(5, K_);
kernel->setArg(6, NB_*AH_*AW_); kernel->setArg(6, B_*AH_*AW_);
kernel->setArg(7, NB_); kernel->setArg(7, B_);
kernel->setArg(8, AH_); kernel->setArg(8, AH_);
kernel->setArg(9, AW_); kernel->setArg(9, AW_);
kernel->setArg(10, BH_); kernel->setArg(10, BH_);
kernel->setArg(11, BW_); kernel->setArg(11, BW_);
// dry run
std::array<size_t, 3> grid = {(M_ + TM - 1)/TM, (N_ + TN - 1)/TN, 1}; std::array<size_t, 3> grid = {(M_ + TM - 1)/TM, (N_ + TN - 1)/TN, 1};
stream->enqueue(kernel, grid, {nthreads, 1, 1}); stream->enqueue(kernel, grid, {nthreads, 1, 1});
} }
void shift::src(std::ostream &os) { void shift::src(std::ostream &os) {
std::string AS0 = "TM", AS1 = "TK";
std::string BS0 = "TK", BS1 = "TN";
std::string bca0 = "[newaxis, :]", bca1 = "[:, newaxis]";
std::string bcb0 = "[:, newaxis]", bcb1 = "[newaxis, :]";
std::string lda0 = "*lda", lda1 = "";
std::string ldb0 = "", ldb1 = "*ldb";
std::string usea = AT_ ? "trans(a)" : "a";
std::string useb = BT_ ? "trans(b)" : "b";
if(AT_){
std::swap(AS0, AS1);
std::swap(bca0, bca1);
std::swap(lda0, lda1);
}
if(BT_){
std::swap(BS0, BS1);
std::swap(bcb0, bcb1);
std::swap(ldb0, ldb1);
}
os << os <<
R"( R"(
const tunable int32 TM = {16, 32, 64, 128}; const tunable int32 TM = {16, 32, 64, 128};
@@ -136,26 +175,27 @@ void shift(restrict read_only align(16) )" << a_ty_ << R"( *a,
int32 raw[TM] = rawhc % AW; int32 raw[TM] = rawhc % AW;
int32 rahc[TM] = rawhc / AW; int32 rahc[TM] = rawhc / AW;
int32 rah[TM] = rahc % AH; int32 rah[TM] = rahc % AH;
__constant__ int32* pd[TK] = delta + rka;
multiple_of(4) int32 d[TK] = *pd;
int1 maskh[TM] = (rah >= pad_h) && (rah < (AH - pad_h)); int1 maskh[TM] = (rah >= pad_h) && (rah < (AH - pad_h));
int1 maskw[TM] = (raw >= pad_w) && (raw < (AW - pad_w)); int1 maskw[TM] = (raw >= pad_w) && (raw < (AW - pad_w));
int1 mask[TM, TK] = maskh[:, newaxis] && maskw[:, newaxis]; int1 mask[)" << AS0 << ", " << AS1 << "] = maskh" << bca1 << " && maskw" << bca1 << R"(;
__constant__ int32* pd[TK] = delta + rka; int32 inc_true[)" << AS0 << ", " << AS1 << "] = d" << bca0 << R"(;
multiple_of(4) int32 d[TK]; int32 inc_false[)" << AS0 << ", " << AS1 << "] = rka" << bca0 << R"( * lda;
d = *pd; )" << a_ty_ << "* pa[" << AS0 << ", " << AS1 << R"(] = a + rxa)" << bca1 << R"( + (mask ? inc_true : inc_false);
int32 offa1[TK] = rka*lda; )" << b_ty_ << "* pb[" << BS0 << ", " << BS1 << "] = b + ryb" << bcb1 << " + rkb" << bcb0 << R"(*N;
int32 inc[TM, TK] = mask ? d[newaxis, :] : offa1[newaxis, :]; )" << a_ty_ << " a[" << AS0 << ", " << AS1 << R"(] = *pa;
)" << a_ty_ << R"(* pa[TM, TK] = a + rxa[:, newaxis] + inc; )" << b_ty_ << " b[" << BS0 << ", " << BS1 << R"(] = *pb;
)" << b_ty_ << R"(* pb[TN, TK] = b + rkb[newaxis, :]*N + ryb[:, newaxis];
)" << a_ty_ << R"( a[TM, TK] = *pa;
)" << b_ty_ << R"( b[TN, TK] = *pb;
for(int32 k = K; k > 0; k = k - TK){ for(int32 k = K; k > 0; k = k - TK){
C = dot(a, trans(b), C); C = dot()" << usea << "," << useb << R"(, C);
pb = pb + TK*N; pb = pb + TK*N;
pd = pd + TK; pd = pd + TK;
d = *pd; d = *pd;
pa = pa + (mask ? d[newaxis, :] : TK*lda); inc_true = d)" << bca0 << R"(;
int1 checka[TM, TK] = k > TK; inc_false = TK * lda;
int1 checkb[TN, TK] = k > TK; pa = pa + (mask ? inc_true : inc_false);
int1 checka[)" << AS0 << ", " << AS1 << R"(] = k > TK;
int1 checkb[)" << BS0 << ", " << BS1 << R"(] = k > TK;
@checka a = *pa; @checka a = *pa;
@checkb b = *pb; @checkb b = *pb;
} }

View File

@@ -255,7 +255,7 @@ std::string cu_module::compile_llvm_module(llvm::Module* module) {
cu_module::cu_module(driver::context * context, llvm::Module* ll_module): cu_module(context, compile_llvm_module(ll_module)) { } cu_module::cu_module(driver::context * context, llvm::Module* ll_module): cu_module(context, compile_llvm_module(ll_module)) { }
cu_module::cu_module(driver::context * context, std::string const & source) : module(context, CUmodule(), true), source_(source){ cu_module::cu_module(driver::context * context, std::string const & source) : module(context, CUmodule(), true), source_(source){
// std::cout << source << std::endl; // std::cout << source << sd::endl;
cu_context::context_switcher ctx_switch(*context); cu_context::context_switcher ctx_switch(*context);
// JIT compile source-code // JIT compile source-code
CUjit_option opt[] = {CU_JIT_ERROR_LOG_BUFFER_SIZE_BYTES, CU_JIT_ERROR_LOG_BUFFER}; CUjit_option opt[] = {CU_JIT_ERROR_LOG_BUFFER_SIZE_BYTES, CU_JIT_ERROR_LOG_BUFFER};

View File

@@ -175,7 +175,7 @@ ir::value* trans_expression::codegen(ir::module *mod) const {
/* Postfix expression */ /* Postfix expression */
ir::value* indexing_expression::codegen(ir::module *mod) const{ ir::value* indexing_expression::codegen(ir::module *mod) const{
ir::value *in = mod->get_value(id_->name()); ir::value *in = lhs_->codegen(mod);
const std::vector<slice*> &slices = slices_->values(); const std::vector<slice*> &slices = slices_->values();
auto in_shapes = in->get_type()->get_tile_shapes(); auto in_shapes = in->get_type()->get_tile_shapes();
ir::type::tile_shapes_t::value_type one = ir::tile_type::make_one(mod->get_context()); ir::type::tile_shapes_t::value_type one = ir::tile_type::make_one(mod->get_context());
@@ -234,44 +234,32 @@ ir::value* cast_expression::codegen(ir::module *mod) const{
/* Conditional expression */ /* Conditional expression */
ir::value *conditional_expression::codegen(ir::module *mod) const{ ir::value *conditional_expression::codegen(ir::module *mod) const{
ir::builder &builder = mod->get_builder(); ir::builder &builder = mod->get_builder();
ir::basic_block::inst_list_t &instructions = builder.get_insert_block()->get_inst_list();
ir::value *pred = cond_->codegen(mod); ir::value *pred = cond_->codegen(mod);
ir::instruction *mask = (ir::instruction*)builder.create_mask(pred); ir::instruction *mask = (ir::instruction*)builder.create_mask(pred);
/* true value */
ir::value *true_mask = mask->get_result(0); ir::value *true_mask = mask->get_result(0);
ir::value *false_mask = mask->get_result(1); auto it_true_begin = instructions.end();
it_true_begin--;
ir::value *true_value = true_value_->codegen(mod); ir::value *true_value = true_value_->codegen(mod);
ir::value *false_value = false_value_->codegen(mod);
if(auto *itn = dynamic_cast<ir::instruction*>(true_value))
itn->set_mask_pred(true_mask);
if(auto *itn = dynamic_cast<ir::instruction*>(false_value))
itn->set_mask_pred(false_mask);
bool is_float, is_ptr, is_int, is_signed;
ir::value *uncasted_true_value = true_value;
ir::value *uncasted_false_value = false_value;
implicit_cast(builder, true_value, false_value, is_float, is_ptr, is_int, is_signed);
implicit_broadcast(mod, pred, true_value); implicit_broadcast(mod, pred, true_value);
it_true_begin++;
auto it_true_end = instructions.end();
for(auto it = it_true_begin; it != it_true_end; it++)
(*it)->set_mask_pred(true_mask);
/* false value */
ir::value *false_mask = mask->get_result(1);
auto it_false_begin = instructions.end();
it_false_begin--;
ir::value *false_value = false_value_->codegen(mod);
it_false_begin++;
implicit_broadcast(mod, pred, false_value); implicit_broadcast(mod, pred, false_value);
{ auto it_false_end = instructions.end();
ir::value *current = true_value; for(auto it = it_false_begin; it != it_false_end; it++)
while(current != uncasted_true_value) { (*it)->set_mask_pred(false_mask);
if(auto *itn = dynamic_cast<ir::instruction*>(current)){ /* cast */
itn->set_mask_pred(true_mask); bool is_float, is_ptr, is_int, is_signed;
current = itn->get_operand(0); implicit_cast(builder, true_value, false_value, is_float, is_ptr, is_int, is_signed);
}
else
break;
}
}
{
ir::value *current = false_value;
while(current != uncasted_false_value) {
if(auto *itn = dynamic_cast<ir::instruction*>(current)){
itn->set_mask_pred(false_mask);
current = itn->get_operand(0);
}
else
break;
}
}
ir::value *result = builder.create_merge(true_mask, true_value, false_mask, false_value); ir::value *result = builder.create_merge(true_mask, true_value, false_mask, false_value);
return result; return result;
} }