[dnn/blocksparse/dot]: BlocksparseDx also working
This commit is contained in:
@@ -45,39 +45,6 @@ Status XpropShape(InferenceContext* ctx)
|
||||
}
|
||||
|
||||
|
||||
REGISTER_OP("TritonBlocksparseMatmul")
|
||||
.Input("x: T")
|
||||
.Input("w: T")
|
||||
.Input("lut: int64")
|
||||
.Input("lut_dx: int64")
|
||||
.Input("lut_dw: int64")
|
||||
.Input("gate: ngate * float")
|
||||
.Output("y: T")
|
||||
.Output("temp: int32")
|
||||
.Attr("T: {half, float, bfloat16}")
|
||||
.Attr("blocks: int >=0")
|
||||
.Attr("bsize: int")
|
||||
.Attr("segments: int = 0")
|
||||
.Attr("segments_dx: int = 0")
|
||||
.Attr("locks: int = 0")
|
||||
.Attr("locks_dx: int = 0")
|
||||
.Attr("axis: int = 1")
|
||||
.Attr("C: int >=0")
|
||||
.Attr("K: int >=0")
|
||||
.Attr("shared: int = 0")
|
||||
.Attr("shared_dx: int = 0")
|
||||
.Attr("alpha: float = 1.0")
|
||||
.Attr("beta: float = 0.0")
|
||||
.Attr("gated_dw: bool = false")
|
||||
.Attr("gate_grad: bool = false")
|
||||
.Attr("bench: int = 0")
|
||||
.Attr("ngate: int >= 0")
|
||||
.SetShapeFn(XpropShape)
|
||||
.Doc(R"doc(
|
||||
Multiply the matrix "a" by the blocksparse matrix "b".
|
||||
)doc");
|
||||
|
||||
|
||||
typedef struct bsmm_params
|
||||
{
|
||||
const int* Lut;
|
||||
@@ -101,7 +68,7 @@ typedef struct bsmm_params
|
||||
CUstream stream;
|
||||
} bsmm_params;
|
||||
|
||||
template<typename T>
|
||||
template<triton::dnn::blocksparse::op_t OP, typename T>
|
||||
class BlocksparseMatmulOp : public OpKernel {
|
||||
public:
|
||||
explicit BlocksparseMatmulOp(OpKernelConstruction* ctx) : OpKernel(ctx) {
|
||||
@@ -167,7 +134,7 @@ public:
|
||||
triton::driver::cu_buffer dlut(ctx, (CUdeviceptr)lut.flat<int64>().data(), false);
|
||||
triton::driver::cu_buffer dlocks(ctx, (CUdeviceptr)locks->flat<int32>().data(), false);
|
||||
// create profile
|
||||
triton::dnn::blocksparse::dot dot(N, params_.K, params_.segments, params_.C, "fp32", params_.bsize, params_.locks);
|
||||
triton::dnn::blocksparse::dot dot(N, params_.K, params_.segments, params_.C, "fp32", params_.bsize, params_.locks, OP);
|
||||
// blocksparse matmul
|
||||
dot.enqueue(stream, {&da, &db, &dc, &dlut, &dlocks}, triton::dnn::NO_TUNING);
|
||||
}
|
||||
@@ -180,5 +147,69 @@ private:
|
||||
char bench_string_[256];
|
||||
};
|
||||
|
||||
REGISTER_KERNEL_BUILDER(Name("TritonBlocksparseMatmul").Device(DEVICE_GPU).TypeConstraint<float>("T"), BlocksparseMatmulOp<float>);
|
||||
REGISTER_KERNEL_BUILDER(Name("TritonBlocksparseMatmul").Device(DEVICE_GPU).TypeConstraint<Eigen::half>("T"), BlocksparseMatmulOp<Eigen::half>);
|
||||
REGISTER_OP("TritonBlocksparseMatmul")
|
||||
.Input("x: T")
|
||||
.Input("w: T")
|
||||
.Input("lut: int64")
|
||||
.Input("lut_dx: int64")
|
||||
.Input("lut_dw: int64")
|
||||
.Input("gate: ngate * float")
|
||||
.Output("y: T")
|
||||
.Output("temp: int32")
|
||||
.Attr("T: {half, float, bfloat16}")
|
||||
.Attr("blocks: int >=0")
|
||||
.Attr("bsize: int")
|
||||
.Attr("segments: int = 0")
|
||||
.Attr("segments_dx: int = 0")
|
||||
.Attr("locks: int = 0")
|
||||
.Attr("locks_dx: int = 0")
|
||||
.Attr("axis: int = 1")
|
||||
.Attr("C: int >=0")
|
||||
.Attr("K: int >=0")
|
||||
.Attr("shared: int = 0")
|
||||
.Attr("shared_dx: int = 0")
|
||||
.Attr("alpha: float = 1.0")
|
||||
.Attr("beta: float = 0.0")
|
||||
.Attr("gated_dw: bool = false")
|
||||
.Attr("gate_grad: bool = false")
|
||||
.Attr("bench: int = 0")
|
||||
.Attr("ngate: int >= 0")
|
||||
.SetShapeFn(XpropShape)
|
||||
.Doc(R"doc(
|
||||
Multiply the matrix "a" by the blocksparse matrix "b".
|
||||
)doc");
|
||||
|
||||
REGISTER_KERNEL_BUILDER(Name("TritonBlocksparseMatmul").Device(DEVICE_GPU).TypeConstraint<float>("T"), BlocksparseMatmulOp<triton::dnn::blocksparse::FPROP, float>);
|
||||
REGISTER_KERNEL_BUILDER(Name("TritonBlocksparseMatmul").Device(DEVICE_GPU).TypeConstraint<Eigen::half>("T"), BlocksparseMatmulOp<triton::dnn::blocksparse::FPROP, Eigen::half>);
|
||||
|
||||
|
||||
REGISTER_OP("TritonBlocksparseMatmulDX")
|
||||
.Input("dy: T")
|
||||
.Input("w: T")
|
||||
.Input("lut: int64")
|
||||
.Input("gate: ngate * float")
|
||||
.Output("dx: T")
|
||||
.Output("temp: int32")
|
||||
.Attr("T: {half, float, bfloat16}")
|
||||
.Attr("blocks: int >=0")
|
||||
.Attr("bsize: int")
|
||||
.Attr("segments: int = 0")
|
||||
.Attr("locks: int = 0")
|
||||
.Attr("axis: int = 1")
|
||||
.Attr("C: int >=0")
|
||||
.Attr("K: int >=0")
|
||||
.Attr("shared: int = 0")
|
||||
.Attr("alpha: float = 1.0")
|
||||
.Attr("beta: float = 0.0")
|
||||
.Attr("gated_dw: bool = false")
|
||||
.Attr("gate_grad: bool = false")
|
||||
.Attr("bench: int = 0")
|
||||
.Attr("ngate: int >= 0")
|
||||
.SetShapeFn(XpropShape)
|
||||
.Doc(R"doc(
|
||||
Multiply the matrix "a" by the blocksparse matrix "b".
|
||||
)doc");
|
||||
|
||||
REGISTER_KERNEL_BUILDER(Name("TritonBlocksparseMatmulDX").Device(DEVICE_GPU).TypeConstraint<float>("T"),BlocksparseMatmulOp<triton::dnn::blocksparse::BPROP, float>);
|
||||
REGISTER_KERNEL_BUILDER(Name("TritonBlocksparseMatmulDX").Device(DEVICE_GPU).TypeConstraint<Eigen::half>("T"),BlocksparseMatmulOp<triton::dnn::blocksparse::BPROP, Eigen::half>);
|
||||
|
||||
|
Reference in New Issue
Block a user