preparing the field for tensor cores transposes
This commit is contained in:
@@ -9,7 +9,7 @@
|
|||||||
|
|
||||||
int main() {
|
int main() {
|
||||||
bool AT = false;
|
bool AT = false;
|
||||||
bool BT = true;
|
bool BT = false;
|
||||||
// initialize default compute device
|
// initialize default compute device
|
||||||
auto context = triton::driver::backend::contexts::get_default();
|
auto context = triton::driver::backend::contexts::get_default();
|
||||||
// matrix multiplication parameters
|
// matrix multiplication parameters
|
||||||
|
@@ -49,7 +49,7 @@ class DotOp : public OpKernel {
|
|||||||
triton::driver::cu_buffer db(ctx, (CUdeviceptr)b.flat<Eigen::half>().data(), false);
|
triton::driver::cu_buffer db(ctx, (CUdeviceptr)b.flat<Eigen::half>().data(), false);
|
||||||
triton::driver::cu_buffer dc(ctx, (CUdeviceptr)c->flat<float>().data(), false);
|
triton::driver::cu_buffer dc(ctx, (CUdeviceptr)c->flat<float>().data(), false);
|
||||||
// template
|
// template
|
||||||
triton::dnn::gemm dot(M, N, K, false, true, "fp16", "fp16", 4, 4);
|
triton::dnn::gemm dot(M, N, K, false, false, "fp16", "fp16", 4, 4);
|
||||||
dot.enqueue(stream, {&da, &db, &dc});
|
dot.enqueue(stream, {&da, &db, &dc});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@@ -23,7 +23,7 @@ def run_dot():
|
|||||||
result = sess.run([c], feed_dict = {a: ha,
|
result = sess.run([c], feed_dict = {a: ha,
|
||||||
b: hb})[0]
|
b: hb})[0]
|
||||||
# Test
|
# Test
|
||||||
hresult = np.dot(ha.T, hb).T
|
hresult = np.dot(ha.T, hb.T).T
|
||||||
dif = np.abs(result - hresult)
|
dif = np.abs(result - hresult)
|
||||||
print(hresult)
|
print(hresult)
|
||||||
print(result)
|
print(result)
|
||||||
|
@@ -550,6 +550,7 @@ private:
|
|||||||
std::string repr_impl() const { return std::string("dot.") + ((AT_==NoTrans)?"n":"t") + ((BT_==NoTrans)?"n":"t"); }
|
std::string repr_impl() const { return std::string("dot.") + ((AT_==NoTrans)?"n":"t") + ((BT_==NoTrans)?"n":"t"); }
|
||||||
|
|
||||||
public:
|
public:
|
||||||
|
static instruction *create(value *A, value *B, value *C, bool AT, bool BT, const std::string &name = "", instruction *next = nullptr);
|
||||||
static instruction* create_nn(value *A, value *B, value *C, const std::string &name = "", instruction *next = nullptr);
|
static instruction* create_nn(value *A, value *B, value *C, const std::string &name = "", instruction *next = nullptr);
|
||||||
static instruction* create_nt(value *A, value *B, value *C, const std::string &name = "", instruction *next = nullptr);
|
static instruction* create_nt(value *A, value *B, value *C, const std::string &name = "", instruction *next = nullptr);
|
||||||
static instruction* create_tn(value *A, value *B, value *C, const std::string &name = "", instruction *next = nullptr);
|
static instruction* create_tn(value *A, value *B, value *C, const std::string &name = "", instruction *next = nullptr);
|
||||||
|
@@ -68,9 +68,9 @@ public:
|
|||||||
target_(target) { }
|
target_(target) { }
|
||||||
|
|
||||||
void target_independent(ir::module &module) {
|
void target_independent(ir::module &module) {
|
||||||
// ir::print(module, std::cout);
|
|
||||||
optimize_dot.run(module);
|
optimize_dot.run(module);
|
||||||
optimize_trans.run(module);
|
optimize_trans.run(module);
|
||||||
|
// ir::print(module, std::cout);
|
||||||
}
|
}
|
||||||
|
|
||||||
void target_dependent(ir::module &module) {
|
void target_dependent(ir::module &module) {
|
||||||
|
@@ -11,6 +11,21 @@ inline bool is_trans(ir::value *v){
|
|||||||
return dynamic_cast<ir::trans_inst*>(v) != nullptr;
|
return dynamic_cast<ir::trans_inst*>(v) != nullptr;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
inline bool is_hmma(ir::value *v){
|
||||||
|
bool result = false;
|
||||||
|
if(auto *x = dynamic_cast<ir::dot_inst*>(v)){
|
||||||
|
ir::value *a = x->get_operand(0);
|
||||||
|
ir::type *a_ty = a->get_type();
|
||||||
|
ir::value *b = x->get_operand(1);
|
||||||
|
ir::type *b_ty = b->get_type();
|
||||||
|
// inputs have to be FP16
|
||||||
|
result = a_ty->get_scalar_ty()->is_half_ty() && b_ty->get_scalar_ty()->is_half_ty();
|
||||||
|
// reduction has to be multiple of 4
|
||||||
|
result = result && ((a_ty->get_tile_shapes()[1]->get_value() % 4) == 0);
|
||||||
|
}
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
void optimize_dot::run(ir::module &mod) {
|
void optimize_dot::run(ir::module &mod) {
|
||||||
ir::builder &builder = mod.get_builder();
|
ir::builder &builder = mod.get_builder();
|
||||||
std::vector<ir::instruction*> to_delete;
|
std::vector<ir::instruction*> to_delete;
|
||||||
@@ -19,28 +34,49 @@ void optimize_dot::run(ir::module &mod) {
|
|||||||
for(ir::basic_block *block: fn->blocks())
|
for(ir::basic_block *block: fn->blocks())
|
||||||
for(ir::instruction *i: block->get_inst_list())
|
for(ir::instruction *i: block->get_inst_list())
|
||||||
if(auto dot = dynamic_cast<ir::dot_inst*>(i))
|
if(auto dot = dynamic_cast<ir::dot_inst*>(i))
|
||||||
if(dot->get_operand(1)->get_type()->get_tile_shapes()[1]->get_value() != 1)
|
if(dot->get_operand(1)->get_type()->get_tile_shapes()[1]->get_value() != 1){
|
||||||
if(!dot->is_a_trans() && !dot->is_b_trans()){
|
|
||||||
builder.set_insert_point(i);
|
builder.set_insert_point(i);
|
||||||
ir::value *A = dot->get_operand(0);
|
ir::value *A = dot->get_operand(0);
|
||||||
ir::value *B = dot->get_operand(1);
|
ir::value *B = dot->get_operand(1);
|
||||||
ir::value *D = dot->get_operand(2);
|
ir::value *D = dot->get_operand(2);
|
||||||
|
bool trans_a = is_trans(A);
|
||||||
|
bool trans_b = is_trans(B);
|
||||||
|
|
||||||
|
if(!dot->is_a_trans() && !dot->is_b_trans()){
|
||||||
|
if(is_hmma(dot)){
|
||||||
|
ir::value *AA = A;
|
||||||
|
ir::value *BB = B;
|
||||||
|
if(trans_a){
|
||||||
|
AA = ((ir::trans_inst*)A)->get_operand(0);
|
||||||
|
to_delete.push_back((ir::instruction*)A);
|
||||||
|
}
|
||||||
|
if(trans_b){
|
||||||
|
BB = ((ir::trans_inst*)B)->get_operand(0);
|
||||||
|
to_delete.push_back((ir::instruction*)B);
|
||||||
|
}
|
||||||
|
ir::instruction *dot_atbt = builder.insert(ir::dot_inst::create(AA, BB, D, trans_a, trans_b));
|
||||||
|
dot->replace_all_uses_with(dot_atbt);
|
||||||
|
to_delete.push_back(dot);
|
||||||
|
}
|
||||||
|
else{
|
||||||
// dot(op(a), trans(b))
|
// dot(op(a), trans(b))
|
||||||
if(is_trans(B)){
|
if(trans_b){
|
||||||
ir::value* BN = ((ir::trans_inst*)B)->get_operand(0);
|
ir::value* BB = ((ir::trans_inst*)B)->get_operand(0);
|
||||||
ir::instruction *NT = builder.insert(ir::dot_inst::create_nt(A, BN, D));
|
ir::instruction *NT = builder.insert(ir::dot_inst::create_nt(A, BB, D));
|
||||||
dot->replace_all_uses_with(NT);
|
dot->replace_all_uses_with(NT);
|
||||||
to_delete.push_back((ir::instruction*)B);
|
to_delete.push_back((ir::instruction*)B);
|
||||||
to_delete.push_back(dot);
|
to_delete.push_back(dot);
|
||||||
}
|
}
|
||||||
// dot(op(a), b)
|
// dot(op(a), b)
|
||||||
if(!is_trans(B)){
|
if(!trans_b){
|
||||||
ir::value* BT = builder.create_trans(B);
|
ir::value* BB = builder.create_trans(B);
|
||||||
ir::instruction *NT = builder.insert(ir::dot_inst::create_nt(A, BT, D));
|
ir::instruction *NT = builder.insert(ir::dot_inst::create_nt(A, BB, D));
|
||||||
dot->replace_all_uses_with(NT);
|
dot->replace_all_uses_with(NT);
|
||||||
to_delete.push_back(dot);
|
to_delete.push_back(dot);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
for(ir::instruction* i: to_delete)
|
for(ir::instruction* i: to_delete)
|
||||||
i->erase_from_parent();
|
i->erase_from_parent();
|
||||||
|
@@ -22,10 +22,8 @@ bool is_hmma(ir::value *v){
|
|||||||
ir::type *a_ty = a->get_type();
|
ir::type *a_ty = a->get_type();
|
||||||
ir::value *b = x->get_operand(1);
|
ir::value *b = x->get_operand(1);
|
||||||
ir::type *b_ty = b->get_type();
|
ir::type *b_ty = b->get_type();
|
||||||
// only NT supported
|
|
||||||
result = !x->is_a_trans() && x->is_b_trans();
|
|
||||||
// inputs have to be FP16
|
// inputs have to be FP16
|
||||||
result = result && a_ty->get_scalar_ty()->is_half_ty() && b_ty->get_scalar_ty()->is_half_ty();
|
result = a_ty->get_scalar_ty()->is_half_ty() && b_ty->get_scalar_ty()->is_half_ty();
|
||||||
// reduction has to be multiple of 4
|
// reduction has to be multiple of 4
|
||||||
result = result && ((a_ty->get_tile_shapes()[1]->get_value() % 4) == 0);
|
result = result && ((a_ty->get_tile_shapes()[1]->get_value() % 4) == 0);
|
||||||
}
|
}
|
||||||
@@ -223,7 +221,7 @@ void tune::run(ir::module &mod) {
|
|||||||
}
|
}
|
||||||
else {
|
else {
|
||||||
ir::metaparameter *fpw = ir::metaparameter::create(ctx, ty, 2, 2);
|
ir::metaparameter *fpw = ir::metaparameter::create(ctx, ty, 2, 2);
|
||||||
ir::metaparameter *wpt = ir::metaparameter::create(ctx, ty, 1, 4);
|
ir::metaparameter *wpt = ir::metaparameter::create(ctx, ty, 2, 4);
|
||||||
connected_components(node, {fpw, wpt}, {"fpw", "wpt"}, nodes_, dependencies_, group_id++);
|
connected_components(node, {fpw, wpt}, {"fpw", "wpt"}, nodes_, dependencies_, group_id++);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@@ -527,6 +527,14 @@ dot_inst::dot_inst(value *A, value *B, value *C, TransT AT, TransT BT,
|
|||||||
set_operand(2, C);
|
set_operand(2, C);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
instruction *dot_inst::create(value *A, value *B, value *C,
|
||||||
|
bool AT, bool BT,
|
||||||
|
const std::string &name, instruction *next) {
|
||||||
|
TransT OPA = AT ? Trans : NoTrans;
|
||||||
|
TransT OPB = BT ? Trans : NoTrans;
|
||||||
|
return new dot_inst(A, B, C, OPA, OPB, name, next);
|
||||||
|
}
|
||||||
|
|
||||||
instruction *dot_inst::create_nn(value *A, value *B, value *C,
|
instruction *dot_inst::create_nn(value *A, value *B, value *C,
|
||||||
const std::string &name, instruction *next) {
|
const std::string &name, instruction *next) {
|
||||||
return new dot_inst(A, B, C, NoTrans, NoTrans, name, next);
|
return new dot_inst(A, B, C, NoTrans, NoTrans, name, next);
|
||||||
|
@@ -135,13 +135,8 @@ std::vector<unsigned> jit::get_valid(const char *name, const char *src) {
|
|||||||
unsigned i = 0;
|
unsigned i = 0;
|
||||||
for(ir::metaparameter *mp: mps)
|
for(ir::metaparameter *mp: mps)
|
||||||
mp->set_value(params[i++]);
|
mp->set_value(params[i++]);
|
||||||
passes.target_independent(tt_module);
|
|
||||||
passes.tune.init(tt_module);
|
passes.tune.init(tt_module);
|
||||||
passes.tune.check_constraints(errors);
|
passes.tune.check_constraints(errors);
|
||||||
// for(auto e: errors)
|
|
||||||
// for(auto x: e.second)
|
|
||||||
// std::cout << x << std::endl;
|
|
||||||
// std::cout << "-----" << std::endl;
|
|
||||||
if(!errors.empty())
|
if(!errors.empty())
|
||||||
return;
|
return;
|
||||||
result = params;
|
result = params;
|
||||||
|
Reference in New Issue
Block a user