[dnn/blocksparse/dot]: BlocksparseDx also working
This commit is contained in:
@@ -45,39 +45,6 @@ Status XpropShape(InferenceContext* ctx)
|
||||
}
|
||||
|
||||
|
||||
REGISTER_OP("TritonBlocksparseMatmul")
|
||||
.Input("x: T")
|
||||
.Input("w: T")
|
||||
.Input("lut: int64")
|
||||
.Input("lut_dx: int64")
|
||||
.Input("lut_dw: int64")
|
||||
.Input("gate: ngate * float")
|
||||
.Output("y: T")
|
||||
.Output("temp: int32")
|
||||
.Attr("T: {half, float, bfloat16}")
|
||||
.Attr("blocks: int >=0")
|
||||
.Attr("bsize: int")
|
||||
.Attr("segments: int = 0")
|
||||
.Attr("segments_dx: int = 0")
|
||||
.Attr("locks: int = 0")
|
||||
.Attr("locks_dx: int = 0")
|
||||
.Attr("axis: int = 1")
|
||||
.Attr("C: int >=0")
|
||||
.Attr("K: int >=0")
|
||||
.Attr("shared: int = 0")
|
||||
.Attr("shared_dx: int = 0")
|
||||
.Attr("alpha: float = 1.0")
|
||||
.Attr("beta: float = 0.0")
|
||||
.Attr("gated_dw: bool = false")
|
||||
.Attr("gate_grad: bool = false")
|
||||
.Attr("bench: int = 0")
|
||||
.Attr("ngate: int >= 0")
|
||||
.SetShapeFn(XpropShape)
|
||||
.Doc(R"doc(
|
||||
Multiply the matrix "a" by the blocksparse matrix "b".
|
||||
)doc");
|
||||
|
||||
|
||||
typedef struct bsmm_params
|
||||
{
|
||||
const int* Lut;
|
||||
@@ -101,7 +68,7 @@ typedef struct bsmm_params
|
||||
CUstream stream;
|
||||
} bsmm_params;
|
||||
|
||||
template<typename T>
|
||||
template<triton::dnn::blocksparse::op_t OP, typename T>
|
||||
class BlocksparseMatmulOp : public OpKernel {
|
||||
public:
|
||||
explicit BlocksparseMatmulOp(OpKernelConstruction* ctx) : OpKernel(ctx) {
|
||||
@@ -167,7 +134,7 @@ public:
|
||||
triton::driver::cu_buffer dlut(ctx, (CUdeviceptr)lut.flat<int64>().data(), false);
|
||||
triton::driver::cu_buffer dlocks(ctx, (CUdeviceptr)locks->flat<int32>().data(), false);
|
||||
// create profile
|
||||
triton::dnn::blocksparse::dot dot(N, params_.K, params_.segments, params_.C, "fp32", params_.bsize, params_.locks);
|
||||
triton::dnn::blocksparse::dot dot(N, params_.K, params_.segments, params_.C, "fp32", params_.bsize, params_.locks, OP);
|
||||
// blocksparse matmul
|
||||
dot.enqueue(stream, {&da, &db, &dc, &dlut, &dlocks}, triton::dnn::NO_TUNING);
|
||||
}
|
||||
@@ -180,5 +147,69 @@ private:
|
||||
char bench_string_[256];
|
||||
};
|
||||
|
||||
REGISTER_KERNEL_BUILDER(Name("TritonBlocksparseMatmul").Device(DEVICE_GPU).TypeConstraint<float>("T"), BlocksparseMatmulOp<float>);
|
||||
REGISTER_KERNEL_BUILDER(Name("TritonBlocksparseMatmul").Device(DEVICE_GPU).TypeConstraint<Eigen::half>("T"), BlocksparseMatmulOp<Eigen::half>);
|
||||
REGISTER_OP("TritonBlocksparseMatmul")
|
||||
.Input("x: T")
|
||||
.Input("w: T")
|
||||
.Input("lut: int64")
|
||||
.Input("lut_dx: int64")
|
||||
.Input("lut_dw: int64")
|
||||
.Input("gate: ngate * float")
|
||||
.Output("y: T")
|
||||
.Output("temp: int32")
|
||||
.Attr("T: {half, float, bfloat16}")
|
||||
.Attr("blocks: int >=0")
|
||||
.Attr("bsize: int")
|
||||
.Attr("segments: int = 0")
|
||||
.Attr("segments_dx: int = 0")
|
||||
.Attr("locks: int = 0")
|
||||
.Attr("locks_dx: int = 0")
|
||||
.Attr("axis: int = 1")
|
||||
.Attr("C: int >=0")
|
||||
.Attr("K: int >=0")
|
||||
.Attr("shared: int = 0")
|
||||
.Attr("shared_dx: int = 0")
|
||||
.Attr("alpha: float = 1.0")
|
||||
.Attr("beta: float = 0.0")
|
||||
.Attr("gated_dw: bool = false")
|
||||
.Attr("gate_grad: bool = false")
|
||||
.Attr("bench: int = 0")
|
||||
.Attr("ngate: int >= 0")
|
||||
.SetShapeFn(XpropShape)
|
||||
.Doc(R"doc(
|
||||
Multiply the matrix "a" by the blocksparse matrix "b".
|
||||
)doc");
|
||||
|
||||
REGISTER_KERNEL_BUILDER(Name("TritonBlocksparseMatmul").Device(DEVICE_GPU).TypeConstraint<float>("T"), BlocksparseMatmulOp<triton::dnn::blocksparse::FPROP, float>);
|
||||
REGISTER_KERNEL_BUILDER(Name("TritonBlocksparseMatmul").Device(DEVICE_GPU).TypeConstraint<Eigen::half>("T"), BlocksparseMatmulOp<triton::dnn::blocksparse::FPROP, Eigen::half>);
|
||||
|
||||
|
||||
REGISTER_OP("TritonBlocksparseMatmulDX")
|
||||
.Input("dy: T")
|
||||
.Input("w: T")
|
||||
.Input("lut: int64")
|
||||
.Input("gate: ngate * float")
|
||||
.Output("dx: T")
|
||||
.Output("temp: int32")
|
||||
.Attr("T: {half, float, bfloat16}")
|
||||
.Attr("blocks: int >=0")
|
||||
.Attr("bsize: int")
|
||||
.Attr("segments: int = 0")
|
||||
.Attr("locks: int = 0")
|
||||
.Attr("axis: int = 1")
|
||||
.Attr("C: int >=0")
|
||||
.Attr("K: int >=0")
|
||||
.Attr("shared: int = 0")
|
||||
.Attr("alpha: float = 1.0")
|
||||
.Attr("beta: float = 0.0")
|
||||
.Attr("gated_dw: bool = false")
|
||||
.Attr("gate_grad: bool = false")
|
||||
.Attr("bench: int = 0")
|
||||
.Attr("ngate: int >= 0")
|
||||
.SetShapeFn(XpropShape)
|
||||
.Doc(R"doc(
|
||||
Multiply the matrix "a" by the blocksparse matrix "b".
|
||||
)doc");
|
||||
|
||||
REGISTER_KERNEL_BUILDER(Name("TritonBlocksparseMatmulDX").Device(DEVICE_GPU).TypeConstraint<float>("T"),BlocksparseMatmulOp<triton::dnn::blocksparse::BPROP, float>);
|
||||
REGISTER_KERNEL_BUILDER(Name("TritonBlocksparseMatmulDX").Device(DEVICE_GPU).TypeConstraint<Eigen::half>("T"),BlocksparseMatmulOp<triton::dnn::blocksparse::BPROP, Eigen::half>);
|
||||
|
||||
|
@@ -7,6 +7,11 @@ namespace triton{
|
||||
namespace dnn{
|
||||
namespace blocksparse{
|
||||
|
||||
enum op_t{
|
||||
FPROP,
|
||||
BPROP,
|
||||
WGRAD
|
||||
};
|
||||
|
||||
class dot: public base {
|
||||
private:
|
||||
@@ -26,7 +31,7 @@ private:
|
||||
void deinit_impl();
|
||||
public:
|
||||
// constructor
|
||||
dot(int32_t N, int32_t K, int32_t S, int32_t C, const std::string &ty, int32_t BS, int32_t nlocks);
|
||||
dot(int32_t N, int32_t K, int32_t S, int32_t C, const std::string &ty, int32_t BS, int32_t nlocks, op_t op = FPROP);
|
||||
// triton-c source
|
||||
void triton_c_src(std::ostream &os) const;
|
||||
// clone
|
||||
@@ -42,6 +47,7 @@ private:
|
||||
int32_t BS_;
|
||||
int32_t nlocks_;
|
||||
driver::buffer *locks_;
|
||||
op_t op_;
|
||||
};
|
||||
|
||||
}
|
||||
|
@@ -247,14 +247,14 @@ void tune::run(ir::module &mod) {
|
||||
size_t addr_space = ptr_ty->get_pointer_address_space();
|
||||
if(addr_space < 4){
|
||||
ir::type *ty = mod.get_builder().get_int32_ty();
|
||||
std::unique_ptr<ir::metaparameter> tmp(ir::metaparameter::create(ctx, ty, 4, 8));
|
||||
std::unique_ptr<ir::metaparameter> tmp(ir::metaparameter::create(ctx, ty, 2, 2));
|
||||
*params_.at(i).at("nts.d0") = *tmp;
|
||||
}
|
||||
}
|
||||
if(dynamic_cast<ir::dot_inst*>(i) && i->get_type()->is_tile_ty()){
|
||||
ir::type *ty = mod.get_builder().get_int32_ty();
|
||||
std::unique_ptr<ir::metaparameter> tmp1(ir::metaparameter::create(ctx, ty, 4, 8));
|
||||
std::unique_ptr<ir::metaparameter> tmp2(ir::metaparameter::create(ctx, ty, 4, 8));
|
||||
std::unique_ptr<ir::metaparameter> tmp1(ir::metaparameter::create(ctx, ty, 2, 2));
|
||||
std::unique_ptr<ir::metaparameter> tmp2(ir::metaparameter::create(ctx, ty, 2, 2));
|
||||
*params_.at(i).at("nts.d0") = *tmp1;
|
||||
*params_.at(i).at("nts.d1") = *tmp2;
|
||||
}
|
||||
|
@@ -13,8 +13,8 @@ bool dot::operator <(const base& other) const {
|
||||
auto *y = dynamic_cast<const dot*>(&other);
|
||||
if(!y)
|
||||
return true;
|
||||
return std::tie(N_, S_, C_, BS_, nlocks_, ab_ty_, c_ty_)
|
||||
< std::tie(y->N_, y->S_, y->C_, y->BS_, y->nlocks_, y->ab_ty_, y->c_ty_);
|
||||
return std::tie(N_, S_, C_, BS_, nlocks_, ab_ty_, c_ty_, op_)
|
||||
< std::tie(y->N_, y->S_, y->C_, y->BS_, y->nlocks_, y->ab_ty_, y->c_ty_, y->op_);
|
||||
}
|
||||
|
||||
std::vector<params_t> dot::search_space() const {
|
||||
@@ -30,11 +30,11 @@ base * dot::clone() const {
|
||||
}
|
||||
|
||||
dot::dot(int32_t N, int32_t K, int32_t S, int32_t C,
|
||||
const std::string& ty, int32_t BS, int32_t nlocks):
|
||||
const std::string& ty, int32_t BS, int32_t nlocks, op_t op):
|
||||
base("bsdot"),
|
||||
N_(N), K_(K), S_(S), C_(C),
|
||||
ab_ty_(ty), c_ty_(ty),
|
||||
BS_(BS), nlocks_(nlocks) {
|
||||
BS_(BS), nlocks_(nlocks), op_(op){
|
||||
}
|
||||
|
||||
void dot::init_impl(driver::stream *stream, driver::cu_module *module) {
|
||||
@@ -72,27 +72,33 @@ void dot::enqueue_impl(driver::stream *stream, driver::kernel *kernel,
|
||||
int32_t TM = info.globals["TM"];
|
||||
size_t grid_0 = (N_ + TM - 1) / TM;
|
||||
size_t grid_1 = S_;
|
||||
std::cout << N_ << " " << grid_0 << std::endl;
|
||||
if(nlocks_){
|
||||
// locks_ = triton::driver::buffer::create(stream->context(), grid_0 * nlocks_ * 2 * 4);
|
||||
if(nlocks_)
|
||||
((driver::cu_buffer*)locks)->set_zero(stream, grid_0 * nlocks_ * 2 * 4);
|
||||
}
|
||||
stream->enqueue(kernel, {grid_0, grid_1, 1}, {info.num_threads, 1, 1});
|
||||
}
|
||||
|
||||
void dot::triton_c_src(std::ostream &os) const {
|
||||
std::string usea = (op_ == WGRAD) ? "trans(a)" : "a";
|
||||
std::string useb = (op_ == FPROP) ? "trans(b)" : "b";
|
||||
std::string sizea = "TM, TK";
|
||||
std::string sizeb = (op_ == FPROP) ? "TN, TK" : "TK, TN";
|
||||
std::string bca0 = ":, newaxis";
|
||||
std::string bca1 = "newaxis, :";
|
||||
std::string bcb0 = (op_ == FPROP) ? ":, newaxis" : "newaxis, :";
|
||||
std::string bcb1 = (op_ == FPROP) ? "newaxis, :" : ":, newaxis";
|
||||
std::string ldb0 = (op_ == FPROP) ? "1" : "TK";
|
||||
std::string ldb1 = (op_ == FPROP) ? "TK" : "1" ;
|
||||
std::string result =
|
||||
|
||||
R"(
|
||||
const tunable int32 TM = {64};
|
||||
const tunable int32 TN = {)" + std::to_string(BS_) + R"(};
|
||||
const tunable int32 TK = {)" + std::to_string(BS_) + R"(};
|
||||
|
||||
void bsdot(restrict read_only align(16) )" + ab_ty_ + R"( *A,
|
||||
restrict read_only align(16) )" + ab_ty_ + R"( *B,
|
||||
)" + c_ty_ + R"(* C,
|
||||
int32 lda, int32 ldc, int32 N,
|
||||
int32* lut, int32* locks, int32 nlocks){
|
||||
restrict read_only align(16) )" + ab_ty_ + R"( *B,
|
||||
)" + c_ty_ + R"(* C,
|
||||
int32 lda, int32 ldc, int32 N,
|
||||
int32* lut, int32* locks, int32 nlocks){
|
||||
int32 ridx = get_range_id(0);
|
||||
int32 ridy = get_range_id(1);
|
||||
fp32 acc[TM, TN] = 0;
|
||||
@@ -100,22 +106,22 @@ void dot::triton_c_src(std::ostream &os) const {
|
||||
int32 ryb[TN] = 0 ... TN;
|
||||
int32 rka[TK] = 0 ... TK;
|
||||
int32 rkb[TK] = 0 ... TK;
|
||||
int32 offa[TM, TK] = rxa[:, newaxis] + rka[newaxis, :]*lda;
|
||||
int32 offb[TK, TN] = ryb[newaxis, :] + rkb[:, newaxis]*TK;
|
||||
int32 offa[)" + sizea + "] = rxa[" + bca0 + "] + rka[" + bca1 + R"(]*lda;
|
||||
int32 offb[)" + sizeb + "] = ryb[" + bcb0 + "]*" + ldb0 + " + rkb[" + bcb1 + "]*" + ldb1 + R"(;
|
||||
int32 *header = lut + ridy * 4;
|
||||
int32 offset = *(header + 0);
|
||||
int32 K = *(header + 1);
|
||||
int32 column = *(header + 2);
|
||||
int32 lockid = *(header + 3);
|
||||
int32 lockid = *(header + 3);
|
||||
int32 *plut = lut + offset * 2;
|
||||
for(int32 k = K; k > 0; k = k - 1){
|
||||
int32 ak = *(plut + 0);
|
||||
int32 bk = *(plut + 1);
|
||||
)" + ab_ty_ + R"(* pa[TM, TK] = A + offa + ak * TK * lda;
|
||||
)" + ab_ty_ + R"(* pb[TK, TN] = B + offb + bk * TK * TN;
|
||||
)" + ab_ty_ + R"( a[TM, TK] = *pa;
|
||||
)" + ab_ty_ + R"( b[TK, TN] = *pb;
|
||||
acc = dot(a, b, acc);
|
||||
)" + ab_ty_ + "* pa[" + sizea + R"(] = A + offa + ak * TK * lda;
|
||||
)" + ab_ty_ + "* pb[" + sizeb + R"(] = B + offb + bk * TK * TN;
|
||||
)" + ab_ty_ + " a[" + sizea + R"(] = *pa;
|
||||
)" + ab_ty_ + " b[" + sizeb + R"(] = *pb;
|
||||
acc = dot()" + usea + ", " + useb + R"(, acc);
|
||||
plut = plut + 2;
|
||||
}
|
||||
int32 rxc[TM] = ridx * TM + (0 ... TM);
|
||||
@@ -123,20 +129,17 @@ void dot::triton_c_src(std::ostream &os) const {
|
||||
)" + c_ty_ + R"(" c[TM, TN] = acc;
|
||||
)" + c_ty_ + R"(* pc[TM, TN] = C + rxc[:, newaxis] + ryc[newaxis, :]*ldc;
|
||||
int1 checkc[TM, TN] = (rxc < N)[:, newaxis];
|
||||
if(lockid == 0){
|
||||
if(lockid == 0)
|
||||
@checkc *pc = c;
|
||||
}
|
||||
else{
|
||||
else {
|
||||
int32 *plock = locks + ridx*nlocks + lockid - 1;
|
||||
int32 *pcount = plock + get_num_program(0)*nlocks;
|
||||
while(__atomic_cas(plock, 0, 1));
|
||||
int32 count = *pcount;
|
||||
if(count == 0) {
|
||||
if(count == 0)
|
||||
@checkc *pc = c;
|
||||
}
|
||||
else {
|
||||
else
|
||||
@checkc *pc = c + *pc;
|
||||
}
|
||||
*pcount = 1;
|
||||
__atomic_exch(plock, 0);
|
||||
}
|
||||
|
Reference in New Issue
Block a user