[general] major overhaul of triton-c/triton-ir/triton-jit:
- Added alloc const - Added atomics - Pruning tuning space - Added example for dot/conv/shift - Bugfixes
This commit is contained in:
@@ -1,6 +1 @@
|
||||
foreach(PROG matrix)
|
||||
add_executable(${PROG} ${PROG}.cpp)
|
||||
set_target_properties(${PROG} PROPERTIES OUTPUT_NAME ${PROG})
|
||||
include_directories(/usr/local/cuda/include/)
|
||||
target_link_libraries(${PROG} triton)
|
||||
endforeach(PROG)
|
||||
add_subdirectory(cpp)
|
||||
|
6
examples/cpp/CMakeLists.txt
Normal file
6
examples/cpp/CMakeLists.txt
Normal file
@@ -0,0 +1,6 @@
|
||||
foreach(PROG dot conv shift)
|
||||
add_executable(${PROG} ${PROG}.cpp)
|
||||
set_target_properties(${PROG} PROPERTIES OUTPUT_NAME ${PROG})
|
||||
include_directories(/usr/local/cuda/include/)
|
||||
target_link_libraries(${PROG} triton)
|
||||
endforeach(PROG)
|
@@ -1,17 +1,18 @@
|
||||
#include <cstring>
|
||||
#include <cstdio>
|
||||
#include "common.hpp"
|
||||
#include "triton/jit.h"
|
||||
#include "triton/driver/backend.h"
|
||||
#include "triton/driver/stream.h"
|
||||
|
||||
const char* src =
|
||||
R"(
|
||||
const tunable int32 TM = {16, 32, 64};
|
||||
const tunable int32 TN = {16, 32, 64};
|
||||
const tunable int32 TM = {16, 32, 64, 128};
|
||||
const tunable int32 TN = {8};
|
||||
const tunable int32 TK = {8};
|
||||
|
||||
void matmul(restrict read_only fp32 *a, restrict read_only fp32 *b, fp32 *c,
|
||||
int32 M, int32 N, int32 K, int32 bound){
|
||||
void blocksparse(restrict read_only fp32 *a, restrict read_only fp32 *b, fp32 *c,
|
||||
int32 M, int32 N, int32 K, int32 bound){
|
||||
int32 rxa[TM] = get_global_range[TM](0);
|
||||
int32 ryb[TN] = get_global_range[TN](1);
|
||||
int32 rka[TK] = 0 ... TK;
|
||||
@@ -22,9 +23,9 @@ void matmul(restrict read_only fp32 *a, restrict read_only fp32 *b, fp32 *c,
|
||||
fp32 a[TM, TK] = *pa;
|
||||
fp32 b[TN, TK] = *pb;
|
||||
for(int32 k = K; k > 0;){
|
||||
C = dot(a, b, C);
|
||||
C = dot(a, trans(b), C);
|
||||
pa = pa + TK*M;
|
||||
pb = pb + TK*K;
|
||||
pb = pb + TK*N;
|
||||
k = k - TK;
|
||||
int1 checka[TM, TK] = k > bound;
|
||||
int1 checkb[TN, TK] = k > bound;
|
||||
@@ -51,71 +52,24 @@ void matmul(restrict read_only fp32 *a, restrict read_only fp32 *b, fp32 *c,
|
||||
}
|
||||
)";
|
||||
|
||||
|
||||
template<class T>
|
||||
void simple_gemm(std::vector<T> &c, const std::vector<T> &a, const std::vector<T> &b, size_t M, size_t N, size_t K){
|
||||
for(size_t m = 0; m < M; m++)
|
||||
for(size_t n = 0; n < N; n++){
|
||||
T acc = 0;
|
||||
for(size_t k = 0; k < K; k++)
|
||||
acc += a[m + k*M] * b[n + k*N];
|
||||
c[m + n*M] = acc;
|
||||
std::vector<int> make_deltas(std::vector<int> mask, int K, int N){
|
||||
std::vector<std::vector<std::pair<int,int>>> pairs(N);
|
||||
unsigned int current = 0;
|
||||
for(int k = 0; k < K; k++)
|
||||
for(int n = 0; n < N; n++){
|
||||
if(mask[k + n*K])
|
||||
pairs[n].push_back({current, k});
|
||||
}
|
||||
}
|
||||
|
||||
class timer{
|
||||
typedef std::chrono::high_resolution_clock high_resolution_clock;
|
||||
typedef std::chrono::nanoseconds nanoseconds;
|
||||
|
||||
public:
|
||||
explicit timer(bool run = false)
|
||||
{ if (run) start(); }
|
||||
|
||||
void start()
|
||||
{ _start = high_resolution_clock::now(); }
|
||||
|
||||
nanoseconds get() const
|
||||
{ return std::chrono::duration_cast<nanoseconds>(high_resolution_clock::now() - _start); }
|
||||
|
||||
private:
|
||||
high_resolution_clock::time_point _start;
|
||||
};
|
||||
|
||||
template<class T>
|
||||
T min(std::vector<T> x)
|
||||
{ return *std::min_element(x.begin(), x.end()); }
|
||||
|
||||
|
||||
template<class OP, class SYNC>
|
||||
double bench(OP const & op, SYNC const & sync, triton::driver::device const & device)
|
||||
{
|
||||
timer tmr;
|
||||
std::vector<size_t> times;
|
||||
double total_time = 0;
|
||||
op();
|
||||
sync();
|
||||
while(total_time*1e-9 < 1e-3){
|
||||
float norm = 1;
|
||||
// normalize clock if possible to get roughly constant result
|
||||
if(auto cu_device = dynamic_cast<const triton::driver::cu_device*>(&device))
|
||||
norm = (float)cu_device->current_sm_clock()/cu_device->max_sm_clock();
|
||||
tmr.start();
|
||||
op();
|
||||
sync();
|
||||
times.push_back(norm*tmr.get().count());
|
||||
total_time+=times.back();
|
||||
}
|
||||
return min(times);
|
||||
}
|
||||
|
||||
|
||||
int main() {
|
||||
// initialize default compute device
|
||||
auto context = triton::driver::backend::contexts::get_default();
|
||||
triton::jit jit(context);
|
||||
|
||||
|
||||
// matrix multiplication parameters
|
||||
int32_t M = 512, N = 512, K = 512;
|
||||
int32_t M = 512, N = 32, K = 2048;
|
||||
std::vector<float> hc(M*N);
|
||||
std::vector<float> rc(M*N);
|
||||
std::vector<float> ha(M*K);
|
||||
@@ -183,14 +137,13 @@ int main() {
|
||||
8, 8,
|
||||
4
|
||||
};
|
||||
|
||||
jit.autotune(src, benchmark);
|
||||
jit.add_module(src, params);
|
||||
jit.autotune("matmul",src, benchmark);
|
||||
jit.add_module("matmul", src, params);
|
||||
triton::driver::kernel* kernel = jit.get_function("matmul");
|
||||
triton::jit::launch_information info = jit.get_launch_info("matmul");
|
||||
std::cout << benchmark(kernel, info) << std::endl;
|
||||
std::cout << "Performance: " << benchmark(kernel, info) << " TFLOPS " << std::endl;
|
||||
stream->read(dc, true, 0, hc);
|
||||
simple_gemm(rc, ha, hb, M, N, K);
|
||||
simple_gemm<float,false,true>(rc, ha, hb, M, N, K);
|
||||
for(size_t i = 0; i < M*N; i++)
|
||||
if(std::abs(hc[i] - rc[i])/std::max(hc[i], rc[i]) > 1e-4){
|
||||
std::cout << i << " " << hc[i] << " " << rc[i] << std::endl;
|
286
examples/cpp/common.hpp
Normal file
286
examples/cpp/common.hpp
Normal file
@@ -0,0 +1,286 @@
|
||||
#include <vector>
|
||||
#include <chrono>
|
||||
#include "triton/driver/device.h"
|
||||
#include <algorithm>
|
||||
|
||||
template<class T, bool AT, bool BT>
|
||||
void simple_gemm(std::vector<T> &c, const std::vector<T> &a, const std::vector<T> &b, size_t M, size_t N, size_t K){
|
||||
for(size_t m = 0; m < M; m++)
|
||||
for(size_t n = 0; n < N; n++){
|
||||
T acc = 0;
|
||||
for(size_t k = 0; k < K; k++)
|
||||
acc += (AT?a[k + m*K]:a[m + k*M]) * (BT?b[n + k*N]:b[k + n*K]);
|
||||
c[m + n*M] = acc;
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
class timer{
|
||||
typedef std::chrono::high_resolution_clock high_resolution_clock;
|
||||
typedef std::chrono::nanoseconds nanoseconds;
|
||||
|
||||
public:
|
||||
explicit timer(bool run = false)
|
||||
{ if (run) start(); }
|
||||
|
||||
void start()
|
||||
{ _start = high_resolution_clock::now(); }
|
||||
|
||||
nanoseconds get() const
|
||||
{ return std::chrono::duration_cast<nanoseconds>(high_resolution_clock::now() - _start); }
|
||||
|
||||
private:
|
||||
high_resolution_clock::time_point _start;
|
||||
};
|
||||
|
||||
template<class T>
|
||||
T min(std::vector<T> x)
|
||||
{ return *std::min_element(x.begin(), x.end()); }
|
||||
|
||||
|
||||
template<class OP, class SYNC>
|
||||
double bench(OP const & op, SYNC const & sync, triton::driver::device const & device)
|
||||
{
|
||||
timer tmr;
|
||||
std::vector<size_t> times;
|
||||
double total_time = 0;
|
||||
op();
|
||||
sync();
|
||||
while(total_time*1e-9 < 1e-3){
|
||||
float norm = 1;
|
||||
// normalize clock if possible to get roughly constant result
|
||||
if(auto cu_device = dynamic_cast<const triton::driver::cu_device*>(&device))
|
||||
norm = (float)cu_device->current_sm_clock()/cu_device->max_sm_clock();
|
||||
tmr.start();
|
||||
op();
|
||||
sync();
|
||||
times.push_back(norm*tmr.get().count());
|
||||
total_time+=times.back();
|
||||
}
|
||||
return min(times);
|
||||
}
|
||||
|
||||
//
|
||||
|
||||
void build_conv_lut(int TK,
|
||||
int stride_d, int stride_h, int stride_w, int stride_c,
|
||||
int pad_d, int pad_h, int pad_w,
|
||||
int T, int R, int S,
|
||||
std::vector<int>& res, std::vector<int>& masks) {
|
||||
/* convolution parameters */
|
||||
int F = T * R * S;
|
||||
int Nlut = (TK + F - 1) / F * F;
|
||||
int upsample_w = 1;
|
||||
int upsample_h = 1;
|
||||
int upsample_d = 1;
|
||||
/* unpack index wrt filters */
|
||||
auto unpack = [&](int32_t trs){
|
||||
int32_t tr = trs / S;
|
||||
int32_t s = trs - tr*S;
|
||||
int32_t t = tr / R;
|
||||
int32_t r = tr - t*R;
|
||||
return std::make_tuple(t, r, s);
|
||||
};
|
||||
/* increments */
|
||||
for(size_t i = 0; i < Nlut; ++i)
|
||||
res[i] = (((i + TK) % Nlut) - i);
|
||||
/* deltas */
|
||||
size_t Ds0 = Nlut;
|
||||
size_t Ds1 = upsample_w;
|
||||
size_t Ds2 = upsample_h;
|
||||
size_t Ds3 = upsample_d;
|
||||
for(size_t pd = 0; pd < Ds3; ++pd)
|
||||
for(size_t ph = 0; ph < Ds2; ++ph)
|
||||
for(size_t pw = 0; pw < Ds1; ++pw){
|
||||
int32_t* deltas_ptr = &res[Nlut + pw*Ds0 + ph*Ds0*Ds1 + pd*Ds0*Ds1*Ds2];
|
||||
// cumulative increments
|
||||
for(size_t i = 0; i < Ds0; ++i){
|
||||
int32_t ctrs = i;
|
||||
int32_t c = ctrs / F;
|
||||
int32_t t, r, s;
|
||||
std::tie(t, r, s) = unpack(ctrs % F);
|
||||
// next indices
|
||||
int32_t nextctrs = ctrs + TK;
|
||||
int32_t nextc = nextctrs / F;
|
||||
int32_t nextt, nextr, nexts;
|
||||
std::tie(nextt, nextr, nexts) = unpack(nextctrs % F);
|
||||
// diffs
|
||||
int32_t cdiff = nextc - c;
|
||||
int32_t tdiff = (nextt + pd)/upsample_d - (t + pd)/upsample_d;
|
||||
int32_t rdiff = (nextr + ph)/upsample_h - (r + ph)/upsample_h;
|
||||
int32_t sdiff = (nexts + pw)/upsample_w - (s + pw)/upsample_w;
|
||||
// delta pointers
|
||||
deltas_ptr[i] = cdiff*stride_c + sdiff*stride_w + rdiff*stride_h + tdiff*stride_d;
|
||||
}
|
||||
}
|
||||
|
||||
/* Masks */
|
||||
size_t Ms0 = Nlut;
|
||||
size_t Ms1 = 2*pad_w + 1;
|
||||
size_t Ms2 = 2*pad_h + 1;
|
||||
size_t Ms3 = 2*pad_d + 1;
|
||||
|
||||
for(size_t pd = 0; pd < Ms3; ++pd)
|
||||
for(size_t ph = 0; ph < Ms2; ++ph)
|
||||
for(size_t pw = 0; pw < Ms1; ++pw){
|
||||
int32_t* masks_ptr = &masks[Nlut + pw*Ms0 + ph*Ms0*Ms1 + pd*Ms0*Ms1*Ms2];
|
||||
for(size_t i = 0; i < Ms0; ++i){
|
||||
int32_t t, r, s;
|
||||
int32_t mask = 0x0;
|
||||
for(size_t j = 0; j < TK; ++j){
|
||||
std::tie(t, r, s) = unpack((i + j) % F);
|
||||
bool in_bounds_d = (t + pd) >= pad_d && (t + pd) < (T + pad_d);
|
||||
bool in_bounds_h = (r + ph) >= pad_h && (r + ph) < (R + pad_h);
|
||||
bool in_bounds_w = (s + pw) >= pad_w && (s + pw) < (S + pad_w);
|
||||
mask |= (in_bounds_d && in_bounds_h && in_bounds_w) << j;
|
||||
}
|
||||
masks_ptr[i] = mask;
|
||||
}
|
||||
}
|
||||
for(size_t i = 0; i < Nlut; ++i)
|
||||
masks[i] = 0x0;
|
||||
}
|
||||
|
||||
|
||||
// Index computation
|
||||
inline int32_t idx(int32_t x, int32_t y, int32_t z, int32_t w, int32_t u,
|
||||
int32_t /*s0*/, int32_t s1, int32_t s2, int32_t s3, int32_t s4)
|
||||
{ return u + w*s4 + z*s4*s3 + y*s4*s3*s2 + x*s4*s3*s2*s1; }
|
||||
|
||||
|
||||
// Pack
|
||||
|
||||
template <class T> T clamp(T x, T lo, T hi){
|
||||
return std::max<T>(lo, std::min<T>(x, hi));
|
||||
}
|
||||
|
||||
|
||||
template<class T, class U>
|
||||
T pack(U* tmp, U scale);
|
||||
|
||||
template<>
|
||||
double pack<double, double>(double* tmp, double scale)
|
||||
{ return tmp[0]*scale; }
|
||||
|
||||
template<>
|
||||
float pack<float, float>(float* tmp, float scale)
|
||||
{ return tmp[0]*scale; }
|
||||
|
||||
template<>
|
||||
int pack<int, float>(float* tmp, float scale)
|
||||
{
|
||||
int res = 0;
|
||||
for(int i = 0; i < 4; i++){
|
||||
int8_t clamped = std::round(clamp(tmp[i]*scale, (float)-128, (float)127));
|
||||
res |= (clamped & 0xFF) << (8*i);
|
||||
}
|
||||
return res;
|
||||
}
|
||||
|
||||
template<class T> struct pack_increment
|
||||
{ enum{ VALUE = 1}; };
|
||||
|
||||
template<> struct pack_increment<int>
|
||||
{ enum{ VALUE = 4}; };
|
||||
|
||||
// Dot
|
||||
template<class T>
|
||||
inline T dot(T x, T y, T z)
|
||||
{
|
||||
return std::fma(x, y, z);
|
||||
}
|
||||
|
||||
inline int dot(int x, int y, int z){
|
||||
int res = 0;
|
||||
for(int i = 0; i < 4; i++){
|
||||
int32_t a = ((x >> (8*i)) & 0x000000FF);
|
||||
int32_t b = ((y >> (8*i)) & 0x000000FF);
|
||||
res += (*(int8_t*)(&a)) * (*(int8_t*)(&b));
|
||||
}
|
||||
return res + z;
|
||||
}
|
||||
|
||||
|
||||
|
||||
template<class IN_DTYPE, class OUT_DTYPE>
|
||||
void cpp_conv_nchw(int32_t C, int32_t N, int32_t K,
|
||||
int32_t D, int32_t H, int32_t W,
|
||||
int32_t T, int32_t R, int32_t S,
|
||||
int32_t pad_d, int32_t pad_h, int32_t pad_w,
|
||||
int32_t stride_d, int32_t stride_h, int32_t stride_w,
|
||||
int32_t M, int32_t P, int32_t Q,
|
||||
std::vector<OUT_DTYPE>& O,
|
||||
const std::vector<IN_DTYPE>& I,
|
||||
const std::vector<IN_DTYPE>& F)
|
||||
{
|
||||
static const int PACK_IN = pack_increment<IN_DTYPE>::VALUE;
|
||||
static const int PACK_OUT = pack_increment<OUT_DTYPE>::VALUE;
|
||||
if(C % PACK_IN != 0) throw std::runtime_error("Number of input channels must be a multiple of 4");
|
||||
if(K % PACK_OUT != 0) throw std::runtime_error("Number of output channels must be a multiple of 4");
|
||||
C /= PACK_IN;
|
||||
K /= PACK_OUT;
|
||||
int32_t Kout = K;
|
||||
IN_DTYPE accs[PACK_OUT];
|
||||
float tmp[PACK_OUT];
|
||||
for(int32_t m = 0 ; m < M; ++m)
|
||||
for(int32_t p = 0 ; p < P; ++p)
|
||||
for(int32_t q = 0; q < Q; ++q)
|
||||
for(int32_t n = 0; n < N; ++n)
|
||||
for(int32_t k = 0; k < Kout ; ++k)
|
||||
{
|
||||
for(int32_t i = 0; i < PACK_OUT; ++i)
|
||||
accs[i] = 0;
|
||||
int32_t mm = m*stride_d - pad_d;
|
||||
int32_t pp = p*stride_h - pad_h;
|
||||
int32_t qq = q*stride_w - pad_w;
|
||||
for(int32_t kk = 0; kk < PACK_OUT; ++kk)
|
||||
for(int32_t c = 0; c < C; ++c)
|
||||
for(int32_t t = 0; t < T; ++t)
|
||||
for(int32_t r = 0; r < R; ++r)
|
||||
for(int32_t s = 0; s < S; ++s){
|
||||
int32_t d = mm + t;
|
||||
int32_t h = pp + r;
|
||||
int32_t w = qq + s;
|
||||
bool in_bounds = (d >= 0 && h >= 0 && w >= 0 && d < D && h < H && w < W);
|
||||
IN_DTYPE i = in_bounds?I[idx(n, c, d, h, w, N, C, D, H, W)]:0;
|
||||
IN_DTYPE f = F[idx(c, t, r, s, k*PACK_OUT + kk, C, T, R, S, K*PACK_OUT)];
|
||||
accs[kk] = dot(i, f, accs[kk]);
|
||||
}
|
||||
for(int32_t kk = 0; kk < PACK_OUT; ++kk){
|
||||
tmp[kk] = accs[kk];
|
||||
}
|
||||
O[idx(n, k, m, p, q, N, K, M, P, Q)] = tmp[0];
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
// input layout: C, H, W, BS
|
||||
// filter layout: C, K
|
||||
// output layout: K, H, W, BS
|
||||
template<class IN_DTYPE, class OUT_DTYPE>
|
||||
void shift_conv(int32_t C, int32_t H, int32_t W, int32_t BS,
|
||||
int32_t K,
|
||||
std::vector<OUT_DTYPE>& O,
|
||||
const std::vector<IN_DTYPE>& I,
|
||||
const std::vector<IN_DTYPE>& F,
|
||||
const std::vector<int32_t> shift_h,
|
||||
const std::vector<int32_t> shift_w)
|
||||
{
|
||||
OUT_DTYPE acc;
|
||||
for(int32_t p = 0; p < H; ++p)
|
||||
for(int32_t q = 0; q < W; ++q)
|
||||
for(int32_t bs = 0; bs < BS; ++bs)
|
||||
for(int32_t k = 0; k < K; ++k)
|
||||
{
|
||||
acc = 0;
|
||||
for(int32_t c = 0; c < C; ++c){
|
||||
int32_t h = p + shift_h[c];
|
||||
int32_t w = q + shift_w[c];
|
||||
bool in_bounds = (h >= 0 && w >= 0 && h < H && w < W);
|
||||
IN_DTYPE a = in_bounds?I[bs + w*BS + h*BS*W + c*BS*H*W]:0;
|
||||
IN_DTYPE b = F[k + c*K];
|
||||
acc = dot(a, b, acc);
|
||||
}
|
||||
O[bs + q*BS + p*BS*W + k*BS*H*W] = acc;
|
||||
}
|
||||
}
|
236
examples/cpp/conv.cpp
Normal file
236
examples/cpp/conv.cpp
Normal file
@@ -0,0 +1,236 @@
|
||||
#include <cstring>
|
||||
#include <cstdio>
|
||||
#include "common.hpp"
|
||||
#include "triton/jit.h"
|
||||
#include "triton/driver/backend.h"
|
||||
#include "triton/driver/stream.h"
|
||||
|
||||
std::string src =
|
||||
R"(
|
||||
const tunable int32 TM = {16, 32, 64};
|
||||
const tunable int32 TN = {16, 32, 64};
|
||||
const tunable int32 TK = {8};
|
||||
|
||||
__constant__ int32* delta = alloc_const int32[18];
|
||||
__constant__ int32* masks = alloc_const int32[1024];
|
||||
|
||||
void conv(read_only restrict fp32 *a,
|
||||
read_only restrict fp32 *b,
|
||||
fp32 *c,
|
||||
int32 M, int32 N, int32 K,
|
||||
int32 AN, int32 AH, int32 AW,
|
||||
int32 CN, int32 CK, int32 CP, int32 CQ,
|
||||
int32 AC, int32 AR, int32 AS,
|
||||
int32 lda_n, int32 lda_c, int32 lda_h, int32 lda_w,
|
||||
int32 ldc_n, int32 ldc_k, int32 ldc_p, int32 ldc_q,
|
||||
int32 pad_h, int32 pad_w,
|
||||
int32 bound){
|
||||
int32 rxa[TM] = get_global_range[TM](0);
|
||||
int32 rb0[TN] = get_global_range[TN](1);
|
||||
int32 rka[TK] = 0 ... TK;
|
||||
int32 rb1[TK] = 0 ... TK;
|
||||
fp32 C[TM, TN] = 0;
|
||||
int32 ranh[TM] = rxa / CQ;
|
||||
int32 raw[TM] = rxa % CQ - pad_w;
|
||||
int32 ran[TM] = ranh / CP;
|
||||
int32 rah[TM] = ranh % CP - pad_h;
|
||||
int32 ra0[TM] = ran*lda_n + rah*lda_h + raw*lda_w;
|
||||
int32 racr[TK] = rka / AS;
|
||||
int32 ras[TK] = rka % AS;
|
||||
int32 rac[TK] = racr / AR;
|
||||
int32 rar[TK] = racr % AR;
|
||||
int32 ra1[TK] = rac*lda_c + rar*lda_h + ras*lda_w;
|
||||
fp32* pa[TM, TK] = a + ra1[newaxis, :] + ra0[:, newaxis];
|
||||
fp32* pb[TN, TK] = b + rb1[newaxis, :]*CK + rb0[:, newaxis];
|
||||
__constant__ int32* pincd[TK] = delta + rka;
|
||||
__constant__ int32* pd[TK] = delta + AR*AS + rka;
|
||||
int32 d[TK] = *pd;
|
||||
int32 incd[TK] = *pincd;
|
||||
int32 maskh[TM] = pad_h + min(rah, 0) + max(rah + AR - AH, 0);
|
||||
int32 maskw[TM] = pad_w + min(raw, 0) + max(raw + AS - AW, 0);
|
||||
__constant__ int32* pm[TM] = masks + AR*AS + maskw*AR*AS + maskh*AR*AS*(2*pad_w + 1);
|
||||
__constant__ int32* pincm[TM] = delta;
|
||||
int32 incm[TM] = *pincm;
|
||||
int32 checka0[TM] = *pm;
|
||||
int32 checka1[TK] = 1 << rka;
|
||||
int1 checka[TM, TK] = (checka0[:, newaxis] & checka1[newaxis, :]) > 0;
|
||||
fp32 a[TM, TK] = checka ? *pa : 0;
|
||||
fp32 b[TN, TK] = *pb;
|
||||
for(int32 k = K; k > 0; k = k - TK){
|
||||
C = dot(a, trans(b), C);
|
||||
pb = pb + TK*CK;
|
||||
pa = pa + d[newaxis, :];
|
||||
b = *pb;
|
||||
pd = pd + incd;
|
||||
pincd = pincd + incd;
|
||||
d = *pd;
|
||||
incd = *pincd;
|
||||
pm = pm + incm;
|
||||
pincm = pincm + incm;
|
||||
incm = *pincm;
|
||||
checka0 = *pm;
|
||||
checka = (checka0[:, newaxis] & checka1[newaxis, :]) > 0;
|
||||
a = checka ? *pa : 0;
|
||||
}
|
||||
int32 rxc[TM] = get_global_range[TM](0);
|
||||
int32 rc1[TN] = get_global_range[TN](1);
|
||||
int32 rcn[TM] = rxc / (CP*CQ);
|
||||
int32 rcpq[TM] = rxc % (CP*CQ);
|
||||
int32 rc0[TM] = rcn * ldc_n + rcpq;
|
||||
fp32* pc[TM, TN] = c + rc1[newaxis, :]*ldc_k + rc0[:, newaxis];
|
||||
int1 checkc0[TM] = rxc < M;
|
||||
int1 checkc1[TN] = rc1 < N;
|
||||
int1 checkc[TM, TN] = checkc0[:, newaxis] && checkc1[newaxis, :];
|
||||
@checkc *pc = C;
|
||||
})";
|
||||
|
||||
|
||||
|
||||
int main() {
|
||||
// initialize default compute device
|
||||
auto context = triton::driver::backend::contexts::get_default();
|
||||
// initialize just-in-time compiler
|
||||
triton::jit jit(context);
|
||||
// initialization
|
||||
int32_t AN = 4, CK = 32;
|
||||
int32_t AD = 1, AH = 24, AW = 240;
|
||||
int32_t BC = 64, BT = 1, BR = 3, BS = 3;
|
||||
int32_t pad_d = 0, pad_h = 1, pad_w = 1;
|
||||
int32_t stride_d = 1, stride_h = 1, stride_w = 1;
|
||||
int32_t upsample_d = 1, upsample_h = 1, upsample_w = 1;
|
||||
int32_t CM = (AD*upsample_d - BT + 1 + 2*pad_d + stride_d - 1)/stride_d;
|
||||
int32_t CP = (AH*upsample_h - BR + 1 + 2*pad_h + stride_h - 1)/stride_h;
|
||||
int32_t CQ = (AW*upsample_w - BS + 1 + 2*pad_w + stride_w - 1)/stride_w;
|
||||
// equivalent matmul dimensions
|
||||
int32_t M = AN*CM*CP*CQ;
|
||||
int32_t N = CK;
|
||||
int32_t K = BC*BT*BR*BS;
|
||||
std::vector<float> hc(AN*CP*CQ*CK);
|
||||
std::vector<float> rc(AN*CP*CQ*CK);
|
||||
std::vector<float> ha(AN*BC*AH*AW);
|
||||
std::vector<float> hb(BC*BR*BS*CK);
|
||||
srand(0);
|
||||
for(size_t i = 0; i < ha.size(); i++)
|
||||
ha[i] = 1;
|
||||
for(size_t i = 0; i < hb.size(); i++)
|
||||
hb[i] = 1;
|
||||
for(size_t i = 0; i < hc.size(); i++)
|
||||
hc[i] = 0;
|
||||
triton::driver::buffer* dc = triton::driver::buffer::create(context, hc.size()*4);
|
||||
triton::driver::buffer* da = triton::driver::buffer::create(context, ha.size()*4);
|
||||
triton::driver::buffer* db = triton::driver::buffer::create(context, hb.size()*4);
|
||||
triton::driver::stream* stream = triton::driver::stream::create(context);
|
||||
stream->write(da, true, 0, ha);
|
||||
stream->write(db, true, 0, hb);
|
||||
stream->write(dc, true, 0, hc);
|
||||
stream->synchronize();
|
||||
// memory strides for data
|
||||
int32_t stride_i_w = 1;
|
||||
int32_t stride_i_h = AW*stride_i_w;
|
||||
int32_t stride_i_d = AH*stride_i_h;
|
||||
int32_t stride_i_c = AD*stride_i_d;
|
||||
int32_t stride_i_n = BC*stride_i_c;
|
||||
// memory strides for filters
|
||||
int32_t stride_f_k = 1;
|
||||
int32_t stride_f_s = CK*stride_f_k;
|
||||
int32_t stride_f_r = BS*stride_f_s;
|
||||
int32_t stride_f_t = BR*stride_f_r;
|
||||
int32_t stride_f_c = BT*stride_f_t;
|
||||
// memory stride for activations
|
||||
int32_t stride_o_q = 1;
|
||||
int32_t stride_o_p = CQ*stride_o_q;
|
||||
int32_t stride_o_m = CP*stride_o_p;
|
||||
int32_t stride_o_k = CM*stride_o_m;
|
||||
int32_t stride_o_n = CK*stride_o_k;
|
||||
// look-up table
|
||||
int TK = 8;
|
||||
int F = BT * BR * BS;
|
||||
int nlut = (TK + F - 1) / F * F;
|
||||
std::vector<int> h_delta(nlut + upsample_d*upsample_h*upsample_w*nlut);
|
||||
std::vector<int> h_masks(nlut + (2*pad_h+1)*(2*pad_w+1)*(2*pad_d+1)*nlut);
|
||||
build_conv_lut(TK, stride_i_d, stride_i_h, stride_i_w, stride_i_c, pad_d, pad_h, pad_w, BT, BR, BS, h_delta, h_masks);
|
||||
// benchmark a given convolution kernel
|
||||
auto benchmark = [&](triton::driver::kernel* kernel,
|
||||
triton::jit::launch_information info) {
|
||||
// launch info
|
||||
unsigned TM = info.global_range_size[0];
|
||||
unsigned TN = info.global_range_size[1];
|
||||
unsigned TK = jit.get_int("TK");
|
||||
// initialize constant memory
|
||||
triton::driver::buffer* delta = jit.get_buffer("delta");
|
||||
triton::driver::buffer* masks = jit.get_buffer("masks");
|
||||
stream->write(delta, false, 0, h_delta.size()*4, h_delta.data());
|
||||
stream->write(masks, false, 0, h_masks.size()*4, h_masks.data());
|
||||
stream->synchronize();
|
||||
// launch info
|
||||
unsigned nthreads = info.num_threads;
|
||||
std::array<size_t, 3> grid = {(M + TM - 1)/TM, (N + TN - 1)/TN, 1};
|
||||
// fast bounds-checking
|
||||
unsigned lasti = (grid[0]*TM - 1)*TM + TM - 1;
|
||||
unsigned lastj = (grid[1]*TN - 1)*TN + TN - 1;
|
||||
unsigned lastk = TK - 1;
|
||||
bool AT = false;
|
||||
bool BT = true;
|
||||
unsigned last_safe_a = (AT==false)?(M*K - 1 - lasti)/M - lastk : M*K - 1 - lasti*K - lastk;
|
||||
unsigned last_safe_b = (BT==true)?(N*K - 1 - lastj)/N - lastk : N*K - 1 - lastj*K - lastk;
|
||||
int32_t bound = std::max<unsigned>(1, std::max(K - last_safe_a, K - last_safe_b));
|
||||
// set arguments
|
||||
kernel->setArg(0, da);
|
||||
kernel->setArg(1, db);
|
||||
kernel->setArg(2, dc);
|
||||
kernel->setArg(3, M);
|
||||
kernel->setArg(4, N);
|
||||
kernel->setArg(5, K);
|
||||
kernel->setArg(6, AN);
|
||||
kernel->setArg(7, AH);
|
||||
kernel->setArg(8, AW);
|
||||
kernel->setArg(9, AN);
|
||||
kernel->setArg(10, CK);
|
||||
kernel->setArg(11, CP);
|
||||
kernel->setArg(12, CQ);
|
||||
kernel->setArg(13, BC);
|
||||
kernel->setArg(14, BR);
|
||||
kernel->setArg(15, BS);
|
||||
kernel->setArg(16, stride_i_n);
|
||||
kernel->setArg(17, stride_i_c);
|
||||
kernel->setArg(18, stride_i_h);
|
||||
kernel->setArg(19, stride_i_w);
|
||||
kernel->setArg(20, stride_o_n);
|
||||
kernel->setArg(21, stride_o_k);
|
||||
kernel->setArg(22, stride_o_p);
|
||||
kernel->setArg(23, stride_o_q);
|
||||
kernel->setArg(24, pad_h);
|
||||
kernel->setArg(25, pad_w);
|
||||
kernel->setArg(26, bound);
|
||||
// dry run
|
||||
stream->enqueue(kernel, grid, {nthreads, 1, 1});
|
||||
stream->synchronize();
|
||||
// benchmark
|
||||
double ts = bench([&](){stream->enqueue(kernel, grid, {nthreads, 1, 1});},
|
||||
[&](){ stream->synchronize(); }, *context->device());
|
||||
ts = ts * 1e-9;
|
||||
double tflops = 2.*M*N*K / ts * 1e-12;
|
||||
return tflops;
|
||||
};
|
||||
// run
|
||||
std::vector<unsigned> params = {
|
||||
16, 2, 64,
|
||||
32, 2, 64,
|
||||
16, 8, 2, 2,
|
||||
8, 8,
|
||||
4
|
||||
};
|
||||
// jit.autotune("conv", src, benchmark);
|
||||
jit.add_module("conv", src, params);
|
||||
triton::driver::kernel* kernel = jit.get_function("conv");
|
||||
triton::jit::launch_information info = jit.get_launch_info("conv");
|
||||
std::cout << "Performance: " << benchmark(kernel, info) << " TFLOPS " << std::endl;
|
||||
stream->read(dc, true, 0, hc);
|
||||
cpp_conv_nchw(BC, AN, CK, AD, AH, AW, BT, BR, BS, pad_d, pad_h, pad_w, stride_d, stride_h, stride_w, CM, CP, CQ, rc, ha, hb);
|
||||
for(size_t i = 0; i < M*N; i++)
|
||||
if(std::abs(hc[i] - rc[i])/std::max(hc[i], rc[i]) > 1e-4){
|
||||
std::cout << i << " " << hc[i] << " " << rc[i] << std::endl;
|
||||
exit(EXIT_FAILURE);
|
||||
}
|
||||
std::cout << "Pass!" << std::endl;
|
||||
}
|
162
examples/cpp/dot.cpp
Normal file
162
examples/cpp/dot.cpp
Normal file
@@ -0,0 +1,162 @@
|
||||
#include <cstring>
|
||||
#include <cstdio>
|
||||
#include "common.hpp"
|
||||
#include "triton/jit.h"
|
||||
#include "triton/driver/backend.h"
|
||||
#include "triton/driver/stream.h"
|
||||
|
||||
const char* src =
|
||||
R"(
|
||||
const tunable int32 TM = {16, 32, 64, 128};
|
||||
const tunable int32 TN = {16, 32, 64, 128};
|
||||
const tunable int32 TK = {8};
|
||||
const tunable int32 GZ = {1};
|
||||
|
||||
void matmul(restrict read_only fp32 *A, restrict read_only fp32 *B, fp32 *C,
|
||||
int32 M, int32 N, int32 K,
|
||||
int32 lda, int32 ldb, int32 ldc,
|
||||
int32 *locks, int32 grid0, int32 grid1) {
|
||||
int32 rxa[TM] = get_global_range[TM](0);
|
||||
int32 ryb[TN] = get_global_range[TN](1);
|
||||
int32 rz = get_global_range[1](2);
|
||||
int32 rka[TK] = 0 ... TK;
|
||||
int32 rkb[TK] = 0 ... TK;
|
||||
fp32 c[TM, TN] = 0;
|
||||
int32 div = K / GZ;
|
||||
int32 rem = K % GZ;
|
||||
K = select(rz < rem, div - 1, div);
|
||||
int32 offk = select(rz < rem, rz*(div + 1), rz*div + rem);
|
||||
fp32* pa[TM, TK] = A + (offk + rka[newaxis, :])*lda + rxa[:, newaxis];
|
||||
fp32* pb[TN, TK] = B + (offk + rkb[newaxis, :])*ldb + ryb[:, newaxis];
|
||||
fp32 a[TM, TK] = *pa;
|
||||
fp32 b[TN, TK] = *pb;
|
||||
int32 last_a = ((M*K - 1) - (TM*TK + 1)) / lda;
|
||||
int32 last_b = ((K*N - 1) - (TN*TK + 1)) / ldb;
|
||||
last_a = last_a / TK * TK;
|
||||
last_b = last_b / TK * TK;
|
||||
int32 bound = K - max(last_a, last_b);
|
||||
for(int32 k = K; k > bound; k = k - TK){
|
||||
c = dot(a, trans(b), c);
|
||||
pa = pa + TK*lda;
|
||||
pb = pb + TK*ldb;
|
||||
a = *pa;
|
||||
b = *pb;
|
||||
}
|
||||
int32 rxc[TM] = get_global_range[TM](0);
|
||||
int32 ryc[TN] = get_global_range[TN](1);
|
||||
for(int32 k = bound; k > 0; k = k - 1){
|
||||
int1 checka[TM, 1] = rxc[:, newaxis] < M;
|
||||
int1 checkb[TN, 1] = ryc[:, newaxis] < N;
|
||||
fp32* pa[TM, 1] = A + (offk + K - k)*lda + rxc[:, newaxis];
|
||||
fp32* pb[TN, 1] = B + (offk + K - k)*ldb + ryc[:, newaxis];
|
||||
fp32 a[TM, 1] = checka ? *pa : 0;
|
||||
fp32 b[TN, 1] = checkb ? *pb : 0;
|
||||
c = dot(a, trans(b), c);
|
||||
}
|
||||
int32 ridx = get_range_id(0);
|
||||
int32 ridy = get_range_id(1);
|
||||
fp32* pc[TM, TN] = C + ryc[newaxis, :]*ldc + rxc[:, newaxis];
|
||||
int32 *plock = locks + ridx + ridy*grid0;
|
||||
for(int32 L = __atomic_cas(plock, 0, 1); L == 1; L = __atomic_cas(plock, 0, 1)){}
|
||||
int32 *pcount = plock + grid0*grid1;
|
||||
int32 count = *pcount;
|
||||
int32 countp1 = select(count == GZ - 1, 0, count + 1);
|
||||
int1 checkc0[TM] = rxc < M;
|
||||
int1 checkc1[TN] = ryc < N;
|
||||
int1 checkc[TM, TN] = checkc0[:, newaxis] && checkc1[newaxis, :];
|
||||
if(count == 0) {
|
||||
@checkc *pc = c;
|
||||
*pcount = countp1;
|
||||
}
|
||||
else {
|
||||
@checkc *pc = c + (checkc ? *pc : 0);
|
||||
*pcount = countp1;
|
||||
}
|
||||
__atomic_cas(plock, 1, 0);
|
||||
}
|
||||
)";
|
||||
|
||||
int main() {
|
||||
// initialize default compute device
|
||||
auto context = triton::driver::backend::contexts::get_default();
|
||||
triton::jit jit(context);
|
||||
|
||||
// matrix multiplication parameters
|
||||
int32_t M = 512, N = 512, K = 512;
|
||||
std::vector<float> hc(M*N);
|
||||
std::vector<float> rc(M*N);
|
||||
std::vector<float> ha(M*K);
|
||||
std::vector<float> hb(K*N);
|
||||
std::vector<int32_t> hlocks(2048);
|
||||
srand(0);
|
||||
for(size_t i = 0; i < ha.size(); i++)
|
||||
ha[i] = (float)rand()/RAND_MAX;
|
||||
for(size_t i = 0; i < hb.size(); i++)
|
||||
hb[i] = (float)rand()/RAND_MAX;
|
||||
for(size_t i = 0; i < hc.size(); i++)
|
||||
hc[i] = 0;
|
||||
triton::driver::buffer* dc = triton::driver::buffer::create(context, hc.size()*4);
|
||||
triton::driver::buffer* da = triton::driver::buffer::create(context, ha.size()*4);
|
||||
triton::driver::buffer* db = triton::driver::buffer::create(context, hb.size()*4);
|
||||
triton::driver::buffer* dlocks = triton::driver::buffer::create(context, hlocks.size()*4);
|
||||
triton::driver::stream* stream = triton::driver::stream::create(context);
|
||||
stream->write(da, true, 0, ha);
|
||||
stream->write(db, true, 0, hb);
|
||||
stream->write(dc, true, 0, hc);
|
||||
stream->synchronize();
|
||||
|
||||
|
||||
// benchmark a given matrix multiplication kernel
|
||||
auto benchmark = [&](triton::driver::kernel* kernel,
|
||||
triton::jit::launch_information info) {
|
||||
// launch info
|
||||
unsigned TM = info.global_range_size[0];
|
||||
unsigned TN = info.global_range_size[1];
|
||||
unsigned nthreads = info.num_threads;
|
||||
unsigned GZ = jit.get_int("GZ");
|
||||
std::array<size_t, 3> grid = {(M + TM - 1)/TM, (N + TN - 1)/TN, GZ};
|
||||
// init locks
|
||||
stream->write(dlocks, true, 0, hlocks);
|
||||
// set argument
|
||||
kernel->setArg(0, da);
|
||||
kernel->setArg(1, db);
|
||||
kernel->setArg(2, dc);
|
||||
kernel->setArg(3, M);
|
||||
kernel->setArg(4, N);
|
||||
kernel->setArg(5, K);
|
||||
kernel->setArg(6, M);
|
||||
kernel->setArg(7, N);
|
||||
kernel->setArg(8, M);
|
||||
kernel->setArg(9, dlocks);
|
||||
kernel->setArg(10, grid[0]);
|
||||
kernel->setArg(11, grid[1]);
|
||||
// dry run
|
||||
stream->enqueue(kernel, grid, {nthreads, 1, 1});
|
||||
stream->synchronize();
|
||||
// benchmark
|
||||
double ts = bench([&](){stream->enqueue(kernel, grid, {nthreads, 1, 1});},
|
||||
[&](){ stream->synchronize(); }, *context->device());
|
||||
ts = ts * 1e-9;
|
||||
double tflops = 2.*M*N*K / ts * 1e-12;
|
||||
return tflops;
|
||||
};
|
||||
|
||||
|
||||
// just-in-time compile source-code
|
||||
std::vector<unsigned> params = {
|
||||
16, 2, 64, 16, 2, 64, 16, 8, 2, 2, 8, 8, 8, 1
|
||||
};
|
||||
// jit.autotune("matmul",src, benchmark);
|
||||
jit.add_module("matmul", src, params);
|
||||
triton::driver::kernel* kernel = jit.get_function("matmul");
|
||||
triton::jit::launch_information info = jit.get_launch_info("matmul");
|
||||
std::cout << "Performance: " << benchmark(kernel, info) << " TFLOPS " << std::endl;
|
||||
stream->read(dc, true, 0, hc);
|
||||
simple_gemm<float,false,true>(rc, ha, hb, M, N, K);
|
||||
for(size_t i = 0; i < M*N; i++)
|
||||
if(std::abs(hc[i] - rc[i])/std::max(hc[i], rc[i]) > 1e-4){
|
||||
std::cout << i << " " << hc[i] << " " << rc[i] << std::endl;
|
||||
exit(EXIT_FAILURE);
|
||||
}
|
||||
std::cout << "Pass!" << std::endl;
|
||||
}
|
212
examples/cpp/shift.cpp
Normal file
212
examples/cpp/shift.cpp
Normal file
@@ -0,0 +1,212 @@
|
||||
#include <cstring>
|
||||
#include <cstdio>
|
||||
#include "common.hpp"
|
||||
#include "triton/jit.h"
|
||||
#include "triton/driver/backend.h"
|
||||
#include "triton/driver/stream.h"
|
||||
|
||||
// K = channels
|
||||
// M = batch * height * width
|
||||
// N = number of feature maps
|
||||
|
||||
const char* src =
|
||||
R"(
|
||||
const tunable int32 TM = {16, 32, 64, 128};
|
||||
const tunable int32 TN = {16, 32, 64, 128};
|
||||
const tunable int32 TK = {8};
|
||||
|
||||
__constant__ int32* delta = alloc_const int32[256];
|
||||
__constant__ int32* masks = alloc_const int32[8192];
|
||||
|
||||
void shift(restrict read_only fp32 *a, restrict read_only fp32 *b, fp32 *c,
|
||||
int32 M, int32 N, int32 K,
|
||||
int32 ABS, int32 AH, int32 AW, int32 AR, int32 AS){
|
||||
int32 rxa[TM] = get_global_range[TM](0);
|
||||
int32 ryb[TN] = get_global_range[TN](1);
|
||||
int32 rka[TK] = 0 ... TK;
|
||||
int32 rkb[TK] = 0 ... TK;
|
||||
fp32 C[TM, TN] = 0;
|
||||
fp32* pxa[TM, TK] = a + rxa[:, newaxis];
|
||||
fp32* pb[TN, TK] = b + rkb[newaxis, :]*N + ryb[:, newaxis];
|
||||
__constant__ int32* pd[TK] = delta + rka;
|
||||
int32 pad_h = AR/2;
|
||||
int32 pad_w = AS/2;
|
||||
int32 rawhc[TM] = rxa / ABS;
|
||||
int32 raw[TM] = rawhc % AW - pad_w;
|
||||
int32 rahc[TM] = rawhc / AW;
|
||||
int32 rah[TM] = rahc % AH - pad_h;
|
||||
int32 maskh[TM] = pad_h + min(rah, 0) + max(rah + AR - AH, 0);
|
||||
int32 maskw[TM] = pad_w + min(raw, 0) + max(raw + AS - AW, 0);
|
||||
__constant__ int32* pxm[TM] = masks + maskh*K + maskw*K*(2*pad_h + 1);
|
||||
__constant__ int32* pm[TM, TK] = pxm[:, newaxis] + rka[newaxis, :];
|
||||
for(int32 k = K; k > 0; k = k - TK){
|
||||
int32 delta[TK] = *pd;
|
||||
fp32 *pa[TM, TK] = pxa + delta[newaxis, :];
|
||||
int1 m[TM, TK] = *pm > 0;
|
||||
fp32 a[TM, TK] = m ? *pa : 0;
|
||||
fp32 b[TN, TK] = *pb;
|
||||
C = dot(a, trans(b), C);
|
||||
pb = pb + TK*N;
|
||||
pd = pd + TK;
|
||||
pm = pm + TK;
|
||||
}
|
||||
int32 rxc[TM] = get_global_range[TM](0);
|
||||
int32 ryc[TN] = get_global_range[TN](1);
|
||||
fp32* pc[TM, TN] = c + ryc[newaxis, :]*M + rxc[:, newaxis];
|
||||
int1 checkc0[TM] = rxc < M;
|
||||
int1 checkc1[TN] = ryc < N;
|
||||
int1 checkc[TM, TN] = checkc0[:, newaxis] && checkc1[newaxis, :];
|
||||
@checkc *pc = C;
|
||||
}
|
||||
)";
|
||||
|
||||
std::vector<int32_t> shift_deltas(// strides
|
||||
int32_t stride_w, int32_t stride_h, int32_t stride_c,
|
||||
// shift
|
||||
int32_t C,
|
||||
const std::vector<int32_t>& shift_h,
|
||||
const std::vector<int32_t>& shift_w) {
|
||||
std::vector<int32_t> res(C);
|
||||
for(unsigned c = 0; c < C; c++){
|
||||
res[c] = c*stride_c;
|
||||
res[c] += shift_h[c]*stride_h;
|
||||
res[c] += shift_w[c]*stride_w;
|
||||
}
|
||||
return res;
|
||||
}
|
||||
|
||||
std::vector<int32_t> shift_masks(int32_t C,
|
||||
const std::vector<int32_t>& shift_h,
|
||||
const std::vector<int32_t>& shift_w,
|
||||
int32_t R, int32_t S) {
|
||||
size_t S0 = C;
|
||||
size_t S1 = R;
|
||||
size_t S2 = S;
|
||||
std::vector<int32_t> res(S0*S1*S2);
|
||||
for(size_t ph = 0; ph < S1; ++ph)
|
||||
for(size_t pw = 0; pw < S2; ++pw){
|
||||
int32_t* ptr = &res[ph*S0 + pw*S0*S1];
|
||||
for(size_t i = 0; i < S0; ++i){
|
||||
bool in_bounds_h = shift_h[i] + ph >= 0 && shift_h[i] + ph < R;
|
||||
bool in_bounds_w = shift_w[i] + pw >= 0 && shift_w[i] + pw < S;
|
||||
ptr[i] = in_bounds_h && in_bounds_w;
|
||||
}
|
||||
}
|
||||
return res;
|
||||
}
|
||||
|
||||
int main() {
|
||||
// initialize default compute device
|
||||
auto context = triton::driver::backend::contexts::get_default();
|
||||
// initialize just-in-time compiler
|
||||
triton::jit jit(context);
|
||||
// initialization
|
||||
int32_t R = 3, S = 3;
|
||||
int32_t BS = 4, F = 128;
|
||||
int32_t H = 32, W = 32;
|
||||
int32_t C = 128;
|
||||
// equivalent matmul dimensions
|
||||
int32_t M = BS*H*W;
|
||||
int32_t N = F;
|
||||
int32_t K = C;
|
||||
std::cout << M << " " << N << " " << K << std::endl;
|
||||
std::vector<float> hc(BS*H*W*F);
|
||||
std::vector<float> rc(BS*H*W*F);
|
||||
std::vector<float> ha(BS*C*H*W);
|
||||
std::vector<float> hb(F*C);
|
||||
// strides
|
||||
int32_t stride_i_bs = 1;
|
||||
int32_t stride_i_w = BS*stride_i_bs;
|
||||
int32_t stride_i_h = W*stride_i_w;
|
||||
int32_t stride_i_c = H*stride_i_h;
|
||||
// random shifts
|
||||
std::vector<int32_t> shift_h(C);
|
||||
std::vector<int32_t> shift_w(C);
|
||||
for(int32_t c = 0; c < C; c++){
|
||||
shift_h[c] = rand() % R - R/2;
|
||||
shift_w[c] = rand() % S - S/2;
|
||||
}
|
||||
// initialize buffers
|
||||
srand(0);
|
||||
for(int c = 0 ; c < C; c++)
|
||||
for(int h = 0 ; h < H; h++)
|
||||
for(int w = 0 ; w < W; w++)
|
||||
for(int bs = 0 ; bs < BS; bs++){
|
||||
float value = (float)rand() / RAND_MAX;
|
||||
size_t idx = bs + w*stride_i_w + h*stride_i_h + c*stride_i_c;
|
||||
ha[idx] = value;
|
||||
}
|
||||
for(size_t i = 0; i < hb.size(); i++)
|
||||
hb[i] = (float)rand() / RAND_MAX;
|
||||
for(size_t i = 0; i < hc.size(); i++)
|
||||
hc[i] = 0;
|
||||
triton::driver::buffer* dc = triton::driver::buffer::create(context, hc.size()*4);
|
||||
triton::driver::buffer* da = triton::driver::buffer::create(context, ha.size()*4);
|
||||
triton::driver::buffer* db = triton::driver::buffer::create(context, hb.size()*4);
|
||||
triton::driver::stream* stream = triton::driver::stream::create(context);
|
||||
stream->write(da, true, 0, ha);
|
||||
stream->write(db, true, 0, hb);
|
||||
stream->write(dc, true, 0, hc);
|
||||
stream->synchronize();
|
||||
std::vector<int32_t> h_delta = shift_deltas(stride_i_w, stride_i_h, stride_i_c, C, shift_h, shift_w);
|
||||
std::vector<int32_t> h_masks = shift_masks(C, shift_h, shift_w, R, S);
|
||||
// benchmark a given matrix multiplication kernel
|
||||
auto benchmark = [&](triton::driver::kernel* kernel,
|
||||
triton::jit::launch_information info) {
|
||||
// launch info
|
||||
unsigned TM = info.global_range_size[0];
|
||||
unsigned TN = info.global_range_size[1];
|
||||
unsigned nthreads = info.num_threads;
|
||||
// initialize constant memory
|
||||
triton::driver::buffer* delta = jit.get_buffer("delta");
|
||||
triton::driver::buffer* masks = jit.get_buffer("masks");
|
||||
stream->write(delta, false, 0, h_delta.size()*4, h_delta.data());
|
||||
stream->write(masks, false, 0, h_masks.size()*4, h_masks.data());
|
||||
stream->synchronize();
|
||||
// set argument
|
||||
kernel->setArg(0, da);
|
||||
kernel->setArg(1, db);
|
||||
kernel->setArg(2, dc);
|
||||
kernel->setArg(3, M);
|
||||
kernel->setArg(4, N);
|
||||
kernel->setArg(5, K);
|
||||
kernel->setArg(6, BS);
|
||||
kernel->setArg(7, H);
|
||||
kernel->setArg(8, W);
|
||||
kernel->setArg(9, R);
|
||||
kernel->setArg(10, S);
|
||||
// dry run
|
||||
std::array<size_t, 3> grid = {(M + TM - 1)/TM, (N + TN - 1)/TN, 1};
|
||||
stream->enqueue(kernel, grid, {nthreads, 1, 1});
|
||||
stream->synchronize();
|
||||
// benchmark
|
||||
double ts = bench([&](){stream->enqueue(kernel, grid, {nthreads, 1, 1});},
|
||||
[&](){ stream->synchronize(); }, *context->device());
|
||||
ts = ts * 1e-9;
|
||||
double tflops = 2.*M*N*K / ts * 1e-12;
|
||||
return tflops;
|
||||
};
|
||||
|
||||
// shift
|
||||
std::vector<unsigned> params = {
|
||||
16, 2, 64,
|
||||
32, 2, 64,
|
||||
16, 8, 2, 2,
|
||||
8, 8,
|
||||
4
|
||||
};
|
||||
// jit.autotune("shift", src, benchmark);
|
||||
jit.add_module("shift", src, params);
|
||||
triton::driver::kernel* kernel = jit.get_function("shift");
|
||||
triton::jit::launch_information info = jit.get_launch_info("shift");
|
||||
std::cout << "Performance: " << benchmark(kernel, info) << " TFLOPS " << std::endl;
|
||||
stream->read(dc, true, 0, hc);
|
||||
shift_conv(C, H, W, BS, F, rc, ha, hb, shift_h, shift_w);
|
||||
for(size_t i = 0; i < M*N; i++)
|
||||
if(std::abs(hc[i] - rc[i])/std::max(hc[i], rc[i]) > 1e-4){
|
||||
std::cout << i << " " << hc[i] << " " << rc[i] << std::endl;
|
||||
exit(EXIT_FAILURE);
|
||||
}
|
||||
std::cout << "Pass!" << std::endl;
|
||||
|
||||
}
|
93
examples/cpp/shift.ptx
Normal file
93
examples/cpp/shift.ptx
Normal file
@@ -0,0 +1,93 @@
|
||||
//
|
||||
// Generated by NVIDIA NVVM Compiler
|
||||
//
|
||||
// Compiler Build ID: CL-24817639
|
||||
// Cuda compilation tools, release 10.0, V10.0.130
|
||||
// Based on LLVM 3.4svn
|
||||
//
|
||||
|
||||
.version 6.3
|
||||
.target sm_60
|
||||
.address_size 64
|
||||
|
||||
// .globl _Z25shift_cuda_forward_kernelPKfPKiPfiiii
|
||||
|
||||
.visible .entry shift(
|
||||
.param .u64 _Z25shift_cuda_forward_kernelPKfPKiPfiiii_param_0,
|
||||
.param .u64 _Z25shift_cuda_forward_kernelPKfPKiPfiiii_param_1,
|
||||
.param .u64 _Z25shift_cuda_forward_kernelPKfPKiPfiiii_param_2,
|
||||
.param .u32 _Z25shift_cuda_forward_kernelPKfPKiPfiiii_param_3,
|
||||
.param .u32 _Z25shift_cuda_forward_kernelPKfPKiPfiiii_param_4,
|
||||
.param .u32 _Z25shift_cuda_forward_kernelPKfPKiPfiiii_param_5,
|
||||
.param .u32 _Z25shift_cuda_forward_kernelPKfPKiPfiiii_param_6
|
||||
)
|
||||
{
|
||||
.reg .pred %p<10>;
|
||||
.reg .f32 %f<2>;
|
||||
.reg .b32 %r<31>;
|
||||
.reg .b64 %rd<13>;
|
||||
|
||||
|
||||
ld.param.u64 %rd1, [_Z25shift_cuda_forward_kernelPKfPKiPfiiii_param_0];
|
||||
ld.param.u64 %rd3, [_Z25shift_cuda_forward_kernelPKfPKiPfiiii_param_1];
|
||||
ld.param.u64 %rd2, [_Z25shift_cuda_forward_kernelPKfPKiPfiiii_param_2];
|
||||
ld.param.u32 %r3, [_Z25shift_cuda_forward_kernelPKfPKiPfiiii_param_3];
|
||||
ld.param.u32 %r4, [_Z25shift_cuda_forward_kernelPKfPKiPfiiii_param_4];
|
||||
ld.param.u32 %r5, [_Z25shift_cuda_forward_kernelPKfPKiPfiiii_param_5];
|
||||
ld.param.u32 %r6, [_Z25shift_cuda_forward_kernelPKfPKiPfiiii_param_6];
|
||||
cvta.to.global.u64 %rd4, %rd3;
|
||||
mov.u32 %r7, %ntid.x;
|
||||
mov.u32 %r8, %ctaid.x;
|
||||
mov.u32 %r9, %tid.x;
|
||||
mad.lo.s32 %r1, %r7, %r8, %r9;
|
||||
mul.lo.s32 %r10, %r4, %r3;
|
||||
mul.lo.s32 %r11, %r10, %r5;
|
||||
mul.lo.s32 %r12, %r11, %r6;
|
||||
mul.lo.s32 %r13, %r5, %r4;
|
||||
mul.lo.s32 %r14, %r13, %r6;
|
||||
rem.s32 %r15, %r1, %r14;
|
||||
sub.s32 %r16, %r1, %r15;
|
||||
mul.lo.s32 %r17, %r6, %r5;
|
||||
div.s32 %r18, %r15, %r17;
|
||||
mul.lo.s32 %r19, %r18, %r17;
|
||||
sub.s32 %r20, %r15, %r19;
|
||||
div.s32 %r21, %r20, %r5;
|
||||
mul.lo.s32 %r22, %r21, %r6;
|
||||
sub.s32 %r23, %r20, %r22;
|
||||
shl.b32 %r24, %r18, 1;
|
||||
mul.wide.s32 %rd5, %r24, 4;
|
||||
add.s64 %rd6, %rd4, %rd5;
|
||||
ld.global.nc.u32 %r25, [%rd6];
|
||||
add.s32 %r26, %r25, %r21;
|
||||
ld.global.nc.u32 %r27, [%rd6+4];
|
||||
add.s32 %r28, %r23, %r27;
|
||||
add.s32 %r29, %r16, %r19;
|
||||
mad.lo.s32 %r30, %r26, %r5, %r29;
|
||||
add.s32 %r2, %r30, %r28;
|
||||
setp.lt.s32 %p1, %r1, %r12;
|
||||
setp.gt.s32 %p2, %r26, -1;
|
||||
and.pred %p3, %p1, %p2;
|
||||
setp.lt.s32 %p4, %r26, %r5;
|
||||
and.pred %p5, %p3, %p4;
|
||||
setp.gt.s32 %p6, %r28, -1;
|
||||
and.pred %p7, %p5, %p6;
|
||||
setp.lt.s32 %p8, %r28, %r6;
|
||||
and.pred %p9, %p7, %p8;
|
||||
@!%p9 bra BB0_2;
|
||||
bra.uni BB0_1;
|
||||
|
||||
BB0_1:
|
||||
cvta.to.global.u64 %rd7, %rd1;
|
||||
mul.wide.s32 %rd8, %r1, 4;
|
||||
add.s64 %rd9, %rd7, %rd8;
|
||||
ld.global.nc.f32 %f1, [%rd9];
|
||||
cvta.to.global.u64 %rd10, %rd2;
|
||||
mul.wide.s32 %rd11, %r2, 4;
|
||||
add.s64 %rd12, %rd10, %rd11;
|
||||
st.global.f32 [%rd12], %f1;
|
||||
|
||||
BB0_2:
|
||||
ret;
|
||||
}
|
||||
|
||||
|
@@ -74,8 +74,8 @@ class constant;
|
||||
class node {
|
||||
protected:
|
||||
static ir::value* explicit_cast(ir::builder &builder, ir::value *src, ir::type *dst_ty);
|
||||
static void implicit_broadcast(ir::module *mod, ir::type *dst_ty, ir::value *&src);
|
||||
static void implicit_broadcast(ir::module *mod, ir::value *&lhs, ir::value *&rhs);
|
||||
static void implicit_broadcast(ir::module *mod, ir::value *&arg, ir::type *ty);
|
||||
static void implicit_cast(ir::builder &builder, ir::value *&lhs, ir::value *&rhs,
|
||||
bool &is_float, bool &is_ptr, bool &is_int, bool &is_signed);
|
||||
public:
|
||||
@@ -164,6 +164,27 @@ private:
|
||||
const constant* axis_;
|
||||
};
|
||||
|
||||
class get_range_id: public builtin_expression{
|
||||
public:
|
||||
get_range_id(node *axis): axis_((constant*)axis) { }
|
||||
ir::value* codegen(ir::module *) const;
|
||||
|
||||
private:
|
||||
const constant* axis_;
|
||||
};
|
||||
|
||||
class atomic_cas: public builtin_expression{
|
||||
public:
|
||||
atomic_cas(node *ptr, node *cmp, node *val): ptr_(ptr), cmp_(cmp), val_(val) { }
|
||||
ir::value* codegen(ir::module *) const;
|
||||
|
||||
private:
|
||||
const node *ptr_;
|
||||
const node *cmp_;
|
||||
const node *val_;
|
||||
};
|
||||
|
||||
|
||||
class matmul_expression: public builtin_expression{
|
||||
public:
|
||||
matmul_expression(node* A, node *B, node *C):
|
||||
@@ -176,6 +197,49 @@ private:
|
||||
const expression *C_;
|
||||
};
|
||||
|
||||
class max_expression: public builtin_expression{
|
||||
public:
|
||||
max_expression(node* x, node* y)
|
||||
: x_((expression*)x), y_((expression*)y){ }
|
||||
ir::value* codegen(ir::module *) const;
|
||||
|
||||
private:
|
||||
const expression *x_;
|
||||
const expression *y_;
|
||||
};
|
||||
|
||||
class min_expression: public builtin_expression{
|
||||
public:
|
||||
min_expression(node* x, node* y)
|
||||
: x_((expression*)x), y_((expression*)y){ }
|
||||
ir::value* codegen(ir::module *mod) const;
|
||||
|
||||
private:
|
||||
const expression *x_;
|
||||
const expression *y_;
|
||||
};
|
||||
|
||||
class select_expression: public builtin_expression{
|
||||
public:
|
||||
select_expression(node* pred, node* if_value, node* else_value)
|
||||
: pred_((expression*)pred), if_value_((expression*)if_value), else_value_((expression*)else_value) { }
|
||||
ir::value* codegen(ir::module *mod) const;
|
||||
|
||||
private:
|
||||
const expression *pred_;
|
||||
const expression *if_value_;
|
||||
const expression *else_value_;
|
||||
};
|
||||
|
||||
class trans_expression: public builtin_expression{
|
||||
public:
|
||||
trans_expression(node *arg): arg_(arg) {}
|
||||
ir::value* codegen(ir::module *mod) const;
|
||||
|
||||
private:
|
||||
node* arg_;
|
||||
};
|
||||
|
||||
|
||||
class indexing_expression: public postfix_expression{
|
||||
public:
|
||||
@@ -189,6 +253,8 @@ private:
|
||||
const list<slice*>* slices_;
|
||||
};
|
||||
|
||||
|
||||
|
||||
class named_expression: public expression {
|
||||
public:
|
||||
named_expression(node *id): id_((const identifier*)id) { lvalue_ = this; }
|
||||
|
@@ -55,7 +55,7 @@ STORAGE_SPEC_T get_storage_spec(node *op) { return ((token*)op)->storage_spec;}
|
||||
%token VOID UINT1 UINT8 UINT16 UINT32 UINT64 INT1 INT8 INT16 INT32 INT64 FP32 FP64
|
||||
%token IF ELSE FOR CONTINUE
|
||||
%token NEWAXIS ELLIPSIS AT
|
||||
%token GET_GLOBAL_RANGE DOT ALLOC_CONST
|
||||
%token GET_GLOBAL_RANGE GET_RANGE_ID DOT TRANS MAX MIN SELECT ATOMIC_CAS ALLOC_CONST
|
||||
|
||||
%start translation_unit
|
||||
%%
|
||||
@@ -118,8 +118,15 @@ identifier
|
||||
|
||||
builtin
|
||||
: GET_GLOBAL_RANGE '[' primary_expression ']' '(' constant ')' { $$ = new get_global_range($3, $6); }
|
||||
| GET_RANGE_ID '(' constant ')' { $$ = new get_range_id($3); }
|
||||
| DOT '(' expression ',' expression ',' expression ')' { $$ = new matmul_expression($3, $5, $7); }
|
||||
| ALLOC_CONST type_specifier '[' constant ']' { $$ = new alloc_const(new typed_declaration_specifier(get_type_spec($2)), $4); }
|
||||
| TRANS '(' expression ')' { $$ = new trans_expression($3); }
|
||||
| MAX '(' expression ',' expression ')' { $$ = new max_expression($3, $5); }
|
||||
| MIN '(' expression ',' expression ')' { $$ = new min_expression($3, $5); }
|
||||
| SELECT '(' expression ',' expression ',' expression ')' { $$ = new select_expression($3, $5, $7); }
|
||||
| ATOMIC_CAS '(' expression ',' expression ',' expression ')' { $$ = new atomic_cas($3, $5, $7); }
|
||||
;
|
||||
|
||||
primary_expression
|
||||
: identifier { $$ = new named_expression($1); }
|
||||
|
@@ -41,7 +41,13 @@ using triton::ast::return_void;
|
||||
"fp64" { return return_impl(FP64, yytext); }
|
||||
"..." { return return_impl(ELLIPSIS, yytext); }
|
||||
"get_global_range" { return return_impl(GET_GLOBAL_RANGE, yytext); }
|
||||
"get_range_id" { return return_impl(GET_RANGE_ID, yytext); }
|
||||
"__atomic_cas" { return return_impl(ATOMIC_CAS, yytext); }
|
||||
"dot" { return return_impl(DOT, yytext); }
|
||||
"max" { return return_impl(MAX, yytext); }
|
||||
"min" { return return_impl(MIN, yytext); }
|
||||
"select" { return return_impl(SELECT, yytext); }
|
||||
"trans" { return return_impl(TRANS, yytext); }
|
||||
"continue" { return return_impl(CONTINUE, yytext); }
|
||||
"alloc_const" { return return_impl(ALLOC_CONST, yytext); }
|
||||
{L}({L}|{D})* { return return_impl(IDENTIFIER, yytext); }
|
||||
@@ -52,8 +58,6 @@ using triton::ast::return_void;
|
||||
L?'(\\.|[^\\'])+' { return return_impl(CONSTANT, yytext); }
|
||||
|
||||
{D}+{E}{FS}? { return return_impl(CONSTANT, yytext); }
|
||||
{D}*"."{D}+({E})?{FS}? { return return_impl(CONSTANT, yytext); }
|
||||
{D}+"."{D}*({E})?{FS}? { return return_impl(CONSTANT, yytext); }
|
||||
|
||||
L?\"(\\.|[^\\"])*\" { return return_impl(STRING_LITERAL, yytext); }
|
||||
|
||||
|
@@ -1,45 +0,0 @@
|
||||
#ifndef TDL_INCLUDE_IR_CODEGEN_LAYOUT_H
|
||||
#define TDL_INCLUDE_IR_CODEGEN_LAYOUT_H
|
||||
|
||||
#include <vector>
|
||||
#include <map>
|
||||
|
||||
namespace triton {
|
||||
|
||||
namespace ir {
|
||||
class module;
|
||||
class instruction;
|
||||
class value;
|
||||
}
|
||||
|
||||
namespace codegen{
|
||||
|
||||
struct shared_view_info{
|
||||
ir::value *usr;
|
||||
bool has_dedicated_storage;
|
||||
};
|
||||
|
||||
class layout {
|
||||
private:
|
||||
typedef std::vector<shared_view_info> shared_view_val_t;
|
||||
|
||||
void add_phi_nodes(ir::value *v);
|
||||
void add_shared_views(ir::value *v);
|
||||
|
||||
public:
|
||||
// accessors
|
||||
unsigned get_num_shared_views(ir::value *v);
|
||||
shared_view_info get_shared_view(ir::value *v, unsigned idx);
|
||||
|
||||
// run
|
||||
void run(ir::module &mod);
|
||||
|
||||
private:
|
||||
std::map<ir::value*, shared_view_val_t> shared_views_;
|
||||
};
|
||||
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
#endif
|
27
include/triton/codegen/optimize_cse.h
Normal file
27
include/triton/codegen/optimize_cse.h
Normal file
@@ -0,0 +1,27 @@
|
||||
#ifndef TDL_INCLUDE_CODEGEN_OPTIMIZE_CSE_H
|
||||
#define TDL_INCLUDE_CODEGEN_OPTIMIZE_CSE_H
|
||||
|
||||
#include <tuple>
|
||||
#include <vector>
|
||||
#include <set>
|
||||
|
||||
namespace triton {
|
||||
|
||||
namespace ir {
|
||||
class module;
|
||||
}
|
||||
|
||||
namespace codegen{
|
||||
class tune;
|
||||
|
||||
class optimize_cse {
|
||||
public:
|
||||
optimize_cse() {}
|
||||
void run(ir::module &mod);
|
||||
};
|
||||
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
#endif
|
31
include/triton/codegen/optimize_dot.h
Normal file
31
include/triton/codegen/optimize_dot.h
Normal file
@@ -0,0 +1,31 @@
|
||||
#ifndef TDL_INCLUDE_CODEGEN_OPTIMIZE_DOT_H
|
||||
#define TDL_INCLUDE_CODEGEN_OPTIMIZE_DOT_H
|
||||
|
||||
#include <tuple>
|
||||
#include <vector>
|
||||
#include <set>
|
||||
|
||||
namespace triton {
|
||||
|
||||
namespace ir {
|
||||
class module;
|
||||
}
|
||||
|
||||
namespace codegen{
|
||||
|
||||
class tune;
|
||||
|
||||
class optimize_dot {
|
||||
public:
|
||||
optimize_dot(tune* params): params_(params) {}
|
||||
void run(ir::module &mod);
|
||||
|
||||
private:
|
||||
tune* params_;
|
||||
};
|
||||
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
#endif
|
33
include/triton/codegen/optimize_trans.h
Normal file
33
include/triton/codegen/optimize_trans.h
Normal file
@@ -0,0 +1,33 @@
|
||||
#ifndef TDL_INCLUDE_CODEGEN_OPTIMIZE_TRANS_H
|
||||
#define TDL_INCLUDE_CODEGEN_OPTIMIZE_TRANS_H
|
||||
|
||||
#include <tuple>
|
||||
#include <vector>
|
||||
#include <set>
|
||||
|
||||
namespace triton {
|
||||
|
||||
namespace ir {
|
||||
class module;
|
||||
class value;
|
||||
class instruction;
|
||||
class trans_inst;
|
||||
class builder;
|
||||
}
|
||||
|
||||
namespace codegen{
|
||||
|
||||
class optimize_trans {
|
||||
private:
|
||||
ir::value *replace_phi(ir::value* value, std::vector<ir::instruction*>& to_delete, ir::builder &builder);
|
||||
|
||||
public:
|
||||
optimize_trans() {}
|
||||
void run(ir::module &mod);
|
||||
};
|
||||
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
#endif
|
@@ -7,7 +7,7 @@
|
||||
#include "triton/ir/module.h"
|
||||
#include "triton/ir/function.h"
|
||||
#include "triton/ir/type.h"
|
||||
#include "triton/codegen/buffer_info.h"
|
||||
#include "triton/codegen/shmem_info.h"
|
||||
|
||||
|
||||
namespace llvm{
|
||||
@@ -21,9 +21,9 @@ namespace llvm{
|
||||
namespace triton{
|
||||
namespace codegen{
|
||||
|
||||
class allocation;
|
||||
class shmem_allocation;
|
||||
class tune;
|
||||
class buffer_info_pass;
|
||||
class shmem_info;
|
||||
class target;
|
||||
|
||||
typedef std::vector<llvm::Value*> indices_t;
|
||||
@@ -129,7 +129,7 @@ private:
|
||||
void lower_tile_instruction(ir::instruction *src, llvm::IRBuilder<> &builder);
|
||||
|
||||
public:
|
||||
selection(allocation *alloc, tune *params, buffer_info_pass *buffer_info, target *tgt)
|
||||
selection(shmem_allocation *alloc, tune *params, shmem_info *buffer_info, target *tgt)
|
||||
: alloc_(alloc), params_(params), buffer_info_(buffer_info), tgt_(tgt){ }
|
||||
|
||||
void run(ir::module &src, llvm::Module &dst);
|
||||
@@ -139,11 +139,12 @@ private:
|
||||
tmap_t tmap_;
|
||||
pmap_t pmap_;
|
||||
pmap_t last_block_;
|
||||
allocation *alloc_;
|
||||
shmem_allocation *alloc_;
|
||||
tune *params_;
|
||||
target *tgt_;
|
||||
buffer_info_pass *buffer_info_;
|
||||
shmem_info *buffer_info_;
|
||||
std::map<ir::metaparameter*, distributed_axis> axes_;
|
||||
llvm::Value *sh_mem_ptr_;
|
||||
};
|
||||
|
||||
}
|
||||
|
@@ -1,41 +0,0 @@
|
||||
#ifndef TDL_INCLUDE_CODEGEN_SHARED_COPY_H
|
||||
#define TDL_INCLUDE_CODEGEN_SHARED_COPY_H
|
||||
|
||||
#include <tuple>
|
||||
#include <vector>
|
||||
|
||||
namespace triton {
|
||||
|
||||
namespace ir {
|
||||
class module;
|
||||
class value;
|
||||
class builder;
|
||||
class basic_block;
|
||||
}
|
||||
|
||||
namespace codegen{
|
||||
|
||||
class buffer_info_pass;
|
||||
|
||||
class place_shared_copy {
|
||||
private:
|
||||
typedef std::pair<unsigned, unsigned> interval_t;
|
||||
typedef std::vector<interval_t> interval_vec_t;
|
||||
|
||||
private:
|
||||
bool intersect(const interval_vec_t &I, interval_t i);
|
||||
void add_copy(ir::value *x, ir::builder &builder);
|
||||
|
||||
public:
|
||||
place_shared_copy(buffer_info_pass *info): info_(info) { }
|
||||
void run(ir::module &mod);
|
||||
|
||||
private:
|
||||
buffer_info_pass *info_;
|
||||
};
|
||||
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
#endif
|
@@ -16,12 +16,12 @@ namespace codegen{
|
||||
|
||||
class layout;
|
||||
class target_tuner;
|
||||
class liveness;
|
||||
class buffer_info_pass;
|
||||
class shmem_liveness;
|
||||
class shmem_info;
|
||||
|
||||
class allocation {
|
||||
class shmem_allocation {
|
||||
public:
|
||||
allocation(liveness *live, buffer_info_pass *buffer_info)
|
||||
shmem_allocation(shmem_liveness *live, shmem_info *buffer_info)
|
||||
: liveness_(live), buffer_info_(buffer_info){ }
|
||||
|
||||
// utilities
|
||||
@@ -39,8 +39,8 @@ private:
|
||||
std::map<ir::value*, unsigned> num_bytes_;
|
||||
size_t allocated_size_;
|
||||
// dependences
|
||||
liveness *liveness_;
|
||||
buffer_info_pass *buffer_info_;
|
||||
shmem_liveness *liveness_;
|
||||
shmem_info *buffer_info_;
|
||||
};
|
||||
|
||||
}
|
@@ -17,10 +17,10 @@ namespace ir {
|
||||
|
||||
namespace codegen{
|
||||
|
||||
class allocation;
|
||||
class buffer_info_pass;
|
||||
class shmem_allocation;
|
||||
class shmem_info;
|
||||
|
||||
class barriers {
|
||||
class shmem_barriers {
|
||||
private:
|
||||
typedef std::pair<unsigned, unsigned> interval_t;
|
||||
typedef std::vector<interval_t> interval_vec_t;
|
||||
@@ -36,12 +36,12 @@ private:
|
||||
std::pair<interval_vec_t, interval_vec_t> transfer(ir::basic_block *block, const interval_vec_t &written_to, const interval_vec_t &read_from, std::set<ir::instruction *> &insert_loc);
|
||||
|
||||
public:
|
||||
barriers(allocation *alloc, buffer_info_pass *buffer_info): alloc_(alloc), buffer_info_(buffer_info) {}
|
||||
shmem_barriers(shmem_allocation *alloc, shmem_info *buffer_info): alloc_(alloc), buffer_info_(buffer_info) {}
|
||||
void run(ir::module &mod);
|
||||
|
||||
private:
|
||||
allocation *alloc_;
|
||||
buffer_info_pass *buffer_info_;
|
||||
shmem_allocation *alloc_;
|
||||
shmem_info *buffer_info_;
|
||||
};
|
||||
|
||||
|
@@ -10,18 +10,19 @@ namespace ir {
|
||||
class module;
|
||||
class value;
|
||||
class phi_node;
|
||||
class instruction;
|
||||
}
|
||||
|
||||
namespace codegen{
|
||||
|
||||
class buffer_info_pass {
|
||||
class shmem_info {
|
||||
public:
|
||||
void run(ir::module &mod);
|
||||
// queries
|
||||
bool is_double(ir::value *x);
|
||||
void add_shared(ir::value *v);
|
||||
bool is_shared(ir::value *x);
|
||||
bool is_loop_latch(ir::phi_node *phi, ir::value *terminator);
|
||||
bool is_loop_latch(ir::phi_node *phi, ir::instruction *terminator);
|
||||
ir::value *get_reference(ir::value *x);
|
||||
void replace(ir::value* before, ir::value *after);
|
||||
|
@@ -15,7 +15,7 @@ namespace codegen{
|
||||
|
||||
typedef unsigned slot_index;
|
||||
|
||||
class buffer_info_pass;
|
||||
class shmem_info;
|
||||
|
||||
struct segment {
|
||||
slot_index start;
|
||||
@@ -30,7 +30,7 @@ struct segment {
|
||||
}
|
||||
};
|
||||
|
||||
class liveness {
|
||||
class shmem_liveness {
|
||||
private:
|
||||
typedef std::map<ir::value*, slot_index> indices_map_t;
|
||||
typedef std::map<ir::value*, segment> intervals_map_t;
|
||||
@@ -43,7 +43,7 @@ public:
|
||||
|
||||
public:
|
||||
// constructor
|
||||
liveness(buffer_info_pass *info): info_(info){ }
|
||||
shmem_liveness(shmem_info *info): info_(info){ }
|
||||
|
||||
// accessors
|
||||
const intervals_map_t& intervals() const { return intervals_; }
|
||||
@@ -53,7 +53,7 @@ public:
|
||||
void run(ir::module &mod);
|
||||
|
||||
private:
|
||||
buffer_info_pass *info_;
|
||||
shmem_info *info_;
|
||||
has_storage_map_t has_dedicated_storage_;
|
||||
indices_map_t indices_;
|
||||
intervals_map_t intervals_;
|
@@ -24,6 +24,7 @@ public:
|
||||
virtual llvm::Instruction* add_barrier(llvm::Module *module, llvm::IRBuilder<>& builder) = 0;
|
||||
virtual llvm::Value* get_global_offset(llvm::Module *module, llvm::IRBuilder<>& builder, unsigned stride, unsigned ax) = 0;
|
||||
virtual llvm::Value* get_local_id(llvm::Module *module, llvm::IRBuilder<>& builder, unsigned ax) = 0;
|
||||
virtual llvm::Value* get_block_id(llvm::Module *module, llvm::IRBuilder<>& builder, unsigned ax) = 0;
|
||||
bool is_gpu() const;
|
||||
|
||||
private:
|
||||
@@ -37,6 +38,7 @@ public:
|
||||
llvm::Instruction* add_barrier(llvm::Module *module, llvm::IRBuilder<>& builder);
|
||||
llvm::Value* get_global_offset(llvm::Module *module, llvm::IRBuilder<>& builder, unsigned stride, unsigned ax);
|
||||
llvm::Value* get_local_id(llvm::Module *module, llvm::IRBuilder<>& builder, unsigned ax);
|
||||
llvm::Value* get_block_id(llvm::Module *module, llvm::IRBuilder<>& builder, unsigned ax);
|
||||
};
|
||||
|
||||
class nvidia_cu_target: public target {
|
||||
@@ -46,6 +48,7 @@ public:
|
||||
llvm::Instruction* add_barrier(llvm::Module *module, llvm::IRBuilder<>& builder);
|
||||
llvm::Value* get_global_offset(llvm::Module *module, llvm::IRBuilder<>& builder, unsigned stride, unsigned ax);
|
||||
llvm::Value* get_local_id(llvm::Module *module, llvm::IRBuilder<>& builder, unsigned ax);
|
||||
llvm::Value* get_block_id(llvm::Module *module, llvm::IRBuilder<>& builder, unsigned ax);
|
||||
};
|
||||
|
||||
class cpu_target: public target {
|
||||
@@ -55,6 +58,7 @@ public:
|
||||
llvm::Instruction* add_barrier(llvm::Module *module, llvm::IRBuilder<>& builder);
|
||||
llvm::Value* get_global_offset(llvm::Module *module, llvm::IRBuilder<>& builder, unsigned stride, unsigned ax);
|
||||
llvm::Value* get_local_id(llvm::Module *module, llvm::IRBuilder<>& builder, unsigned ax);
|
||||
llvm::Value* get_block_id(llvm::Module *module, llvm::IRBuilder<>& builder, unsigned ax);
|
||||
};
|
||||
|
||||
}
|
||||
|
@@ -90,7 +90,7 @@ class cu_module: public module {
|
||||
public:
|
||||
cu_module(driver::context* context, llvm::Module *module);
|
||||
cu_module(driver::context* context, const std::string& source);
|
||||
cu_buffer symbol(const char * name) const;
|
||||
cu_buffer* symbol(const char * name) const;
|
||||
|
||||
private:
|
||||
std::string source_;
|
||||
|
@@ -67,6 +67,7 @@ public:
|
||||
value* create_fp_ext(value *src, type *dst_ty, const std::string &name = "");
|
||||
value* create_fp_trunc(value *src, type *dst_ty, const std::string &name = "");
|
||||
value* create_int_cast(value *src, type *dst_ty, bool is_signed, const std::string &name = "");
|
||||
value *create_downcast(value *arg, const std::string &name = "");
|
||||
// Phi instruction
|
||||
phi_node* create_phi(type *ty, unsigned num_reserved, const std::string &name = "");
|
||||
// Binary instructions
|
||||
@@ -124,7 +125,11 @@ public:
|
||||
value *create_broadcast(value *arg, const type::tile_shapes_t &shapes, const std::string &name = "");
|
||||
// Built-in instruction
|
||||
value *create_get_global_range(unsigned axis, type::tile_shapes_t::value_type size, const std::string &name = "");
|
||||
value *create_matmul(value *A, value *B, value *C, const std::string &name = "");
|
||||
value *create_get_range_id(unsigned axis, const std::string &name = "");
|
||||
value *create_atomic_cas(value *ptr, value *cmp, value *val, const std::string &name = "");
|
||||
value *create_dot(value *A, value *B, value *C, const std::string &name = "");
|
||||
value *create_trans(value *A, const std::string &name = "");
|
||||
value *create_select(value *pred, value *if_value, value *else_value, const std::string &name = "");
|
||||
// Intrinsics
|
||||
value *create_copy_to_shared(value *arg, const std::string &name = "");
|
||||
value *create_vectorize(value *arg, const std::string &name = "");
|
||||
|
@@ -54,6 +54,7 @@ public:
|
||||
void set_value(uint64_t value) { has_value_ = true; value_ = value; }
|
||||
bool has_value() { return has_value_; }
|
||||
const std::vector<unsigned>& get_space() { return space_; }
|
||||
void set_space(const std::vector<unsigned> &space) { space_ = space; }
|
||||
|
||||
private:
|
||||
std::vector<unsigned> space_;
|
||||
|
@@ -464,6 +464,17 @@ public:
|
||||
};
|
||||
|
||||
|
||||
// downcast
|
||||
|
||||
class downcast_inst: public unary_inst {
|
||||
private:
|
||||
using unary_inst::unary_inst;
|
||||
std::string repr_impl() const { return "downcast"; }
|
||||
|
||||
public:
|
||||
static instruction* create(value *arg, const std::string &name = "", instruction *next = nullptr);
|
||||
};
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// builtin_inst classes
|
||||
//===----------------------------------------------------------------------===//
|
||||
@@ -488,17 +499,76 @@ private:
|
||||
unsigned axis_;
|
||||
};
|
||||
|
||||
class matmul_inst: public builtin_inst {
|
||||
class get_range_id_inst: public builtin_inst {
|
||||
private:
|
||||
matmul_inst(value *A, value *B, value *C, const std::string &name, instruction *next);
|
||||
std::string repr_impl() const { return "dot"; }
|
||||
get_range_id_inst(type *ty, unsigned axis, const std::string &name, instruction *next);
|
||||
std::string repr_impl() const { return "get_range_id(" + std::to_string(axis_) + ")"; }
|
||||
|
||||
public:
|
||||
static instruction* create(value *A, value *B, value *C,
|
||||
const std::string &name = "",
|
||||
instruction *next = nullptr);
|
||||
static instruction* create(context &ctx, unsigned axis, const std::string &name = "", instruction *next = nullptr);
|
||||
unsigned get_axis() const { return axis_; }
|
||||
|
||||
private:
|
||||
unsigned axis_;
|
||||
};
|
||||
|
||||
class atomic_cas_inst: public builtin_inst {
|
||||
private:
|
||||
atomic_cas_inst(value *ptr, value *cmp, value *val, const std::string &name, instruction *next);
|
||||
std::string repr_impl() const { return "atomic_cas"; }
|
||||
|
||||
public:
|
||||
static instruction* create(value *ptr, value *cmp, value *val, const std::string &name = "", instruction *next = nullptr);
|
||||
};
|
||||
|
||||
class dot_inst: public builtin_inst {
|
||||
public:
|
||||
enum TransT { NoTrans, Trans };
|
||||
|
||||
private:
|
||||
dot_inst(value *A, value *B, value *C, TransT AT, TransT BT, const std::string &name, instruction *next);
|
||||
std::string repr_impl() const { return std::string("dot.") + ((AT_==NoTrans)?"n":"t") + ((BT_==NoTrans)?"n":"t"); }
|
||||
|
||||
public:
|
||||
static instruction* create_nn(value *A, value *B, value *C, const std::string &name = "", instruction *next = nullptr);
|
||||
static instruction* create_nt(value *A, value *B, value *C, const std::string &name = "", instruction *next = nullptr);
|
||||
static instruction* create_tn(value *A, value *B, value *C, const std::string &name = "", instruction *next = nullptr);
|
||||
static instruction* create_tt(value *A, value *B, value *C, const std::string &name = "", instruction *next = nullptr);
|
||||
bool is_a_trans() { return AT_ == Trans; }
|
||||
bool is_b_trans() { return BT_ == Trans; }
|
||||
|
||||
private:
|
||||
TransT AT_;
|
||||
TransT BT_;
|
||||
};
|
||||
|
||||
//class outer_inst: public builtin_inst {
|
||||
//private:
|
||||
// outer_inst(value *A, value *B, value *C, const std::string &name, instruction *next);
|
||||
//public:
|
||||
// static instruction* create(value *A, value *B, value *C, const std::string &name = "", instruction *next = nullptr);
|
||||
//};
|
||||
|
||||
class trans_inst: public builtin_inst {
|
||||
public:
|
||||
ir::type* get_res_ty(ir::type* in);
|
||||
|
||||
private:
|
||||
trans_inst(value *arg, const std::string& name, instruction* next);
|
||||
std::string repr_impl() const { return "trans"; }
|
||||
|
||||
public:
|
||||
static instruction* create(value *arg, const std::string &name = "", instruction *next = nullptr);
|
||||
};
|
||||
|
||||
class select_inst: public builtin_inst {
|
||||
private:
|
||||
select_inst(value *pred, value *if_value, value *else_value, const std::string& name, instruction* next);
|
||||
std::string repr_impl() const { return "select"; }
|
||||
|
||||
public:
|
||||
static instruction* create(value *pred, value *if_value, value *else_value, const std::string &name = "", instruction *next = nullptr);
|
||||
};
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// intrinsics classes
|
||||
|
@@ -66,6 +66,7 @@ public:
|
||||
// Getters
|
||||
value *get_value(const std::string& name, basic_block* block);
|
||||
value *get_value(const std::string& name);
|
||||
const std::string& get_name();
|
||||
std::function<ir::value*()> get_continue_fn();
|
||||
// Seal block -- no more predecessors will be added
|
||||
void seal_block(basic_block *block);
|
||||
|
@@ -10,13 +10,15 @@
|
||||
#include "triton/driver/kernel.h"
|
||||
#include "triton/codegen/selection.h"
|
||||
#include "triton/codegen/tune.h"
|
||||
#include "triton/codegen/shared_copy.h"
|
||||
#include "triton/codegen/allocation.h"
|
||||
#include "triton/codegen/liveness.h"
|
||||
#include "triton/codegen/vectorize.h"
|
||||
#include "triton/codegen/buffer_info.h"
|
||||
#include "triton/codegen/barriers.h"
|
||||
#include "triton/codegen/optimize_dot.h"
|
||||
#include "triton/codegen/optimize_cse.h"
|
||||
#include "triton/codegen/optimize_trans.h"
|
||||
#include "triton/codegen/shmem_allocation.h"
|
||||
#include "triton/codegen/shmem_liveness.h"
|
||||
#include "triton/codegen/shmem_info.h"
|
||||
#include "triton/codegen/shmem_barriers.h"
|
||||
#include "triton/codegen/target.h"
|
||||
#include "triton/codegen/vectorize.h"
|
||||
#include <functional>
|
||||
|
||||
namespace llvm {
|
||||
@@ -45,48 +47,59 @@ public:
|
||||
|
||||
struct passes_wrapper {
|
||||
passes_wrapper(codegen::target* target)
|
||||
: shared(&buffer_info), liveness(&buffer_info),
|
||||
allocation(&liveness, &buffer_info),
|
||||
barriers(&allocation, &buffer_info),
|
||||
: shmem_liveness(&shmem_info),
|
||||
shmem_allocation(&shmem_liveness, &shmem_info),
|
||||
shmem_barriers(&shmem_allocation, &shmem_info),
|
||||
vectorize(&tune),
|
||||
selection(&allocation, &tune, &buffer_info, target),
|
||||
selection(&shmem_allocation, &tune, &shmem_info, target),
|
||||
optimize_dot(&tune),
|
||||
optimize_cse(),
|
||||
optimize_trans(),
|
||||
target_(target) { }
|
||||
|
||||
void init(ir::module &module) {
|
||||
void target_independent(ir::module &module) {
|
||||
optimize_dot.run(module);
|
||||
optimize_trans.run(module);
|
||||
// ir::print(module, std::cout);
|
||||
}
|
||||
|
||||
void target_dependent(ir::module &module) {
|
||||
if(target_->is_gpu()){
|
||||
buffer_info.run(module);
|
||||
shared.run(module);
|
||||
liveness.run(module);
|
||||
allocation.run();
|
||||
barriers.run(module);
|
||||
shmem_info.run(module);
|
||||
shmem_liveness.run(module);
|
||||
shmem_allocation.run();
|
||||
shmem_barriers.run(module);
|
||||
}
|
||||
vectorize.run(module);
|
||||
}
|
||||
|
||||
codegen::tune tune;
|
||||
codegen::buffer_info_pass buffer_info;
|
||||
codegen::place_shared_copy shared;
|
||||
codegen::liveness liveness;
|
||||
codegen::allocation allocation;
|
||||
codegen::barriers barriers;
|
||||
codegen::shmem_info shmem_info;
|
||||
codegen::shmem_liveness shmem_liveness;
|
||||
codegen::shmem_allocation shmem_allocation;
|
||||
codegen::shmem_barriers shmem_barriers;
|
||||
codegen::vectorize vectorize;
|
||||
codegen::selection selection;
|
||||
codegen::optimize_dot optimize_dot;
|
||||
codegen::optimize_cse optimize_cse;
|
||||
codegen::optimize_trans optimize_trans;
|
||||
codegen::target* target_;
|
||||
};
|
||||
|
||||
private:
|
||||
std::string compute_data_layout(bool is_64bit = true, bool use_short_pointers = true);
|
||||
std::unique_ptr<llvm::Module> make_llvm_module(triton::ir::module &module, passes_wrapper &passes);
|
||||
std::unique_ptr<ir::module> make_triton_module(const std::string &src);
|
||||
std::unique_ptr<ir::module> make_triton_module(const std::string &name, const std::string &src);
|
||||
|
||||
public:
|
||||
jit(driver::context* context);
|
||||
void autotune(const std::string &src, benchmark_t benchmark);
|
||||
void autotune(const std::string &name, const std::string &src, benchmark_t benchmark);
|
||||
void add_module(ir::module &module, const std::vector<unsigned>& params = {});
|
||||
void add_module(const std::string &src, const std::vector<unsigned>& params = {});
|
||||
void add_module(const std::string &name, const std::string &src, const std::vector<unsigned>& params = {});
|
||||
driver::kernel* get_function(const std::string &name);
|
||||
launch_information get_launch_info(const std::string &name);
|
||||
unsigned get_int(const std::string &name);
|
||||
driver::buffer *get_buffer(const std::string &name);
|
||||
|
||||
private:
|
||||
std::vector<driver::module*> modules_;
|
||||
|
@@ -95,55 +95,75 @@ void node::implicit_cast(ir::builder &builder, ir::value *&lhs, ir::value *&rhs,
|
||||
throw std::runtime_error("unreachable");
|
||||
}
|
||||
|
||||
void node::implicit_broadcast(ir::module *mod, ir::value *&arg, ir::type *ty) {
|
||||
ir::value *tmp = ir::undef_value::get(ty);
|
||||
implicit_broadcast(mod, arg, tmp);
|
||||
}
|
||||
|
||||
void node::implicit_broadcast(ir::module *mod, ir::value *&lhs, ir::value *&rhs){
|
||||
ir::builder &builder = mod->get_builder();
|
||||
void node::implicit_broadcast(ir::module *mod, ir::value *&lhs, ir::value *&rhs) {
|
||||
ir::type *lhs_ty = lhs->get_type();
|
||||
ir::type *rhs_ty = rhs->get_type();
|
||||
ir::type::tile_shapes_t::value_type one = ir::tile_type::make_one(mod->get_context());
|
||||
// Both are scalar
|
||||
ir::type *res_ty = nullptr;
|
||||
if(!lhs_ty->is_tile_ty() && !rhs_ty->is_tile_ty())
|
||||
return;
|
||||
// One argument is scalar
|
||||
if(lhs_ty->is_tile_ty() ^ rhs_ty->is_tile_ty()){
|
||||
auto &shapes = lhs_ty->is_tile_ty()?lhs_ty->get_tile_shapes():rhs_ty->get_tile_shapes();
|
||||
auto &scalar = lhs_ty->is_tile_ty()?rhs:lhs;
|
||||
scalar = builder.create_splat(scalar, shapes);
|
||||
else if(lhs_ty->is_tile_ty() && !rhs_ty->is_tile_ty())
|
||||
res_ty = lhs_ty;
|
||||
else if(!lhs_ty->is_tile_ty() && rhs_ty->is_tile_ty())
|
||||
res_ty = rhs_ty;
|
||||
else{
|
||||
auto lhs_shapes = lhs_ty->get_tile_shapes();
|
||||
auto rhs_shapes = rhs_ty->get_tile_shapes();
|
||||
size_t lhs_size = lhs_shapes.size();
|
||||
size_t rhs_size = rhs_shapes.size();
|
||||
size_t res_size = std::max(lhs_size, rhs_size);
|
||||
ir::type::tile_shapes_t res_shapes(res_size);
|
||||
ir::type::tile_shapes_t::value_type one = ir::tile_type::make_one(mod->get_context());
|
||||
for(int i = 0; i < res_size; i++){
|
||||
if(i >= res_size - lhs_size && i >= res_size - rhs_size)
|
||||
res_shapes[i] = lhs_shapes[i]==one?rhs_shapes[i]:lhs_shapes[i];
|
||||
else if(i >= res_size - lhs_size)
|
||||
res_shapes[i] = lhs_shapes[i];
|
||||
else if(i >= res_size - rhs_size)
|
||||
res_shapes[i] = rhs_shapes[i];
|
||||
}
|
||||
res_ty = ir::tile_type::get(lhs_ty->get_scalar_ty(), res_shapes);
|
||||
}
|
||||
implicit_broadcast(mod, res_ty, rhs);
|
||||
implicit_broadcast(mod, res_ty, lhs);
|
||||
}
|
||||
|
||||
void node::implicit_broadcast(ir::module *mod, ir::type *ty, ir::value *&src){
|
||||
ir::builder &builder = mod->get_builder();
|
||||
ir::type *src_ty = src->get_type();
|
||||
ir::type::tile_shapes_t::value_type one = ir::tile_type::make_one(mod->get_context());
|
||||
// Both are scalar
|
||||
if(!ty->is_tile_ty() && !src_ty->is_tile_ty())
|
||||
return;
|
||||
// Broadcast scalar
|
||||
if(ty->is_tile_ty() && !src_ty->is_tile_ty()){
|
||||
src = builder.create_splat(src, ty->get_tile_shapes());
|
||||
return;
|
||||
}
|
||||
// Downcast tile
|
||||
if(!ty->is_tile_ty() && src_ty->is_tile_ty()){
|
||||
for(ir::constant *shape: src_ty->get_tile_shapes())
|
||||
if(shape != one)
|
||||
throw std::runtime_error("cannot downcast");
|
||||
src = builder.create_downcast(src);
|
||||
return;
|
||||
}
|
||||
// Both are arrays
|
||||
auto lhs_shapes = lhs->get_type()->get_tile_shapes();
|
||||
auto rhs_shapes = rhs->get_type()->get_tile_shapes();
|
||||
if(lhs_shapes == rhs_shapes)
|
||||
return;
|
||||
int lhs_dim = lhs_shapes.size();
|
||||
int rhs_dim = rhs_shapes.size();
|
||||
auto &shortest = (lhs_dim < rhs_dim)?lhs_shapes:rhs_shapes;
|
||||
auto &longest = (lhs_dim < rhs_dim)?rhs_shapes:lhs_shapes;
|
||||
size_t ndim = longest.size();
|
||||
int off = longest.size() - shortest.size();
|
||||
for(int i = longest.size() - 1; i>= 0; i--){
|
||||
if(shortest[off + i] != longest[i] && shortest[off + i] != one && longest[i] != one)
|
||||
throw std::runtime_error("cannot broadcast");
|
||||
}
|
||||
auto dst_shapes = ty->get_tile_shapes();
|
||||
auto src_shapes = src_ty->get_tile_shapes();
|
||||
int dst_dim = dst_shapes.size();
|
||||
int src_dim = src_shapes.size();
|
||||
// Pad
|
||||
int off = dst_dim - src_dim;
|
||||
for(size_t i = 0; i < off; i++)
|
||||
shortest.insert(shortest.begin(), one);
|
||||
ir::value *&target = (lhs_dim < rhs_dim)?lhs:rhs;
|
||||
src_shapes.insert(src_shapes.begin(), one);
|
||||
if(off > 0)
|
||||
target = builder.create_reshape(target, shortest);
|
||||
src = builder.create_reshape(src, src_shapes);
|
||||
// Broadcast
|
||||
ir::type::tile_shapes_t shapes(ndim);
|
||||
for(size_t i = 0; i < ndim; i++)
|
||||
shapes[i] = shortest[i]==one?longest[i]:shortest[i];
|
||||
if(shapes != lhs_shapes)
|
||||
lhs = builder.create_broadcast(lhs, shapes);
|
||||
if(shapes != rhs_shapes)
|
||||
rhs = builder.create_broadcast(rhs, shapes);
|
||||
for(int i = dst_dim - 1; i>= 0; i--)
|
||||
if(dst_shapes[i] != src_shapes[i] && dst_shapes[i] != one && src_shapes[i] != one)
|
||||
throw std::runtime_error("cannot broadcast");
|
||||
if(dst_shapes != src_shapes)
|
||||
src = builder.create_broadcast(src, dst_shapes);
|
||||
}
|
||||
|
||||
/* Helper */
|
||||
@@ -336,7 +356,9 @@ ir::value* iteration_statement::codegen(ir::module *mod) const{
|
||||
return builder.create_cond_br(cond, loop_bb, next_bb);
|
||||
});
|
||||
init_->codegen(mod);
|
||||
builder.create_br(loop_bb);
|
||||
ir::value *cond = stop_->codegen(mod);
|
||||
builder.create_cond_br(cond, loop_bb, next_bb);
|
||||
// builder.create_br(loop_bb);
|
||||
builder.set_insert_point(loop_bb);
|
||||
if(!is_terminator(statements_->codegen(mod)))
|
||||
mod->get_continue_fn()();
|
||||
@@ -378,6 +400,7 @@ ir::value* selection_statement::codegen(ir::module* mod) const{
|
||||
builder.create_br(endif_bb);
|
||||
}
|
||||
// Endif
|
||||
mod->seal_block(endif_bb);
|
||||
builder.set_insert_point(endif_bb);
|
||||
return nullptr;
|
||||
}
|
||||
@@ -422,7 +445,7 @@ ir::value* initializer::codegen(ir::module * mod) const{
|
||||
else if(expr_){
|
||||
value = expr_->codegen(mod);
|
||||
value = explicit_cast(mod->get_builder(), value, ty);
|
||||
implicit_broadcast(mod, value, ty);
|
||||
implicit_broadcast(mod, ty, value);
|
||||
}
|
||||
value->set_name(name);
|
||||
mod->set_value(name, value);
|
||||
@@ -543,6 +566,19 @@ ir::value* get_global_range::codegen(ir::module *mod) const {
|
||||
return builder.create_get_global_range(axis_->value(), (ir::constant_int*)size_->codegen(mod));
|
||||
}
|
||||
|
||||
// get_range_id
|
||||
ir::value* get_range_id::codegen(ir::module *mod) const {
|
||||
return mod->get_builder().create_get_range_id(axis_->value());
|
||||
}
|
||||
|
||||
// atomic cas
|
||||
ir::value* atomic_cas::codegen(ir::module *mod) const {
|
||||
ir::value *ptr = ptr_->codegen(mod);
|
||||
ir::value *cmp = cmp_->codegen(mod);
|
||||
ir::value *val = val_->codegen(mod);
|
||||
return mod->get_builder().create_atomic_cas(ptr, cmp, val);
|
||||
}
|
||||
|
||||
// matmul
|
||||
ir::value* matmul_expression::codegen(ir::module *mod) const {
|
||||
ir::value *A = A_->codegen(mod);
|
||||
@@ -554,10 +590,37 @@ ir::value* matmul_expression::codegen(ir::module *mod) const {
|
||||
// ir::type *tile_ty = ir::tile_type::get(scalar_ty, {M, N});
|
||||
// ir::value *tmp = ir::undef_value::get(tile_ty);
|
||||
// implicit_broadcast(mod, tmp, C);
|
||||
return mod->get_builder().create_matmul(A, B, C);
|
||||
return mod->get_builder().create_dot(A, B, C);
|
||||
}
|
||||
|
||||
// min
|
||||
ir::value* min_expression::codegen(ir::module *mod) const {
|
||||
ir::value* cmp = binary_operator(LT, (node*)x_, (node*)y_).codegen(mod);
|
||||
ir::value* x = ((ir::cmp_inst*)cmp)->get_operand(0);
|
||||
ir::value* y = ((ir::cmp_inst*)cmp)->get_operand(1);
|
||||
return mod->get_builder().create_select(cmp, x, y);
|
||||
}
|
||||
|
||||
// max
|
||||
ir::value* max_expression::codegen(ir::module *mod) const {
|
||||
ir::value* cmp = binary_operator(GT, (node*)x_, (node*)y_).codegen(mod);
|
||||
ir::value* x = ((ir::cmp_inst*)cmp)->get_operand(0);
|
||||
ir::value* y = ((ir::cmp_inst*)cmp)->get_operand(1);
|
||||
return mod->get_builder().create_select(cmp, x, y);
|
||||
}
|
||||
|
||||
// select
|
||||
ir::value* select_expression::codegen(ir::module *mod) const {
|
||||
ir::value* pred = pred_->codegen(mod);
|
||||
ir::value* if_value = if_value_->codegen(mod);
|
||||
ir::value* else_value = else_value_->codegen(mod);
|
||||
return mod->get_builder().create_select(pred, if_value, else_value);
|
||||
}
|
||||
|
||||
// Trans
|
||||
ir::value* trans_expression::codegen(ir::module *mod) const {
|
||||
return mod->get_builder().create_trans(arg_->codegen(mod));
|
||||
}
|
||||
|
||||
/* Postfix expression */
|
||||
ir::value* indexing_expression::codegen(ir::module *mod) const{
|
||||
@@ -573,6 +636,7 @@ ir::value* indexing_expression::codegen(ir::module *mod) const{
|
||||
return mod->get_builder().create_reshape(in, out_shapes);
|
||||
}
|
||||
|
||||
|
||||
/* Unary operator */
|
||||
ir::value *unary_operator::llvm_op(ir::builder &builder, ir::value *arg, const std::string &name) const{
|
||||
ir::type *atype = arg->get_type();
|
||||
@@ -666,7 +730,7 @@ ir::value *assignment_expression::codegen(ir::module *mod) const{
|
||||
if(auto *x = dynamic_cast<const named_expression*>(lvalue_)){
|
||||
ir::type *ty = mod->get_scope().types.at(x->id()->name());
|
||||
rvalue = explicit_cast(mod->get_builder(), rvalue, ty);
|
||||
implicit_broadcast(mod, rvalue, ty);
|
||||
implicit_broadcast(mod, ty, rvalue);
|
||||
mod->set_value(x->id()->name(), rvalue);
|
||||
}
|
||||
else if(auto* x = dynamic_cast<const unary_operator*>(lvalue_)){
|
||||
|
@@ -1,90 +0,0 @@
|
||||
#include "triton/codegen/buffer_info.h"
|
||||
#include "triton/ir/module.h"
|
||||
#include "triton/ir/function.h"
|
||||
#include "triton/ir/basic_block.h"
|
||||
#include "triton/ir/instructions.h"
|
||||
#include "triton/ir/type.h"
|
||||
|
||||
namespace triton {
|
||||
|
||||
namespace codegen{
|
||||
|
||||
|
||||
// run pass on module
|
||||
bool buffer_info_pass::is_loop_latch(ir::phi_node *phi, ir::value *terminator){
|
||||
if(auto *br = dynamic_cast<ir::cond_branch_inst*>(terminator))
|
||||
return br->get_true_dest() == phi->get_parent()
|
||||
|| br->get_false_dest() == phi->get_parent();
|
||||
else if(auto *br = dynamic_cast<ir::uncond_branch_inst*>(terminator))
|
||||
return false;
|
||||
else
|
||||
throw std::runtime_error("unreachable");
|
||||
}
|
||||
|
||||
void buffer_info_pass::replace(ir::value* before, ir::value *after) {
|
||||
shared_.erase(before);
|
||||
shared_.insert(after);
|
||||
if(refs_.find(before) != refs_.end()){
|
||||
ir::value* v = refs_.at(before);
|
||||
refs_.erase(before);
|
||||
refs_.insert({after, v});
|
||||
}
|
||||
}
|
||||
|
||||
void buffer_info_pass::run(ir::module &mod) {
|
||||
// Find which buffers are shared
|
||||
for(ir::function *fn: mod.get_function_list())
|
||||
for(ir::basic_block *block: fn->blocks())
|
||||
for(ir::instruction *i: block->get_inst_list())
|
||||
if(dynamic_cast<ir::matmul_inst*>(i)){
|
||||
shared_.insert(i->get_operand(0));
|
||||
shared_.insert(i->get_operand(1));
|
||||
}
|
||||
|
||||
// Handles phi nodes
|
||||
for(ir::function *fn: mod.get_function_list())
|
||||
for(ir::basic_block *block: fn->blocks())
|
||||
for(ir::instruction *i: block->get_inst_list()) {
|
||||
if(!i->get_type()->is_tile_ty())
|
||||
continue;
|
||||
// handle phi
|
||||
if(auto *phi = dynamic_cast<ir::phi_node*>(i))
|
||||
if(is_shared(phi)){
|
||||
// determine if the value is in shared memory
|
||||
bool is_double = false;
|
||||
for(unsigned n = 0; n < phi->get_num_incoming(); n++){
|
||||
ir::basic_block *inc_block = phi->get_incoming_block(n);
|
||||
ir::value *terminator = inc_block->get_inst_list().back();
|
||||
is_double = is_double || is_loop_latch(phi, terminator);
|
||||
}
|
||||
// add to double-buffered
|
||||
if(is_double)
|
||||
double_.insert(phi);
|
||||
// set references of input
|
||||
for(unsigned n = 0; n < phi->get_num_incoming(); n++){
|
||||
ir::value *inc_val = phi->get_incoming_value(n);
|
||||
refs_[inc_val] = phi;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for(auto &ref: refs_)
|
||||
shared_.insert(ref.first);
|
||||
}
|
||||
|
||||
// query double-buffered status
|
||||
bool buffer_info_pass::is_double(ir::value *x)
|
||||
{ return double_.find(x) != double_.end(); }
|
||||
|
||||
// query shared status
|
||||
bool buffer_info_pass::is_shared(ir::value *x)
|
||||
{ return shared_.find(x) != shared_.end(); }
|
||||
|
||||
// get reference if any
|
||||
ir::value *buffer_info_pass::get_reference(ir::value *x)
|
||||
{ return refs_[x]; }
|
||||
|
||||
|
||||
|
||||
}
|
||||
}
|
@@ -1,56 +0,0 @@
|
||||
#include "triton/codegen/layout.h"
|
||||
#include "triton/ir/function.h"
|
||||
#include "triton/ir/module.h"
|
||||
#include "triton/ir/basic_block.h"
|
||||
#include "triton/ir/instructions.h"
|
||||
|
||||
namespace triton{
|
||||
namespace codegen{
|
||||
|
||||
|
||||
shared_view_info layout::get_shared_view(ir::value *v, unsigned idx){
|
||||
return shared_views_.at(v)[idx];
|
||||
}
|
||||
|
||||
unsigned layout::get_num_shared_views(ir::value *v){
|
||||
return shared_views_.at(v).size();
|
||||
}
|
||||
|
||||
// Phi node
|
||||
void layout::add_phi_nodes(ir::value *v){
|
||||
if(ir::phi_node *phi = dynamic_cast<ir::phi_node*>(v))
|
||||
if(shared_views_.find(phi) != shared_views_.end())
|
||||
for(ir::value *v: phi->ops()){
|
||||
shared_views_[v] = shared_views_[phi];
|
||||
for(shared_view_info &info: shared_views_[v])
|
||||
info.has_dedicated_storage = false;
|
||||
}
|
||||
}
|
||||
|
||||
// Memory Layout
|
||||
void layout::add_shared_views(ir::value *v){
|
||||
// GEMM has shared inputs
|
||||
if(dynamic_cast<ir::matmul_inst*>(v))
|
||||
shared_views_[v].push_back({v, true});
|
||||
if(dynamic_cast<ir::reshape_inst*>(v))
|
||||
shared_views_[v].push_back({v, true});
|
||||
}
|
||||
|
||||
// Entry point
|
||||
void layout::run(ir::module &mod) {
|
||||
for(ir::function *fn: mod.get_function_list()){
|
||||
// Non-phis
|
||||
for(ir::basic_block *block: fn->blocks())
|
||||
for(ir::instruction *instr: block->get_inst_list()) {
|
||||
add_shared_views(instr);
|
||||
}
|
||||
// Phi nodes
|
||||
for(ir::basic_block *block: fn->blocks())
|
||||
for(ir::instruction *instr: block->get_inst_list()) {
|
||||
add_phi_nodes(instr);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
}
|
14
lib/codegen/optimize_cse.cpp
Normal file
14
lib/codegen/optimize_cse.cpp
Normal file
@@ -0,0 +1,14 @@
|
||||
#include "triton/ir/function.h"
|
||||
#include "triton/ir/basic_block.h"
|
||||
#include "triton/ir/module.h"
|
||||
#include "triton/codegen/optimize_cse.h"
|
||||
|
||||
namespace triton {
|
||||
namespace codegen{
|
||||
|
||||
|
||||
void optimize_cse::run(ir::module &mod) {
|
||||
}
|
||||
|
||||
}
|
||||
}
|
50
lib/codegen/optimize_dot.cpp
Normal file
50
lib/codegen/optimize_dot.cpp
Normal file
@@ -0,0 +1,50 @@
|
||||
#include "triton/ir/function.h"
|
||||
#include "triton/ir/basic_block.h"
|
||||
#include "triton/ir/module.h"
|
||||
#include "triton/codegen/optimize_dot.h"
|
||||
#include "triton/codegen/tune.h"
|
||||
|
||||
namespace triton {
|
||||
namespace codegen{
|
||||
|
||||
inline bool is_trans(ir::value *v){
|
||||
return dynamic_cast<ir::trans_inst*>(v) != nullptr;
|
||||
}
|
||||
|
||||
void optimize_dot::run(ir::module &mod) {
|
||||
ir::builder &builder = mod.get_builder();
|
||||
std::vector<ir::instruction*> to_delete;
|
||||
// iterate
|
||||
for(ir::function *fn: mod.get_function_list())
|
||||
for(ir::basic_block *block: fn->blocks())
|
||||
for(ir::instruction *i: block->get_inst_list())
|
||||
if(auto dot = dynamic_cast<ir::dot_inst*>(i))
|
||||
if(dot->get_operand(1)->get_type()->get_tile_shapes()[1]->get_value() != 1)
|
||||
if(!dot->is_a_trans() && !dot->is_b_trans()){
|
||||
builder.set_insert_point(i);
|
||||
ir::value *A = dot->get_operand(0);
|
||||
ir::value *B = dot->get_operand(1);
|
||||
ir::value *D = dot->get_operand(2);
|
||||
// dot(op(a), trans(b))
|
||||
if(is_trans(B)){
|
||||
ir::value* BN = ((ir::trans_inst*)B)->get_operand(0);
|
||||
ir::instruction *NT = builder.insert(ir::dot_inst::create_nt(A, BN, D));
|
||||
dot->replace_all_uses_with(NT);
|
||||
to_delete.push_back((ir::instruction*)B);
|
||||
to_delete.push_back(dot);
|
||||
}
|
||||
// dot(op(a), b)
|
||||
if(!is_trans(B)){
|
||||
ir::value* BT = builder.create_trans(B);
|
||||
ir::instruction *NT = builder.insert(ir::dot_inst::create_nt(A, BT, D));
|
||||
dot->replace_all_uses_with(NT);
|
||||
to_delete.push_back(dot);
|
||||
}
|
||||
}
|
||||
|
||||
for(ir::instruction* i: to_delete)
|
||||
i->erase_from_parent();
|
||||
}
|
||||
|
||||
}
|
||||
}
|
71
lib/codegen/optimize_trans.cpp
Normal file
71
lib/codegen/optimize_trans.cpp
Normal file
@@ -0,0 +1,71 @@
|
||||
#include "triton/ir/module.h"
|
||||
#include "triton/ir/function.h"
|
||||
#include "triton/codegen/optimize_trans.h"
|
||||
|
||||
namespace triton {
|
||||
namespace codegen{
|
||||
|
||||
|
||||
ir::value* optimize_trans::replace_phi(ir::value* value,
|
||||
std::vector<ir::instruction*>& to_delete,
|
||||
ir::builder& builder){
|
||||
if(auto phi = dynamic_cast<ir::phi_node*>(value)) {
|
||||
// transpose operands
|
||||
std::vector<ir::value*> incs;
|
||||
for(unsigned n = 0; n < phi->get_num_incoming(); n++)
|
||||
incs.push_back(replace_phi(phi->get_incoming_value(n), to_delete, builder));
|
||||
// create phi for transposed values
|
||||
builder.set_insert_point(phi);
|
||||
ir::phi_node* result = builder.create_phi(incs[0]->get_type(), incs.size(), phi->get_name());
|
||||
for(unsigned n = 0; n < phi->get_num_incoming(); n++)
|
||||
result->add_incoming(incs[n], phi->get_incoming_block(n));
|
||||
phi->replace_all_uses_with(result);
|
||||
to_delete.push_back(phi);
|
||||
return result;
|
||||
}
|
||||
else if(auto i = dynamic_cast<ir::instruction*>(value)){
|
||||
ir::basic_block* block = i->get_parent();
|
||||
auto it = std::find(block->begin(), block->end(), i);
|
||||
it++;
|
||||
builder.set_insert_point(it);
|
||||
ir::instruction *trans = (ir::instruction*)builder.create_trans(i);
|
||||
i->replace_all_uses_with(trans);
|
||||
trans->set_operand(0, i);
|
||||
return trans;
|
||||
}
|
||||
throw std::runtime_error("cannot transpose phi");
|
||||
}
|
||||
|
||||
|
||||
void optimize_trans::run(ir::module &mod) {
|
||||
ir::builder &builder = mod.get_builder();
|
||||
std::vector<ir::instruction*> to_delete;
|
||||
// iterate
|
||||
for(ir::function *fn: mod.get_function_list())
|
||||
for(ir::basic_block *block: fn->blocks())
|
||||
for(ir::instruction* i: block->get_inst_list()){
|
||||
// filter transposition
|
||||
if(auto trans = dynamic_cast<ir::trans_inst*>(i)) {
|
||||
auto users = trans->get_users();
|
||||
auto ops = trans->ops();
|
||||
if(users.size() > 1 || ops.size() > 1)
|
||||
continue;
|
||||
ir::value* op = *ops.begin();
|
||||
// chains of transpositions
|
||||
// TODO
|
||||
|
||||
// trans(phi) -> phi(trans(), trans()...)
|
||||
if(dynamic_cast<ir::phi_node*>(op)){
|
||||
ir::value* new_phi = replace_phi(op, to_delete, builder);
|
||||
to_delete.push_back(trans);
|
||||
trans->replace_all_uses_with(new_phi);
|
||||
}
|
||||
}
|
||||
}
|
||||
// erase dead code
|
||||
for(ir::instruction* i: to_delete)
|
||||
i->erase_from_parent();
|
||||
}
|
||||
|
||||
}
|
||||
}
|
@@ -1,6 +1,6 @@
|
||||
#include "triton/codegen/selection.h"
|
||||
#include "triton/codegen/tune.h"
|
||||
#include "triton/codegen/allocation.h"
|
||||
#include "triton/codegen/shmem_allocation.h"
|
||||
#include "triton/codegen/target.h"
|
||||
#include "llvm/IR/InstrTypes.h"
|
||||
#include "llvm/IR/Module.h"
|
||||
@@ -309,7 +309,47 @@ Instruction *selection::llvm_inst(ir::instruction *inst, std::function<Value*(ir
|
||||
}
|
||||
if(ir::load_inst* ii = dynamic_cast<ir::load_inst*>(inst)){
|
||||
Value *ptr = value(ii->get_pointer_operand());
|
||||
return builder.Insert(new LoadInst(ptr));
|
||||
LoadInst *result = new LoadInst(ptr);
|
||||
return builder.Insert(result);
|
||||
}
|
||||
if(ir::store_inst* ii = dynamic_cast<ir::store_inst*>(inst)){
|
||||
Value *val = value(ii->get_value_operand());
|
||||
Value *ptr = value(ii->get_pointer_operand());
|
||||
builder.CreateStore(val, ptr);
|
||||
return nullptr;
|
||||
}
|
||||
if(ir::select_inst* ii = dynamic_cast<ir::select_inst*>(inst)){
|
||||
Value *pred = value(ii->get_operand(0));
|
||||
Value *if_value = value(ii->get_operand(1));
|
||||
Value *else_value = value(ii->get_operand(2));
|
||||
return builder.Insert(SelectInst::Create(pred, if_value, else_value));
|
||||
}
|
||||
if(ir::get_range_id_inst* ii = dynamic_cast<ir::get_range_id_inst*>(inst)){
|
||||
Value *offset = tgt_->get_block_id(builder.GetInsertBlock()->getModule(), builder, ii->get_axis());
|
||||
return (Instruction*)builder.CreateAdd(offset, builder.getInt32(0));
|
||||
}
|
||||
if(ir::atomic_cas_inst* ii = dynamic_cast<ir::atomic_cas_inst*>(inst)){
|
||||
BasicBlock *current = builder.GetInsertBlock();
|
||||
Module *module = current->getModule();
|
||||
Value *tid = tgt_->get_local_id(module, builder, 0);
|
||||
Value *pred = builder.CreateICmpEQ(tid, builder.getInt32(0));
|
||||
BasicBlock *tid_0_bb = BasicBlock::Create(ctx, "tid_0", current->getParent());
|
||||
BasicBlock *tid_0_done_bb = BasicBlock::Create(ctx, "tid_0_done", current->getParent());
|
||||
Value *ptr = builder.CreateGEP(sh_mem_ptr_, builder.getInt32(alloc_->get_offset(ii)));
|
||||
ptr = builder.CreateBitCast(ptr, PointerType::get(builder.getInt32Ty(), ptr->getType()->getPointerAddressSpace()));
|
||||
builder.CreateCondBr(pred, tid_0_bb, tid_0_done_bb);
|
||||
builder.SetInsertPoint(tid_0_bb);
|
||||
Value *cas_ptr = value(ii->get_operand(0));
|
||||
Value *cas_cmp = value(ii->get_operand(1));
|
||||
Value *cas_val = value(ii->get_operand(2));
|
||||
Value *old = builder.CreateAtomicCmpXchg(cas_ptr, cas_cmp, cas_val, AtomicOrdering::Monotonic, AtomicOrdering::Monotonic);
|
||||
old = builder.CreateExtractValue(old, {0});
|
||||
builder.CreateStore(old, ptr);
|
||||
builder.CreateBr(tid_0_done_bb);
|
||||
builder.SetInsertPoint(tid_0_done_bb);
|
||||
tgt_->add_barrier(module, builder);
|
||||
Value *res = builder.CreateLoad(ptr);
|
||||
return (Instruction*)res;
|
||||
}
|
||||
// unknown instruction
|
||||
throw std::runtime_error("unknown conversion from ir::instruction to Instruction");
|
||||
@@ -446,7 +486,7 @@ void selection::create_grids(std::vector<ir::value*> &grids,
|
||||
bind_references(op);
|
||||
// bind
|
||||
const auto& shapes = v->get_type()->get_tile_shapes();
|
||||
if(dynamic_cast<ir::copy_to_shared_inst*>(v) || buffer_info_->is_double(v))
|
||||
if(buffer_info_->is_shared(v))
|
||||
return;
|
||||
for(size_t d = 0; d < shapes.size(); d++){
|
||||
if(shapes[d]->get_value() == 1)
|
||||
@@ -490,20 +530,11 @@ void selection::create_tile(ir::value *v, IRBuilder<> &builder,
|
||||
shapes2.push_back(shape->get_value());
|
||||
Type* ty = llvm_type(v->get_type()->get_scalar_ty(), ctx);
|
||||
// create shared tile
|
||||
if(dynamic_cast<ir::copy_to_shared_inst*>(v) || (buffer_info_->is_double(v))){
|
||||
if(buffer_info_->is_shared(v)){
|
||||
// shared copy
|
||||
PointerType *ptr_ty = ty->getPointerTo(sh_mem_ptr->getType()->getPointerAddressSpace());
|
||||
// TODO - buffer info not up-to-date with references
|
||||
if(dynamic_cast<ir::copy_to_shared_inst*>(v)) {
|
||||
if(!has_phi_user(v)){
|
||||
size_t offset = alloc_->get_offset(v);
|
||||
Value *ptr = builder.CreateGEP(sh_mem_ptr, builder.getInt32(offset));
|
||||
ptr = builder.CreateBitCast(ptr, ptr_ty);
|
||||
tmap_.insert({v, new shared_tile(ty, shapes2, ptr, builder)});
|
||||
}
|
||||
}
|
||||
// phi-node (double-buffering)
|
||||
else if(auto *phi = dynamic_cast<ir::phi_node*>(v)) {
|
||||
if(auto *phi = dynamic_cast<ir::phi_node*>(v)) {
|
||||
BasicBlock *parent = (BasicBlock*)vmap_[phi->get_parent()];
|
||||
unsigned id_pre = 0, id_loop = 1;
|
||||
if(phi->get_incoming_block(0) == phi->get_parent())
|
||||
@@ -522,13 +553,19 @@ void selection::create_tile(ir::value *v, IRBuilder<> &builder,
|
||||
for(unsigned i = 0; i < phi->get_num_incoming(); i++) {
|
||||
ir::basic_block* inc_block = phi->get_incoming_block(i);
|
||||
ir::value* inc_value = phi->get_incoming_value(i);
|
||||
ir::value* terminator = inc_block->get_inst_list().back();
|
||||
ir::instruction* terminator = inc_block->get_inst_list().back();
|
||||
bool is_loop_latch = buffer_info_->is_loop_latch(phi, terminator);
|
||||
tmap_.insert({inc_value, new shared_tile(ty, shapes2, is_loop_latch?next_ptr:pre_ptr, builder)});
|
||||
}
|
||||
}
|
||||
else
|
||||
throw std::runtime_error("unknown shared memory tile");
|
||||
else {
|
||||
if(!has_phi_user(v)){
|
||||
size_t offset = alloc_->get_offset(v);
|
||||
Value *ptr = builder.CreateGEP(sh_mem_ptr, builder.getInt32(offset));
|
||||
ptr = builder.CreateBitCast(ptr, ptr_ty);
|
||||
tmap_.insert({v, new shared_tile(ty, shapes2, ptr, builder)});
|
||||
}
|
||||
}
|
||||
}
|
||||
// create distributed tile
|
||||
else {
|
||||
@@ -607,10 +644,16 @@ void selection::lower_tile_instruction(ir::instruction *ins, llvm::IRBuilder<> &
|
||||
tile *value = tmap_.at(x->get_value_operand());
|
||||
ptr->for_each([&](indices_t idx){
|
||||
set_mask_insert_pt(idx);
|
||||
builder.CreateStore(value->get_value(idx), ptr->get_value(idx));
|
||||
StoreInst *store = new StoreInst(value->get_value(idx), ptr->get_value(idx));
|
||||
// store->setAlignment(16);
|
||||
builder.Insert(store);
|
||||
});
|
||||
}
|
||||
else {
|
||||
if(auto *x = dynamic_cast<ir::downcast_inst*>(ins)){
|
||||
vmap_[x] = tmap_[x->get_operand(0)]->get_value({builder.getInt32(0)});
|
||||
return;
|
||||
}
|
||||
tile *ti = tmap_[ins];
|
||||
distributed_tile* result = (distributed_tile*)ti;
|
||||
if(!ins->get_type()->is_tile_ty())
|
||||
@@ -727,31 +770,67 @@ void selection::lower_tile_instruction(ir::instruction *ins, llvm::IRBuilder<> &
|
||||
ti->set_value(idx, in->get_value(idx));
|
||||
});
|
||||
}
|
||||
else if(dynamic_cast<ir::copy_to_shared_inst*>(ins) || (buffer_info_->is_double(ins)))
|
||||
// trans
|
||||
else if(dynamic_cast<ir::trans_inst*>(ins)) {
|
||||
distributed_tile* in = (distributed_tile*)tmap_.at(ins->get_operand(0));
|
||||
in->for_each([&](indices_t idx){
|
||||
indices_t out_idx = idx;
|
||||
std::rotate(out_idx.begin(), out_idx.begin() + 1, out_idx.end());
|
||||
ti->set_value(out_idx, in->get_value(idx));
|
||||
});
|
||||
}
|
||||
else if(buffer_info_->is_shared(ins))
|
||||
return;
|
||||
// matrix multiplication
|
||||
else if(dynamic_cast<ir::matmul_inst*>(ins)) {
|
||||
// dot
|
||||
else if(auto dot = dynamic_cast<ir::dot_inst*>(ins)) {
|
||||
ir::value *A = ins->get_operand(0);
|
||||
ir::value *B = ins->get_operand(1);
|
||||
ir::value *C = ins->get_operand(2);
|
||||
shared_tile *TA = (shared_tile*)tmap_.at(A);
|
||||
shared_tile *TB = (shared_tile*)tmap_.at(B);
|
||||
bool AT = dot->is_a_trans();
|
||||
bool BT = dot->is_b_trans();
|
||||
distributed_tile *TC = (distributed_tile*)tmap_.at(C);
|
||||
TA->set_vector_size(TC->axis(0).contiguous);
|
||||
TB->set_vector_size(TC->axis(1).contiguous);
|
||||
Function *f_mul_add = Intrinsic::getDeclaration(module, Intrinsic::fmuladd, {llvm_type(C->get_type()->get_scalar_ty(), ctx)});
|
||||
result->for_each([&](indices_t idx){
|
||||
Value *res = TC->get_value(idx);
|
||||
unsigned NK = A->get_type()->get_tile_shapes()[1]->get_value();
|
||||
for(unsigned K = 0; K < NK; ++K){
|
||||
indices_t a_idx = {idx[0], builder.getInt32(K)};
|
||||
indices_t b_idx = {idx[1], builder.getInt32(K)};
|
||||
if(dot->get_operand(0)->get_type()->get_tile_shapes()[1]->get_value() != 1)
|
||||
{
|
||||
shared_tile *TA = (shared_tile*)tmap_.at(A);
|
||||
shared_tile *TB = (shared_tile*)tmap_.at(B);
|
||||
TA->set_vector_size(TC->axis(0).contiguous);
|
||||
TB->set_vector_size(TC->axis(1).contiguous);
|
||||
result->for_each([&](indices_t idx){
|
||||
Value *res = TC->get_value(idx);
|
||||
unsigned NK = A->get_type()->get_tile_shapes()[1]->get_value();
|
||||
for(unsigned K = 0; K < NK; ++K){
|
||||
indices_t a_idx = {idx[0], builder.getInt32(K)};
|
||||
indices_t b_idx = {builder.getInt32(K), idx[1]};
|
||||
if(AT)
|
||||
std::swap(a_idx[0], a_idx[1]);
|
||||
if(BT)
|
||||
std::swap(b_idx[0], b_idx[1]);
|
||||
Value *a = TA->get_value(a_idx);
|
||||
Value *b = TB->get_value(b_idx);
|
||||
res = builder.CreateCall(f_mul_add, {a, b, res});
|
||||
}
|
||||
result->set_value(idx, res);
|
||||
});
|
||||
}
|
||||
else
|
||||
{
|
||||
distributed_tile *TA = (distributed_tile*)tmap_.at(A);
|
||||
distributed_tile *TB = (distributed_tile*)tmap_.at(B);
|
||||
result->for_each([&](indices_t idx){
|
||||
Value *res = TC->get_value(idx);
|
||||
indices_t a_idx = {idx[0], builder.getInt32(0)};
|
||||
indices_t b_idx = {builder.getInt32(0), idx[1]};
|
||||
if(AT)
|
||||
std::swap(a_idx[0], a_idx[1]);
|
||||
if(BT)
|
||||
std::swap(b_idx[0], b_idx[1]);
|
||||
Value *a = TA->get_value(a_idx);
|
||||
Value *b = TB->get_value(b_idx);
|
||||
res = builder.CreateCall(f_mul_add, {a, b, res});
|
||||
}
|
||||
result->set_value(idx, res);
|
||||
});
|
||||
result->set_value(idx, res);
|
||||
});
|
||||
}
|
||||
}
|
||||
// element-wise
|
||||
else {
|
||||
@@ -858,6 +937,7 @@ void selection::run(ir::module &src, Module &dst) {
|
||||
nullptr, "__shared_ptr", nullptr, GlobalVariable::NotThreadLocal, 3);
|
||||
sh_mem_ptr = dst_builder.CreateBitCast(sh_mem_array, ptr_ty);
|
||||
}
|
||||
sh_mem_ptr_ = sh_mem_ptr;
|
||||
|
||||
// create grids
|
||||
init_grids(fn, dst_builder, sh_mem_ptr);
|
||||
@@ -890,7 +970,7 @@ void selection::run(ir::module &src, Module &dst) {
|
||||
for(unsigned n = 0; n < phi->get_num_incoming(); n++){
|
||||
ir::basic_block* inc_block = phi->get_incoming_block(n);
|
||||
ir::value* inc_val = phi->get_incoming_value(n);
|
||||
ir::value* terminator = inc_block->get_inst_list().back();
|
||||
ir::instruction* terminator = inc_block->get_inst_list().back();
|
||||
BasicBlock *llvm_inc_block = last_block.at(inc_block);
|
||||
shared_tile *inc_shared = (shared_tile*)tmap_.at(inc_val);
|
||||
bool is_loop_latch = buffer_info_->is_loop_latch(phi, terminator);
|
||||
@@ -920,8 +1000,8 @@ void selection::run(ir::module &src, Module &dst) {
|
||||
});
|
||||
}
|
||||
else {
|
||||
PHINode *llvm_phi = (PHINode*)vmap_.at(phi);
|
||||
Value *llvm_inc_val = vmap_.at(inc_val);
|
||||
PHINode *llvm_phi = (PHINode*)llvm_value(phi, dst_builder);
|
||||
Value *llvm_inc_val = llvm_value(inc_val, dst_builder);
|
||||
llvm_phi->addIncoming(llvm_inc_val, llvm_inc_block);
|
||||
}
|
||||
}
|
||||
|
@@ -1,40 +0,0 @@
|
||||
#include <algorithm>
|
||||
#include "triton/codegen/shared_copy.h"
|
||||
#include "triton/codegen/buffer_info.h"
|
||||
#include "triton/ir/module.h"
|
||||
#include "triton/ir/function.h"
|
||||
#include "triton/ir/basic_block.h"
|
||||
#include "triton/ir/instructions.h"
|
||||
|
||||
namespace triton {
|
||||
|
||||
namespace codegen{
|
||||
|
||||
void place_shared_copy::add_copy(ir::value *x, ir::builder &builder) {
|
||||
if(auto *i = dynamic_cast<ir::instruction*>(x)){
|
||||
ir::basic_block* block = i->get_parent();
|
||||
auto it = std::find(block->begin(), block->end(), i);
|
||||
builder.set_insert_point(++it);
|
||||
}
|
||||
ir::instruction *rx = (ir::instruction*)builder.create_copy_to_shared(x);
|
||||
x->replace_all_uses_with(rx);
|
||||
rx->set_operand(0, x);
|
||||
}
|
||||
|
||||
void place_shared_copy::run(ir::module &mod) {
|
||||
ir::builder &builder = mod.get_builder();
|
||||
for(ir::function *fn: mod.get_function_list())
|
||||
for(ir::basic_block *block: fn->blocks())
|
||||
for(ir::instruction *i: block->get_inst_list())
|
||||
if(info_->is_shared(i) && !info_->is_double(i))
|
||||
add_copy(i, builder);
|
||||
|
||||
for(ir::function *fn: mod.get_function_list())
|
||||
for(ir::basic_block *block: fn->blocks())
|
||||
for(ir::instruction *i: block->get_inst_list())
|
||||
if(auto* cts = dynamic_cast<ir::copy_to_shared_inst*>(i))
|
||||
info_->replace(cts->get_operand(0), cts);
|
||||
}
|
||||
|
||||
}
|
||||
}
|
@@ -1,7 +1,6 @@
|
||||
#include "triton/codegen/allocation.h"
|
||||
#include "triton/codegen/liveness.h"
|
||||
#include "triton/codegen/layout.h"
|
||||
#include "triton/codegen/buffer_info.h"
|
||||
#include "triton/codegen/shmem_allocation.h"
|
||||
#include "triton/codegen/shmem_liveness.h"
|
||||
#include "triton/codegen/shmem_info.h"
|
||||
#include "triton/ir/basic_block.h"
|
||||
#include "triton/ir/type.h"
|
||||
#include "triton/ir/value.h"
|
||||
@@ -11,14 +10,14 @@
|
||||
namespace triton{
|
||||
namespace codegen{
|
||||
|
||||
unsigned allocation::get_num_bytes(ir::value *x) {
|
||||
unsigned result = x->get_type()->get_tile_bitwidth() / 8;
|
||||
unsigned shmem_allocation::get_num_bytes(ir::value *x) {
|
||||
unsigned result = x->get_type()->get_primitive_size_in_bits() / 8;
|
||||
if(buffer_info_->is_double(x))
|
||||
result *= 2;
|
||||
return result;
|
||||
}
|
||||
|
||||
void allocation::run(){
|
||||
void shmem_allocation::run(){
|
||||
using std::max;
|
||||
using std::min;
|
||||
typedef std::multimap<unsigned, segment> triples_map_type;
|
@@ -1,7 +1,7 @@
|
||||
#include <algorithm>
|
||||
#include "triton/codegen/barriers.h"
|
||||
#include "triton/codegen/allocation.h"
|
||||
#include "triton/codegen/buffer_info.h"
|
||||
#include "triton/codegen/shmem_barriers.h"
|
||||
#include "triton/codegen/shmem_allocation.h"
|
||||
#include "triton/codegen/shmem_info.h"
|
||||
#include "triton/ir/module.h"
|
||||
#include "triton/ir/function.h"
|
||||
#include "triton/ir/basic_block.h"
|
||||
@@ -12,7 +12,7 @@ namespace triton {
|
||||
|
||||
namespace codegen{
|
||||
|
||||
bool barriers::intersect(const interval_vec_t &X, interval_t x) {
|
||||
bool shmem_barriers::intersect(const interval_vec_t &X, interval_t x) {
|
||||
return std::any_of(X.begin(), X.end(), [&](const interval_t &y){
|
||||
bool left_intersect = y.first <= x.first && x.first < y.second;
|
||||
bool right_intersect = y.first <= x.second && x.second < y.second;
|
||||
@@ -20,31 +20,31 @@ bool barriers::intersect(const interval_vec_t &X, interval_t x) {
|
||||
});
|
||||
}
|
||||
|
||||
bool barriers::intersect(const interval_vec_t &X, const interval_vec_t &Y) {
|
||||
bool shmem_barriers::intersect(const interval_vec_t &X, const interval_vec_t &Y) {
|
||||
return std::any_of(Y.begin(), Y.end(), [&](const interval_t &y){
|
||||
return intersect(X, y);
|
||||
});
|
||||
}
|
||||
|
||||
void barriers::add_reference(ir::value *v, interval_vec_t &res){
|
||||
if(dynamic_cast<ir::copy_to_shared_inst*>(v)){
|
||||
void shmem_barriers::add_reference(ir::value *v, interval_vec_t &res){
|
||||
if(buffer_info_->is_shared(v) && !dynamic_cast<ir::phi_node*>(v)){
|
||||
unsigned offset = alloc_->get_offset(v);
|
||||
unsigned num_bytes = alloc_->get_num_bytes(v);
|
||||
res.push_back(interval_t(offset, offset + num_bytes));
|
||||
}
|
||||
}
|
||||
|
||||
void barriers::get_read_intervals(ir::instruction *i, interval_vec_t &res){
|
||||
void shmem_barriers::get_read_intervals(ir::instruction *i, interval_vec_t &res){
|
||||
for(ir::value *op: i->ops())
|
||||
add_reference(op, res);
|
||||
}
|
||||
|
||||
void barriers::get_written_intervals(ir::instruction *i, interval_vec_t &res){
|
||||
void shmem_barriers::get_written_intervals(ir::instruction *i, interval_vec_t &res){
|
||||
if(!dynamic_cast<ir::phi_node*>(i))
|
||||
add_reference(i, res);
|
||||
}
|
||||
|
||||
void barriers::insert_barrier(ir::instruction *instr, ir::builder &builder) {
|
||||
void shmem_barriers::insert_barrier(ir::instruction *instr, ir::builder &builder) {
|
||||
if(auto *phi = dynamic_cast<ir::phi_node*>(instr)) {
|
||||
std::set<ir::value*> incoming;
|
||||
for(unsigned n = 0; n < phi->get_num_incoming(); n++){
|
||||
@@ -63,16 +63,16 @@ void barriers::insert_barrier(ir::instruction *instr, ir::builder &builder) {
|
||||
}
|
||||
}
|
||||
|
||||
barriers::interval_vec_t barriers::join(const std::vector<interval_vec_t>& intervals) {
|
||||
barriers::interval_vec_t result;
|
||||
shmem_barriers::interval_vec_t shmem_barriers::join(const std::vector<interval_vec_t>& intervals) {
|
||||
shmem_barriers::interval_vec_t result;
|
||||
for(auto x: intervals)
|
||||
for(interval_t i: x)
|
||||
result.push_back(i);
|
||||
return result;
|
||||
}
|
||||
|
||||
std::pair<barriers::interval_vec_t,
|
||||
barriers::interval_vec_t> barriers::transfer(ir::basic_block *block,
|
||||
std::pair<shmem_barriers::interval_vec_t,
|
||||
shmem_barriers::interval_vec_t> shmem_barriers::transfer(ir::basic_block *block,
|
||||
const interval_vec_t &written_to,
|
||||
const interval_vec_t &read_from,
|
||||
std::set<ir::instruction*>& insert_loc) {
|
||||
@@ -83,13 +83,13 @@ std::pair<barriers::interval_vec_t,
|
||||
interval_vec_t read, written;
|
||||
get_read_intervals(i, read);
|
||||
get_written_intervals(i, written);
|
||||
bool read_while_written = intersect(new_written_to, read);
|
||||
bool written_while_read = intersect(new_read_from, written);
|
||||
bool read_after_write = intersect(new_written_to, read);
|
||||
bool write_after_read = intersect(new_read_from, written);
|
||||
// double buffering: write and phi-node read won't intersect
|
||||
if(dynamic_cast<ir::copy_to_shared_inst*>(i) &&
|
||||
if(buffer_info_->is_shared(i) &&
|
||||
buffer_info_->is_double(buffer_info_->get_reference(i)))
|
||||
written_while_read = false;
|
||||
if(read_while_written || written_while_read) {
|
||||
write_after_read = false;
|
||||
if(read_after_write || write_after_read) {
|
||||
insert_loc.insert(i);
|
||||
new_written_to.clear();
|
||||
new_read_from.clear();
|
||||
@@ -100,7 +100,7 @@ std::pair<barriers::interval_vec_t,
|
||||
return std::make_pair(new_written_to, new_read_from);
|
||||
}
|
||||
|
||||
void barriers::run(ir::module &mod) {
|
||||
void shmem_barriers::run(ir::module &mod) {
|
||||
ir::builder &builder = mod.get_builder();
|
||||
for(ir::function *fn: mod.get_function_list()){
|
||||
std::vector<ir::basic_block*> rpo = ir::cfg::reverse_post_order(fn);
|
135
lib/codegen/shmem_info.cpp
Normal file
135
lib/codegen/shmem_info.cpp
Normal file
@@ -0,0 +1,135 @@
|
||||
#include "triton/codegen/shmem_info.h"
|
||||
#include "triton/ir/module.h"
|
||||
#include "triton/ir/function.h"
|
||||
#include "triton/ir/basic_block.h"
|
||||
#include "triton/ir/instructions.h"
|
||||
#include "triton/ir/type.h"
|
||||
|
||||
namespace triton {
|
||||
|
||||
namespace codegen{
|
||||
|
||||
|
||||
// run pass on module
|
||||
bool shmem_info::is_loop_latch(ir::phi_node *phi, ir::instruction *terminator){
|
||||
if(phi->get_parent() != terminator->get_parent())
|
||||
return false;
|
||||
if(auto *br = dynamic_cast<ir::cond_branch_inst*>(terminator))
|
||||
return br->get_true_dest() == phi->get_parent()
|
||||
|| br->get_false_dest() == phi->get_parent();
|
||||
else if(auto *br = dynamic_cast<ir::uncond_branch_inst*>(terminator))
|
||||
return false;
|
||||
else
|
||||
throw std::runtime_error("unreachable");
|
||||
}
|
||||
|
||||
void shmem_info::replace(ir::value* before, ir::value *after) {
|
||||
shared_.erase(before);
|
||||
shared_.insert(after);
|
||||
if(refs_.find(before) != refs_.end()){
|
||||
ir::value* v = refs_.at(before);
|
||||
refs_.erase(before);
|
||||
refs_.insert({after, v});
|
||||
}
|
||||
}
|
||||
|
||||
inline bool get_is_shared(ir::value* v) {
|
||||
if(auto x = dynamic_cast<ir::atomic_cas_inst*>(v))
|
||||
return true;
|
||||
if(auto x = dynamic_cast<ir::trans_inst*>(v))
|
||||
return true;
|
||||
if(auto x = dynamic_cast<ir::copy_to_shared_inst*>(v))
|
||||
return true;
|
||||
if(auto x = dynamic_cast<ir::phi_node*>(v)){
|
||||
bool res = true;
|
||||
for(unsigned inc = 0; inc < x->get_num_incoming(); inc++)
|
||||
res = res && get_is_shared(x->get_incoming_value(inc));
|
||||
return res;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
void add_copy(ir::value *x, ir::builder &builder) {
|
||||
if(auto phi = dynamic_cast<ir::phi_node*>(x)){
|
||||
for(unsigned i = 0; i < phi->get_num_incoming(); ++i)
|
||||
add_copy(phi->get_incoming_value(i), builder);
|
||||
}
|
||||
else {
|
||||
if(get_is_shared(x))
|
||||
return;
|
||||
if(auto *i = dynamic_cast<ir::instruction*>(x)){
|
||||
ir::basic_block* block = i->get_parent();
|
||||
auto it = std::find(block->begin(), block->end(), i);
|
||||
builder.set_insert_point(++it);
|
||||
}
|
||||
ir::instruction *rx = (ir::instruction*)builder.create_copy_to_shared(x);
|
||||
x->replace_all_uses_with(rx);
|
||||
rx->set_operand(0, x);
|
||||
}
|
||||
}
|
||||
|
||||
void shmem_info::run(ir::module &mod) {
|
||||
// Add shared copies
|
||||
for(ir::function *fn: mod.get_function_list()){
|
||||
ir::builder builder(mod.get_context());
|
||||
for(ir::basic_block *block: fn->blocks())
|
||||
for(ir::instruction *i: block->get_inst_list()){
|
||||
if(dynamic_cast<ir::dot_inst*>(i))
|
||||
if(i->get_operand(1)->get_type()->get_tile_shapes()[1]->get_value() != 1){
|
||||
add_copy(i->get_operand(0), builder);
|
||||
add_copy(i->get_operand(1), builder);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Find which buffers are shared
|
||||
for(ir::function *fn: mod.get_function_list())
|
||||
for(ir::basic_block *block: fn->blocks())
|
||||
for(ir::instruction *i: block->get_inst_list())
|
||||
if(get_is_shared(i))
|
||||
shared_.insert(i);
|
||||
|
||||
// double-buffering
|
||||
for(ir::function *fn: mod.get_function_list())
|
||||
for(ir::basic_block *block: fn->blocks())
|
||||
for(ir::instruction *i: block->get_inst_list()) {
|
||||
if(!i->get_type()->is_tile_ty())
|
||||
continue;
|
||||
// handle phi
|
||||
if(auto *phi = dynamic_cast<ir::phi_node*>(i))
|
||||
if(is_shared(phi)){
|
||||
// determine if the value is in shared memory
|
||||
bool is_double = false;
|
||||
for(unsigned n = 0; n < phi->get_num_incoming(); n++){
|
||||
ir::basic_block *inc_block = phi->get_incoming_block(n);
|
||||
ir::instruction *terminator = inc_block->get_inst_list().back();
|
||||
is_double = is_double || is_loop_latch(phi, terminator);
|
||||
}
|
||||
// add to double-buffered
|
||||
if(is_double)
|
||||
double_.insert(phi);
|
||||
// set references of input
|
||||
for(unsigned n = 0; n < phi->get_num_incoming(); n++){
|
||||
ir::value *inc_val = phi->get_incoming_value(n);
|
||||
refs_[inc_val] = phi;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// query double-buffered status
|
||||
bool shmem_info::is_double(ir::value *x)
|
||||
{ return double_.find(x) != double_.end(); }
|
||||
|
||||
// query shared status
|
||||
bool shmem_info::is_shared(ir::value *x)
|
||||
{ return shared_.find(x) != shared_.end(); }
|
||||
|
||||
// get reference if any
|
||||
ir::value *shmem_info::get_reference(ir::value *x)
|
||||
{ return refs_[x]; }
|
||||
|
||||
|
||||
|
||||
}
|
||||
}
|
@@ -1,5 +1,5 @@
|
||||
#include "triton/codegen/liveness.h"
|
||||
#include "triton/codegen/buffer_info.h"
|
||||
#include "triton/codegen/shmem_liveness.h"
|
||||
#include "triton/codegen/shmem_info.h"
|
||||
#include "triton/ir/basic_block.h"
|
||||
#include "triton/ir/function.h"
|
||||
#include "triton/ir/module.h"
|
||||
@@ -11,19 +11,7 @@ namespace codegen{
|
||||
|
||||
|
||||
// Entry point
|
||||
inline bool is_shared(ir::value* v) {
|
||||
if(auto x = dynamic_cast<ir::copy_to_shared_inst*>(v))
|
||||
return true;
|
||||
if(auto x = dynamic_cast<ir::phi_node*>(v)){
|
||||
bool res = true;
|
||||
for(unsigned inc = 0; inc < x->get_num_incoming(); inc++)
|
||||
res = res && is_shared(x->get_incoming_value(inc));
|
||||
return res;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
void liveness::run(ir::module &mod) {
|
||||
void shmem_liveness::run(ir::module &mod) {
|
||||
for(ir::function *fn: mod.get_function_list()){
|
||||
// Assigns index to each instruction
|
||||
slot_index index = 0;
|
@@ -4,6 +4,7 @@
|
||||
#include "llvm/IR/Intrinsics.h"
|
||||
#include "llvm/IR/Value.h"
|
||||
#include "llvm/IR/IRBuilder.h"
|
||||
#include <iostream>
|
||||
|
||||
using namespace llvm;
|
||||
|
||||
@@ -26,6 +27,12 @@ Instruction* amd_cl_target::add_barrier(Module *module, IRBuilder<>& builder) {
|
||||
}
|
||||
|
||||
Value* amd_cl_target::get_global_offset(Module *module, IRBuilder<>& builder, unsigned stride, unsigned ax) {
|
||||
Value* group_id = get_block_id(module, builder, ax);
|
||||
Value* result = builder.CreateMul(builder.getInt32(stride), group_id);
|
||||
return result;
|
||||
}
|
||||
|
||||
Value* amd_cl_target::get_block_id(Module *module, IRBuilder<>& builder, unsigned ax) {
|
||||
static std::array<Intrinsic::ID, 3> ids = {
|
||||
Intrinsic::amdgcn_workgroup_id_x,
|
||||
Intrinsic::amdgcn_workgroup_id_y,
|
||||
@@ -33,8 +40,7 @@ Value* amd_cl_target::get_global_offset(Module *module, IRBuilder<>& builder, un
|
||||
};
|
||||
Value* get_group_id = Intrinsic::getDeclaration(module, ids[ax]);
|
||||
Value* group_id = builder.CreateCall(get_group_id, {});
|
||||
Value* result = builder.CreateMul(builder.getInt32(stride), group_id);
|
||||
return result;
|
||||
return group_id;
|
||||
}
|
||||
|
||||
Value* amd_cl_target::get_local_id(Module *module, IRBuilder<>& builder, unsigned ax) {
|
||||
@@ -65,6 +71,12 @@ Instruction* nvidia_cu_target::add_barrier(Module *module, IRBuilder<>& builder)
|
||||
}
|
||||
|
||||
Value* nvidia_cu_target::get_global_offset(Module *module, IRBuilder<>& builder, unsigned stride, unsigned ax) {
|
||||
Value* group_id = get_block_id(module, builder, ax);
|
||||
Value* result = builder.CreateMul(builder.getInt32(stride), group_id);
|
||||
return result;
|
||||
}
|
||||
|
||||
Value* nvidia_cu_target::get_block_id(Module *module, IRBuilder<>& builder, unsigned ax) {
|
||||
static std::array<Intrinsic::ID, 3> ids = {
|
||||
Intrinsic::nvvm_read_ptx_sreg_ctaid_x,
|
||||
Intrinsic::nvvm_read_ptx_sreg_ctaid_y,
|
||||
@@ -72,8 +84,7 @@ Value* nvidia_cu_target::get_global_offset(Module *module, IRBuilder<>& builder,
|
||||
};
|
||||
Value* get_group_id = Intrinsic::getDeclaration(module, ids[ax]);
|
||||
Value* group_id = builder.CreateCall(get_group_id, {});
|
||||
Value* result = builder.CreateMul(builder.getInt32(stride), group_id);
|
||||
return result;
|
||||
return group_id;
|
||||
}
|
||||
|
||||
Value* nvidia_cu_target::get_local_id(Module *module, IRBuilder<>& builder, unsigned ax) {
|
||||
@@ -97,7 +108,7 @@ Instruction* cpu_target::add_barrier(Module *module, IRBuilder<>& builder) {
|
||||
return (Instruction*)builder.CreateAdd(builder.getInt32(0), builder.getInt32(0));
|
||||
}
|
||||
|
||||
Value* cpu_target::get_global_offset(Module *module, IRBuilder<>& builder, unsigned stride, unsigned ax) {
|
||||
Value* cpu_target::get_block_id(Module *module, llvm::IRBuilder<> &builder, unsigned ax) {
|
||||
const Function *fn = builder.GetInsertBlock()->getParent();
|
||||
size_t num_params = fn->getFunctionType()->getNumParams();
|
||||
static std::array<const Argument*, 3> ids = {
|
||||
@@ -105,7 +116,11 @@ Value* cpu_target::get_global_offset(Module *module, IRBuilder<>& builder, unsig
|
||||
fn->arg_begin() + num_params - 2,
|
||||
fn->arg_begin() + num_params - 1
|
||||
};
|
||||
Value* result = builder.CreateMul(builder.getInt32(stride), (Argument*)ids[ax]);
|
||||
return (Argument*)ids[ax];
|
||||
}
|
||||
|
||||
Value* cpu_target::get_global_offset(Module *module, IRBuilder<>& builder, unsigned stride, unsigned ax) {
|
||||
Value* result = builder.CreateMul(builder.getInt32(stride), get_block_id(module, builder, ax));
|
||||
return result;
|
||||
}
|
||||
|
||||
@@ -113,6 +128,5 @@ Value* cpu_target::get_local_id(Module *module, IRBuilder<>& builder, unsigned a
|
||||
return builder.getInt32(0);
|
||||
}
|
||||
|
||||
|
||||
}
|
||||
}
|
||||
|
@@ -1,5 +1,4 @@
|
||||
#include "triton/codegen/tune.h"
|
||||
#include "triton/codegen/shared_copy.h"
|
||||
#include "triton/ir/instructions.h"
|
||||
#include "triton/ir/type.h"
|
||||
#include "triton/ir/module.h"
|
||||
@@ -40,6 +39,8 @@ void tune::init_c_graph(ir::instruction *v) {
|
||||
ir::type::tile_shapes_t shapes;
|
||||
if(auto *store = dynamic_cast<ir::store_inst*>(v))
|
||||
shapes = store->get_pointer_operand()->get_type()->get_tile_shapes();
|
||||
else if(auto *downcast = dynamic_cast<ir::downcast_inst*>(v))
|
||||
return;
|
||||
else
|
||||
shapes = v->get_type()->get_tile_shapes();
|
||||
// Reshape
|
||||
@@ -56,6 +57,14 @@ void tune::init_c_graph(ir::instruction *v) {
|
||||
// Splat
|
||||
else if(dynamic_cast<ir::splat_inst*>(v)){
|
||||
|
||||
}
|
||||
// Trans
|
||||
else if(dynamic_cast<ir::trans_inst*>(v)){
|
||||
ir::value *op = v->get_operand(0);
|
||||
size_t n_shapes = shapes.size();
|
||||
for(unsigned i = 0; i < n_shapes; i++){
|
||||
add_constraint({v, (i + 1) % n_shapes}, {op, i});
|
||||
}
|
||||
}
|
||||
// Broadcast
|
||||
else if(dynamic_cast<ir::broadcast_inst*>(v)){
|
||||
@@ -68,7 +77,7 @@ void tune::init_c_graph(ir::instruction *v) {
|
||||
}
|
||||
}
|
||||
// Matrix multiplication
|
||||
else if(dynamic_cast<ir::matmul_inst*>(v)){
|
||||
else if(dynamic_cast<ir::dot_inst*>(v)){
|
||||
ir::value *D = v->get_operand(2);
|
||||
add_constraint({v, 0}, {D, 0});
|
||||
add_constraint({v, 1}, {D, 1});
|
||||
@@ -119,6 +128,13 @@ std::vector<ir::metaparameter *> tune::get_params(ir::module &mod) {
|
||||
if(seen.insert(x.second).second && !x.second->has_value()){
|
||||
result.push_back(x.second);
|
||||
}
|
||||
|
||||
for(auto x: mod.globals()){
|
||||
if(auto mp = dynamic_cast<ir::metaparameter*>(x.second))
|
||||
if(seen.insert(mp).second && !mp->has_value())
|
||||
result.push_back(mp);
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
@@ -145,23 +161,22 @@ void tune::run(ir::module &mod) {
|
||||
// Layout parameters
|
||||
while(!nodes_.empty()){
|
||||
ir::type *ty = mod.get_builder().get_int32_ty();
|
||||
ir::metaparameter *nts = ir::metaparameter::create(ctx, ty, 2, 4);
|
||||
ir::metaparameter *nts = ir::metaparameter::create(ctx, ty, 1, 1);
|
||||
nts->set_value(1);
|
||||
ir::metaparameter *mts = ir::metaparameter::create(ctx, ty, 4, 32);
|
||||
connected_components(*nodes_.begin(), {nts, mts}, nodes_, dependencies_);
|
||||
}
|
||||
}
|
||||
|
||||
// Simplify metaparameters
|
||||
std::set<ir::metaparameter*> fixed_io_nts;
|
||||
for(ir::function *fn: mod.get_function_list())
|
||||
for(ir::basic_block *block: fn->blocks())
|
||||
for(ir::instruction *i : block->get_inst_list())
|
||||
if(dynamic_cast<ir::load_inst*>(i) || dynamic_cast<ir::store_inst*>(i))
|
||||
if(i->get_type()->is_tile_ty())
|
||||
for(unsigned d = 1; d < i->get_type()->get_tile_shapes().size(); d++)
|
||||
fixed_io_nts.insert(params_.at(i).at("nts.d" + std::to_string(d)));
|
||||
for(ir::metaparameter* mp: fixed_io_nts)
|
||||
mp->set_value(1);
|
||||
if(dynamic_cast<ir::load_inst*>(i) && i->get_type()->is_tile_ty()){
|
||||
ir::type *ty = mod.get_builder().get_int32_ty();
|
||||
std::unique_ptr<ir::metaparameter> tmp(ir::metaparameter::create(ctx, ty, 2, 2));
|
||||
*params_.at(i).at("nts.d0") = *tmp;
|
||||
}
|
||||
}
|
||||
|
||||
void tune::init(ir::module &mod) {
|
||||
|
@@ -64,9 +64,6 @@ buffer* buffer::create(driver::context* ctx, size_t size) {
|
||||
host_buffer::host_buffer(driver::context *context, size_t size)
|
||||
: buffer(context, host_buffer_t(), true){
|
||||
hst_->data = new char[size];
|
||||
std::cout << size << std::endl;
|
||||
std::cout << "allocating " << (float*)hst_->data << std::endl;
|
||||
std::cout << *((float*)(hst_->data) + 512*500) << std::endl;
|
||||
}
|
||||
|
||||
//
|
||||
|
@@ -106,7 +106,11 @@ void module::compile_llvm_module(llvm::Module* module, const std::string& triple
|
||||
const std::string& features,
|
||||
file_type_t ft) {
|
||||
init_llvm();
|
||||
|
||||
// debug
|
||||
// llvm::legacy::PassManager pm;
|
||||
// pm.add(llvm::createPrintModulePass(llvm::outs()));
|
||||
// pm.add(llvm::createVerifierPass());
|
||||
// pm.run(*module);
|
||||
// create machine
|
||||
module->setTargetTriple(triple);
|
||||
std::string error;
|
||||
@@ -249,6 +253,7 @@ std::string cu_module::compile_llvm_module(llvm::Module* module) {
|
||||
cu_module::cu_module(driver::context * context, llvm::Module* ll_module): cu_module(context, compile_llvm_module(ll_module)) { }
|
||||
|
||||
cu_module::cu_module(driver::context * context, std::string const & source) : module(context, CUmodule(), true), source_(source){
|
||||
// std::cout << source << std::endl;
|
||||
cu_context::context_switcher ctx_switch(*context);
|
||||
// JIT compile source-code
|
||||
CUjit_option opt[] = {CU_JIT_ERROR_LOG_BUFFER_SIZE_BYTES, CU_JIT_ERROR_LOG_BUFFER};
|
||||
@@ -264,11 +269,11 @@ cu_module::cu_module(driver::context * context, std::string const & source) : mo
|
||||
}
|
||||
}
|
||||
|
||||
cu_buffer cu_module::symbol(const char *name) const{
|
||||
cu_buffer* cu_module::symbol(const char *name) const{
|
||||
CUdeviceptr handle;
|
||||
size_t size;
|
||||
dispatch::cuModuleGetGlobal_v2(&handle, &size, *cu_, name);
|
||||
return cu_buffer(ctx_, handle, false);
|
||||
return new cu_buffer(ctx_, handle, false);
|
||||
}
|
||||
|
||||
|
||||
|
@@ -285,6 +285,10 @@ value *builder::create_broadcast(value *arg, const type::tile_shapes_t &shapes,
|
||||
return insert(broadcast_inst::create(arg, shapes, name));
|
||||
}
|
||||
|
||||
value *builder::create_downcast(value *arg, const std::string &name) {
|
||||
return insert(downcast_inst::create(arg, name));
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// built-in instructions
|
||||
//===----------------------------------------------------------------------===//
|
||||
@@ -293,8 +297,24 @@ value *builder::create_get_global_range(unsigned axis, type::tile_shapes_t::valu
|
||||
return insert(get_global_range_inst::create(ctx_, axis, size, name));
|
||||
}
|
||||
|
||||
value *builder::create_matmul(value *A, value *B, value *C, const std::string &name) {
|
||||
return insert(matmul_inst::create(A, B, C, name));
|
||||
value *builder::create_get_range_id(unsigned axis, const std::string &name) {
|
||||
return insert(get_range_id_inst::create(ctx_, axis, name));
|
||||
}
|
||||
|
||||
value *builder::create_atomic_cas(value *ptr, value *cmp, value *val, const std::string &name){
|
||||
return insert(atomic_cas_inst::create(ptr, cmp, val, name));
|
||||
}
|
||||
|
||||
value *builder::create_dot(value *A, value *B, value *C, const std::string &name) {
|
||||
return insert(dot_inst::create_nn(A, B, C, name));
|
||||
}
|
||||
|
||||
value *builder::create_trans(value *A, const std::string &name) {
|
||||
return insert(trans_inst::create(A, name));
|
||||
}
|
||||
|
||||
value *builder::create_select(value *pred, value *if_value, value *else_value, const std::string &name){
|
||||
return insert(select_inst::create(pred, if_value, else_value, name));
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@@ -28,6 +28,8 @@ instruction::instruction(type *ty, unsigned num_ops, unsigned num_results, const
|
||||
|
||||
void instruction::erase_from_parent() {
|
||||
parent_->erase(this);
|
||||
for(ir::value* op: ops())
|
||||
op->erase_use(this);
|
||||
}
|
||||
|
||||
bool instruction::has_tile_result_or_op() {
|
||||
@@ -482,27 +484,82 @@ instruction* broadcast_inst::create(value *arg, const type::tile_shapes_t &shape
|
||||
return new broadcast_inst(arg, shapes, name, next);
|
||||
}
|
||||
|
||||
// downcast
|
||||
|
||||
instruction* downcast_inst::create(value *arg, const std::string &name, instruction *next) {
|
||||
return new downcast_inst(arg->get_type()->get_scalar_ty(), arg, name, next);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// matmul_inst classes
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
matmul_inst::matmul_inst(value *A, value *B, value *C,
|
||||
dot_inst::dot_inst(value *A, value *B, value *C, TransT AT, TransT BT,
|
||||
const std::string &name, instruction *next)
|
||||
: builtin_inst(C->get_type(), 3, 0, name, next) {
|
||||
: builtin_inst(C->get_type(), 3, 1, name, next), AT_(AT), BT_(BT) {
|
||||
set_operand(0, A);
|
||||
set_operand(1, B);
|
||||
set_operand(2, C);
|
||||
}
|
||||
|
||||
instruction *matmul_inst::create(value *A, value *B, value *C,
|
||||
instruction *dot_inst::create_nn(value *A, value *B, value *C,
|
||||
const std::string &name, instruction *next) {
|
||||
return new matmul_inst(A, B, C, name, next);
|
||||
return new dot_inst(A, B, C, NoTrans, NoTrans, name, next);
|
||||
}
|
||||
|
||||
instruction *dot_inst::create_nt(value *A, value *B, value *C,
|
||||
const std::string &name, instruction *next) {
|
||||
return new dot_inst(A, B, C, NoTrans, Trans, name, next);
|
||||
}
|
||||
|
||||
instruction *dot_inst::create_tn(value *A, value *B, value *C,
|
||||
const std::string &name, instruction *next) {
|
||||
return new dot_inst(A, B, C, Trans, NoTrans, name, next);
|
||||
}
|
||||
|
||||
instruction *dot_inst::create_tt(value *A, value *B, value *C,
|
||||
const std::string &name, instruction *next) {
|
||||
return new dot_inst(A, B, C, Trans, Trans, name, next);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// trans instructions
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
ir::type* trans_inst::get_res_ty(ir::type* ty) {
|
||||
auto shapes = ty->get_tile_shapes();
|
||||
std::rotate(shapes.begin(), shapes.begin() + 1, shapes.end());
|
||||
return tile_type::get(ty->get_scalar_ty(), shapes);
|
||||
}
|
||||
|
||||
trans_inst::trans_inst(value *arg, const std::string &name, instruction *next)
|
||||
: builtin_inst(get_res_ty(arg->get_type()), 1, 1, name, next) {
|
||||
set_operand(0, arg);
|
||||
}
|
||||
|
||||
instruction* trans_inst::create(value *arg, const std::string &name, instruction *next) {
|
||||
return new trans_inst(arg, name, next);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// select instructions
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
select_inst::select_inst(value *pred, value *if_value, value *else_value, const std::string &name, instruction *next)
|
||||
: builtin_inst(if_value->get_type(), 3, 1, name, next){
|
||||
set_operand(0, pred);
|
||||
set_operand(1, if_value);
|
||||
set_operand(2, else_value);
|
||||
}
|
||||
|
||||
instruction* select_inst::create(value *pred, value *if_value, value *else_value, const std::string &name, instruction *next) {
|
||||
return new select_inst(pred, if_value, else_value, name, next);
|
||||
}
|
||||
//===----------------------------------------------------------------------===//
|
||||
// builtin instructions
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
// get_global_range
|
||||
get_global_range_inst::get_global_range_inst(type *ty, unsigned axis,
|
||||
const std::string &name, instruction *next)
|
||||
: builtin_inst(ty, 0, 1, name, next), axis_(axis) {
|
||||
@@ -516,6 +573,28 @@ instruction* get_global_range_inst::create(context &ctx, unsigned axis, type::ti
|
||||
return new get_global_range_inst(tile_ty, axis, name, next);
|
||||
}
|
||||
|
||||
// get_range_id
|
||||
get_range_id_inst::get_range_id_inst(type *ty, unsigned axis, const std::string &name, instruction *next)
|
||||
: builtin_inst(ty, 0, 1, name, next), axis_(axis){
|
||||
|
||||
}
|
||||
|
||||
instruction* get_range_id_inst::create(context &ctx, unsigned axis, const std::string &name, instruction *next) {
|
||||
return new get_range_id_inst(type::get_int32_ty(ctx), axis, name, next);
|
||||
}
|
||||
|
||||
// atomic cas
|
||||
|
||||
atomic_cas_inst::atomic_cas_inst(value *ptr, value *cmp, value *val, const std::string &name, instruction *next)
|
||||
: builtin_inst(ptr->get_type()->get_pointer_element_ty(), 3, 1, name, next) {
|
||||
set_operand(0, ptr);
|
||||
set_operand(1, cmp);
|
||||
set_operand(2, val);
|
||||
}
|
||||
|
||||
instruction* atomic_cas_inst::create(value *ptr, value *cmp, value *val, const std::string &name, instruction *next) {
|
||||
return new atomic_cas_inst(ptr, cmp, val, name, next);
|
||||
}
|
||||
//===----------------------------------------------------------------------===//
|
||||
// intrinsic instructions
|
||||
//===----------------------------------------------------------------------===//
|
||||
@@ -530,7 +609,7 @@ vectorize_inst* vectorize_inst::create(value *arg, const std::string &name, inst
|
||||
|
||||
barrier_inst::barrier_inst(context &ctx, const std::string &name,
|
||||
instruction *next)
|
||||
: instruction(type::get_void_ty(ctx), 0, 1, name, next){ }
|
||||
: instruction(type::get_void_ty(ctx), 0, 0, name, next){ }
|
||||
|
||||
barrier_inst* barrier_inst::create(context &ctx, const std::string &name, instruction *next) {
|
||||
return new barrier_inst(ctx, name, next);
|
||||
|
@@ -128,6 +128,9 @@ ir::value *module::get_value(const std::string& name) {
|
||||
return get_value(name, builder_.get_insert_block());
|
||||
}
|
||||
|
||||
const std::string& module::get_name() {
|
||||
return name_;
|
||||
}
|
||||
|
||||
void module::seal_block(ir::basic_block *block){
|
||||
for(auto &x: incomplete_phis_[block]){
|
||||
|
@@ -172,7 +172,7 @@ unsigned tile_type::get_bitwidth() const {
|
||||
tile_type* tile_type::get(type *elt_ty, const tile_shapes_t &shapes) {
|
||||
assert(elt_ty && "Can't get a tile of <null> type!");
|
||||
assert(shapes.size() && "Can't create a tile with empty shapes!");
|
||||
assert(is_valid_elt_ty(elt_ty) && "Invalid type for pointer element!");
|
||||
assert(is_valid_elt_ty(elt_ty) && "Invalid type for tile element!");
|
||||
// look-up
|
||||
context_impl *impl = elt_ty->get_context().p_impl.get();
|
||||
tile_type *&entry = impl->tile_tys[std::make_pair(elt_ty, shapes)];
|
||||
|
38
lib/jit.cpp
38
lib/jit.cpp
@@ -68,7 +68,7 @@ void loop_nest(std::vector<std::vector<T>> const & iterates, std::function<void(
|
||||
|
||||
|
||||
std::unique_ptr<llvm::Module> jit::make_llvm_module(ir::module &module, passes_wrapper &passes) {
|
||||
llvm::Module* result = new llvm::Module("matmul", llvm_context_);
|
||||
llvm::Module* result = new llvm::Module(module.get_name(), llvm_context_);
|
||||
passes.selection.run(module, *result);
|
||||
// launch information
|
||||
auto &launch_info_map = launch_info_map_[result->getName()];
|
||||
@@ -79,14 +79,14 @@ std::unique_ptr<llvm::Module> jit::make_llvm_module(ir::module &module, passes_w
|
||||
return std::unique_ptr<llvm::Module>(result);
|
||||
}
|
||||
|
||||
std::unique_ptr<ir::module> jit::make_triton_module(const std::string &src) {
|
||||
std::unique_ptr<ir::module> jit::make_triton_module(const std::string &name, const std::string &src) {
|
||||
// create AST from Triton-C source
|
||||
YY_BUFFER_STATE buffer = yy_scan_string(src.c_str());
|
||||
yyparse();
|
||||
yy_delete_buffer(buffer);
|
||||
translation_unit *program = ast_root;
|
||||
// create Triton-IR from AST
|
||||
ir::module* module = new ir::module("matrix", triton_context_);
|
||||
ir::module* module = new ir::module(name, triton_context_);
|
||||
program->codegen(module);
|
||||
return std::unique_ptr<ir::module>(module);
|
||||
}
|
||||
@@ -97,18 +97,20 @@ jit::jit(driver::context *context): driver_context_(context),
|
||||
}
|
||||
|
||||
|
||||
void jit::autotune(const std::string &src, benchmark_t benchmark) {
|
||||
void jit::autotune(const std::string &name, const std::string &src, benchmark_t benchmark) {
|
||||
// find metaparameters
|
||||
auto ptt_module = make_triton_module(src);
|
||||
auto ptt_module = make_triton_module(name, src);
|
||||
ir::module &tt_module = *ptt_module;
|
||||
// set parameters
|
||||
passes_wrapper passes(target_.get());
|
||||
passes.target_independent(tt_module);
|
||||
passes.tune.run(tt_module);
|
||||
auto mps = passes.tune.get_params(tt_module);
|
||||
// create parameter ranges
|
||||
std::vector<std::vector<unsigned>> ranges;
|
||||
for(ir::metaparameter *mp: mps)
|
||||
ranges.push_back(mp->get_space());
|
||||
// std::cout << ranges.size() << std::endl;
|
||||
// iterate over parameters
|
||||
unsigned i;
|
||||
double best = 0;
|
||||
@@ -117,51 +119,56 @@ void jit::autotune(const std::string &src, benchmark_t benchmark) {
|
||||
i = 0;
|
||||
for(ir::metaparameter *mp: mps)
|
||||
mp->set_value(params[i++]);
|
||||
passes.target_independent(tt_module);
|
||||
passes.tune.init(tt_module);
|
||||
if(!passes.tune.check_constraints(errors))
|
||||
return;
|
||||
// Deep copy of the module and tuner
|
||||
auto ptt_module = make_triton_module(src);
|
||||
auto ptt_module = make_triton_module(name, src);
|
||||
ir::module &tt_module = *ptt_module;
|
||||
passes_wrapper passes(target_.get());
|
||||
passes.target_independent(tt_module);
|
||||
passes.tune.run(tt_module);
|
||||
i = 0;
|
||||
for(ir::metaparameter* mp: passes.tune.get_params(tt_module)){
|
||||
mp->set_value(params[i++]);
|
||||
}
|
||||
passes.tune.init(tt_module);
|
||||
passes.init(tt_module);
|
||||
passes.target_dependent(tt_module);
|
||||
driver::device* device = driver_context_->device();
|
||||
if(passes.allocation.get_allocated_size() > device->max_shared_memory())
|
||||
if(passes.shmem_allocation.get_allocated_size() > device->max_shared_memory())
|
||||
return;
|
||||
if(passes.tune.get_num_threads() > device->max_threads_per_block())
|
||||
return;
|
||||
// Compile
|
||||
auto ll_module = make_llvm_module(tt_module, passes);
|
||||
std::unique_ptr<driver::module> module(driver::module::create(driver_context_, &*ll_module));
|
||||
std::unique_ptr<driver::kernel> kernel(driver::kernel::create(module.get(), "matmul"));
|
||||
launch_information info = launch_info_map_.at("matmul");
|
||||
std::unique_ptr<driver::kernel> kernel(driver::kernel::create(module.get(), name.c_str()));
|
||||
launch_information info = launch_info_map_.at(name.c_str());
|
||||
for(unsigned p: params)
|
||||
std::cout << p << " " << std::flush;
|
||||
// add globals
|
||||
for(auto x: tt_module.globals())
|
||||
global_ints_[x.first] = ((ir::metaparameter*)x.second)->get_value();
|
||||
modules_.push_back(module.get());
|
||||
double perf;
|
||||
perf = benchmark(kernel.get(), info);
|
||||
best = std::max(perf, best);
|
||||
std::cout << perf << " [ " << best << " ] " << std::endl;
|
||||
modules_.pop_back();
|
||||
});
|
||||
}
|
||||
|
||||
void jit::add_module(ir::module &tt_module, const std::vector<unsigned> ¶ms) {
|
||||
// set parameters
|
||||
passes_wrapper passes(target_.get());
|
||||
passes.target_independent(tt_module);
|
||||
passes.tune.run(tt_module);
|
||||
unsigned i = 0;
|
||||
for(ir::metaparameter* mp: passes.tune.get_params(tt_module))
|
||||
mp->set_value(params[i++]);
|
||||
passes.tune.init(tt_module);
|
||||
passes.init(tt_module);
|
||||
passes.target_dependent(tt_module);
|
||||
// check constraints
|
||||
std::map<ir::value*, std::vector<std::string>> errors;
|
||||
passes.tune.check_constraints(errors);
|
||||
@@ -184,8 +191,8 @@ void jit::add_module(ir::module &tt_module, const std::vector<unsigned> ¶ms)
|
||||
global_ints_[x.first] = ((ir::metaparameter*)x.second)->get_value();
|
||||
}
|
||||
|
||||
void jit::add_module(const std::string &src, const std::vector<unsigned> ¶ms) {
|
||||
auto ptt_module = make_triton_module(src);
|
||||
void jit::add_module(const std::string &name, const std::string &src, const std::vector<unsigned> ¶ms) {
|
||||
auto ptt_module = make_triton_module(name, src);
|
||||
add_module(*ptt_module, params);
|
||||
}
|
||||
|
||||
@@ -201,4 +208,9 @@ unsigned jit::get_int(const std::string &name){
|
||||
return global_ints_.at(name);
|
||||
}
|
||||
|
||||
driver::buffer *jit::get_buffer(const std::string &name){
|
||||
driver::cu_module *mod = (driver::cu_module*)modules_.front();
|
||||
return mod->symbol(name.c_str());
|
||||
}
|
||||
|
||||
}
|
||||
|
Reference in New Issue
Block a user