[codegen] added leading dimension padding for transposition in shared
memory
This commit is contained in:
@@ -14,6 +14,17 @@ void simple_gemm(std::vector<T> &c, const std::vector<T> &a, const std::vector<T
|
||||
}
|
||||
}
|
||||
|
||||
template<class T>
|
||||
void simple_gemm(bool AT, bool BT, std::vector<T> &c, const std::vector<T> &a, const std::vector<T> &b, size_t M, size_t N, size_t K) {
|
||||
if(AT && BT)
|
||||
simple_gemm<T, true, true>(c, a, b, M, N, K);
|
||||
else if(AT && !BT)
|
||||
simple_gemm<T, true, false>(c, a, b, M, N, K);
|
||||
else if(!AT && BT)
|
||||
simple_gemm<T, false, true>(c, a, b, M, N, K);
|
||||
else
|
||||
simple_gemm<T, false, false>(c, a, b, M, N, K);
|
||||
}
|
||||
|
||||
class timer{
|
||||
typedef std::chrono::high_resolution_clock high_resolution_clock;
|
||||
|
@@ -5,63 +5,104 @@
|
||||
#include "triton/driver/backend.h"
|
||||
#include "triton/driver/stream.h"
|
||||
|
||||
const char* src =
|
||||
R"(
|
||||
const tunable int32 TM = {16, 32, 64, 128};
|
||||
const tunable int32 TN = {16, 32, 64, 128};
|
||||
const tunable int32 TK = {8};
|
||||
const tunable int32 GZ = {1};
|
||||
|
||||
void matmul(restrict read_only fp32 *A, restrict read_only fp32 *B, fp32 *C,
|
||||
int32 M, int32 N, int32 K,
|
||||
int32 lda, int32 ldb, int32 ldc,
|
||||
int32 *locks, int32 grid0, int32 grid1) {
|
||||
int32 rxa[TM] = get_global_range[TM](0);
|
||||
int32 ryb[TN] = get_global_range[TN](1);
|
||||
int32 rz = get_global_range[1](2);
|
||||
int32 rka[TK] = 0 ... TK;
|
||||
int32 rkb[TK] = 0 ... TK;
|
||||
fp32 c[TM, TN] = 0;
|
||||
int32 div = K / GZ;
|
||||
int32 rem = K % GZ;
|
||||
K = select(rz < rem, div - 1, div);
|
||||
int32 offk = select(rz < rem, rz*(div + 1), rz*div + rem);
|
||||
fp32* pa[TM, TK] = A + (offk + rka[newaxis, :])*lda + rxa[:, newaxis];
|
||||
fp32* pb[TN, TK] = B + (offk + rkb[newaxis, :])*ldb + ryb[:, newaxis];
|
||||
fp32 a[TM, TK] = *pa;
|
||||
fp32 b[TN, TK] = *pb;
|
||||
int32 last_a = ((M*K - 1) - (TM*TK + 1)) / lda;
|
||||
int32 last_b = ((K*N - 1) - (TN*TK + 1)) / ldb;
|
||||
last_a = last_a / TK * TK;
|
||||
last_b = last_b / TK * TK;
|
||||
int32 bound = K - max(last_a, last_b);
|
||||
for(int32 k = K; k > bound; k = k - TK){
|
||||
c = dot(a, trans(b), c);
|
||||
pa = pa + TK*lda;
|
||||
pb = pb + TK*ldb;
|
||||
a = *pa;
|
||||
b = *pb;
|
||||
std::string triton_source(bool AT, bool BT) {
|
||||
std::string AS0 = "TM", AS1 = "TK";
|
||||
std::string BS0 = "TK", BS1 = "TN";
|
||||
std::string bca0 = "[newaxis, :]", bca1 = "[:, newaxis]";
|
||||
std::string bcb0 = "[:, newaxis]", bcb1 = "[newaxis, :]";
|
||||
std::string lda0 = "*lda", lda1 = "";
|
||||
std::string ldb0 = "", ldb1 = "*ldb";
|
||||
std::string usea = AT ? "trans(a)" : "a";
|
||||
std::string useb = BT ? "trans(b)" : "b";
|
||||
if(AT){
|
||||
std::swap(AS0, AS1);
|
||||
std::swap(bca0, bca1);
|
||||
std::swap(lda0, lda1);
|
||||
}
|
||||
int32 rxc[TM] = get_global_range[TM](0);
|
||||
int32 ryc[TN] = get_global_range[TN](1);
|
||||
for(int32 k = bound; k > 0; k = k - 1){
|
||||
int1 checka[TM, 1] = rxc[:, newaxis] < M;
|
||||
int1 checkb[TN, 1] = ryc[:, newaxis] < N;
|
||||
fp32* pa[TM, 1] = A + (offk + K - k)*lda + rxc[:, newaxis];
|
||||
fp32* pb[TN, 1] = B + (offk + K - k)*ldb + ryc[:, newaxis];
|
||||
fp32 a[TM, 1] = checka ? *pa : 0;
|
||||
fp32 b[TN, 1] = checkb ? *pb : 0;
|
||||
c = dot(a, trans(b), c);
|
||||
if(BT){
|
||||
std::swap(BS0, BS1);
|
||||
std::swap(bcb0, bcb1);
|
||||
std::swap(ldb0, ldb1);
|
||||
}
|
||||
fp32* pc[TM, TN] = C + ryc[newaxis, :]*ldc + rxc[:, newaxis];
|
||||
int1 checkc0[TM] = rxc < M;
|
||||
int1 checkc1[TN] = ryc < N;
|
||||
int1 checkc[TM, TN] = checkc0[:, newaxis] && checkc1[newaxis, :];
|
||||
@checkc *pc = c;
|
||||
std::string res =
|
||||
R"(
|
||||
const tunable int32 TM = {16, 32, 64, 128};
|
||||
const tunable int32 TN = {16, 32, 64, 128};
|
||||
const tunable int32 TK = {8};
|
||||
const tunable int32 GZ = {1};
|
||||
|
||||
void matmul(restrict read_only fp32 *A, restrict read_only fp32 *B, fp32 *C,
|
||||
int32 M, int32 N, int32 K,
|
||||
int32 lda, int32 ldb, int32 ldc,
|
||||
int32 *locks, int32 grid0, int32 grid1) {
|
||||
int32 rxa[TM] = get_global_range[TM](0);
|
||||
int32 ryb[TN] = get_global_range[TN](1);
|
||||
int32 rz = get_global_range[1](2);
|
||||
int32 rka[TK] = 0 ... TK;
|
||||
int32 rkb[TK] = 0 ... TK;
|
||||
fp32 c[TM, TN] = 0;
|
||||
int32 div = K / GZ;
|
||||
int32 rem = K % GZ;
|
||||
K = select(rz < rem, div - 1, div);
|
||||
int32 offk = select(rz < rem, rz*(div + 1), rz*div + rem);
|
||||
fp32* pa[)" + AS0 + ", " + AS1 + "] = A + (offk + rka" + bca0 + ")" + lda0 + " + rxa" + bca1 + lda1 + R"(;
|
||||
fp32* pb[)" + BS0 + ", " + BS1 + "] = B + (offk + rkb" + bcb0 + ")" + ldb0 + " + ryb" + bcb1 + ldb1 + R"(;
|
||||
fp32 a[)" + AS0 + ", " + AS1 + R"(] = *pa;
|
||||
fp32 b[)" + BS0 + ", " + BS1 + R"(] = *pb;
|
||||
int32 last_a = ((M*K - 1) - (TM*TK + 1)) / lda;
|
||||
int32 last_b = ((K*N - 1) - (TN*TK + 1)) / ldb;
|
||||
last_a = last_a / TK * TK;
|
||||
last_b = last_b / TK * TK;
|
||||
int32 bound = K - max(last_a, last_b);
|
||||
for(int32 k = K; k > bound; k = k - TK){
|
||||
c = dot()" + usea + ", " + useb + R"(, c);
|
||||
pa = pa + TK)" + lda0 + R"(;
|
||||
pb = pb + TK)" + ldb0 + R"(;
|
||||
a = *pa;
|
||||
b = *pb;
|
||||
}
|
||||
int32 rxc[TM] = get_global_range[TM](0);
|
||||
int32 ryc[TN] = get_global_range[TN](1);
|
||||
for(int32 k = bound; k > 0; k = k - 1){
|
||||
int1 checka[TM, 1] = rxc[:, newaxis] < M;
|
||||
int1 checkb[TN, 1] = ryc[:, newaxis] < N;
|
||||
fp32* pa[TM, 1] = A + (offk + K - k))" + lda0 + " + rxc[:, newaxis]" + lda1 + R"(;
|
||||
fp32* pb[TN, 1] = B + (offk + K - k))" + ldb0 + " + ryc[:, newaxis]" + ldb1 + R"(;
|
||||
fp32 a[TM, 1] = checka ? *pa : 0;
|
||||
fp32 b[TN, 1] = checkb ? *pb : 0;
|
||||
c = dot(a, trans(b), c);
|
||||
}
|
||||
int32 ridx = get_range_id(0);
|
||||
int32 ridy = get_range_id(1);
|
||||
fp32* pc[TM, TN] = C + ryc[newaxis, :]*ldc + rxc[:, newaxis];
|
||||
int32 *plock = locks + ridx + ridy*grid0;
|
||||
while(__atomic_cas(plock, 0, 1));
|
||||
int32 *pcount = plock + grid0*grid1;
|
||||
int32 count = *pcount;
|
||||
int32 countp1 = select(count == GZ - 1, 0, count + 1);
|
||||
int1 checkc0[TM] = rxc < M;
|
||||
int1 checkc1[TN] = ryc < N;
|
||||
int1 checkc[TM, TN] = checkc0[:, newaxis] && checkc1[newaxis, :];
|
||||
if(count == 0) {
|
||||
@checkc *pc = c;
|
||||
*pcount = countp1;
|
||||
}
|
||||
else {
|
||||
@checkc *pc = c + *pc;
|
||||
*pcount = countp1;
|
||||
}
|
||||
__atomic_cas(plock, 1, 0);
|
||||
}
|
||||
)";
|
||||
return res;
|
||||
}
|
||||
)";
|
||||
|
||||
|
||||
int main() {
|
||||
bool AT = false;
|
||||
bool BT = true;
|
||||
|
||||
// initialize default compute device
|
||||
auto context = triton::driver::backend::contexts::get_default();
|
||||
triton::jit jit(context);
|
||||
@@ -128,16 +169,16 @@ int main() {
|
||||
|
||||
|
||||
// just-in-time compile source-code
|
||||
std::vector<unsigned> params = {
|
||||
16, 2, 64, 16, 2, 64, 16, 8, 2, 2, 8, 8, 8, 1
|
||||
};
|
||||
// jit.autotune("matmul",src, benchmark);
|
||||
jit.add_module("matmul", src, params);
|
||||
std::string src = triton_source(AT, BT);
|
||||
// jit.autotune("matmul",src.c_str(), benchmark);
|
||||
jit.add_module("matmul", src.c_str(), {16, 2, 64, 16, 2, 64, 16, 8, 2, 2, 8, 8, 8, 1});
|
||||
// jit.add_module("matmul", src.c_str(), {16, 2, 128, 32, 32, 32, 4, 2, 2, 8, 8, 4, 2, 1});
|
||||
// jit.add_module("matmul", src.c_str(), {32, 64, 32, 64, 16, 8, 2, 2, 4, 2, 8, 4, 2, 1});
|
||||
triton::driver::kernel* kernel = jit.get_function("matmul");
|
||||
triton::jit::launch_information info = jit.get_launch_info("matmul");
|
||||
std::cout << "Performance: " << benchmark(kernel, info) << " TFLOPS " << std::endl;
|
||||
stream->read(dc, true, 0, hc);
|
||||
simple_gemm<float,false,true>(rc, ha, hb, M, N, K);
|
||||
simple_gemm<float>(AT, BT, rc, ha, hb, M, N, K);
|
||||
for(size_t i = 0; i < M*N; i++)
|
||||
if(std::abs(hc[i] - rc[i])/std::max(hc[i], rc[i]) > 1e-4){
|
||||
std::cout << i << " " << hc[i] << " " << rc[i] << std::endl;
|
||||
|
@@ -88,6 +88,85 @@ void conv(read_only restrict fp32 *a,
|
||||
@checkc *pc = C;
|
||||
})";
|
||||
|
||||
void build_conv_lut(int TK,
|
||||
int stride_d, int stride_h, int stride_w, int stride_c,
|
||||
int pad_d, int pad_h, int pad_w,
|
||||
int T, int R, int S,
|
||||
std::vector<int>& res, std::vector<int>& masks) {
|
||||
/* convolution parameters */
|
||||
int F = T * R * S;
|
||||
int Nlut = (TK + F - 1) / F * F;
|
||||
int upsample_w = 1;
|
||||
int upsample_h = 1;
|
||||
int upsample_d = 1;
|
||||
/* unpack index wrt filters */
|
||||
auto unpack = [&](int32_t trs){
|
||||
int32_t tr = trs / S;
|
||||
int32_t s = trs - tr*S;
|
||||
int32_t t = tr / R;
|
||||
int32_t r = tr - t*R;
|
||||
return std::make_tuple(t, r, s);
|
||||
};
|
||||
/* increments */
|
||||
for(size_t i = 0; i < Nlut; ++i)
|
||||
res[i] = (((i + TK) % Nlut) - i);
|
||||
/* deltas */
|
||||
size_t Ds0 = Nlut;
|
||||
size_t Ds1 = upsample_w;
|
||||
size_t Ds2 = upsample_h;
|
||||
size_t Ds3 = upsample_d;
|
||||
for(size_t pd = 0; pd < Ds3; ++pd)
|
||||
for(size_t ph = 0; ph < Ds2; ++ph)
|
||||
for(size_t pw = 0; pw < Ds1; ++pw){
|
||||
int32_t* deltas_ptr = &res[Nlut + pw*Ds0 + ph*Ds0*Ds1 + pd*Ds0*Ds1*Ds2];
|
||||
// cumulative increments
|
||||
for(size_t i = 0; i < Ds0; ++i){
|
||||
int32_t ctrs = i;
|
||||
int32_t c = ctrs / F;
|
||||
int32_t t, r, s;
|
||||
std::tie(t, r, s) = unpack(ctrs % F);
|
||||
// next indices
|
||||
int32_t nextctrs = ctrs + TK;
|
||||
int32_t nextc = nextctrs / F;
|
||||
int32_t nextt, nextr, nexts;
|
||||
std::tie(nextt, nextr, nexts) = unpack(nextctrs % F);
|
||||
// diffs
|
||||
int32_t cdiff = nextc - c;
|
||||
int32_t tdiff = (nextt + pd)/upsample_d - (t + pd)/upsample_d;
|
||||
int32_t rdiff = (nextr + ph)/upsample_h - (r + ph)/upsample_h;
|
||||
int32_t sdiff = (nexts + pw)/upsample_w - (s + pw)/upsample_w;
|
||||
// delta pointers
|
||||
deltas_ptr[i] = cdiff*stride_c + sdiff*stride_w + rdiff*stride_h + tdiff*stride_d;
|
||||
}
|
||||
}
|
||||
|
||||
/* Masks */
|
||||
size_t Ms0 = Nlut;
|
||||
size_t Ms1 = 2*pad_w + 1;
|
||||
size_t Ms2 = 2*pad_h + 1;
|
||||
size_t Ms3 = 2*pad_d + 1;
|
||||
|
||||
for(size_t pd = 0; pd < Ms3; ++pd)
|
||||
for(size_t ph = 0; ph < Ms2; ++ph)
|
||||
for(size_t pw = 0; pw < Ms1; ++pw){
|
||||
int32_t* masks_ptr = &masks[Nlut + pw*Ms0 + ph*Ms0*Ms1 + pd*Ms0*Ms1*Ms2];
|
||||
for(size_t i = 0; i < Ms0; ++i){
|
||||
int32_t t, r, s;
|
||||
int32_t mask = 0x0;
|
||||
for(size_t j = 0; j < TK; ++j){
|
||||
std::tie(t, r, s) = unpack((i + j) % F);
|
||||
bool in_bounds_d = (t + pd) >= pad_d && (t + pd) < (T + pad_d);
|
||||
bool in_bounds_h = (r + ph) >= pad_h && (r + ph) < (R + pad_h);
|
||||
bool in_bounds_w = (s + pw) >= pad_w && (s + pw) < (S + pad_w);
|
||||
mask |= (in_bounds_d && in_bounds_h && in_bounds_w) << j;
|
||||
}
|
||||
masks_ptr[i] = mask;
|
||||
}
|
||||
}
|
||||
for(size_t i = 0; i < Nlut; ++i)
|
||||
masks[i] = 0x0;
|
||||
}
|
||||
|
||||
torch::Tensor conv_forward(
|
||||
const torch::Tensor data,
|
||||
const torch::Tensor weight) {
|
||||
@@ -95,37 +174,118 @@ torch::Tensor conv_forward(
|
||||
CHECK_INPUT(data);
|
||||
CHECK_INPUT(weight);
|
||||
// Unpack data shapes
|
||||
const auto B = data.size(0);
|
||||
const auto Ci = data.size(1);
|
||||
const auto H = data.size(2);
|
||||
const auto W = data.size(3);
|
||||
const int32_t B = data.size(0);
|
||||
const int32_t Ci = data.size(1);
|
||||
const int32_t H = data.size(2);
|
||||
const int32_t W = data.size(3);
|
||||
// Unpack weight shapes
|
||||
const auto Cf = weight.size(0);
|
||||
const auto R = weight.size(1);
|
||||
const auto S = weight.size(2);
|
||||
const auto K = weight.size(3);
|
||||
const int32_t Cf = weight.size(0);
|
||||
const int32_t T = 1;
|
||||
const int32_t R = weight.size(1);
|
||||
const int32_t S = weight.size(2);
|
||||
const int32_t NF = weight.size(3);
|
||||
// Conv parameters
|
||||
int32_t upsample_d = 1, upsample_h = 1, upsample_w = 1;
|
||||
int32_t pad_d = 0, pad_h = 0, pad_w = 0;
|
||||
int32_t stride_h = 1, stride_w = 1;
|
||||
// Output shapes
|
||||
int32_t P = (H*upsample_h - R + 1 + 2*pad_h + stride_h - 1)/stride_h;
|
||||
int32_t Q = (W*upsample_w - S + 1 + 2*pad_w + stride_w - 1)/stride_w;
|
||||
// Allocate output
|
||||
AT_CHECK(Ci == Cf, "Number of channels in data and weights must match");
|
||||
torch::Tensor output = torch::empty({B, K, H, W}, torch::kFloat);
|
||||
torch::Tensor output = torch::empty({B, NF, P, Q}, torch::kFloat).cuda();
|
||||
// Wrap CUDA handles
|
||||
triton::driver::cu_stream sstream(at::cuda::getCurrentCUDAStream(), false);
|
||||
c10::DeviceIndex device = output.storage().device().index();
|
||||
triton::driver::cu_stream sstream((CUstream)at::cuda::getCurrentCUDAStream(device).stream(), false);
|
||||
triton::driver::stream* stream = &sstream;
|
||||
triton::driver::context* ctx = stream->context();
|
||||
triton::driver::cu_buffer d(ctx, (CUdeviceptr)data.storage().data(), false);
|
||||
triton::driver::cu_buffer w(ctx, (CUdeviceptr)weight.storage().data(), false);
|
||||
triton::driver::cu_buffer a(ctx, (CUdeviceptr)output.storage().data(), false);
|
||||
// Create JIT
|
||||
triton::jit jit(ctx);
|
||||
std::vector<unsigned> params = {
|
||||
16, 2, 64,
|
||||
32, 2, 64,
|
||||
16, 8, 2, 2,
|
||||
8, 8,
|
||||
8, 1, 8,
|
||||
4
|
||||
};
|
||||
jit.add_module("conv", src, params);
|
||||
triton::driver::kernel* kernel = jit.get_function("conv");
|
||||
triton::jit::launch_information info = jit.get_launch_info("conv");
|
||||
|
||||
// launch info
|
||||
unsigned TM = info.global_range_size[0];
|
||||
unsigned TN = info.global_range_size[1];
|
||||
unsigned TK = jit.get_int("TK");
|
||||
// initialize constant memory
|
||||
int FS = T*R*S;
|
||||
int nlut = (TK + FS - 1) / FS * FS;
|
||||
std::vector<int> h_delta(nlut + upsample_d*upsample_h*upsample_w*nlut);
|
||||
std::vector<int> h_masks(nlut + (2*pad_h+1)*(2*pad_w+1)*(2*pad_d+1)*nlut);
|
||||
// memory stride for images
|
||||
int32_t stride_i_w = 1;
|
||||
int32_t stride_i_h = W*stride_i_w;
|
||||
int32_t stride_i_d = H*stride_i_h;
|
||||
int32_t stride_i_c = 1*stride_i_d;
|
||||
int32_t stride_i_n = Ci*stride_i_c;
|
||||
// memory stride for activations
|
||||
int32_t stride_o_q = 1;
|
||||
int32_t stride_o_p = Q*stride_o_q;
|
||||
int32_t stride_o_m = P*stride_o_p;
|
||||
int32_t stride_o_k = 1*stride_o_m;
|
||||
int32_t stride_o_n = NF*stride_o_k;
|
||||
build_conv_lut(TK, stride_i_d, stride_i_h, stride_i_w, stride_i_c, pad_d, pad_h, pad_w, T, R, S, h_delta, h_masks);
|
||||
// equivalent matmul dimensions
|
||||
int32_t M = B*P*Q;
|
||||
int32_t N = NF;
|
||||
int32_t K = Ci*R*S;
|
||||
triton::driver::buffer* delta = jit.get_buffer("delta");
|
||||
triton::driver::buffer* masks = jit.get_buffer("masks");
|
||||
stream->write(delta, false, 0, h_delta.size()*4, h_delta.data());
|
||||
stream->write(masks, false, 0, h_masks.size()*4, h_masks.data());
|
||||
// launch info
|
||||
unsigned nthreads = info.num_threads;
|
||||
std::array<size_t, 3> grid = {(M + TM - 1)/TM, (N + TN - 1)/TN, 1};
|
||||
// fast bounds-checking
|
||||
unsigned lasti = (grid[0]*TM - 1)*TM + TM - 1;
|
||||
unsigned lastj = (grid[1]*TN - 1)*TN + TN - 1;
|
||||
unsigned lastk = TK - 1;
|
||||
bool AT = false;
|
||||
bool BT = true;
|
||||
unsigned last_safe_a = (AT==false)?(M*K - 1 - lasti)/M - lastk : M*K - 1 - lasti*K - lastk;
|
||||
unsigned last_safe_b = (BT==true)?(N*K - 1 - lastj)/N - lastk : N*K - 1 - lastj*K - lastk;
|
||||
int32_t bound = std::max<unsigned>(1, std::max(K - last_safe_a, K - last_safe_b));
|
||||
// set arguments
|
||||
kernel->setArg(0, *d.cu());
|
||||
kernel->setArg(1, *w.cu());
|
||||
kernel->setArg(2, *a.cu());
|
||||
kernel->setArg(3, M);
|
||||
kernel->setArg(4, N);
|
||||
kernel->setArg(5, K);
|
||||
kernel->setArg(6, B);
|
||||
kernel->setArg(7, H);
|
||||
kernel->setArg(8, W);
|
||||
kernel->setArg(9, B);
|
||||
kernel->setArg(10, NF);
|
||||
kernel->setArg(11, P);
|
||||
kernel->setArg(12, Q);
|
||||
kernel->setArg(13, Ci);
|
||||
kernel->setArg(14, R);
|
||||
kernel->setArg(15, S);
|
||||
kernel->setArg(16, stride_i_n);
|
||||
kernel->setArg(17, stride_i_c);
|
||||
kernel->setArg(18, stride_i_h);
|
||||
kernel->setArg(19, stride_i_w);
|
||||
kernel->setArg(20, stride_o_n);
|
||||
kernel->setArg(21, stride_o_k);
|
||||
kernel->setArg(22, stride_o_p);
|
||||
kernel->setArg(23, stride_o_q);
|
||||
kernel->setArg(24, pad_h);
|
||||
kernel->setArg(25, pad_w);
|
||||
kernel->setArg(26, bound);
|
||||
// // dry run
|
||||
stream->enqueue(kernel, grid, {nthreads, 1, 1});
|
||||
return output;
|
||||
}
|
||||
|
||||
|
@@ -26,6 +26,7 @@ public:
|
||||
|
||||
// utilities
|
||||
unsigned get_num_bytes(ir::value *x);
|
||||
bool is_ld_padded(ir::value* x);
|
||||
|
||||
// accessors
|
||||
unsigned get_offset(ir::value *x) const { return offsets_.at(x); }
|
||||
|
@@ -58,7 +58,6 @@ public:
|
||||
target_(target) { }
|
||||
|
||||
void target_independent(ir::module &module) {
|
||||
// ir::print(module, std::cout);
|
||||
optimize_dot.run(module);
|
||||
optimize_trans.run(module);
|
||||
}
|
||||
|
@@ -525,10 +525,12 @@ void selection::create_tile(ir::value *v, IRBuilder<> &builder,
|
||||
for(ir::value *op: user->ops())
|
||||
create_tile(op, builder, references, seen, sh_mem_ptr);
|
||||
LLVMContext &ctx = builder.getContext();
|
||||
const auto& shapes = v->get_type()->get_tile_shapes();
|
||||
std::vector<unsigned> shapes2;
|
||||
for(ir::constant_int* shape: shapes)
|
||||
shapes2.push_back(shape->get_value());
|
||||
const auto& cshapes = v->get_type()->get_tile_shapes();
|
||||
std::vector<unsigned> shapes;
|
||||
for(ir::constant_int* shape: cshapes)
|
||||
shapes.push_back(shape->get_value());
|
||||
if(alloc_->is_ld_padded(v))
|
||||
shapes[0] += 4;
|
||||
Type* ty = llvm_type(v->get_type()->get_scalar_ty(), ctx);
|
||||
// create shared tile
|
||||
if(buffer_info_->is_shared(v)){
|
||||
@@ -550,13 +552,13 @@ void selection::create_tile(ir::value *v, IRBuilder<> &builder,
|
||||
Value *pre_ptr = builder.CreateGEP(sh_mem_ptr, builder.getInt32(alloc_->get_offset(phi)));
|
||||
pre_ptr = builder.CreateBitCast(pre_ptr, ptr->getType());
|
||||
Value *next_ptr = builder.CreateGEP(ptr, offset, "next_ptr");
|
||||
tmap_.insert({phi, new shared_tile(ty, shapes2, ptr, builder, offset)});
|
||||
tmap_.insert({phi, new shared_tile(ty, shapes, ptr, builder, offset)});
|
||||
for(unsigned i = 0; i < phi->get_num_incoming(); i++) {
|
||||
ir::basic_block* inc_block = phi->get_incoming_block(i);
|
||||
ir::value* inc_value = phi->get_incoming_value(i);
|
||||
ir::instruction* terminator = inc_block->get_inst_list().back();
|
||||
bool is_loop_latch = buffer_info_->is_loop_latch(phi, terminator);
|
||||
tmap_.insert({inc_value, new shared_tile(ty, shapes2, is_loop_latch?next_ptr:pre_ptr, builder)});
|
||||
tmap_.insert({inc_value, new shared_tile(ty, shapes, is_loop_latch?next_ptr:pre_ptr, builder)});
|
||||
}
|
||||
}
|
||||
else {
|
||||
@@ -564,16 +566,16 @@ void selection::create_tile(ir::value *v, IRBuilder<> &builder,
|
||||
size_t offset = alloc_->get_offset(v);
|
||||
Value *ptr = builder.CreateGEP(sh_mem_ptr, builder.getInt32(offset));
|
||||
ptr = builder.CreateBitCast(ptr, ptr_ty);
|
||||
tmap_.insert({v, new shared_tile(ty, shapes2, ptr, builder)});
|
||||
tmap_.insert({v, new shared_tile(ty, shapes, ptr, builder)});
|
||||
}
|
||||
}
|
||||
}
|
||||
// create distributed tile
|
||||
else {
|
||||
const auto &shapes = v->get_type()->get_tile_shapes();
|
||||
std::vector<distributed_axis> axes(shapes.size());
|
||||
for(size_t d = 0; d < shapes.size(); d++){
|
||||
if(shapes[d]->get_value() > 1){
|
||||
const auto &cshapes = v->get_type()->get_tile_shapes();
|
||||
std::vector<distributed_axis> axes(cshapes.size());
|
||||
for(size_t d = 0; d < cshapes.size(); d++){
|
||||
if(cshapes[d]->get_value() > 1){
|
||||
ir::metaparameter *x = params_->get_param(v, "nts.d" + std::to_string(d));
|
||||
axes[d] = axes_.at(x);
|
||||
}
|
||||
@@ -583,7 +585,7 @@ void selection::create_tile(ir::value *v, IRBuilder<> &builder,
|
||||
}
|
||||
}
|
||||
bool vectorize = dynamic_cast<ir::vectorize_inst*>(v);
|
||||
distributed_tile *T = new distributed_tile(ty, shapes2, axes, builder, vectorize);
|
||||
distributed_tile *T = new distributed_tile(ty, shapes, axes, builder, vectorize);
|
||||
tmap_.insert({v, T});
|
||||
// constant range
|
||||
if(dynamic_cast<ir::constant*>(v) && !dynamic_cast<ir::undef_value*>(v)){
|
||||
|
@@ -10,8 +10,24 @@
|
||||
namespace triton{
|
||||
namespace codegen{
|
||||
|
||||
bool shmem_allocation::is_ld_padded(ir::value *x) {
|
||||
if(auto* phi = dynamic_cast<ir::phi_node*>(x)) {
|
||||
bool result = false;
|
||||
for(unsigned i = 0; i < phi->get_num_incoming(); i++)
|
||||
result = result | is_ld_padded(phi->get_incoming_value(i));
|
||||
return result;
|
||||
}
|
||||
if(dynamic_cast<ir::trans_inst*>(x))
|
||||
return true;
|
||||
return false;
|
||||
}
|
||||
|
||||
unsigned shmem_allocation::get_num_bytes(ir::value *x) {
|
||||
unsigned result = x->get_type()->get_primitive_size_in_bits() / 8;
|
||||
if(is_ld_padded(x)){
|
||||
unsigned ld = x->get_type()->get_tile_shapes()[0]->get_value();
|
||||
result += 4 * result / ld;
|
||||
}
|
||||
if(buffer_info_->is_double(x))
|
||||
result *= 2;
|
||||
return result;
|
||||
@@ -23,8 +39,9 @@ void shmem_allocation::run(){
|
||||
typedef std::multimap<unsigned, segment> triples_map_type;
|
||||
|
||||
std::vector<ir::value *> I;
|
||||
for(auto x: liveness_->intervals())
|
||||
for(auto x: liveness_->intervals()){
|
||||
I.push_back(x.first);
|
||||
}
|
||||
std::vector<ir::value *> J = I;
|
||||
|
||||
triples_map_type H;
|
||||
|
Reference in New Issue
Block a user