[dnn/blocksparse] FPROP test passes!
This commit is contained in:
@@ -101,6 +101,7 @@ typedef struct bsmm_params
|
||||
CUstream stream;
|
||||
} bsmm_params;
|
||||
|
||||
template<typename T>
|
||||
class BlocksparseMatmulOp : public OpKernel {
|
||||
public:
|
||||
explicit BlocksparseMatmulOp(OpKernelConstruction* ctx) : OpKernel(ctx) {
|
||||
@@ -152,29 +153,23 @@ public:
|
||||
shape_c.AddDim(params_.K);
|
||||
Tensor* c = nullptr;
|
||||
OP_REQUIRES_OK(context, context->allocate_output(0, shape_c, &c));
|
||||
// grid and block
|
||||
int blkN = 128, gridN = (N + 127)/128, modN128 = N & 127;
|
||||
if (axis_ == 1 || (modN128 > 0 && modN128 <= 64) || gridN * params_.segments < SMs_*4){
|
||||
blkN = 64;
|
||||
gridN = (N + 63)/64;
|
||||
}
|
||||
// allocate locks
|
||||
int gridN = (N + 63)/64;
|
||||
Tensor* locks;
|
||||
TensorShape shape_l;
|
||||
if (params_.locks > 0)
|
||||
shape_l.AddDim(gridN * params_.locks * 2);
|
||||
OP_REQUIRES_OK(context, context->allocate_output(1, shape_l, &locks));
|
||||
// initialize default compute device
|
||||
triton::runtime::jit jit(ctx);
|
||||
// matrix multiplication parameters
|
||||
triton::driver::cu_buffer da(ctx, (CUdeviceptr)a.flat<float>().data(), false);
|
||||
triton::driver::cu_buffer db(ctx, (CUdeviceptr)b.flat<float>().data(), false);
|
||||
triton::driver::cu_buffer dc(ctx, (CUdeviceptr)c->flat<float>().data(), false);
|
||||
// triton::driver::cu_buffer dlocks(ctx, (CUdeviceptr)locks->flat<int32>().data(), false);
|
||||
// wrap tensorflow handles
|
||||
triton::driver::cu_buffer da(ctx, (CUdeviceptr)a.flat<T>().data(), false);
|
||||
triton::driver::cu_buffer db(ctx, (CUdeviceptr)b.flat<T>().data(), false);
|
||||
triton::driver::cu_buffer dc(ctx, (CUdeviceptr)c->flat<T>().data(), false);
|
||||
triton::driver::cu_buffer dlut(ctx, (CUdeviceptr)lut.flat<int64>().data(), false);
|
||||
triton::driver::cu_buffer dlocks(ctx, (CUdeviceptr)locks->flat<int32>().data(), false);
|
||||
// create profile
|
||||
triton::dnn::blocksparse::dot dot(N, params_.K, params_.segments, params_.C, "fp32", params_.bsize, params_.locks);
|
||||
// blocksparse matmul
|
||||
triton::dnn::blocksparse::dot dot(N, params_.K, params_.C);
|
||||
dot.enqueue(stream, {&da, &db, &dc, &dlut}, triton::dnn::NO_TUNING);
|
||||
dot.enqueue(stream, {&da, &db, &dc, &dlut, &dlocks}, triton::dnn::NO_TUNING);
|
||||
}
|
||||
|
||||
private:
|
||||
@@ -185,4 +180,5 @@ private:
|
||||
char bench_string_[256];
|
||||
};
|
||||
|
||||
REGISTER_KERNEL_BUILDER(Name("TritonBlocksparseMatmul").Device(DEVICE_GPU).TypeConstraint<float>("T"), BlocksparseMatmulOp);
|
||||
REGISTER_KERNEL_BUILDER(Name("TritonBlocksparseMatmul").Device(DEVICE_GPU).TypeConstraint<float>("T"), BlocksparseMatmulOp<float>);
|
||||
REGISTER_KERNEL_BUILDER(Name("TritonBlocksparseMatmul").Device(DEVICE_GPU).TypeConstraint<Eigen::half>("T"), BlocksparseMatmulOp<Eigen::half>);
|
||||
|
Reference in New Issue
Block a user