more stuff
This commit is contained in:
@@ -16,6 +16,7 @@ int main() {
|
||||
auto context = triton::driver::backend::contexts::get_default();
|
||||
// initialize just-in-time compiler
|
||||
triton::jit jit(context);
|
||||
|
||||
// initialization
|
||||
int32_t R = 3, S = 3;
|
||||
int32_t BS = 32, F = 1024;
|
||||
@@ -30,7 +31,7 @@ int main() {
|
||||
shift_w[c] = rand() % S - S/2;
|
||||
}
|
||||
// configuration
|
||||
triton::dnn::shift shift(BS, C, 1, H, W, 1, R, S, F, shift_h, shift_w, numeric_t_str, numeric_t_str);
|
||||
triton::dnn::shift shift(BS, C, 1, H, W, 1, R, S, F, shift_h, shift_w, numeric_t_str, numeric_t_str, triton::dnn::shift::FPROP);
|
||||
// host buffers
|
||||
std::vector<float> hc(shift.c_size());
|
||||
std::vector<float> rc(shift.c_size());
|
||||
@@ -58,7 +59,7 @@ int main() {
|
||||
auto benchmark = [&](triton::driver::kernel* kernel,
|
||||
triton::jit::launch_information info) {
|
||||
shift.init(stream, (triton::driver::cu_module*)kernel->module());
|
||||
// launch info
|
||||
// launch infoRR
|
||||
unsigned TM = info.global_range_size[0];
|
||||
unsigned TN = info.global_range_size[1];
|
||||
unsigned nthreads = info.num_threads;
|
||||
@@ -78,7 +79,7 @@ int main() {
|
||||
std::ostringstream oss;
|
||||
shift.src(oss);
|
||||
std::string src = oss.str();
|
||||
// jit.autotune("shift", src.c_str(), benchmark);
|
||||
jit.autotune("shift", src.c_str(), benchmark);
|
||||
jit.add_module("shift", src.c_str(), params);
|
||||
triton::driver::kernel* kernel = jit.get_function("shift");
|
||||
triton::jit::launch_information info = jit.get_launch_info("shift");
|
||||
|
@@ -38,7 +38,9 @@ class shift {
|
||||
|
||||
public:
|
||||
enum type {
|
||||
FPROP
|
||||
FPROP,
|
||||
BPROP,
|
||||
WGRAD
|
||||
};
|
||||
|
||||
private:
|
||||
@@ -85,11 +87,11 @@ public:
|
||||
OUT_DTYPE acc;
|
||||
for(int32_t p = 0; p < AH_; ++p)
|
||||
for(int32_t q = 0; q < AW_; ++q)
|
||||
for(int32_t bs = 0; bs < NB_; ++bs)
|
||||
for(int32_t k = 0; k < NF_; ++k)
|
||||
for(int32_t bs = 0; bs < B_; ++bs)
|
||||
for(int32_t k = 0; k < F_; ++k)
|
||||
{
|
||||
acc = 0;
|
||||
for(int32_t c = 0; c < NC_; ++c){
|
||||
for(int32_t c = 0; c < C_; ++c){
|
||||
int32_t h = p;
|
||||
int32_t w = q;
|
||||
if(h >= BH_/2 && h < AH_ - BH_/2
|
||||
@@ -97,11 +99,11 @@ public:
|
||||
h += shift_h_[c];
|
||||
w += shift_w_[c];
|
||||
}
|
||||
IN_DTYPE a = I[bs + w*NB_ + h*NB_*AW_ + c*NB_*AH_*AW_];
|
||||
IN_DTYPE b = F[k + c*NF_];
|
||||
IN_DTYPE a = I[bs + w*B_ + h*B_*AW_ + c*B_*AH_*AW_];
|
||||
IN_DTYPE b = F[k + c*F_];
|
||||
acc = std::fma(a, b, acc);
|
||||
}
|
||||
O[bs + q*NB_ + p*NB_*AW_ + k*NB_*AH_*AW_] = acc;
|
||||
O[bs + q*B_ + p*B_*AW_ + k*B_*AH_*AW_] = acc;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -109,8 +111,8 @@ private:
|
||||
int32_t MAX_C_;
|
||||
int32_t TK_;
|
||||
// image size
|
||||
int32_t NB_;
|
||||
int32_t NC_;
|
||||
int32_t B_;
|
||||
int32_t C_;
|
||||
int32_t AD_;
|
||||
int32_t AH_;
|
||||
int32_t AW_;
|
||||
@@ -118,7 +120,7 @@ private:
|
||||
int32_t BD_;
|
||||
int32_t BH_;
|
||||
int32_t BW_;
|
||||
int32_t NF_;
|
||||
int32_t F_;
|
||||
// activation size
|
||||
int32_t CD_;
|
||||
int32_t CH_;
|
||||
@@ -149,6 +151,9 @@ private:
|
||||
// convolution type
|
||||
type ty_;
|
||||
bool bias_;
|
||||
// transpose
|
||||
bool AT_;
|
||||
bool BT_;
|
||||
};
|
||||
|
||||
}
|
||||
|
@@ -160,13 +160,13 @@ private:
|
||||
|
||||
class indexing_expression: public postfix_expression{
|
||||
public:
|
||||
indexing_expression(node *id, node *slices)
|
||||
: id_((const identifier*)id), slices_((const list<slice*>*)slices) {}
|
||||
indexing_expression(node *lhs, node *slices)
|
||||
: lhs_((const expression*)lhs), slices_((const list<slice*>*)slices) {}
|
||||
|
||||
ir::value* codegen(ir::module *) const;
|
||||
|
||||
private:
|
||||
const identifier* id_;
|
||||
const expression* lhs_;
|
||||
const list<slice*>* slices_;
|
||||
};
|
||||
|
||||
|
@@ -157,7 +157,7 @@ slice_list
|
||||
|
||||
postfix_expression
|
||||
: primary_expression { $$ = $1;}
|
||||
| identifier '[' slice_list ']' { $$ = new indexing_expression($1, $3);}
|
||||
| primary_expression '[' slice_list ']' { $$ = new indexing_expression($1, $3);}
|
||||
;
|
||||
|
||||
/* Unary */
|
||||
|
@@ -65,9 +65,9 @@ public:
|
||||
target_(target) { }
|
||||
|
||||
void target_independent(ir::module &module) {
|
||||
optimize_dot.run(module);
|
||||
optimize_trans.run(module);
|
||||
// ir::print(module, std::cout);
|
||||
ir::print(module, std::cout);
|
||||
optimize_dot.run(module);
|
||||
optimize_trans.run(module);
|
||||
}
|
||||
|
||||
void target_dependent(ir::module &module) {
|
||||
|
@@ -59,7 +59,7 @@ void tune::init_c_graph(ir::instruction *v) {
|
||||
else if(auto *downcast = dynamic_cast<ir::downcast_inst*>(v))
|
||||
return;
|
||||
else{
|
||||
// std::cout << v->get_name() << std::endl;
|
||||
std::cout << v->get_name() << std::endl;
|
||||
shapes = v->get_type()->get_tile_shapes();
|
||||
}
|
||||
// Reshape
|
||||
|
@@ -8,42 +8,63 @@ void shift::set_ld(const std::vector<int32_t>& shapes,
|
||||
std::vector<int32_t>& ld) {
|
||||
size_t size = shapes.size();
|
||||
ld.resize(size);
|
||||
ld[3] = 1;
|
||||
ld[2] = shapes[3]*ld[3];
|
||||
ld[1] = shapes[2]*ld[2];
|
||||
ld[0] = shapes[1]*ld[1];
|
||||
ld[size - 1] = 1;
|
||||
for(int i = size - 1; i >= 1; i--)
|
||||
ld[i - 1] = shapes[i] * ld[i];
|
||||
}
|
||||
|
||||
shift::shift(int B, int NC,
|
||||
shift::shift(int B, int C,
|
||||
int D, int H, int W,
|
||||
int T, int R, int S,
|
||||
int NF,
|
||||
int F,
|
||||
const std::vector<int32_t>& shift_h, const std::vector<int32_t>& shift_w,
|
||||
std::string a_ty, std::string b_ty,
|
||||
type ty, bool bias)
|
||||
: NB_(B), NC_(NC),
|
||||
: B_(B), C_(C),
|
||||
AD_(D), AH_(H), AW_(W),
|
||||
BD_(T), BH_(R), BW_(S),
|
||||
NF_(NF),
|
||||
F_(F),
|
||||
shift_h_(shift_h), shift_w_(shift_w),
|
||||
a_ty_(a_ty), b_ty_(b_ty),
|
||||
ty_(ty), bias_(bias) {
|
||||
// max number of channels
|
||||
TK_ = 16;
|
||||
MAX_C_ = 8192 + TK_;
|
||||
// transpose
|
||||
AT_ = false;
|
||||
BT_ = true;
|
||||
// equivalent matmul
|
||||
M_ = NB_*AH_*AW_;
|
||||
N_ = NF_;
|
||||
K_ = NC_;
|
||||
M_ = B_*AH_*AW_;
|
||||
N_ = F_;
|
||||
K_ = C_;
|
||||
// shapes
|
||||
// input layout: C, H, W, BS
|
||||
// filter layout: C, K
|
||||
// output layout: K, H, W, BS
|
||||
shapes_a_ = {NC, H, W, B};
|
||||
shapes_b_ = {NC, NF};
|
||||
shapes_c_ = {NF, H, W, B};
|
||||
// input layout: C, H, W, B
|
||||
// filter layout: C, F
|
||||
// output layout: F, H, W, B
|
||||
shapes_a_ = {C, H, W, B};
|
||||
shapes_b_ = {C, F};
|
||||
shapes_c_ = {F, H, W, B};
|
||||
if(ty_ == WGRAD){
|
||||
shapes_b_.swap(shapes_c_);
|
||||
shapes_a_.swap(shapes_b_);
|
||||
AT_ = true;
|
||||
BT_ = false;
|
||||
M_ = K_;
|
||||
N_ = C_;
|
||||
K_ = B_*AH_*AW_;
|
||||
}
|
||||
if(ty_ == BPROP){
|
||||
shapes_a_.swap(shapes_c_);
|
||||
AT_ = false;
|
||||
BT_ = false;
|
||||
K_ = F_;
|
||||
M_ = B_*AH_*AW_;
|
||||
N_ = C_;
|
||||
}
|
||||
// memory strides
|
||||
set_ld(shapes_a_, ld_a_);
|
||||
set_ld(shapes_b_, ld_b_);
|
||||
set_ld(shapes_c_, ld_c_);
|
||||
// build LUTs
|
||||
build_deltas();
|
||||
}
|
||||
@@ -57,7 +78,7 @@ void shift::build_deltas() {
|
||||
// populate look-up table
|
||||
for(unsigned c = 0; c < TK_; c++)
|
||||
h_deltas_[c] = offset(c);
|
||||
for(unsigned c = 0; c < NC_; c++)
|
||||
for(unsigned c = 0; c < C_; c++)
|
||||
h_deltas_[TK_ + c] = offset(c + TK_) - offset(c);
|
||||
}
|
||||
|
||||
@@ -99,18 +120,36 @@ void shift::enqueue(driver::stream *stream, driver::kernel *kernel,
|
||||
kernel->setArg(3, M_);
|
||||
kernel->setArg(4, N_);
|
||||
kernel->setArg(5, K_);
|
||||
kernel->setArg(6, NB_*AH_*AW_);
|
||||
kernel->setArg(7, NB_);
|
||||
kernel->setArg(6, B_*AH_*AW_);
|
||||
kernel->setArg(7, B_);
|
||||
kernel->setArg(8, AH_);
|
||||
kernel->setArg(9, AW_);
|
||||
kernel->setArg(10, BH_);
|
||||
kernel->setArg(11, BW_);
|
||||
// dry run
|
||||
std::array<size_t, 3> grid = {(M_ + TM - 1)/TM, (N_ + TN - 1)/TN, 1};
|
||||
stream->enqueue(kernel, grid, {nthreads, 1, 1});
|
||||
}
|
||||
|
||||
void shift::src(std::ostream &os) {
|
||||
std::string AS0 = "TM", AS1 = "TK";
|
||||
std::string BS0 = "TK", BS1 = "TN";
|
||||
std::string bca0 = "[newaxis, :]", bca1 = "[:, newaxis]";
|
||||
std::string bcb0 = "[:, newaxis]", bcb1 = "[newaxis, :]";
|
||||
std::string lda0 = "*lda", lda1 = "";
|
||||
std::string ldb0 = "", ldb1 = "*ldb";
|
||||
std::string usea = AT_ ? "trans(a)" : "a";
|
||||
std::string useb = BT_ ? "trans(b)" : "b";
|
||||
if(AT_){
|
||||
std::swap(AS0, AS1);
|
||||
std::swap(bca0, bca1);
|
||||
std::swap(lda0, lda1);
|
||||
}
|
||||
if(BT_){
|
||||
std::swap(BS0, BS1);
|
||||
std::swap(bcb0, bcb1);
|
||||
std::swap(ldb0, ldb1);
|
||||
}
|
||||
|
||||
os <<
|
||||
R"(
|
||||
const tunable int32 TM = {16, 32, 64, 128};
|
||||
@@ -136,26 +175,27 @@ void shift(restrict read_only align(16) )" << a_ty_ << R"( *a,
|
||||
int32 raw[TM] = rawhc % AW;
|
||||
int32 rahc[TM] = rawhc / AW;
|
||||
int32 rah[TM] = rahc % AH;
|
||||
__constant__ int32* pd[TK] = delta + rka;
|
||||
multiple_of(4) int32 d[TK] = *pd;
|
||||
int1 maskh[TM] = (rah >= pad_h) && (rah < (AH - pad_h));
|
||||
int1 maskw[TM] = (raw >= pad_w) && (raw < (AW - pad_w));
|
||||
int1 mask[TM, TK] = maskh[:, newaxis] && maskw[:, newaxis];
|
||||
__constant__ int32* pd[TK] = delta + rka;
|
||||
multiple_of(4) int32 d[TK];
|
||||
d = *pd;
|
||||
int32 offa1[TK] = rka*lda;
|
||||
int32 inc[TM, TK] = mask ? d[newaxis, :] : offa1[newaxis, :];
|
||||
)" << a_ty_ << R"(* pa[TM, TK] = a + rxa[:, newaxis] + inc;
|
||||
)" << b_ty_ << R"(* pb[TN, TK] = b + rkb[newaxis, :]*N + ryb[:, newaxis];
|
||||
)" << a_ty_ << R"( a[TM, TK] = *pa;
|
||||
)" << b_ty_ << R"( b[TN, TK] = *pb;
|
||||
int1 mask[)" << AS0 << ", " << AS1 << "] = maskh" << bca1 << " && maskw" << bca1 << R"(;
|
||||
int32 inc_true[)" << AS0 << ", " << AS1 << "] = d" << bca0 << R"(;
|
||||
int32 inc_false[)" << AS0 << ", " << AS1 << "] = rka" << bca0 << R"( * lda;
|
||||
)" << a_ty_ << "* pa[" << AS0 << ", " << AS1 << R"(] = a + rxa)" << bca1 << R"( + (mask ? inc_true : inc_false);
|
||||
)" << b_ty_ << "* pb[" << BS0 << ", " << BS1 << "] = b + ryb" << bcb1 << " + rkb" << bcb0 << R"(*N;
|
||||
)" << a_ty_ << " a[" << AS0 << ", " << AS1 << R"(] = *pa;
|
||||
)" << b_ty_ << " b[" << BS0 << ", " << BS1 << R"(] = *pb;
|
||||
for(int32 k = K; k > 0; k = k - TK){
|
||||
C = dot(a, trans(b), C);
|
||||
C = dot()" << usea << "," << useb << R"(, C);
|
||||
pb = pb + TK*N;
|
||||
pd = pd + TK;
|
||||
d = *pd;
|
||||
pa = pa + (mask ? d[newaxis, :] : TK*lda);
|
||||
int1 checka[TM, TK] = k > TK;
|
||||
int1 checkb[TN, TK] = k > TK;
|
||||
inc_true = d)" << bca0 << R"(;
|
||||
inc_false = TK * lda;
|
||||
pa = pa + (mask ? inc_true : inc_false);
|
||||
int1 checka[)" << AS0 << ", " << AS1 << R"(] = k > TK;
|
||||
int1 checkb[)" << BS0 << ", " << BS1 << R"(] = k > TK;
|
||||
@checka a = *pa;
|
||||
@checkb b = *pb;
|
||||
}
|
||||
|
@@ -255,7 +255,7 @@ std::string cu_module::compile_llvm_module(llvm::Module* module) {
|
||||
cu_module::cu_module(driver::context * context, llvm::Module* ll_module): cu_module(context, compile_llvm_module(ll_module)) { }
|
||||
|
||||
cu_module::cu_module(driver::context * context, std::string const & source) : module(context, CUmodule(), true), source_(source){
|
||||
// std::cout << source << std::endl;
|
||||
// std::cout << source << sd::endl;
|
||||
cu_context::context_switcher ctx_switch(*context);
|
||||
// JIT compile source-code
|
||||
CUjit_option opt[] = {CU_JIT_ERROR_LOG_BUFFER_SIZE_BYTES, CU_JIT_ERROR_LOG_BUFFER};
|
||||
|
@@ -175,7 +175,7 @@ ir::value* trans_expression::codegen(ir::module *mod) const {
|
||||
|
||||
/* Postfix expression */
|
||||
ir::value* indexing_expression::codegen(ir::module *mod) const{
|
||||
ir::value *in = mod->get_value(id_->name());
|
||||
ir::value *in = lhs_->codegen(mod);
|
||||
const std::vector<slice*> &slices = slices_->values();
|
||||
auto in_shapes = in->get_type()->get_tile_shapes();
|
||||
ir::type::tile_shapes_t::value_type one = ir::tile_type::make_one(mod->get_context());
|
||||
@@ -234,44 +234,32 @@ ir::value* cast_expression::codegen(ir::module *mod) const{
|
||||
/* Conditional expression */
|
||||
ir::value *conditional_expression::codegen(ir::module *mod) const{
|
||||
ir::builder &builder = mod->get_builder();
|
||||
ir::basic_block::inst_list_t &instructions = builder.get_insert_block()->get_inst_list();
|
||||
ir::value *pred = cond_->codegen(mod);
|
||||
ir::instruction *mask = (ir::instruction*)builder.create_mask(pred);
|
||||
/* true value */
|
||||
ir::value *true_mask = mask->get_result(0);
|
||||
ir::value *false_mask = mask->get_result(1);
|
||||
auto it_true_begin = instructions.end();
|
||||
it_true_begin--;
|
||||
ir::value *true_value = true_value_->codegen(mod);
|
||||
ir::value *false_value = false_value_->codegen(mod);
|
||||
if(auto *itn = dynamic_cast<ir::instruction*>(true_value))
|
||||
itn->set_mask_pred(true_mask);
|
||||
if(auto *itn = dynamic_cast<ir::instruction*>(false_value))
|
||||
itn->set_mask_pred(false_mask);
|
||||
bool is_float, is_ptr, is_int, is_signed;
|
||||
ir::value *uncasted_true_value = true_value;
|
||||
ir::value *uncasted_false_value = false_value;
|
||||
implicit_cast(builder, true_value, false_value, is_float, is_ptr, is_int, is_signed);
|
||||
implicit_broadcast(mod, pred, true_value);
|
||||
it_true_begin++;
|
||||
auto it_true_end = instructions.end();
|
||||
for(auto it = it_true_begin; it != it_true_end; it++)
|
||||
(*it)->set_mask_pred(true_mask);
|
||||
/* false value */
|
||||
ir::value *false_mask = mask->get_result(1);
|
||||
auto it_false_begin = instructions.end();
|
||||
it_false_begin--;
|
||||
ir::value *false_value = false_value_->codegen(mod);
|
||||
it_false_begin++;
|
||||
implicit_broadcast(mod, pred, false_value);
|
||||
{
|
||||
ir::value *current = true_value;
|
||||
while(current != uncasted_true_value) {
|
||||
if(auto *itn = dynamic_cast<ir::instruction*>(current)){
|
||||
itn->set_mask_pred(true_mask);
|
||||
current = itn->get_operand(0);
|
||||
}
|
||||
else
|
||||
break;
|
||||
}
|
||||
}
|
||||
{
|
||||
ir::value *current = false_value;
|
||||
while(current != uncasted_false_value) {
|
||||
if(auto *itn = dynamic_cast<ir::instruction*>(current)){
|
||||
itn->set_mask_pred(false_mask);
|
||||
current = itn->get_operand(0);
|
||||
}
|
||||
else
|
||||
break;
|
||||
}
|
||||
}
|
||||
auto it_false_end = instructions.end();
|
||||
for(auto it = it_false_begin; it != it_false_end; it++)
|
||||
(*it)->set_mask_pred(false_mask);
|
||||
/* cast */
|
||||
bool is_float, is_ptr, is_int, is_signed;
|
||||
implicit_cast(builder, true_value, false_value, is_float, is_ptr, is_int, is_signed);
|
||||
ir::value *result = builder.create_merge(true_mask, true_value, false_mask, false_value);
|
||||
return result;
|
||||
}
|
||||
|
Reference in New Issue
Block a user