[dnn/blocksparse/dot] prototype version seems to pass basic test
This commit is contained in:
@@ -4,7 +4,7 @@
|
|||||||
#include "triton/runtime/jit.h"
|
#include "triton/runtime/jit.h"
|
||||||
#include "triton/driver/backend.h"
|
#include "triton/driver/backend.h"
|
||||||
#include "triton/driver/stream.h"
|
#include "triton/driver/stream.h"
|
||||||
#include "triton/dnn/gemm.h"
|
#include "triton/dnn/dot.h"
|
||||||
#include "triton/tools/bench.hpp"
|
#include "triton/tools/bench.hpp"
|
||||||
#include "cuda.h"
|
#include "cuda.h"
|
||||||
|
|
||||||
@@ -48,7 +48,7 @@ perf_t do_bench(triton::driver::stream* stream, bool AT, bool BT, int32_t M, int
|
|||||||
stream->synchronize();
|
stream->synchronize();
|
||||||
triton::dnn::dot dot(M, N, K, AT, BT, ty, ty, 8, 8);
|
triton::dnn::dot dot(M, N, K, AT, BT, ty, ty, 8, 8);
|
||||||
// benchmark triton
|
// benchmark triton
|
||||||
double triton_ns = triton::tools::bench([&]() { dot.enqueue(stream, {da, db, dc}, triton::dnn::FULL_TUNING);}, stream);
|
double triton_ns = triton::tools::bench([&]() { dot.enqueue(stream, {da, db, dc}, triton::dnn::PARTIAL_TUNING);}, stream);
|
||||||
// benchmark cublas
|
// benchmark cublas
|
||||||
NumericT alpha = 1;
|
NumericT alpha = 1;
|
||||||
NumericT beta = 0;
|
NumericT beta = 0;
|
||||||
@@ -98,8 +98,9 @@ int main() {
|
|||||||
// shapes to benchmark
|
// shapes to benchmark
|
||||||
std::vector<config_t> configs = {
|
std::vector<config_t> configs = {
|
||||||
// {false, false, 8192, 512, 512},
|
// {false, false, 8192, 512, 512},
|
||||||
{false, true, 8192, 8192, 8192}
|
// {false, true, 8192, 8192, 8192}
|
||||||
// {false, true, 32768, 256, 512}
|
{false, true, 32768, 256, 256},
|
||||||
|
{false, true, 32768, 256, 512}
|
||||||
// {true, false, 8192, 512, 512},
|
// {true, false, 8192, 512, 512},
|
||||||
// {true, true, 8192, 512, 512}
|
// {true, true, 8192, 512, 512}
|
||||||
};
|
};
|
||||||
|
@@ -67,23 +67,23 @@ perf_t do_bench(triton::driver::stream *stream,
|
|||||||
stream->write(dc, true, 0, hc);
|
stream->write(dc, true, 0, hc);
|
||||||
stream->synchronize();
|
stream->synchronize();
|
||||||
// benchmark triton
|
// benchmark triton
|
||||||
double triton_ns = triton::tools::bench([&]() { shift.enqueue(stream, {da, db, dc}, triton::dnn::FULL_TUNING);}, stream);
|
double triton_ns = triton::tools::bench([&]() { shift.enqueue(stream, {da, db, dc}, triton::dnn::NO_TUNING);}, stream);
|
||||||
// benchmark cublas
|
// benchmark cublas
|
||||||
NumericT alpha = 1;
|
// NumericT alpha = 1;
|
||||||
NumericT beta = 0;
|
// NumericT beta = 0;
|
||||||
cublasGemmAlgo_t fastest;
|
// cublasGemmAlgo_t fastest;
|
||||||
cublasGemm(HALF_TYPE, stream, shift.AT(), shift.BT(), shift.M(), shift.N(), shift.K(),
|
// cublasGemm(HALF_TYPE, stream, shift.AT(), shift.BT(), shift.M(), shift.N(), shift.K(),
|
||||||
&alpha, da, shift.lda(),
|
// &alpha, da, shift.lda(),
|
||||||
db, shift.ldb(), &beta,
|
// db, shift.ldb(), &beta,
|
||||||
dc, shift.ldc(), &fastest);
|
// dc, shift.ldc(), &fastest);
|
||||||
double cublas_ns = triton::tools::bench([&]() { cublasGemm(HALF_TYPE, stream, shift.AT(), shift.BT(), shift.M(), shift.N(), shift.K(),
|
// double cublas_ns = triton::tools::bench([&]() { cublasGemm(HALF_TYPE, stream, shift.AT(), shift.BT(), shift.M(), shift.N(), shift.K(),
|
||||||
&alpha, da, shift.lda(),
|
// &alpha, da, shift.lda(),
|
||||||
db, shift.ldb(),
|
// db, shift.ldb(),
|
||||||
&beta, dc, shift.ldc(), nullptr, fastest); }, stream);
|
// &beta, dc, shift.ldc(), nullptr, fastest); }, stream);
|
||||||
// result
|
// result
|
||||||
auto tflops = [&](double nanosec) { return shift.num_flops() / nanosec * 1e-3; };
|
auto tflops = [&](double nanosec) { return shift.num_flops() / nanosec * 1e-3; };
|
||||||
perf_t result;
|
perf_t result;
|
||||||
result.cublas = tflops(cublas_ns);
|
// result.cublas = tflops(cublas_ns);
|
||||||
result.triton = tflops(triton_ns);
|
result.triton = tflops(triton_ns);
|
||||||
delete da;
|
delete da;
|
||||||
delete db;
|
delete db;
|
||||||
@@ -133,8 +133,9 @@ int main() {
|
|||||||
{128, 1024, 8, 8, 3, 3, 1024, 1, 1}
|
{128, 1024, 8, 8, 3, 3, 1024, 1, 1}
|
||||||
};
|
};
|
||||||
for(config_t c: resnet18){
|
for(config_t c: resnet18){
|
||||||
for(op_t op: {op_t::FPROP, op_t::BPROP, op_t::WGRAD})
|
for(op_t op: {op_t::FPROP, op_t::BPROP, op_t::WGRAD}){
|
||||||
configs.push_back({c.B, c.C, c.H, c.W, c.R, c.S, c.F, c.stride_h, c.stride_w, op, layout_t::CHWN, "fp16"});
|
configs.push_back({c.B, c.C, c.H, c.W, c.R, c.S, c.F, c.stride_h, c.stride_w, op, layout_t::CHWN, "fp16"});
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// initialize default compute device
|
// initialize default compute device
|
||||||
|
@@ -5,7 +5,7 @@ if(${TensorFlow_FOUND})
|
|||||||
include_directories("${CUDA_HOME}/include")
|
include_directories("${CUDA_HOME}/include")
|
||||||
link_directories(${TF_LIB})
|
link_directories(${TF_LIB})
|
||||||
add_definitions(-D_GLIBCXX_USE_CXX11_ABI=${TF_ABI})
|
add_definitions(-D_GLIBCXX_USE_CXX11_ABI=${TF_ABI})
|
||||||
add_library(tf_blocksparse SHARED dot.cpp conv.cpp shift.cpp batchnorm.cpp)
|
add_library(tf_blocksparse SHARED blocksparse.cpp dot.cpp conv.cpp shift.cpp batchnorm.cpp)
|
||||||
target_link_libraries(tf_blocksparse tensorflow_framework triton)
|
target_link_libraries(tf_blocksparse tensorflow_framework triton)
|
||||||
configure_file(${CMAKE_CURRENT_SOURCE_DIR}/run.py
|
configure_file(${CMAKE_CURRENT_SOURCE_DIR}/run.py
|
||||||
${CMAKE_CURRENT_BINARY_DIR}/run.py
|
${CMAKE_CURRENT_BINARY_DIR}/run.py
|
||||||
|
@@ -3,7 +3,8 @@
|
|||||||
#include "triton/driver/buffer.h"
|
#include "triton/driver/buffer.h"
|
||||||
#include "triton/driver/backend.h"
|
#include "triton/driver/backend.h"
|
||||||
#include "triton/driver/stream.h"
|
#include "triton/driver/stream.h"
|
||||||
#include "triton/jit.h"
|
#include "triton/runtime/jit.h"
|
||||||
|
#include "triton/dnn/blocksparse/dot.h"
|
||||||
|
|
||||||
#define EIGEN_USE_GPU
|
#define EIGEN_USE_GPU
|
||||||
#include "tensorflow/core/framework/op.h"
|
#include "tensorflow/core/framework/op.h"
|
||||||
@@ -20,106 +21,88 @@ using shape_inference::InferenceContext;
|
|||||||
using shape_inference::ShapeHandle;
|
using shape_inference::ShapeHandle;
|
||||||
using GPUDevice = Eigen::GpuDevice;
|
using GPUDevice = Eigen::GpuDevice;
|
||||||
|
|
||||||
|
|
||||||
const char* src =
|
|
||||||
R"(
|
|
||||||
const tunable int32 TM = {16, 32, 64, 128};
|
|
||||||
const tunable int32 TN = {16, 32, 64, 128};
|
|
||||||
const tunable int32 TK = {8};
|
|
||||||
const tunable int32 GZ = {1};
|
|
||||||
|
|
||||||
void bsmm (restrict read_only fp32 *A, restrict read_only fp32 *B, fp32 *C,
|
|
||||||
int32 M, int32 N, int32 K,
|
|
||||||
int32 lda, int32 ldb, int32 ldc,
|
|
||||||
int32 *locks, int32 grid0, int32 grid1) {
|
|
||||||
|
|
||||||
}
|
|
||||||
)";
|
|
||||||
|
|
||||||
Status XpropShape(InferenceContext* ctx)
|
Status XpropShape(InferenceContext* ctx)
|
||||||
{
|
{
|
||||||
int K; TF_RETURN_IF_ERROR(ctx->GetAttr( "K", &K));
|
int K; TF_RETURN_IF_ERROR(ctx->GetAttr( "K", &K));
|
||||||
int axis; TF_RETURN_IF_ERROR(ctx->GetAttr("axis", &axis));
|
int axis; TF_RETURN_IF_ERROR(ctx->GetAttr("axis", &axis));
|
||||||
|
|
||||||
// C ==> K
|
// C ==> K
|
||||||
ShapeHandle x = ctx->input(0);
|
ShapeHandle x = ctx->input(0);
|
||||||
int rank = ctx->Rank(x);
|
int rank = ctx->Rank(x);
|
||||||
//printf("XpropShape: %d\n", rank);
|
//printf("XpropShape: %d\n", rank);
|
||||||
if (rank > 0)
|
if (rank > 0)
|
||||||
{
|
{
|
||||||
std::vector<DimensionHandle> shape;
|
std::vector<DimensionHandle> shape;
|
||||||
shape.reserve(rank);
|
shape.reserve(rank);
|
||||||
for (int i = 0; i < rank; i++)
|
for (int i = 0; i < rank; i++)
|
||||||
shape.push_back(i == axis ? ctx->MakeDim(K) : ctx->Dim(x, i));
|
shape.push_back(i == axis ? ctx->MakeDim(K) : ctx->Dim(x, i));
|
||||||
|
ctx->set_output(0, ctx->MakeShape(shape));
|
||||||
ctx->set_output(0, ctx->MakeShape(shape));
|
}
|
||||||
}
|
else
|
||||||
else
|
ctx->set_output(0, ctx->UnknownShape());
|
||||||
ctx->set_output(0, ctx->UnknownShape());
|
ctx->set_output(1, ctx->UnknownShape());
|
||||||
ctx->set_output(1, ctx->UnknownShape());
|
return Status::OK();
|
||||||
return Status::OK();
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
REGISTER_OP("BlocksparseMatmul")
|
REGISTER_OP("TritonBlocksparseMatmul")
|
||||||
.Input("x: T")
|
.Input("x: T")
|
||||||
.Input("w: T")
|
.Input("w: T")
|
||||||
.Input("lut: int64")
|
.Input("lut: int64")
|
||||||
.Input("lut_dx: int64")
|
.Input("lut_dx: int64")
|
||||||
.Input("lut_dw: int64")
|
.Input("lut_dw: int64")
|
||||||
.Input("gate: ngate * float")
|
.Input("gate: ngate * float")
|
||||||
.Output("y: T")
|
.Output("y: T")
|
||||||
.Output("temp: int32")
|
.Output("temp: int32")
|
||||||
.Attr("T: {half, float, bfloat16}")
|
.Attr("T: {half, float, bfloat16}")
|
||||||
.Attr("blocks: int >=0")
|
.Attr("blocks: int >=0")
|
||||||
.Attr("bsize: int")
|
.Attr("bsize: int")
|
||||||
.Attr("segments: int = 0")
|
.Attr("segments: int = 0")
|
||||||
.Attr("segments_dx: int = 0")
|
.Attr("segments_dx: int = 0")
|
||||||
.Attr("locks: int = 0")
|
.Attr("locks: int = 0")
|
||||||
.Attr("locks_dx: int = 0")
|
.Attr("locks_dx: int = 0")
|
||||||
.Attr("axis: int = 1")
|
.Attr("axis: int = 1")
|
||||||
.Attr("C: int >=0")
|
.Attr("C: int >=0")
|
||||||
.Attr("K: int >=0")
|
.Attr("K: int >=0")
|
||||||
.Attr("shared: int = 0")
|
.Attr("shared: int = 0")
|
||||||
.Attr("shared_dx: int = 0")
|
.Attr("shared_dx: int = 0")
|
||||||
.Attr("alpha: float = 1.0")
|
.Attr("alpha: float = 1.0")
|
||||||
.Attr("beta: float = 0.0")
|
.Attr("beta: float = 0.0")
|
||||||
.Attr("gated_dw: bool = false")
|
.Attr("gated_dw: bool = false")
|
||||||
.Attr("gate_grad: bool = false")
|
.Attr("gate_grad: bool = false")
|
||||||
.Attr("bench: int = 0")
|
.Attr("bench: int = 0")
|
||||||
.Attr("ngate: int >= 0")
|
.Attr("ngate: int >= 0")
|
||||||
.SetShapeFn(XpropShape)
|
.SetShapeFn(XpropShape)
|
||||||
.Doc(R"doc(
|
.Doc(R"doc(
|
||||||
Multiply the matrix "a" by the blocksparse matrix "b".
|
Multiply the matrix "a" by the blocksparse matrix "b".
|
||||||
)doc");
|
)doc");
|
||||||
|
|
||||||
|
|
||||||
typedef struct bsmm_params
|
typedef struct bsmm_params
|
||||||
{
|
{
|
||||||
const int* Lut;
|
const int* Lut;
|
||||||
const float* Gate;
|
const float* Gate;
|
||||||
int* Lock;
|
int* Lock;
|
||||||
//float4* Scratch;
|
int blocks;
|
||||||
int blocks;
|
int bsize;
|
||||||
int bsize;
|
int segments;
|
||||||
int segments;
|
int locks;
|
||||||
int locks;
|
int C;
|
||||||
int C;
|
int K;
|
||||||
int K;
|
int N;
|
||||||
int N;
|
int shared;
|
||||||
int shared;
|
int pcount;
|
||||||
int pcount;
|
uint blk_a;
|
||||||
uint blk_a;
|
uint blk_A;
|
||||||
uint blk_A;
|
uint blk_b;
|
||||||
uint blk_b;
|
uint blk_B;
|
||||||
uint blk_B;
|
float alpha;
|
||||||
float alpha;
|
float beta;
|
||||||
float beta;
|
CUstream stream;
|
||||||
CUstream stream;
|
|
||||||
} bsmm_params;
|
} bsmm_params;
|
||||||
|
|
||||||
class BlocksparseMatmulOp : public OpKernel {
|
class BlocksparseMatmulOp : public OpKernel {
|
||||||
public:
|
public:
|
||||||
explicit BlocksparseMatmulOp(OpKernelConstruction* ctx) : OpKernel(ctx) {
|
explicit BlocksparseMatmulOp(OpKernelConstruction* ctx) : OpKernel(ctx) {
|
||||||
OP_REQUIRES_OK(ctx, ctx->GetAttr("segments", ¶ms_.segments));
|
OP_REQUIRES_OK(ctx, ctx->GetAttr("segments", ¶ms_.segments));
|
||||||
OP_REQUIRES_OK(ctx, ctx->GetAttr("locks", ¶ms_.locks ));
|
OP_REQUIRES_OK(ctx, ctx->GetAttr("locks", ¶ms_.locks ));
|
||||||
@@ -147,6 +130,51 @@ class BlocksparseMatmulOp : public OpKernel {
|
|||||||
}
|
}
|
||||||
|
|
||||||
void Compute(OpKernelContext* context){
|
void Compute(OpKernelContext* context){
|
||||||
|
// get device/stream
|
||||||
|
GPUDevice device = context->eigen_device<GPUDevice>();
|
||||||
|
triton::driver::cu_stream sstream(device.stream(), false);
|
||||||
|
triton::driver::context* ctx = sstream.context();
|
||||||
|
triton::driver::stream* stream = &sstream;
|
||||||
|
// get inputs
|
||||||
|
const Tensor& a = context->input(0);
|
||||||
|
const Tensor& b = context->input(1);
|
||||||
|
const Tensor& lut = context->input(2);
|
||||||
|
// allocate c
|
||||||
|
TensorShape shape_c;
|
||||||
|
int N = 1;
|
||||||
|
int rank_a = a.dims();
|
||||||
|
for (int i = 0; i < rank_a; i++)
|
||||||
|
if (i != axis_) {
|
||||||
|
shape_c.AddDim(a.dim_size(i));
|
||||||
|
N *= a.dim_size(i);
|
||||||
|
}
|
||||||
|
else
|
||||||
|
shape_c.AddDim(params_.K);
|
||||||
|
Tensor* c = nullptr;
|
||||||
|
OP_REQUIRES_OK(context, context->allocate_output(0, shape_c, &c));
|
||||||
|
// grid and block
|
||||||
|
int blkN = 128, gridN = (N + 127)/128, modN128 = N & 127;
|
||||||
|
if (axis_ == 1 || (modN128 > 0 && modN128 <= 64) || gridN * params_.segments < SMs_*4){
|
||||||
|
blkN = 64;
|
||||||
|
gridN = (N + 63)/64;
|
||||||
|
}
|
||||||
|
// allocate locks
|
||||||
|
Tensor* locks;
|
||||||
|
TensorShape shape_l;
|
||||||
|
if (params_.locks > 0)
|
||||||
|
shape_l.AddDim(gridN * params_.locks * 2);
|
||||||
|
OP_REQUIRES_OK(context, context->allocate_output(1, shape_l, &locks));
|
||||||
|
// initialize default compute device
|
||||||
|
triton::runtime::jit jit(ctx);
|
||||||
|
// matrix multiplication parameters
|
||||||
|
triton::driver::cu_buffer da(ctx, (CUdeviceptr)a.flat<float>().data(), false);
|
||||||
|
triton::driver::cu_buffer db(ctx, (CUdeviceptr)b.flat<float>().data(), false);
|
||||||
|
triton::driver::cu_buffer dc(ctx, (CUdeviceptr)c->flat<float>().data(), false);
|
||||||
|
// triton::driver::cu_buffer dlocks(ctx, (CUdeviceptr)locks->flat<int32>().data(), false);
|
||||||
|
triton::driver::cu_buffer dlut(ctx, (CUdeviceptr)lut.flat<int64>().data(), false);
|
||||||
|
// blocksparse matmul
|
||||||
|
triton::dnn::blocksparse::dot dot(N, params_.K, params_.C);
|
||||||
|
dot.enqueue(stream, {&da, &db, &dc, &dlut}, triton::dnn::NO_TUNING);
|
||||||
}
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
@@ -157,4 +185,4 @@ private:
|
|||||||
char bench_string_[256];
|
char bench_string_[256];
|
||||||
};
|
};
|
||||||
|
|
||||||
REGISTER_KERNEL_BUILDER(Name("BlocksparseMatmul").Device(DEVICE_GPU).TypeConstraint<float>("T"), BlocksparseMatmulOp);
|
REGISTER_KERNEL_BUILDER(Name("TritonBlocksparseMatmul").Device(DEVICE_GPU).TypeConstraint<float>("T"), BlocksparseMatmulOp);
|
||||||
|
@@ -5,7 +5,6 @@
|
|||||||
#include "triton/driver/stream.h"
|
#include "triton/driver/stream.h"
|
||||||
#include "triton/runtime/jit.h"
|
#include "triton/runtime/jit.h"
|
||||||
#include "triton/tools/bench.hpp"
|
#include "triton/tools/bench.hpp"
|
||||||
#include "triton/dnn/gemm.h"
|
|
||||||
#include "triton/dnn/conv.h"
|
#include "triton/dnn/conv.h"
|
||||||
|
|
||||||
#define EIGEN_USE_GPU
|
#define EIGEN_USE_GPU
|
||||||
|
@@ -5,7 +5,7 @@
|
|||||||
#include "triton/driver/stream.h"
|
#include "triton/driver/stream.h"
|
||||||
#include "triton/runtime/jit.h"
|
#include "triton/runtime/jit.h"
|
||||||
#include "triton/tools/bench.hpp"
|
#include "triton/tools/bench.hpp"
|
||||||
#include "triton/dnn/gemm.h"
|
#include "triton/dnn/dot.h"
|
||||||
|
|
||||||
#define EIGEN_USE_GPU
|
#define EIGEN_USE_GPU
|
||||||
#include "tensorflow/core/framework/op.h"
|
#include "tensorflow/core/framework/op.h"
|
||||||
|
@@ -19,10 +19,10 @@
|
|||||||
using namespace tensorflow;
|
using namespace tensorflow;
|
||||||
using GPUDevice = Eigen::GpuDevice;
|
using GPUDevice = Eigen::GpuDevice;
|
||||||
|
|
||||||
template<triton::dnn::shift::op_t OP>
|
template<triton::dnn::op_t OP>
|
||||||
class ShiftConvOp : public OpKernel {
|
class ShiftConvOp : public OpKernel {
|
||||||
public:
|
public:
|
||||||
explicit ShiftConvOp(OpKernelConstruction* context) : OpKernel(context), layout_(triton::dnn::shift::NCHW) {
|
explicit ShiftConvOp(OpKernelConstruction* context) : OpKernel(context), layout_(triton::dnn::NCHW) {
|
||||||
context->GetAttr("shift_h", &h_shift_h_);
|
context->GetAttr("shift_h", &h_shift_h_);
|
||||||
context->GetAttr("shift_w", &h_shift_w_);
|
context->GetAttr("shift_w", &h_shift_w_);
|
||||||
context->GetAttr("stride_h", &stride_h_);
|
context->GetAttr("stride_h", &stride_h_);
|
||||||
@@ -32,13 +32,13 @@ public:
|
|||||||
}
|
}
|
||||||
|
|
||||||
void ExtractShapes(const Tensor &x, int64_t &C, int64_t &H, int64_t &W, int64_t &B) {
|
void ExtractShapes(const Tensor &x, int64_t &C, int64_t &H, int64_t &W, int64_t &B) {
|
||||||
if(layout_ == triton::dnn::shift::CHWN){
|
if(layout_ == triton::dnn::CHWN){
|
||||||
C = x.dim_size(0);
|
C = x.dim_size(0);
|
||||||
H = x.dim_size(1);
|
H = x.dim_size(1);
|
||||||
W = x.dim_size(2);
|
W = x.dim_size(2);
|
||||||
B = x.dim_size(3);
|
B = x.dim_size(3);
|
||||||
}
|
}
|
||||||
else if(layout_ == triton::dnn::shift::NCHW){
|
else if(layout_ == triton::dnn::NCHW){
|
||||||
B = x.dim_size(0);
|
B = x.dim_size(0);
|
||||||
C = x.dim_size(1);
|
C = x.dim_size(1);
|
||||||
H = x.dim_size(2);
|
H = x.dim_size(2);
|
||||||
@@ -52,7 +52,7 @@ public:
|
|||||||
void FillShapes(OpKernelContext* context,
|
void FillShapes(OpKernelContext* context,
|
||||||
int64_t &C, int64_t &H, int64_t &W, int64_t &B, int64_t &F,
|
int64_t &C, int64_t &H, int64_t &W, int64_t &B, int64_t &F,
|
||||||
const Tensor& tf_a, const Tensor& tf_b) {
|
const Tensor& tf_a, const Tensor& tf_b) {
|
||||||
if(OP == triton::dnn::shift::WGRAD) {
|
if(OP == triton::dnn::WGRAD) {
|
||||||
int64_t Ha, Wa, Ba;
|
int64_t Ha, Wa, Ba;
|
||||||
int64_t Hb, Wb, Bb;
|
int64_t Hb, Wb, Bb;
|
||||||
ExtractShapes(tf_a, F, Ha, Wa, Ba);
|
ExtractShapes(tf_a, F, Ha, Wa, Ba);
|
||||||
@@ -68,19 +68,19 @@ public:
|
|||||||
// shapes for a
|
// shapes for a
|
||||||
int64_t Ca;
|
int64_t Ca;
|
||||||
ExtractShapes(tf_a, Ca, H, W, B);
|
ExtractShapes(tf_a, Ca, H, W, B);
|
||||||
if(OP == triton::dnn::shift::BPROP){
|
if(OP == triton::dnn::BPROP){
|
||||||
H *= stride_h_;
|
H *= stride_h_;
|
||||||
W *= stride_w_;
|
W *= stride_w_;
|
||||||
}
|
}
|
||||||
// shapes for b
|
// shapes for b
|
||||||
int64_t Cb = tf_b.dim_size(0);
|
int64_t Cb = tf_b.dim_size(0);
|
||||||
F = tf_b.dim_size(1);
|
F = tf_b.dim_size(1);
|
||||||
if(OP == triton::dnn::shift::BPROP)
|
if(OP == triton::dnn::BPROP)
|
||||||
std::swap(Cb, F);
|
std::swap(Cb, F);
|
||||||
// checks
|
// checks
|
||||||
OP_REQUIRES(context, Ca == Cb, tensorflow::errors::InvalidArgument("operands must have the same number of channels"));
|
OP_REQUIRES(context, Ca == Cb, tensorflow::errors::InvalidArgument("operands must have the same number of channels"));
|
||||||
C = Ca;
|
C = Ca;
|
||||||
if(OP == triton::dnn::shift::BPROP)
|
if(OP == triton::dnn::BPROP)
|
||||||
std::swap(C, F);
|
std::swap(C, F);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -122,7 +122,7 @@ public:
|
|||||||
triton::driver::cu_buffer da(ctx, (CUdeviceptr)tf_a.flat<Eigen::half>().data(), false);
|
triton::driver::cu_buffer da(ctx, (CUdeviceptr)tf_a.flat<Eigen::half>().data(), false);
|
||||||
triton::driver::cu_buffer db(ctx, (CUdeviceptr)tf_b.flat<Eigen::half>().data(), false);
|
triton::driver::cu_buffer db(ctx, (CUdeviceptr)tf_b.flat<Eigen::half>().data(), false);
|
||||||
triton::driver::cu_buffer dc(ctx, (CUdeviceptr)tf_c->flat<Eigen::half>().data(), false);
|
triton::driver::cu_buffer dc(ctx, (CUdeviceptr)tf_c->flat<Eigen::half>().data(), false);
|
||||||
shift.enqueue(stream, {&da, &db, &dc}, false);
|
shift.enqueue(stream, {&da, &db, &dc}, triton::dnn::PARTIAL_TUNING);
|
||||||
}
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
@@ -132,10 +132,10 @@ private:
|
|||||||
int stride_w_;
|
int stride_w_;
|
||||||
int R_;
|
int R_;
|
||||||
int S_;
|
int S_;
|
||||||
triton::dnn::shift::layout_t layout_;
|
triton::dnn::layout_t layout_;
|
||||||
};
|
};
|
||||||
|
|
||||||
REGISTER_KERNEL_BUILDER(Name("ShiftConv").Device(DEVICE_GPU), ShiftConvOp<triton::dnn::shift::FPROP>);
|
REGISTER_KERNEL_BUILDER(Name("ShiftConv").Device(DEVICE_GPU), ShiftConvOp<triton::dnn::FPROP>);
|
||||||
REGISTER_OP("ShiftConv")
|
REGISTER_OP("ShiftConv")
|
||||||
.Input("a: float16")
|
.Input("a: float16")
|
||||||
.Input("b: float16")
|
.Input("b: float16")
|
||||||
@@ -145,7 +145,7 @@ REGISTER_OP("ShiftConv")
|
|||||||
.Attr("stride_w: int")
|
.Attr("stride_w: int")
|
||||||
.Output("c: float16");
|
.Output("c: float16");
|
||||||
|
|
||||||
REGISTER_KERNEL_BUILDER(Name("ShiftConvDx").Device(DEVICE_GPU), ShiftConvOp<triton::dnn::shift::BPROP>);
|
REGISTER_KERNEL_BUILDER(Name("ShiftConvDx").Device(DEVICE_GPU), ShiftConvOp<triton::dnn::BPROP>);
|
||||||
REGISTER_OP("ShiftConvDx")
|
REGISTER_OP("ShiftConvDx")
|
||||||
.Input("a: float16")
|
.Input("a: float16")
|
||||||
.Input("b: float16")
|
.Input("b: float16")
|
||||||
@@ -155,7 +155,7 @@ REGISTER_OP("ShiftConvDx")
|
|||||||
.Attr("stride_w: int")
|
.Attr("stride_w: int")
|
||||||
.Output("c: float16");
|
.Output("c: float16");
|
||||||
|
|
||||||
REGISTER_KERNEL_BUILDER(Name("ShiftConvDw").Device(DEVICE_GPU), ShiftConvOp<triton::dnn::shift::WGRAD>);
|
REGISTER_KERNEL_BUILDER(Name("ShiftConvDw").Device(DEVICE_GPU), ShiftConvOp<triton::dnn::WGRAD>);
|
||||||
REGISTER_OP("ShiftConvDw")
|
REGISTER_OP("ShiftConvDw")
|
||||||
.Input("a: float16")
|
.Input("a: float16")
|
||||||
.Input("b: float16")
|
.Input("b: float16")
|
||||||
|
42
include/triton/dnn/blocksparse/dot.h
Normal file
42
include/triton/dnn/blocksparse/dot.h
Normal file
@@ -0,0 +1,42 @@
|
|||||||
|
#include "triton/driver/stream.h"
|
||||||
|
#include "triton/driver/kernel.h"
|
||||||
|
#include "triton/dnn/base.h"
|
||||||
|
#include <string>
|
||||||
|
|
||||||
|
namespace triton{
|
||||||
|
namespace dnn{
|
||||||
|
namespace blocksparse{
|
||||||
|
|
||||||
|
|
||||||
|
class dot: public base {
|
||||||
|
private:
|
||||||
|
void enqueue_impl(driver::stream *stream, driver::kernel *kernel,
|
||||||
|
std::vector<driver::buffer*> args,
|
||||||
|
triton::runtime::launch_information info);
|
||||||
|
// number of flops
|
||||||
|
virtual size_t num_flops() const;
|
||||||
|
// comparison for maps
|
||||||
|
virtual bool operator<(const base& other) const;
|
||||||
|
// default parameters
|
||||||
|
virtual std::vector<params_t> search_space() const;
|
||||||
|
virtual params_t heuristics() const;
|
||||||
|
|
||||||
|
public:
|
||||||
|
// constructor
|
||||||
|
dot(int32_t M, int32_t N, int32_t K);
|
||||||
|
// triton-c source
|
||||||
|
virtual void triton_c_src(std::ostream &os) const;
|
||||||
|
// clone
|
||||||
|
virtual base* clone() const;
|
||||||
|
|
||||||
|
private:
|
||||||
|
std::string ab_ty_;
|
||||||
|
std::string c_ty_;
|
||||||
|
int32_t M_;
|
||||||
|
int32_t N_;
|
||||||
|
int32_t K_;
|
||||||
|
};
|
||||||
|
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
@@ -13,79 +13,95 @@ typedef std::tuple<size_t, size_t> size_key_t;
|
|||||||
static const std::map<trans_key_t, std::map<size_key_t, params_t>> params = {
|
static const std::map<trans_key_t, std::map<size_key_t, params_t>> params = {
|
||||||
/* NN */
|
/* NN */
|
||||||
{trans_key_t(false, false), std::map<size_key_t, params_t>{
|
{trans_key_t(false, false), std::map<size_key_t, params_t>{
|
||||||
{size_key_t(16, 16), {4, 4, 16, 8, 16, 2, 2, 1, 1, 8, 32, 4, 8, 1}},
|
{size_key_t(16, 16), {2, 8, 16, 4, 16, 2, 2, 1, 1, 16, 32, 8, 4, 1}},
|
||||||
{size_key_t(16, 32), {2, 8, 16, 8, 32, 2, 2, 1, 1, 16, 32, 4, 8, 1}},
|
{size_key_t(16, 32), {4, 4, 16, 4, 32, 2, 2, 1, 1, 8, 32, 8, 4, 1}},
|
||||||
{size_key_t(16, 64), {4, 4, 16, 4, 64, 2, 2, 1, 1, 8, 32, 8, 4, 1}},
|
{size_key_t(16, 64), {4, 4, 16, 4, 64, 2, 2, 1, 1, 8, 32, 8, 4, 1}},
|
||||||
{size_key_t(16, 128), {4, 4, 16, 16, 128, 2, 2, 1, 2, 16, 32, 4, 8, 1}},
|
{size_key_t(16, 128), {2, 8, 16, 8, 128, 2, 2, 1, 1, 16, 32, 4, 8, 1}},
|
||||||
{size_key_t(32, 16), {4, 8, 32, 8, 16, 2, 2, 1, 1, 8, 32, 4, 8, 1}},
|
{size_key_t(32, 16), {8, 4, 32, 8, 16, 2, 2, 1, 1, 4, 32, 4, 8, 1}},
|
||||||
{size_key_t(32, 32), {4, 8, 32, 8, 32, 2, 2, 1, 1, 8, 32, 4, 8, 1}},
|
{size_key_t(32, 32), {4, 8, 32, 4, 32, 2, 2, 1, 1, 8, 32, 8, 4, 1}},
|
||||||
{size_key_t(32, 64), {8, 4, 32, 8, 64, 2, 2, 1, 1, 4, 32, 4, 8, 1}},
|
{size_key_t(32, 64), {8, 4, 32, 4, 64, 2, 2, 1, 1, 4, 32, 8, 4, 1}},
|
||||||
{size_key_t(32, 128), {8, 4, 32, 16, 128, 2, 2, 1, 4, 16, 32, 8, 4, 1}},
|
{size_key_t(32, 128), {8, 4, 32, 32, 128, 2, 2, 2, 2, 16, 32, 4, 4, 1}},
|
||||||
{size_key_t(64, 16), {8, 8, 64, 4, 16, 2, 2, 1, 1, 4, 32, 8, 4, 1}},
|
{size_key_t(32, 256), {4, 8, 32, 32, 256, 2, 2, 1, 4, 32, 32, 4, 8, 1}},
|
||||||
|
{size_key_t(64, 16), {8, 8, 64, 8, 16, 2, 2, 1, 1, 4, 32, 4, 8, 1}},
|
||||||
{size_key_t(64, 32), {8, 8, 64, 8, 32, 2, 2, 1, 1, 4, 32, 4, 8, 1}},
|
{size_key_t(64, 32), {8, 8, 64, 8, 32, 2, 2, 1, 1, 4, 32, 4, 8, 1}},
|
||||||
{size_key_t(64, 64), {8, 8, 64, 16, 64, 2, 2, 2, 1, 8, 32, 4, 8, 1}},
|
{size_key_t(64, 64), {8, 8, 64, 16, 64, 2, 2, 1, 2, 8, 32, 4, 8, 1}},
|
||||||
{size_key_t(64, 128), {16, 4, 64, 16, 128, 2, 2, 2, 2, 8, 32, 8, 4, 1}},
|
{size_key_t(64, 128), {16, 4, 64, 32, 128, 2, 2, 1, 4, 8, 32, 4, 8, 1}},
|
||||||
{size_key_t(128, 16), {8, 8, 128, 8, 16, 2, 2, 2, 1, 8, 32, 8, 4, 1}},
|
{size_key_t(128, 16), {8, 8, 128, 16, 16, 2, 2, 2, 1, 8, 32, 4, 8, 1}},
|
||||||
{size_key_t(128, 32), {8, 8, 128, 16, 32, 2, 2, 2, 1, 8, 32, 4, 8, 1}},
|
{size_key_t(128, 32), {32, 4, 128, 16, 32, 2, 2, 2, 1, 2, 32, 4, 8, 1}},
|
||||||
{size_key_t(128, 64), {8, 8, 128, 32, 64, 2, 2, 2, 2, 16, 32, 4, 8, 1}},
|
{size_key_t(128, 64), {16, 8, 128, 16, 64, 2, 2, 2, 2, 8, 32, 8, 4, 1}},
|
||||||
{size_key_t(128, 128), {8, 8, 128, 32, 128, 2, 2, 1, 4, 16, 32, 4, 8, 1}}
|
{size_key_t(128, 128), {8, 8, 128, 32, 128, 2, 2, 2, 2, 16, 32, 4, 8, 1}},
|
||||||
|
{size_key_t(256, 16), {32, 8, 256, 16, 16, 2, 2, 4, 1, 4, 32, 8, 4, 1}},
|
||||||
|
{size_key_t(256, 32), {32, 8, 256, 16, 32, 2, 2, 4, 1, 4, 32, 8, 4, 1}},
|
||||||
|
{size_key_t(256, 64), {16, 8, 256, 32, 64, 2, 2, 4, 1, 8, 32, 4, 8, 1}}
|
||||||
}},
|
}},
|
||||||
/* NT */
|
/* NT */
|
||||||
{trans_key_t(false, true), std::map<size_key_t, params_t>{
|
{trans_key_t(false, true), std::map<size_key_t, params_t>{
|
||||||
{size_key_t(16, 16), {4, 4, 16, 2, 8, 16, 2, 2, 1, 1, 8, 32, 16, 1}},
|
{size_key_t(16, 16), {2, 4, 16, 2, 8, 16, 2, 2, 1, 1, 16, 32, 16, 1}},
|
||||||
{size_key_t(16, 32), {4, 4, 16, 4, 8, 32, 2, 2, 1, 1, 8, 32, 8, 1}},
|
{size_key_t(16, 32), {4, 4, 16, 8, 4, 32, 2, 2, 1, 1, 8, 32, 4, 1}},
|
||||||
{size_key_t(16, 64), {4, 4, 16, 8, 8, 64, 2, 2, 1, 4, 32, 32, 16, 1}},
|
{size_key_t(16, 64), {2, 4, 16, 2, 8, 64, 2, 2, 1, 1, 16, 32, 16, 1}},
|
||||||
{size_key_t(16, 128), {4, 4, 16, 32, 4, 128, 2, 2, 1, 2, 16, 32, 2, 1}},
|
{size_key_t(16, 128), {2, 8, 16, 8, 8, 128, 2, 2, 1, 1, 16, 32, 4, 1}},
|
||||||
{size_key_t(32, 16), {8, 4, 32, 2, 8, 16, 2, 2, 1, 1, 4, 32, 16, 1}},
|
{size_key_t(32, 16), {8, 4, 32, 2, 8, 16, 2, 2, 1, 1, 4, 32, 16, 1}},
|
||||||
{size_key_t(32, 32), {4, 8, 32, 4, 8, 32, 2, 2, 1, 1, 8, 32, 8, 1}},
|
{size_key_t(32, 32), {4, 8, 32, 8, 4, 32, 2, 2, 1, 1, 8, 32, 4, 1}},
|
||||||
{size_key_t(32, 64), {16, 8, 128, 4, 4, 64, 2, 2, 1, 4, 8, 32, 32, 1}},
|
{size_key_t(32, 64), {16, 4, 64, 16, 4, 64, 2, 2, 4, 1, 8, 32, 8, 1}},
|
||||||
{size_key_t(32, 128), {4, 8, 32, 8, 8, 128, 2, 2, 1, 2, 16, 32, 8, 1}},
|
{size_key_t(32, 128), {4, 8, 32, 16, 4, 128, 2, 2, 1, 2, 16, 32, 4, 1}},
|
||||||
{size_key_t(64, 16), {8, 8, 64, 2, 8, 16, 2, 2, 1, 1, 4, 32, 16, 1}},
|
{size_key_t(32, 256), {4, 8, 32, 64, 4, 256, 2, 2, 1, 4, 32, 32, 2, 1}},
|
||||||
{size_key_t(64, 32), {8, 8, 64, 4, 8, 32, 2, 2, 1, 1, 4, 32, 8, 1}},
|
{size_key_t(64, 16), {8, 8, 64, 2, 8, 16, 2, 2, 1, 1, 4, 32, 16, 1}},
|
||||||
{size_key_t(64, 64), {8, 8, 64, 8, 8, 64, 2, 2, 1, 2, 8, 32, 8, 1}},
|
{size_key_t(64, 32), {16, 4, 64, 4, 4, 32, 2, 2, 1, 1, 2, 32, 8, 1}},
|
||||||
{size_key_t(64, 128), {8, 8, 64, 16, 8, 128, 2, 2, 1, 4, 16, 32, 8, 1}},
|
{size_key_t(64, 64), {8, 8, 64, 8, 8, 64, 2, 2, 2, 1, 8, 32, 8, 1}},
|
||||||
{size_key_t(128, 16), {8, 8, 128, 2, 8, 16, 2, 2, 2, 1, 8, 32, 32, 1}},
|
{size_key_t(64, 128), {4, 4, 64, 8, 8, 128, 2, 2, 1, 4, 32, 32, 16, 1}},
|
||||||
{size_key_t(128, 32), {16, 8, 128, 4, 8, 32, 2, 2, 2, 1, 4, 32, 16, 1}},
|
{size_key_t(64, 256), {8, 8, 64, 8, 8, 256, 2, 2, 1, 4, 16, 32, 16, 1}},
|
||||||
{size_key_t(128, 64), {8, 8, 128, 8, 8, 64, 2, 2, 2, 2, 16, 32, 16, 1}},
|
{size_key_t(128, 16), {16, 8, 128, 2, 8, 16, 2, 2, 1, 1, 2, 32, 16, 1}},
|
||||||
{size_key_t(128, 128), {8, 8, 128, 8, 8, 128, 2, 2, 4, 1, 16, 32, 16, 1}}
|
{size_key_t(128, 32), {32, 4, 128, 4, 8, 32, 2, 2, 2, 1, 2, 32, 16, 1}},
|
||||||
}},
|
{size_key_t(128, 64), {8, 8, 128, 8, 8, 64, 2, 2, 4, 1, 16, 32, 16, 1}},
|
||||||
|
{size_key_t(128, 128), {8, 8, 128, 16, 8, 128, 2, 2, 2, 2, 16, 32, 8, 1}},
|
||||||
|
{size_key_t(256, 16), {32, 4, 256, 4, 4, 16, 2, 2, 4, 1, 4, 32, 32, 1}},
|
||||||
|
{size_key_t(256, 32), {16, 8, 256, 8, 4, 32, 2, 2, 4, 1, 8, 32, 16, 1}},
|
||||||
|
{size_key_t(256, 64), {8, 8, 256, 8, 8, 64, 2, 2, 4, 1, 16, 32, 16, 1}}
|
||||||
|
}},
|
||||||
/* TN */
|
/* TN */
|
||||||
{trans_key_t(true, false), std::map<size_key_t, params_t>{
|
{trans_key_t(true, false), std::map<size_key_t, params_t>{
|
||||||
{size_key_t(16, 16), {8, 16, 16, 16, 2, 2, 1, 1, 4, 8, 32, 2, 8, 1}},
|
{size_key_t(16, 16), {4, 16, 4, 16, 2, 2, 1, 1, 8, 4, 32, 8, 4, 1}},
|
||||||
{size_key_t(16, 32), {4, 16, 8, 32, 2, 2, 1, 1, 8, 4, 32, 4, 8, 1}},
|
{size_key_t(16, 32), {8, 16, 8, 32, 2, 2, 1, 1, 4, 4, 32, 4, 4, 1}},
|
||||||
{size_key_t(16, 64), {4, 16, 4, 64, 2, 2, 1, 1, 8, 4, 32, 8, 4, 1}},
|
{size_key_t(16, 64), {4, 16, 8, 64, 2, 2, 1, 1, 8, 4, 32, 4, 8, 1}},
|
||||||
{size_key_t(16, 128), {16, 16, 16, 128, 2, 2, 1, 2, 4, 8, 32, 4, 8, 1}},
|
{size_key_t(16, 128), {4, 16, 8, 128, 2, 2, 1, 1, 8, 4, 32, 4, 8, 1}},
|
||||||
{size_key_t(32, 16), {4, 32, 8, 16, 2, 2, 1, 1, 8, 4, 32, 4, 8, 1}},
|
{size_key_t(32, 16), {4, 32, 8, 16, 2, 2, 1, 1, 8, 4, 32, 4, 8, 1}},
|
||||||
{size_key_t(32, 32), {8, 32, 8, 32, 2, 2, 1, 1, 4, 8, 32, 4, 8, 1}},
|
{size_key_t(32, 32), {4, 32, 4, 32, 2, 2, 1, 1, 8, 4, 32, 8, 4, 1}},
|
||||||
{size_key_t(32, 64), {8, 32, 8, 64, 2, 2, 1, 1, 4, 8, 32, 4, 8, 1}},
|
{size_key_t(32, 64), {4, 32, 4, 64, 2, 2, 1, 1, 8, 4, 32, 8, 4, 1}},
|
||||||
{size_key_t(32, 128), {32, 32, 64, 128, 2, 2, 2, 2, 4, 8, 32, 2, 8, 1}},
|
{size_key_t(32, 128), {8, 32, 8, 128, 2, 2, 1, 1, 4, 8, 32, 4, 8, 1}},
|
||||||
{size_key_t(64, 16), {8, 64, 8, 16, 2, 2, 1, 1, 4, 8, 32, 4, 8, 1}},
|
{size_key_t(32, 256), {32, 32, 32, 256, 2, 2, 1, 4, 4, 8, 32, 4, 8, 1}},
|
||||||
{size_key_t(64, 32), {8, 64, 8, 32, 2, 2, 1, 1, 4, 8, 32, 4, 8, 1}},
|
{size_key_t(64, 16), {4, 64, 8, 16, 2, 2, 1, 1, 8, 4, 32, 4, 8, 1}},
|
||||||
{size_key_t(64, 64), {16, 64, 16, 64, 2, 2, 2, 1, 4, 8, 32, 4, 8, 1}},
|
{size_key_t(64, 32), {4, 64, 4, 32, 2, 2, 1, 1, 8, 4, 32, 8, 4, 1}},
|
||||||
{size_key_t(64, 128), {32, 64, 16, 128, 2, 2, 2, 2, 4, 8, 32, 8, 4, 1}},
|
{size_key_t(64, 64), {8, 64, 16, 64, 2, 2, 2, 1, 8, 4, 32, 4, 8, 1}},
|
||||||
{size_key_t(128, 16), {16, 128, 16, 16, 2, 2, 2, 1, 4, 8, 32, 4, 8, 1}},
|
{size_key_t(64, 128), {16, 64, 32, 128, 2, 2, 1, 4, 8, 4, 32, 4, 8, 1}},
|
||||||
{size_key_t(128, 32), {32, 128, 32, 32, 2, 2, 4, 1, 4, 8, 32, 4, 8, 1}},
|
{size_key_t(128, 16), {8, 128, 8, 16, 2, 2, 1, 1, 4, 8, 32, 4, 8, 1}},
|
||||||
{size_key_t(128, 64), {32, 128, 32, 64, 2, 2, 4, 1, 4, 8, 32, 4, 8, 1}},
|
{size_key_t(128, 32), {16, 128, 16, 32, 2, 2, 4, 1, 8, 4, 32, 8, 4, 1}},
|
||||||
{size_key_t(128, 128), {32, 128, 32, 128, 2, 2, 4, 1, 4, 8, 32, 4, 8, 1}},
|
{size_key_t(128, 64), {32, 128, 32, 64, 2, 2, 2, 2, 4, 8, 32, 4, 8, 1}},
|
||||||
|
{size_key_t(128, 128), {32, 128, 32, 128, 2, 2, 1, 4, 4, 8, 32, 4, 8, 1}},
|
||||||
|
{size_key_t(256, 16), {16, 256, 16, 16, 2, 2, 2, 1, 4, 8, 32, 4, 8, 1}},
|
||||||
|
{size_key_t(256, 32), {16, 256, 32, 32, 2, 2, 4, 1, 8, 4, 32, 4, 8, 1}},
|
||||||
}},
|
}},
|
||||||
/* TT */
|
/* TT */
|
||||||
{trans_key_t(true, true), std::map<size_key_t, params_t>{
|
{trans_key_t(true, true), std::map<size_key_t, params_t>{
|
||||||
{size_key_t(16, 16), {4, 16, 2, 8, 16, 2, 2, 1, 1, 8, 4, 32, 16, 1}},
|
{size_key_t(16, 16), {8, 16, 4, 4, 16, 2, 2, 1, 1, 4, 8, 32, 8, 1}},
|
||||||
{size_key_t(16, 32), {8, 16, 4, 8, 32, 2, 2, 1, 1, 4, 8, 32, 8, 1}},
|
{size_key_t(16, 32), {8, 16, 8, 4, 32, 2, 2, 1, 1, 4, 8, 32, 4, 1}},
|
||||||
{size_key_t(16, 64), {16, 16, 4, 8, 64, 2, 2, 1, 4, 8, 4, 32, 32, 1}},
|
{size_key_t(16, 64), {16, 16, 4, 8, 64, 2, 2, 1, 4, 8, 4, 32, 32, 1}},
|
||||||
{size_key_t(16, 128), {16, 16, 8, 4, 128, 2, 2, 1, 2, 4, 8, 32, 8, 1}},
|
{size_key_t(16, 128), {16, 16, 8, 8, 128, 2, 2, 1, 1, 2, 4, 32, 4, 1}},
|
||||||
{size_key_t(32, 16), {4, 32, 2, 8, 16, 2, 2, 1, 1, 8, 4, 32, 16, 1}},
|
{size_key_t(32, 16), {4, 32, 4, 4, 16, 2, 2, 1, 1, 8, 4, 32, 8, 1}},
|
||||||
{size_key_t(32, 32), {8, 32, 4, 8, 32, 2, 2, 1, 1, 4, 8, 32, 8, 1}},
|
{size_key_t(32, 32), {8, 32, 8, 4, 32, 2, 2, 1, 1, 4, 8, 32, 4, 1}},
|
||||||
{size_key_t(32, 64), {16, 64, 4, 8, 64, 2, 2, 2, 1, 4, 8, 32, 16, 1}},
|
{size_key_t(32, 64), {64, 128, 8, 4, 64, 2, 2, 4, 1, 2, 8, 32, 16, 1}},
|
||||||
{size_key_t(32, 128), {32, 32, 8, 8, 128, 2, 2, 1, 4, 4, 8, 32, 16, 1}},
|
{size_key_t(32, 128), {16, 32, 32, 4, 128, 2, 2, 1, 2, 4, 8, 32, 2, 1}},
|
||||||
{size_key_t(64, 16), {8, 64, 2, 8, 16, 2, 2, 1, 1, 4, 8, 32, 16, 1}},
|
{size_key_t(32, 256), {32, 32, 32, 4, 256, 2, 2, 1, 4, 4, 8, 32, 4, 1}},
|
||||||
{size_key_t(64, 32), {8, 64, 4, 8, 32, 2, 2, 1, 1, 4, 8, 32, 8, 1}},
|
{size_key_t(64, 16), {4, 64, 2, 8, 16, 2, 2, 1, 1, 8, 4, 32, 16, 1}},
|
||||||
{size_key_t(64, 64), {16, 64, 8, 8, 64, 2, 2, 2, 1, 4, 8, 32, 8, 1}},
|
{size_key_t(64, 32), {4, 64, 8, 4, 32, 2, 2, 1, 1, 8, 4, 32, 4, 1}},
|
||||||
{size_key_t(64, 128), {32, 64, 8, 8, 128, 2, 2, 1, 4, 4, 8, 32, 16, 1}},
|
{size_key_t(64, 64), {16, 64, 8, 8, 64, 2, 2, 2, 1, 4, 8, 32, 8, 1}},
|
||||||
{size_key_t(128, 16), {16, 128, 2, 8, 16, 2, 2, 2, 1, 4, 8, 32, 32, 1}},
|
{size_key_t(64, 128), {32, 64, 8, 8, 128, 2, 2, 1, 4, 4, 4, 32, 16, 1}},
|
||||||
{size_key_t(128, 32), {32, 128, 8, 4, 32, 2, 2, 4, 1, 4, 8, 32, 16, 1}},
|
{size_key_t(64, 256), {64, 64, 8, 8, 256, 2, 2, 1, 4, 2, 8, 32, 16}},
|
||||||
{size_key_t(128, 64), {32, 128, 16, 4, 64, 2, 2, 4, 1, 4, 8, 32, 8, 1}},
|
{size_key_t(128, 16), {8, 128, 2, 8, 16, 2, 2, 1, 1, 4, 8, 32, 16, 1}},
|
||||||
{size_key_t(128, 128), {32, 128, 8, 8, 128, 2, 2, 4, 1, 4, 8, 32, 16, 1}}
|
{size_key_t(128, 32), {16, 128, 8, 4, 32, 2, 2, 4, 1, 8, 4, 32, 16, 1}},
|
||||||
|
{size_key_t(128, 64), {32, 128, 8, 8, 64, 2, 2, 4, 1, 4, 8, 32, 16, 1}},
|
||||||
|
{size_key_t(128, 128), {32, 128, 16, 8, 128, 2, 2, 2, 2, 4, 8, 32, 8, 1}},
|
||||||
|
{size_key_t(256, 16), {32, 256, 4, 4, 16, 2, 2, 4, 1, 4, 8, 32, 32, 1}},
|
||||||
|
{size_key_t(256, 32), {32, 256, 8, 4, 32, 2, 2, 4, 1, 4, 8, 32, 16, 1}}
|
||||||
}}
|
}}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@@ -72,6 +72,7 @@ public:
|
|||||||
void target_independent(ir::module &module) {
|
void target_independent(ir::module &module) {
|
||||||
optimize_dot.run(module);
|
optimize_dot.run(module);
|
||||||
optimize_trans.run(module);
|
optimize_trans.run(module);
|
||||||
|
// optimize_dce.run(module);
|
||||||
}
|
}
|
||||||
|
|
||||||
void target_dependent(ir::module &module) {
|
void target_dependent(ir::module &module) {
|
||||||
|
@@ -1190,9 +1190,8 @@ void selection::lower_tile_instruction(ir::instruction *ins, llvm::IRBuilder<> &
|
|||||||
Value *ptr = pointers->get_value(idx);
|
Value *ptr = pointers->get_value(idx);
|
||||||
ConstantInt *cst = nullptr;
|
ConstantInt *cst = nullptr;
|
||||||
if(GetElementPtrInst *gep = dyn_cast<GetElementPtrInst>(ptr))
|
if(GetElementPtrInst *gep = dyn_cast<GetElementPtrInst>(ptr))
|
||||||
if(gep->getNumIndices() == 1){
|
if(gep->getNumIndices() == 1)
|
||||||
cst = dyn_cast<ConstantInt>(gep->idx_begin());
|
cst = dyn_cast<ConstantInt>(gep->idx_begin());
|
||||||
}
|
|
||||||
ptr = builder.CreateBitCast(ptr, PointerType::get(VectorType::get(result->get_ty(), vector_size),
|
ptr = builder.CreateBitCast(ptr, PointerType::get(VectorType::get(result->get_ty(), vector_size),
|
||||||
ptr->getType()->getPointerAddressSpace()));
|
ptr->getType()->getPointerAddressSpace()));
|
||||||
packets[id] = builder.CreateLoad(ptr);
|
packets[id] = builder.CreateLoad(ptr);
|
||||||
@@ -1202,7 +1201,7 @@ void selection::lower_tile_instruction(ir::instruction *ins, llvm::IRBuilder<> &
|
|||||||
result->for_each([&](indices_t idx){
|
result->for_each([&](indices_t idx){
|
||||||
unsigned linear = result->get_linear_index(idx);
|
unsigned linear = result->get_linear_index(idx);
|
||||||
unsigned id = linear / vector_size;
|
unsigned id = linear / vector_size;
|
||||||
// result->set_value(idx, builder.CreateExtractElement(packets.at(id), linear % vector_size));
|
result->set_value(idx, builder.CreateExtractElement(packets.at(id), linear % vector_size));
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
// element-wise
|
// element-wise
|
||||||
|
@@ -247,14 +247,14 @@ void tune::run(ir::module &mod) {
|
|||||||
size_t addr_space = ptr_ty->get_pointer_address_space();
|
size_t addr_space = ptr_ty->get_pointer_address_space();
|
||||||
if(addr_space < 4){
|
if(addr_space < 4){
|
||||||
ir::type *ty = mod.get_builder().get_int32_ty();
|
ir::type *ty = mod.get_builder().get_int32_ty();
|
||||||
std::unique_ptr<ir::metaparameter> tmp(ir::metaparameter::create(ctx, ty, 8, 8));
|
std::unique_ptr<ir::metaparameter> tmp(ir::metaparameter::create(ctx, ty, 4, 8));
|
||||||
*params_.at(i).at("nts.d0") = *tmp;
|
*params_.at(i).at("nts.d0") = *tmp;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if(dynamic_cast<ir::dot_inst*>(i) && i->get_type()->is_tile_ty()){
|
if(dynamic_cast<ir::dot_inst*>(i) && i->get_type()->is_tile_ty()){
|
||||||
ir::type *ty = mod.get_builder().get_int32_ty();
|
ir::type *ty = mod.get_builder().get_int32_ty();
|
||||||
std::unique_ptr<ir::metaparameter> tmp1(ir::metaparameter::create(ctx, ty, 8, 8));
|
std::unique_ptr<ir::metaparameter> tmp1(ir::metaparameter::create(ctx, ty, 4, 8));
|
||||||
std::unique_ptr<ir::metaparameter> tmp2(ir::metaparameter::create(ctx, ty, 8, 8));
|
std::unique_ptr<ir::metaparameter> tmp2(ir::metaparameter::create(ctx, ty, 4, 8));
|
||||||
*params_.at(i).at("nts.d0") = *tmp1;
|
*params_.at(i).at("nts.d0") = *tmp1;
|
||||||
*params_.at(i).at("nts.d1") = *tmp2;
|
*params_.at(i).at("nts.d1") = *tmp2;
|
||||||
}
|
}
|
||||||
|
@@ -59,8 +59,8 @@ void base::enqueue(driver::stream *stream, std::vector<driver::buffer *> args, a
|
|||||||
jit->add_module(name_.c_str(), src.c_str(), best.params);
|
jit->add_module(name_.c_str(), src.c_str(), best.params);
|
||||||
}
|
}
|
||||||
else {
|
else {
|
||||||
params_t params = heuristics();
|
// params_t params = heuristics();
|
||||||
// params_t params = jit->get_valid(name_.c_str(), src.c_str());
|
params_t params = jit->get_valid(name_.c_str(), src.c_str());
|
||||||
jit->add_module(name_.c_str(), src.c_str(), params);
|
jit->add_module(name_.c_str(), src.c_str(), params);
|
||||||
}
|
}
|
||||||
triton::driver::kernel* kernel = jit->get_function(name_.c_str());
|
triton::driver::kernel* kernel = jit->get_function(name_.c_str());
|
||||||
|
109
lib/dnn/blocksparse/dot.cpp
Normal file
109
lib/dnn/blocksparse/dot.cpp
Normal file
@@ -0,0 +1,109 @@
|
|||||||
|
#include "triton/dnn/blocksparse/dot.h"
|
||||||
|
|
||||||
|
namespace triton{
|
||||||
|
namespace dnn{
|
||||||
|
namespace blocksparse{
|
||||||
|
|
||||||
|
|
||||||
|
size_t dot::num_flops() const {
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
bool dot::operator <(const base& other) const {
|
||||||
|
auto *y = dynamic_cast<const dot*>(&other);
|
||||||
|
if(!y)
|
||||||
|
return true;
|
||||||
|
return std::tie(M_, N_, K_)
|
||||||
|
< std::tie(y->M_, y->N_, y->K_);
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<params_t> dot::search_space() const {
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
params_t dot::heuristics() const {
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
base * dot::clone() const {
|
||||||
|
return new dot(*this);
|
||||||
|
}
|
||||||
|
|
||||||
|
dot::dot(int32_t M, int32_t N, int32_t K):
|
||||||
|
base("bsdot"), M_(M), N_(N), K_(K) {
|
||||||
|
ab_ty_ = "fp32";
|
||||||
|
c_ty_ = "fp32";
|
||||||
|
}
|
||||||
|
|
||||||
|
void dot::enqueue_impl(driver::stream *stream, driver::kernel *kernel,
|
||||||
|
std::vector<driver::buffer *> args, runtime::launch_information info) {
|
||||||
|
driver::buffer *a = args[0];
|
||||||
|
driver::buffer *b = args[1];
|
||||||
|
driver::buffer *c = args[2];
|
||||||
|
driver::buffer *lut = args[3];
|
||||||
|
int32_t lda = M_;
|
||||||
|
int32_t ldc = M_;
|
||||||
|
kernel->setArg(0, a);
|
||||||
|
kernel->setArg(1, b);
|
||||||
|
kernel->setArg(2, c);
|
||||||
|
kernel->setArg(3, lda);
|
||||||
|
kernel->setArg(4, ldc);
|
||||||
|
kernel->setArg(5, lut);
|
||||||
|
int32_t TM = info.globals["TM"];
|
||||||
|
int32_t TN = info.globals["TN"];
|
||||||
|
size_t grid_0 = (M_ + TM - 1) / TM;
|
||||||
|
size_t grid_1 = (N_ + TN - 1) / TN;
|
||||||
|
stream->enqueue(kernel, {grid_0, grid_1, 1}, {info.num_threads, 1, 1});
|
||||||
|
stream->synchronize();
|
||||||
|
}
|
||||||
|
|
||||||
|
void dot::triton_c_src(std::ostream &os) const {
|
||||||
|
std::string result =
|
||||||
|
|
||||||
|
R"(
|
||||||
|
const tunable int32 TM = {64, 128};
|
||||||
|
const tunable int32 TN = {32};
|
||||||
|
const tunable int32 TK = {32};
|
||||||
|
|
||||||
|
void bsdot(restrict read_only align(16) )" + ab_ty_ + R"( *A,
|
||||||
|
restrict read_only align(16) )" + ab_ty_ + R"( *B,
|
||||||
|
fp32* C,
|
||||||
|
int32 lda, int32 ldc,
|
||||||
|
int32* lut_base){
|
||||||
|
int32 ridx = get_range_id(0);
|
||||||
|
int32 ridy = get_range_id(1);
|
||||||
|
fp32 c[TM, TN] = 0;
|
||||||
|
int32 rka[TK] = 0 ... TK;
|
||||||
|
int32 rkb[TK] = 0 ... TK;
|
||||||
|
int32 rxa[TM] = ridx * TM + (0 ... TM);
|
||||||
|
int32 ryb[TN] = 0 ... TN;
|
||||||
|
int32 offa[TM, TK] = rxa[:, newaxis] + rka[newaxis, :]*lda;
|
||||||
|
int32 offb[TK, TN] = ryb[newaxis, :] + rkb[:, newaxis]*TK;
|
||||||
|
int32 *header = lut_base + ridy * 4;
|
||||||
|
int32 offset = *(header + 0);
|
||||||
|
int32 K = *(header + 1);
|
||||||
|
int32 h2 = *(header + 2);
|
||||||
|
int32 h3 = *(header + 3);
|
||||||
|
int32 *lut = lut_base + offset*2;
|
||||||
|
for(int32 k = K; k > 0; k = k - 1){
|
||||||
|
int32 ak = *(lut + 0);
|
||||||
|
int32 bk = *(lut + 1);
|
||||||
|
fp32* pa[TM, TK] = A + offa + ak * TK * lda;
|
||||||
|
fp32* pb[TK, TN] = B + offb + bk * TK * TN;
|
||||||
|
fp32 a[TM, TK] = *pa;
|
||||||
|
fp32 b[TK, TN] = *pb;;
|
||||||
|
c = dot(a, b, c);
|
||||||
|
lut = lut + 2;
|
||||||
|
}
|
||||||
|
int32 rxc[TM] = ridx * TM + (0 ... TM);
|
||||||
|
int32 ryc[TN] = ridy * TN + (0 ... TN);
|
||||||
|
fp32* pc[TM, TN] = C + rxc[:, newaxis] + ryc[newaxis, :]*ldc;
|
||||||
|
*pc = c;
|
||||||
|
})";
|
||||||
|
|
||||||
|
os << result;
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
@@ -1,6 +1,6 @@
|
|||||||
#include "triton/driver/stream.h"
|
#include "triton/driver/stream.h"
|
||||||
#include "triton/driver/kernel.h"
|
#include "triton/driver/kernel.h"
|
||||||
#include "triton/dnn/gemm.h"
|
#include "triton/dnn/dot.h"
|
||||||
#include "triton/dnn/heuristics.h"
|
#include "triton/dnn/heuristics.h"
|
||||||
#include <string>
|
#include <string>
|
||||||
|
|
||||||
@@ -101,8 +101,8 @@ void dot::triton_c_src(std::ostream &os) const {
|
|||||||
std::string align_ldb_str = "multiple_of(" + std::to_string(align_ldb_) + ")";
|
std::string align_ldb_str = "multiple_of(" + std::to_string(align_ldb_) + ")";
|
||||||
std::string res =
|
std::string res =
|
||||||
R"(
|
R"(
|
||||||
const tunable int32 TM = {16, 32, 64, 128};
|
const tunable int32 TM = {16, 32, 64, 128, 256};
|
||||||
const tunable int32 TN = {16, 32, 64, 128};
|
const tunable int32 TN = {16, 32, 64, 128, 256};
|
||||||
const tunable int32 TK = {32};
|
const tunable int32 TK = {32};
|
||||||
const tunable int32 GZ = {1};
|
const tunable int32 GZ = {1};
|
||||||
|
|
@@ -467,7 +467,7 @@ if(op_ == WGRAD){
|
|||||||
pa = pa_base + offxa[:, newaxis];)";
|
pa = pa_base + offxa[:, newaxis];)";
|
||||||
}
|
}
|
||||||
result += R"(
|
result += R"(
|
||||||
@checka a = *pa;)";
|
a = checka ? *pa : 0;)";
|
||||||
|
|
||||||
/* Increment B pointers */
|
/* Increment B pointers */
|
||||||
if(op_ == WGRAD){
|
if(op_ == WGRAD){
|
||||||
@@ -488,7 +488,7 @@ if(op_ == BPROP){
|
|||||||
pb = pb + TK;)";
|
pb = pb + TK;)";
|
||||||
}
|
}
|
||||||
result += R"(
|
result += R"(
|
||||||
@checkb b = *pb;
|
b = checkb ? *pb : 0;
|
||||||
}
|
}
|
||||||
int32 rxc[TM] = ridx*TM + (0 ... TM);
|
int32 rxc[TM] = ridx*TM + (0 ... TM);
|
||||||
int32 ryc[TN] = ridy*TN + (0 ... TN);)";
|
int32 ryc[TN] = ridy*TN + (0 ... TN);)";
|
||||||
|
@@ -37,12 +37,13 @@ void parallel_loop_nest(std::vector<size_t> const & ranges,
|
|||||||
size_t D = ranges.size();
|
size_t D = ranges.size();
|
||||||
std::vector<size_t> values(D, 0);
|
std::vector<size_t> values(D, 0);
|
||||||
// thread pools
|
// thread pools
|
||||||
ThreadPool pool(nthreads);
|
// ThreadPool pool(nthreads);
|
||||||
// Start with innermost loop
|
// Start with innermost loop
|
||||||
size_t i = D - 1;
|
size_t i = D - 1;
|
||||||
while(true){
|
while(true){
|
||||||
// Execute function
|
// Execute function
|
||||||
pool.enqueue(f,values);
|
// pool.enqueue(f,values);
|
||||||
|
f(values);
|
||||||
while(values[i]++ == ranges[i] - 1){
|
while(values[i]++ == ranges[i] - 1){
|
||||||
if(i == 0)
|
if(i == 0)
|
||||||
return;
|
return;
|
||||||
@@ -50,7 +51,7 @@ void parallel_loop_nest(std::vector<size_t> const & ranges,
|
|||||||
}
|
}
|
||||||
i = D - 1;
|
i = D - 1;
|
||||||
// Short sleep so that the thread pool doesn't grow too big
|
// Short sleep so that the thread pool doesn't grow too big
|
||||||
std::this_thread::sleep_for(std::chrono::microseconds(1));
|
// std::this_thread::sleep_for(std::chrono::microseconds(1));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -211,9 +212,9 @@ jit::tune_res_t jit::autotune(const char *name, const char *src, benchmark_t ben
|
|||||||
best.perf = perf;
|
best.perf = perf;
|
||||||
best.params = params;
|
best.params = params;
|
||||||
}
|
}
|
||||||
for(size_t i = 0; i < params.size(); i++)
|
// for(size_t i = 0; i < params.size(); i++)
|
||||||
std::cout << ((i==0)?"":", ") << params[i] << std::flush;
|
// std::cout << ((i==0)?"":", ") << params[i] << std::flush;
|
||||||
std::cout << ", " << perf << " [ " << best.perf << " ] " << std::endl;
|
// std::cout << ", " << perf << " [ " << best.perf << " ] " << std::endl;
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
Reference in New Issue
Block a user