[codegen] shift: added sketch for shift-convolution backpropagation
This commit is contained in:
@@ -8,7 +8,7 @@
|
|||||||
|
|
||||||
|
|
||||||
int main() {
|
int main() {
|
||||||
bool AT = false;
|
bool AT = true;
|
||||||
bool BT = true;
|
bool BT = true;
|
||||||
|
|
||||||
// initialize default compute device
|
// initialize default compute device
|
||||||
@@ -16,7 +16,7 @@ int main() {
|
|||||||
triton::jit jit(context);
|
triton::jit jit(context);
|
||||||
|
|
||||||
// matrix multiplication parameters
|
// matrix multiplication parameters
|
||||||
int32_t M = 32768, N = 1024, K = 1024;
|
int32_t M = 1024, N = 1024, K = 1024;
|
||||||
std::vector<float> hc(M*N);
|
std::vector<float> hc(M*N);
|
||||||
std::vector<float> rc(M*N);
|
std::vector<float> rc(M*N);
|
||||||
std::vector<float> ha(M*K);
|
std::vector<float> ha(M*K);
|
||||||
@@ -59,9 +59,9 @@ int main() {
|
|||||||
|
|
||||||
|
|
||||||
// just-in-time compile source-code
|
// just-in-time compile source-code
|
||||||
std::string src = triton::dnn::gemm::src(AT, BT, "fp32", "fp32", 1, 1);
|
std::string src = triton::dnn::gemm::src(AT, BT, "fp32", "fp32", 4, 4);
|
||||||
jit.autotune("matmul",src.c_str(), benchmark);
|
// jit.autotune("matmul",src.c_str(), benchmark);
|
||||||
jit.add_module("matmul", src.c_str(), triton::dnn::gemm::default_params(AT, BT));
|
jit.add_module("matmul", src.c_str(), {8, 16, 4, 2, 16, 8, 4, 2, 2, 4, 2, 8, 8, 1});
|
||||||
triton::driver::kernel* kernel = jit.get_function("matmul");
|
triton::driver::kernel* kernel = jit.get_function("matmul");
|
||||||
triton::jit::launch_information info = jit.get_launch_info("matmul");
|
triton::jit::launch_information info = jit.get_launch_info("matmul");
|
||||||
std::cout << "Performance: " << benchmark(kernel, info) << " TFLOPS " << std::endl;
|
std::cout << "Performance: " << benchmark(kernel, info) << " TFLOPS " << std::endl;
|
||||||
|
@@ -19,8 +19,8 @@ int main() {
|
|||||||
|
|
||||||
// initialization
|
// initialization
|
||||||
int32_t R = 3, S = 3;
|
int32_t R = 3, S = 3;
|
||||||
int32_t BS = 32, F = 1024;
|
int32_t BS = 4, F = 1024;
|
||||||
int32_t H = 32, W = 32;
|
int32_t H = 16, W = 16;
|
||||||
int32_t C = 1024;
|
int32_t C = 1024;
|
||||||
|
|
||||||
// random shifts
|
// random shifts
|
||||||
@@ -31,7 +31,7 @@ int main() {
|
|||||||
shift_w[c] = rand() % S - S/2;
|
shift_w[c] = rand() % S - S/2;
|
||||||
}
|
}
|
||||||
// configuration
|
// configuration
|
||||||
triton::dnn::shift shift(BS, C, 1, H, W, 1, R, S, F, shift_h, shift_w, numeric_t_str, numeric_t_str, triton::dnn::shift::FPROP);
|
triton::dnn::shift shift(BS, C, 1, H, W, 1, R, S, F, shift_h, shift_w, numeric_t_str, numeric_t_str, triton::dnn::shift::BPROP);
|
||||||
// host buffers
|
// host buffers
|
||||||
std::vector<float> hc(shift.c_size());
|
std::vector<float> hc(shift.c_size());
|
||||||
std::vector<float> rc(shift.c_size());
|
std::vector<float> rc(shift.c_size());
|
||||||
|
@@ -128,6 +128,7 @@ public:
|
|||||||
value *create_get_global_range(unsigned axis, type::tile_shapes_t::value_type size, const std::string &name = "");
|
value *create_get_global_range(unsigned axis, type::tile_shapes_t::value_type size, const std::string &name = "");
|
||||||
value *create_get_range_id(unsigned axis, const std::string &name = "");
|
value *create_get_range_id(unsigned axis, const std::string &name = "");
|
||||||
value *create_atomic_cas(value *ptr, value *cmp, value *val, const std::string &name = "");
|
value *create_atomic_cas(value *ptr, value *cmp, value *val, const std::string &name = "");
|
||||||
|
value *create_atomic_add(value *ptr, value *val, const std::string &name = "");
|
||||||
value *create_dot(value *A, value *B, value *C, const std::string &name = "");
|
value *create_dot(value *A, value *B, value *C, const std::string &name = "");
|
||||||
value *create_trans(value *A, const std::string &name = "");
|
value *create_trans(value *A, const std::string &name = "");
|
||||||
value *create_select(value *pred, value *if_value, value *else_value, const std::string &name = "");
|
value *create_select(value *pred, value *if_value, value *else_value, const std::string &name = "");
|
||||||
|
@@ -532,6 +532,15 @@ public:
|
|||||||
static instruction* create(value *ptr, value *cmp, value *val, const std::string &name = "", instruction *next = nullptr);
|
static instruction* create(value *ptr, value *cmp, value *val, const std::string &name = "", instruction *next = nullptr);
|
||||||
};
|
};
|
||||||
|
|
||||||
|
class atomic_add_inst: public builtin_inst {
|
||||||
|
private:
|
||||||
|
atomic_add_inst(value *ptr, value *val, const std::string &name = "", instruction *next = nullptr);
|
||||||
|
std::string repr_impl() const { return "atomic_add"; }
|
||||||
|
|
||||||
|
public:
|
||||||
|
static instruction* create(value *ptr, value *val, const std::string &name = "", instruction *next = nullptr);
|
||||||
|
};
|
||||||
|
|
||||||
class dot_inst: public builtin_inst {
|
class dot_inst: public builtin_inst {
|
||||||
public:
|
public:
|
||||||
enum TransT { NoTrans, Trans };
|
enum TransT { NoTrans, Trans };
|
||||||
|
@@ -101,6 +101,16 @@ private:
|
|||||||
const node *val_;
|
const node *val_;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
class atomic_add_expression: public builtin_expression{
|
||||||
|
public:
|
||||||
|
atomic_add_expression(node *ptr, node *val): ptr_(ptr), val_(val) { }
|
||||||
|
ir::value* codegen(ir::module *) const;
|
||||||
|
|
||||||
|
private:
|
||||||
|
const node *ptr_;
|
||||||
|
const node *val_;
|
||||||
|
};
|
||||||
|
|
||||||
|
|
||||||
class matmul_expression: public builtin_expression{
|
class matmul_expression: public builtin_expression{
|
||||||
public:
|
public:
|
||||||
|
@@ -55,7 +55,7 @@ STORAGE_SPEC_T get_storage_spec(node *op) { return ((token*)op)->storage_spec;}
|
|||||||
%token VOID UINT1 UINT8 UINT16 UINT32 UINT64 INT1 INT8 INT16 INT32 INT64 FP16 FP32 FP64
|
%token VOID UINT1 UINT8 UINT16 UINT32 UINT64 INT1 INT8 INT16 INT32 INT64 FP16 FP32 FP64
|
||||||
%token IF ELSE FOR CONTINUE WHILE
|
%token IF ELSE FOR CONTINUE WHILE
|
||||||
%token NEWAXIS ELLIPSIS AT
|
%token NEWAXIS ELLIPSIS AT
|
||||||
%token GET_GLOBAL_RANGE GET_RANGE_ID DOT TRANS MAX MIN SELECT ATOMIC_CAS ATOMIC_EXCHG ALLOC_CONST
|
%token GET_GLOBAL_RANGE GET_RANGE_ID DOT TRANS MAX MIN SELECT ATOMIC_CAS ATOMIC_EXCHG ATOMIC_ADD ALLOC_CONST
|
||||||
|
|
||||||
%start translation_unit
|
%start translation_unit
|
||||||
%%
|
%%
|
||||||
@@ -129,6 +129,7 @@ builtin_expression
|
|||||||
| MIN '(' expression ',' expression ')' { $$ = new min_expression($3, $5); }
|
| MIN '(' expression ',' expression ')' { $$ = new min_expression($3, $5); }
|
||||||
| SELECT '(' expression ',' expression ',' expression ')' { $$ = new select_expression($3, $5, $7); }
|
| SELECT '(' expression ',' expression ',' expression ')' { $$ = new select_expression($3, $5, $7); }
|
||||||
| ATOMIC_CAS '(' expression ',' expression ',' expression ')' { $$ = new atomic_cas_expression($3, $5, $7); }
|
| ATOMIC_CAS '(' expression ',' expression ',' expression ')' { $$ = new atomic_cas_expression($3, $5, $7); }
|
||||||
|
| ATOMIC_ADD '(' expression ',' expression ')' { $$ = new atomic_add_expression($3, $5); }
|
||||||
;
|
;
|
||||||
|
|
||||||
/* Primary */
|
/* Primary */
|
||||||
|
@@ -28,26 +28,27 @@ using triton::lang::return_void;
|
|||||||
"if" { return return_impl(IF, yytext); }
|
"if" { return return_impl(IF, yytext); }
|
||||||
"else" { return return_impl(ELSE, yytext); }
|
"else" { return return_impl(ELSE, yytext); }
|
||||||
"for" { return return_impl(FOR, yytext); }
|
"for" { return return_impl(FOR, yytext); }
|
||||||
"while" { return return_impl(WHILE, yytext); }
|
"while" { return return_impl(WHILE, yytext); }
|
||||||
"void" { return return_impl(VOID, yytext); }
|
"void" { return return_impl(VOID, yytext); }
|
||||||
"uint1" { return return_impl(UINT1, yytext); }
|
"uint1" { return return_impl(UINT1, yytext); }
|
||||||
"uint8" { return return_impl(UINT8, yytext); }
|
"uint8" { return return_impl(UINT8, yytext); }
|
||||||
"uint16" { return return_impl(UINT16, yytext); }
|
"uint16" { return return_impl(UINT16, yytext); }
|
||||||
"uint32" { return return_impl(UINT32, yytext); }
|
"uint32" { return return_impl(UINT32, yytext); }
|
||||||
"uint64" { return return_impl(UINT64, yytext); }
|
"uint64" { return return_impl(UINT64, yytext); }
|
||||||
"int1" { return return_impl(INT1, yytext); }
|
"int1" { return return_impl(INT1, yytext); }
|
||||||
"int8" { return return_impl(INT8, yytext); }
|
"int8" { return return_impl(INT8, yytext); }
|
||||||
"int16" { return return_impl(INT16, yytext); }
|
"int16" { return return_impl(INT16, yytext); }
|
||||||
"int32" { return return_impl(INT32, yytext); }
|
"int32" { return return_impl(INT32, yytext); }
|
||||||
"int64" { return return_impl(INT64, yytext); }
|
"int64" { return return_impl(INT64, yytext); }
|
||||||
"fp16" { return return_impl(FP16, yytext); }
|
"fp16" { return return_impl(FP16, yytext); }
|
||||||
"fp32" { return return_impl(FP32, yytext); }
|
"fp32" { return return_impl(FP32, yytext); }
|
||||||
"fp64" { return return_impl(FP64, yytext); }
|
"fp64" { return return_impl(FP64, yytext); }
|
||||||
"..." { return return_impl(ELLIPSIS, yytext); }
|
"..." { return return_impl(ELLIPSIS, yytext); }
|
||||||
"get_global_range" { return return_impl(GET_GLOBAL_RANGE, yytext); }
|
"get_global_range" { return return_impl(GET_GLOBAL_RANGE, yytext); }
|
||||||
"get_range_id" { return return_impl(GET_RANGE_ID, yytext); }
|
"get_range_id" { return return_impl(GET_RANGE_ID, yytext); }
|
||||||
"__atomic_cas" { return return_impl(ATOMIC_CAS, yytext); }
|
"__atomic_cas" { return return_impl(ATOMIC_CAS, yytext); }
|
||||||
"__atomic_exchg" { return return_impl(ATOMIC_EXCHG, yytext); }
|
"__atomic_exchg" { return return_impl(ATOMIC_EXCHG, yytext); }
|
||||||
|
"__atomic_add" { return return_impl(ATOMIC_ADD, yytext); }
|
||||||
"dot" { return return_impl(DOT, yytext); }
|
"dot" { return return_impl(DOT, yytext); }
|
||||||
"max" { return return_impl(MAX, yytext); }
|
"max" { return return_impl(MAX, yytext); }
|
||||||
"min" { return return_impl(MIN, yytext); }
|
"min" { return return_impl(MIN, yytext); }
|
||||||
@@ -58,57 +59,57 @@ using triton::lang::return_void;
|
|||||||
{L}({L}|{D})* { return return_impl(IDENTIFIER, yytext); }
|
{L}({L}|{D})* { return return_impl(IDENTIFIER, yytext); }
|
||||||
0[xX]{H}+{IS}? { return return_impl(CONSTANT, yytext); }
|
0[xX]{H}+{IS}? { return return_impl(CONSTANT, yytext); }
|
||||||
0{D}+{IS}? { return return_impl(CONSTANT, yytext); }
|
0{D}+{IS}? { return return_impl(CONSTANT, yytext); }
|
||||||
{D}+{IS}? { return return_impl(CONSTANT, yytext); }
|
{D}+{IS}? { return return_impl(CONSTANT, yytext); }
|
||||||
L?'(\\.|[^\\'])+' { return return_impl(CONSTANT, yytext); }
|
L?'(\\.|[^\\'])+' { return return_impl(CONSTANT, yytext); }
|
||||||
{D}+{E}{FS}? { return return_impl(CONSTANT, yytext); }
|
{D}+{E}{FS}? { return return_impl(CONSTANT, yytext); }
|
||||||
L?\"(\\.|[^\\"])*\" { return return_impl(STRING_LITERAL, yytext); }
|
L?\"(\\.|[^\\"])*\" { return return_impl(STRING_LITERAL, yytext); }
|
||||||
">>=" { return return_impl(RIGHT_ASSIGN, yytext); }
|
">>=" { return return_impl(RIGHT_ASSIGN, yytext); }
|
||||||
"<<=" { return return_impl(LEFT_ASSIGN, yytext); }
|
"<<=" { return return_impl(LEFT_ASSIGN, yytext); }
|
||||||
"+=" { return return_impl(ADD_ASSIGN, yytext); }
|
"+=" { return return_impl(ADD_ASSIGN, yytext); }
|
||||||
"-=" { return return_impl(SUB_ASSIGN, yytext); }
|
"-=" { return return_impl(SUB_ASSIGN, yytext); }
|
||||||
"*=" { return return_impl(MUL_ASSIGN, yytext); }
|
"*=" { return return_impl(MUL_ASSIGN, yytext); }
|
||||||
"/=" { return return_impl(DIV_ASSIGN, yytext); }
|
"/=" { return return_impl(DIV_ASSIGN, yytext); }
|
||||||
"%=" { return return_impl(MOD_ASSIGN, yytext); }
|
"%=" { return return_impl(MOD_ASSIGN, yytext); }
|
||||||
"&=" { return return_impl(AND_ASSIGN, yytext); }
|
"&=" { return return_impl(AND_ASSIGN, yytext); }
|
||||||
"^=" { return return_impl(XOR_ASSIGN, yytext); }
|
"^=" { return return_impl(XOR_ASSIGN, yytext); }
|
||||||
"|=" { return return_impl(OR_ASSIGN, yytext); }
|
"|=" { return return_impl(OR_ASSIGN, yytext); }
|
||||||
">>" { return return_impl(RIGHT_OP, yytext); }
|
">>" { return return_impl(RIGHT_OP, yytext); }
|
||||||
"<<" { return return_impl(LEFT_OP, yytext); }
|
"<<" { return return_impl(LEFT_OP, yytext); }
|
||||||
"++" { return return_impl(INC_OP, yytext); }
|
"++" { return return_impl(INC_OP, yytext); }
|
||||||
"--" { return return_impl(DEC_OP, yytext); }
|
"--" { return return_impl(DEC_OP, yytext); }
|
||||||
"->" { return return_impl(PTR_OP, yytext); }
|
"->" { return return_impl(PTR_OP, yytext); }
|
||||||
"&&" { return return_impl(AND_OP, yytext); }
|
"&&" { return return_impl(AND_OP, yytext); }
|
||||||
"||" { return return_impl(OR_OP, yytext); }
|
"||" { return return_impl(OR_OP, yytext); }
|
||||||
"<=" { return return_impl(LE_OP, yytext); }
|
"<=" { return return_impl(LE_OP, yytext); }
|
||||||
">=" { return return_impl(GE_OP, yytext); }
|
">=" { return return_impl(GE_OP, yytext); }
|
||||||
"==" { return return_impl(EQ_OP, yytext); }
|
"==" { return return_impl(EQ_OP, yytext); }
|
||||||
"!=" { return return_impl(NE_OP, yytext); }
|
"!=" { return return_impl(NE_OP, yytext); }
|
||||||
";" { return return_impl(';', yytext); }
|
";" { return return_impl(';', yytext); }
|
||||||
("{"|"<%") { return return_impl('{', yytext); }
|
("{"|"<%") { return return_impl('{', yytext); }
|
||||||
("}"|"%>") { return return_impl('}', yytext); }
|
("}"|"%>") { return return_impl('}', yytext); }
|
||||||
"," { return return_impl(',', yytext); }
|
"," { return return_impl(',', yytext); }
|
||||||
":" { return return_impl(':', yytext); }
|
":" { return return_impl(':', yytext); }
|
||||||
"=" { return return_impl('=', yytext); }
|
"=" { return return_impl('=', yytext); }
|
||||||
"(" { return return_impl('(', yytext); }
|
"(" { return return_impl('(', yytext); }
|
||||||
")" { return return_impl(')', yytext); }
|
")" { return return_impl(')', yytext); }
|
||||||
("["|"<:") { return return_impl('[', yytext); }
|
("["|"<:") { return return_impl('[', yytext); }
|
||||||
("]"|":>") { return return_impl(']', yytext); }
|
("]"|":>") { return return_impl(']', yytext); }
|
||||||
"." { return return_impl('.', yytext); }
|
"." { return return_impl('.', yytext); }
|
||||||
"&" { return return_impl('&', yytext); }
|
"&" { return return_impl('&', yytext); }
|
||||||
"!" { return return_impl('!', yytext); }
|
"!" { return return_impl('!', yytext); }
|
||||||
"~" { return return_impl('~', yytext); }
|
"~" { return return_impl('~', yytext); }
|
||||||
"-" { return return_impl('-', yytext); }
|
"-" { return return_impl('-', yytext); }
|
||||||
"+" { return return_impl('+', yytext); }
|
"+" { return return_impl('+', yytext); }
|
||||||
"*" { return return_impl('*', yytext); }
|
"*" { return return_impl('*', yytext); }
|
||||||
"/" { return return_impl('/', yytext); }
|
"/" { return return_impl('/', yytext); }
|
||||||
"%" { return return_impl('%', yytext); }
|
"%" { return return_impl('%', yytext); }
|
||||||
"<" { return return_impl('<', yytext); }
|
"<" { return return_impl('<', yytext); }
|
||||||
">" { return return_impl('>', yytext); }
|
">" { return return_impl('>', yytext); }
|
||||||
"^" { return return_impl('^', yytext); }
|
"^" { return return_impl('^', yytext); }
|
||||||
"|" { return return_impl('|', yytext); }
|
"|" { return return_impl('|', yytext); }
|
||||||
"?" { return return_impl('?', yytext); }
|
"?" { return return_impl('?', yytext); }
|
||||||
[ \t\v\n\f] { return_void(yytext);}
|
[ \t\v\n\f] { return_void(yytext);}
|
||||||
. { /* ignore bad characters */ }
|
. { /* ignore bad characters */ }
|
||||||
|
|
||||||
%%
|
%%
|
||||||
|
|
||||||
|
@@ -373,6 +373,13 @@ Instruction *selection::llvm_inst(ir::instruction *inst, std::function<Value*(ir
|
|||||||
Value *res = builder.CreateLoad(ptr);
|
Value *res = builder.CreateLoad(ptr);
|
||||||
return (Instruction*)res;
|
return (Instruction*)res;
|
||||||
}
|
}
|
||||||
|
if(ir::atomic_add_inst* ii = dynamic_cast<ir::atomic_add_inst*>(inst)){
|
||||||
|
Value *ptr = value(ii->get_operand(0));
|
||||||
|
Value *val = value(ii->get_operand(1));
|
||||||
|
Value *atom_f_add = Intrinsic::getDeclaration(builder.GetInsertBlock()->getModule(), Intrinsic::nvvm_atomic_load_add_f32, {ptr->getType()});
|
||||||
|
Value *res = builder.CreateCall(atom_f_add, {ptr, val});
|
||||||
|
return (Instruction*)res;
|
||||||
|
}
|
||||||
// unknown instruction
|
// unknown instruction
|
||||||
throw std::runtime_error("unknown conversion from ir::instruction to Instruction");
|
throw std::runtime_error("unknown conversion from ir::instruction to Instruction");
|
||||||
}
|
}
|
||||||
|
@@ -56,6 +56,8 @@ void tune::init_c_graph(ir::instruction *v) {
|
|||||||
ir::type::tile_shapes_t shapes;
|
ir::type::tile_shapes_t shapes;
|
||||||
if(auto *store = dynamic_cast<ir::store_inst*>(v))
|
if(auto *store = dynamic_cast<ir::store_inst*>(v))
|
||||||
shapes = store->get_pointer_operand()->get_type()->get_tile_shapes();
|
shapes = store->get_pointer_operand()->get_type()->get_tile_shapes();
|
||||||
|
else if(auto *atom = dynamic_cast<ir::atomic_add_inst*>(v))
|
||||||
|
shapes = atom->get_operand(0)->get_type()->get_tile_shapes();
|
||||||
else if(auto *downcast = dynamic_cast<ir::downcast_inst*>(v))
|
else if(auto *downcast = dynamic_cast<ir::downcast_inst*>(v))
|
||||||
return;
|
return;
|
||||||
else
|
else
|
||||||
@@ -233,13 +235,13 @@ void tune::run(ir::module &mod) {
|
|||||||
continue;
|
continue;
|
||||||
if(dynamic_cast<ir::load_inst*>(i) && i->get_type()->is_tile_ty()){
|
if(dynamic_cast<ir::load_inst*>(i) && i->get_type()->is_tile_ty()){
|
||||||
ir::type *ty = mod.get_builder().get_int32_ty();
|
ir::type *ty = mod.get_builder().get_int32_ty();
|
||||||
std::unique_ptr<ir::metaparameter> tmp(ir::metaparameter::create(ctx, ty, 4, 4));
|
std::unique_ptr<ir::metaparameter> tmp(ir::metaparameter::create(ctx, ty, 2, 2));
|
||||||
*params_.at(i).at("nts.d0") = *tmp;
|
*params_.at(i).at("nts.d0") = *tmp;
|
||||||
}
|
}
|
||||||
if(dynamic_cast<ir::dot_inst*>(i) && i->get_type()->is_tile_ty()){
|
if(dynamic_cast<ir::dot_inst*>(i) && i->get_type()->is_tile_ty()){
|
||||||
ir::type *ty = mod.get_builder().get_int32_ty();
|
ir::type *ty = mod.get_builder().get_int32_ty();
|
||||||
std::unique_ptr<ir::metaparameter> tmp1(ir::metaparameter::create(ctx, ty, 4, 4));
|
std::unique_ptr<ir::metaparameter> tmp1(ir::metaparameter::create(ctx, ty, 2, 2));
|
||||||
std::unique_ptr<ir::metaparameter> tmp2(ir::metaparameter::create(ctx, ty, 4, 4));
|
std::unique_ptr<ir::metaparameter> tmp2(ir::metaparameter::create(ctx, ty, 2, 2));
|
||||||
*params_.at(i).at("nts.d0") = *tmp1;
|
*params_.at(i).at("nts.d0") = *tmp1;
|
||||||
*params_.at(i).at("nts.d1") = *tmp2;
|
*params_.at(i).at("nts.d1") = *tmp2;
|
||||||
}
|
}
|
||||||
|
@@ -114,6 +114,8 @@ void shift::init(driver::stream *stream, driver::cu_module *module) {
|
|||||||
void shift::enqueue(driver::stream *stream, driver::kernel *kernel,
|
void shift::enqueue(driver::stream *stream, driver::kernel *kernel,
|
||||||
driver::buffer *a, driver::buffer *b, driver::buffer *c,
|
driver::buffer *a, driver::buffer *b, driver::buffer *c,
|
||||||
size_t TM, size_t TN, size_t nthreads) {
|
size_t TM, size_t TN, size_t nthreads) {
|
||||||
|
if(ty_ == WGRAD)
|
||||||
|
std::swap(a, b);
|
||||||
kernel->setArg(0, a);
|
kernel->setArg(0, a);
|
||||||
kernel->setArg(1, b);
|
kernel->setArg(1, b);
|
||||||
kernel->setArg(2, c);
|
kernel->setArg(2, c);
|
||||||
@@ -121,24 +123,35 @@ void shift::enqueue(driver::stream *stream, driver::kernel *kernel,
|
|||||||
kernel->setArg(4, N_);
|
kernel->setArg(4, N_);
|
||||||
kernel->setArg(5, K_);
|
kernel->setArg(5, K_);
|
||||||
kernel->setArg(6, B_*AH_*AW_);
|
kernel->setArg(6, B_*AH_*AW_);
|
||||||
kernel->setArg(7, B_);
|
kernel->setArg(7, N_);
|
||||||
kernel->setArg(8, AH_);
|
kernel->setArg(8, B_);
|
||||||
kernel->setArg(9, AW_);
|
kernel->setArg(9, AH_);
|
||||||
kernel->setArg(10, BH_);
|
kernel->setArg(10, AW_);
|
||||||
kernel->setArg(11, BW_);
|
kernel->setArg(11, BH_);
|
||||||
|
kernel->setArg(12, BW_);
|
||||||
std::array<size_t, 3> grid = {(M_ + TM - 1)/TM, (N_ + TN - 1)/TN, 1};
|
std::array<size_t, 3> grid = {(M_ + TM - 1)/TM, (N_ + TN - 1)/TN, 1};
|
||||||
|
if(ty_ == BPROP)
|
||||||
|
((driver::cu_buffer*)c)->set_zero(stream, M_*N_*4);
|
||||||
stream->enqueue(kernel, grid, {nthreads, 1, 1});
|
stream->enqueue(kernel, grid, {nthreads, 1, 1});
|
||||||
}
|
}
|
||||||
|
|
||||||
void shift::src(std::ostream &os) {
|
void shift::src(std::ostream &os) {
|
||||||
std::string AS0 = "TM", AS1 = "TK";
|
std::string AS0 = "TM", AS1 = "TK";
|
||||||
std::string BS0 = "TK", BS1 = "TN";
|
std::string BS0 = "TK", BS1 = "TN";
|
||||||
std::string bca0 = "[newaxis, :]", bca1 = "[:, newaxis]";
|
|
||||||
std::string bcb0 = "[:, newaxis]", bcb1 = "[newaxis, :]";
|
std::string bcb0 = "[:, newaxis]", bcb1 = "[newaxis, :]";
|
||||||
std::string lda0 = "*lda", lda1 = "";
|
|
||||||
std::string ldb0 = "", ldb1 = "*ldb";
|
std::string ldb0 = "", ldb1 = "*ldb";
|
||||||
std::string usea = AT_ ? "trans(a)" : "a";
|
std::string usea = AT_ ? "trans(a)" : "a";
|
||||||
std::string useb = BT_ ? "trans(b)" : "b";
|
std::string useb = BT_ ? "trans(b)" : "b";
|
||||||
|
std::string rkb = "rkb";
|
||||||
|
std::string rka = "rka";
|
||||||
|
std::string bca0 = "[newaxis, :]", bca1 = "[:, newaxis]";
|
||||||
|
std::string lda0 = "*lda", lda1 = "";
|
||||||
|
if(ty_ == FPROP){
|
||||||
|
rka = "inc";
|
||||||
|
bca0 = "";
|
||||||
|
lda0 = "";
|
||||||
|
}
|
||||||
|
|
||||||
if(AT_){
|
if(AT_){
|
||||||
std::swap(AS0, AS1);
|
std::swap(AS0, AS1);
|
||||||
std::swap(bca0, bca1);
|
std::swap(bca0, bca1);
|
||||||
@@ -149,6 +162,8 @@ void shift::src(std::ostream &os) {
|
|||||||
std::swap(bcb0, bcb1);
|
std::swap(bcb0, bcb1);
|
||||||
std::swap(ldb0, ldb1);
|
std::swap(ldb0, ldb1);
|
||||||
}
|
}
|
||||||
|
std::string AS = AS0 + ", " + AS1;
|
||||||
|
std::string BS = BS0 + ", " + BS1;
|
||||||
|
|
||||||
os <<
|
os <<
|
||||||
R"(
|
R"(
|
||||||
@@ -161,8 +176,8 @@ __constant__ int32* delta = alloc_const int32[)" << MAX_C_ << R"(];
|
|||||||
void shift(restrict read_only align(16) )" << a_ty_ << R"( *a,
|
void shift(restrict read_only align(16) )" << a_ty_ << R"( *a,
|
||||||
restrict read_only align(16) )" << b_ty_ << R"( *b,
|
restrict read_only align(16) )" << b_ty_ << R"( *b,
|
||||||
fp32 *c,
|
fp32 *c,
|
||||||
multiple_of(4) int32 M, multiple_of(4) int32 N, multiple_of(4) int32 K,
|
int32 M, int32 N, int32 K,
|
||||||
multiple_of(4) int32 lda,
|
multiple_of(4) int32 lda, multiple_of(4) int32 ldb,
|
||||||
int32 ABS, int32 AH, int32 AW, int32 AR, int32 AS) {
|
int32 ABS, int32 AH, int32 AW, int32 AR, int32 AS) {
|
||||||
int32 rxa[TM] = get_global_range[TM](0);
|
int32 rxa[TM] = get_global_range[TM](0);
|
||||||
int32 ryb[TN] = get_global_range[TN](1);
|
int32 ryb[TN] = get_global_range[TN](1);
|
||||||
@@ -170,7 +185,9 @@ void shift(restrict read_only align(16) )" << a_ty_ << R"( *a,
|
|||||||
int32 rkb[TK] = 0 ... TK;
|
int32 rkb[TK] = 0 ... TK;
|
||||||
fp32 C[TM, TN] = 0;
|
fp32 C[TM, TN] = 0;
|
||||||
int32 pad_h = AR / 2;
|
int32 pad_h = AR / 2;
|
||||||
int32 pad_w = AS / 2;
|
int32 pad_w = AS / 2;)";
|
||||||
|
if(ty_ == FPROP){
|
||||||
|
os << R"(
|
||||||
int32 rawhc[TM] = rxa / ABS;
|
int32 rawhc[TM] = rxa / ABS;
|
||||||
int32 raw[TM] = rawhc % AW;
|
int32 raw[TM] = rawhc % AW;
|
||||||
int32 rahc[TM] = rawhc / AW;
|
int32 rahc[TM] = rawhc / AW;
|
||||||
@@ -179,35 +196,86 @@ void shift(restrict read_only align(16) )" << a_ty_ << R"( *a,
|
|||||||
multiple_of(4) int32 d[TK] = *pd;
|
multiple_of(4) int32 d[TK] = *pd;
|
||||||
int1 maskh[TM] = (rah >= pad_h) && (rah < (AH - pad_h));
|
int1 maskh[TM] = (rah >= pad_h) && (rah < (AH - pad_h));
|
||||||
int1 maskw[TM] = (raw >= pad_w) && (raw < (AW - pad_w));
|
int1 maskw[TM] = (raw >= pad_w) && (raw < (AW - pad_w));
|
||||||
int1 mask[)" << AS0 << ", " << AS1 << "] = maskh" << bca1 << " && maskw" << bca1 << R"(;
|
int1 mask[TM, TK] = maskh[:, newaxis] && maskw[:, newaxis];
|
||||||
int32 inc_true[)" << AS0 << ", " << AS1 << "] = d" << bca0 << R"(;
|
int32 inc_true[TM, TK] = d[newaxis, :];
|
||||||
int32 inc_false[)" << AS0 << ", " << AS1 << "] = rka" << bca0 << R"( * lda;
|
int32 inc_false[TM, TK] = rka[newaxis, :] * lda;
|
||||||
)" << a_ty_ << "* pa[" << AS0 << ", " << AS1 << R"(] = a + rxa)" << bca1 << R"( + (mask ? inc_true : inc_false);
|
int32 inc[TM, TK] = mask ? inc_true : inc_false;)";
|
||||||
)" << b_ty_ << "* pb[" << BS0 << ", " << BS1 << "] = b + ryb" << bcb1 << " + rkb" << bcb0 << R"(*N;
|
}
|
||||||
)" << a_ty_ << " a[" << AS0 << ", " << AS1 << R"(] = *pa;
|
if(ty_ == WGRAD){
|
||||||
)" << b_ty_ << " b[" << BS0 << ", " << BS1 << R"(] = *pb;
|
os << R"(
|
||||||
|
int32 shift[TK, TN] = 0;)";
|
||||||
|
}
|
||||||
|
os << R"(
|
||||||
|
)" << a_ty_ << "* pa[" << AS << "] = a + rxa" << bca1 << " + " << rka << bca0 << lda0 << R"(;
|
||||||
|
)" << b_ty_ << "* pb[" << BS << "] = b + ryb" << bcb1 << " + " << rkb << bcb0 << ldb0 << R"(;
|
||||||
|
)" << a_ty_ << " a[" << AS << R"(] = *pa;
|
||||||
|
)" << b_ty_ << " b[" << BS << R"(] = *pb;
|
||||||
for(int32 k = K; k > 0; k = k - TK){
|
for(int32 k = K; k > 0; k = k - TK){
|
||||||
C = dot()" << usea << "," << useb << R"(, C);
|
C = dot()" << usea << "," << useb << R"(, C);
|
||||||
pb = pb + TK*N;
|
int1 checka[)" << AS << R"(] = k > TK;
|
||||||
|
int1 checkb[)" << BS << R"(] = k > TK;)";
|
||||||
|
if(ty_ == FPROP){
|
||||||
|
os << R"(
|
||||||
pd = pd + TK;
|
pd = pd + TK;
|
||||||
d = *pd;
|
d = *pd;
|
||||||
inc_true = d)" << bca0 << R"(;
|
inc_true = d[newaxis, :];
|
||||||
inc_false = TK * lda;
|
inc_false = TK * lda;
|
||||||
pa = pa + (mask ? inc_true : inc_false);
|
inc = mask ? inc_true : inc_false;
|
||||||
int1 checka[)" << AS0 << ", " << AS1 << R"(] = k > TK;
|
pa = pa + inc;
|
||||||
int1 checkb[)" << BS0 << ", " << BS1 << R"(] = k > TK;
|
@checka a = *pa;)";
|
||||||
@checka a = *pa;
|
}
|
||||||
@checkb b = *pb;
|
else{
|
||||||
|
os << R"(
|
||||||
|
pa = pa + TK)" << lda0 << R"(;
|
||||||
|
@checka a = *pa;)";
|
||||||
|
}
|
||||||
|
if(ty_ == WGRAD){
|
||||||
|
os << R"(
|
||||||
|
int32 rbwhc[TK] = rkb / ABS;
|
||||||
|
int32 rbw[TK] = rbwhc % AW;
|
||||||
|
int32 rbhc[TK] = rbwhc / AW;
|
||||||
|
int32 rbh[TK] = rbhc % AH;
|
||||||
|
int1 maskh[TK] = (rbh >= pad_h) && (rbh < (AH - pad_h));
|
||||||
|
int1 maskw[TK] = (rbw >= pad_w) && (rbw < (AW - pad_w));
|
||||||
|
int1 mask[TK, TN] = maskh[:, newaxis] && maskw[:, newaxis];
|
||||||
|
int32 inc[TK, TN] = mask ? 0 : shift;
|
||||||
|
pb = pb + TK;
|
||||||
|
)" << b_ty_ << R"(* pbb[TK, TN] = pb + inc;
|
||||||
|
@checkb b = *pbb;)";
|
||||||
|
}
|
||||||
|
else{
|
||||||
|
os << R"(
|
||||||
|
pb = pb + TK)" << ldb0 << R"(;
|
||||||
|
@checkb b = *pb;)";
|
||||||
|
}
|
||||||
|
os << R"(
|
||||||
}
|
}
|
||||||
int32 rxc[TM] = get_global_range[TM](0);
|
int32 rxc[TM] = get_global_range[TM](0);
|
||||||
int32 ryc[TN] = get_global_range[TN](1);
|
int32 ryc[TN] = get_global_range[TN](1);
|
||||||
fp32* pc[TM, TN] = c + ryc[newaxis, :]*M + rxc[:, newaxis];
|
fp32* pc[TM, TN] = c + ryc[newaxis, :]*M + rxc[:, newaxis];
|
||||||
int1 checkc0[TM] = rxc < M;
|
int1 checkc0[TM] = rxc < M;
|
||||||
int1 checkc1[TN] = ryc < N;
|
int1 checkc1[TN] = ryc < N;
|
||||||
int1 checkc[TM, TN] = checkc0[:, newaxis] && checkc1[newaxis, :];
|
int1 checkc[TM, TN] = checkc0[:, newaxis] && checkc1[newaxis, :];)";
|
||||||
@checkc *pc = C;
|
if(ty_ == BPROP){
|
||||||
|
os << R"(
|
||||||
|
int32 rcwhc[TM] = rxc / ABS;
|
||||||
|
int32 rcw[TM] = rcwhc % AW;
|
||||||
|
int32 rchc[TM] = rcwhc / AW;
|
||||||
|
int32 rch[TM] = rchc % AH;
|
||||||
|
int1 maskh[TM] = (rch >= pad_h) && (rch < (AH - pad_h));
|
||||||
|
int1 maskw[TM] = (rcw >= pad_w) && (rcw < (AW - pad_w));
|
||||||
|
int1 interior[TM, TN] = maskh[:, newaxis] && maskw[:, newaxis];
|
||||||
|
fp32* shiftpc[TM, TN] = pc + 0;
|
||||||
|
pc = interior ? shiftpc : pc;
|
||||||
|
@checkc __atomic_add(pc, C);
|
||||||
|
)";
|
||||||
}
|
}
|
||||||
)";
|
else{
|
||||||
|
os << R"(
|
||||||
|
@checkc *pc = C;)";
|
||||||
|
}
|
||||||
|
os << R"(
|
||||||
|
})";
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
@@ -308,6 +308,10 @@ value *builder::create_atomic_cas(value *ptr, value *cmp, value *val, const std:
|
|||||||
return insert(atomic_cas_inst::create(ptr, cmp, val, name));
|
return insert(atomic_cas_inst::create(ptr, cmp, val, name));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
value *builder::create_atomic_add(value *ptr, value *val, const std::string &name){
|
||||||
|
return insert(atomic_add_inst::create(ptr, val, name));
|
||||||
|
}
|
||||||
|
|
||||||
value *builder::create_dot(value *A, value *B, value *C, const std::string &name) {
|
value *builder::create_dot(value *A, value *B, value *C, const std::string &name) {
|
||||||
return insert(dot_inst::create_nn(A, B, C, name));
|
return insert(dot_inst::create_nn(A, B, C, name));
|
||||||
}
|
}
|
||||||
|
@@ -620,6 +620,19 @@ atomic_cas_inst::atomic_cas_inst(value *ptr, value *cmp, value *val, const std::
|
|||||||
instruction* atomic_cas_inst::create(value *ptr, value *cmp, value *val, const std::string &name, instruction *next) {
|
instruction* atomic_cas_inst::create(value *ptr, value *cmp, value *val, const std::string &name, instruction *next) {
|
||||||
return new atomic_cas_inst(ptr, cmp, val, name, next);
|
return new atomic_cas_inst(ptr, cmp, val, name, next);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// atomic add
|
||||||
|
|
||||||
|
atomic_add_inst::atomic_add_inst(value *ptr, value *val, const std::string &name, instruction *next)
|
||||||
|
: builtin_inst(ptr->get_type()->get_pointer_element_ty(), 2, 1, name, next) {
|
||||||
|
set_operand(0, ptr);
|
||||||
|
set_operand(1, val);
|
||||||
|
}
|
||||||
|
|
||||||
|
instruction* atomic_add_inst::create(value *ptr, value *val, const std::string &name, instruction *next) {
|
||||||
|
return new atomic_add_inst(ptr, val, name, next);
|
||||||
|
}
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// intrinsic instructions
|
// intrinsic instructions
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
@@ -59,8 +59,12 @@ unsigned type::get_pointer_address_space() const {
|
|||||||
}
|
}
|
||||||
|
|
||||||
type * type::get_pointer_element_ty() const {
|
type * type::get_pointer_element_ty() const {
|
||||||
assert(is_pointer_ty());
|
type *ptr_ty = get_scalar_ty();
|
||||||
return ((pointer_type*)this)->get_element_ty();
|
assert(ptr_ty->is_pointer_ty());
|
||||||
|
type *scalar_ty = ((pointer_type*)ptr_ty)->get_element_ty();
|
||||||
|
if(is_tile_ty())
|
||||||
|
return tile_type::get_same_shapes(scalar_ty, (type*)this);
|
||||||
|
return scalar_ty;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@@ -130,6 +130,13 @@ ir::value* atomic_cas_expression::codegen(ir::module *mod) const {
|
|||||||
return mod->get_builder().create_atomic_cas(ptr, cmp, val);
|
return mod->get_builder().create_atomic_cas(ptr, cmp, val);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// atomic add
|
||||||
|
ir::value* atomic_add_expression::codegen(ir::module *mod) const {
|
||||||
|
ir::value *ptr = ptr_->codegen(mod);
|
||||||
|
ir::value *val = val_->codegen(mod);
|
||||||
|
return mod->get_builder().create_atomic_add(ptr, val);
|
||||||
|
}
|
||||||
|
|
||||||
// matmul
|
// matmul
|
||||||
ir::value* matmul_expression::codegen(ir::module *mod) const {
|
ir::value* matmul_expression::codegen(ir::module *mod) const {
|
||||||
ir::value *A = A_->codegen(mod);
|
ir::value *A = A_->codegen(mod);
|
||||||
|
@@ -53,8 +53,6 @@ void node::implicit_cast(ir::builder &builder, ir::value *&lhs, ir::value *&rhs,
|
|||||||
ir::type *right_ty = rhs->get_type()->get_scalar_ty();
|
ir::type *right_ty = rhs->get_type()->get_scalar_ty();
|
||||||
// One operand is pointer
|
// One operand is pointer
|
||||||
if(left_ty->is_pointer_ty() || right_ty->is_pointer_ty()){
|
if(left_ty->is_pointer_ty() || right_ty->is_pointer_ty()){
|
||||||
if(left_ty->is_pointer_ty() && right_ty->is_pointer_ty())
|
|
||||||
throw std::runtime_error("invalid operands");
|
|
||||||
if(right_ty->is_pointer_ty())
|
if(right_ty->is_pointer_ty())
|
||||||
std::swap(lhs, rhs);
|
std::swap(lhs, rhs);
|
||||||
is_ptr = true;
|
is_ptr = true;
|
||||||
|
@@ -31,9 +31,6 @@ ir::value* expression_statement::codegen(ir::module *mod) const{
|
|||||||
ir::builder &builder = mod->get_builder();
|
ir::builder &builder = mod->get_builder();
|
||||||
ir::basic_block *block = builder.get_insert_block();
|
ir::basic_block *block = builder.get_insert_block();
|
||||||
if(pred_) {
|
if(pred_) {
|
||||||
// check that it is an assignment
|
|
||||||
assignment_expression *assignment = dynamic_cast<assignment_expression*>(expr_);
|
|
||||||
assert(assignment);
|
|
||||||
// generate mask
|
// generate mask
|
||||||
ir::value *pred = pred_->codegen(mod);
|
ir::value *pred = pred_->codegen(mod);
|
||||||
ir::mask_inst *mask = (ir::mask_inst*)builder.create_mask(pred);
|
ir::mask_inst *mask = (ir::mask_inst*)builder.create_mask(pred);
|
||||||
@@ -53,8 +50,10 @@ ir::value* expression_statement::codegen(ir::module *mod) const{
|
|||||||
// merge with psi
|
// merge with psi
|
||||||
ir::psi_inst *psi = (ir::psi_inst*)builder.create_merge(mask->get_result(0), expr,
|
ir::psi_inst *psi = (ir::psi_inst*)builder.create_merge(mask->get_result(0), expr,
|
||||||
mask->get_result(1), ir::undef_value::get(ty));
|
mask->get_result(1), ir::undef_value::get(ty));
|
||||||
std::string name = ((named_expression*)assignment->lvalue())->id()->name();
|
if(assignment_expression *assignment = dynamic_cast<assignment_expression*>(expr_)){
|
||||||
mod->set_value(name, psi);
|
std::string name = ((named_expression*)assignment->lvalue())->id()->name();
|
||||||
|
mod->set_value(name, psi);
|
||||||
|
}
|
||||||
return psi;
|
return psi;
|
||||||
}
|
}
|
||||||
return expr_->codegen(mod);
|
return expr_->codegen(mod);
|
||||||
|
Reference in New Issue
Block a user