[dnn/blocksparse/dot] prototype version seems to pass basic test

This commit is contained in:
Philippe Tillet
2019-07-27 21:21:36 -07:00
parent 2a377bc8b1
commit 17cb2db356
18 changed files with 402 additions and 205 deletions

View File

@@ -4,7 +4,7 @@
#include "triton/runtime/jit.h" #include "triton/runtime/jit.h"
#include "triton/driver/backend.h" #include "triton/driver/backend.h"
#include "triton/driver/stream.h" #include "triton/driver/stream.h"
#include "triton/dnn/gemm.h" #include "triton/dnn/dot.h"
#include "triton/tools/bench.hpp" #include "triton/tools/bench.hpp"
#include "cuda.h" #include "cuda.h"
@@ -48,7 +48,7 @@ perf_t do_bench(triton::driver::stream* stream, bool AT, bool BT, int32_t M, int
stream->synchronize(); stream->synchronize();
triton::dnn::dot dot(M, N, K, AT, BT, ty, ty, 8, 8); triton::dnn::dot dot(M, N, K, AT, BT, ty, ty, 8, 8);
// benchmark triton // benchmark triton
double triton_ns = triton::tools::bench([&]() { dot.enqueue(stream, {da, db, dc}, triton::dnn::FULL_TUNING);}, stream); double triton_ns = triton::tools::bench([&]() { dot.enqueue(stream, {da, db, dc}, triton::dnn::PARTIAL_TUNING);}, stream);
// benchmark cublas // benchmark cublas
NumericT alpha = 1; NumericT alpha = 1;
NumericT beta = 0; NumericT beta = 0;
@@ -98,8 +98,9 @@ int main() {
// shapes to benchmark // shapes to benchmark
std::vector<config_t> configs = { std::vector<config_t> configs = {
// {false, false, 8192, 512, 512}, // {false, false, 8192, 512, 512},
{false, true, 8192, 8192, 8192} // {false, true, 8192, 8192, 8192}
// {false, true, 32768, 256, 512} {false, true, 32768, 256, 256},
{false, true, 32768, 256, 512}
// {true, false, 8192, 512, 512}, // {true, false, 8192, 512, 512},
// {true, true, 8192, 512, 512} // {true, true, 8192, 512, 512}
}; };

View File

@@ -67,23 +67,23 @@ perf_t do_bench(triton::driver::stream *stream,
stream->write(dc, true, 0, hc); stream->write(dc, true, 0, hc);
stream->synchronize(); stream->synchronize();
// benchmark triton // benchmark triton
double triton_ns = triton::tools::bench([&]() { shift.enqueue(stream, {da, db, dc}, triton::dnn::FULL_TUNING);}, stream); double triton_ns = triton::tools::bench([&]() { shift.enqueue(stream, {da, db, dc}, triton::dnn::NO_TUNING);}, stream);
// benchmark cublas // benchmark cublas
NumericT alpha = 1; // NumericT alpha = 1;
NumericT beta = 0; // NumericT beta = 0;
cublasGemmAlgo_t fastest; // cublasGemmAlgo_t fastest;
cublasGemm(HALF_TYPE, stream, shift.AT(), shift.BT(), shift.M(), shift.N(), shift.K(), // cublasGemm(HALF_TYPE, stream, shift.AT(), shift.BT(), shift.M(), shift.N(), shift.K(),
&alpha, da, shift.lda(), // &alpha, da, shift.lda(),
db, shift.ldb(), &beta, // db, shift.ldb(), &beta,
dc, shift.ldc(), &fastest); // dc, shift.ldc(), &fastest);
double cublas_ns = triton::tools::bench([&]() { cublasGemm(HALF_TYPE, stream, shift.AT(), shift.BT(), shift.M(), shift.N(), shift.K(), // double cublas_ns = triton::tools::bench([&]() { cublasGemm(HALF_TYPE, stream, shift.AT(), shift.BT(), shift.M(), shift.N(), shift.K(),
&alpha, da, shift.lda(), // &alpha, da, shift.lda(),
db, shift.ldb(), // db, shift.ldb(),
&beta, dc, shift.ldc(), nullptr, fastest); }, stream); // &beta, dc, shift.ldc(), nullptr, fastest); }, stream);
// result // result
auto tflops = [&](double nanosec) { return shift.num_flops() / nanosec * 1e-3; }; auto tflops = [&](double nanosec) { return shift.num_flops() / nanosec * 1e-3; };
perf_t result; perf_t result;
result.cublas = tflops(cublas_ns); // result.cublas = tflops(cublas_ns);
result.triton = tflops(triton_ns); result.triton = tflops(triton_ns);
delete da; delete da;
delete db; delete db;
@@ -133,8 +133,9 @@ int main() {
{128, 1024, 8, 8, 3, 3, 1024, 1, 1} {128, 1024, 8, 8, 3, 3, 1024, 1, 1}
}; };
for(config_t c: resnet18){ for(config_t c: resnet18){
for(op_t op: {op_t::FPROP, op_t::BPROP, op_t::WGRAD}) for(op_t op: {op_t::FPROP, op_t::BPROP, op_t::WGRAD}){
configs.push_back({c.B, c.C, c.H, c.W, c.R, c.S, c.F, c.stride_h, c.stride_w, op, layout_t::CHWN, "fp16"}); configs.push_back({c.B, c.C, c.H, c.W, c.R, c.S, c.F, c.stride_h, c.stride_w, op, layout_t::CHWN, "fp16"});
}
} }
// initialize default compute device // initialize default compute device

View File

@@ -5,7 +5,7 @@ if(${TensorFlow_FOUND})
include_directories("${CUDA_HOME}/include") include_directories("${CUDA_HOME}/include")
link_directories(${TF_LIB}) link_directories(${TF_LIB})
add_definitions(-D_GLIBCXX_USE_CXX11_ABI=${TF_ABI}) add_definitions(-D_GLIBCXX_USE_CXX11_ABI=${TF_ABI})
add_library(tf_blocksparse SHARED dot.cpp conv.cpp shift.cpp batchnorm.cpp) add_library(tf_blocksparse SHARED blocksparse.cpp dot.cpp conv.cpp shift.cpp batchnorm.cpp)
target_link_libraries(tf_blocksparse tensorflow_framework triton) target_link_libraries(tf_blocksparse tensorflow_framework triton)
configure_file(${CMAKE_CURRENT_SOURCE_DIR}/run.py configure_file(${CMAKE_CURRENT_SOURCE_DIR}/run.py
${CMAKE_CURRENT_BINARY_DIR}/run.py ${CMAKE_CURRENT_BINARY_DIR}/run.py

View File

@@ -3,7 +3,8 @@
#include "triton/driver/buffer.h" #include "triton/driver/buffer.h"
#include "triton/driver/backend.h" #include "triton/driver/backend.h"
#include "triton/driver/stream.h" #include "triton/driver/stream.h"
#include "triton/jit.h" #include "triton/runtime/jit.h"
#include "triton/dnn/blocksparse/dot.h"
#define EIGEN_USE_GPU #define EIGEN_USE_GPU
#include "tensorflow/core/framework/op.h" #include "tensorflow/core/framework/op.h"
@@ -20,106 +21,88 @@ using shape_inference::InferenceContext;
using shape_inference::ShapeHandle; using shape_inference::ShapeHandle;
using GPUDevice = Eigen::GpuDevice; using GPUDevice = Eigen::GpuDevice;
const char* src =
R"(
const tunable int32 TM = {16, 32, 64, 128};
const tunable int32 TN = {16, 32, 64, 128};
const tunable int32 TK = {8};
const tunable int32 GZ = {1};
void bsmm (restrict read_only fp32 *A, restrict read_only fp32 *B, fp32 *C,
int32 M, int32 N, int32 K,
int32 lda, int32 ldb, int32 ldc,
int32 *locks, int32 grid0, int32 grid1) {
}
)";
Status XpropShape(InferenceContext* ctx) Status XpropShape(InferenceContext* ctx)
{ {
int K; TF_RETURN_IF_ERROR(ctx->GetAttr( "K", &K)); int K; TF_RETURN_IF_ERROR(ctx->GetAttr( "K", &K));
int axis; TF_RETURN_IF_ERROR(ctx->GetAttr("axis", &axis)); int axis; TF_RETURN_IF_ERROR(ctx->GetAttr("axis", &axis));
// C ==> K // C ==> K
ShapeHandle x = ctx->input(0); ShapeHandle x = ctx->input(0);
int rank = ctx->Rank(x); int rank = ctx->Rank(x);
//printf("XpropShape: %d\n", rank); //printf("XpropShape: %d\n", rank);
if (rank > 0) if (rank > 0)
{ {
std::vector<DimensionHandle> shape; std::vector<DimensionHandle> shape;
shape.reserve(rank); shape.reserve(rank);
for (int i = 0; i < rank; i++) for (int i = 0; i < rank; i++)
shape.push_back(i == axis ? ctx->MakeDim(K) : ctx->Dim(x, i)); shape.push_back(i == axis ? ctx->MakeDim(K) : ctx->Dim(x, i));
ctx->set_output(0, ctx->MakeShape(shape));
ctx->set_output(0, ctx->MakeShape(shape)); }
} else
else ctx->set_output(0, ctx->UnknownShape());
ctx->set_output(0, ctx->UnknownShape()); ctx->set_output(1, ctx->UnknownShape());
ctx->set_output(1, ctx->UnknownShape()); return Status::OK();
return Status::OK();
} }
REGISTER_OP("BlocksparseMatmul") REGISTER_OP("TritonBlocksparseMatmul")
.Input("x: T") .Input("x: T")
.Input("w: T") .Input("w: T")
.Input("lut: int64") .Input("lut: int64")
.Input("lut_dx: int64") .Input("lut_dx: int64")
.Input("lut_dw: int64") .Input("lut_dw: int64")
.Input("gate: ngate * float") .Input("gate: ngate * float")
.Output("y: T") .Output("y: T")
.Output("temp: int32") .Output("temp: int32")
.Attr("T: {half, float, bfloat16}") .Attr("T: {half, float, bfloat16}")
.Attr("blocks: int >=0") .Attr("blocks: int >=0")
.Attr("bsize: int") .Attr("bsize: int")
.Attr("segments: int = 0") .Attr("segments: int = 0")
.Attr("segments_dx: int = 0") .Attr("segments_dx: int = 0")
.Attr("locks: int = 0") .Attr("locks: int = 0")
.Attr("locks_dx: int = 0") .Attr("locks_dx: int = 0")
.Attr("axis: int = 1") .Attr("axis: int = 1")
.Attr("C: int >=0") .Attr("C: int >=0")
.Attr("K: int >=0") .Attr("K: int >=0")
.Attr("shared: int = 0") .Attr("shared: int = 0")
.Attr("shared_dx: int = 0") .Attr("shared_dx: int = 0")
.Attr("alpha: float = 1.0") .Attr("alpha: float = 1.0")
.Attr("beta: float = 0.0") .Attr("beta: float = 0.0")
.Attr("gated_dw: bool = false") .Attr("gated_dw: bool = false")
.Attr("gate_grad: bool = false") .Attr("gate_grad: bool = false")
.Attr("bench: int = 0") .Attr("bench: int = 0")
.Attr("ngate: int >= 0") .Attr("ngate: int >= 0")
.SetShapeFn(XpropShape) .SetShapeFn(XpropShape)
.Doc(R"doc( .Doc(R"doc(
Multiply the matrix "a" by the blocksparse matrix "b". Multiply the matrix "a" by the blocksparse matrix "b".
)doc"); )doc");
typedef struct bsmm_params typedef struct bsmm_params
{ {
const int* Lut; const int* Lut;
const float* Gate; const float* Gate;
int* Lock; int* Lock;
//float4* Scratch; int blocks;
int blocks; int bsize;
int bsize; int segments;
int segments; int locks;
int locks; int C;
int C; int K;
int K; int N;
int N; int shared;
int shared; int pcount;
int pcount; uint blk_a;
uint blk_a; uint blk_A;
uint blk_A; uint blk_b;
uint blk_b; uint blk_B;
uint blk_B; float alpha;
float alpha; float beta;
float beta; CUstream stream;
CUstream stream;
} bsmm_params; } bsmm_params;
class BlocksparseMatmulOp : public OpKernel { class BlocksparseMatmulOp : public OpKernel {
public: public:
explicit BlocksparseMatmulOp(OpKernelConstruction* ctx) : OpKernel(ctx) { explicit BlocksparseMatmulOp(OpKernelConstruction* ctx) : OpKernel(ctx) {
OP_REQUIRES_OK(ctx, ctx->GetAttr("segments", &params_.segments)); OP_REQUIRES_OK(ctx, ctx->GetAttr("segments", &params_.segments));
OP_REQUIRES_OK(ctx, ctx->GetAttr("locks", &params_.locks )); OP_REQUIRES_OK(ctx, ctx->GetAttr("locks", &params_.locks ));
@@ -147,6 +130,51 @@ class BlocksparseMatmulOp : public OpKernel {
} }
void Compute(OpKernelContext* context){ void Compute(OpKernelContext* context){
// get device/stream
GPUDevice device = context->eigen_device<GPUDevice>();
triton::driver::cu_stream sstream(device.stream(), false);
triton::driver::context* ctx = sstream.context();
triton::driver::stream* stream = &sstream;
// get inputs
const Tensor& a = context->input(0);
const Tensor& b = context->input(1);
const Tensor& lut = context->input(2);
// allocate c
TensorShape shape_c;
int N = 1;
int rank_a = a.dims();
for (int i = 0; i < rank_a; i++)
if (i != axis_) {
shape_c.AddDim(a.dim_size(i));
N *= a.dim_size(i);
}
else
shape_c.AddDim(params_.K);
Tensor* c = nullptr;
OP_REQUIRES_OK(context, context->allocate_output(0, shape_c, &c));
// grid and block
int blkN = 128, gridN = (N + 127)/128, modN128 = N & 127;
if (axis_ == 1 || (modN128 > 0 && modN128 <= 64) || gridN * params_.segments < SMs_*4){
blkN = 64;
gridN = (N + 63)/64;
}
// allocate locks
Tensor* locks;
TensorShape shape_l;
if (params_.locks > 0)
shape_l.AddDim(gridN * params_.locks * 2);
OP_REQUIRES_OK(context, context->allocate_output(1, shape_l, &locks));
// initialize default compute device
triton::runtime::jit jit(ctx);
// matrix multiplication parameters
triton::driver::cu_buffer da(ctx, (CUdeviceptr)a.flat<float>().data(), false);
triton::driver::cu_buffer db(ctx, (CUdeviceptr)b.flat<float>().data(), false);
triton::driver::cu_buffer dc(ctx, (CUdeviceptr)c->flat<float>().data(), false);
// triton::driver::cu_buffer dlocks(ctx, (CUdeviceptr)locks->flat<int32>().data(), false);
triton::driver::cu_buffer dlut(ctx, (CUdeviceptr)lut.flat<int64>().data(), false);
// blocksparse matmul
triton::dnn::blocksparse::dot dot(N, params_.K, params_.C);
dot.enqueue(stream, {&da, &db, &dc, &dlut}, triton::dnn::NO_TUNING);
} }
private: private:
@@ -157,4 +185,4 @@ private:
char bench_string_[256]; char bench_string_[256];
}; };
REGISTER_KERNEL_BUILDER(Name("BlocksparseMatmul").Device(DEVICE_GPU).TypeConstraint<float>("T"), BlocksparseMatmulOp); REGISTER_KERNEL_BUILDER(Name("TritonBlocksparseMatmul").Device(DEVICE_GPU).TypeConstraint<float>("T"), BlocksparseMatmulOp);

View File

@@ -5,7 +5,6 @@
#include "triton/driver/stream.h" #include "triton/driver/stream.h"
#include "triton/runtime/jit.h" #include "triton/runtime/jit.h"
#include "triton/tools/bench.hpp" #include "triton/tools/bench.hpp"
#include "triton/dnn/gemm.h"
#include "triton/dnn/conv.h" #include "triton/dnn/conv.h"
#define EIGEN_USE_GPU #define EIGEN_USE_GPU

View File

@@ -5,7 +5,7 @@
#include "triton/driver/stream.h" #include "triton/driver/stream.h"
#include "triton/runtime/jit.h" #include "triton/runtime/jit.h"
#include "triton/tools/bench.hpp" #include "triton/tools/bench.hpp"
#include "triton/dnn/gemm.h" #include "triton/dnn/dot.h"
#define EIGEN_USE_GPU #define EIGEN_USE_GPU
#include "tensorflow/core/framework/op.h" #include "tensorflow/core/framework/op.h"

View File

@@ -19,10 +19,10 @@
using namespace tensorflow; using namespace tensorflow;
using GPUDevice = Eigen::GpuDevice; using GPUDevice = Eigen::GpuDevice;
template<triton::dnn::shift::op_t OP> template<triton::dnn::op_t OP>
class ShiftConvOp : public OpKernel { class ShiftConvOp : public OpKernel {
public: public:
explicit ShiftConvOp(OpKernelConstruction* context) : OpKernel(context), layout_(triton::dnn::shift::NCHW) { explicit ShiftConvOp(OpKernelConstruction* context) : OpKernel(context), layout_(triton::dnn::NCHW) {
context->GetAttr("shift_h", &h_shift_h_); context->GetAttr("shift_h", &h_shift_h_);
context->GetAttr("shift_w", &h_shift_w_); context->GetAttr("shift_w", &h_shift_w_);
context->GetAttr("stride_h", &stride_h_); context->GetAttr("stride_h", &stride_h_);
@@ -32,13 +32,13 @@ public:
} }
void ExtractShapes(const Tensor &x, int64_t &C, int64_t &H, int64_t &W, int64_t &B) { void ExtractShapes(const Tensor &x, int64_t &C, int64_t &H, int64_t &W, int64_t &B) {
if(layout_ == triton::dnn::shift::CHWN){ if(layout_ == triton::dnn::CHWN){
C = x.dim_size(0); C = x.dim_size(0);
H = x.dim_size(1); H = x.dim_size(1);
W = x.dim_size(2); W = x.dim_size(2);
B = x.dim_size(3); B = x.dim_size(3);
} }
else if(layout_ == triton::dnn::shift::NCHW){ else if(layout_ == triton::dnn::NCHW){
B = x.dim_size(0); B = x.dim_size(0);
C = x.dim_size(1); C = x.dim_size(1);
H = x.dim_size(2); H = x.dim_size(2);
@@ -52,7 +52,7 @@ public:
void FillShapes(OpKernelContext* context, void FillShapes(OpKernelContext* context,
int64_t &C, int64_t &H, int64_t &W, int64_t &B, int64_t &F, int64_t &C, int64_t &H, int64_t &W, int64_t &B, int64_t &F,
const Tensor& tf_a, const Tensor& tf_b) { const Tensor& tf_a, const Tensor& tf_b) {
if(OP == triton::dnn::shift::WGRAD) { if(OP == triton::dnn::WGRAD) {
int64_t Ha, Wa, Ba; int64_t Ha, Wa, Ba;
int64_t Hb, Wb, Bb; int64_t Hb, Wb, Bb;
ExtractShapes(tf_a, F, Ha, Wa, Ba); ExtractShapes(tf_a, F, Ha, Wa, Ba);
@@ -68,19 +68,19 @@ public:
// shapes for a // shapes for a
int64_t Ca; int64_t Ca;
ExtractShapes(tf_a, Ca, H, W, B); ExtractShapes(tf_a, Ca, H, W, B);
if(OP == triton::dnn::shift::BPROP){ if(OP == triton::dnn::BPROP){
H *= stride_h_; H *= stride_h_;
W *= stride_w_; W *= stride_w_;
} }
// shapes for b // shapes for b
int64_t Cb = tf_b.dim_size(0); int64_t Cb = tf_b.dim_size(0);
F = tf_b.dim_size(1); F = tf_b.dim_size(1);
if(OP == triton::dnn::shift::BPROP) if(OP == triton::dnn::BPROP)
std::swap(Cb, F); std::swap(Cb, F);
// checks // checks
OP_REQUIRES(context, Ca == Cb, tensorflow::errors::InvalidArgument("operands must have the same number of channels")); OP_REQUIRES(context, Ca == Cb, tensorflow::errors::InvalidArgument("operands must have the same number of channels"));
C = Ca; C = Ca;
if(OP == triton::dnn::shift::BPROP) if(OP == triton::dnn::BPROP)
std::swap(C, F); std::swap(C, F);
} }
} }
@@ -122,7 +122,7 @@ public:
triton::driver::cu_buffer da(ctx, (CUdeviceptr)tf_a.flat<Eigen::half>().data(), false); triton::driver::cu_buffer da(ctx, (CUdeviceptr)tf_a.flat<Eigen::half>().data(), false);
triton::driver::cu_buffer db(ctx, (CUdeviceptr)tf_b.flat<Eigen::half>().data(), false); triton::driver::cu_buffer db(ctx, (CUdeviceptr)tf_b.flat<Eigen::half>().data(), false);
triton::driver::cu_buffer dc(ctx, (CUdeviceptr)tf_c->flat<Eigen::half>().data(), false); triton::driver::cu_buffer dc(ctx, (CUdeviceptr)tf_c->flat<Eigen::half>().data(), false);
shift.enqueue(stream, {&da, &db, &dc}, false); shift.enqueue(stream, {&da, &db, &dc}, triton::dnn::PARTIAL_TUNING);
} }
private: private:
@@ -132,10 +132,10 @@ private:
int stride_w_; int stride_w_;
int R_; int R_;
int S_; int S_;
triton::dnn::shift::layout_t layout_; triton::dnn::layout_t layout_;
}; };
REGISTER_KERNEL_BUILDER(Name("ShiftConv").Device(DEVICE_GPU), ShiftConvOp<triton::dnn::shift::FPROP>); REGISTER_KERNEL_BUILDER(Name("ShiftConv").Device(DEVICE_GPU), ShiftConvOp<triton::dnn::FPROP>);
REGISTER_OP("ShiftConv") REGISTER_OP("ShiftConv")
.Input("a: float16") .Input("a: float16")
.Input("b: float16") .Input("b: float16")
@@ -145,7 +145,7 @@ REGISTER_OP("ShiftConv")
.Attr("stride_w: int") .Attr("stride_w: int")
.Output("c: float16"); .Output("c: float16");
REGISTER_KERNEL_BUILDER(Name("ShiftConvDx").Device(DEVICE_GPU), ShiftConvOp<triton::dnn::shift::BPROP>); REGISTER_KERNEL_BUILDER(Name("ShiftConvDx").Device(DEVICE_GPU), ShiftConvOp<triton::dnn::BPROP>);
REGISTER_OP("ShiftConvDx") REGISTER_OP("ShiftConvDx")
.Input("a: float16") .Input("a: float16")
.Input("b: float16") .Input("b: float16")
@@ -155,7 +155,7 @@ REGISTER_OP("ShiftConvDx")
.Attr("stride_w: int") .Attr("stride_w: int")
.Output("c: float16"); .Output("c: float16");
REGISTER_KERNEL_BUILDER(Name("ShiftConvDw").Device(DEVICE_GPU), ShiftConvOp<triton::dnn::shift::WGRAD>); REGISTER_KERNEL_BUILDER(Name("ShiftConvDw").Device(DEVICE_GPU), ShiftConvOp<triton::dnn::WGRAD>);
REGISTER_OP("ShiftConvDw") REGISTER_OP("ShiftConvDw")
.Input("a: float16") .Input("a: float16")
.Input("b: float16") .Input("b: float16")

View File

@@ -0,0 +1,42 @@
#include "triton/driver/stream.h"
#include "triton/driver/kernel.h"
#include "triton/dnn/base.h"
#include <string>
namespace triton{
namespace dnn{
namespace blocksparse{
class dot: public base {
private:
void enqueue_impl(driver::stream *stream, driver::kernel *kernel,
std::vector<driver::buffer*> args,
triton::runtime::launch_information info);
// number of flops
virtual size_t num_flops() const;
// comparison for maps
virtual bool operator<(const base& other) const;
// default parameters
virtual std::vector<params_t> search_space() const;
virtual params_t heuristics() const;
public:
// constructor
dot(int32_t M, int32_t N, int32_t K);
// triton-c source
virtual void triton_c_src(std::ostream &os) const;
// clone
virtual base* clone() const;
private:
std::string ab_ty_;
std::string c_ty_;
int32_t M_;
int32_t N_;
int32_t K_;
};
}
}
}

View File

@@ -13,79 +13,95 @@ typedef std::tuple<size_t, size_t> size_key_t;
static const std::map<trans_key_t, std::map<size_key_t, params_t>> params = { static const std::map<trans_key_t, std::map<size_key_t, params_t>> params = {
/* NN */ /* NN */
{trans_key_t(false, false), std::map<size_key_t, params_t>{ {trans_key_t(false, false), std::map<size_key_t, params_t>{
{size_key_t(16, 16), {4, 4, 16, 8, 16, 2, 2, 1, 1, 8, 32, 4, 8, 1}}, {size_key_t(16, 16), {2, 8, 16, 4, 16, 2, 2, 1, 1, 16, 32, 8, 4, 1}},
{size_key_t(16, 32), {2, 8, 16, 8, 32, 2, 2, 1, 1, 16, 32, 4, 8, 1}}, {size_key_t(16, 32), {4, 4, 16, 4, 32, 2, 2, 1, 1, 8, 32, 8, 4, 1}},
{size_key_t(16, 64), {4, 4, 16, 4, 64, 2, 2, 1, 1, 8, 32, 8, 4, 1}}, {size_key_t(16, 64), {4, 4, 16, 4, 64, 2, 2, 1, 1, 8, 32, 8, 4, 1}},
{size_key_t(16, 128), {4, 4, 16, 16, 128, 2, 2, 1, 2, 16, 32, 4, 8, 1}}, {size_key_t(16, 128), {2, 8, 16, 8, 128, 2, 2, 1, 1, 16, 32, 4, 8, 1}},
{size_key_t(32, 16), {4, 8, 32, 8, 16, 2, 2, 1, 1, 8, 32, 4, 8, 1}}, {size_key_t(32, 16), {8, 4, 32, 8, 16, 2, 2, 1, 1, 4, 32, 4, 8, 1}},
{size_key_t(32, 32), {4, 8, 32, 8, 32, 2, 2, 1, 1, 8, 32, 4, 8, 1}}, {size_key_t(32, 32), {4, 8, 32, 4, 32, 2, 2, 1, 1, 8, 32, 8, 4, 1}},
{size_key_t(32, 64), {8, 4, 32, 8, 64, 2, 2, 1, 1, 4, 32, 4, 8, 1}}, {size_key_t(32, 64), {8, 4, 32, 4, 64, 2, 2, 1, 1, 4, 32, 8, 4, 1}},
{size_key_t(32, 128), {8, 4, 32, 16, 128, 2, 2, 1, 4, 16, 32, 8, 4, 1}}, {size_key_t(32, 128), {8, 4, 32, 32, 128, 2, 2, 2, 2, 16, 32, 4, 4, 1}},
{size_key_t(64, 16), {8, 8, 64, 4, 16, 2, 2, 1, 1, 4, 32, 8, 4, 1}}, {size_key_t(32, 256), {4, 8, 32, 32, 256, 2, 2, 1, 4, 32, 32, 4, 8, 1}},
{size_key_t(64, 16), {8, 8, 64, 8, 16, 2, 2, 1, 1, 4, 32, 4, 8, 1}},
{size_key_t(64, 32), {8, 8, 64, 8, 32, 2, 2, 1, 1, 4, 32, 4, 8, 1}}, {size_key_t(64, 32), {8, 8, 64, 8, 32, 2, 2, 1, 1, 4, 32, 4, 8, 1}},
{size_key_t(64, 64), {8, 8, 64, 16, 64, 2, 2, 2, 1, 8, 32, 4, 8, 1}}, {size_key_t(64, 64), {8, 8, 64, 16, 64, 2, 2, 1, 2, 8, 32, 4, 8, 1}},
{size_key_t(64, 128), {16, 4, 64, 16, 128, 2, 2, 2, 2, 8, 32, 8, 4, 1}}, {size_key_t(64, 128), {16, 4, 64, 32, 128, 2, 2, 1, 4, 8, 32, 4, 8, 1}},
{size_key_t(128, 16), {8, 8, 128, 8, 16, 2, 2, 2, 1, 8, 32, 8, 4, 1}}, {size_key_t(128, 16), {8, 8, 128, 16, 16, 2, 2, 2, 1, 8, 32, 4, 8, 1}},
{size_key_t(128, 32), {8, 8, 128, 16, 32, 2, 2, 2, 1, 8, 32, 4, 8, 1}}, {size_key_t(128, 32), {32, 4, 128, 16, 32, 2, 2, 2, 1, 2, 32, 4, 8, 1}},
{size_key_t(128, 64), {8, 8, 128, 32, 64, 2, 2, 2, 2, 16, 32, 4, 8, 1}}, {size_key_t(128, 64), {16, 8, 128, 16, 64, 2, 2, 2, 2, 8, 32, 8, 4, 1}},
{size_key_t(128, 128), {8, 8, 128, 32, 128, 2, 2, 1, 4, 16, 32, 4, 8, 1}} {size_key_t(128, 128), {8, 8, 128, 32, 128, 2, 2, 2, 2, 16, 32, 4, 8, 1}},
{size_key_t(256, 16), {32, 8, 256, 16, 16, 2, 2, 4, 1, 4, 32, 8, 4, 1}},
{size_key_t(256, 32), {32, 8, 256, 16, 32, 2, 2, 4, 1, 4, 32, 8, 4, 1}},
{size_key_t(256, 64), {16, 8, 256, 32, 64, 2, 2, 4, 1, 8, 32, 4, 8, 1}}
}}, }},
/* NT */ /* NT */
{trans_key_t(false, true), std::map<size_key_t, params_t>{ {trans_key_t(false, true), std::map<size_key_t, params_t>{
{size_key_t(16, 16), {4, 4, 16, 2, 8, 16, 2, 2, 1, 1, 8, 32, 16, 1}}, {size_key_t(16, 16), {2, 4, 16, 2, 8, 16, 2, 2, 1, 1, 16, 32, 16, 1}},
{size_key_t(16, 32), {4, 4, 16, 4, 8, 32, 2, 2, 1, 1, 8, 32, 8, 1}}, {size_key_t(16, 32), {4, 4, 16, 8, 4, 32, 2, 2, 1, 1, 8, 32, 4, 1}},
{size_key_t(16, 64), {4, 4, 16, 8, 8, 64, 2, 2, 1, 4, 32, 32, 16, 1}}, {size_key_t(16, 64), {2, 4, 16, 2, 8, 64, 2, 2, 1, 1, 16, 32, 16, 1}},
{size_key_t(16, 128), {4, 4, 16, 32, 4, 128, 2, 2, 1, 2, 16, 32, 2, 1}}, {size_key_t(16, 128), {2, 8, 16, 8, 8, 128, 2, 2, 1, 1, 16, 32, 4, 1}},
{size_key_t(32, 16), {8, 4, 32, 2, 8, 16, 2, 2, 1, 1, 4, 32, 16, 1}}, {size_key_t(32, 16), {8, 4, 32, 2, 8, 16, 2, 2, 1, 1, 4, 32, 16, 1}},
{size_key_t(32, 32), {4, 8, 32, 4, 8, 32, 2, 2, 1, 1, 8, 32, 8, 1}}, {size_key_t(32, 32), {4, 8, 32, 8, 4, 32, 2, 2, 1, 1, 8, 32, 4, 1}},
{size_key_t(32, 64), {16, 8, 128, 4, 4, 64, 2, 2, 1, 4, 8, 32, 32, 1}}, {size_key_t(32, 64), {16, 4, 64, 16, 4, 64, 2, 2, 4, 1, 8, 32, 8, 1}},
{size_key_t(32, 128), {4, 8, 32, 8, 8, 128, 2, 2, 1, 2, 16, 32, 8, 1}}, {size_key_t(32, 128), {4, 8, 32, 16, 4, 128, 2, 2, 1, 2, 16, 32, 4, 1}},
{size_key_t(64, 16), {8, 8, 64, 2, 8, 16, 2, 2, 1, 1, 4, 32, 16, 1}}, {size_key_t(32, 256), {4, 8, 32, 64, 4, 256, 2, 2, 1, 4, 32, 32, 2, 1}},
{size_key_t(64, 32), {8, 8, 64, 4, 8, 32, 2, 2, 1, 1, 4, 32, 8, 1}}, {size_key_t(64, 16), {8, 8, 64, 2, 8, 16, 2, 2, 1, 1, 4, 32, 16, 1}},
{size_key_t(64, 64), {8, 8, 64, 8, 8, 64, 2, 2, 1, 2, 8, 32, 8, 1}}, {size_key_t(64, 32), {16, 4, 64, 4, 4, 32, 2, 2, 1, 1, 2, 32, 8, 1}},
{size_key_t(64, 128), {8, 8, 64, 16, 8, 128, 2, 2, 1, 4, 16, 32, 8, 1}}, {size_key_t(64, 64), {8, 8, 64, 8, 8, 64, 2, 2, 2, 1, 8, 32, 8, 1}},
{size_key_t(128, 16), {8, 8, 128, 2, 8, 16, 2, 2, 2, 1, 8, 32, 32, 1}}, {size_key_t(64, 128), {4, 4, 64, 8, 8, 128, 2, 2, 1, 4, 32, 32, 16, 1}},
{size_key_t(128, 32), {16, 8, 128, 4, 8, 32, 2, 2, 2, 1, 4, 32, 16, 1}}, {size_key_t(64, 256), {8, 8, 64, 8, 8, 256, 2, 2, 1, 4, 16, 32, 16, 1}},
{size_key_t(128, 64), {8, 8, 128, 8, 8, 64, 2, 2, 2, 2, 16, 32, 16, 1}}, {size_key_t(128, 16), {16, 8, 128, 2, 8, 16, 2, 2, 1, 1, 2, 32, 16, 1}},
{size_key_t(128, 128), {8, 8, 128, 8, 8, 128, 2, 2, 4, 1, 16, 32, 16, 1}} {size_key_t(128, 32), {32, 4, 128, 4, 8, 32, 2, 2, 2, 1, 2, 32, 16, 1}},
}}, {size_key_t(128, 64), {8, 8, 128, 8, 8, 64, 2, 2, 4, 1, 16, 32, 16, 1}},
{size_key_t(128, 128), {8, 8, 128, 16, 8, 128, 2, 2, 2, 2, 16, 32, 8, 1}},
{size_key_t(256, 16), {32, 4, 256, 4, 4, 16, 2, 2, 4, 1, 4, 32, 32, 1}},
{size_key_t(256, 32), {16, 8, 256, 8, 4, 32, 2, 2, 4, 1, 8, 32, 16, 1}},
{size_key_t(256, 64), {8, 8, 256, 8, 8, 64, 2, 2, 4, 1, 16, 32, 16, 1}}
}},
/* TN */ /* TN */
{trans_key_t(true, false), std::map<size_key_t, params_t>{ {trans_key_t(true, false), std::map<size_key_t, params_t>{
{size_key_t(16, 16), {8, 16, 16, 16, 2, 2, 1, 1, 4, 8, 32, 2, 8, 1}}, {size_key_t(16, 16), {4, 16, 4, 16, 2, 2, 1, 1, 8, 4, 32, 8, 4, 1}},
{size_key_t(16, 32), {4, 16, 8, 32, 2, 2, 1, 1, 8, 4, 32, 4, 8, 1}}, {size_key_t(16, 32), {8, 16, 8, 32, 2, 2, 1, 1, 4, 4, 32, 4, 4, 1}},
{size_key_t(16, 64), {4, 16, 4, 64, 2, 2, 1, 1, 8, 4, 32, 8, 4, 1}}, {size_key_t(16, 64), {4, 16, 8, 64, 2, 2, 1, 1, 8, 4, 32, 4, 8, 1}},
{size_key_t(16, 128), {16, 16, 16, 128, 2, 2, 1, 2, 4, 8, 32, 4, 8, 1}}, {size_key_t(16, 128), {4, 16, 8, 128, 2, 2, 1, 1, 8, 4, 32, 4, 8, 1}},
{size_key_t(32, 16), {4, 32, 8, 16, 2, 2, 1, 1, 8, 4, 32, 4, 8, 1}}, {size_key_t(32, 16), {4, 32, 8, 16, 2, 2, 1, 1, 8, 4, 32, 4, 8, 1}},
{size_key_t(32, 32), {8, 32, 8, 32, 2, 2, 1, 1, 4, 8, 32, 4, 8, 1}}, {size_key_t(32, 32), {4, 32, 4, 32, 2, 2, 1, 1, 8, 4, 32, 8, 4, 1}},
{size_key_t(32, 64), {8, 32, 8, 64, 2, 2, 1, 1, 4, 8, 32, 4, 8, 1}}, {size_key_t(32, 64), {4, 32, 4, 64, 2, 2, 1, 1, 8, 4, 32, 8, 4, 1}},
{size_key_t(32, 128), {32, 32, 64, 128, 2, 2, 2, 2, 4, 8, 32, 2, 8, 1}}, {size_key_t(32, 128), {8, 32, 8, 128, 2, 2, 1, 1, 4, 8, 32, 4, 8, 1}},
{size_key_t(64, 16), {8, 64, 8, 16, 2, 2, 1, 1, 4, 8, 32, 4, 8, 1}}, {size_key_t(32, 256), {32, 32, 32, 256, 2, 2, 1, 4, 4, 8, 32, 4, 8, 1}},
{size_key_t(64, 32), {8, 64, 8, 32, 2, 2, 1, 1, 4, 8, 32, 4, 8, 1}}, {size_key_t(64, 16), {4, 64, 8, 16, 2, 2, 1, 1, 8, 4, 32, 4, 8, 1}},
{size_key_t(64, 64), {16, 64, 16, 64, 2, 2, 2, 1, 4, 8, 32, 4, 8, 1}}, {size_key_t(64, 32), {4, 64, 4, 32, 2, 2, 1, 1, 8, 4, 32, 8, 4, 1}},
{size_key_t(64, 128), {32, 64, 16, 128, 2, 2, 2, 2, 4, 8, 32, 8, 4, 1}}, {size_key_t(64, 64), {8, 64, 16, 64, 2, 2, 2, 1, 8, 4, 32, 4, 8, 1}},
{size_key_t(128, 16), {16, 128, 16, 16, 2, 2, 2, 1, 4, 8, 32, 4, 8, 1}}, {size_key_t(64, 128), {16, 64, 32, 128, 2, 2, 1, 4, 8, 4, 32, 4, 8, 1}},
{size_key_t(128, 32), {32, 128, 32, 32, 2, 2, 4, 1, 4, 8, 32, 4, 8, 1}}, {size_key_t(128, 16), {8, 128, 8, 16, 2, 2, 1, 1, 4, 8, 32, 4, 8, 1}},
{size_key_t(128, 64), {32, 128, 32, 64, 2, 2, 4, 1, 4, 8, 32, 4, 8, 1}}, {size_key_t(128, 32), {16, 128, 16, 32, 2, 2, 4, 1, 8, 4, 32, 8, 4, 1}},
{size_key_t(128, 128), {32, 128, 32, 128, 2, 2, 4, 1, 4, 8, 32, 4, 8, 1}}, {size_key_t(128, 64), {32, 128, 32, 64, 2, 2, 2, 2, 4, 8, 32, 4, 8, 1}},
{size_key_t(128, 128), {32, 128, 32, 128, 2, 2, 1, 4, 4, 8, 32, 4, 8, 1}},
{size_key_t(256, 16), {16, 256, 16, 16, 2, 2, 2, 1, 4, 8, 32, 4, 8, 1}},
{size_key_t(256, 32), {16, 256, 32, 32, 2, 2, 4, 1, 8, 4, 32, 4, 8, 1}},
}}, }},
/* TT */ /* TT */
{trans_key_t(true, true), std::map<size_key_t, params_t>{ {trans_key_t(true, true), std::map<size_key_t, params_t>{
{size_key_t(16, 16), {4, 16, 2, 8, 16, 2, 2, 1, 1, 8, 4, 32, 16, 1}}, {size_key_t(16, 16), {8, 16, 4, 4, 16, 2, 2, 1, 1, 4, 8, 32, 8, 1}},
{size_key_t(16, 32), {8, 16, 4, 8, 32, 2, 2, 1, 1, 4, 8, 32, 8, 1}}, {size_key_t(16, 32), {8, 16, 8, 4, 32, 2, 2, 1, 1, 4, 8, 32, 4, 1}},
{size_key_t(16, 64), {16, 16, 4, 8, 64, 2, 2, 1, 4, 8, 4, 32, 32, 1}}, {size_key_t(16, 64), {16, 16, 4, 8, 64, 2, 2, 1, 4, 8, 4, 32, 32, 1}},
{size_key_t(16, 128), {16, 16, 8, 4, 128, 2, 2, 1, 2, 4, 8, 32, 8, 1}}, {size_key_t(16, 128), {16, 16, 8, 8, 128, 2, 2, 1, 1, 2, 4, 32, 4, 1}},
{size_key_t(32, 16), {4, 32, 2, 8, 16, 2, 2, 1, 1, 8, 4, 32, 16, 1}}, {size_key_t(32, 16), {4, 32, 4, 4, 16, 2, 2, 1, 1, 8, 4, 32, 8, 1}},
{size_key_t(32, 32), {8, 32, 4, 8, 32, 2, 2, 1, 1, 4, 8, 32, 8, 1}}, {size_key_t(32, 32), {8, 32, 8, 4, 32, 2, 2, 1, 1, 4, 8, 32, 4, 1}},
{size_key_t(32, 64), {16, 64, 4, 8, 64, 2, 2, 2, 1, 4, 8, 32, 16, 1}}, {size_key_t(32, 64), {64, 128, 8, 4, 64, 2, 2, 4, 1, 2, 8, 32, 16, 1}},
{size_key_t(32, 128), {32, 32, 8, 8, 128, 2, 2, 1, 4, 4, 8, 32, 16, 1}}, {size_key_t(32, 128), {16, 32, 32, 4, 128, 2, 2, 1, 2, 4, 8, 32, 2, 1}},
{size_key_t(64, 16), {8, 64, 2, 8, 16, 2, 2, 1, 1, 4, 8, 32, 16, 1}}, {size_key_t(32, 256), {32, 32, 32, 4, 256, 2, 2, 1, 4, 4, 8, 32, 4, 1}},
{size_key_t(64, 32), {8, 64, 4, 8, 32, 2, 2, 1, 1, 4, 8, 32, 8, 1}}, {size_key_t(64, 16), {4, 64, 2, 8, 16, 2, 2, 1, 1, 8, 4, 32, 16, 1}},
{size_key_t(64, 64), {16, 64, 8, 8, 64, 2, 2, 2, 1, 4, 8, 32, 8, 1}}, {size_key_t(64, 32), {4, 64, 8, 4, 32, 2, 2, 1, 1, 8, 4, 32, 4, 1}},
{size_key_t(64, 128), {32, 64, 8, 8, 128, 2, 2, 1, 4, 4, 8, 32, 16, 1}}, {size_key_t(64, 64), {16, 64, 8, 8, 64, 2, 2, 2, 1, 4, 8, 32, 8, 1}},
{size_key_t(128, 16), {16, 128, 2, 8, 16, 2, 2, 2, 1, 4, 8, 32, 32, 1}}, {size_key_t(64, 128), {32, 64, 8, 8, 128, 2, 2, 1, 4, 4, 4, 32, 16, 1}},
{size_key_t(128, 32), {32, 128, 8, 4, 32, 2, 2, 4, 1, 4, 8, 32, 16, 1}}, {size_key_t(64, 256), {64, 64, 8, 8, 256, 2, 2, 1, 4, 2, 8, 32, 16}},
{size_key_t(128, 64), {32, 128, 16, 4, 64, 2, 2, 4, 1, 4, 8, 32, 8, 1}}, {size_key_t(128, 16), {8, 128, 2, 8, 16, 2, 2, 1, 1, 4, 8, 32, 16, 1}},
{size_key_t(128, 128), {32, 128, 8, 8, 128, 2, 2, 4, 1, 4, 8, 32, 16, 1}} {size_key_t(128, 32), {16, 128, 8, 4, 32, 2, 2, 4, 1, 8, 4, 32, 16, 1}},
{size_key_t(128, 64), {32, 128, 8, 8, 64, 2, 2, 4, 1, 4, 8, 32, 16, 1}},
{size_key_t(128, 128), {32, 128, 16, 8, 128, 2, 2, 2, 2, 4, 8, 32, 8, 1}},
{size_key_t(256, 16), {32, 256, 4, 4, 16, 2, 2, 4, 1, 4, 8, 32, 32, 1}},
{size_key_t(256, 32), {32, 256, 8, 4, 32, 2, 2, 4, 1, 4, 8, 32, 16, 1}}
}} }}
}; };

View File

@@ -72,6 +72,7 @@ public:
void target_independent(ir::module &module) { void target_independent(ir::module &module) {
optimize_dot.run(module); optimize_dot.run(module);
optimize_trans.run(module); optimize_trans.run(module);
// optimize_dce.run(module);
} }
void target_dependent(ir::module &module) { void target_dependent(ir::module &module) {

View File

@@ -1190,9 +1190,8 @@ void selection::lower_tile_instruction(ir::instruction *ins, llvm::IRBuilder<> &
Value *ptr = pointers->get_value(idx); Value *ptr = pointers->get_value(idx);
ConstantInt *cst = nullptr; ConstantInt *cst = nullptr;
if(GetElementPtrInst *gep = dyn_cast<GetElementPtrInst>(ptr)) if(GetElementPtrInst *gep = dyn_cast<GetElementPtrInst>(ptr))
if(gep->getNumIndices() == 1){ if(gep->getNumIndices() == 1)
cst = dyn_cast<ConstantInt>(gep->idx_begin()); cst = dyn_cast<ConstantInt>(gep->idx_begin());
}
ptr = builder.CreateBitCast(ptr, PointerType::get(VectorType::get(result->get_ty(), vector_size), ptr = builder.CreateBitCast(ptr, PointerType::get(VectorType::get(result->get_ty(), vector_size),
ptr->getType()->getPointerAddressSpace())); ptr->getType()->getPointerAddressSpace()));
packets[id] = builder.CreateLoad(ptr); packets[id] = builder.CreateLoad(ptr);
@@ -1202,7 +1201,7 @@ void selection::lower_tile_instruction(ir::instruction *ins, llvm::IRBuilder<> &
result->for_each([&](indices_t idx){ result->for_each([&](indices_t idx){
unsigned linear = result->get_linear_index(idx); unsigned linear = result->get_linear_index(idx);
unsigned id = linear / vector_size; unsigned id = linear / vector_size;
// result->set_value(idx, builder.CreateExtractElement(packets.at(id), linear % vector_size)); result->set_value(idx, builder.CreateExtractElement(packets.at(id), linear % vector_size));
}); });
} }
// element-wise // element-wise

View File

@@ -247,14 +247,14 @@ void tune::run(ir::module &mod) {
size_t addr_space = ptr_ty->get_pointer_address_space(); size_t addr_space = ptr_ty->get_pointer_address_space();
if(addr_space < 4){ if(addr_space < 4){
ir::type *ty = mod.get_builder().get_int32_ty(); ir::type *ty = mod.get_builder().get_int32_ty();
std::unique_ptr<ir::metaparameter> tmp(ir::metaparameter::create(ctx, ty, 8, 8)); std::unique_ptr<ir::metaparameter> tmp(ir::metaparameter::create(ctx, ty, 4, 8));
*params_.at(i).at("nts.d0") = *tmp; *params_.at(i).at("nts.d0") = *tmp;
} }
} }
if(dynamic_cast<ir::dot_inst*>(i) && i->get_type()->is_tile_ty()){ if(dynamic_cast<ir::dot_inst*>(i) && i->get_type()->is_tile_ty()){
ir::type *ty = mod.get_builder().get_int32_ty(); ir::type *ty = mod.get_builder().get_int32_ty();
std::unique_ptr<ir::metaparameter> tmp1(ir::metaparameter::create(ctx, ty, 8, 8)); std::unique_ptr<ir::metaparameter> tmp1(ir::metaparameter::create(ctx, ty, 4, 8));
std::unique_ptr<ir::metaparameter> tmp2(ir::metaparameter::create(ctx, ty, 8, 8)); std::unique_ptr<ir::metaparameter> tmp2(ir::metaparameter::create(ctx, ty, 4, 8));
*params_.at(i).at("nts.d0") = *tmp1; *params_.at(i).at("nts.d0") = *tmp1;
*params_.at(i).at("nts.d1") = *tmp2; *params_.at(i).at("nts.d1") = *tmp2;
} }

View File

@@ -59,8 +59,8 @@ void base::enqueue(driver::stream *stream, std::vector<driver::buffer *> args, a
jit->add_module(name_.c_str(), src.c_str(), best.params); jit->add_module(name_.c_str(), src.c_str(), best.params);
} }
else { else {
params_t params = heuristics(); // params_t params = heuristics();
// params_t params = jit->get_valid(name_.c_str(), src.c_str()); params_t params = jit->get_valid(name_.c_str(), src.c_str());
jit->add_module(name_.c_str(), src.c_str(), params); jit->add_module(name_.c_str(), src.c_str(), params);
} }
triton::driver::kernel* kernel = jit->get_function(name_.c_str()); triton::driver::kernel* kernel = jit->get_function(name_.c_str());

109
lib/dnn/blocksparse/dot.cpp Normal file
View File

@@ -0,0 +1,109 @@
#include "triton/dnn/blocksparse/dot.h"
namespace triton{
namespace dnn{
namespace blocksparse{
size_t dot::num_flops() const {
}
bool dot::operator <(const base& other) const {
auto *y = dynamic_cast<const dot*>(&other);
if(!y)
return true;
return std::tie(M_, N_, K_)
< std::tie(y->M_, y->N_, y->K_);
}
std::vector<params_t> dot::search_space() const {
}
params_t dot::heuristics() const {
}
base * dot::clone() const {
return new dot(*this);
}
dot::dot(int32_t M, int32_t N, int32_t K):
base("bsdot"), M_(M), N_(N), K_(K) {
ab_ty_ = "fp32";
c_ty_ = "fp32";
}
void dot::enqueue_impl(driver::stream *stream, driver::kernel *kernel,
std::vector<driver::buffer *> args, runtime::launch_information info) {
driver::buffer *a = args[0];
driver::buffer *b = args[1];
driver::buffer *c = args[2];
driver::buffer *lut = args[3];
int32_t lda = M_;
int32_t ldc = M_;
kernel->setArg(0, a);
kernel->setArg(1, b);
kernel->setArg(2, c);
kernel->setArg(3, lda);
kernel->setArg(4, ldc);
kernel->setArg(5, lut);
int32_t TM = info.globals["TM"];
int32_t TN = info.globals["TN"];
size_t grid_0 = (M_ + TM - 1) / TM;
size_t grid_1 = (N_ + TN - 1) / TN;
stream->enqueue(kernel, {grid_0, grid_1, 1}, {info.num_threads, 1, 1});
stream->synchronize();
}
void dot::triton_c_src(std::ostream &os) const {
std::string result =
R"(
const tunable int32 TM = {64, 128};
const tunable int32 TN = {32};
const tunable int32 TK = {32};
void bsdot(restrict read_only align(16) )" + ab_ty_ + R"( *A,
restrict read_only align(16) )" + ab_ty_ + R"( *B,
fp32* C,
int32 lda, int32 ldc,
int32* lut_base){
int32 ridx = get_range_id(0);
int32 ridy = get_range_id(1);
fp32 c[TM, TN] = 0;
int32 rka[TK] = 0 ... TK;
int32 rkb[TK] = 0 ... TK;
int32 rxa[TM] = ridx * TM + (0 ... TM);
int32 ryb[TN] = 0 ... TN;
int32 offa[TM, TK] = rxa[:, newaxis] + rka[newaxis, :]*lda;
int32 offb[TK, TN] = ryb[newaxis, :] + rkb[:, newaxis]*TK;
int32 *header = lut_base + ridy * 4;
int32 offset = *(header + 0);
int32 K = *(header + 1);
int32 h2 = *(header + 2);
int32 h3 = *(header + 3);
int32 *lut = lut_base + offset*2;
for(int32 k = K; k > 0; k = k - 1){
int32 ak = *(lut + 0);
int32 bk = *(lut + 1);
fp32* pa[TM, TK] = A + offa + ak * TK * lda;
fp32* pb[TK, TN] = B + offb + bk * TK * TN;
fp32 a[TM, TK] = *pa;
fp32 b[TK, TN] = *pb;;
c = dot(a, b, c);
lut = lut + 2;
}
int32 rxc[TM] = ridx * TM + (0 ... TM);
int32 ryc[TN] = ridy * TN + (0 ... TN);
fp32* pc[TM, TN] = C + rxc[:, newaxis] + ryc[newaxis, :]*ldc;
*pc = c;
})";
os << result;
}
}
}
}

View File

@@ -1,6 +1,6 @@
#include "triton/driver/stream.h" #include "triton/driver/stream.h"
#include "triton/driver/kernel.h" #include "triton/driver/kernel.h"
#include "triton/dnn/gemm.h" #include "triton/dnn/dot.h"
#include "triton/dnn/heuristics.h" #include "triton/dnn/heuristics.h"
#include <string> #include <string>
@@ -101,8 +101,8 @@ void dot::triton_c_src(std::ostream &os) const {
std::string align_ldb_str = "multiple_of(" + std::to_string(align_ldb_) + ")"; std::string align_ldb_str = "multiple_of(" + std::to_string(align_ldb_) + ")";
std::string res = std::string res =
R"( R"(
const tunable int32 TM = {16, 32, 64, 128}; const tunable int32 TM = {16, 32, 64, 128, 256};
const tunable int32 TN = {16, 32, 64, 128}; const tunable int32 TN = {16, 32, 64, 128, 256};
const tunable int32 TK = {32}; const tunable int32 TK = {32};
const tunable int32 GZ = {1}; const tunable int32 GZ = {1};

View File

@@ -467,7 +467,7 @@ if(op_ == WGRAD){
pa = pa_base + offxa[:, newaxis];)"; pa = pa_base + offxa[:, newaxis];)";
} }
result += R"( result += R"(
@checka a = *pa;)"; a = checka ? *pa : 0;)";
/* Increment B pointers */ /* Increment B pointers */
if(op_ == WGRAD){ if(op_ == WGRAD){
@@ -488,7 +488,7 @@ if(op_ == BPROP){
pb = pb + TK;)"; pb = pb + TK;)";
} }
result += R"( result += R"(
@checkb b = *pb; b = checkb ? *pb : 0;
} }
int32 rxc[TM] = ridx*TM + (0 ... TM); int32 rxc[TM] = ridx*TM + (0 ... TM);
int32 ryc[TN] = ridy*TN + (0 ... TN);)"; int32 ryc[TN] = ridy*TN + (0 ... TN);)";

View File

@@ -37,12 +37,13 @@ void parallel_loop_nest(std::vector<size_t> const & ranges,
size_t D = ranges.size(); size_t D = ranges.size();
std::vector<size_t> values(D, 0); std::vector<size_t> values(D, 0);
// thread pools // thread pools
ThreadPool pool(nthreads); // ThreadPool pool(nthreads);
// Start with innermost loop // Start with innermost loop
size_t i = D - 1; size_t i = D - 1;
while(true){ while(true){
// Execute function // Execute function
pool.enqueue(f,values); // pool.enqueue(f,values);
f(values);
while(values[i]++ == ranges[i] - 1){ while(values[i]++ == ranges[i] - 1){
if(i == 0) if(i == 0)
return; return;
@@ -50,7 +51,7 @@ void parallel_loop_nest(std::vector<size_t> const & ranges,
} }
i = D - 1; i = D - 1;
// Short sleep so that the thread pool doesn't grow too big // Short sleep so that the thread pool doesn't grow too big
std::this_thread::sleep_for(std::chrono::microseconds(1)); // std::this_thread::sleep_for(std::chrono::microseconds(1));
} }
} }
@@ -211,9 +212,9 @@ jit::tune_res_t jit::autotune(const char *name, const char *src, benchmark_t ben
best.perf = perf; best.perf = perf;
best.params = params; best.params = params;
} }
for(size_t i = 0; i < params.size(); i++) // for(size_t i = 0; i < params.size(); i++)
std::cout << ((i==0)?"":", ") << params[i] << std::flush; // std::cout << ((i==0)?"":", ") << params[i] << std::flush;
std::cout << ", " << perf << " [ " << best.perf << " ] " << std::endl; // std::cout << ", " << perf << " [ " << best.perf << " ] " << std::endl;
} }
}; };