[codegen] added leading dimension padding for transposition in shared

memory
This commit is contained in:
Philippe Tillet
2019-05-04 20:15:34 -04:00
parent 4813bb007c
commit f80441017c
7 changed files with 314 additions and 83 deletions

View File

@@ -14,6 +14,17 @@ void simple_gemm(std::vector<T> &c, const std::vector<T> &a, const std::vector<T
}
}
template<class T>
void simple_gemm(bool AT, bool BT, std::vector<T> &c, const std::vector<T> &a, const std::vector<T> &b, size_t M, size_t N, size_t K) {
if(AT && BT)
simple_gemm<T, true, true>(c, a, b, M, N, K);
else if(AT && !BT)
simple_gemm<T, true, false>(c, a, b, M, N, K);
else if(!AT && BT)
simple_gemm<T, false, true>(c, a, b, M, N, K);
else
simple_gemm<T, false, false>(c, a, b, M, N, K);
}
class timer{
typedef std::chrono::high_resolution_clock high_resolution_clock;

View File

@@ -5,63 +5,104 @@
#include "triton/driver/backend.h"
#include "triton/driver/stream.h"
const char* src =
R"(
const tunable int32 TM = {16, 32, 64, 128};
const tunable int32 TN = {16, 32, 64, 128};
const tunable int32 TK = {8};
const tunable int32 GZ = {1};
void matmul(restrict read_only fp32 *A, restrict read_only fp32 *B, fp32 *C,
int32 M, int32 N, int32 K,
int32 lda, int32 ldb, int32 ldc,
int32 *locks, int32 grid0, int32 grid1) {
int32 rxa[TM] = get_global_range[TM](0);
int32 ryb[TN] = get_global_range[TN](1);
int32 rz = get_global_range[1](2);
int32 rka[TK] = 0 ... TK;
int32 rkb[TK] = 0 ... TK;
fp32 c[TM, TN] = 0;
int32 div = K / GZ;
int32 rem = K % GZ;
K = select(rz < rem, div - 1, div);
int32 offk = select(rz < rem, rz*(div + 1), rz*div + rem);
fp32* pa[TM, TK] = A + (offk + rka[newaxis, :])*lda + rxa[:, newaxis];
fp32* pb[TN, TK] = B + (offk + rkb[newaxis, :])*ldb + ryb[:, newaxis];
fp32 a[TM, TK] = *pa;
fp32 b[TN, TK] = *pb;
int32 last_a = ((M*K - 1) - (TM*TK + 1)) / lda;
int32 last_b = ((K*N - 1) - (TN*TK + 1)) / ldb;
last_a = last_a / TK * TK;
last_b = last_b / TK * TK;
int32 bound = K - max(last_a, last_b);
for(int32 k = K; k > bound; k = k - TK){
c = dot(a, trans(b), c);
pa = pa + TK*lda;
pb = pb + TK*ldb;
a = *pa;
b = *pb;
std::string triton_source(bool AT, bool BT) {
std::string AS0 = "TM", AS1 = "TK";
std::string BS0 = "TK", BS1 = "TN";
std::string bca0 = "[newaxis, :]", bca1 = "[:, newaxis]";
std::string bcb0 = "[:, newaxis]", bcb1 = "[newaxis, :]";
std::string lda0 = "*lda", lda1 = "";
std::string ldb0 = "", ldb1 = "*ldb";
std::string usea = AT ? "trans(a)" : "a";
std::string useb = BT ? "trans(b)" : "b";
if(AT){
std::swap(AS0, AS1);
std::swap(bca0, bca1);
std::swap(lda0, lda1);
}
int32 rxc[TM] = get_global_range[TM](0);
int32 ryc[TN] = get_global_range[TN](1);
for(int32 k = bound; k > 0; k = k - 1){
int1 checka[TM, 1] = rxc[:, newaxis] < M;
int1 checkb[TN, 1] = ryc[:, newaxis] < N;
fp32* pa[TM, 1] = A + (offk + K - k)*lda + rxc[:, newaxis];
fp32* pb[TN, 1] = B + (offk + K - k)*ldb + ryc[:, newaxis];
fp32 a[TM, 1] = checka ? *pa : 0;
fp32 b[TN, 1] = checkb ? *pb : 0;
c = dot(a, trans(b), c);
if(BT){
std::swap(BS0, BS1);
std::swap(bcb0, bcb1);
std::swap(ldb0, ldb1);
}
fp32* pc[TM, TN] = C + ryc[newaxis, :]*ldc + rxc[:, newaxis];
int1 checkc0[TM] = rxc < M;
int1 checkc1[TN] = ryc < N;
int1 checkc[TM, TN] = checkc0[:, newaxis] && checkc1[newaxis, :];
@checkc *pc = c;
std::string res =
R"(
const tunable int32 TM = {16, 32, 64, 128};
const tunable int32 TN = {16, 32, 64, 128};
const tunable int32 TK = {8};
const tunable int32 GZ = {1};
void matmul(restrict read_only fp32 *A, restrict read_only fp32 *B, fp32 *C,
int32 M, int32 N, int32 K,
int32 lda, int32 ldb, int32 ldc,
int32 *locks, int32 grid0, int32 grid1) {
int32 rxa[TM] = get_global_range[TM](0);
int32 ryb[TN] = get_global_range[TN](1);
int32 rz = get_global_range[1](2);
int32 rka[TK] = 0 ... TK;
int32 rkb[TK] = 0 ... TK;
fp32 c[TM, TN] = 0;
int32 div = K / GZ;
int32 rem = K % GZ;
K = select(rz < rem, div - 1, div);
int32 offk = select(rz < rem, rz*(div + 1), rz*div + rem);
fp32* pa[)" + AS0 + ", " + AS1 + "] = A + (offk + rka" + bca0 + ")" + lda0 + " + rxa" + bca1 + lda1 + R"(;
fp32* pb[)" + BS0 + ", " + BS1 + "] = B + (offk + rkb" + bcb0 + ")" + ldb0 + " + ryb" + bcb1 + ldb1 + R"(;
fp32 a[)" + AS0 + ", " + AS1 + R"(] = *pa;
fp32 b[)" + BS0 + ", " + BS1 + R"(] = *pb;
int32 last_a = ((M*K - 1) - (TM*TK + 1)) / lda;
int32 last_b = ((K*N - 1) - (TN*TK + 1)) / ldb;
last_a = last_a / TK * TK;
last_b = last_b / TK * TK;
int32 bound = K - max(last_a, last_b);
for(int32 k = K; k > bound; k = k - TK){
c = dot()" + usea + ", " + useb + R"(, c);
pa = pa + TK)" + lda0 + R"(;
pb = pb + TK)" + ldb0 + R"(;
a = *pa;
b = *pb;
}
int32 rxc[TM] = get_global_range[TM](0);
int32 ryc[TN] = get_global_range[TN](1);
for(int32 k = bound; k > 0; k = k - 1){
int1 checka[TM, 1] = rxc[:, newaxis] < M;
int1 checkb[TN, 1] = ryc[:, newaxis] < N;
fp32* pa[TM, 1] = A + (offk + K - k))" + lda0 + " + rxc[:, newaxis]" + lda1 + R"(;
fp32* pb[TN, 1] = B + (offk + K - k))" + ldb0 + " + ryc[:, newaxis]" + ldb1 + R"(;
fp32 a[TM, 1] = checka ? *pa : 0;
fp32 b[TN, 1] = checkb ? *pb : 0;
c = dot(a, trans(b), c);
}
int32 ridx = get_range_id(0);
int32 ridy = get_range_id(1);
fp32* pc[TM, TN] = C + ryc[newaxis, :]*ldc + rxc[:, newaxis];
int32 *plock = locks + ridx + ridy*grid0;
while(__atomic_cas(plock, 0, 1));
int32 *pcount = plock + grid0*grid1;
int32 count = *pcount;
int32 countp1 = select(count == GZ - 1, 0, count + 1);
int1 checkc0[TM] = rxc < M;
int1 checkc1[TN] = ryc < N;
int1 checkc[TM, TN] = checkc0[:, newaxis] && checkc1[newaxis, :];
if(count == 0) {
@checkc *pc = c;
*pcount = countp1;
}
else {
@checkc *pc = c + *pc;
*pcount = countp1;
}
__atomic_cas(plock, 1, 0);
}
)";
return res;
}
)";
int main() {
bool AT = false;
bool BT = true;
// initialize default compute device
auto context = triton::driver::backend::contexts::get_default();
triton::jit jit(context);
@@ -128,16 +169,16 @@ int main() {
// just-in-time compile source-code
std::vector<unsigned> params = {
16, 2, 64, 16, 2, 64, 16, 8, 2, 2, 8, 8, 8, 1
};
// jit.autotune("matmul",src, benchmark);
jit.add_module("matmul", src, params);
std::string src = triton_source(AT, BT);
// jit.autotune("matmul",src.c_str(), benchmark);
jit.add_module("matmul", src.c_str(), {16, 2, 64, 16, 2, 64, 16, 8, 2, 2, 8, 8, 8, 1});
// jit.add_module("matmul", src.c_str(), {16, 2, 128, 32, 32, 32, 4, 2, 2, 8, 8, 4, 2, 1});
// jit.add_module("matmul", src.c_str(), {32, 64, 32, 64, 16, 8, 2, 2, 4, 2, 8, 4, 2, 1});
triton::driver::kernel* kernel = jit.get_function("matmul");
triton::jit::launch_information info = jit.get_launch_info("matmul");
std::cout << "Performance: " << benchmark(kernel, info) << " TFLOPS " << std::endl;
stream->read(dc, true, 0, hc);
simple_gemm<float,false,true>(rc, ha, hb, M, N, K);
simple_gemm<float>(AT, BT, rc, ha, hb, M, N, K);
for(size_t i = 0; i < M*N; i++)
if(std::abs(hc[i] - rc[i])/std::max(hc[i], rc[i]) > 1e-4){
std::cout << i << " " << hc[i] << " " << rc[i] << std::endl;

View File

@@ -88,6 +88,85 @@ void conv(read_only restrict fp32 *a,
@checkc *pc = C;
})";
void build_conv_lut(int TK,
int stride_d, int stride_h, int stride_w, int stride_c,
int pad_d, int pad_h, int pad_w,
int T, int R, int S,
std::vector<int>& res, std::vector<int>& masks) {
/* convolution parameters */
int F = T * R * S;
int Nlut = (TK + F - 1) / F * F;
int upsample_w = 1;
int upsample_h = 1;
int upsample_d = 1;
/* unpack index wrt filters */
auto unpack = [&](int32_t trs){
int32_t tr = trs / S;
int32_t s = trs - tr*S;
int32_t t = tr / R;
int32_t r = tr - t*R;
return std::make_tuple(t, r, s);
};
/* increments */
for(size_t i = 0; i < Nlut; ++i)
res[i] = (((i + TK) % Nlut) - i);
/* deltas */
size_t Ds0 = Nlut;
size_t Ds1 = upsample_w;
size_t Ds2 = upsample_h;
size_t Ds3 = upsample_d;
for(size_t pd = 0; pd < Ds3; ++pd)
for(size_t ph = 0; ph < Ds2; ++ph)
for(size_t pw = 0; pw < Ds1; ++pw){
int32_t* deltas_ptr = &res[Nlut + pw*Ds0 + ph*Ds0*Ds1 + pd*Ds0*Ds1*Ds2];
// cumulative increments
for(size_t i = 0; i < Ds0; ++i){
int32_t ctrs = i;
int32_t c = ctrs / F;
int32_t t, r, s;
std::tie(t, r, s) = unpack(ctrs % F);
// next indices
int32_t nextctrs = ctrs + TK;
int32_t nextc = nextctrs / F;
int32_t nextt, nextr, nexts;
std::tie(nextt, nextr, nexts) = unpack(nextctrs % F);
// diffs
int32_t cdiff = nextc - c;
int32_t tdiff = (nextt + pd)/upsample_d - (t + pd)/upsample_d;
int32_t rdiff = (nextr + ph)/upsample_h - (r + ph)/upsample_h;
int32_t sdiff = (nexts + pw)/upsample_w - (s + pw)/upsample_w;
// delta pointers
deltas_ptr[i] = cdiff*stride_c + sdiff*stride_w + rdiff*stride_h + tdiff*stride_d;
}
}
/* Masks */
size_t Ms0 = Nlut;
size_t Ms1 = 2*pad_w + 1;
size_t Ms2 = 2*pad_h + 1;
size_t Ms3 = 2*pad_d + 1;
for(size_t pd = 0; pd < Ms3; ++pd)
for(size_t ph = 0; ph < Ms2; ++ph)
for(size_t pw = 0; pw < Ms1; ++pw){
int32_t* masks_ptr = &masks[Nlut + pw*Ms0 + ph*Ms0*Ms1 + pd*Ms0*Ms1*Ms2];
for(size_t i = 0; i < Ms0; ++i){
int32_t t, r, s;
int32_t mask = 0x0;
for(size_t j = 0; j < TK; ++j){
std::tie(t, r, s) = unpack((i + j) % F);
bool in_bounds_d = (t + pd) >= pad_d && (t + pd) < (T + pad_d);
bool in_bounds_h = (r + ph) >= pad_h && (r + ph) < (R + pad_h);
bool in_bounds_w = (s + pw) >= pad_w && (s + pw) < (S + pad_w);
mask |= (in_bounds_d && in_bounds_h && in_bounds_w) << j;
}
masks_ptr[i] = mask;
}
}
for(size_t i = 0; i < Nlut; ++i)
masks[i] = 0x0;
}
torch::Tensor conv_forward(
const torch::Tensor data,
const torch::Tensor weight) {
@@ -95,37 +174,118 @@ torch::Tensor conv_forward(
CHECK_INPUT(data);
CHECK_INPUT(weight);
// Unpack data shapes
const auto B = data.size(0);
const auto Ci = data.size(1);
const auto H = data.size(2);
const auto W = data.size(3);
const int32_t B = data.size(0);
const int32_t Ci = data.size(1);
const int32_t H = data.size(2);
const int32_t W = data.size(3);
// Unpack weight shapes
const auto Cf = weight.size(0);
const auto R = weight.size(1);
const auto S = weight.size(2);
const auto K = weight.size(3);
const int32_t Cf = weight.size(0);
const int32_t T = 1;
const int32_t R = weight.size(1);
const int32_t S = weight.size(2);
const int32_t NF = weight.size(3);
// Conv parameters
int32_t upsample_d = 1, upsample_h = 1, upsample_w = 1;
int32_t pad_d = 0, pad_h = 0, pad_w = 0;
int32_t stride_h = 1, stride_w = 1;
// Output shapes
int32_t P = (H*upsample_h - R + 1 + 2*pad_h + stride_h - 1)/stride_h;
int32_t Q = (W*upsample_w - S + 1 + 2*pad_w + stride_w - 1)/stride_w;
// Allocate output
AT_CHECK(Ci == Cf, "Number of channels in data and weights must match");
torch::Tensor output = torch::empty({B, K, H, W}, torch::kFloat);
torch::Tensor output = torch::empty({B, NF, P, Q}, torch::kFloat).cuda();
// Wrap CUDA handles
triton::driver::cu_stream sstream(at::cuda::getCurrentCUDAStream(), false);
c10::DeviceIndex device = output.storage().device().index();
triton::driver::cu_stream sstream((CUstream)at::cuda::getCurrentCUDAStream(device).stream(), false);
triton::driver::stream* stream = &sstream;
triton::driver::context* ctx = stream->context();
triton::driver::cu_buffer d(ctx, (CUdeviceptr)data.storage().data(), false);
triton::driver::cu_buffer w(ctx, (CUdeviceptr)weight.storage().data(), false);
triton::driver::cu_buffer a(ctx, (CUdeviceptr)output.storage().data(), false);
// Create JIT
triton::jit jit(ctx);
std::vector<unsigned> params = {
16, 2, 64,
32, 2, 64,
16, 8, 2, 2,
8, 8,
8, 1, 8,
4
};
jit.add_module("conv", src, params);
triton::driver::kernel* kernel = jit.get_function("conv");
triton::jit::launch_information info = jit.get_launch_info("conv");
// launch info
unsigned TM = info.global_range_size[0];
unsigned TN = info.global_range_size[1];
unsigned TK = jit.get_int("TK");
// initialize constant memory
int FS = T*R*S;
int nlut = (TK + FS - 1) / FS * FS;
std::vector<int> h_delta(nlut + upsample_d*upsample_h*upsample_w*nlut);
std::vector<int> h_masks(nlut + (2*pad_h+1)*(2*pad_w+1)*(2*pad_d+1)*nlut);
// memory stride for images
int32_t stride_i_w = 1;
int32_t stride_i_h = W*stride_i_w;
int32_t stride_i_d = H*stride_i_h;
int32_t stride_i_c = 1*stride_i_d;
int32_t stride_i_n = Ci*stride_i_c;
// memory stride for activations
int32_t stride_o_q = 1;
int32_t stride_o_p = Q*stride_o_q;
int32_t stride_o_m = P*stride_o_p;
int32_t stride_o_k = 1*stride_o_m;
int32_t stride_o_n = NF*stride_o_k;
build_conv_lut(TK, stride_i_d, stride_i_h, stride_i_w, stride_i_c, pad_d, pad_h, pad_w, T, R, S, h_delta, h_masks);
// equivalent matmul dimensions
int32_t M = B*P*Q;
int32_t N = NF;
int32_t K = Ci*R*S;
triton::driver::buffer* delta = jit.get_buffer("delta");
triton::driver::buffer* masks = jit.get_buffer("masks");
stream->write(delta, false, 0, h_delta.size()*4, h_delta.data());
stream->write(masks, false, 0, h_masks.size()*4, h_masks.data());
// launch info
unsigned nthreads = info.num_threads;
std::array<size_t, 3> grid = {(M + TM - 1)/TM, (N + TN - 1)/TN, 1};
// fast bounds-checking
unsigned lasti = (grid[0]*TM - 1)*TM + TM - 1;
unsigned lastj = (grid[1]*TN - 1)*TN + TN - 1;
unsigned lastk = TK - 1;
bool AT = false;
bool BT = true;
unsigned last_safe_a = (AT==false)?(M*K - 1 - lasti)/M - lastk : M*K - 1 - lasti*K - lastk;
unsigned last_safe_b = (BT==true)?(N*K - 1 - lastj)/N - lastk : N*K - 1 - lastj*K - lastk;
int32_t bound = std::max<unsigned>(1, std::max(K - last_safe_a, K - last_safe_b));
// set arguments
kernel->setArg(0, *d.cu());
kernel->setArg(1, *w.cu());
kernel->setArg(2, *a.cu());
kernel->setArg(3, M);
kernel->setArg(4, N);
kernel->setArg(5, K);
kernel->setArg(6, B);
kernel->setArg(7, H);
kernel->setArg(8, W);
kernel->setArg(9, B);
kernel->setArg(10, NF);
kernel->setArg(11, P);
kernel->setArg(12, Q);
kernel->setArg(13, Ci);
kernel->setArg(14, R);
kernel->setArg(15, S);
kernel->setArg(16, stride_i_n);
kernel->setArg(17, stride_i_c);
kernel->setArg(18, stride_i_h);
kernel->setArg(19, stride_i_w);
kernel->setArg(20, stride_o_n);
kernel->setArg(21, stride_o_k);
kernel->setArg(22, stride_o_p);
kernel->setArg(23, stride_o_q);
kernel->setArg(24, pad_h);
kernel->setArg(25, pad_w);
kernel->setArg(26, bound);
// // dry run
stream->enqueue(kernel, grid, {nthreads, 1, 1});
return output;
}

View File

@@ -26,6 +26,7 @@ public:
// utilities
unsigned get_num_bytes(ir::value *x);
bool is_ld_padded(ir::value* x);
// accessors
unsigned get_offset(ir::value *x) const { return offsets_.at(x); }

View File

@@ -58,7 +58,6 @@ public:
target_(target) { }
void target_independent(ir::module &module) {
// ir::print(module, std::cout);
optimize_dot.run(module);
optimize_trans.run(module);
}

View File

@@ -525,10 +525,12 @@ void selection::create_tile(ir::value *v, IRBuilder<> &builder,
for(ir::value *op: user->ops())
create_tile(op, builder, references, seen, sh_mem_ptr);
LLVMContext &ctx = builder.getContext();
const auto& shapes = v->get_type()->get_tile_shapes();
std::vector<unsigned> shapes2;
for(ir::constant_int* shape: shapes)
shapes2.push_back(shape->get_value());
const auto& cshapes = v->get_type()->get_tile_shapes();
std::vector<unsigned> shapes;
for(ir::constant_int* shape: cshapes)
shapes.push_back(shape->get_value());
if(alloc_->is_ld_padded(v))
shapes[0] += 4;
Type* ty = llvm_type(v->get_type()->get_scalar_ty(), ctx);
// create shared tile
if(buffer_info_->is_shared(v)){
@@ -550,13 +552,13 @@ void selection::create_tile(ir::value *v, IRBuilder<> &builder,
Value *pre_ptr = builder.CreateGEP(sh_mem_ptr, builder.getInt32(alloc_->get_offset(phi)));
pre_ptr = builder.CreateBitCast(pre_ptr, ptr->getType());
Value *next_ptr = builder.CreateGEP(ptr, offset, "next_ptr");
tmap_.insert({phi, new shared_tile(ty, shapes2, ptr, builder, offset)});
tmap_.insert({phi, new shared_tile(ty, shapes, ptr, builder, offset)});
for(unsigned i = 0; i < phi->get_num_incoming(); i++) {
ir::basic_block* inc_block = phi->get_incoming_block(i);
ir::value* inc_value = phi->get_incoming_value(i);
ir::instruction* terminator = inc_block->get_inst_list().back();
bool is_loop_latch = buffer_info_->is_loop_latch(phi, terminator);
tmap_.insert({inc_value, new shared_tile(ty, shapes2, is_loop_latch?next_ptr:pre_ptr, builder)});
tmap_.insert({inc_value, new shared_tile(ty, shapes, is_loop_latch?next_ptr:pre_ptr, builder)});
}
}
else {
@@ -564,16 +566,16 @@ void selection::create_tile(ir::value *v, IRBuilder<> &builder,
size_t offset = alloc_->get_offset(v);
Value *ptr = builder.CreateGEP(sh_mem_ptr, builder.getInt32(offset));
ptr = builder.CreateBitCast(ptr, ptr_ty);
tmap_.insert({v, new shared_tile(ty, shapes2, ptr, builder)});
tmap_.insert({v, new shared_tile(ty, shapes, ptr, builder)});
}
}
}
// create distributed tile
else {
const auto &shapes = v->get_type()->get_tile_shapes();
std::vector<distributed_axis> axes(shapes.size());
for(size_t d = 0; d < shapes.size(); d++){
if(shapes[d]->get_value() > 1){
const auto &cshapes = v->get_type()->get_tile_shapes();
std::vector<distributed_axis> axes(cshapes.size());
for(size_t d = 0; d < cshapes.size(); d++){
if(cshapes[d]->get_value() > 1){
ir::metaparameter *x = params_->get_param(v, "nts.d" + std::to_string(d));
axes[d] = axes_.at(x);
}
@@ -583,7 +585,7 @@ void selection::create_tile(ir::value *v, IRBuilder<> &builder,
}
}
bool vectorize = dynamic_cast<ir::vectorize_inst*>(v);
distributed_tile *T = new distributed_tile(ty, shapes2, axes, builder, vectorize);
distributed_tile *T = new distributed_tile(ty, shapes, axes, builder, vectorize);
tmap_.insert({v, T});
// constant range
if(dynamic_cast<ir::constant*>(v) && !dynamic_cast<ir::undef_value*>(v)){

View File

@@ -10,8 +10,24 @@
namespace triton{
namespace codegen{
bool shmem_allocation::is_ld_padded(ir::value *x) {
if(auto* phi = dynamic_cast<ir::phi_node*>(x)) {
bool result = false;
for(unsigned i = 0; i < phi->get_num_incoming(); i++)
result = result | is_ld_padded(phi->get_incoming_value(i));
return result;
}
if(dynamic_cast<ir::trans_inst*>(x))
return true;
return false;
}
unsigned shmem_allocation::get_num_bytes(ir::value *x) {
unsigned result = x->get_type()->get_primitive_size_in_bits() / 8;
if(is_ld_padded(x)){
unsigned ld = x->get_type()->get_tile_shapes()[0]->get_value();
result += 4 * result / ld;
}
if(buffer_info_->is_double(x))
result *= 2;
return result;
@@ -23,8 +39,9 @@ void shmem_allocation::run(){
typedef std::multimap<unsigned, segment> triples_map_type;
std::vector<ir::value *> I;
for(auto x: liveness_->intervals())
for(auto x: liveness_->intervals()){
I.push_back(x.first);
}
std::vector<ir::value *> J = I;
triples_map_type H;