[dnn/blocksparse/dot] prototype version seems to pass basic test

This commit is contained in:
Philippe Tillet
2019-07-27 21:21:36 -07:00
parent 2a377bc8b1
commit 17cb2db356
18 changed files with 402 additions and 205 deletions

View File

@@ -4,7 +4,7 @@
#include "triton/runtime/jit.h"
#include "triton/driver/backend.h"
#include "triton/driver/stream.h"
#include "triton/dnn/gemm.h"
#include "triton/dnn/dot.h"
#include "triton/tools/bench.hpp"
#include "cuda.h"
@@ -48,7 +48,7 @@ perf_t do_bench(triton::driver::stream* stream, bool AT, bool BT, int32_t M, int
stream->synchronize();
triton::dnn::dot dot(M, N, K, AT, BT, ty, ty, 8, 8);
// benchmark triton
double triton_ns = triton::tools::bench([&]() { dot.enqueue(stream, {da, db, dc}, triton::dnn::FULL_TUNING);}, stream);
double triton_ns = triton::tools::bench([&]() { dot.enqueue(stream, {da, db, dc}, triton::dnn::PARTIAL_TUNING);}, stream);
// benchmark cublas
NumericT alpha = 1;
NumericT beta = 0;
@@ -98,8 +98,9 @@ int main() {
// shapes to benchmark
std::vector<config_t> configs = {
// {false, false, 8192, 512, 512},
{false, true, 8192, 8192, 8192}
// {false, true, 32768, 256, 512}
// {false, true, 8192, 8192, 8192}
{false, true, 32768, 256, 256},
{false, true, 32768, 256, 512}
// {true, false, 8192, 512, 512},
// {true, true, 8192, 512, 512}
};

View File

@@ -67,23 +67,23 @@ perf_t do_bench(triton::driver::stream *stream,
stream->write(dc, true, 0, hc);
stream->synchronize();
// benchmark triton
double triton_ns = triton::tools::bench([&]() { shift.enqueue(stream, {da, db, dc}, triton::dnn::FULL_TUNING);}, stream);
double triton_ns = triton::tools::bench([&]() { shift.enqueue(stream, {da, db, dc}, triton::dnn::NO_TUNING);}, stream);
// benchmark cublas
NumericT alpha = 1;
NumericT beta = 0;
cublasGemmAlgo_t fastest;
cublasGemm(HALF_TYPE, stream, shift.AT(), shift.BT(), shift.M(), shift.N(), shift.K(),
&alpha, da, shift.lda(),
db, shift.ldb(), &beta,
dc, shift.ldc(), &fastest);
double cublas_ns = triton::tools::bench([&]() { cublasGemm(HALF_TYPE, stream, shift.AT(), shift.BT(), shift.M(), shift.N(), shift.K(),
&alpha, da, shift.lda(),
db, shift.ldb(),
&beta, dc, shift.ldc(), nullptr, fastest); }, stream);
// NumericT alpha = 1;
// NumericT beta = 0;
// cublasGemmAlgo_t fastest;
// cublasGemm(HALF_TYPE, stream, shift.AT(), shift.BT(), shift.M(), shift.N(), shift.K(),
// &alpha, da, shift.lda(),
// db, shift.ldb(), &beta,
// dc, shift.ldc(), &fastest);
// double cublas_ns = triton::tools::bench([&]() { cublasGemm(HALF_TYPE, stream, shift.AT(), shift.BT(), shift.M(), shift.N(), shift.K(),
// &alpha, da, shift.lda(),
// db, shift.ldb(),
// &beta, dc, shift.ldc(), nullptr, fastest); }, stream);
// result
auto tflops = [&](double nanosec) { return shift.num_flops() / nanosec * 1e-3; };
perf_t result;
result.cublas = tflops(cublas_ns);
// result.cublas = tflops(cublas_ns);
result.triton = tflops(triton_ns);
delete da;
delete db;
@@ -133,8 +133,9 @@ int main() {
{128, 1024, 8, 8, 3, 3, 1024, 1, 1}
};
for(config_t c: resnet18){
for(op_t op: {op_t::FPROP, op_t::BPROP, op_t::WGRAD})
for(op_t op: {op_t::FPROP, op_t::BPROP, op_t::WGRAD}){
configs.push_back({c.B, c.C, c.H, c.W, c.R, c.S, c.F, c.stride_h, c.stride_w, op, layout_t::CHWN, "fp16"});
}
}
// initialize default compute device

View File

@@ -5,7 +5,7 @@ if(${TensorFlow_FOUND})
include_directories("${CUDA_HOME}/include")
link_directories(${TF_LIB})
add_definitions(-D_GLIBCXX_USE_CXX11_ABI=${TF_ABI})
add_library(tf_blocksparse SHARED dot.cpp conv.cpp shift.cpp batchnorm.cpp)
add_library(tf_blocksparse SHARED blocksparse.cpp dot.cpp conv.cpp shift.cpp batchnorm.cpp)
target_link_libraries(tf_blocksparse tensorflow_framework triton)
configure_file(${CMAKE_CURRENT_SOURCE_DIR}/run.py
${CMAKE_CURRENT_BINARY_DIR}/run.py

View File

@@ -3,7 +3,8 @@
#include "triton/driver/buffer.h"
#include "triton/driver/backend.h"
#include "triton/driver/stream.h"
#include "triton/jit.h"
#include "triton/runtime/jit.h"
#include "triton/dnn/blocksparse/dot.h"
#define EIGEN_USE_GPU
#include "tensorflow/core/framework/op.h"
@@ -20,106 +21,88 @@ using shape_inference::InferenceContext;
using shape_inference::ShapeHandle;
using GPUDevice = Eigen::GpuDevice;
const char* src =
R"(
const tunable int32 TM = {16, 32, 64, 128};
const tunable int32 TN = {16, 32, 64, 128};
const tunable int32 TK = {8};
const tunable int32 GZ = {1};
void bsmm (restrict read_only fp32 *A, restrict read_only fp32 *B, fp32 *C,
int32 M, int32 N, int32 K,
int32 lda, int32 ldb, int32 ldc,
int32 *locks, int32 grid0, int32 grid1) {
}
)";
Status XpropShape(InferenceContext* ctx)
{
int K; TF_RETURN_IF_ERROR(ctx->GetAttr( "K", &K));
int axis; TF_RETURN_IF_ERROR(ctx->GetAttr("axis", &axis));
int K; TF_RETURN_IF_ERROR(ctx->GetAttr( "K", &K));
int axis; TF_RETURN_IF_ERROR(ctx->GetAttr("axis", &axis));
// C ==> K
ShapeHandle x = ctx->input(0);
int rank = ctx->Rank(x);
//printf("XpropShape: %d\n", rank);
if (rank > 0)
{
std::vector<DimensionHandle> shape;
shape.reserve(rank);
for (int i = 0; i < rank; i++)
shape.push_back(i == axis ? ctx->MakeDim(K) : ctx->Dim(x, i));
ctx->set_output(0, ctx->MakeShape(shape));
}
else
ctx->set_output(0, ctx->UnknownShape());
ctx->set_output(1, ctx->UnknownShape());
return Status::OK();
// C ==> K
ShapeHandle x = ctx->input(0);
int rank = ctx->Rank(x);
//printf("XpropShape: %d\n", rank);
if (rank > 0)
{
std::vector<DimensionHandle> shape;
shape.reserve(rank);
for (int i = 0; i < rank; i++)
shape.push_back(i == axis ? ctx->MakeDim(K) : ctx->Dim(x, i));
ctx->set_output(0, ctx->MakeShape(shape));
}
else
ctx->set_output(0, ctx->UnknownShape());
ctx->set_output(1, ctx->UnknownShape());
return Status::OK();
}
REGISTER_OP("BlocksparseMatmul")
.Input("x: T")
.Input("w: T")
.Input("lut: int64")
.Input("lut_dx: int64")
.Input("lut_dw: int64")
.Input("gate: ngate * float")
.Output("y: T")
.Output("temp: int32")
.Attr("T: {half, float, bfloat16}")
.Attr("blocks: int >=0")
.Attr("bsize: int")
.Attr("segments: int = 0")
.Attr("segments_dx: int = 0")
.Attr("locks: int = 0")
.Attr("locks_dx: int = 0")
.Attr("axis: int = 1")
.Attr("C: int >=0")
.Attr("K: int >=0")
.Attr("shared: int = 0")
.Attr("shared_dx: int = 0")
.Attr("alpha: float = 1.0")
.Attr("beta: float = 0.0")
.Attr("gated_dw: bool = false")
.Attr("gate_grad: bool = false")
.Attr("bench: int = 0")
.Attr("ngate: int >= 0")
.SetShapeFn(XpropShape)
.Doc(R"doc(
Multiply the matrix "a" by the blocksparse matrix "b".
)doc");
REGISTER_OP("TritonBlocksparseMatmul")
.Input("x: T")
.Input("w: T")
.Input("lut: int64")
.Input("lut_dx: int64")
.Input("lut_dw: int64")
.Input("gate: ngate * float")
.Output("y: T")
.Output("temp: int32")
.Attr("T: {half, float, bfloat16}")
.Attr("blocks: int >=0")
.Attr("bsize: int")
.Attr("segments: int = 0")
.Attr("segments_dx: int = 0")
.Attr("locks: int = 0")
.Attr("locks_dx: int = 0")
.Attr("axis: int = 1")
.Attr("C: int >=0")
.Attr("K: int >=0")
.Attr("shared: int = 0")
.Attr("shared_dx: int = 0")
.Attr("alpha: float = 1.0")
.Attr("beta: float = 0.0")
.Attr("gated_dw: bool = false")
.Attr("gate_grad: bool = false")
.Attr("bench: int = 0")
.Attr("ngate: int >= 0")
.SetShapeFn(XpropShape)
.Doc(R"doc(
Multiply the matrix "a" by the blocksparse matrix "b".
)doc");
typedef struct bsmm_params
{
const int* Lut;
const float* Gate;
int* Lock;
//float4* Scratch;
int blocks;
int bsize;
int segments;
int locks;
int C;
int K;
int N;
int shared;
int pcount;
uint blk_a;
uint blk_A;
uint blk_b;
uint blk_B;
float alpha;
float beta;
CUstream stream;
const int* Lut;
const float* Gate;
int* Lock;
int blocks;
int bsize;
int segments;
int locks;
int C;
int K;
int N;
int shared;
int pcount;
uint blk_a;
uint blk_A;
uint blk_b;
uint blk_B;
float alpha;
float beta;
CUstream stream;
} bsmm_params;
class BlocksparseMatmulOp : public OpKernel {
public:
public:
explicit BlocksparseMatmulOp(OpKernelConstruction* ctx) : OpKernel(ctx) {
OP_REQUIRES_OK(ctx, ctx->GetAttr("segments", &params_.segments));
OP_REQUIRES_OK(ctx, ctx->GetAttr("locks", &params_.locks ));
@@ -147,6 +130,51 @@ class BlocksparseMatmulOp : public OpKernel {
}
void Compute(OpKernelContext* context){
// get device/stream
GPUDevice device = context->eigen_device<GPUDevice>();
triton::driver::cu_stream sstream(device.stream(), false);
triton::driver::context* ctx = sstream.context();
triton::driver::stream* stream = &sstream;
// get inputs
const Tensor& a = context->input(0);
const Tensor& b = context->input(1);
const Tensor& lut = context->input(2);
// allocate c
TensorShape shape_c;
int N = 1;
int rank_a = a.dims();
for (int i = 0; i < rank_a; i++)
if (i != axis_) {
shape_c.AddDim(a.dim_size(i));
N *= a.dim_size(i);
}
else
shape_c.AddDim(params_.K);
Tensor* c = nullptr;
OP_REQUIRES_OK(context, context->allocate_output(0, shape_c, &c));
// grid and block
int blkN = 128, gridN = (N + 127)/128, modN128 = N & 127;
if (axis_ == 1 || (modN128 > 0 && modN128 <= 64) || gridN * params_.segments < SMs_*4){
blkN = 64;
gridN = (N + 63)/64;
}
// allocate locks
Tensor* locks;
TensorShape shape_l;
if (params_.locks > 0)
shape_l.AddDim(gridN * params_.locks * 2);
OP_REQUIRES_OK(context, context->allocate_output(1, shape_l, &locks));
// initialize default compute device
triton::runtime::jit jit(ctx);
// matrix multiplication parameters
triton::driver::cu_buffer da(ctx, (CUdeviceptr)a.flat<float>().data(), false);
triton::driver::cu_buffer db(ctx, (CUdeviceptr)b.flat<float>().data(), false);
triton::driver::cu_buffer dc(ctx, (CUdeviceptr)c->flat<float>().data(), false);
// triton::driver::cu_buffer dlocks(ctx, (CUdeviceptr)locks->flat<int32>().data(), false);
triton::driver::cu_buffer dlut(ctx, (CUdeviceptr)lut.flat<int64>().data(), false);
// blocksparse matmul
triton::dnn::blocksparse::dot dot(N, params_.K, params_.C);
dot.enqueue(stream, {&da, &db, &dc, &dlut}, triton::dnn::NO_TUNING);
}
private:
@@ -157,4 +185,4 @@ private:
char bench_string_[256];
};
REGISTER_KERNEL_BUILDER(Name("BlocksparseMatmul").Device(DEVICE_GPU).TypeConstraint<float>("T"), BlocksparseMatmulOp);
REGISTER_KERNEL_BUILDER(Name("TritonBlocksparseMatmul").Device(DEVICE_GPU).TypeConstraint<float>("T"), BlocksparseMatmulOp);

View File

@@ -5,7 +5,6 @@
#include "triton/driver/stream.h"
#include "triton/runtime/jit.h"
#include "triton/tools/bench.hpp"
#include "triton/dnn/gemm.h"
#include "triton/dnn/conv.h"
#define EIGEN_USE_GPU

View File

@@ -5,7 +5,7 @@
#include "triton/driver/stream.h"
#include "triton/runtime/jit.h"
#include "triton/tools/bench.hpp"
#include "triton/dnn/gemm.h"
#include "triton/dnn/dot.h"
#define EIGEN_USE_GPU
#include "tensorflow/core/framework/op.h"

View File

@@ -19,10 +19,10 @@
using namespace tensorflow;
using GPUDevice = Eigen::GpuDevice;
template<triton::dnn::shift::op_t OP>
template<triton::dnn::op_t OP>
class ShiftConvOp : public OpKernel {
public:
explicit ShiftConvOp(OpKernelConstruction* context) : OpKernel(context), layout_(triton::dnn::shift::NCHW) {
explicit ShiftConvOp(OpKernelConstruction* context) : OpKernel(context), layout_(triton::dnn::NCHW) {
context->GetAttr("shift_h", &h_shift_h_);
context->GetAttr("shift_w", &h_shift_w_);
context->GetAttr("stride_h", &stride_h_);
@@ -32,13 +32,13 @@ public:
}
void ExtractShapes(const Tensor &x, int64_t &C, int64_t &H, int64_t &W, int64_t &B) {
if(layout_ == triton::dnn::shift::CHWN){
if(layout_ == triton::dnn::CHWN){
C = x.dim_size(0);
H = x.dim_size(1);
W = x.dim_size(2);
B = x.dim_size(3);
}
else if(layout_ == triton::dnn::shift::NCHW){
else if(layout_ == triton::dnn::NCHW){
B = x.dim_size(0);
C = x.dim_size(1);
H = x.dim_size(2);
@@ -52,7 +52,7 @@ public:
void FillShapes(OpKernelContext* context,
int64_t &C, int64_t &H, int64_t &W, int64_t &B, int64_t &F,
const Tensor& tf_a, const Tensor& tf_b) {
if(OP == triton::dnn::shift::WGRAD) {
if(OP == triton::dnn::WGRAD) {
int64_t Ha, Wa, Ba;
int64_t Hb, Wb, Bb;
ExtractShapes(tf_a, F, Ha, Wa, Ba);
@@ -68,19 +68,19 @@ public:
// shapes for a
int64_t Ca;
ExtractShapes(tf_a, Ca, H, W, B);
if(OP == triton::dnn::shift::BPROP){
if(OP == triton::dnn::BPROP){
H *= stride_h_;
W *= stride_w_;
}
// shapes for b
int64_t Cb = tf_b.dim_size(0);
F = tf_b.dim_size(1);
if(OP == triton::dnn::shift::BPROP)
if(OP == triton::dnn::BPROP)
std::swap(Cb, F);
// checks
OP_REQUIRES(context, Ca == Cb, tensorflow::errors::InvalidArgument("operands must have the same number of channels"));
C = Ca;
if(OP == triton::dnn::shift::BPROP)
if(OP == triton::dnn::BPROP)
std::swap(C, F);
}
}
@@ -122,7 +122,7 @@ public:
triton::driver::cu_buffer da(ctx, (CUdeviceptr)tf_a.flat<Eigen::half>().data(), false);
triton::driver::cu_buffer db(ctx, (CUdeviceptr)tf_b.flat<Eigen::half>().data(), false);
triton::driver::cu_buffer dc(ctx, (CUdeviceptr)tf_c->flat<Eigen::half>().data(), false);
shift.enqueue(stream, {&da, &db, &dc}, false);
shift.enqueue(stream, {&da, &db, &dc}, triton::dnn::PARTIAL_TUNING);
}
private:
@@ -132,10 +132,10 @@ private:
int stride_w_;
int R_;
int S_;
triton::dnn::shift::layout_t layout_;
triton::dnn::layout_t layout_;
};
REGISTER_KERNEL_BUILDER(Name("ShiftConv").Device(DEVICE_GPU), ShiftConvOp<triton::dnn::shift::FPROP>);
REGISTER_KERNEL_BUILDER(Name("ShiftConv").Device(DEVICE_GPU), ShiftConvOp<triton::dnn::FPROP>);
REGISTER_OP("ShiftConv")
.Input("a: float16")
.Input("b: float16")
@@ -145,7 +145,7 @@ REGISTER_OP("ShiftConv")
.Attr("stride_w: int")
.Output("c: float16");
REGISTER_KERNEL_BUILDER(Name("ShiftConvDx").Device(DEVICE_GPU), ShiftConvOp<triton::dnn::shift::BPROP>);
REGISTER_KERNEL_BUILDER(Name("ShiftConvDx").Device(DEVICE_GPU), ShiftConvOp<triton::dnn::BPROP>);
REGISTER_OP("ShiftConvDx")
.Input("a: float16")
.Input("b: float16")
@@ -155,7 +155,7 @@ REGISTER_OP("ShiftConvDx")
.Attr("stride_w: int")
.Output("c: float16");
REGISTER_KERNEL_BUILDER(Name("ShiftConvDw").Device(DEVICE_GPU), ShiftConvOp<triton::dnn::shift::WGRAD>);
REGISTER_KERNEL_BUILDER(Name("ShiftConvDw").Device(DEVICE_GPU), ShiftConvOp<triton::dnn::WGRAD>);
REGISTER_OP("ShiftConvDw")
.Input("a: float16")
.Input("b: float16")

View File

@@ -0,0 +1,42 @@
#include "triton/driver/stream.h"
#include "triton/driver/kernel.h"
#include "triton/dnn/base.h"
#include <string>
namespace triton{
namespace dnn{
namespace blocksparse{
class dot: public base {
private:
void enqueue_impl(driver::stream *stream, driver::kernel *kernel,
std::vector<driver::buffer*> args,
triton::runtime::launch_information info);
// number of flops
virtual size_t num_flops() const;
// comparison for maps
virtual bool operator<(const base& other) const;
// default parameters
virtual std::vector<params_t> search_space() const;
virtual params_t heuristics() const;
public:
// constructor
dot(int32_t M, int32_t N, int32_t K);
// triton-c source
virtual void triton_c_src(std::ostream &os) const;
// clone
virtual base* clone() const;
private:
std::string ab_ty_;
std::string c_ty_;
int32_t M_;
int32_t N_;
int32_t K_;
};
}
}
}

View File

@@ -13,79 +13,95 @@ typedef std::tuple<size_t, size_t> size_key_t;
static const std::map<trans_key_t, std::map<size_key_t, params_t>> params = {
/* NN */
{trans_key_t(false, false), std::map<size_key_t, params_t>{
{size_key_t(16, 16), {4, 4, 16, 8, 16, 2, 2, 1, 1, 8, 32, 4, 8, 1}},
{size_key_t(16, 32), {2, 8, 16, 8, 32, 2, 2, 1, 1, 16, 32, 4, 8, 1}},
{size_key_t(16, 16), {2, 8, 16, 4, 16, 2, 2, 1, 1, 16, 32, 8, 4, 1}},
{size_key_t(16, 32), {4, 4, 16, 4, 32, 2, 2, 1, 1, 8, 32, 8, 4, 1}},
{size_key_t(16, 64), {4, 4, 16, 4, 64, 2, 2, 1, 1, 8, 32, 8, 4, 1}},
{size_key_t(16, 128), {4, 4, 16, 16, 128, 2, 2, 1, 2, 16, 32, 4, 8, 1}},
{size_key_t(32, 16), {4, 8, 32, 8, 16, 2, 2, 1, 1, 8, 32, 4, 8, 1}},
{size_key_t(32, 32), {4, 8, 32, 8, 32, 2, 2, 1, 1, 8, 32, 4, 8, 1}},
{size_key_t(32, 64), {8, 4, 32, 8, 64, 2, 2, 1, 1, 4, 32, 4, 8, 1}},
{size_key_t(32, 128), {8, 4, 32, 16, 128, 2, 2, 1, 4, 16, 32, 8, 4, 1}},
{size_key_t(64, 16), {8, 8, 64, 4, 16, 2, 2, 1, 1, 4, 32, 8, 4, 1}},
{size_key_t(16, 128), {2, 8, 16, 8, 128, 2, 2, 1, 1, 16, 32, 4, 8, 1}},
{size_key_t(32, 16), {8, 4, 32, 8, 16, 2, 2, 1, 1, 4, 32, 4, 8, 1}},
{size_key_t(32, 32), {4, 8, 32, 4, 32, 2, 2, 1, 1, 8, 32, 8, 4, 1}},
{size_key_t(32, 64), {8, 4, 32, 4, 64, 2, 2, 1, 1, 4, 32, 8, 4, 1}},
{size_key_t(32, 128), {8, 4, 32, 32, 128, 2, 2, 2, 2, 16, 32, 4, 4, 1}},
{size_key_t(32, 256), {4, 8, 32, 32, 256, 2, 2, 1, 4, 32, 32, 4, 8, 1}},
{size_key_t(64, 16), {8, 8, 64, 8, 16, 2, 2, 1, 1, 4, 32, 4, 8, 1}},
{size_key_t(64, 32), {8, 8, 64, 8, 32, 2, 2, 1, 1, 4, 32, 4, 8, 1}},
{size_key_t(64, 64), {8, 8, 64, 16, 64, 2, 2, 2, 1, 8, 32, 4, 8, 1}},
{size_key_t(64, 128), {16, 4, 64, 16, 128, 2, 2, 2, 2, 8, 32, 8, 4, 1}},
{size_key_t(128, 16), {8, 8, 128, 8, 16, 2, 2, 2, 1, 8, 32, 8, 4, 1}},
{size_key_t(128, 32), {8, 8, 128, 16, 32, 2, 2, 2, 1, 8, 32, 4, 8, 1}},
{size_key_t(128, 64), {8, 8, 128, 32, 64, 2, 2, 2, 2, 16, 32, 4, 8, 1}},
{size_key_t(128, 128), {8, 8, 128, 32, 128, 2, 2, 1, 4, 16, 32, 4, 8, 1}}
{size_key_t(64, 64), {8, 8, 64, 16, 64, 2, 2, 1, 2, 8, 32, 4, 8, 1}},
{size_key_t(64, 128), {16, 4, 64, 32, 128, 2, 2, 1, 4, 8, 32, 4, 8, 1}},
{size_key_t(128, 16), {8, 8, 128, 16, 16, 2, 2, 2, 1, 8, 32, 4, 8, 1}},
{size_key_t(128, 32), {32, 4, 128, 16, 32, 2, 2, 2, 1, 2, 32, 4, 8, 1}},
{size_key_t(128, 64), {16, 8, 128, 16, 64, 2, 2, 2, 2, 8, 32, 8, 4, 1}},
{size_key_t(128, 128), {8, 8, 128, 32, 128, 2, 2, 2, 2, 16, 32, 4, 8, 1}},
{size_key_t(256, 16), {32, 8, 256, 16, 16, 2, 2, 4, 1, 4, 32, 8, 4, 1}},
{size_key_t(256, 32), {32, 8, 256, 16, 32, 2, 2, 4, 1, 4, 32, 8, 4, 1}},
{size_key_t(256, 64), {16, 8, 256, 32, 64, 2, 2, 4, 1, 8, 32, 4, 8, 1}}
}},
/* NT */
{trans_key_t(false, true), std::map<size_key_t, params_t>{
{size_key_t(16, 16), {4, 4, 16, 2, 8, 16, 2, 2, 1, 1, 8, 32, 16, 1}},
{size_key_t(16, 32), {4, 4, 16, 4, 8, 32, 2, 2, 1, 1, 8, 32, 8, 1}},
{size_key_t(16, 64), {4, 4, 16, 8, 8, 64, 2, 2, 1, 4, 32, 32, 16, 1}},
{size_key_t(16, 128), {4, 4, 16, 32, 4, 128, 2, 2, 1, 2, 16, 32, 2, 1}},
{size_key_t(32, 16), {8, 4, 32, 2, 8, 16, 2, 2, 1, 1, 4, 32, 16, 1}},
{size_key_t(32, 32), {4, 8, 32, 4, 8, 32, 2, 2, 1, 1, 8, 32, 8, 1}},
{size_key_t(32, 64), {16, 8, 128, 4, 4, 64, 2, 2, 1, 4, 8, 32, 32, 1}},
{size_key_t(32, 128), {4, 8, 32, 8, 8, 128, 2, 2, 1, 2, 16, 32, 8, 1}},
{size_key_t(64, 16), {8, 8, 64, 2, 8, 16, 2, 2, 1, 1, 4, 32, 16, 1}},
{size_key_t(64, 32), {8, 8, 64, 4, 8, 32, 2, 2, 1, 1, 4, 32, 8, 1}},
{size_key_t(64, 64), {8, 8, 64, 8, 8, 64, 2, 2, 1, 2, 8, 32, 8, 1}},
{size_key_t(64, 128), {8, 8, 64, 16, 8, 128, 2, 2, 1, 4, 16, 32, 8, 1}},
{size_key_t(128, 16), {8, 8, 128, 2, 8, 16, 2, 2, 2, 1, 8, 32, 32, 1}},
{size_key_t(128, 32), {16, 8, 128, 4, 8, 32, 2, 2, 2, 1, 4, 32, 16, 1}},
{size_key_t(128, 64), {8, 8, 128, 8, 8, 64, 2, 2, 2, 2, 16, 32, 16, 1}},
{size_key_t(128, 128), {8, 8, 128, 8, 8, 128, 2, 2, 4, 1, 16, 32, 16, 1}}
}},
{size_key_t(16, 16), {2, 4, 16, 2, 8, 16, 2, 2, 1, 1, 16, 32, 16, 1}},
{size_key_t(16, 32), {4, 4, 16, 8, 4, 32, 2, 2, 1, 1, 8, 32, 4, 1}},
{size_key_t(16, 64), {2, 4, 16, 2, 8, 64, 2, 2, 1, 1, 16, 32, 16, 1}},
{size_key_t(16, 128), {2, 8, 16, 8, 8, 128, 2, 2, 1, 1, 16, 32, 4, 1}},
{size_key_t(32, 16), {8, 4, 32, 2, 8, 16, 2, 2, 1, 1, 4, 32, 16, 1}},
{size_key_t(32, 32), {4, 8, 32, 8, 4, 32, 2, 2, 1, 1, 8, 32, 4, 1}},
{size_key_t(32, 64), {16, 4, 64, 16, 4, 64, 2, 2, 4, 1, 8, 32, 8, 1}},
{size_key_t(32, 128), {4, 8, 32, 16, 4, 128, 2, 2, 1, 2, 16, 32, 4, 1}},
{size_key_t(32, 256), {4, 8, 32, 64, 4, 256, 2, 2, 1, 4, 32, 32, 2, 1}},
{size_key_t(64, 16), {8, 8, 64, 2, 8, 16, 2, 2, 1, 1, 4, 32, 16, 1}},
{size_key_t(64, 32), {16, 4, 64, 4, 4, 32, 2, 2, 1, 1, 2, 32, 8, 1}},
{size_key_t(64, 64), {8, 8, 64, 8, 8, 64, 2, 2, 2, 1, 8, 32, 8, 1}},
{size_key_t(64, 128), {4, 4, 64, 8, 8, 128, 2, 2, 1, 4, 32, 32, 16, 1}},
{size_key_t(64, 256), {8, 8, 64, 8, 8, 256, 2, 2, 1, 4, 16, 32, 16, 1}},
{size_key_t(128, 16), {16, 8, 128, 2, 8, 16, 2, 2, 1, 1, 2, 32, 16, 1}},
{size_key_t(128, 32), {32, 4, 128, 4, 8, 32, 2, 2, 2, 1, 2, 32, 16, 1}},
{size_key_t(128, 64), {8, 8, 128, 8, 8, 64, 2, 2, 4, 1, 16, 32, 16, 1}},
{size_key_t(128, 128), {8, 8, 128, 16, 8, 128, 2, 2, 2, 2, 16, 32, 8, 1}},
{size_key_t(256, 16), {32, 4, 256, 4, 4, 16, 2, 2, 4, 1, 4, 32, 32, 1}},
{size_key_t(256, 32), {16, 8, 256, 8, 4, 32, 2, 2, 4, 1, 8, 32, 16, 1}},
{size_key_t(256, 64), {8, 8, 256, 8, 8, 64, 2, 2, 4, 1, 16, 32, 16, 1}}
}},
/* TN */
{trans_key_t(true, false), std::map<size_key_t, params_t>{
{size_key_t(16, 16), {8, 16, 16, 16, 2, 2, 1, 1, 4, 8, 32, 2, 8, 1}},
{size_key_t(16, 32), {4, 16, 8, 32, 2, 2, 1, 1, 8, 4, 32, 4, 8, 1}},
{size_key_t(16, 64), {4, 16, 4, 64, 2, 2, 1, 1, 8, 4, 32, 8, 4, 1}},
{size_key_t(16, 128), {16, 16, 16, 128, 2, 2, 1, 2, 4, 8, 32, 4, 8, 1}},
{size_key_t(32, 16), {4, 32, 8, 16, 2, 2, 1, 1, 8, 4, 32, 4, 8, 1}},
{size_key_t(32, 32), {8, 32, 8, 32, 2, 2, 1, 1, 4, 8, 32, 4, 8, 1}},
{size_key_t(32, 64), {8, 32, 8, 64, 2, 2, 1, 1, 4, 8, 32, 4, 8, 1}},
{size_key_t(32, 128), {32, 32, 64, 128, 2, 2, 2, 2, 4, 8, 32, 2, 8, 1}},
{size_key_t(64, 16), {8, 64, 8, 16, 2, 2, 1, 1, 4, 8, 32, 4, 8, 1}},
{size_key_t(64, 32), {8, 64, 8, 32, 2, 2, 1, 1, 4, 8, 32, 4, 8, 1}},
{size_key_t(64, 64), {16, 64, 16, 64, 2, 2, 2, 1, 4, 8, 32, 4, 8, 1}},
{size_key_t(64, 128), {32, 64, 16, 128, 2, 2, 2, 2, 4, 8, 32, 8, 4, 1}},
{size_key_t(128, 16), {16, 128, 16, 16, 2, 2, 2, 1, 4, 8, 32, 4, 8, 1}},
{size_key_t(128, 32), {32, 128, 32, 32, 2, 2, 4, 1, 4, 8, 32, 4, 8, 1}},
{size_key_t(128, 64), {32, 128, 32, 64, 2, 2, 4, 1, 4, 8, 32, 4, 8, 1}},
{size_key_t(128, 128), {32, 128, 32, 128, 2, 2, 4, 1, 4, 8, 32, 4, 8, 1}},
{size_key_t(16, 16), {4, 16, 4, 16, 2, 2, 1, 1, 8, 4, 32, 8, 4, 1}},
{size_key_t(16, 32), {8, 16, 8, 32, 2, 2, 1, 1, 4, 4, 32, 4, 4, 1}},
{size_key_t(16, 64), {4, 16, 8, 64, 2, 2, 1, 1, 8, 4, 32, 4, 8, 1}},
{size_key_t(16, 128), {4, 16, 8, 128, 2, 2, 1, 1, 8, 4, 32, 4, 8, 1}},
{size_key_t(32, 16), {4, 32, 8, 16, 2, 2, 1, 1, 8, 4, 32, 4, 8, 1}},
{size_key_t(32, 32), {4, 32, 4, 32, 2, 2, 1, 1, 8, 4, 32, 8, 4, 1}},
{size_key_t(32, 64), {4, 32, 4, 64, 2, 2, 1, 1, 8, 4, 32, 8, 4, 1}},
{size_key_t(32, 128), {8, 32, 8, 128, 2, 2, 1, 1, 4, 8, 32, 4, 8, 1}},
{size_key_t(32, 256), {32, 32, 32, 256, 2, 2, 1, 4, 4, 8, 32, 4, 8, 1}},
{size_key_t(64, 16), {4, 64, 8, 16, 2, 2, 1, 1, 8, 4, 32, 4, 8, 1}},
{size_key_t(64, 32), {4, 64, 4, 32, 2, 2, 1, 1, 8, 4, 32, 8, 4, 1}},
{size_key_t(64, 64), {8, 64, 16, 64, 2, 2, 2, 1, 8, 4, 32, 4, 8, 1}},
{size_key_t(64, 128), {16, 64, 32, 128, 2, 2, 1, 4, 8, 4, 32, 4, 8, 1}},
{size_key_t(128, 16), {8, 128, 8, 16, 2, 2, 1, 1, 4, 8, 32, 4, 8, 1}},
{size_key_t(128, 32), {16, 128, 16, 32, 2, 2, 4, 1, 8, 4, 32, 8, 4, 1}},
{size_key_t(128, 64), {32, 128, 32, 64, 2, 2, 2, 2, 4, 8, 32, 4, 8, 1}},
{size_key_t(128, 128), {32, 128, 32, 128, 2, 2, 1, 4, 4, 8, 32, 4, 8, 1}},
{size_key_t(256, 16), {16, 256, 16, 16, 2, 2, 2, 1, 4, 8, 32, 4, 8, 1}},
{size_key_t(256, 32), {16, 256, 32, 32, 2, 2, 4, 1, 8, 4, 32, 4, 8, 1}},
}},
/* TT */
{trans_key_t(true, true), std::map<size_key_t, params_t>{
{size_key_t(16, 16), {4, 16, 2, 8, 16, 2, 2, 1, 1, 8, 4, 32, 16, 1}},
{size_key_t(16, 32), {8, 16, 4, 8, 32, 2, 2, 1, 1, 4, 8, 32, 8, 1}},
{size_key_t(16, 64), {16, 16, 4, 8, 64, 2, 2, 1, 4, 8, 4, 32, 32, 1}},
{size_key_t(16, 128), {16, 16, 8, 4, 128, 2, 2, 1, 2, 4, 8, 32, 8, 1}},
{size_key_t(32, 16), {4, 32, 2, 8, 16, 2, 2, 1, 1, 8, 4, 32, 16, 1}},
{size_key_t(32, 32), {8, 32, 4, 8, 32, 2, 2, 1, 1, 4, 8, 32, 8, 1}},
{size_key_t(32, 64), {16, 64, 4, 8, 64, 2, 2, 2, 1, 4, 8, 32, 16, 1}},
{size_key_t(32, 128), {32, 32, 8, 8, 128, 2, 2, 1, 4, 4, 8, 32, 16, 1}},
{size_key_t(64, 16), {8, 64, 2, 8, 16, 2, 2, 1, 1, 4, 8, 32, 16, 1}},
{size_key_t(64, 32), {8, 64, 4, 8, 32, 2, 2, 1, 1, 4, 8, 32, 8, 1}},
{size_key_t(64, 64), {16, 64, 8, 8, 64, 2, 2, 2, 1, 4, 8, 32, 8, 1}},
{size_key_t(64, 128), {32, 64, 8, 8, 128, 2, 2, 1, 4, 4, 8, 32, 16, 1}},
{size_key_t(128, 16), {16, 128, 2, 8, 16, 2, 2, 2, 1, 4, 8, 32, 32, 1}},
{size_key_t(128, 32), {32, 128, 8, 4, 32, 2, 2, 4, 1, 4, 8, 32, 16, 1}},
{size_key_t(128, 64), {32, 128, 16, 4, 64, 2, 2, 4, 1, 4, 8, 32, 8, 1}},
{size_key_t(128, 128), {32, 128, 8, 8, 128, 2, 2, 4, 1, 4, 8, 32, 16, 1}}
{size_key_t(16, 16), {8, 16, 4, 4, 16, 2, 2, 1, 1, 4, 8, 32, 8, 1}},
{size_key_t(16, 32), {8, 16, 8, 4, 32, 2, 2, 1, 1, 4, 8, 32, 4, 1}},
{size_key_t(16, 64), {16, 16, 4, 8, 64, 2, 2, 1, 4, 8, 4, 32, 32, 1}},
{size_key_t(16, 128), {16, 16, 8, 8, 128, 2, 2, 1, 1, 2, 4, 32, 4, 1}},
{size_key_t(32, 16), {4, 32, 4, 4, 16, 2, 2, 1, 1, 8, 4, 32, 8, 1}},
{size_key_t(32, 32), {8, 32, 8, 4, 32, 2, 2, 1, 1, 4, 8, 32, 4, 1}},
{size_key_t(32, 64), {64, 128, 8, 4, 64, 2, 2, 4, 1, 2, 8, 32, 16, 1}},
{size_key_t(32, 128), {16, 32, 32, 4, 128, 2, 2, 1, 2, 4, 8, 32, 2, 1}},
{size_key_t(32, 256), {32, 32, 32, 4, 256, 2, 2, 1, 4, 4, 8, 32, 4, 1}},
{size_key_t(64, 16), {4, 64, 2, 8, 16, 2, 2, 1, 1, 8, 4, 32, 16, 1}},
{size_key_t(64, 32), {4, 64, 8, 4, 32, 2, 2, 1, 1, 8, 4, 32, 4, 1}},
{size_key_t(64, 64), {16, 64, 8, 8, 64, 2, 2, 2, 1, 4, 8, 32, 8, 1}},
{size_key_t(64, 128), {32, 64, 8, 8, 128, 2, 2, 1, 4, 4, 4, 32, 16, 1}},
{size_key_t(64, 256), {64, 64, 8, 8, 256, 2, 2, 1, 4, 2, 8, 32, 16}},
{size_key_t(128, 16), {8, 128, 2, 8, 16, 2, 2, 1, 1, 4, 8, 32, 16, 1}},
{size_key_t(128, 32), {16, 128, 8, 4, 32, 2, 2, 4, 1, 8, 4, 32, 16, 1}},
{size_key_t(128, 64), {32, 128, 8, 8, 64, 2, 2, 4, 1, 4, 8, 32, 16, 1}},
{size_key_t(128, 128), {32, 128, 16, 8, 128, 2, 2, 2, 2, 4, 8, 32, 8, 1}},
{size_key_t(256, 16), {32, 256, 4, 4, 16, 2, 2, 4, 1, 4, 8, 32, 32, 1}},
{size_key_t(256, 32), {32, 256, 8, 4, 32, 2, 2, 4, 1, 4, 8, 32, 16, 1}}
}}
};

View File

@@ -72,6 +72,7 @@ public:
void target_independent(ir::module &module) {
optimize_dot.run(module);
optimize_trans.run(module);
// optimize_dce.run(module);
}
void target_dependent(ir::module &module) {

View File

@@ -1190,9 +1190,8 @@ void selection::lower_tile_instruction(ir::instruction *ins, llvm::IRBuilder<> &
Value *ptr = pointers->get_value(idx);
ConstantInt *cst = nullptr;
if(GetElementPtrInst *gep = dyn_cast<GetElementPtrInst>(ptr))
if(gep->getNumIndices() == 1){
if(gep->getNumIndices() == 1)
cst = dyn_cast<ConstantInt>(gep->idx_begin());
}
ptr = builder.CreateBitCast(ptr, PointerType::get(VectorType::get(result->get_ty(), vector_size),
ptr->getType()->getPointerAddressSpace()));
packets[id] = builder.CreateLoad(ptr);
@@ -1202,7 +1201,7 @@ void selection::lower_tile_instruction(ir::instruction *ins, llvm::IRBuilder<> &
result->for_each([&](indices_t idx){
unsigned linear = result->get_linear_index(idx);
unsigned id = linear / vector_size;
// result->set_value(idx, builder.CreateExtractElement(packets.at(id), linear % vector_size));
result->set_value(idx, builder.CreateExtractElement(packets.at(id), linear % vector_size));
});
}
// element-wise

View File

@@ -247,14 +247,14 @@ void tune::run(ir::module &mod) {
size_t addr_space = ptr_ty->get_pointer_address_space();
if(addr_space < 4){
ir::type *ty = mod.get_builder().get_int32_ty();
std::unique_ptr<ir::metaparameter> tmp(ir::metaparameter::create(ctx, ty, 8, 8));
std::unique_ptr<ir::metaparameter> tmp(ir::metaparameter::create(ctx, ty, 4, 8));
*params_.at(i).at("nts.d0") = *tmp;
}
}
if(dynamic_cast<ir::dot_inst*>(i) && i->get_type()->is_tile_ty()){
ir::type *ty = mod.get_builder().get_int32_ty();
std::unique_ptr<ir::metaparameter> tmp1(ir::metaparameter::create(ctx, ty, 8, 8));
std::unique_ptr<ir::metaparameter> tmp2(ir::metaparameter::create(ctx, ty, 8, 8));
std::unique_ptr<ir::metaparameter> tmp1(ir::metaparameter::create(ctx, ty, 4, 8));
std::unique_ptr<ir::metaparameter> tmp2(ir::metaparameter::create(ctx, ty, 4, 8));
*params_.at(i).at("nts.d0") = *tmp1;
*params_.at(i).at("nts.d1") = *tmp2;
}

View File

@@ -59,8 +59,8 @@ void base::enqueue(driver::stream *stream, std::vector<driver::buffer *> args, a
jit->add_module(name_.c_str(), src.c_str(), best.params);
}
else {
params_t params = heuristics();
// params_t params = jit->get_valid(name_.c_str(), src.c_str());
// params_t params = heuristics();
params_t params = jit->get_valid(name_.c_str(), src.c_str());
jit->add_module(name_.c_str(), src.c_str(), params);
}
triton::driver::kernel* kernel = jit->get_function(name_.c_str());

109
lib/dnn/blocksparse/dot.cpp Normal file
View File

@@ -0,0 +1,109 @@
#include "triton/dnn/blocksparse/dot.h"
namespace triton{
namespace dnn{
namespace blocksparse{
size_t dot::num_flops() const {
}
bool dot::operator <(const base& other) const {
auto *y = dynamic_cast<const dot*>(&other);
if(!y)
return true;
return std::tie(M_, N_, K_)
< std::tie(y->M_, y->N_, y->K_);
}
std::vector<params_t> dot::search_space() const {
}
params_t dot::heuristics() const {
}
base * dot::clone() const {
return new dot(*this);
}
dot::dot(int32_t M, int32_t N, int32_t K):
base("bsdot"), M_(M), N_(N), K_(K) {
ab_ty_ = "fp32";
c_ty_ = "fp32";
}
void dot::enqueue_impl(driver::stream *stream, driver::kernel *kernel,
std::vector<driver::buffer *> args, runtime::launch_information info) {
driver::buffer *a = args[0];
driver::buffer *b = args[1];
driver::buffer *c = args[2];
driver::buffer *lut = args[3];
int32_t lda = M_;
int32_t ldc = M_;
kernel->setArg(0, a);
kernel->setArg(1, b);
kernel->setArg(2, c);
kernel->setArg(3, lda);
kernel->setArg(4, ldc);
kernel->setArg(5, lut);
int32_t TM = info.globals["TM"];
int32_t TN = info.globals["TN"];
size_t grid_0 = (M_ + TM - 1) / TM;
size_t grid_1 = (N_ + TN - 1) / TN;
stream->enqueue(kernel, {grid_0, grid_1, 1}, {info.num_threads, 1, 1});
stream->synchronize();
}
void dot::triton_c_src(std::ostream &os) const {
std::string result =
R"(
const tunable int32 TM = {64, 128};
const tunable int32 TN = {32};
const tunable int32 TK = {32};
void bsdot(restrict read_only align(16) )" + ab_ty_ + R"( *A,
restrict read_only align(16) )" + ab_ty_ + R"( *B,
fp32* C,
int32 lda, int32 ldc,
int32* lut_base){
int32 ridx = get_range_id(0);
int32 ridy = get_range_id(1);
fp32 c[TM, TN] = 0;
int32 rka[TK] = 0 ... TK;
int32 rkb[TK] = 0 ... TK;
int32 rxa[TM] = ridx * TM + (0 ... TM);
int32 ryb[TN] = 0 ... TN;
int32 offa[TM, TK] = rxa[:, newaxis] + rka[newaxis, :]*lda;
int32 offb[TK, TN] = ryb[newaxis, :] + rkb[:, newaxis]*TK;
int32 *header = lut_base + ridy * 4;
int32 offset = *(header + 0);
int32 K = *(header + 1);
int32 h2 = *(header + 2);
int32 h3 = *(header + 3);
int32 *lut = lut_base + offset*2;
for(int32 k = K; k > 0; k = k - 1){
int32 ak = *(lut + 0);
int32 bk = *(lut + 1);
fp32* pa[TM, TK] = A + offa + ak * TK * lda;
fp32* pb[TK, TN] = B + offb + bk * TK * TN;
fp32 a[TM, TK] = *pa;
fp32 b[TK, TN] = *pb;;
c = dot(a, b, c);
lut = lut + 2;
}
int32 rxc[TM] = ridx * TM + (0 ... TM);
int32 ryc[TN] = ridy * TN + (0 ... TN);
fp32* pc[TM, TN] = C + rxc[:, newaxis] + ryc[newaxis, :]*ldc;
*pc = c;
})";
os << result;
}
}
}
}

View File

@@ -1,6 +1,6 @@
#include "triton/driver/stream.h"
#include "triton/driver/kernel.h"
#include "triton/dnn/gemm.h"
#include "triton/dnn/dot.h"
#include "triton/dnn/heuristics.h"
#include <string>
@@ -101,8 +101,8 @@ void dot::triton_c_src(std::ostream &os) const {
std::string align_ldb_str = "multiple_of(" + std::to_string(align_ldb_) + ")";
std::string res =
R"(
const tunable int32 TM = {16, 32, 64, 128};
const tunable int32 TN = {16, 32, 64, 128};
const tunable int32 TM = {16, 32, 64, 128, 256};
const tunable int32 TN = {16, 32, 64, 128, 256};
const tunable int32 TK = {32};
const tunable int32 GZ = {1};

View File

@@ -467,7 +467,7 @@ if(op_ == WGRAD){
pa = pa_base + offxa[:, newaxis];)";
}
result += R"(
@checka a = *pa;)";
a = checka ? *pa : 0;)";
/* Increment B pointers */
if(op_ == WGRAD){
@@ -488,7 +488,7 @@ if(op_ == BPROP){
pb = pb + TK;)";
}
result += R"(
@checkb b = *pb;
b = checkb ? *pb : 0;
}
int32 rxc[TM] = ridx*TM + (0 ... TM);
int32 ryc[TN] = ridy*TN + (0 ... TN);)";

View File

@@ -37,12 +37,13 @@ void parallel_loop_nest(std::vector<size_t> const & ranges,
size_t D = ranges.size();
std::vector<size_t> values(D, 0);
// thread pools
ThreadPool pool(nthreads);
// ThreadPool pool(nthreads);
// Start with innermost loop
size_t i = D - 1;
while(true){
// Execute function
pool.enqueue(f,values);
// pool.enqueue(f,values);
f(values);
while(values[i]++ == ranges[i] - 1){
if(i == 0)
return;
@@ -50,7 +51,7 @@ void parallel_loop_nest(std::vector<size_t> const & ranges,
}
i = D - 1;
// Short sleep so that the thread pool doesn't grow too big
std::this_thread::sleep_for(std::chrono::microseconds(1));
// std::this_thread::sleep_for(std::chrono::microseconds(1));
}
}
@@ -211,9 +212,9 @@ jit::tune_res_t jit::autotune(const char *name, const char *src, benchmark_t ben
best.perf = perf;
best.params = params;
}
for(size_t i = 0; i < params.size(); i++)
std::cout << ((i==0)?"":", ") << params[i] << std::flush;
std::cout << ", " << perf << " [ " << best.perf << " ] " << std::endl;
// for(size_t i = 0; i < params.size(); i++)
// std::cout << ((i==0)?"":", ") << params[i] << std::flush;
// std::cout << ", " << perf << " [ " << best.perf << " ] " << std::endl;
}
};