[dnn][blocksparse] added dw code
This commit is contained in:
@@ -49,7 +49,7 @@ perf_t do_bench(triton::driver::stream* stream, bool AT, bool BT, int32_t M, int
|
|||||||
stream->synchronize();
|
stream->synchronize();
|
||||||
triton::dnn::dot dot(M, N, K, AT, BT, ty, ty, ty, 8, 8, 8);
|
triton::dnn::dot dot(M, N, K, AT, BT, ty, ty, ty, 8, 8, 8);
|
||||||
// benchmark triton
|
// benchmark triton
|
||||||
double triton_ns = triton::tools::bench([&]() { dot.enqueue(stream, {da, db, dc}, triton::dnn::NO_TUNING);}, stream);
|
double triton_ns = triton::tools::bench([&]() { dot.enqueue(stream, {da, db, dc}, triton::dnn::FULL_TUNING);}, stream);
|
||||||
// benchmark cublas
|
// benchmark cublas
|
||||||
// NumericT alpha = 1;
|
// NumericT alpha = 1;
|
||||||
// NumericT beta = 0;
|
// NumericT beta = 0;
|
||||||
@@ -111,7 +111,7 @@ int main() {
|
|||||||
// shapes to benchmark
|
// shapes to benchmark
|
||||||
std::vector<config_t> configs = {
|
std::vector<config_t> configs = {
|
||||||
// {false, false, 8192, 512, 512},
|
// {false, false, 8192, 512, 512},
|
||||||
{false, true, 128, 128, 128}
|
{false, true, 64, 64, 128}
|
||||||
// {false, true, 128, 128, 128},
|
// {false, true, 128, 128, 128},
|
||||||
// {false, false, 128, 128, 128},
|
// {false, false, 128, 128, 128},
|
||||||
// {true, false, 128, 128, 128},
|
// {true, false, 128, 128, 128},
|
||||||
|
@@ -46,6 +46,19 @@ Status XpropShape(InferenceContext* ctx)
|
|||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Status UpdatShape(InferenceContext* ctx)
|
||||||
|
{
|
||||||
|
//printf("UpdatShape: %d\n", ctx->Rank(ctx->input(0)));
|
||||||
|
|
||||||
|
int blocks, bsize;
|
||||||
|
TF_RETURN_IF_ERROR(ctx->GetAttr("blocks", &blocks));
|
||||||
|
TF_RETURN_IF_ERROR(ctx->GetAttr("bsize", &bsize));
|
||||||
|
|
||||||
|
// (blocks, block_size, block_size)
|
||||||
|
DimensionHandle bsize_dim = ctx->MakeDim(bsize);
|
||||||
|
ctx->set_output(0, ctx->MakeShape({ ctx->MakeDim(blocks), bsize_dim, bsize_dim }));
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
typedef struct bsmm_params
|
typedef struct bsmm_params
|
||||||
{
|
{
|
||||||
@@ -72,34 +85,46 @@ typedef struct bsmm_params
|
|||||||
|
|
||||||
template<triton::dnn::blocksparse::op_t OP, typename T>
|
template<triton::dnn::blocksparse::op_t OP, typename T>
|
||||||
class BlocksparseMatmulOp : public OpKernel {
|
class BlocksparseMatmulOp : public OpKernel {
|
||||||
public:
|
private:
|
||||||
explicit BlocksparseMatmulOp(OpKernelConstruction* ctx) : OpKernel(ctx) {
|
void ComputeDw(OpKernelContext* context){
|
||||||
OP_REQUIRES_OK(ctx, ctx->GetAttr("segments", ¶ms_.segments));
|
// get device/stream
|
||||||
OP_REQUIRES_OK(ctx, ctx->GetAttr("locks", ¶ms_.locks ));
|
GPUDevice device = context->eigen_device<GPUDevice>();
|
||||||
OP_REQUIRES_OK(ctx, ctx->GetAttr("blocks", ¶ms_.blocks ));
|
triton::driver::cu_stream sstream(device.stream(), false);
|
||||||
OP_REQUIRES_OK(ctx, ctx->GetAttr("bsize", ¶ms_.bsize ));
|
triton::driver::context* ctx = sstream.context();
|
||||||
OP_REQUIRES_OK(ctx, ctx->GetAttr("C", ¶ms_.C ));
|
triton::driver::stream* stream = &sstream;
|
||||||
OP_REQUIRES_OK(ctx, ctx->GetAttr("K", ¶ms_.K ));
|
// extract input
|
||||||
OP_REQUIRES_OK(ctx, ctx->GetAttr("shared", ¶ms_.shared ));
|
OpInputList x, dy, gate;
|
||||||
OP_REQUIRES_OK(ctx, ctx->GetAttr("alpha", ¶ms_.alpha ));
|
context->input_list( "x", &x);
|
||||||
OP_REQUIRES_OK(ctx, ctx->GetAttr("beta", ¶ms_.beta ));
|
context->input_list( "dy", &dy);
|
||||||
OP_REQUIRES_OK(ctx, ctx->GetAttr("gated_dw", &gated_dw_ ));
|
context->input_list("gate", &gate);
|
||||||
OP_REQUIRES_OK(ctx, ctx->GetAttr("axis", &axis_ ));
|
// sanity checks
|
||||||
OP_REQUIRES_OK(ctx, ctx->GetAttr("bench", &bench_));
|
params_.pcount = x.size();
|
||||||
OP_REQUIRES(ctx, params_.K < params_.bsize*65536, errors::InvalidArgument("K < bsize*65536"));
|
if (params_.pcount > 1)
|
||||||
OP_REQUIRES(ctx, params_.C < params_.bsize*65536, errors::InvalidArgument("C < bsize*65536"));
|
errors::Internal("No more than 1 input allowed.");
|
||||||
params_.pcount = 1;
|
if (params_.beta != 0.0f || params_.alpha != 1.0f)
|
||||||
params_.blk_A = 0;
|
errors::Internal("Not supported yet");
|
||||||
is_gpu_ = ctx->device_type() == DEVICE_GPU;
|
// N
|
||||||
if (bench_) {
|
int N = 1;
|
||||||
repeat_ = bench_;
|
int rank = x[0].dims();
|
||||||
flops_ = (float)(params_.blocks * params_.bsize*params_.bsize);
|
for (int i = 0; i < rank; i++)
|
||||||
const char* op = "FPROP";
|
if (i != axis_)
|
||||||
sprintf(bench_string_, "%s %02d-%d C:%05d K:%05d blks:%d", op, params_.bsize, axis_, params_.C, params_.K, params_.blocks);
|
N *= x[0].dim_size(i);
|
||||||
}
|
// allocate output
|
||||||
|
Tensor* C;
|
||||||
|
TensorShape shapeC({ params_.blocks, params_.bsize, params_.bsize });
|
||||||
|
OP_REQUIRES_OK(context, context->allocate_output(0, shapeC, &C));
|
||||||
|
// wrap tensorflow handles
|
||||||
|
triton::driver::cu_buffer da(ctx, x[0].tensor_data().size(), (CUdeviceptr)x[0].tensor_data().data(), false);
|
||||||
|
triton::driver::cu_buffer db(ctx, dy[0].tensor_data().size(), (CUdeviceptr)dy[0].tensor_data().data(), false);
|
||||||
|
triton::driver::cu_buffer dc(ctx, C->tensor_data().size(), (CUdeviceptr)C->tensor_data().data(), false);
|
||||||
|
triton::driver::cu_buffer dlut(ctx, context->input(params_.pcount*2).tensor_data().size(), (CUdeviceptr)context->input(params_.pcount*2).tensor_data().data(), false);
|
||||||
|
// create profile
|
||||||
|
triton::dnn::blocksparse::dot dot(N, params_.K, params_.segments, params_.C, "half", params_.bsize, params_.locks, params_.blocks, OP);
|
||||||
|
// enqueue
|
||||||
|
dot.enqueue(stream, {&da, &db, &dc, &dlut}, triton::dnn::FULL_TUNING);
|
||||||
}
|
}
|
||||||
|
|
||||||
void Compute(OpKernelContext* context){
|
void ComputeYDx(OpKernelContext* context){
|
||||||
// get device/stream
|
// get device/stream
|
||||||
GPUDevice device = context->eigen_device<GPUDevice>();
|
GPUDevice device = context->eigen_device<GPUDevice>();
|
||||||
triton::driver::cu_stream sstream(device.stream(), false);
|
triton::driver::cu_stream sstream(device.stream(), false);
|
||||||
@@ -129,8 +154,8 @@ public:
|
|||||||
triton::driver::cu_buffer dlut(ctx, lut.tensor_data().size(), (CUdeviceptr)lut.tensor_data().data(), false);
|
triton::driver::cu_buffer dlut(ctx, lut.tensor_data().size(), (CUdeviceptr)lut.tensor_data().data(), false);
|
||||||
// create profile
|
// create profile
|
||||||
triton::dnn::blocksparse::dot dot(N, params_.K, params_.segments, params_.C, "half", params_.bsize, params_.locks, params_.blocks, OP);
|
triton::dnn::blocksparse::dot dot(N, params_.K, params_.segments, params_.C, "half", params_.bsize, params_.locks, params_.blocks, OP);
|
||||||
// blocksparse matmul
|
// enqueue
|
||||||
triton::dnn::base* op = dot.enqueue(stream, {&da, &db, &dc, &dlut}, triton::dnn::PARTIAL_TUNING);
|
triton::dnn::base* op = dot.enqueue(stream, {&da, &db, &dc, &dlut}, triton::dnn::NO_TUNING);
|
||||||
triton::driver::buffer* locks_buffer = ((triton::dnn::blocksparse::dot*)op)->get_locks();
|
triton::driver::buffer* locks_buffer = ((triton::dnn::blocksparse::dot*)op)->get_locks();
|
||||||
Tensor *tmp = nullptr;
|
Tensor *tmp = nullptr;
|
||||||
TensorShape tmp_shapes;
|
TensorShape tmp_shapes;
|
||||||
@@ -138,6 +163,41 @@ public:
|
|||||||
OP_REQUIRES_OK(context, context->allocate_output(1, tmp_shapes, &tmp));
|
OP_REQUIRES_OK(context, context->allocate_output(1, tmp_shapes, &tmp));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public:
|
||||||
|
|
||||||
|
explicit BlocksparseMatmulOp(OpKernelConstruction* ctx) : OpKernel(ctx) {
|
||||||
|
OP_REQUIRES_OK(ctx, ctx->GetAttr("segments", ¶ms_.segments));
|
||||||
|
OP_REQUIRES_OK(ctx, ctx->GetAttr("locks", ¶ms_.locks ));
|
||||||
|
OP_REQUIRES_OK(ctx, ctx->GetAttr("blocks", ¶ms_.blocks ));
|
||||||
|
OP_REQUIRES_OK(ctx, ctx->GetAttr("bsize", ¶ms_.bsize ));
|
||||||
|
OP_REQUIRES_OK(ctx, ctx->GetAttr("C", ¶ms_.C ));
|
||||||
|
OP_REQUIRES_OK(ctx, ctx->GetAttr("K", ¶ms_.K ));
|
||||||
|
OP_REQUIRES_OK(ctx, ctx->GetAttr("shared", ¶ms_.shared ));
|
||||||
|
OP_REQUIRES_OK(ctx, ctx->GetAttr("alpha", ¶ms_.alpha ));
|
||||||
|
OP_REQUIRES_OK(ctx, ctx->GetAttr("beta", ¶ms_.beta ));
|
||||||
|
OP_REQUIRES_OK(ctx, ctx->GetAttr("gated_dw", &gated_dw_ ));
|
||||||
|
OP_REQUIRES_OK(ctx, ctx->GetAttr("axis", &axis_ ));
|
||||||
|
OP_REQUIRES_OK(ctx, ctx->GetAttr("bench", &bench_));
|
||||||
|
OP_REQUIRES(ctx, params_.K < params_.bsize*65536, errors::InvalidArgument("K < bsize*65536"));
|
||||||
|
OP_REQUIRES(ctx, params_.C < params_.bsize*65536, errors::InvalidArgument("C < bsize*65536"));
|
||||||
|
params_.pcount = 1;
|
||||||
|
params_.blk_A = 0;
|
||||||
|
is_gpu_ = ctx->device_type() == DEVICE_GPU;
|
||||||
|
if (bench_) {
|
||||||
|
repeat_ = bench_;
|
||||||
|
flops_ = (float)(params_.blocks * params_.bsize*params_.bsize);
|
||||||
|
const char* op = "FPROP";
|
||||||
|
sprintf(bench_string_, "%s %02d-%d C:%05d K:%05d blks:%d", op, params_.bsize, axis_, params_.C, params_.K, params_.blocks);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void Compute(OpKernelContext* context) override{
|
||||||
|
if(OP == triton::dnn::blocksparse::WGRAD)
|
||||||
|
ComputeDw(context);
|
||||||
|
else
|
||||||
|
ComputeYDx(context);
|
||||||
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
bsmm_params params_;
|
bsmm_params params_;
|
||||||
int axis_, bench_, repeat_, SMs_, major_, grid_n_;
|
int axis_, bench_, repeat_, SMs_, major_, grid_n_;
|
||||||
@@ -212,3 +272,33 @@ Multiply the matrix "a" by the blocksparse matrix "b".
|
|||||||
REGISTER_KERNEL_BUILDER(Name("TritonBlocksparseMatmulDX").Device(DEVICE_GPU).TypeConstraint<float>("T"),BlocksparseMatmulOp<triton::dnn::blocksparse::BPROP, float>);
|
REGISTER_KERNEL_BUILDER(Name("TritonBlocksparseMatmulDX").Device(DEVICE_GPU).TypeConstraint<float>("T"),BlocksparseMatmulOp<triton::dnn::blocksparse::BPROP, float>);
|
||||||
REGISTER_KERNEL_BUILDER(Name("TritonBlocksparseMatmulDX").Device(DEVICE_GPU).TypeConstraint<Eigen::half>("T"),BlocksparseMatmulOp<triton::dnn::blocksparse::BPROP, Eigen::half>);
|
REGISTER_KERNEL_BUILDER(Name("TritonBlocksparseMatmulDX").Device(DEVICE_GPU).TypeConstraint<Eigen::half>("T"),BlocksparseMatmulOp<triton::dnn::blocksparse::BPROP, Eigen::half>);
|
||||||
|
|
||||||
|
|
||||||
|
REGISTER_OP("TritonBlocksparseMatmulDW")
|
||||||
|
.Input("x: params * T")
|
||||||
|
.Input("dy: params * T")
|
||||||
|
.Input("lut: int64")
|
||||||
|
.Input("gate: ngate * float")
|
||||||
|
.Output("dw: T")
|
||||||
|
.Attr("T: {half, float, bfloat16}")
|
||||||
|
.Attr("params: int")
|
||||||
|
.Attr("blocks: int >=0")
|
||||||
|
.Attr("bsize: int")
|
||||||
|
.Attr("segments: int = 0")
|
||||||
|
.Attr("locks: int = 0")
|
||||||
|
.Attr("axis: int = 1")
|
||||||
|
.Attr("C: int >=0")
|
||||||
|
.Attr("K: int >=0")
|
||||||
|
.Attr("shared: int = 0")
|
||||||
|
.Attr("alpha: float = 1.0")
|
||||||
|
.Attr("beta: float = 0.0")
|
||||||
|
.Attr("gated_dw: bool = false")
|
||||||
|
.Attr("gate_grad: bool = false")
|
||||||
|
.Attr("bench: int = 0")
|
||||||
|
.Attr("ngate: int >= 0")
|
||||||
|
.SetShapeFn(UpdatShape)
|
||||||
|
.Doc(R"doc(
|
||||||
|
Multiply the matrix "a" by the blocksparse matrix "b".
|
||||||
|
)doc");
|
||||||
|
|
||||||
|
REGISTER_KERNEL_BUILDER(Name("TritonBlocksparseMatmulDW").Device(DEVICE_GPU).TypeConstraint<float>("T"),BlocksparseMatmulOp<triton::dnn::blocksparse::WGRAD, float>);
|
||||||
|
REGISTER_KERNEL_BUILDER(Name("TritonBlocksparseMatmulDW").Device(DEVICE_GPU).TypeConstraint<Eigen::half>("T"),BlocksparseMatmulOp<triton::dnn::blocksparse::WGRAD, Eigen::half>);
|
||||||
|
@@ -29,6 +29,9 @@ private:
|
|||||||
void init_impl(driver::stream *stream, driver::cu_module *module, triton::runtime::launch_information info);
|
void init_impl(driver::stream *stream, driver::cu_module *module, triton::runtime::launch_information info);
|
||||||
// deinit
|
// deinit
|
||||||
void deinit_impl();
|
void deinit_impl();
|
||||||
|
// source
|
||||||
|
std::string triton_c_src_ydx() const;
|
||||||
|
std::string triton_c_src_dw() const;
|
||||||
public:
|
public:
|
||||||
// constructor
|
// constructor
|
||||||
dot(int32_t N, int32_t K, int32_t S, int32_t C, const std::string &ty, int32_t BS, int32_t nlocks, int32_t nblocks, op_t op = FPROP);
|
dot(int32_t N, int32_t K, int32_t S, int32_t C, const std::string &ty, int32_t BS, int32_t nlocks, int32_t nblocks, op_t op = FPROP);
|
||||||
|
@@ -75,7 +75,6 @@ public:
|
|||||||
optimize_dot.run(module);
|
optimize_dot.run(module);
|
||||||
optimize_trans.run(module);
|
optimize_trans.run(module);
|
||||||
dce.run(module);
|
dce.run(module);
|
||||||
// ir::print(module, std::cout);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
void target_dependent(ir::module &module) {
|
void target_dependent(ir::module &module) {
|
||||||
@@ -89,6 +88,7 @@ public:
|
|||||||
}
|
}
|
||||||
vectorize.run(module);
|
vectorize.run(module);
|
||||||
dce.run(module);
|
dce.run(module);
|
||||||
|
// ir::print(module, std::cout);
|
||||||
}
|
}
|
||||||
|
|
||||||
codegen::selection selection;
|
codegen::selection selection;
|
||||||
|
@@ -250,15 +250,15 @@ void tune::run(ir::module &mod) {
|
|||||||
node_t node = *nodes_.begin();
|
node_t node = *nodes_.begin();
|
||||||
if(fragments_[node] == STRIDED_SCAN) {
|
if(fragments_[node] == STRIDED_SCAN) {
|
||||||
ir::metaparameter *nts = ir::metaparameter::create(ctx, ty, 1, 1);
|
ir::metaparameter *nts = ir::metaparameter::create(ctx, ty, 1, 1);
|
||||||
ir::metaparameter *mts = ir::metaparameter::create(ctx, ty, 4, 32);
|
ir::metaparameter *mts = ir::metaparameter::create(ctx, ty, 1, 8);
|
||||||
connected_components(node, {nts, mts}, {"nts", "mts"}, nodes_, dependencies_, group_id++);
|
connected_components(node, {nts, mts}, {"nts", "mts"}, nodes_, dependencies_, group_id++);
|
||||||
nts->set_value(1);
|
nts->set_value(1);
|
||||||
}
|
}
|
||||||
else {
|
else {
|
||||||
ir::metaparameter *fpw = ir::metaparameter::create(ctx, ty, 2, 2);
|
ir::metaparameter *fpw = ir::metaparameter::create(ctx, ty, 1, 1);
|
||||||
if(node.second == 2)
|
if(node.second == 2)
|
||||||
fpw->set_value(1);
|
fpw->set_value(1);
|
||||||
ir::metaparameter *wpt = ir::metaparameter::create(ctx, ty, 1, 4);
|
ir::metaparameter *wpt = ir::metaparameter::create(ctx, ty, 1, 1);
|
||||||
connected_components(node, {fpw, wpt}, {"fpw", "wpt"}, nodes_, dependencies_, group_id++);
|
connected_components(node, {fpw, wpt}, {"fpw", "wpt"}, nodes_, dependencies_, group_id++);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -277,7 +277,7 @@ void tune::run(ir::module &mod) {
|
|||||||
size_t addr_space = ptr_ty->get_pointer_address_space();
|
size_t addr_space = ptr_ty->get_pointer_address_space();
|
||||||
if(addr_space < 4){
|
if(addr_space < 4){
|
||||||
ir::type *ty = mod.get_builder().get_int32_ty();
|
ir::type *ty = mod.get_builder().get_int32_ty();
|
||||||
std::unique_ptr<ir::metaparameter> tmp(ir::metaparameter::create(ctx, ty, 4, 8));
|
std::unique_ptr<ir::metaparameter> tmp(ir::metaparameter::create(ctx, ty, 1, 1));
|
||||||
*params_.at(i).at("nts.d0") = *tmp;
|
*params_.at(i).at("nts.d0") = *tmp;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -287,8 +287,8 @@ void tune::run(ir::module &mod) {
|
|||||||
// *params_.at(i->get_operand(0)).at("mts.d2") = *mts_2;
|
// *params_.at(i->get_operand(0)).at("mts.d2") = *mts_2;
|
||||||
// *params_.at(i->get_operand(1)).at("mts.d2") = *mts_2;
|
// *params_.at(i->get_operand(1)).at("mts.d2") = *mts_2;
|
||||||
if(fragments_.at({i, 0}) == STRIDED_SCAN){
|
if(fragments_.at({i, 0}) == STRIDED_SCAN){
|
||||||
std::unique_ptr<ir::metaparameter> tmp1(ir::metaparameter::create(ctx, ty, 4, 8));
|
std::unique_ptr<ir::metaparameter> tmp1(ir::metaparameter::create(ctx, ty, 1, 1));
|
||||||
std::unique_ptr<ir::metaparameter> tmp2(ir::metaparameter::create(ctx, ty, 4, 8));
|
std::unique_ptr<ir::metaparameter> tmp2(ir::metaparameter::create(ctx, ty, 1, 1));
|
||||||
*params_.at(i).at("nts.d0") = *tmp1;
|
*params_.at(i).at("nts.d0") = *tmp1;
|
||||||
*params_.at(i).at("nts.d1") = *tmp2;
|
*params_.at(i).at("nts.d1") = *tmp2;
|
||||||
// for(size_t k = 2; k < shapes.size(); k++)
|
// for(size_t k = 2; k < shapes.size(); k++)
|
||||||
@@ -423,7 +423,7 @@ bool tune::check_constraints(std::map<ir::value *, std::vector<std::string>> &er
|
|||||||
for(size_t k = 0; k < shapes.size(); k++){
|
for(size_t k = 0; k < shapes.size(); k++){
|
||||||
prod *= params_[i]["fpw.d" + std::to_string(k)]->get_value();
|
prod *= params_[i]["fpw.d" + std::to_string(k)]->get_value();
|
||||||
}
|
}
|
||||||
if(prod != 4)
|
if(prod > 4)
|
||||||
errors[i].push_back("HMMA must have only 4 fragments per warp");
|
errors[i].push_back("HMMA must have only 4 fragments per warp");
|
||||||
}
|
}
|
||||||
int num_threads = get_req_num_threads(i);
|
int num_threads = get_req_num_threads(i);
|
||||||
|
@@ -554,7 +554,6 @@ void selection::init_axes(ir::value *v, IRBuilder<> &builder, Value *u_thread_id
|
|||||||
num_packs_0_ = num_rep_0 / pack_size_0_;
|
num_packs_0_ = num_rep_0 / pack_size_0_;
|
||||||
num_packs_1_ = num_rep_1 / pack_size_1_;
|
num_packs_1_ = num_rep_1 / pack_size_1_;
|
||||||
|
|
||||||
|
|
||||||
/* intra warp offset */
|
/* intra warp offset */
|
||||||
// offset of quad in pair
|
// offset of quad in pair
|
||||||
Value *in_pair_off_a = builder.CreateMul(builder.CreateUDiv(builder.CreateAnd(u_thread_id, _16), builder.getInt32(4)),
|
Value *in_pair_off_a = builder.CreateMul(builder.CreateUDiv(builder.CreateAnd(u_thread_id, _16), builder.getInt32(4)),
|
||||||
@@ -566,7 +565,7 @@ void selection::init_axes(ir::value *v, IRBuilder<> &builder, Value *u_thread_id
|
|||||||
Value *pair_a_id = builder.CreateUDiv(builder.CreateURem(u_thread_id, _16), _4);
|
Value *pair_a_id = builder.CreateUDiv(builder.CreateURem(u_thread_id, _16), _4);
|
||||||
Value *pair_b_id = builder.CreateUDiv(builder.CreateURem(u_thread_id, _16), _4);
|
Value *pair_b_id = builder.CreateUDiv(builder.CreateURem(u_thread_id, _16), _4);
|
||||||
pair_a_id = builder.CreateURem(pair_a_id, builder.getInt32(fpw_0));
|
pair_a_id = builder.CreateURem(pair_a_id, builder.getInt32(fpw_0));
|
||||||
pair_b_id = builder.CreateUDiv(pair_b_id, builder.getInt32(fpw_0));
|
pair_b_id = builder.CreateURem(builder.CreateUDiv(pair_b_id, builder.getInt32(fpw_0)), builder.getInt32(fpw_1));
|
||||||
// Quad pair offset
|
// Quad pair offset
|
||||||
Value *pair_a_off = builder.CreateMul(pair_a_id, builder.getInt32(4 * pack_size_0_));
|
Value *pair_a_off = builder.CreateMul(pair_a_id, builder.getInt32(4 * pack_size_0_));
|
||||||
Value *pair_b_off = builder.CreateMul(pair_b_id, builder.getInt32(4 * pack_size_1_));
|
Value *pair_b_off = builder.CreateMul(pair_b_id, builder.getInt32(4 * pack_size_1_));
|
||||||
|
@@ -52,40 +52,56 @@ void dot::enqueue_impl(driver::stream *stream, driver::kernel *kernel,
|
|||||||
driver::buffer *b = args[1];
|
driver::buffer *b = args[1];
|
||||||
driver::buffer *c = args[2];
|
driver::buffer *c = args[2];
|
||||||
driver::buffer *lut = args[3];
|
driver::buffer *lut = args[3];
|
||||||
int32_t lda = N_;
|
|
||||||
int32_t ldc = N_;
|
|
||||||
kernel->setArg(0, a);
|
kernel->setArg(0, a);
|
||||||
kernel->setArg(1, b);
|
kernel->setArg(1, b);
|
||||||
kernel->setArg(2, c);
|
kernel->setArg(2, c);
|
||||||
kernel->setArg(3, lda);
|
if(op_ == FPROP || op_ == BPROP){
|
||||||
kernel->setArg(4, ldc);
|
kernel->setArg(3, N_);
|
||||||
kernel->setArg(5, N_);
|
kernel->setArg(4, BS_);
|
||||||
kernel->setArg(6, lut);
|
kernel->setArg(5, N_);
|
||||||
kernel->setArg(7, locks_.get());
|
}
|
||||||
kernel->setArg(8, nlocks_);
|
else{
|
||||||
int32_t TM = info.globals["TM"];
|
kernel->setArg(3, N_);
|
||||||
size_t grid_0 = (N_ + TM - 1) / TM;
|
kernel->setArg(4, N_);
|
||||||
size_t grid_1 = S_;
|
kernel->setArg(5, BS_);
|
||||||
if(nlocks_)
|
}
|
||||||
((driver::cu_buffer*)locks_.get())->set_zero(stream, grid_0 * nlocks_ * 2 * 4);
|
kernel->setArg(6, N_);
|
||||||
stream->enqueue(kernel, {grid_0, grid_1, 1}, {info.num_threads, 1, 1});
|
kernel->setArg(7, lut);
|
||||||
|
kernel->setArg(8, locks_.get());
|
||||||
|
kernel->setArg(9, nlocks_);
|
||||||
|
if(op_ == FPROP || op_ == BPROP){
|
||||||
|
int32_t TM = info.globals["TM"];
|
||||||
|
size_t grid_0 = (N_ + TM - 1) / TM;
|
||||||
|
size_t grid_1 = S_;
|
||||||
|
if(nlocks_)
|
||||||
|
((driver::cu_buffer*)locks_.get())->set_zero(stream, grid_0 * nlocks_ * 2 * 4);
|
||||||
|
stream->enqueue(kernel, {grid_0, grid_1, 1}, {info.num_threads, 1, 1});
|
||||||
|
}
|
||||||
|
else{
|
||||||
|
size_t grid_0 = nblocks_;
|
||||||
|
stream->enqueue(kernel, {grid_0, 1, 1}, {info.num_threads, 1, 1});
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
driver::buffer* dot::get_locks() const {
|
driver::buffer* dot::get_locks() const {
|
||||||
return locks_.get();
|
return locks_.get();
|
||||||
}
|
}
|
||||||
|
|
||||||
void dot::triton_c_src(std::ostream &os) const {
|
std::string dot::triton_c_src_ydx() const {
|
||||||
std::string usea = (op_ == WGRAD) ? "trans(a)" : "a";
|
bool AT = (op_ == WGRAD);
|
||||||
std::string useb = (op_ == FPROP) ? "trans(b)" : "b";
|
bool BT = (op_ == FPROP);
|
||||||
|
std::string usea = AT ? "trans(a)" : "a";
|
||||||
|
std::string useb = BT ? "trans(b)" : "b";
|
||||||
std::string sizea = "TM, TK";
|
std::string sizea = "TM, TK";
|
||||||
std::string sizeb = (op_ == FPROP) ? "TN, TK" : "TK, TN";
|
std::string sizeb = BT ? "TN, TK" : "TK, TN";
|
||||||
std::string bca0 = ":, newaxis";
|
std::string bca0 = ":, newaxis";
|
||||||
std::string bca1 = "newaxis, :";
|
std::string bca1 = "newaxis, :";
|
||||||
std::string bcb0 = (op_ == FPROP) ? ":, newaxis" : "newaxis, :";
|
std::string bcb0 = BT ? ":, newaxis" : "newaxis, :";
|
||||||
std::string bcb1 = (op_ == FPROP) ? "newaxis, :" : ":, newaxis";
|
std::string bcb1 = BT ? "newaxis, :" : ":, newaxis";
|
||||||
std::string ldb0 = (op_ == FPROP) ? "" : "*TK";
|
std::string lda0 = AT ? "*lda" : "";
|
||||||
std::string ldb1 = (op_ == FPROP) ? "*TK" : "" ;
|
std::string lda1 = AT ? "" : "*lda";
|
||||||
|
std::string ldb0 = BT ? "" : "*ldb";
|
||||||
|
std::string ldb1 = BT ? "*ldb" : "" ;
|
||||||
std::string result =
|
std::string result =
|
||||||
R"(
|
R"(
|
||||||
const tunable int TM = {16, 32, 64, 128};
|
const tunable int TM = {16, 32, 64, 128};
|
||||||
@@ -95,26 +111,25 @@ void dot::triton_c_src(std::ostream &os) const {
|
|||||||
void bsdot(restrict read_only align(16) )" + ab_ty_ + R"( *A,
|
void bsdot(restrict read_only align(16) )" + ab_ty_ + R"( *A,
|
||||||
restrict read_only align(16) )" + ab_ty_ + R"( *B,
|
restrict read_only align(16) )" + ab_ty_ + R"( *B,
|
||||||
)" + c_ty_ + R"(* C,
|
)" + c_ty_ + R"(* C,
|
||||||
int lda, int ldc, int N,
|
int lda, int ldb, int ldc,
|
||||||
int* lut, int* locks, int nlocks) {
|
int N, int* lut,
|
||||||
|
int* locks, int nlocks) {
|
||||||
int ridx = get_range_id(0);
|
int ridx = get_range_id(0);
|
||||||
int ridy = get_range_id(1);
|
|
||||||
float acc[TM, TN] = 0;
|
float acc[TM, TN] = 0;
|
||||||
int rxa[TM] = ridx * TM + (0 ... TM);
|
|
||||||
int ryb[TN] = 0 ... TN;
|
|
||||||
int rka[TK] = 0 ... TK;
|
int rka[TK] = 0 ... TK;
|
||||||
int rkb[TK] = 0 ... TK;
|
int rkb[TK] = 0 ... TK;
|
||||||
bool checka[TM, TK] = (rxa < N)[:, newaxis];
|
int *header = lut + get_range_id(1) * 4;
|
||||||
int offa[)" + sizea + "] = rxa[" + bca0 + "] + rka[" + bca1 + R"(]*lda;
|
|
||||||
int offb[)" + sizeb + "] = ryb[" + bcb0 + "]" + ldb0 + " + rkb[" + bcb1 + "]" + ldb1 + R"(;
|
|
||||||
int *header = lut + ridy * 4;
|
|
||||||
int offset = *(header + 0);
|
int offset = *(header + 0);
|
||||||
int K = *(header + 1);
|
int K = *(header + 1);
|
||||||
int column = *(header + 2);
|
int column = *(header + 2);
|
||||||
int lockid = *(header + 3);
|
int lockid = *(header + 3);
|
||||||
|
int rxa[TM] = ridx * TM + (0 ... TM);
|
||||||
|
int ryb[TN] = 0 ... TN;
|
||||||
int *plut = lut + offset * 2;
|
int *plut = lut + offset * 2;
|
||||||
for(int k = K; k > 0; k = k - 1)
|
int offa[)" + sizea + "] = rxa[" + bca0 + "]" + lda0 + " + rka[" + bca1 + "]" + lda1 + R"(;
|
||||||
{
|
int offb[)" + sizeb + "] = ryb[" + bcb0 + "]" + ldb0 + " + rkb[" + bcb1 + "]" + ldb1 + R"(;
|
||||||
|
bool checka[TM, TK] = (rxa < N)[:, newaxis];
|
||||||
|
for(int k = K; k > 0; k = k - 1) {
|
||||||
int ak = *(plut + 0);
|
int ak = *(plut + 0);
|
||||||
int bk = *(plut + 1);
|
int bk = *(plut + 1);
|
||||||
)" + ab_ty_ + "* pa[" + sizea + R"(] = A + offa + ak * TK * lda;
|
)" + ab_ty_ + "* pa[" + sizea + R"(] = A + offa + ak * TK * lda;
|
||||||
@@ -137,17 +152,83 @@ void dot::triton_c_src(std::ostream &os) const {
|
|||||||
int *pcount = plock + get_num_program(0)*nlocks;
|
int *pcount = plock + get_num_program(0)*nlocks;
|
||||||
while(__atomic_cas(plock, 0, 1));
|
while(__atomic_cas(plock, 0, 1));
|
||||||
int count = *pcount;
|
int count = *pcount;
|
||||||
if(count == 0){
|
if(count == 0)
|
||||||
@checkc *pc = c;
|
@checkc *pc = c;
|
||||||
}
|
else
|
||||||
else{
|
|
||||||
@checkc *pc = c + *pc;
|
@checkc *pc = c + *pc;
|
||||||
}
|
|
||||||
__atomic_exch(pcount, 1);
|
__atomic_exch(pcount, 1);
|
||||||
__atomic_exch(plock, 0);
|
__atomic_exch(plock, 0);
|
||||||
}
|
}
|
||||||
})";
|
})";
|
||||||
os << result;
|
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::string dot::triton_c_src_dw() const {
|
||||||
|
bool AT = (op_ == WGRAD);
|
||||||
|
bool BT = (op_ == FPROP);
|
||||||
|
std::string usea = AT ? "trans(a)" : "a";
|
||||||
|
std::string useb = BT ? "trans(b)" : "b";
|
||||||
|
std::string sizea = AT ? "TK, TM" : "TM, TK";
|
||||||
|
std::string sizeb = BT ? "TN, TK" : "TK, TN";
|
||||||
|
std::string bca0 = AT ? "newaxis, :" : ":, newaxis";
|
||||||
|
std::string bca1 = AT ? ":, newaxis" : "newaxis, :";
|
||||||
|
std::string bcb0 = BT ? ":, newaxis" : "newaxis, :";
|
||||||
|
std::string bcb1 = BT ? "newaxis, :" : ":, newaxis";
|
||||||
|
std::string lda0 = AT ? "*lda" : "";
|
||||||
|
std::string lda1 = AT ? "" : "*lda";
|
||||||
|
std::string ldb0 = BT ? "" : "*ldb";
|
||||||
|
std::string ldb1 = BT ? "*ldb" : "" ;
|
||||||
|
std::string result =
|
||||||
|
R"(
|
||||||
|
const tunable int TM = {)" + std::to_string(BS_) + R"(};
|
||||||
|
const tunable int TN = {)" + std::to_string(BS_) + R"(};
|
||||||
|
const tunable int TK = {32};
|
||||||
|
|
||||||
|
void bsdot(restrict read_only align(16) )" + ab_ty_ + R"( *A,
|
||||||
|
restrict read_only align(16) )" + ab_ty_ + R"( *B,
|
||||||
|
)" + c_ty_ + R"(* C,
|
||||||
|
int lda, int ldb, int ldc,
|
||||||
|
int N, int* lut,
|
||||||
|
int* locks, int nlocks) {
|
||||||
|
int ridx = get_range_id(0);
|
||||||
|
float acc[TM, TN] = 0;
|
||||||
|
int rka[TK] = 0 ... TK;
|
||||||
|
int rkb[TK] = 0 ... TK;
|
||||||
|
int *header = lut + ridx * 2;
|
||||||
|
int offx = *(header + 0);
|
||||||
|
int offy = *(header + 1);
|
||||||
|
int rxa[TM] = offx*TM + (0 ... TM);
|
||||||
|
int ryb[TN] = offy*TN + (0 ... TN);
|
||||||
|
bool checka[TK, TM] = (rka < N)[:, newaxis];
|
||||||
|
bool checkb[TK, TN] = (rkb < N)[:, newaxis];
|
||||||
|
int offa[)" + sizea + "] = rxa[" + bca0 + "]" + lda0 + " + rka[" + bca1 + "]" + lda1 + R"(;
|
||||||
|
int offb[)" + sizeb + "] = ryb[" + bcb0 + "]" + ldb0 + " + rkb[" + bcb1 + "]" + ldb1 + R"(;
|
||||||
|
)" + ab_ty_ + " * pa[" + sizea + R"(] = A + offa;
|
||||||
|
)" + ab_ty_ + " * pb[" + sizeb + R"(] = B + offb;
|
||||||
|
)" + ab_ty_ + " a[" + sizea + R"(] = checka ? *pa : 0;
|
||||||
|
)" + ab_ty_ + " b[" + sizeb + R"(] = checkb ? *pb : 0;
|
||||||
|
for(int k = N; k > 0; k = k - TK) {
|
||||||
|
acc = dot()" + usea + ", " + useb + R"(, acc);
|
||||||
|
pa = pa + TK)" + lda1 + R"(;
|
||||||
|
pb = pb + TK)" + ldb1 + R"(;
|
||||||
|
a = checka ? *pa : 0;
|
||||||
|
b = checkb ? *pb : 0;
|
||||||
|
}
|
||||||
|
int rxc[TM] = (0 ... TM);
|
||||||
|
int ryc[TN] = (0 ... TN);
|
||||||
|
)" + c_ty_ + R"( c[TM, TN] = acc;
|
||||||
|
)" + c_ty_ + R"(* pc[TM, TN] = C + rxc[:, newaxis]*TM + ryc[newaxis, :] + ridx*TM*TN;
|
||||||
|
*pc = c;
|
||||||
|
})";
|
||||||
|
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
void dot::triton_c_src(std::ostream &os) const {
|
||||||
|
if(op_ == FPROP || op_ == BPROP)
|
||||||
|
os << triton_c_src_ydx();
|
||||||
|
else
|
||||||
|
os << triton_c_src_dw();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@@ -106,8 +106,8 @@ void dot::triton_c_src(std::ostream &os) const {
|
|||||||
std::string align_ldb_str = "multiple_of(" + std::to_string(align_ldb_) + ")";
|
std::string align_ldb_str = "multiple_of(" + std::to_string(align_ldb_) + ")";
|
||||||
std::string res =
|
std::string res =
|
||||||
R"(
|
R"(
|
||||||
const tunable int TM = {32};
|
const tunable int TM = {8};
|
||||||
const tunable int TN = {32};
|
const tunable int TN = {8};
|
||||||
const tunable int TK = {32};
|
const tunable int TK = {32};
|
||||||
const tunable int GZ = {1};
|
const tunable int GZ = {1};
|
||||||
|
|
||||||
|
@@ -37,13 +37,13 @@ void parallel_loop_nest(std::vector<size_t> const & ranges,
|
|||||||
size_t D = ranges.size();
|
size_t D = ranges.size();
|
||||||
std::vector<size_t> values(D, 0);
|
std::vector<size_t> values(D, 0);
|
||||||
// thread pools
|
// thread pools
|
||||||
ThreadPool pool(nthreads);
|
// ThreadPool pool(nthreads);
|
||||||
// Start with innermost loop
|
// Start with innermost loop
|
||||||
size_t i = D - 1;
|
size_t i = D - 1;
|
||||||
while(true){
|
while(true){
|
||||||
// Execute function
|
// Execute function
|
||||||
pool.enqueue(f,values);
|
// pool.enqueue(f,values);
|
||||||
// f(values);
|
f(values);
|
||||||
while(values[i]++ == ranges[i] - 1){
|
while(values[i]++ == ranges[i] - 1){
|
||||||
if(i == 0)
|
if(i == 0)
|
||||||
return;
|
return;
|
||||||
|
Reference in New Issue
Block a user