[PYTHON][KERNEL] Added benchmarking functionalities for kernels
This commit is contained in:
@@ -56,14 +56,12 @@ void disassociate::run(ir::module &mod) {
|
|||||||
bld.set_insert_point(y);
|
bld.set_insert_point(y);
|
||||||
bld.insert(cloned);
|
bld.insert(cloned);
|
||||||
clone_map[y] = cloned;
|
clone_map[y] = cloned;
|
||||||
// replace in above level
|
// replace operands of parents
|
||||||
if(depth > 1){
|
if(depth > 1)
|
||||||
for(ir::user* ux: x.second.at(depth - 1))
|
for(ir::user* ux: x.second.at(depth - 1))
|
||||||
clone_map.at((ir::instruction*)ux)->replace_uses_of_with(y, cloned);
|
clone_map.at((ir::instruction*)ux)->replace_uses_of_with(y, cloned);
|
||||||
}
|
else
|
||||||
else{
|
|
||||||
x.first->replace_uses_of_with(y, cloned);
|
x.first->replace_uses_of_with(y, cloned);
|
||||||
}
|
|
||||||
}
|
}
|
||||||
depth += 1;
|
depth += 1;
|
||||||
}
|
}
|
||||||
|
@@ -2,11 +2,11 @@ import numpy as np
|
|||||||
import triton
|
import triton
|
||||||
|
|
||||||
def run_tf():
|
def run_tf():
|
||||||
M, N, K = 128, 128, 128
|
M, N, K = 2048, 2048, 2048
|
||||||
a = tf.placeholder(tf.float32, shape=[M, K])
|
a = tf.placeholder(tf.float32, shape=[M, K])
|
||||||
b = tf.placeholder(tf.float32, shape=[N, K])
|
b = tf.placeholder(tf.float32, shape=[N, K])
|
||||||
tr_c = triton.ops.dot(a, b, transpose_a = False, transpose_b = True)
|
tr_c = triton.ops.dot(a, b, transpose_a = False, transpose_b = True, bench=10)
|
||||||
tr_d = triton.ops.dot(tr_c, b, transpose_a = True, transpose_b = False)
|
tr_d = triton.ops.dot(tr_c, b, transpose_a = True, transpose_b = False, bench=10)
|
||||||
tf_c = tf.matmul(a, b, transpose_a = False, transpose_b = True)
|
tf_c = tf.matmul(a, b, transpose_a = False, transpose_b = True)
|
||||||
tf_d = tf.matmul(tf_c, b, transpose_a = True, transpose_b = False)
|
tf_d = tf.matmul(tf_c, b, transpose_a = True, transpose_b = False)
|
||||||
# Gradient
|
# Gradient
|
||||||
@@ -20,15 +20,20 @@ def run_tf():
|
|||||||
sess.run(tf.global_variables_initializer())
|
sess.run(tf.global_variables_initializer())
|
||||||
result = sess.run([tr_da, tf_da], feed_dict = {a: ha,
|
result = sess.run([tr_da, tf_da], feed_dict = {a: ha,
|
||||||
b: hb})
|
b: hb})
|
||||||
|
# Benchmark
|
||||||
|
nanosec = triton.bench_registry[tr_d]
|
||||||
|
print('NANOSEC: ', nanosec)
|
||||||
|
print('TFLOPS:', 2. * M * N * K / nanosec * 1e-3)
|
||||||
# Test
|
# Test
|
||||||
print(result[0][0])
|
print(result[0][0])
|
||||||
print(result[1][0])
|
print(result[1][0])
|
||||||
dif = np.abs(result[0][0] - result[1][0])
|
dif = np.abs(result[0][0] - result[1][0])
|
||||||
print("dif: %f" % np.max(dif))
|
print("dif: %f" % np.max(dif))
|
||||||
|
|
||||||
|
|
||||||
def run_torch():
|
def run_torch():
|
||||||
torch.manual_seed(0)
|
torch.manual_seed(0)
|
||||||
M, N, K = 128, 128, 128
|
M, N, K = 2048, 2048, 2048
|
||||||
a = torch.randn(M, K).cuda()
|
a = torch.randn(M, K).cuda()
|
||||||
b = torch.randn(K, N).cuda()
|
b = torch.randn(K, N).cuda()
|
||||||
a.requires_grad_(True)
|
a.requires_grad_(True)
|
||||||
@@ -37,9 +42,8 @@ def run_torch():
|
|||||||
torch_d = torch.matmul(torch.t(torch_c), b)
|
torch_d = torch.matmul(torch.t(torch_c), b)
|
||||||
torch_y = torch.mean(torch_d)
|
torch_y = torch.mean(torch_d)
|
||||||
triton_c = triton.ops.dot(a, b, False, True)
|
triton_c = triton.ops.dot(a, b, False, True)
|
||||||
triton_d = triton.ops.dot(triton_c, b, True, False)
|
triton_d = triton.ops.dot(triton_c, b, True, False, 1)
|
||||||
triton_y = torch.mean(triton_d)
|
triton_y = torch.mean(triton_d)
|
||||||
|
|
||||||
# torch gradient
|
# torch gradient
|
||||||
torch_y.backward()
|
torch_y.backward()
|
||||||
torch_da = a.grad.clone()
|
torch_da = a.grad.clone()
|
||||||
@@ -51,6 +55,9 @@ def run_torch():
|
|||||||
triton_da = a.grad.clone()
|
triton_da = a.grad.clone()
|
||||||
triton_db = b.grad.clone()
|
triton_db = b.grad.clone()
|
||||||
|
|
||||||
|
nanosec = triton.bench_registry[triton_d]
|
||||||
|
print(nanosec)
|
||||||
|
print('TFLOPS:', 2. * M * N * K / nanosec * 1e-3)
|
||||||
print('Diff DA:', (torch_da - triton_da).max())
|
print('Diff DA:', (torch_da - triton_da).max())
|
||||||
print('Diff DB:', (torch_db - triton_db).max())
|
print('Diff DB:', (torch_db - triton_db).max())
|
||||||
|
|
||||||
|
@@ -12,7 +12,8 @@ from tensorflow.python.ops import gradient_checker
|
|||||||
|
|
||||||
one = 0
|
one = 0
|
||||||
out = 0
|
out = 0
|
||||||
bench = 0
|
bench = 10
|
||||||
|
|
||||||
class ProdKeyTest(tf.test.TestCase):
|
class ProdKeyTest(tf.test.TestCase):
|
||||||
|
|
||||||
def testEinsum(self):
|
def testEinsum(self):
|
||||||
@@ -36,9 +37,9 @@ class ProdKeyTest(tf.test.TestCase):
|
|||||||
# key_dim = 16
|
# key_dim = 16
|
||||||
|
|
||||||
for a_shape, b_shape, c_shape, einsum in [
|
for a_shape, b_shape, c_shape, einsum in [
|
||||||
[ [ 4, 8, 8 ], [ 8, 8 ], [ 4, 8, 8 ], "btc,ck->btk" ],
|
#[ [ 4, 8, 8 ], [ 8, 8 ], [ 4, 8, 8 ], "btc,ck->btk" ],
|
||||||
[ [ 4, 1024, 1024 ], [ 1024, 512 ], [ 4, 1024, 512 ], "btc,ck->btk" ],
|
[ [4, 2048, 2048 ], [ 2048, 2048 ], [4, 2048, 2048 ], "btc,ck->btk" ],
|
||||||
[ (batch_dim, ctx_dim, head_dim, 2, key_dim//2),(head_dim, 2, n_keys, key_dim//2), (batch_dim, ctx_dim, head_dim, 2, n_keys), "bchak,hank->bchan" ],
|
#[ (batch_dim, ctx_dim, head_dim, 2, key_dim//2),(head_dim, 2, n_keys, key_dim//2), (batch_dim, ctx_dim, head_dim, 2, n_keys), "bchak,hank->bchan" ],
|
||||||
]:
|
]:
|
||||||
|
|
||||||
if one:
|
if one:
|
||||||
@@ -57,7 +58,7 @@ class ProdKeyTest(tf.test.TestCase):
|
|||||||
e = tf.placeholder(tf.float32, c_shape, name="e")
|
e = tf.placeholder(tf.float32, c_shape, name="e")
|
||||||
feed_dict = { a:A, b:B, e:E }
|
feed_dict = { a:A, b:B, e:E }
|
||||||
|
|
||||||
cc = triton.ops.einsum(einsum, a, b)
|
cc = triton.ops.einsum(einsum, a, b, bench=bench)
|
||||||
|
|
||||||
# error = gradient_checker.compute_gradient_error(a, a_shape, c, c_shape, delta=1e-1, extra_feed_dict={ b:B }) #
|
# error = gradient_checker.compute_gradient_error(a, a_shape, c, c_shape, delta=1e-1, extra_feed_dict={ b:B }) #
|
||||||
# print(error)
|
# print(error)
|
||||||
@@ -71,8 +72,12 @@ class ProdKeyTest(tf.test.TestCase):
|
|||||||
# c, = sess.run( [ c, ], feed_dict )
|
# c, = sess.run( [ c, ], feed_dict )
|
||||||
c, da, db = sess.run( [ cc, da, db ], feed_dict )
|
c, da, db = sess.run( [ cc, da, db ], feed_dict )
|
||||||
|
|
||||||
if bench == 0:
|
if bench > 0:
|
||||||
|
nanosec = triton.bench_registry[cc]
|
||||||
|
print(A.shape, B.shape)
|
||||||
|
print(nanosec)
|
||||||
|
|
||||||
|
else:
|
||||||
C = np.einsum(einsum, A, B)
|
C = np.einsum(einsum, A, B)
|
||||||
id = cc.op.get_attr('id')
|
id = cc.op.get_attr('id')
|
||||||
ctx = triton.ops._einsum.contexts[id]
|
ctx = triton.ops._einsum.contexts[id]
|
||||||
|
@@ -20,13 +20,13 @@ using namespace triton;
|
|||||||
|
|
||||||
namespace rt = triton::runtime;
|
namespace rt = triton::runtime;
|
||||||
|
|
||||||
|
|
||||||
/* TF triton op properties */
|
|
||||||
|
|
||||||
std::map<size_t, std::shared_ptr<rt::function::grid_fn_ty>> id_grid_map;
|
std::map<size_t, std::shared_ptr<rt::function::grid_fn_ty>> id_grid_map;
|
||||||
std::map<size_t, std::shared_ptr<rt::function>> id_fn_map;
|
std::map<size_t, std::shared_ptr<rt::function>> id_fn_map;
|
||||||
|
std::map<size_t, double> fp64scalar_map;
|
||||||
std::map<size_t, int64_t> i64scalar_map;
|
std::map<size_t, int64_t> i64scalar_map;
|
||||||
|
|
||||||
|
/* Grid map */
|
||||||
|
|
||||||
void register_grid(size_t id,
|
void register_grid(size_t id,
|
||||||
const rt::function::grid_fn_ty& grid_fn) {
|
const rt::function::grid_fn_ty& grid_fn) {
|
||||||
id_grid_map[id].reset(new rt::function::grid_fn_ty(grid_fn));
|
id_grid_map[id].reset(new rt::function::grid_fn_ty(grid_fn));
|
||||||
@@ -36,6 +36,8 @@ void delete_grid(size_t id) {
|
|||||||
id_grid_map.erase(id);
|
id_grid_map.erase(id);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/* Function map */
|
||||||
|
|
||||||
void register_fn(size_t id,
|
void register_fn(size_t id,
|
||||||
const std::string& src,
|
const std::string& src,
|
||||||
const rt::function::options_space_t& opt) {
|
const rt::function::options_space_t& opt) {
|
||||||
@@ -56,8 +58,11 @@ size_t make_op_id() {
|
|||||||
return id_fn_map.size();
|
return id_fn_map.size();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/* TF scalar wrapper */
|
||||||
size_t make_scalar_id() {
|
size_t make_scalar_id() {
|
||||||
return i64scalar_map.size();
|
size_t ret = i64scalar_map.size();
|
||||||
|
i64scalar_map[ret] = int64_t();
|
||||||
|
return ret;
|
||||||
}
|
}
|
||||||
|
|
||||||
bool has_scalar(size_t id) {
|
bool has_scalar(size_t id) {
|
||||||
@@ -135,8 +140,9 @@ void gen_make_handles(std::ostream &os, const std::vector<ir::argument*>& args)
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void gen_make_launch_function(std::ostream &os, const std::vector<ir::argument*>& args) {
|
void gen_make_launch_function(std::ostream &os, int num_outputs, const std::vector<ir::argument*>& args) {
|
||||||
os << " (*id_fn_map.at(id_))({";
|
os << " std::function<void()> run = [&](){\n ";
|
||||||
|
os << " (*id_fn_map.at(id_))({";
|
||||||
for(unsigned i = 0; i < args.size() ; i++){
|
for(unsigned i = 0; i < args.size() ; i++){
|
||||||
ir::argument *arg = args[i];
|
ir::argument *arg = args[i];
|
||||||
std::string name = arg->get_name();
|
std::string name = arg->get_name();
|
||||||
@@ -146,7 +152,11 @@ void gen_make_launch_function(std::ostream &os, const std::vector<ir::argument*>
|
|||||||
os << ", ";
|
os << ", ";
|
||||||
os << name;
|
os << name;
|
||||||
}
|
}
|
||||||
os << "}, *id_grid_map.at(id_), stream); \n";
|
os << "}, *id_grid_map.at(id_), stream);\n";
|
||||||
|
os << " };\n ";
|
||||||
|
os << " run();";
|
||||||
|
os << " if(bench_ > 0)\n ";
|
||||||
|
os << " i64scalar_map[id_] = triton::tools::bench(run, stream);\n ";
|
||||||
}
|
}
|
||||||
|
|
||||||
void gen_tf_register_kernel_builder(std::ostream &os, const std::string &name,
|
void gen_tf_register_kernel_builder(std::ostream &os, const std::string &name,
|
||||||
@@ -186,7 +196,9 @@ void gen_tf_register_op(std::ostream &os, const std::string &name,
|
|||||||
throw std::runtime_error("unknown output");
|
throw std::runtime_error("unknown output");
|
||||||
os << " .Output(\"out" << i << ": T" << idx << "\")\n";
|
os << " .Output(\"out" << i << ": T" << idx << "\")\n";
|
||||||
}
|
}
|
||||||
os << " .Attr(\"id: int\")" << std::endl;
|
os << " .Attr(\"id: int\")\n";
|
||||||
|
os << " .Attr(\"bench: int\")\n";
|
||||||
|
os << " .Attr(\"bench_id: int\")\n";
|
||||||
os << ";\n";
|
os << ";\n";
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -247,6 +259,7 @@ std::tuple<std::string,
|
|||||||
#include "triton/driver/backend.h"
|
#include "triton/driver/backend.h"
|
||||||
#include "triton/driver/stream.h"
|
#include "triton/driver/stream.h"
|
||||||
#include "triton/runtime/function.h"
|
#include "triton/runtime/function.h"
|
||||||
|
#include "triton/tools/bench.hpp"
|
||||||
|
|
||||||
#define EIGEN_USE_GPU
|
#define EIGEN_USE_GPU
|
||||||
#include "tensorflow/core/framework/op.h"
|
#include "tensorflow/core/framework/op.h"
|
||||||
@@ -260,13 +273,15 @@ namespace drv = triton::driver;
|
|||||||
|
|
||||||
extern std::map<size_t, std::shared_ptr<rt::function::grid_fn_ty>> id_grid_map;
|
extern std::map<size_t, std::shared_ptr<rt::function::grid_fn_ty>> id_grid_map;
|
||||||
extern std::map<size_t, std::shared_ptr<rt::function>> id_fn_map;
|
extern std::map<size_t, std::shared_ptr<rt::function>> id_fn_map;
|
||||||
|
extern std::map<size_t, int64_t> i64scalar_map;
|
||||||
|
|
||||||
class )" << opname << R"(: public OpKernel {
|
class )" << opname << R"(: public OpKernel {
|
||||||
public:
|
public:
|
||||||
explicit )" << opname << R"((OpKernelConstruction* context)
|
explicit )" << opname << R"((OpKernelConstruction* context)
|
||||||
: OpKernel(context) {
|
: OpKernel(context) {
|
||||||
OP_REQUIRES_OK(context, context->GetAttr("id", &id_));
|
OP_REQUIRES_OK(context, context->GetAttr("id", &id_));
|
||||||
|
OP_REQUIRES_OK(context, context->GetAttr("bench", &bench_));
|
||||||
|
OP_REQUIRES_OK(context, context->GetAttr("bench_id", &bench_id_));
|
||||||
}
|
}
|
||||||
|
|
||||||
void Compute(OpKernelContext* context){
|
void Compute(OpKernelContext* context){
|
||||||
@@ -291,12 +306,14 @@ oss << R"(
|
|||||||
oss << R"(
|
oss << R"(
|
||||||
// launch function
|
// launch function
|
||||||
)";
|
)";
|
||||||
gen_make_launch_function(oss, fn->args());
|
gen_make_launch_function(oss, outputs.size(), fn->args());
|
||||||
oss << R"(
|
oss << R"(
|
||||||
}
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
int id_;
|
int id_;
|
||||||
|
int bench_;
|
||||||
|
int bench_id_;
|
||||||
};
|
};
|
||||||
|
|
||||||
// register kernel builder
|
// register kernel builder
|
||||||
@@ -379,6 +396,7 @@ void gen_torch_signature(std::ostringstream& oss,
|
|||||||
|
|
||||||
oss << ret_ty << " " << name << "(";
|
oss << ret_ty << " " << name << "(";
|
||||||
oss << "int64_t id, ";
|
oss << "int64_t id, ";
|
||||||
|
oss << "int64_t bench, ";
|
||||||
for(size_t i = 0; i < args.size(); i++) {
|
for(size_t i = 0; i < args.size(); i++) {
|
||||||
ir::argument* arg = args[i];
|
ir::argument* arg = args[i];
|
||||||
if(i > 0)
|
if(i > 0)
|
||||||
@@ -420,7 +438,8 @@ void gen_torch_make_handles(std::ostream &os,
|
|||||||
}
|
}
|
||||||
|
|
||||||
void gen_torch_make_launch_function(std::ostream &os, const std::vector<ir::argument*>& args) {
|
void gen_torch_make_launch_function(std::ostream &os, const std::vector<ir::argument*>& args) {
|
||||||
os << " (*id_fn_map.at(id))({";
|
os << " std::function<void()> run = [&](){\n ";
|
||||||
|
os << " (*id_fn_map.at(id))({";
|
||||||
for(unsigned i = 0; i < args.size() ; i++){
|
for(unsigned i = 0; i < args.size() ; i++){
|
||||||
ir::argument *arg = args[i];
|
ir::argument *arg = args[i];
|
||||||
std::string name = "arg_" + arg->get_name();
|
std::string name = "arg_" + arg->get_name();
|
||||||
@@ -431,7 +450,11 @@ void gen_torch_make_launch_function(std::ostream &os, const std::vector<ir::argu
|
|||||||
os << name;
|
os << name;
|
||||||
}
|
}
|
||||||
os << "}, *id_grid_map.at(id), &stream);\n";
|
os << "}, *id_grid_map.at(id), &stream);\n";
|
||||||
}
|
os << " };\n ";
|
||||||
|
os << " run();";
|
||||||
|
os << " if(bench > 0)\n ";
|
||||||
|
os << " i64scalar_map[id] = triton::tools::bench(run, stream);\n ";
|
||||||
|
}
|
||||||
|
|
||||||
void gen_torch_ret(std::ostream &os, const std::vector<std::string>& outputs) {
|
void gen_torch_ret(std::ostream &os, const std::vector<std::string>& outputs) {
|
||||||
if(outputs.size() == 1){
|
if(outputs.size() == 1){
|
||||||
@@ -465,6 +488,7 @@ std::tuple<std::string,
|
|||||||
#include "triton/driver/backend.h"
|
#include "triton/driver/backend.h"
|
||||||
#include "triton/driver/stream.h"
|
#include "triton/driver/stream.h"
|
||||||
#include "triton/runtime/function.h"
|
#include "triton/runtime/function.h"
|
||||||
|
#include "triton/tools/bench.hpp"
|
||||||
#include "torch/extension.h"
|
#include "torch/extension.h"
|
||||||
#include "torch/script.h"
|
#include "torch/script.h"
|
||||||
#include "ATen/cuda/CUDAContext.h"
|
#include "ATen/cuda/CUDAContext.h"
|
||||||
@@ -479,6 +503,7 @@ namespace drv = triton::driver;
|
|||||||
|
|
||||||
extern std::map<size_t, std::shared_ptr<rt::function::grid_fn_ty>> id_grid_map;
|
extern std::map<size_t, std::shared_ptr<rt::function::grid_fn_ty>> id_grid_map;
|
||||||
extern std::map<size_t, std::shared_ptr<rt::function>> id_fn_map;
|
extern std::map<size_t, std::shared_ptr<rt::function>> id_fn_map;
|
||||||
|
extern std::map<size_t, int64_t> i64scalar_map;
|
||||||
|
|
||||||
)";
|
)";
|
||||||
|
|
||||||
|
@@ -5,6 +5,7 @@ import shutil
|
|||||||
import hashlib
|
import hashlib
|
||||||
import sysconfig
|
import sysconfig
|
||||||
import sys
|
import sys
|
||||||
|
import weakref
|
||||||
# import for just-in-time compilation
|
# import for just-in-time compilation
|
||||||
import distutils
|
import distutils
|
||||||
import setuptools.command.build_ext
|
import setuptools.command.build_ext
|
||||||
@@ -176,6 +177,38 @@ def _make_grid(args) :
|
|||||||
return grid
|
return grid
|
||||||
|
|
||||||
|
|
||||||
|
class bench_dict:
|
||||||
|
|
||||||
|
# Lazy entry for e.g., tensorflow, when value of benchmark is
|
||||||
|
# not known at graph compile time
|
||||||
|
class lazy_entry:
|
||||||
|
def __init__(self, id):
|
||||||
|
self.id = id
|
||||||
|
|
||||||
|
def get(self):
|
||||||
|
return libtriton.retrieve_scalar(self.id)
|
||||||
|
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self.data = dict()
|
||||||
|
|
||||||
|
def __delitem__(self, key):
|
||||||
|
del self.data[id(key)]
|
||||||
|
|
||||||
|
def __getitem__(self, key):
|
||||||
|
ret = self.data[id(key)]
|
||||||
|
if isinstance(ret, bench_dict.lazy_entry):
|
||||||
|
return ret.get()
|
||||||
|
return ret
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return len(self.data)
|
||||||
|
|
||||||
|
def __setitem__(self, key, value):
|
||||||
|
self.data[id(key)] = value
|
||||||
|
|
||||||
|
bench_registry = bench_dict()
|
||||||
|
|
||||||
class kernel:
|
class kernel:
|
||||||
|
|
||||||
def __init__(self, src, outputs):
|
def __init__(self, src, outputs):
|
||||||
@@ -200,7 +233,7 @@ class kernel:
|
|||||||
defines.append((k, values))
|
defines.append((k, values))
|
||||||
opt = libtriton.options_space()
|
opt = libtriton.options_space()
|
||||||
opt.defines = defines
|
opt.defines = defines
|
||||||
opt.num_warps = [4]
|
opt.num_warps = [2, 4, 8]
|
||||||
# create unique id for this op
|
# create unique id for this op
|
||||||
op_id = libtriton.make_op_id()
|
op_id = libtriton.make_op_id()
|
||||||
self.fw_id[key] = op_id
|
self.fw_id[key] = op_id
|
||||||
@@ -209,6 +242,10 @@ class kernel:
|
|||||||
if self.fw_op is None:
|
if self.fw_op is None:
|
||||||
self.fw_op = _make_framework_op(self.src, self.outputs, opt)
|
self.fw_op = _make_framework_op(self.src, self.outputs, opt)
|
||||||
|
|
||||||
|
# benchmarking info
|
||||||
|
bench = 0
|
||||||
|
if 'bench' in kwargs:
|
||||||
|
bench = kwargs['bench']
|
||||||
# retrieve framework op
|
# retrieve framework op
|
||||||
op_id = self.fw_id[key]
|
op_id = self.fw_id[key]
|
||||||
# register grid
|
# register grid
|
||||||
@@ -217,9 +254,16 @@ class kernel:
|
|||||||
op_args = [x.handle if isinstance(x, triton.utils.scalar) else x for x in args[:-1]]
|
op_args = [x.handle if isinstance(x, triton.utils.scalar) else x for x in args[:-1]]
|
||||||
# call framework function
|
# call framework function
|
||||||
if fw.has_tensorflow():
|
if fw.has_tensorflow():
|
||||||
return self.fw_op(*op_args, id=op_id)
|
bench_id = libtriton.make_scalar_id() if bench > 0 else 0
|
||||||
|
ret = self.fw_op(*op_args, id=op_id, bench=bench, bench_id=bench_id)
|
||||||
|
if bench > 0:
|
||||||
|
bench_registry[ret] = bench_dict.lazy_entry(bench_id)
|
||||||
|
|
||||||
elif fw.has_torch():
|
elif fw.has_torch():
|
||||||
args = [x.contiguous() if isinstance(x, fw.torch.Tensor) else x for x in op_args]
|
args = [x.contiguous() if isinstance(x, fw.torch.Tensor) else x for x in op_args]
|
||||||
return self.fw_op(op_id, *args)
|
ret = self.fw_op(op_id, bench, *args)
|
||||||
|
if bench > 0:
|
||||||
|
bench_registry[ret] = libtriton.retrieve_scalar(op_id)
|
||||||
else:
|
else:
|
||||||
assert False
|
assert False
|
||||||
|
return ret
|
@@ -11,38 +11,36 @@ void dot(TYPE * A, TYPE * B, TYPE * C,
|
|||||||
// prologue
|
// prologue
|
||||||
int ridx = get_program_id(0);
|
int ridx = get_program_id(0);
|
||||||
int ridy = get_program_id(1);
|
int ridy = get_program_id(1);
|
||||||
int rxa[TM] = ridx * TM + 0 ... TM;
|
int rm[TM] = ridx * TM + 0 ... TM;
|
||||||
int ryb[TN] = ridy * TN + 0 ... TN;
|
int rn[TN] = ridy * TN + 0 ... TN;
|
||||||
int rka[TK] = 0 ... TK;
|
int rk[TK] = 0 ... TK;
|
||||||
int rkb[TK] = 0 ... TK;
|
|
||||||
float c[TM, TN] = 0;
|
float c[TM, TN] = 0;
|
||||||
// pointers to operands
|
// pointers to operands
|
||||||
TYPE* pa[SHAPE_A] = A + rka[BROADCAST_AK] * STRIDE_AK + rxa[BROADCAST_AM] * STRIDE_AM;
|
TYPE* pa[SHAPE_A] = A + rk[BROADCAST_AK] * STRIDE_AK + rm[BROADCAST_AM] * STRIDE_AM;
|
||||||
TYPE* pb[SHAPE_B] = B + rkb[BROADCAST_BK] * STRIDE_BK + ryb[BROADCAST_BN] * STRIDE_BN;
|
TYPE* pb[SHAPE_B] = B + rk[BROADCAST_BK] * STRIDE_BK + rn[BROADCAST_BN] * STRIDE_BN;
|
||||||
// prefetches operands
|
// prefetches operands
|
||||||
TYPE a[SHAPE_A] = (*pa);
|
TYPE a[SHAPE_A] = *pa;
|
||||||
TYPE b[SHAPE_B] = (*pb);
|
TYPE b[SHAPE_B] = *pb;
|
||||||
// reduction loop
|
// reduction loop
|
||||||
for(int k = K; k > 0; k-= TK){
|
for(int k = K; k > 0; k-= TK){
|
||||||
c += USE_A @ USE_B;
|
c += USE_A @ USE_B;
|
||||||
pa = pa + TK * STRIDE_AK;
|
pa = pa + TK * STRIDE_AK;
|
||||||
pb = pb + TK * STRIDE_BK;
|
pb = pb + TK * STRIDE_BK;
|
||||||
a = *pa;
|
bool checka[SHAPE_A] = k > TK;
|
||||||
b = *pb;
|
bool checkb[SHAPE_B] = k > TK;
|
||||||
|
a = checka ? *pa : 0;
|
||||||
|
b = checkb ? *pb : 0;
|
||||||
}
|
}
|
||||||
// epilogue
|
// epilogue
|
||||||
int rxc[TM] = ridx * TM + 0 ... TM;
|
TYPE* pc[TM, TN] = C + rm[:, newaxis] * ldc + rn[newaxis, :];
|
||||||
int ryc[TN] = ridy * TN + 0 ... TN;
|
*pc = c;
|
||||||
TYPE* pc[TM, TN] = C + ryc[newaxis, :] + rxc[:, newaxis] * ldc;
|
|
||||||
bool checkc[TM, TN] = (rxc < M)[:, newaxis] && (ryc < N)[newaxis, :];
|
|
||||||
*?(checkc) pc = c;
|
|
||||||
}
|
}
|
||||||
"""
|
"""
|
||||||
|
|
||||||
kernel = triton.kernel(src, ['C'])
|
kernel = triton.kernel(src, ['C'])
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _call(a, b, transpose_a, transpose_b):
|
def _call(a, b, transpose_a, transpose_b, bench = 0):
|
||||||
# extract shapes
|
# extract shapes
|
||||||
shape_a = triton.shape(a)
|
shape_a = triton.shape(a)
|
||||||
shape_b = triton.shape(b)
|
shape_b = triton.shape(b)
|
||||||
@@ -78,16 +76,17 @@ void dot(TYPE * A, TYPE * B, TYPE * C,
|
|||||||
'BROADCAST_BK': 'newaxis, :' if transpose_b else ':, newaxis',
|
'BROADCAST_BK': 'newaxis, :' if transpose_b else ':, newaxis',
|
||||||
'BROADCAST_BN': ':, newaxis' if transpose_b else 'newaxis, :',
|
'BROADCAST_BN': ':, newaxis' if transpose_b else 'newaxis, :',
|
||||||
'SHAPE_B' : 'TN, TK' if transpose_b else 'TK, TN'}
|
'SHAPE_B' : 'TN, TK' if transpose_b else 'TK, TN'}
|
||||||
return _dot.kernel(a, b, c, M, N, Ka, lda, ldb, ldc, grid,
|
return _dot.kernel(a, b, c, M, N, Ka, lda, ldb, ldc,
|
||||||
|
grid, bench=bench,
|
||||||
AT = transpose_a, BT = transpose_b, TYPE = dtype,
|
AT = transpose_a, BT = transpose_b, TYPE = dtype,
|
||||||
TM = [128], TN = [128], TK = [8], **macros)
|
TM = [64, 128], TN = [64, 128], TK = [8], **macros)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def forward(ctx, a, b, transpose_a = False, transpose_b = False):
|
def forward(ctx, a, b, transpose_a = False, transpose_b = False, bench = 0):
|
||||||
ctx.save_for_backward(a, b)
|
ctx.save_for_backward(a, b)
|
||||||
ctx.t_a = transpose_a
|
ctx.t_a = transpose_a
|
||||||
ctx.t_b = transpose_b
|
ctx.t_b = transpose_b
|
||||||
return _dot._call(a, b, transpose_a, transpose_b)
|
return _dot._call(a, b, transpose_a, transpose_b, bench)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def backward(ctx, dy):
|
def backward(ctx, dy):
|
||||||
@@ -108,5 +107,5 @@ void dot(TYPE * A, TYPE * B, TYPE * C,
|
|||||||
else:
|
else:
|
||||||
assert False
|
assert False
|
||||||
return da, db, None, None, None, None, None, None, None
|
return da, db, None, None, None, None, None, None, None
|
||||||
|
|
||||||
dot = _dot.apply
|
dot = _dot.apply
|
@@ -2,52 +2,58 @@
|
|||||||
|
|
||||||
|
|
||||||
import triton
|
import triton
|
||||||
|
import math
|
||||||
|
|
||||||
class _einsum(triton.function):
|
class _einsum(triton.function):
|
||||||
|
|
||||||
src = """
|
src = """
|
||||||
void einsum_(TYPE * A, TYPE * B, TYPE * C,
|
void einsum_(TYPE * A, TYPE * B, TYPE * C,
|
||||||
int dim_M, int dim_N, int dim_K,
|
int dim_M, int dim_N, int dim_K,
|
||||||
int std_A0, int std_B0, int std_C0,
|
int std_A0, int std_B0, int std_C0,
|
||||||
int std_A1, int std_B1, int std_C1) {
|
int std_A1, int std_B1, int std_C1) {
|
||||||
// program id
|
// program id
|
||||||
int pgm = get_program_id(0);
|
int pgm = get_program_id(0);
|
||||||
int pgn = get_program_id(1);
|
int pgn = get_program_id(1);
|
||||||
int pgb = get_program_id(2);
|
int pgb = get_program_id(2);
|
||||||
// range
|
// range
|
||||||
int rm[TM] = pgm * TM + 0 ... TM;
|
int rm[TM] = pgm * TM + 0 ... TM;
|
||||||
int rn[TN] = pgn * TN + 0 ... TN;
|
int rn[TN] = pgn * TN + 0 ... TN;
|
||||||
int rb[TB] = pgb * TB + 0 ... TB;
|
int rb[TB] = pgb * TB + 0 ... TB;
|
||||||
int rk[TK] = 0 ... TK;
|
int rk[TK] = 0 ... TK;
|
||||||
// accumulator
|
// accumulator
|
||||||
TYPE c[TM, TN, TB] = 0;
|
TYPE c[TM, TN, TB] = 0;
|
||||||
// pointers to a
|
// pointers to a
|
||||||
TYPE *pa[SHAPE_A] = A + rk[BROADCAST_AK] * STRIDE_AK
|
TYPE *pa[SHAPE_A] = A + rk[BROADCAST_AK] * STRIDE_AK
|
||||||
+ rm[BROADCAST_AM] * STRIDE_AM
|
+ rm[BROADCAST_AM] * STRIDE_AM
|
||||||
+ rb[newaxis, newaxis, :] * std_A0;
|
+ rb[newaxis, newaxis, :] * std_A0;
|
||||||
// pointers to b
|
// pointers to b
|
||||||
TYPE *pb[SHAPE_B] = B + rk[BROADCAST_BK] * STRIDE_BK
|
TYPE *pb[SHAPE_B] = B + rk[BROADCAST_BK] * STRIDE_BK
|
||||||
+ rn[BROADCAST_BN] * STRIDE_BN
|
+ rn[BROADCAST_BN] * STRIDE_BN
|
||||||
+ rb[newaxis, newaxis, :] * std_B0;
|
+ rb[newaxis, newaxis, :] * std_B0;
|
||||||
// accumulation
|
// prefetch
|
||||||
for(int k = dim_K; k > 0; k -= TK) {
|
TYPE a[SHAPE_A] = *pa;
|
||||||
TYPE a[SHAPE_A] = *pa;
|
TYPE b[SHAPE_B] = *pb;
|
||||||
TYPE b[SHAPE_B] = *pb;
|
// accumulation
|
||||||
c += USE_A @ USE_B;
|
for(int k = dim_K; k > 0; k -= TK) {
|
||||||
pa += TK * STRIDE_AK;
|
c += USE_A @ USE_B;
|
||||||
pb += TK * STRIDE_BK;
|
pa += TK * STRIDE_AK;
|
||||||
}
|
pb += TK * STRIDE_BK;
|
||||||
// write-back
|
bool checka[SHAPE_A] = k > TK;
|
||||||
TYPE *pc[TM, TN, TB] = C + rm[:, newaxis, newaxis] * std_C1
|
bool checkb[SHAPE_B] = k > TK;
|
||||||
+ rn[newaxis, :, newaxis] * 1
|
a = checka ? *pa : 0;
|
||||||
+ rb[newaxis, newaxis, :] * std_C0;
|
b = checkb ? *pb : 0;
|
||||||
bool checkm[TM] = rm < dim_M;
|
|
||||||
bool checkn[TN] = rn < dim_N;
|
|
||||||
bool checkc[TM, TN, TB] = checkm[:, newaxis, newaxis] &&
|
|
||||||
checkn[newaxis, :, newaxis];
|
|
||||||
*?(checkc)pc = c;
|
|
||||||
}
|
}
|
||||||
"""
|
// write-back
|
||||||
|
TYPE *pc[TM, TN, TB] = C + rm[:, newaxis, newaxis] * std_C1
|
||||||
|
+ rn[newaxis, :, newaxis] * 1
|
||||||
|
+ rb[newaxis, newaxis, :] * std_C0;
|
||||||
|
bool checkm[TM] = rm < dim_M;
|
||||||
|
bool checkn[TN] = rn < dim_N;
|
||||||
|
bool checkc[TM, TN, TB] = checkm[:, newaxis, newaxis] &&
|
||||||
|
checkn[newaxis, :, newaxis];
|
||||||
|
*?(checkc)pc = c;
|
||||||
|
}
|
||||||
|
"""
|
||||||
|
|
||||||
kernel = triton.kernel(src, ['C'])
|
kernel = triton.kernel(src, ['C'])
|
||||||
|
|
||||||
@@ -134,7 +140,8 @@ class _einsum(triton.function):
|
|||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def call(a, b, trans_a, trans_b, shape_c, bmnk,
|
def call(a, b, trans_a, trans_b, shape_c, bmnk,
|
||||||
std0, std1, einsum_a, einsum_b, einsum_c):
|
std0, std1, einsum_a, einsum_b, einsum_c,
|
||||||
|
bench):
|
||||||
dtype = a.dtype
|
dtype = a.dtype
|
||||||
c = triton.empty(shape_c, dtype)
|
c = triton.empty(shape_c, dtype)
|
||||||
grid = lambda opt: [triton.cdiv(bmnk[1], opt.d('TM')),
|
grid = lambda opt: [triton.cdiv(bmnk[1], opt.d('TM')),
|
||||||
@@ -154,16 +161,22 @@ class _einsum(triton.function):
|
|||||||
'BROADCAST_BK': ':, newaxis, newaxis' if not trans_b else 'newaxis, :, newaxis',
|
'BROADCAST_BK': ':, newaxis, newaxis' if not trans_b else 'newaxis, :, newaxis',
|
||||||
'BROADCAST_BN': 'newaxis, :, newaxis' if not trans_b else ':, newaxis, newaxis',
|
'BROADCAST_BN': 'newaxis, :, newaxis' if not trans_b else ':, newaxis, newaxis',
|
||||||
'SHAPE_B' : 'TK, TN, TB' if not trans_b else 'TN, TK, TB'}
|
'SHAPE_B' : 'TK, TN, TB' if not trans_b else 'TN, TK, TB'}
|
||||||
return _einsum.kernel(a, b, c,
|
TM = [2**i for i in range(5, max(6, min(8, int(math.log2(bmnk[1]) + 1 ))))]
|
||||||
|
TN = [2**i for i in range(5, max(6, min(8, int(math.log2(bmnk[2]) + 1 ))))]
|
||||||
|
TB = [2**i for i in range(0, max(1, min(3, int(math.log2(bmnk[0]) + 1 ))))]
|
||||||
|
print(TM)
|
||||||
|
print(TN)
|
||||||
|
return _einsum.kernel(a, b, c,
|
||||||
bmnk[1], bmnk[2], bmnk[3],
|
bmnk[1], bmnk[2], bmnk[3],
|
||||||
std0[0], std0[1], std0[2],
|
std0[0], std0[1], std0[2],
|
||||||
std1[0], std1[1], std1[2],
|
std1[0], std1[1], std1[2],
|
||||||
grid, **macros,
|
grid, bench=bench,
|
||||||
TYPE='float', TM=32, TN=32, TK=8, TB=1)
|
**macros,
|
||||||
|
TYPE='float', TM=TM, TN=TN, TK=8, TB=TB)
|
||||||
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def forward(ctx, subscripts, a, b):
|
def forward(ctx, subscripts, a, b, **kwargs):
|
||||||
ctx.save_for_backward(a, b)
|
ctx.save_for_backward(a, b)
|
||||||
if type(subscripts) is str:
|
if type(subscripts) is str:
|
||||||
einsum_a, einsum_bc = subscripts.split(",")
|
einsum_a, einsum_bc = subscripts.split(",")
|
||||||
@@ -173,14 +186,16 @@ class _einsum(triton.function):
|
|||||||
|
|
||||||
shape_c, bmnk, std0, std1, ta, tb = _einsum._parse_einsum(
|
shape_c, bmnk, std0, std1, ta, tb = _einsum._parse_einsum(
|
||||||
einsum_a, einsum_b, einsum_c,
|
einsum_a, einsum_b, einsum_c,
|
||||||
a.shape.as_list(), b.shape.as_list()
|
triton.shape(a), triton.shape(b)
|
||||||
)
|
)
|
||||||
|
bench = kwargs['bench'] if 'bench' in kwargs else 0
|
||||||
ctx.trans_a = ta
|
ctx.trans_a = ta
|
||||||
ctx.trans_b = tb
|
ctx.trans_b = tb
|
||||||
ctx.einsum_a = einsum_a
|
ctx.einsum_a = einsum_a
|
||||||
ctx.einsum_b = einsum_b
|
ctx.einsum_b = einsum_b
|
||||||
ctx.einsum_c = einsum_c
|
ctx.einsum_c = einsum_c
|
||||||
return _einsum.call(a, b, ta, tb, shape_c, bmnk, std0, std1, einsum_a, einsum_b, einsum_c)
|
ctx.bench = bench
|
||||||
|
return _einsum.call(a, b, ta, tb, shape_c, bmnk, std0, std1, einsum_a, einsum_b, einsum_c, bench)
|
||||||
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@@ -191,22 +206,23 @@ class _einsum(triton.function):
|
|||||||
einsum_a = ctx.einsum_a
|
einsum_a = ctx.einsum_a
|
||||||
einsum_b = ctx.einsum_b
|
einsum_b = ctx.einsum_b
|
||||||
einsum_c = ctx.einsum_c
|
einsum_c = ctx.einsum_c
|
||||||
|
bench = ctx.bench
|
||||||
|
|
||||||
if not trans_a and not trans_b: # NN
|
if not trans_a and not trans_b: # NN
|
||||||
da = einsum((einsum_c, einsum_b, einsum_a), dc, b)
|
da = einsum((einsum_c, einsum_b, einsum_a), dc, b, bench=bench)
|
||||||
db = einsum((einsum_a, einsum_c, einsum_b), a, dc)
|
db = einsum((einsum_a, einsum_c, einsum_b), a, dc, bench=bench)
|
||||||
|
|
||||||
elif not trans_a and trans_b: # NT
|
elif not trans_a and trans_b: # NT
|
||||||
da = einsum((einsum_c, einsum_b, einsum_a), dc, b)
|
da = einsum((einsum_c, einsum_b, einsum_a), dc, b, bench=bench)
|
||||||
db = einsum((einsum_c, einsum_a, einsum_b), dc, a)
|
db = einsum((einsum_c, einsum_a, einsum_b), dc, a, bench=bench)
|
||||||
|
|
||||||
elif trans_a and not trans_b: # TN
|
elif trans_a and not trans_b: # TN
|
||||||
da = einsum((einsum_b, einsum_c, einsum_a), b, dc)
|
da = einsum((einsum_b, einsum_c, einsum_a), b, dc, bench=bench)
|
||||||
db = einsum((einsum_a, einsum_c, einsum_b), a, dc)
|
db = einsum((einsum_a, einsum_c, einsum_b), a, dc, bench=bench)
|
||||||
|
|
||||||
elif trans_a and trans_b: # TT (not used)
|
elif trans_a and trans_b: # TT (not used)
|
||||||
da = einsum((einsum_b, einsum_c, einsum_a), b, dc)
|
da = einsum((einsum_b, einsum_c, einsum_a), b, dc, bench=bench)
|
||||||
db = einsum((einsum_c, einsum_a, einsum_b), dc, a)
|
db = einsum((einsum_c, einsum_a, einsum_b), dc, a, bench=bench)
|
||||||
|
|
||||||
return da, db, None, None, None, None, None, None, None, None, None, None
|
return da, db, None, None, None, None, None, None, None, None, None, None
|
||||||
|
|
||||||
|
@@ -22,7 +22,8 @@ class lazy_shape:
|
|||||||
|
|
||||||
def shape(A) :
|
def shape(A) :
|
||||||
if fw.has_tensorflow():
|
if fw.has_tensorflow():
|
||||||
return lazy_shape(fw.tensorflow.shape(A))
|
return A.shape.as_list()
|
||||||
|
#return lazy_shape(fw.tensorflow.shape(A))
|
||||||
elif fw.has_torch():
|
elif fw.has_torch():
|
||||||
return A.shape
|
return A.shape
|
||||||
else:
|
else:
|
||||||
|
@@ -34,7 +34,7 @@ int main() {
|
|||||||
for(const auto& c: configs){
|
for(const auto& c: configs){
|
||||||
std::tie(ord, AT, BT, M, N, K) = c;
|
std::tie(ord, AT, BT, M, N, K) = c;
|
||||||
std::cout << "// " << c << std::flush;
|
std::cout << "// " << c << std::flush;
|
||||||
for(auto perf: bench_dot(stream, HALF, AT, BT, M, N, K, ord, ord))
|
for(auto perf: bench_dot(stream, FLOAT, AT, BT, M, N, K, ord, ord))
|
||||||
std::cout << ", " << perf << std::flush;
|
std::cout << ", " << perf << std::flush;
|
||||||
std::cout << std::endl;
|
std::cout << std::endl;
|
||||||
}
|
}
|
||||||
|
Reference in New Issue
Block a user