[dnn/blocksparse/dot] prototype version seems to pass basic test
This commit is contained in:
@@ -4,7 +4,7 @@
|
||||
#include "triton/runtime/jit.h"
|
||||
#include "triton/driver/backend.h"
|
||||
#include "triton/driver/stream.h"
|
||||
#include "triton/dnn/gemm.h"
|
||||
#include "triton/dnn/dot.h"
|
||||
#include "triton/tools/bench.hpp"
|
||||
#include "cuda.h"
|
||||
|
||||
@@ -48,7 +48,7 @@ perf_t do_bench(triton::driver::stream* stream, bool AT, bool BT, int32_t M, int
|
||||
stream->synchronize();
|
||||
triton::dnn::dot dot(M, N, K, AT, BT, ty, ty, 8, 8);
|
||||
// benchmark triton
|
||||
double triton_ns = triton::tools::bench([&]() { dot.enqueue(stream, {da, db, dc}, triton::dnn::FULL_TUNING);}, stream);
|
||||
double triton_ns = triton::tools::bench([&]() { dot.enqueue(stream, {da, db, dc}, triton::dnn::PARTIAL_TUNING);}, stream);
|
||||
// benchmark cublas
|
||||
NumericT alpha = 1;
|
||||
NumericT beta = 0;
|
||||
@@ -98,8 +98,9 @@ int main() {
|
||||
// shapes to benchmark
|
||||
std::vector<config_t> configs = {
|
||||
// {false, false, 8192, 512, 512},
|
||||
{false, true, 8192, 8192, 8192}
|
||||
// {false, true, 32768, 256, 512}
|
||||
// {false, true, 8192, 8192, 8192}
|
||||
{false, true, 32768, 256, 256},
|
||||
{false, true, 32768, 256, 512}
|
||||
// {true, false, 8192, 512, 512},
|
||||
// {true, true, 8192, 512, 512}
|
||||
};
|
||||
|
@@ -67,23 +67,23 @@ perf_t do_bench(triton::driver::stream *stream,
|
||||
stream->write(dc, true, 0, hc);
|
||||
stream->synchronize();
|
||||
// benchmark triton
|
||||
double triton_ns = triton::tools::bench([&]() { shift.enqueue(stream, {da, db, dc}, triton::dnn::FULL_TUNING);}, stream);
|
||||
double triton_ns = triton::tools::bench([&]() { shift.enqueue(stream, {da, db, dc}, triton::dnn::NO_TUNING);}, stream);
|
||||
// benchmark cublas
|
||||
NumericT alpha = 1;
|
||||
NumericT beta = 0;
|
||||
cublasGemmAlgo_t fastest;
|
||||
cublasGemm(HALF_TYPE, stream, shift.AT(), shift.BT(), shift.M(), shift.N(), shift.K(),
|
||||
&alpha, da, shift.lda(),
|
||||
db, shift.ldb(), &beta,
|
||||
dc, shift.ldc(), &fastest);
|
||||
double cublas_ns = triton::tools::bench([&]() { cublasGemm(HALF_TYPE, stream, shift.AT(), shift.BT(), shift.M(), shift.N(), shift.K(),
|
||||
&alpha, da, shift.lda(),
|
||||
db, shift.ldb(),
|
||||
&beta, dc, shift.ldc(), nullptr, fastest); }, stream);
|
||||
// NumericT alpha = 1;
|
||||
// NumericT beta = 0;
|
||||
// cublasGemmAlgo_t fastest;
|
||||
// cublasGemm(HALF_TYPE, stream, shift.AT(), shift.BT(), shift.M(), shift.N(), shift.K(),
|
||||
// &alpha, da, shift.lda(),
|
||||
// db, shift.ldb(), &beta,
|
||||
// dc, shift.ldc(), &fastest);
|
||||
// double cublas_ns = triton::tools::bench([&]() { cublasGemm(HALF_TYPE, stream, shift.AT(), shift.BT(), shift.M(), shift.N(), shift.K(),
|
||||
// &alpha, da, shift.lda(),
|
||||
// db, shift.ldb(),
|
||||
// &beta, dc, shift.ldc(), nullptr, fastest); }, stream);
|
||||
// result
|
||||
auto tflops = [&](double nanosec) { return shift.num_flops() / nanosec * 1e-3; };
|
||||
perf_t result;
|
||||
result.cublas = tflops(cublas_ns);
|
||||
// result.cublas = tflops(cublas_ns);
|
||||
result.triton = tflops(triton_ns);
|
||||
delete da;
|
||||
delete db;
|
||||
@@ -133,8 +133,9 @@ int main() {
|
||||
{128, 1024, 8, 8, 3, 3, 1024, 1, 1}
|
||||
};
|
||||
for(config_t c: resnet18){
|
||||
for(op_t op: {op_t::FPROP, op_t::BPROP, op_t::WGRAD})
|
||||
for(op_t op: {op_t::FPROP, op_t::BPROP, op_t::WGRAD}){
|
||||
configs.push_back({c.B, c.C, c.H, c.W, c.R, c.S, c.F, c.stride_h, c.stride_w, op, layout_t::CHWN, "fp16"});
|
||||
}
|
||||
}
|
||||
|
||||
// initialize default compute device
|
||||
|
@@ -5,7 +5,7 @@ if(${TensorFlow_FOUND})
|
||||
include_directories("${CUDA_HOME}/include")
|
||||
link_directories(${TF_LIB})
|
||||
add_definitions(-D_GLIBCXX_USE_CXX11_ABI=${TF_ABI})
|
||||
add_library(tf_blocksparse SHARED dot.cpp conv.cpp shift.cpp batchnorm.cpp)
|
||||
add_library(tf_blocksparse SHARED blocksparse.cpp dot.cpp conv.cpp shift.cpp batchnorm.cpp)
|
||||
target_link_libraries(tf_blocksparse tensorflow_framework triton)
|
||||
configure_file(${CMAKE_CURRENT_SOURCE_DIR}/run.py
|
||||
${CMAKE_CURRENT_BINARY_DIR}/run.py
|
||||
|
@@ -3,7 +3,8 @@
|
||||
#include "triton/driver/buffer.h"
|
||||
#include "triton/driver/backend.h"
|
||||
#include "triton/driver/stream.h"
|
||||
#include "triton/jit.h"
|
||||
#include "triton/runtime/jit.h"
|
||||
#include "triton/dnn/blocksparse/dot.h"
|
||||
|
||||
#define EIGEN_USE_GPU
|
||||
#include "tensorflow/core/framework/op.h"
|
||||
@@ -20,106 +21,88 @@ using shape_inference::InferenceContext;
|
||||
using shape_inference::ShapeHandle;
|
||||
using GPUDevice = Eigen::GpuDevice;
|
||||
|
||||
|
||||
const char* src =
|
||||
R"(
|
||||
const tunable int32 TM = {16, 32, 64, 128};
|
||||
const tunable int32 TN = {16, 32, 64, 128};
|
||||
const tunable int32 TK = {8};
|
||||
const tunable int32 GZ = {1};
|
||||
|
||||
void bsmm (restrict read_only fp32 *A, restrict read_only fp32 *B, fp32 *C,
|
||||
int32 M, int32 N, int32 K,
|
||||
int32 lda, int32 ldb, int32 ldc,
|
||||
int32 *locks, int32 grid0, int32 grid1) {
|
||||
|
||||
}
|
||||
)";
|
||||
|
||||
Status XpropShape(InferenceContext* ctx)
|
||||
{
|
||||
int K; TF_RETURN_IF_ERROR(ctx->GetAttr( "K", &K));
|
||||
int axis; TF_RETURN_IF_ERROR(ctx->GetAttr("axis", &axis));
|
||||
int K; TF_RETURN_IF_ERROR(ctx->GetAttr( "K", &K));
|
||||
int axis; TF_RETURN_IF_ERROR(ctx->GetAttr("axis", &axis));
|
||||
|
||||
// C ==> K
|
||||
ShapeHandle x = ctx->input(0);
|
||||
int rank = ctx->Rank(x);
|
||||
//printf("XpropShape: %d\n", rank);
|
||||
if (rank > 0)
|
||||
{
|
||||
std::vector<DimensionHandle> shape;
|
||||
shape.reserve(rank);
|
||||
for (int i = 0; i < rank; i++)
|
||||
shape.push_back(i == axis ? ctx->MakeDim(K) : ctx->Dim(x, i));
|
||||
|
||||
ctx->set_output(0, ctx->MakeShape(shape));
|
||||
}
|
||||
else
|
||||
ctx->set_output(0, ctx->UnknownShape());
|
||||
ctx->set_output(1, ctx->UnknownShape());
|
||||
return Status::OK();
|
||||
// C ==> K
|
||||
ShapeHandle x = ctx->input(0);
|
||||
int rank = ctx->Rank(x);
|
||||
//printf("XpropShape: %d\n", rank);
|
||||
if (rank > 0)
|
||||
{
|
||||
std::vector<DimensionHandle> shape;
|
||||
shape.reserve(rank);
|
||||
for (int i = 0; i < rank; i++)
|
||||
shape.push_back(i == axis ? ctx->MakeDim(K) : ctx->Dim(x, i));
|
||||
ctx->set_output(0, ctx->MakeShape(shape));
|
||||
}
|
||||
else
|
||||
ctx->set_output(0, ctx->UnknownShape());
|
||||
ctx->set_output(1, ctx->UnknownShape());
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
|
||||
REGISTER_OP("BlocksparseMatmul")
|
||||
.Input("x: T")
|
||||
.Input("w: T")
|
||||
.Input("lut: int64")
|
||||
.Input("lut_dx: int64")
|
||||
.Input("lut_dw: int64")
|
||||
.Input("gate: ngate * float")
|
||||
.Output("y: T")
|
||||
.Output("temp: int32")
|
||||
.Attr("T: {half, float, bfloat16}")
|
||||
.Attr("blocks: int >=0")
|
||||
.Attr("bsize: int")
|
||||
.Attr("segments: int = 0")
|
||||
.Attr("segments_dx: int = 0")
|
||||
.Attr("locks: int = 0")
|
||||
.Attr("locks_dx: int = 0")
|
||||
.Attr("axis: int = 1")
|
||||
.Attr("C: int >=0")
|
||||
.Attr("K: int >=0")
|
||||
.Attr("shared: int = 0")
|
||||
.Attr("shared_dx: int = 0")
|
||||
.Attr("alpha: float = 1.0")
|
||||
.Attr("beta: float = 0.0")
|
||||
.Attr("gated_dw: bool = false")
|
||||
.Attr("gate_grad: bool = false")
|
||||
.Attr("bench: int = 0")
|
||||
.Attr("ngate: int >= 0")
|
||||
.SetShapeFn(XpropShape)
|
||||
.Doc(R"doc(
|
||||
Multiply the matrix "a" by the blocksparse matrix "b".
|
||||
)doc");
|
||||
REGISTER_OP("TritonBlocksparseMatmul")
|
||||
.Input("x: T")
|
||||
.Input("w: T")
|
||||
.Input("lut: int64")
|
||||
.Input("lut_dx: int64")
|
||||
.Input("lut_dw: int64")
|
||||
.Input("gate: ngate * float")
|
||||
.Output("y: T")
|
||||
.Output("temp: int32")
|
||||
.Attr("T: {half, float, bfloat16}")
|
||||
.Attr("blocks: int >=0")
|
||||
.Attr("bsize: int")
|
||||
.Attr("segments: int = 0")
|
||||
.Attr("segments_dx: int = 0")
|
||||
.Attr("locks: int = 0")
|
||||
.Attr("locks_dx: int = 0")
|
||||
.Attr("axis: int = 1")
|
||||
.Attr("C: int >=0")
|
||||
.Attr("K: int >=0")
|
||||
.Attr("shared: int = 0")
|
||||
.Attr("shared_dx: int = 0")
|
||||
.Attr("alpha: float = 1.0")
|
||||
.Attr("beta: float = 0.0")
|
||||
.Attr("gated_dw: bool = false")
|
||||
.Attr("gate_grad: bool = false")
|
||||
.Attr("bench: int = 0")
|
||||
.Attr("ngate: int >= 0")
|
||||
.SetShapeFn(XpropShape)
|
||||
.Doc(R"doc(
|
||||
Multiply the matrix "a" by the blocksparse matrix "b".
|
||||
)doc");
|
||||
|
||||
|
||||
typedef struct bsmm_params
|
||||
{
|
||||
const int* Lut;
|
||||
const float* Gate;
|
||||
int* Lock;
|
||||
//float4* Scratch;
|
||||
int blocks;
|
||||
int bsize;
|
||||
int segments;
|
||||
int locks;
|
||||
int C;
|
||||
int K;
|
||||
int N;
|
||||
int shared;
|
||||
int pcount;
|
||||
uint blk_a;
|
||||
uint blk_A;
|
||||
uint blk_b;
|
||||
uint blk_B;
|
||||
float alpha;
|
||||
float beta;
|
||||
CUstream stream;
|
||||
const int* Lut;
|
||||
const float* Gate;
|
||||
int* Lock;
|
||||
int blocks;
|
||||
int bsize;
|
||||
int segments;
|
||||
int locks;
|
||||
int C;
|
||||
int K;
|
||||
int N;
|
||||
int shared;
|
||||
int pcount;
|
||||
uint blk_a;
|
||||
uint blk_A;
|
||||
uint blk_b;
|
||||
uint blk_B;
|
||||
float alpha;
|
||||
float beta;
|
||||
CUstream stream;
|
||||
} bsmm_params;
|
||||
|
||||
class BlocksparseMatmulOp : public OpKernel {
|
||||
public:
|
||||
public:
|
||||
explicit BlocksparseMatmulOp(OpKernelConstruction* ctx) : OpKernel(ctx) {
|
||||
OP_REQUIRES_OK(ctx, ctx->GetAttr("segments", ¶ms_.segments));
|
||||
OP_REQUIRES_OK(ctx, ctx->GetAttr("locks", ¶ms_.locks ));
|
||||
@@ -147,6 +130,51 @@ class BlocksparseMatmulOp : public OpKernel {
|
||||
}
|
||||
|
||||
void Compute(OpKernelContext* context){
|
||||
// get device/stream
|
||||
GPUDevice device = context->eigen_device<GPUDevice>();
|
||||
triton::driver::cu_stream sstream(device.stream(), false);
|
||||
triton::driver::context* ctx = sstream.context();
|
||||
triton::driver::stream* stream = &sstream;
|
||||
// get inputs
|
||||
const Tensor& a = context->input(0);
|
||||
const Tensor& b = context->input(1);
|
||||
const Tensor& lut = context->input(2);
|
||||
// allocate c
|
||||
TensorShape shape_c;
|
||||
int N = 1;
|
||||
int rank_a = a.dims();
|
||||
for (int i = 0; i < rank_a; i++)
|
||||
if (i != axis_) {
|
||||
shape_c.AddDim(a.dim_size(i));
|
||||
N *= a.dim_size(i);
|
||||
}
|
||||
else
|
||||
shape_c.AddDim(params_.K);
|
||||
Tensor* c = nullptr;
|
||||
OP_REQUIRES_OK(context, context->allocate_output(0, shape_c, &c));
|
||||
// grid and block
|
||||
int blkN = 128, gridN = (N + 127)/128, modN128 = N & 127;
|
||||
if (axis_ == 1 || (modN128 > 0 && modN128 <= 64) || gridN * params_.segments < SMs_*4){
|
||||
blkN = 64;
|
||||
gridN = (N + 63)/64;
|
||||
}
|
||||
// allocate locks
|
||||
Tensor* locks;
|
||||
TensorShape shape_l;
|
||||
if (params_.locks > 0)
|
||||
shape_l.AddDim(gridN * params_.locks * 2);
|
||||
OP_REQUIRES_OK(context, context->allocate_output(1, shape_l, &locks));
|
||||
// initialize default compute device
|
||||
triton::runtime::jit jit(ctx);
|
||||
// matrix multiplication parameters
|
||||
triton::driver::cu_buffer da(ctx, (CUdeviceptr)a.flat<float>().data(), false);
|
||||
triton::driver::cu_buffer db(ctx, (CUdeviceptr)b.flat<float>().data(), false);
|
||||
triton::driver::cu_buffer dc(ctx, (CUdeviceptr)c->flat<float>().data(), false);
|
||||
// triton::driver::cu_buffer dlocks(ctx, (CUdeviceptr)locks->flat<int32>().data(), false);
|
||||
triton::driver::cu_buffer dlut(ctx, (CUdeviceptr)lut.flat<int64>().data(), false);
|
||||
// blocksparse matmul
|
||||
triton::dnn::blocksparse::dot dot(N, params_.K, params_.C);
|
||||
dot.enqueue(stream, {&da, &db, &dc, &dlut}, triton::dnn::NO_TUNING);
|
||||
}
|
||||
|
||||
private:
|
||||
@@ -157,4 +185,4 @@ private:
|
||||
char bench_string_[256];
|
||||
};
|
||||
|
||||
REGISTER_KERNEL_BUILDER(Name("BlocksparseMatmul").Device(DEVICE_GPU).TypeConstraint<float>("T"), BlocksparseMatmulOp);
|
||||
REGISTER_KERNEL_BUILDER(Name("TritonBlocksparseMatmul").Device(DEVICE_GPU).TypeConstraint<float>("T"), BlocksparseMatmulOp);
|
||||
|
@@ -5,7 +5,6 @@
|
||||
#include "triton/driver/stream.h"
|
||||
#include "triton/runtime/jit.h"
|
||||
#include "triton/tools/bench.hpp"
|
||||
#include "triton/dnn/gemm.h"
|
||||
#include "triton/dnn/conv.h"
|
||||
|
||||
#define EIGEN_USE_GPU
|
||||
|
@@ -5,7 +5,7 @@
|
||||
#include "triton/driver/stream.h"
|
||||
#include "triton/runtime/jit.h"
|
||||
#include "triton/tools/bench.hpp"
|
||||
#include "triton/dnn/gemm.h"
|
||||
#include "triton/dnn/dot.h"
|
||||
|
||||
#define EIGEN_USE_GPU
|
||||
#include "tensorflow/core/framework/op.h"
|
||||
|
@@ -19,10 +19,10 @@
|
||||
using namespace tensorflow;
|
||||
using GPUDevice = Eigen::GpuDevice;
|
||||
|
||||
template<triton::dnn::shift::op_t OP>
|
||||
template<triton::dnn::op_t OP>
|
||||
class ShiftConvOp : public OpKernel {
|
||||
public:
|
||||
explicit ShiftConvOp(OpKernelConstruction* context) : OpKernel(context), layout_(triton::dnn::shift::NCHW) {
|
||||
explicit ShiftConvOp(OpKernelConstruction* context) : OpKernel(context), layout_(triton::dnn::NCHW) {
|
||||
context->GetAttr("shift_h", &h_shift_h_);
|
||||
context->GetAttr("shift_w", &h_shift_w_);
|
||||
context->GetAttr("stride_h", &stride_h_);
|
||||
@@ -32,13 +32,13 @@ public:
|
||||
}
|
||||
|
||||
void ExtractShapes(const Tensor &x, int64_t &C, int64_t &H, int64_t &W, int64_t &B) {
|
||||
if(layout_ == triton::dnn::shift::CHWN){
|
||||
if(layout_ == triton::dnn::CHWN){
|
||||
C = x.dim_size(0);
|
||||
H = x.dim_size(1);
|
||||
W = x.dim_size(2);
|
||||
B = x.dim_size(3);
|
||||
}
|
||||
else if(layout_ == triton::dnn::shift::NCHW){
|
||||
else if(layout_ == triton::dnn::NCHW){
|
||||
B = x.dim_size(0);
|
||||
C = x.dim_size(1);
|
||||
H = x.dim_size(2);
|
||||
@@ -52,7 +52,7 @@ public:
|
||||
void FillShapes(OpKernelContext* context,
|
||||
int64_t &C, int64_t &H, int64_t &W, int64_t &B, int64_t &F,
|
||||
const Tensor& tf_a, const Tensor& tf_b) {
|
||||
if(OP == triton::dnn::shift::WGRAD) {
|
||||
if(OP == triton::dnn::WGRAD) {
|
||||
int64_t Ha, Wa, Ba;
|
||||
int64_t Hb, Wb, Bb;
|
||||
ExtractShapes(tf_a, F, Ha, Wa, Ba);
|
||||
@@ -68,19 +68,19 @@ public:
|
||||
// shapes for a
|
||||
int64_t Ca;
|
||||
ExtractShapes(tf_a, Ca, H, W, B);
|
||||
if(OP == triton::dnn::shift::BPROP){
|
||||
if(OP == triton::dnn::BPROP){
|
||||
H *= stride_h_;
|
||||
W *= stride_w_;
|
||||
}
|
||||
// shapes for b
|
||||
int64_t Cb = tf_b.dim_size(0);
|
||||
F = tf_b.dim_size(1);
|
||||
if(OP == triton::dnn::shift::BPROP)
|
||||
if(OP == triton::dnn::BPROP)
|
||||
std::swap(Cb, F);
|
||||
// checks
|
||||
OP_REQUIRES(context, Ca == Cb, tensorflow::errors::InvalidArgument("operands must have the same number of channels"));
|
||||
C = Ca;
|
||||
if(OP == triton::dnn::shift::BPROP)
|
||||
if(OP == triton::dnn::BPROP)
|
||||
std::swap(C, F);
|
||||
}
|
||||
}
|
||||
@@ -122,7 +122,7 @@ public:
|
||||
triton::driver::cu_buffer da(ctx, (CUdeviceptr)tf_a.flat<Eigen::half>().data(), false);
|
||||
triton::driver::cu_buffer db(ctx, (CUdeviceptr)tf_b.flat<Eigen::half>().data(), false);
|
||||
triton::driver::cu_buffer dc(ctx, (CUdeviceptr)tf_c->flat<Eigen::half>().data(), false);
|
||||
shift.enqueue(stream, {&da, &db, &dc}, false);
|
||||
shift.enqueue(stream, {&da, &db, &dc}, triton::dnn::PARTIAL_TUNING);
|
||||
}
|
||||
|
||||
private:
|
||||
@@ -132,10 +132,10 @@ private:
|
||||
int stride_w_;
|
||||
int R_;
|
||||
int S_;
|
||||
triton::dnn::shift::layout_t layout_;
|
||||
triton::dnn::layout_t layout_;
|
||||
};
|
||||
|
||||
REGISTER_KERNEL_BUILDER(Name("ShiftConv").Device(DEVICE_GPU), ShiftConvOp<triton::dnn::shift::FPROP>);
|
||||
REGISTER_KERNEL_BUILDER(Name("ShiftConv").Device(DEVICE_GPU), ShiftConvOp<triton::dnn::FPROP>);
|
||||
REGISTER_OP("ShiftConv")
|
||||
.Input("a: float16")
|
||||
.Input("b: float16")
|
||||
@@ -145,7 +145,7 @@ REGISTER_OP("ShiftConv")
|
||||
.Attr("stride_w: int")
|
||||
.Output("c: float16");
|
||||
|
||||
REGISTER_KERNEL_BUILDER(Name("ShiftConvDx").Device(DEVICE_GPU), ShiftConvOp<triton::dnn::shift::BPROP>);
|
||||
REGISTER_KERNEL_BUILDER(Name("ShiftConvDx").Device(DEVICE_GPU), ShiftConvOp<triton::dnn::BPROP>);
|
||||
REGISTER_OP("ShiftConvDx")
|
||||
.Input("a: float16")
|
||||
.Input("b: float16")
|
||||
@@ -155,7 +155,7 @@ REGISTER_OP("ShiftConvDx")
|
||||
.Attr("stride_w: int")
|
||||
.Output("c: float16");
|
||||
|
||||
REGISTER_KERNEL_BUILDER(Name("ShiftConvDw").Device(DEVICE_GPU), ShiftConvOp<triton::dnn::shift::WGRAD>);
|
||||
REGISTER_KERNEL_BUILDER(Name("ShiftConvDw").Device(DEVICE_GPU), ShiftConvOp<triton::dnn::WGRAD>);
|
||||
REGISTER_OP("ShiftConvDw")
|
||||
.Input("a: float16")
|
||||
.Input("b: float16")
|
||||
|
42
include/triton/dnn/blocksparse/dot.h
Normal file
42
include/triton/dnn/blocksparse/dot.h
Normal file
@@ -0,0 +1,42 @@
|
||||
#include "triton/driver/stream.h"
|
||||
#include "triton/driver/kernel.h"
|
||||
#include "triton/dnn/base.h"
|
||||
#include <string>
|
||||
|
||||
namespace triton{
|
||||
namespace dnn{
|
||||
namespace blocksparse{
|
||||
|
||||
|
||||
class dot: public base {
|
||||
private:
|
||||
void enqueue_impl(driver::stream *stream, driver::kernel *kernel,
|
||||
std::vector<driver::buffer*> args,
|
||||
triton::runtime::launch_information info);
|
||||
// number of flops
|
||||
virtual size_t num_flops() const;
|
||||
// comparison for maps
|
||||
virtual bool operator<(const base& other) const;
|
||||
// default parameters
|
||||
virtual std::vector<params_t> search_space() const;
|
||||
virtual params_t heuristics() const;
|
||||
|
||||
public:
|
||||
// constructor
|
||||
dot(int32_t M, int32_t N, int32_t K);
|
||||
// triton-c source
|
||||
virtual void triton_c_src(std::ostream &os) const;
|
||||
// clone
|
||||
virtual base* clone() const;
|
||||
|
||||
private:
|
||||
std::string ab_ty_;
|
||||
std::string c_ty_;
|
||||
int32_t M_;
|
||||
int32_t N_;
|
||||
int32_t K_;
|
||||
};
|
||||
|
||||
}
|
||||
}
|
||||
}
|
@@ -13,79 +13,95 @@ typedef std::tuple<size_t, size_t> size_key_t;
|
||||
static const std::map<trans_key_t, std::map<size_key_t, params_t>> params = {
|
||||
/* NN */
|
||||
{trans_key_t(false, false), std::map<size_key_t, params_t>{
|
||||
{size_key_t(16, 16), {4, 4, 16, 8, 16, 2, 2, 1, 1, 8, 32, 4, 8, 1}},
|
||||
{size_key_t(16, 32), {2, 8, 16, 8, 32, 2, 2, 1, 1, 16, 32, 4, 8, 1}},
|
||||
{size_key_t(16, 16), {2, 8, 16, 4, 16, 2, 2, 1, 1, 16, 32, 8, 4, 1}},
|
||||
{size_key_t(16, 32), {4, 4, 16, 4, 32, 2, 2, 1, 1, 8, 32, 8, 4, 1}},
|
||||
{size_key_t(16, 64), {4, 4, 16, 4, 64, 2, 2, 1, 1, 8, 32, 8, 4, 1}},
|
||||
{size_key_t(16, 128), {4, 4, 16, 16, 128, 2, 2, 1, 2, 16, 32, 4, 8, 1}},
|
||||
{size_key_t(32, 16), {4, 8, 32, 8, 16, 2, 2, 1, 1, 8, 32, 4, 8, 1}},
|
||||
{size_key_t(32, 32), {4, 8, 32, 8, 32, 2, 2, 1, 1, 8, 32, 4, 8, 1}},
|
||||
{size_key_t(32, 64), {8, 4, 32, 8, 64, 2, 2, 1, 1, 4, 32, 4, 8, 1}},
|
||||
{size_key_t(32, 128), {8, 4, 32, 16, 128, 2, 2, 1, 4, 16, 32, 8, 4, 1}},
|
||||
{size_key_t(64, 16), {8, 8, 64, 4, 16, 2, 2, 1, 1, 4, 32, 8, 4, 1}},
|
||||
{size_key_t(16, 128), {2, 8, 16, 8, 128, 2, 2, 1, 1, 16, 32, 4, 8, 1}},
|
||||
{size_key_t(32, 16), {8, 4, 32, 8, 16, 2, 2, 1, 1, 4, 32, 4, 8, 1}},
|
||||
{size_key_t(32, 32), {4, 8, 32, 4, 32, 2, 2, 1, 1, 8, 32, 8, 4, 1}},
|
||||
{size_key_t(32, 64), {8, 4, 32, 4, 64, 2, 2, 1, 1, 4, 32, 8, 4, 1}},
|
||||
{size_key_t(32, 128), {8, 4, 32, 32, 128, 2, 2, 2, 2, 16, 32, 4, 4, 1}},
|
||||
{size_key_t(32, 256), {4, 8, 32, 32, 256, 2, 2, 1, 4, 32, 32, 4, 8, 1}},
|
||||
{size_key_t(64, 16), {8, 8, 64, 8, 16, 2, 2, 1, 1, 4, 32, 4, 8, 1}},
|
||||
{size_key_t(64, 32), {8, 8, 64, 8, 32, 2, 2, 1, 1, 4, 32, 4, 8, 1}},
|
||||
{size_key_t(64, 64), {8, 8, 64, 16, 64, 2, 2, 2, 1, 8, 32, 4, 8, 1}},
|
||||
{size_key_t(64, 128), {16, 4, 64, 16, 128, 2, 2, 2, 2, 8, 32, 8, 4, 1}},
|
||||
{size_key_t(128, 16), {8, 8, 128, 8, 16, 2, 2, 2, 1, 8, 32, 8, 4, 1}},
|
||||
{size_key_t(128, 32), {8, 8, 128, 16, 32, 2, 2, 2, 1, 8, 32, 4, 8, 1}},
|
||||
{size_key_t(128, 64), {8, 8, 128, 32, 64, 2, 2, 2, 2, 16, 32, 4, 8, 1}},
|
||||
{size_key_t(128, 128), {8, 8, 128, 32, 128, 2, 2, 1, 4, 16, 32, 4, 8, 1}}
|
||||
{size_key_t(64, 64), {8, 8, 64, 16, 64, 2, 2, 1, 2, 8, 32, 4, 8, 1}},
|
||||
{size_key_t(64, 128), {16, 4, 64, 32, 128, 2, 2, 1, 4, 8, 32, 4, 8, 1}},
|
||||
{size_key_t(128, 16), {8, 8, 128, 16, 16, 2, 2, 2, 1, 8, 32, 4, 8, 1}},
|
||||
{size_key_t(128, 32), {32, 4, 128, 16, 32, 2, 2, 2, 1, 2, 32, 4, 8, 1}},
|
||||
{size_key_t(128, 64), {16, 8, 128, 16, 64, 2, 2, 2, 2, 8, 32, 8, 4, 1}},
|
||||
{size_key_t(128, 128), {8, 8, 128, 32, 128, 2, 2, 2, 2, 16, 32, 4, 8, 1}},
|
||||
{size_key_t(256, 16), {32, 8, 256, 16, 16, 2, 2, 4, 1, 4, 32, 8, 4, 1}},
|
||||
{size_key_t(256, 32), {32, 8, 256, 16, 32, 2, 2, 4, 1, 4, 32, 8, 4, 1}},
|
||||
{size_key_t(256, 64), {16, 8, 256, 32, 64, 2, 2, 4, 1, 8, 32, 4, 8, 1}}
|
||||
}},
|
||||
/* NT */
|
||||
{trans_key_t(false, true), std::map<size_key_t, params_t>{
|
||||
{size_key_t(16, 16), {4, 4, 16, 2, 8, 16, 2, 2, 1, 1, 8, 32, 16, 1}},
|
||||
{size_key_t(16, 32), {4, 4, 16, 4, 8, 32, 2, 2, 1, 1, 8, 32, 8, 1}},
|
||||
{size_key_t(16, 64), {4, 4, 16, 8, 8, 64, 2, 2, 1, 4, 32, 32, 16, 1}},
|
||||
{size_key_t(16, 128), {4, 4, 16, 32, 4, 128, 2, 2, 1, 2, 16, 32, 2, 1}},
|
||||
{size_key_t(32, 16), {8, 4, 32, 2, 8, 16, 2, 2, 1, 1, 4, 32, 16, 1}},
|
||||
{size_key_t(32, 32), {4, 8, 32, 4, 8, 32, 2, 2, 1, 1, 8, 32, 8, 1}},
|
||||
{size_key_t(32, 64), {16, 8, 128, 4, 4, 64, 2, 2, 1, 4, 8, 32, 32, 1}},
|
||||
{size_key_t(32, 128), {4, 8, 32, 8, 8, 128, 2, 2, 1, 2, 16, 32, 8, 1}},
|
||||
{size_key_t(64, 16), {8, 8, 64, 2, 8, 16, 2, 2, 1, 1, 4, 32, 16, 1}},
|
||||
{size_key_t(64, 32), {8, 8, 64, 4, 8, 32, 2, 2, 1, 1, 4, 32, 8, 1}},
|
||||
{size_key_t(64, 64), {8, 8, 64, 8, 8, 64, 2, 2, 1, 2, 8, 32, 8, 1}},
|
||||
{size_key_t(64, 128), {8, 8, 64, 16, 8, 128, 2, 2, 1, 4, 16, 32, 8, 1}},
|
||||
{size_key_t(128, 16), {8, 8, 128, 2, 8, 16, 2, 2, 2, 1, 8, 32, 32, 1}},
|
||||
{size_key_t(128, 32), {16, 8, 128, 4, 8, 32, 2, 2, 2, 1, 4, 32, 16, 1}},
|
||||
{size_key_t(128, 64), {8, 8, 128, 8, 8, 64, 2, 2, 2, 2, 16, 32, 16, 1}},
|
||||
{size_key_t(128, 128), {8, 8, 128, 8, 8, 128, 2, 2, 4, 1, 16, 32, 16, 1}}
|
||||
}},
|
||||
{size_key_t(16, 16), {2, 4, 16, 2, 8, 16, 2, 2, 1, 1, 16, 32, 16, 1}},
|
||||
{size_key_t(16, 32), {4, 4, 16, 8, 4, 32, 2, 2, 1, 1, 8, 32, 4, 1}},
|
||||
{size_key_t(16, 64), {2, 4, 16, 2, 8, 64, 2, 2, 1, 1, 16, 32, 16, 1}},
|
||||
{size_key_t(16, 128), {2, 8, 16, 8, 8, 128, 2, 2, 1, 1, 16, 32, 4, 1}},
|
||||
{size_key_t(32, 16), {8, 4, 32, 2, 8, 16, 2, 2, 1, 1, 4, 32, 16, 1}},
|
||||
{size_key_t(32, 32), {4, 8, 32, 8, 4, 32, 2, 2, 1, 1, 8, 32, 4, 1}},
|
||||
{size_key_t(32, 64), {16, 4, 64, 16, 4, 64, 2, 2, 4, 1, 8, 32, 8, 1}},
|
||||
{size_key_t(32, 128), {4, 8, 32, 16, 4, 128, 2, 2, 1, 2, 16, 32, 4, 1}},
|
||||
{size_key_t(32, 256), {4, 8, 32, 64, 4, 256, 2, 2, 1, 4, 32, 32, 2, 1}},
|
||||
{size_key_t(64, 16), {8, 8, 64, 2, 8, 16, 2, 2, 1, 1, 4, 32, 16, 1}},
|
||||
{size_key_t(64, 32), {16, 4, 64, 4, 4, 32, 2, 2, 1, 1, 2, 32, 8, 1}},
|
||||
{size_key_t(64, 64), {8, 8, 64, 8, 8, 64, 2, 2, 2, 1, 8, 32, 8, 1}},
|
||||
{size_key_t(64, 128), {4, 4, 64, 8, 8, 128, 2, 2, 1, 4, 32, 32, 16, 1}},
|
||||
{size_key_t(64, 256), {8, 8, 64, 8, 8, 256, 2, 2, 1, 4, 16, 32, 16, 1}},
|
||||
{size_key_t(128, 16), {16, 8, 128, 2, 8, 16, 2, 2, 1, 1, 2, 32, 16, 1}},
|
||||
{size_key_t(128, 32), {32, 4, 128, 4, 8, 32, 2, 2, 2, 1, 2, 32, 16, 1}},
|
||||
{size_key_t(128, 64), {8, 8, 128, 8, 8, 64, 2, 2, 4, 1, 16, 32, 16, 1}},
|
||||
{size_key_t(128, 128), {8, 8, 128, 16, 8, 128, 2, 2, 2, 2, 16, 32, 8, 1}},
|
||||
{size_key_t(256, 16), {32, 4, 256, 4, 4, 16, 2, 2, 4, 1, 4, 32, 32, 1}},
|
||||
{size_key_t(256, 32), {16, 8, 256, 8, 4, 32, 2, 2, 4, 1, 8, 32, 16, 1}},
|
||||
{size_key_t(256, 64), {8, 8, 256, 8, 8, 64, 2, 2, 4, 1, 16, 32, 16, 1}}
|
||||
}},
|
||||
/* TN */
|
||||
{trans_key_t(true, false), std::map<size_key_t, params_t>{
|
||||
{size_key_t(16, 16), {8, 16, 16, 16, 2, 2, 1, 1, 4, 8, 32, 2, 8, 1}},
|
||||
{size_key_t(16, 32), {4, 16, 8, 32, 2, 2, 1, 1, 8, 4, 32, 4, 8, 1}},
|
||||
{size_key_t(16, 64), {4, 16, 4, 64, 2, 2, 1, 1, 8, 4, 32, 8, 4, 1}},
|
||||
{size_key_t(16, 128), {16, 16, 16, 128, 2, 2, 1, 2, 4, 8, 32, 4, 8, 1}},
|
||||
{size_key_t(32, 16), {4, 32, 8, 16, 2, 2, 1, 1, 8, 4, 32, 4, 8, 1}},
|
||||
{size_key_t(32, 32), {8, 32, 8, 32, 2, 2, 1, 1, 4, 8, 32, 4, 8, 1}},
|
||||
{size_key_t(32, 64), {8, 32, 8, 64, 2, 2, 1, 1, 4, 8, 32, 4, 8, 1}},
|
||||
{size_key_t(32, 128), {32, 32, 64, 128, 2, 2, 2, 2, 4, 8, 32, 2, 8, 1}},
|
||||
{size_key_t(64, 16), {8, 64, 8, 16, 2, 2, 1, 1, 4, 8, 32, 4, 8, 1}},
|
||||
{size_key_t(64, 32), {8, 64, 8, 32, 2, 2, 1, 1, 4, 8, 32, 4, 8, 1}},
|
||||
{size_key_t(64, 64), {16, 64, 16, 64, 2, 2, 2, 1, 4, 8, 32, 4, 8, 1}},
|
||||
{size_key_t(64, 128), {32, 64, 16, 128, 2, 2, 2, 2, 4, 8, 32, 8, 4, 1}},
|
||||
{size_key_t(128, 16), {16, 128, 16, 16, 2, 2, 2, 1, 4, 8, 32, 4, 8, 1}},
|
||||
{size_key_t(128, 32), {32, 128, 32, 32, 2, 2, 4, 1, 4, 8, 32, 4, 8, 1}},
|
||||
{size_key_t(128, 64), {32, 128, 32, 64, 2, 2, 4, 1, 4, 8, 32, 4, 8, 1}},
|
||||
{size_key_t(128, 128), {32, 128, 32, 128, 2, 2, 4, 1, 4, 8, 32, 4, 8, 1}},
|
||||
{size_key_t(16, 16), {4, 16, 4, 16, 2, 2, 1, 1, 8, 4, 32, 8, 4, 1}},
|
||||
{size_key_t(16, 32), {8, 16, 8, 32, 2, 2, 1, 1, 4, 4, 32, 4, 4, 1}},
|
||||
{size_key_t(16, 64), {4, 16, 8, 64, 2, 2, 1, 1, 8, 4, 32, 4, 8, 1}},
|
||||
{size_key_t(16, 128), {4, 16, 8, 128, 2, 2, 1, 1, 8, 4, 32, 4, 8, 1}},
|
||||
{size_key_t(32, 16), {4, 32, 8, 16, 2, 2, 1, 1, 8, 4, 32, 4, 8, 1}},
|
||||
{size_key_t(32, 32), {4, 32, 4, 32, 2, 2, 1, 1, 8, 4, 32, 8, 4, 1}},
|
||||
{size_key_t(32, 64), {4, 32, 4, 64, 2, 2, 1, 1, 8, 4, 32, 8, 4, 1}},
|
||||
{size_key_t(32, 128), {8, 32, 8, 128, 2, 2, 1, 1, 4, 8, 32, 4, 8, 1}},
|
||||
{size_key_t(32, 256), {32, 32, 32, 256, 2, 2, 1, 4, 4, 8, 32, 4, 8, 1}},
|
||||
{size_key_t(64, 16), {4, 64, 8, 16, 2, 2, 1, 1, 8, 4, 32, 4, 8, 1}},
|
||||
{size_key_t(64, 32), {4, 64, 4, 32, 2, 2, 1, 1, 8, 4, 32, 8, 4, 1}},
|
||||
{size_key_t(64, 64), {8, 64, 16, 64, 2, 2, 2, 1, 8, 4, 32, 4, 8, 1}},
|
||||
{size_key_t(64, 128), {16, 64, 32, 128, 2, 2, 1, 4, 8, 4, 32, 4, 8, 1}},
|
||||
{size_key_t(128, 16), {8, 128, 8, 16, 2, 2, 1, 1, 4, 8, 32, 4, 8, 1}},
|
||||
{size_key_t(128, 32), {16, 128, 16, 32, 2, 2, 4, 1, 8, 4, 32, 8, 4, 1}},
|
||||
{size_key_t(128, 64), {32, 128, 32, 64, 2, 2, 2, 2, 4, 8, 32, 4, 8, 1}},
|
||||
{size_key_t(128, 128), {32, 128, 32, 128, 2, 2, 1, 4, 4, 8, 32, 4, 8, 1}},
|
||||
{size_key_t(256, 16), {16, 256, 16, 16, 2, 2, 2, 1, 4, 8, 32, 4, 8, 1}},
|
||||
{size_key_t(256, 32), {16, 256, 32, 32, 2, 2, 4, 1, 8, 4, 32, 4, 8, 1}},
|
||||
}},
|
||||
/* TT */
|
||||
{trans_key_t(true, true), std::map<size_key_t, params_t>{
|
||||
{size_key_t(16, 16), {4, 16, 2, 8, 16, 2, 2, 1, 1, 8, 4, 32, 16, 1}},
|
||||
{size_key_t(16, 32), {8, 16, 4, 8, 32, 2, 2, 1, 1, 4, 8, 32, 8, 1}},
|
||||
{size_key_t(16, 64), {16, 16, 4, 8, 64, 2, 2, 1, 4, 8, 4, 32, 32, 1}},
|
||||
{size_key_t(16, 128), {16, 16, 8, 4, 128, 2, 2, 1, 2, 4, 8, 32, 8, 1}},
|
||||
{size_key_t(32, 16), {4, 32, 2, 8, 16, 2, 2, 1, 1, 8, 4, 32, 16, 1}},
|
||||
{size_key_t(32, 32), {8, 32, 4, 8, 32, 2, 2, 1, 1, 4, 8, 32, 8, 1}},
|
||||
{size_key_t(32, 64), {16, 64, 4, 8, 64, 2, 2, 2, 1, 4, 8, 32, 16, 1}},
|
||||
{size_key_t(32, 128), {32, 32, 8, 8, 128, 2, 2, 1, 4, 4, 8, 32, 16, 1}},
|
||||
{size_key_t(64, 16), {8, 64, 2, 8, 16, 2, 2, 1, 1, 4, 8, 32, 16, 1}},
|
||||
{size_key_t(64, 32), {8, 64, 4, 8, 32, 2, 2, 1, 1, 4, 8, 32, 8, 1}},
|
||||
{size_key_t(64, 64), {16, 64, 8, 8, 64, 2, 2, 2, 1, 4, 8, 32, 8, 1}},
|
||||
{size_key_t(64, 128), {32, 64, 8, 8, 128, 2, 2, 1, 4, 4, 8, 32, 16, 1}},
|
||||
{size_key_t(128, 16), {16, 128, 2, 8, 16, 2, 2, 2, 1, 4, 8, 32, 32, 1}},
|
||||
{size_key_t(128, 32), {32, 128, 8, 4, 32, 2, 2, 4, 1, 4, 8, 32, 16, 1}},
|
||||
{size_key_t(128, 64), {32, 128, 16, 4, 64, 2, 2, 4, 1, 4, 8, 32, 8, 1}},
|
||||
{size_key_t(128, 128), {32, 128, 8, 8, 128, 2, 2, 4, 1, 4, 8, 32, 16, 1}}
|
||||
{size_key_t(16, 16), {8, 16, 4, 4, 16, 2, 2, 1, 1, 4, 8, 32, 8, 1}},
|
||||
{size_key_t(16, 32), {8, 16, 8, 4, 32, 2, 2, 1, 1, 4, 8, 32, 4, 1}},
|
||||
{size_key_t(16, 64), {16, 16, 4, 8, 64, 2, 2, 1, 4, 8, 4, 32, 32, 1}},
|
||||
{size_key_t(16, 128), {16, 16, 8, 8, 128, 2, 2, 1, 1, 2, 4, 32, 4, 1}},
|
||||
{size_key_t(32, 16), {4, 32, 4, 4, 16, 2, 2, 1, 1, 8, 4, 32, 8, 1}},
|
||||
{size_key_t(32, 32), {8, 32, 8, 4, 32, 2, 2, 1, 1, 4, 8, 32, 4, 1}},
|
||||
{size_key_t(32, 64), {64, 128, 8, 4, 64, 2, 2, 4, 1, 2, 8, 32, 16, 1}},
|
||||
{size_key_t(32, 128), {16, 32, 32, 4, 128, 2, 2, 1, 2, 4, 8, 32, 2, 1}},
|
||||
{size_key_t(32, 256), {32, 32, 32, 4, 256, 2, 2, 1, 4, 4, 8, 32, 4, 1}},
|
||||
{size_key_t(64, 16), {4, 64, 2, 8, 16, 2, 2, 1, 1, 8, 4, 32, 16, 1}},
|
||||
{size_key_t(64, 32), {4, 64, 8, 4, 32, 2, 2, 1, 1, 8, 4, 32, 4, 1}},
|
||||
{size_key_t(64, 64), {16, 64, 8, 8, 64, 2, 2, 2, 1, 4, 8, 32, 8, 1}},
|
||||
{size_key_t(64, 128), {32, 64, 8, 8, 128, 2, 2, 1, 4, 4, 4, 32, 16, 1}},
|
||||
{size_key_t(64, 256), {64, 64, 8, 8, 256, 2, 2, 1, 4, 2, 8, 32, 16}},
|
||||
{size_key_t(128, 16), {8, 128, 2, 8, 16, 2, 2, 1, 1, 4, 8, 32, 16, 1}},
|
||||
{size_key_t(128, 32), {16, 128, 8, 4, 32, 2, 2, 4, 1, 8, 4, 32, 16, 1}},
|
||||
{size_key_t(128, 64), {32, 128, 8, 8, 64, 2, 2, 4, 1, 4, 8, 32, 16, 1}},
|
||||
{size_key_t(128, 128), {32, 128, 16, 8, 128, 2, 2, 2, 2, 4, 8, 32, 8, 1}},
|
||||
{size_key_t(256, 16), {32, 256, 4, 4, 16, 2, 2, 4, 1, 4, 8, 32, 32, 1}},
|
||||
{size_key_t(256, 32), {32, 256, 8, 4, 32, 2, 2, 4, 1, 4, 8, 32, 16, 1}}
|
||||
}}
|
||||
};
|
||||
|
||||
|
@@ -72,6 +72,7 @@ public:
|
||||
void target_independent(ir::module &module) {
|
||||
optimize_dot.run(module);
|
||||
optimize_trans.run(module);
|
||||
// optimize_dce.run(module);
|
||||
}
|
||||
|
||||
void target_dependent(ir::module &module) {
|
||||
|
@@ -1190,9 +1190,8 @@ void selection::lower_tile_instruction(ir::instruction *ins, llvm::IRBuilder<> &
|
||||
Value *ptr = pointers->get_value(idx);
|
||||
ConstantInt *cst = nullptr;
|
||||
if(GetElementPtrInst *gep = dyn_cast<GetElementPtrInst>(ptr))
|
||||
if(gep->getNumIndices() == 1){
|
||||
if(gep->getNumIndices() == 1)
|
||||
cst = dyn_cast<ConstantInt>(gep->idx_begin());
|
||||
}
|
||||
ptr = builder.CreateBitCast(ptr, PointerType::get(VectorType::get(result->get_ty(), vector_size),
|
||||
ptr->getType()->getPointerAddressSpace()));
|
||||
packets[id] = builder.CreateLoad(ptr);
|
||||
@@ -1202,7 +1201,7 @@ void selection::lower_tile_instruction(ir::instruction *ins, llvm::IRBuilder<> &
|
||||
result->for_each([&](indices_t idx){
|
||||
unsigned linear = result->get_linear_index(idx);
|
||||
unsigned id = linear / vector_size;
|
||||
// result->set_value(idx, builder.CreateExtractElement(packets.at(id), linear % vector_size));
|
||||
result->set_value(idx, builder.CreateExtractElement(packets.at(id), linear % vector_size));
|
||||
});
|
||||
}
|
||||
// element-wise
|
||||
|
@@ -247,14 +247,14 @@ void tune::run(ir::module &mod) {
|
||||
size_t addr_space = ptr_ty->get_pointer_address_space();
|
||||
if(addr_space < 4){
|
||||
ir::type *ty = mod.get_builder().get_int32_ty();
|
||||
std::unique_ptr<ir::metaparameter> tmp(ir::metaparameter::create(ctx, ty, 8, 8));
|
||||
std::unique_ptr<ir::metaparameter> tmp(ir::metaparameter::create(ctx, ty, 4, 8));
|
||||
*params_.at(i).at("nts.d0") = *tmp;
|
||||
}
|
||||
}
|
||||
if(dynamic_cast<ir::dot_inst*>(i) && i->get_type()->is_tile_ty()){
|
||||
ir::type *ty = mod.get_builder().get_int32_ty();
|
||||
std::unique_ptr<ir::metaparameter> tmp1(ir::metaparameter::create(ctx, ty, 8, 8));
|
||||
std::unique_ptr<ir::metaparameter> tmp2(ir::metaparameter::create(ctx, ty, 8, 8));
|
||||
std::unique_ptr<ir::metaparameter> tmp1(ir::metaparameter::create(ctx, ty, 4, 8));
|
||||
std::unique_ptr<ir::metaparameter> tmp2(ir::metaparameter::create(ctx, ty, 4, 8));
|
||||
*params_.at(i).at("nts.d0") = *tmp1;
|
||||
*params_.at(i).at("nts.d1") = *tmp2;
|
||||
}
|
||||
|
@@ -59,8 +59,8 @@ void base::enqueue(driver::stream *stream, std::vector<driver::buffer *> args, a
|
||||
jit->add_module(name_.c_str(), src.c_str(), best.params);
|
||||
}
|
||||
else {
|
||||
params_t params = heuristics();
|
||||
// params_t params = jit->get_valid(name_.c_str(), src.c_str());
|
||||
// params_t params = heuristics();
|
||||
params_t params = jit->get_valid(name_.c_str(), src.c_str());
|
||||
jit->add_module(name_.c_str(), src.c_str(), params);
|
||||
}
|
||||
triton::driver::kernel* kernel = jit->get_function(name_.c_str());
|
||||
|
109
lib/dnn/blocksparse/dot.cpp
Normal file
109
lib/dnn/blocksparse/dot.cpp
Normal file
@@ -0,0 +1,109 @@
|
||||
#include "triton/dnn/blocksparse/dot.h"
|
||||
|
||||
namespace triton{
|
||||
namespace dnn{
|
||||
namespace blocksparse{
|
||||
|
||||
|
||||
size_t dot::num_flops() const {
|
||||
|
||||
}
|
||||
|
||||
bool dot::operator <(const base& other) const {
|
||||
auto *y = dynamic_cast<const dot*>(&other);
|
||||
if(!y)
|
||||
return true;
|
||||
return std::tie(M_, N_, K_)
|
||||
< std::tie(y->M_, y->N_, y->K_);
|
||||
}
|
||||
|
||||
std::vector<params_t> dot::search_space() const {
|
||||
|
||||
}
|
||||
|
||||
params_t dot::heuristics() const {
|
||||
|
||||
}
|
||||
|
||||
base * dot::clone() const {
|
||||
return new dot(*this);
|
||||
}
|
||||
|
||||
dot::dot(int32_t M, int32_t N, int32_t K):
|
||||
base("bsdot"), M_(M), N_(N), K_(K) {
|
||||
ab_ty_ = "fp32";
|
||||
c_ty_ = "fp32";
|
||||
}
|
||||
|
||||
void dot::enqueue_impl(driver::stream *stream, driver::kernel *kernel,
|
||||
std::vector<driver::buffer *> args, runtime::launch_information info) {
|
||||
driver::buffer *a = args[0];
|
||||
driver::buffer *b = args[1];
|
||||
driver::buffer *c = args[2];
|
||||
driver::buffer *lut = args[3];
|
||||
int32_t lda = M_;
|
||||
int32_t ldc = M_;
|
||||
kernel->setArg(0, a);
|
||||
kernel->setArg(1, b);
|
||||
kernel->setArg(2, c);
|
||||
kernel->setArg(3, lda);
|
||||
kernel->setArg(4, ldc);
|
||||
kernel->setArg(5, lut);
|
||||
int32_t TM = info.globals["TM"];
|
||||
int32_t TN = info.globals["TN"];
|
||||
size_t grid_0 = (M_ + TM - 1) / TM;
|
||||
size_t grid_1 = (N_ + TN - 1) / TN;
|
||||
stream->enqueue(kernel, {grid_0, grid_1, 1}, {info.num_threads, 1, 1});
|
||||
stream->synchronize();
|
||||
}
|
||||
|
||||
void dot::triton_c_src(std::ostream &os) const {
|
||||
std::string result =
|
||||
|
||||
R"(
|
||||
const tunable int32 TM = {64, 128};
|
||||
const tunable int32 TN = {32};
|
||||
const tunable int32 TK = {32};
|
||||
|
||||
void bsdot(restrict read_only align(16) )" + ab_ty_ + R"( *A,
|
||||
restrict read_only align(16) )" + ab_ty_ + R"( *B,
|
||||
fp32* C,
|
||||
int32 lda, int32 ldc,
|
||||
int32* lut_base){
|
||||
int32 ridx = get_range_id(0);
|
||||
int32 ridy = get_range_id(1);
|
||||
fp32 c[TM, TN] = 0;
|
||||
int32 rka[TK] = 0 ... TK;
|
||||
int32 rkb[TK] = 0 ... TK;
|
||||
int32 rxa[TM] = ridx * TM + (0 ... TM);
|
||||
int32 ryb[TN] = 0 ... TN;
|
||||
int32 offa[TM, TK] = rxa[:, newaxis] + rka[newaxis, :]*lda;
|
||||
int32 offb[TK, TN] = ryb[newaxis, :] + rkb[:, newaxis]*TK;
|
||||
int32 *header = lut_base + ridy * 4;
|
||||
int32 offset = *(header + 0);
|
||||
int32 K = *(header + 1);
|
||||
int32 h2 = *(header + 2);
|
||||
int32 h3 = *(header + 3);
|
||||
int32 *lut = lut_base + offset*2;
|
||||
for(int32 k = K; k > 0; k = k - 1){
|
||||
int32 ak = *(lut + 0);
|
||||
int32 bk = *(lut + 1);
|
||||
fp32* pa[TM, TK] = A + offa + ak * TK * lda;
|
||||
fp32* pb[TK, TN] = B + offb + bk * TK * TN;
|
||||
fp32 a[TM, TK] = *pa;
|
||||
fp32 b[TK, TN] = *pb;;
|
||||
c = dot(a, b, c);
|
||||
lut = lut + 2;
|
||||
}
|
||||
int32 rxc[TM] = ridx * TM + (0 ... TM);
|
||||
int32 ryc[TN] = ridy * TN + (0 ... TN);
|
||||
fp32* pc[TM, TN] = C + rxc[:, newaxis] + ryc[newaxis, :]*ldc;
|
||||
*pc = c;
|
||||
})";
|
||||
|
||||
os << result;
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
}
|
@@ -1,6 +1,6 @@
|
||||
#include "triton/driver/stream.h"
|
||||
#include "triton/driver/kernel.h"
|
||||
#include "triton/dnn/gemm.h"
|
||||
#include "triton/dnn/dot.h"
|
||||
#include "triton/dnn/heuristics.h"
|
||||
#include <string>
|
||||
|
||||
@@ -101,8 +101,8 @@ void dot::triton_c_src(std::ostream &os) const {
|
||||
std::string align_ldb_str = "multiple_of(" + std::to_string(align_ldb_) + ")";
|
||||
std::string res =
|
||||
R"(
|
||||
const tunable int32 TM = {16, 32, 64, 128};
|
||||
const tunable int32 TN = {16, 32, 64, 128};
|
||||
const tunable int32 TM = {16, 32, 64, 128, 256};
|
||||
const tunable int32 TN = {16, 32, 64, 128, 256};
|
||||
const tunable int32 TK = {32};
|
||||
const tunable int32 GZ = {1};
|
||||
|
@@ -467,7 +467,7 @@ if(op_ == WGRAD){
|
||||
pa = pa_base + offxa[:, newaxis];)";
|
||||
}
|
||||
result += R"(
|
||||
@checka a = *pa;)";
|
||||
a = checka ? *pa : 0;)";
|
||||
|
||||
/* Increment B pointers */
|
||||
if(op_ == WGRAD){
|
||||
@@ -488,7 +488,7 @@ if(op_ == BPROP){
|
||||
pb = pb + TK;)";
|
||||
}
|
||||
result += R"(
|
||||
@checkb b = *pb;
|
||||
b = checkb ? *pb : 0;
|
||||
}
|
||||
int32 rxc[TM] = ridx*TM + (0 ... TM);
|
||||
int32 ryc[TN] = ridy*TN + (0 ... TN);)";
|
||||
|
@@ -37,12 +37,13 @@ void parallel_loop_nest(std::vector<size_t> const & ranges,
|
||||
size_t D = ranges.size();
|
||||
std::vector<size_t> values(D, 0);
|
||||
// thread pools
|
||||
ThreadPool pool(nthreads);
|
||||
// ThreadPool pool(nthreads);
|
||||
// Start with innermost loop
|
||||
size_t i = D - 1;
|
||||
while(true){
|
||||
// Execute function
|
||||
pool.enqueue(f,values);
|
||||
// pool.enqueue(f,values);
|
||||
f(values);
|
||||
while(values[i]++ == ranges[i] - 1){
|
||||
if(i == 0)
|
||||
return;
|
||||
@@ -50,7 +51,7 @@ void parallel_loop_nest(std::vector<size_t> const & ranges,
|
||||
}
|
||||
i = D - 1;
|
||||
// Short sleep so that the thread pool doesn't grow too big
|
||||
std::this_thread::sleep_for(std::chrono::microseconds(1));
|
||||
// std::this_thread::sleep_for(std::chrono::microseconds(1));
|
||||
}
|
||||
}
|
||||
|
||||
@@ -211,9 +212,9 @@ jit::tune_res_t jit::autotune(const char *name, const char *src, benchmark_t ben
|
||||
best.perf = perf;
|
||||
best.params = params;
|
||||
}
|
||||
for(size_t i = 0; i < params.size(); i++)
|
||||
std::cout << ((i==0)?"":", ") << params[i] << std::flush;
|
||||
std::cout << ", " << perf << " [ " << best.perf << " ] " << std::endl;
|
||||
// for(size_t i = 0; i < params.size(); i++)
|
||||
// std::cout << ((i==0)?"":", ") << params[i] << std::flush;
|
||||
// std::cout << ", " << perf << " [ " << best.perf << " ] " << std::endl;
|
||||
}
|
||||
};
|
||||
|
||||
|
Reference in New Issue
Block a user