[general] major overhaul of triton-c/triton-ir/triton-jit:

- Added alloc const
- Added atomics
- Pruning tuning space
- Added example for dot/conv/shift
- Bugfixes
This commit is contained in:
Philippe Tillet
2019-04-25 16:17:36 -04:00
parent 0c607c9392
commit 3413aad582
50 changed files with 2051 additions and 570 deletions

View File

@@ -1,6 +1 @@
foreach(PROG matrix)
add_executable(${PROG} ${PROG}.cpp)
set_target_properties(${PROG} PROPERTIES OUTPUT_NAME ${PROG})
include_directories(/usr/local/cuda/include/)
target_link_libraries(${PROG} triton)
endforeach(PROG)
add_subdirectory(cpp)

View File

@@ -0,0 +1,6 @@
foreach(PROG dot conv shift)
add_executable(${PROG} ${PROG}.cpp)
set_target_properties(${PROG} PROPERTIES OUTPUT_NAME ${PROG})
include_directories(/usr/local/cuda/include/)
target_link_libraries(${PROG} triton)
endforeach(PROG)

View File

@@ -1,17 +1,18 @@
#include <cstring>
#include <cstdio>
#include "common.hpp"
#include "triton/jit.h"
#include "triton/driver/backend.h"
#include "triton/driver/stream.h"
const char* src =
R"(
const tunable int32 TM = {16, 32, 64};
const tunable int32 TN = {16, 32, 64};
const tunable int32 TM = {16, 32, 64, 128};
const tunable int32 TN = {8};
const tunable int32 TK = {8};
void matmul(restrict read_only fp32 *a, restrict read_only fp32 *b, fp32 *c,
int32 M, int32 N, int32 K, int32 bound){
void blocksparse(restrict read_only fp32 *a, restrict read_only fp32 *b, fp32 *c,
int32 M, int32 N, int32 K, int32 bound){
int32 rxa[TM] = get_global_range[TM](0);
int32 ryb[TN] = get_global_range[TN](1);
int32 rka[TK] = 0 ... TK;
@@ -22,9 +23,9 @@ void matmul(restrict read_only fp32 *a, restrict read_only fp32 *b, fp32 *c,
fp32 a[TM, TK] = *pa;
fp32 b[TN, TK] = *pb;
for(int32 k = K; k > 0;){
C = dot(a, b, C);
C = dot(a, trans(b), C);
pa = pa + TK*M;
pb = pb + TK*K;
pb = pb + TK*N;
k = k - TK;
int1 checka[TM, TK] = k > bound;
int1 checkb[TN, TK] = k > bound;
@@ -51,71 +52,24 @@ void matmul(restrict read_only fp32 *a, restrict read_only fp32 *b, fp32 *c,
}
)";
template<class T>
void simple_gemm(std::vector<T> &c, const std::vector<T> &a, const std::vector<T> &b, size_t M, size_t N, size_t K){
for(size_t m = 0; m < M; m++)
for(size_t n = 0; n < N; n++){
T acc = 0;
for(size_t k = 0; k < K; k++)
acc += a[m + k*M] * b[n + k*N];
c[m + n*M] = acc;
std::vector<int> make_deltas(std::vector<int> mask, int K, int N){
std::vector<std::vector<std::pair<int,int>>> pairs(N);
unsigned int current = 0;
for(int k = 0; k < K; k++)
for(int n = 0; n < N; n++){
if(mask[k + n*K])
pairs[n].push_back({current, k});
}
}
class timer{
typedef std::chrono::high_resolution_clock high_resolution_clock;
typedef std::chrono::nanoseconds nanoseconds;
public:
explicit timer(bool run = false)
{ if (run) start(); }
void start()
{ _start = high_resolution_clock::now(); }
nanoseconds get() const
{ return std::chrono::duration_cast<nanoseconds>(high_resolution_clock::now() - _start); }
private:
high_resolution_clock::time_point _start;
};
template<class T>
T min(std::vector<T> x)
{ return *std::min_element(x.begin(), x.end()); }
template<class OP, class SYNC>
double bench(OP const & op, SYNC const & sync, triton::driver::device const & device)
{
timer tmr;
std::vector<size_t> times;
double total_time = 0;
op();
sync();
while(total_time*1e-9 < 1e-3){
float norm = 1;
// normalize clock if possible to get roughly constant result
if(auto cu_device = dynamic_cast<const triton::driver::cu_device*>(&device))
norm = (float)cu_device->current_sm_clock()/cu_device->max_sm_clock();
tmr.start();
op();
sync();
times.push_back(norm*tmr.get().count());
total_time+=times.back();
}
return min(times);
}
int main() {
// initialize default compute device
auto context = triton::driver::backend::contexts::get_default();
triton::jit jit(context);
// matrix multiplication parameters
int32_t M = 512, N = 512, K = 512;
int32_t M = 512, N = 32, K = 2048;
std::vector<float> hc(M*N);
std::vector<float> rc(M*N);
std::vector<float> ha(M*K);
@@ -183,14 +137,13 @@ int main() {
8, 8,
4
};
jit.autotune(src, benchmark);
jit.add_module(src, params);
jit.autotune("matmul",src, benchmark);
jit.add_module("matmul", src, params);
triton::driver::kernel* kernel = jit.get_function("matmul");
triton::jit::launch_information info = jit.get_launch_info("matmul");
std::cout << benchmark(kernel, info) << std::endl;
std::cout << "Performance: " << benchmark(kernel, info) << " TFLOPS " << std::endl;
stream->read(dc, true, 0, hc);
simple_gemm(rc, ha, hb, M, N, K);
simple_gemm<float,false,true>(rc, ha, hb, M, N, K);
for(size_t i = 0; i < M*N; i++)
if(std::abs(hc[i] - rc[i])/std::max(hc[i], rc[i]) > 1e-4){
std::cout << i << " " << hc[i] << " " << rc[i] << std::endl;

286
examples/cpp/common.hpp Normal file
View File

@@ -0,0 +1,286 @@
#include <vector>
#include <chrono>
#include "triton/driver/device.h"
#include <algorithm>
template<class T, bool AT, bool BT>
void simple_gemm(std::vector<T> &c, const std::vector<T> &a, const std::vector<T> &b, size_t M, size_t N, size_t K){
for(size_t m = 0; m < M; m++)
for(size_t n = 0; n < N; n++){
T acc = 0;
for(size_t k = 0; k < K; k++)
acc += (AT?a[k + m*K]:a[m + k*M]) * (BT?b[n + k*N]:b[k + n*K]);
c[m + n*M] = acc;
}
}
class timer{
typedef std::chrono::high_resolution_clock high_resolution_clock;
typedef std::chrono::nanoseconds nanoseconds;
public:
explicit timer(bool run = false)
{ if (run) start(); }
void start()
{ _start = high_resolution_clock::now(); }
nanoseconds get() const
{ return std::chrono::duration_cast<nanoseconds>(high_resolution_clock::now() - _start); }
private:
high_resolution_clock::time_point _start;
};
template<class T>
T min(std::vector<T> x)
{ return *std::min_element(x.begin(), x.end()); }
template<class OP, class SYNC>
double bench(OP const & op, SYNC const & sync, triton::driver::device const & device)
{
timer tmr;
std::vector<size_t> times;
double total_time = 0;
op();
sync();
while(total_time*1e-9 < 1e-3){
float norm = 1;
// normalize clock if possible to get roughly constant result
if(auto cu_device = dynamic_cast<const triton::driver::cu_device*>(&device))
norm = (float)cu_device->current_sm_clock()/cu_device->max_sm_clock();
tmr.start();
op();
sync();
times.push_back(norm*tmr.get().count());
total_time+=times.back();
}
return min(times);
}
//
void build_conv_lut(int TK,
int stride_d, int stride_h, int stride_w, int stride_c,
int pad_d, int pad_h, int pad_w,
int T, int R, int S,
std::vector<int>& res, std::vector<int>& masks) {
/* convolution parameters */
int F = T * R * S;
int Nlut = (TK + F - 1) / F * F;
int upsample_w = 1;
int upsample_h = 1;
int upsample_d = 1;
/* unpack index wrt filters */
auto unpack = [&](int32_t trs){
int32_t tr = trs / S;
int32_t s = trs - tr*S;
int32_t t = tr / R;
int32_t r = tr - t*R;
return std::make_tuple(t, r, s);
};
/* increments */
for(size_t i = 0; i < Nlut; ++i)
res[i] = (((i + TK) % Nlut) - i);
/* deltas */
size_t Ds0 = Nlut;
size_t Ds1 = upsample_w;
size_t Ds2 = upsample_h;
size_t Ds3 = upsample_d;
for(size_t pd = 0; pd < Ds3; ++pd)
for(size_t ph = 0; ph < Ds2; ++ph)
for(size_t pw = 0; pw < Ds1; ++pw){
int32_t* deltas_ptr = &res[Nlut + pw*Ds0 + ph*Ds0*Ds1 + pd*Ds0*Ds1*Ds2];
// cumulative increments
for(size_t i = 0; i < Ds0; ++i){
int32_t ctrs = i;
int32_t c = ctrs / F;
int32_t t, r, s;
std::tie(t, r, s) = unpack(ctrs % F);
// next indices
int32_t nextctrs = ctrs + TK;
int32_t nextc = nextctrs / F;
int32_t nextt, nextr, nexts;
std::tie(nextt, nextr, nexts) = unpack(nextctrs % F);
// diffs
int32_t cdiff = nextc - c;
int32_t tdiff = (nextt + pd)/upsample_d - (t + pd)/upsample_d;
int32_t rdiff = (nextr + ph)/upsample_h - (r + ph)/upsample_h;
int32_t sdiff = (nexts + pw)/upsample_w - (s + pw)/upsample_w;
// delta pointers
deltas_ptr[i] = cdiff*stride_c + sdiff*stride_w + rdiff*stride_h + tdiff*stride_d;
}
}
/* Masks */
size_t Ms0 = Nlut;
size_t Ms1 = 2*pad_w + 1;
size_t Ms2 = 2*pad_h + 1;
size_t Ms3 = 2*pad_d + 1;
for(size_t pd = 0; pd < Ms3; ++pd)
for(size_t ph = 0; ph < Ms2; ++ph)
for(size_t pw = 0; pw < Ms1; ++pw){
int32_t* masks_ptr = &masks[Nlut + pw*Ms0 + ph*Ms0*Ms1 + pd*Ms0*Ms1*Ms2];
for(size_t i = 0; i < Ms0; ++i){
int32_t t, r, s;
int32_t mask = 0x0;
for(size_t j = 0; j < TK; ++j){
std::tie(t, r, s) = unpack((i + j) % F);
bool in_bounds_d = (t + pd) >= pad_d && (t + pd) < (T + pad_d);
bool in_bounds_h = (r + ph) >= pad_h && (r + ph) < (R + pad_h);
bool in_bounds_w = (s + pw) >= pad_w && (s + pw) < (S + pad_w);
mask |= (in_bounds_d && in_bounds_h && in_bounds_w) << j;
}
masks_ptr[i] = mask;
}
}
for(size_t i = 0; i < Nlut; ++i)
masks[i] = 0x0;
}
// Index computation
inline int32_t idx(int32_t x, int32_t y, int32_t z, int32_t w, int32_t u,
int32_t /*s0*/, int32_t s1, int32_t s2, int32_t s3, int32_t s4)
{ return u + w*s4 + z*s4*s3 + y*s4*s3*s2 + x*s4*s3*s2*s1; }
// Pack
template <class T> T clamp(T x, T lo, T hi){
return std::max<T>(lo, std::min<T>(x, hi));
}
template<class T, class U>
T pack(U* tmp, U scale);
template<>
double pack<double, double>(double* tmp, double scale)
{ return tmp[0]*scale; }
template<>
float pack<float, float>(float* tmp, float scale)
{ return tmp[0]*scale; }
template<>
int pack<int, float>(float* tmp, float scale)
{
int res = 0;
for(int i = 0; i < 4; i++){
int8_t clamped = std::round(clamp(tmp[i]*scale, (float)-128, (float)127));
res |= (clamped & 0xFF) << (8*i);
}
return res;
}
template<class T> struct pack_increment
{ enum{ VALUE = 1}; };
template<> struct pack_increment<int>
{ enum{ VALUE = 4}; };
// Dot
template<class T>
inline T dot(T x, T y, T z)
{
return std::fma(x, y, z);
}
inline int dot(int x, int y, int z){
int res = 0;
for(int i = 0; i < 4; i++){
int32_t a = ((x >> (8*i)) & 0x000000FF);
int32_t b = ((y >> (8*i)) & 0x000000FF);
res += (*(int8_t*)(&a)) * (*(int8_t*)(&b));
}
return res + z;
}
template<class IN_DTYPE, class OUT_DTYPE>
void cpp_conv_nchw(int32_t C, int32_t N, int32_t K,
int32_t D, int32_t H, int32_t W,
int32_t T, int32_t R, int32_t S,
int32_t pad_d, int32_t pad_h, int32_t pad_w,
int32_t stride_d, int32_t stride_h, int32_t stride_w,
int32_t M, int32_t P, int32_t Q,
std::vector<OUT_DTYPE>& O,
const std::vector<IN_DTYPE>& I,
const std::vector<IN_DTYPE>& F)
{
static const int PACK_IN = pack_increment<IN_DTYPE>::VALUE;
static const int PACK_OUT = pack_increment<OUT_DTYPE>::VALUE;
if(C % PACK_IN != 0) throw std::runtime_error("Number of input channels must be a multiple of 4");
if(K % PACK_OUT != 0) throw std::runtime_error("Number of output channels must be a multiple of 4");
C /= PACK_IN;
K /= PACK_OUT;
int32_t Kout = K;
IN_DTYPE accs[PACK_OUT];
float tmp[PACK_OUT];
for(int32_t m = 0 ; m < M; ++m)
for(int32_t p = 0 ; p < P; ++p)
for(int32_t q = 0; q < Q; ++q)
for(int32_t n = 0; n < N; ++n)
for(int32_t k = 0; k < Kout ; ++k)
{
for(int32_t i = 0; i < PACK_OUT; ++i)
accs[i] = 0;
int32_t mm = m*stride_d - pad_d;
int32_t pp = p*stride_h - pad_h;
int32_t qq = q*stride_w - pad_w;
for(int32_t kk = 0; kk < PACK_OUT; ++kk)
for(int32_t c = 0; c < C; ++c)
for(int32_t t = 0; t < T; ++t)
for(int32_t r = 0; r < R; ++r)
for(int32_t s = 0; s < S; ++s){
int32_t d = mm + t;
int32_t h = pp + r;
int32_t w = qq + s;
bool in_bounds = (d >= 0 && h >= 0 && w >= 0 && d < D && h < H && w < W);
IN_DTYPE i = in_bounds?I[idx(n, c, d, h, w, N, C, D, H, W)]:0;
IN_DTYPE f = F[idx(c, t, r, s, k*PACK_OUT + kk, C, T, R, S, K*PACK_OUT)];
accs[kk] = dot(i, f, accs[kk]);
}
for(int32_t kk = 0; kk < PACK_OUT; ++kk){
tmp[kk] = accs[kk];
}
O[idx(n, k, m, p, q, N, K, M, P, Q)] = tmp[0];
}
}
// input layout: C, H, W, BS
// filter layout: C, K
// output layout: K, H, W, BS
template<class IN_DTYPE, class OUT_DTYPE>
void shift_conv(int32_t C, int32_t H, int32_t W, int32_t BS,
int32_t K,
std::vector<OUT_DTYPE>& O,
const std::vector<IN_DTYPE>& I,
const std::vector<IN_DTYPE>& F,
const std::vector<int32_t> shift_h,
const std::vector<int32_t> shift_w)
{
OUT_DTYPE acc;
for(int32_t p = 0; p < H; ++p)
for(int32_t q = 0; q < W; ++q)
for(int32_t bs = 0; bs < BS; ++bs)
for(int32_t k = 0; k < K; ++k)
{
acc = 0;
for(int32_t c = 0; c < C; ++c){
int32_t h = p + shift_h[c];
int32_t w = q + shift_w[c];
bool in_bounds = (h >= 0 && w >= 0 && h < H && w < W);
IN_DTYPE a = in_bounds?I[bs + w*BS + h*BS*W + c*BS*H*W]:0;
IN_DTYPE b = F[k + c*K];
acc = dot(a, b, acc);
}
O[bs + q*BS + p*BS*W + k*BS*H*W] = acc;
}
}

236
examples/cpp/conv.cpp Normal file
View File

@@ -0,0 +1,236 @@
#include <cstring>
#include <cstdio>
#include "common.hpp"
#include "triton/jit.h"
#include "triton/driver/backend.h"
#include "triton/driver/stream.h"
std::string src =
R"(
const tunable int32 TM = {16, 32, 64};
const tunable int32 TN = {16, 32, 64};
const tunable int32 TK = {8};
__constant__ int32* delta = alloc_const int32[18];
__constant__ int32* masks = alloc_const int32[1024];
void conv(read_only restrict fp32 *a,
read_only restrict fp32 *b,
fp32 *c,
int32 M, int32 N, int32 K,
int32 AN, int32 AH, int32 AW,
int32 CN, int32 CK, int32 CP, int32 CQ,
int32 AC, int32 AR, int32 AS,
int32 lda_n, int32 lda_c, int32 lda_h, int32 lda_w,
int32 ldc_n, int32 ldc_k, int32 ldc_p, int32 ldc_q,
int32 pad_h, int32 pad_w,
int32 bound){
int32 rxa[TM] = get_global_range[TM](0);
int32 rb0[TN] = get_global_range[TN](1);
int32 rka[TK] = 0 ... TK;
int32 rb1[TK] = 0 ... TK;
fp32 C[TM, TN] = 0;
int32 ranh[TM] = rxa / CQ;
int32 raw[TM] = rxa % CQ - pad_w;
int32 ran[TM] = ranh / CP;
int32 rah[TM] = ranh % CP - pad_h;
int32 ra0[TM] = ran*lda_n + rah*lda_h + raw*lda_w;
int32 racr[TK] = rka / AS;
int32 ras[TK] = rka % AS;
int32 rac[TK] = racr / AR;
int32 rar[TK] = racr % AR;
int32 ra1[TK] = rac*lda_c + rar*lda_h + ras*lda_w;
fp32* pa[TM, TK] = a + ra1[newaxis, :] + ra0[:, newaxis];
fp32* pb[TN, TK] = b + rb1[newaxis, :]*CK + rb0[:, newaxis];
__constant__ int32* pincd[TK] = delta + rka;
__constant__ int32* pd[TK] = delta + AR*AS + rka;
int32 d[TK] = *pd;
int32 incd[TK] = *pincd;
int32 maskh[TM] = pad_h + min(rah, 0) + max(rah + AR - AH, 0);
int32 maskw[TM] = pad_w + min(raw, 0) + max(raw + AS - AW, 0);
__constant__ int32* pm[TM] = masks + AR*AS + maskw*AR*AS + maskh*AR*AS*(2*pad_w + 1);
__constant__ int32* pincm[TM] = delta;
int32 incm[TM] = *pincm;
int32 checka0[TM] = *pm;
int32 checka1[TK] = 1 << rka;
int1 checka[TM, TK] = (checka0[:, newaxis] & checka1[newaxis, :]) > 0;
fp32 a[TM, TK] = checka ? *pa : 0;
fp32 b[TN, TK] = *pb;
for(int32 k = K; k > 0; k = k - TK){
C = dot(a, trans(b), C);
pb = pb + TK*CK;
pa = pa + d[newaxis, :];
b = *pb;
pd = pd + incd;
pincd = pincd + incd;
d = *pd;
incd = *pincd;
pm = pm + incm;
pincm = pincm + incm;
incm = *pincm;
checka0 = *pm;
checka = (checka0[:, newaxis] & checka1[newaxis, :]) > 0;
a = checka ? *pa : 0;
}
int32 rxc[TM] = get_global_range[TM](0);
int32 rc1[TN] = get_global_range[TN](1);
int32 rcn[TM] = rxc / (CP*CQ);
int32 rcpq[TM] = rxc % (CP*CQ);
int32 rc0[TM] = rcn * ldc_n + rcpq;
fp32* pc[TM, TN] = c + rc1[newaxis, :]*ldc_k + rc0[:, newaxis];
int1 checkc0[TM] = rxc < M;
int1 checkc1[TN] = rc1 < N;
int1 checkc[TM, TN] = checkc0[:, newaxis] && checkc1[newaxis, :];
@checkc *pc = C;
})";
int main() {
// initialize default compute device
auto context = triton::driver::backend::contexts::get_default();
// initialize just-in-time compiler
triton::jit jit(context);
// initialization
int32_t AN = 4, CK = 32;
int32_t AD = 1, AH = 24, AW = 240;
int32_t BC = 64, BT = 1, BR = 3, BS = 3;
int32_t pad_d = 0, pad_h = 1, pad_w = 1;
int32_t stride_d = 1, stride_h = 1, stride_w = 1;
int32_t upsample_d = 1, upsample_h = 1, upsample_w = 1;
int32_t CM = (AD*upsample_d - BT + 1 + 2*pad_d + stride_d - 1)/stride_d;
int32_t CP = (AH*upsample_h - BR + 1 + 2*pad_h + stride_h - 1)/stride_h;
int32_t CQ = (AW*upsample_w - BS + 1 + 2*pad_w + stride_w - 1)/stride_w;
// equivalent matmul dimensions
int32_t M = AN*CM*CP*CQ;
int32_t N = CK;
int32_t K = BC*BT*BR*BS;
std::vector<float> hc(AN*CP*CQ*CK);
std::vector<float> rc(AN*CP*CQ*CK);
std::vector<float> ha(AN*BC*AH*AW);
std::vector<float> hb(BC*BR*BS*CK);
srand(0);
for(size_t i = 0; i < ha.size(); i++)
ha[i] = 1;
for(size_t i = 0; i < hb.size(); i++)
hb[i] = 1;
for(size_t i = 0; i < hc.size(); i++)
hc[i] = 0;
triton::driver::buffer* dc = triton::driver::buffer::create(context, hc.size()*4);
triton::driver::buffer* da = triton::driver::buffer::create(context, ha.size()*4);
triton::driver::buffer* db = triton::driver::buffer::create(context, hb.size()*4);
triton::driver::stream* stream = triton::driver::stream::create(context);
stream->write(da, true, 0, ha);
stream->write(db, true, 0, hb);
stream->write(dc, true, 0, hc);
stream->synchronize();
// memory strides for data
int32_t stride_i_w = 1;
int32_t stride_i_h = AW*stride_i_w;
int32_t stride_i_d = AH*stride_i_h;
int32_t stride_i_c = AD*stride_i_d;
int32_t stride_i_n = BC*stride_i_c;
// memory strides for filters
int32_t stride_f_k = 1;
int32_t stride_f_s = CK*stride_f_k;
int32_t stride_f_r = BS*stride_f_s;
int32_t stride_f_t = BR*stride_f_r;
int32_t stride_f_c = BT*stride_f_t;
// memory stride for activations
int32_t stride_o_q = 1;
int32_t stride_o_p = CQ*stride_o_q;
int32_t stride_o_m = CP*stride_o_p;
int32_t stride_o_k = CM*stride_o_m;
int32_t stride_o_n = CK*stride_o_k;
// look-up table
int TK = 8;
int F = BT * BR * BS;
int nlut = (TK + F - 1) / F * F;
std::vector<int> h_delta(nlut + upsample_d*upsample_h*upsample_w*nlut);
std::vector<int> h_masks(nlut + (2*pad_h+1)*(2*pad_w+1)*(2*pad_d+1)*nlut);
build_conv_lut(TK, stride_i_d, stride_i_h, stride_i_w, stride_i_c, pad_d, pad_h, pad_w, BT, BR, BS, h_delta, h_masks);
// benchmark a given convolution kernel
auto benchmark = [&](triton::driver::kernel* kernel,
triton::jit::launch_information info) {
// launch info
unsigned TM = info.global_range_size[0];
unsigned TN = info.global_range_size[1];
unsigned TK = jit.get_int("TK");
// initialize constant memory
triton::driver::buffer* delta = jit.get_buffer("delta");
triton::driver::buffer* masks = jit.get_buffer("masks");
stream->write(delta, false, 0, h_delta.size()*4, h_delta.data());
stream->write(masks, false, 0, h_masks.size()*4, h_masks.data());
stream->synchronize();
// launch info
unsigned nthreads = info.num_threads;
std::array<size_t, 3> grid = {(M + TM - 1)/TM, (N + TN - 1)/TN, 1};
// fast bounds-checking
unsigned lasti = (grid[0]*TM - 1)*TM + TM - 1;
unsigned lastj = (grid[1]*TN - 1)*TN + TN - 1;
unsigned lastk = TK - 1;
bool AT = false;
bool BT = true;
unsigned last_safe_a = (AT==false)?(M*K - 1 - lasti)/M - lastk : M*K - 1 - lasti*K - lastk;
unsigned last_safe_b = (BT==true)?(N*K - 1 - lastj)/N - lastk : N*K - 1 - lastj*K - lastk;
int32_t bound = std::max<unsigned>(1, std::max(K - last_safe_a, K - last_safe_b));
// set arguments
kernel->setArg(0, da);
kernel->setArg(1, db);
kernel->setArg(2, dc);
kernel->setArg(3, M);
kernel->setArg(4, N);
kernel->setArg(5, K);
kernel->setArg(6, AN);
kernel->setArg(7, AH);
kernel->setArg(8, AW);
kernel->setArg(9, AN);
kernel->setArg(10, CK);
kernel->setArg(11, CP);
kernel->setArg(12, CQ);
kernel->setArg(13, BC);
kernel->setArg(14, BR);
kernel->setArg(15, BS);
kernel->setArg(16, stride_i_n);
kernel->setArg(17, stride_i_c);
kernel->setArg(18, stride_i_h);
kernel->setArg(19, stride_i_w);
kernel->setArg(20, stride_o_n);
kernel->setArg(21, stride_o_k);
kernel->setArg(22, stride_o_p);
kernel->setArg(23, stride_o_q);
kernel->setArg(24, pad_h);
kernel->setArg(25, pad_w);
kernel->setArg(26, bound);
// dry run
stream->enqueue(kernel, grid, {nthreads, 1, 1});
stream->synchronize();
// benchmark
double ts = bench([&](){stream->enqueue(kernel, grid, {nthreads, 1, 1});},
[&](){ stream->synchronize(); }, *context->device());
ts = ts * 1e-9;
double tflops = 2.*M*N*K / ts * 1e-12;
return tflops;
};
// run
std::vector<unsigned> params = {
16, 2, 64,
32, 2, 64,
16, 8, 2, 2,
8, 8,
4
};
// jit.autotune("conv", src, benchmark);
jit.add_module("conv", src, params);
triton::driver::kernel* kernel = jit.get_function("conv");
triton::jit::launch_information info = jit.get_launch_info("conv");
std::cout << "Performance: " << benchmark(kernel, info) << " TFLOPS " << std::endl;
stream->read(dc, true, 0, hc);
cpp_conv_nchw(BC, AN, CK, AD, AH, AW, BT, BR, BS, pad_d, pad_h, pad_w, stride_d, stride_h, stride_w, CM, CP, CQ, rc, ha, hb);
for(size_t i = 0; i < M*N; i++)
if(std::abs(hc[i] - rc[i])/std::max(hc[i], rc[i]) > 1e-4){
std::cout << i << " " << hc[i] << " " << rc[i] << std::endl;
exit(EXIT_FAILURE);
}
std::cout << "Pass!" << std::endl;
}

162
examples/cpp/dot.cpp Normal file
View File

@@ -0,0 +1,162 @@
#include <cstring>
#include <cstdio>
#include "common.hpp"
#include "triton/jit.h"
#include "triton/driver/backend.h"
#include "triton/driver/stream.h"
const char* src =
R"(
const tunable int32 TM = {16, 32, 64, 128};
const tunable int32 TN = {16, 32, 64, 128};
const tunable int32 TK = {8};
const tunable int32 GZ = {1};
void matmul(restrict read_only fp32 *A, restrict read_only fp32 *B, fp32 *C,
int32 M, int32 N, int32 K,
int32 lda, int32 ldb, int32 ldc,
int32 *locks, int32 grid0, int32 grid1) {
int32 rxa[TM] = get_global_range[TM](0);
int32 ryb[TN] = get_global_range[TN](1);
int32 rz = get_global_range[1](2);
int32 rka[TK] = 0 ... TK;
int32 rkb[TK] = 0 ... TK;
fp32 c[TM, TN] = 0;
int32 div = K / GZ;
int32 rem = K % GZ;
K = select(rz < rem, div - 1, div);
int32 offk = select(rz < rem, rz*(div + 1), rz*div + rem);
fp32* pa[TM, TK] = A + (offk + rka[newaxis, :])*lda + rxa[:, newaxis];
fp32* pb[TN, TK] = B + (offk + rkb[newaxis, :])*ldb + ryb[:, newaxis];
fp32 a[TM, TK] = *pa;
fp32 b[TN, TK] = *pb;
int32 last_a = ((M*K - 1) - (TM*TK + 1)) / lda;
int32 last_b = ((K*N - 1) - (TN*TK + 1)) / ldb;
last_a = last_a / TK * TK;
last_b = last_b / TK * TK;
int32 bound = K - max(last_a, last_b);
for(int32 k = K; k > bound; k = k - TK){
c = dot(a, trans(b), c);
pa = pa + TK*lda;
pb = pb + TK*ldb;
a = *pa;
b = *pb;
}
int32 rxc[TM] = get_global_range[TM](0);
int32 ryc[TN] = get_global_range[TN](1);
for(int32 k = bound; k > 0; k = k - 1){
int1 checka[TM, 1] = rxc[:, newaxis] < M;
int1 checkb[TN, 1] = ryc[:, newaxis] < N;
fp32* pa[TM, 1] = A + (offk + K - k)*lda + rxc[:, newaxis];
fp32* pb[TN, 1] = B + (offk + K - k)*ldb + ryc[:, newaxis];
fp32 a[TM, 1] = checka ? *pa : 0;
fp32 b[TN, 1] = checkb ? *pb : 0;
c = dot(a, trans(b), c);
}
int32 ridx = get_range_id(0);
int32 ridy = get_range_id(1);
fp32* pc[TM, TN] = C + ryc[newaxis, :]*ldc + rxc[:, newaxis];
int32 *plock = locks + ridx + ridy*grid0;
for(int32 L = __atomic_cas(plock, 0, 1); L == 1; L = __atomic_cas(plock, 0, 1)){}
int32 *pcount = plock + grid0*grid1;
int32 count = *pcount;
int32 countp1 = select(count == GZ - 1, 0, count + 1);
int1 checkc0[TM] = rxc < M;
int1 checkc1[TN] = ryc < N;
int1 checkc[TM, TN] = checkc0[:, newaxis] && checkc1[newaxis, :];
if(count == 0) {
@checkc *pc = c;
*pcount = countp1;
}
else {
@checkc *pc = c + (checkc ? *pc : 0);
*pcount = countp1;
}
__atomic_cas(plock, 1, 0);
}
)";
int main() {
// initialize default compute device
auto context = triton::driver::backend::contexts::get_default();
triton::jit jit(context);
// matrix multiplication parameters
int32_t M = 512, N = 512, K = 512;
std::vector<float> hc(M*N);
std::vector<float> rc(M*N);
std::vector<float> ha(M*K);
std::vector<float> hb(K*N);
std::vector<int32_t> hlocks(2048);
srand(0);
for(size_t i = 0; i < ha.size(); i++)
ha[i] = (float)rand()/RAND_MAX;
for(size_t i = 0; i < hb.size(); i++)
hb[i] = (float)rand()/RAND_MAX;
for(size_t i = 0; i < hc.size(); i++)
hc[i] = 0;
triton::driver::buffer* dc = triton::driver::buffer::create(context, hc.size()*4);
triton::driver::buffer* da = triton::driver::buffer::create(context, ha.size()*4);
triton::driver::buffer* db = triton::driver::buffer::create(context, hb.size()*4);
triton::driver::buffer* dlocks = triton::driver::buffer::create(context, hlocks.size()*4);
triton::driver::stream* stream = triton::driver::stream::create(context);
stream->write(da, true, 0, ha);
stream->write(db, true, 0, hb);
stream->write(dc, true, 0, hc);
stream->synchronize();
// benchmark a given matrix multiplication kernel
auto benchmark = [&](triton::driver::kernel* kernel,
triton::jit::launch_information info) {
// launch info
unsigned TM = info.global_range_size[0];
unsigned TN = info.global_range_size[1];
unsigned nthreads = info.num_threads;
unsigned GZ = jit.get_int("GZ");
std::array<size_t, 3> grid = {(M + TM - 1)/TM, (N + TN - 1)/TN, GZ};
// init locks
stream->write(dlocks, true, 0, hlocks);
// set argument
kernel->setArg(0, da);
kernel->setArg(1, db);
kernel->setArg(2, dc);
kernel->setArg(3, M);
kernel->setArg(4, N);
kernel->setArg(5, K);
kernel->setArg(6, M);
kernel->setArg(7, N);
kernel->setArg(8, M);
kernel->setArg(9, dlocks);
kernel->setArg(10, grid[0]);
kernel->setArg(11, grid[1]);
// dry run
stream->enqueue(kernel, grid, {nthreads, 1, 1});
stream->synchronize();
// benchmark
double ts = bench([&](){stream->enqueue(kernel, grid, {nthreads, 1, 1});},
[&](){ stream->synchronize(); }, *context->device());
ts = ts * 1e-9;
double tflops = 2.*M*N*K / ts * 1e-12;
return tflops;
};
// just-in-time compile source-code
std::vector<unsigned> params = {
16, 2, 64, 16, 2, 64, 16, 8, 2, 2, 8, 8, 8, 1
};
// jit.autotune("matmul",src, benchmark);
jit.add_module("matmul", src, params);
triton::driver::kernel* kernel = jit.get_function("matmul");
triton::jit::launch_information info = jit.get_launch_info("matmul");
std::cout << "Performance: " << benchmark(kernel, info) << " TFLOPS " << std::endl;
stream->read(dc, true, 0, hc);
simple_gemm<float,false,true>(rc, ha, hb, M, N, K);
for(size_t i = 0; i < M*N; i++)
if(std::abs(hc[i] - rc[i])/std::max(hc[i], rc[i]) > 1e-4){
std::cout << i << " " << hc[i] << " " << rc[i] << std::endl;
exit(EXIT_FAILURE);
}
std::cout << "Pass!" << std::endl;
}

212
examples/cpp/shift.cpp Normal file
View File

@@ -0,0 +1,212 @@
#include <cstring>
#include <cstdio>
#include "common.hpp"
#include "triton/jit.h"
#include "triton/driver/backend.h"
#include "triton/driver/stream.h"
// K = channels
// M = batch * height * width
// N = number of feature maps
const char* src =
R"(
const tunable int32 TM = {16, 32, 64, 128};
const tunable int32 TN = {16, 32, 64, 128};
const tunable int32 TK = {8};
__constant__ int32* delta = alloc_const int32[256];
__constant__ int32* masks = alloc_const int32[8192];
void shift(restrict read_only fp32 *a, restrict read_only fp32 *b, fp32 *c,
int32 M, int32 N, int32 K,
int32 ABS, int32 AH, int32 AW, int32 AR, int32 AS){
int32 rxa[TM] = get_global_range[TM](0);
int32 ryb[TN] = get_global_range[TN](1);
int32 rka[TK] = 0 ... TK;
int32 rkb[TK] = 0 ... TK;
fp32 C[TM, TN] = 0;
fp32* pxa[TM, TK] = a + rxa[:, newaxis];
fp32* pb[TN, TK] = b + rkb[newaxis, :]*N + ryb[:, newaxis];
__constant__ int32* pd[TK] = delta + rka;
int32 pad_h = AR/2;
int32 pad_w = AS/2;
int32 rawhc[TM] = rxa / ABS;
int32 raw[TM] = rawhc % AW - pad_w;
int32 rahc[TM] = rawhc / AW;
int32 rah[TM] = rahc % AH - pad_h;
int32 maskh[TM] = pad_h + min(rah, 0) + max(rah + AR - AH, 0);
int32 maskw[TM] = pad_w + min(raw, 0) + max(raw + AS - AW, 0);
__constant__ int32* pxm[TM] = masks + maskh*K + maskw*K*(2*pad_h + 1);
__constant__ int32* pm[TM, TK] = pxm[:, newaxis] + rka[newaxis, :];
for(int32 k = K; k > 0; k = k - TK){
int32 delta[TK] = *pd;
fp32 *pa[TM, TK] = pxa + delta[newaxis, :];
int1 m[TM, TK] = *pm > 0;
fp32 a[TM, TK] = m ? *pa : 0;
fp32 b[TN, TK] = *pb;
C = dot(a, trans(b), C);
pb = pb + TK*N;
pd = pd + TK;
pm = pm + TK;
}
int32 rxc[TM] = get_global_range[TM](0);
int32 ryc[TN] = get_global_range[TN](1);
fp32* pc[TM, TN] = c + ryc[newaxis, :]*M + rxc[:, newaxis];
int1 checkc0[TM] = rxc < M;
int1 checkc1[TN] = ryc < N;
int1 checkc[TM, TN] = checkc0[:, newaxis] && checkc1[newaxis, :];
@checkc *pc = C;
}
)";
std::vector<int32_t> shift_deltas(// strides
int32_t stride_w, int32_t stride_h, int32_t stride_c,
// shift
int32_t C,
const std::vector<int32_t>& shift_h,
const std::vector<int32_t>& shift_w) {
std::vector<int32_t> res(C);
for(unsigned c = 0; c < C; c++){
res[c] = c*stride_c;
res[c] += shift_h[c]*stride_h;
res[c] += shift_w[c]*stride_w;
}
return res;
}
std::vector<int32_t> shift_masks(int32_t C,
const std::vector<int32_t>& shift_h,
const std::vector<int32_t>& shift_w,
int32_t R, int32_t S) {
size_t S0 = C;
size_t S1 = R;
size_t S2 = S;
std::vector<int32_t> res(S0*S1*S2);
for(size_t ph = 0; ph < S1; ++ph)
for(size_t pw = 0; pw < S2; ++pw){
int32_t* ptr = &res[ph*S0 + pw*S0*S1];
for(size_t i = 0; i < S0; ++i){
bool in_bounds_h = shift_h[i] + ph >= 0 && shift_h[i] + ph < R;
bool in_bounds_w = shift_w[i] + pw >= 0 && shift_w[i] + pw < S;
ptr[i] = in_bounds_h && in_bounds_w;
}
}
return res;
}
int main() {
// initialize default compute device
auto context = triton::driver::backend::contexts::get_default();
// initialize just-in-time compiler
triton::jit jit(context);
// initialization
int32_t R = 3, S = 3;
int32_t BS = 4, F = 128;
int32_t H = 32, W = 32;
int32_t C = 128;
// equivalent matmul dimensions
int32_t M = BS*H*W;
int32_t N = F;
int32_t K = C;
std::cout << M << " " << N << " " << K << std::endl;
std::vector<float> hc(BS*H*W*F);
std::vector<float> rc(BS*H*W*F);
std::vector<float> ha(BS*C*H*W);
std::vector<float> hb(F*C);
// strides
int32_t stride_i_bs = 1;
int32_t stride_i_w = BS*stride_i_bs;
int32_t stride_i_h = W*stride_i_w;
int32_t stride_i_c = H*stride_i_h;
// random shifts
std::vector<int32_t> shift_h(C);
std::vector<int32_t> shift_w(C);
for(int32_t c = 0; c < C; c++){
shift_h[c] = rand() % R - R/2;
shift_w[c] = rand() % S - S/2;
}
// initialize buffers
srand(0);
for(int c = 0 ; c < C; c++)
for(int h = 0 ; h < H; h++)
for(int w = 0 ; w < W; w++)
for(int bs = 0 ; bs < BS; bs++){
float value = (float)rand() / RAND_MAX;
size_t idx = bs + w*stride_i_w + h*stride_i_h + c*stride_i_c;
ha[idx] = value;
}
for(size_t i = 0; i < hb.size(); i++)
hb[i] = (float)rand() / RAND_MAX;
for(size_t i = 0; i < hc.size(); i++)
hc[i] = 0;
triton::driver::buffer* dc = triton::driver::buffer::create(context, hc.size()*4);
triton::driver::buffer* da = triton::driver::buffer::create(context, ha.size()*4);
triton::driver::buffer* db = triton::driver::buffer::create(context, hb.size()*4);
triton::driver::stream* stream = triton::driver::stream::create(context);
stream->write(da, true, 0, ha);
stream->write(db, true, 0, hb);
stream->write(dc, true, 0, hc);
stream->synchronize();
std::vector<int32_t> h_delta = shift_deltas(stride_i_w, stride_i_h, stride_i_c, C, shift_h, shift_w);
std::vector<int32_t> h_masks = shift_masks(C, shift_h, shift_w, R, S);
// benchmark a given matrix multiplication kernel
auto benchmark = [&](triton::driver::kernel* kernel,
triton::jit::launch_information info) {
// launch info
unsigned TM = info.global_range_size[0];
unsigned TN = info.global_range_size[1];
unsigned nthreads = info.num_threads;
// initialize constant memory
triton::driver::buffer* delta = jit.get_buffer("delta");
triton::driver::buffer* masks = jit.get_buffer("masks");
stream->write(delta, false, 0, h_delta.size()*4, h_delta.data());
stream->write(masks, false, 0, h_masks.size()*4, h_masks.data());
stream->synchronize();
// set argument
kernel->setArg(0, da);
kernel->setArg(1, db);
kernel->setArg(2, dc);
kernel->setArg(3, M);
kernel->setArg(4, N);
kernel->setArg(5, K);
kernel->setArg(6, BS);
kernel->setArg(7, H);
kernel->setArg(8, W);
kernel->setArg(9, R);
kernel->setArg(10, S);
// dry run
std::array<size_t, 3> grid = {(M + TM - 1)/TM, (N + TN - 1)/TN, 1};
stream->enqueue(kernel, grid, {nthreads, 1, 1});
stream->synchronize();
// benchmark
double ts = bench([&](){stream->enqueue(kernel, grid, {nthreads, 1, 1});},
[&](){ stream->synchronize(); }, *context->device());
ts = ts * 1e-9;
double tflops = 2.*M*N*K / ts * 1e-12;
return tflops;
};
// shift
std::vector<unsigned> params = {
16, 2, 64,
32, 2, 64,
16, 8, 2, 2,
8, 8,
4
};
// jit.autotune("shift", src, benchmark);
jit.add_module("shift", src, params);
triton::driver::kernel* kernel = jit.get_function("shift");
triton::jit::launch_information info = jit.get_launch_info("shift");
std::cout << "Performance: " << benchmark(kernel, info) << " TFLOPS " << std::endl;
stream->read(dc, true, 0, hc);
shift_conv(C, H, W, BS, F, rc, ha, hb, shift_h, shift_w);
for(size_t i = 0; i < M*N; i++)
if(std::abs(hc[i] - rc[i])/std::max(hc[i], rc[i]) > 1e-4){
std::cout << i << " " << hc[i] << " " << rc[i] << std::endl;
exit(EXIT_FAILURE);
}
std::cout << "Pass!" << std::endl;
}

93
examples/cpp/shift.ptx Normal file
View File

@@ -0,0 +1,93 @@
//
// Generated by NVIDIA NVVM Compiler
//
// Compiler Build ID: CL-24817639
// Cuda compilation tools, release 10.0, V10.0.130
// Based on LLVM 3.4svn
//
.version 6.3
.target sm_60
.address_size 64
// .globl _Z25shift_cuda_forward_kernelPKfPKiPfiiii
.visible .entry shift(
.param .u64 _Z25shift_cuda_forward_kernelPKfPKiPfiiii_param_0,
.param .u64 _Z25shift_cuda_forward_kernelPKfPKiPfiiii_param_1,
.param .u64 _Z25shift_cuda_forward_kernelPKfPKiPfiiii_param_2,
.param .u32 _Z25shift_cuda_forward_kernelPKfPKiPfiiii_param_3,
.param .u32 _Z25shift_cuda_forward_kernelPKfPKiPfiiii_param_4,
.param .u32 _Z25shift_cuda_forward_kernelPKfPKiPfiiii_param_5,
.param .u32 _Z25shift_cuda_forward_kernelPKfPKiPfiiii_param_6
)
{
.reg .pred %p<10>;
.reg .f32 %f<2>;
.reg .b32 %r<31>;
.reg .b64 %rd<13>;
ld.param.u64 %rd1, [_Z25shift_cuda_forward_kernelPKfPKiPfiiii_param_0];
ld.param.u64 %rd3, [_Z25shift_cuda_forward_kernelPKfPKiPfiiii_param_1];
ld.param.u64 %rd2, [_Z25shift_cuda_forward_kernelPKfPKiPfiiii_param_2];
ld.param.u32 %r3, [_Z25shift_cuda_forward_kernelPKfPKiPfiiii_param_3];
ld.param.u32 %r4, [_Z25shift_cuda_forward_kernelPKfPKiPfiiii_param_4];
ld.param.u32 %r5, [_Z25shift_cuda_forward_kernelPKfPKiPfiiii_param_5];
ld.param.u32 %r6, [_Z25shift_cuda_forward_kernelPKfPKiPfiiii_param_6];
cvta.to.global.u64 %rd4, %rd3;
mov.u32 %r7, %ntid.x;
mov.u32 %r8, %ctaid.x;
mov.u32 %r9, %tid.x;
mad.lo.s32 %r1, %r7, %r8, %r9;
mul.lo.s32 %r10, %r4, %r3;
mul.lo.s32 %r11, %r10, %r5;
mul.lo.s32 %r12, %r11, %r6;
mul.lo.s32 %r13, %r5, %r4;
mul.lo.s32 %r14, %r13, %r6;
rem.s32 %r15, %r1, %r14;
sub.s32 %r16, %r1, %r15;
mul.lo.s32 %r17, %r6, %r5;
div.s32 %r18, %r15, %r17;
mul.lo.s32 %r19, %r18, %r17;
sub.s32 %r20, %r15, %r19;
div.s32 %r21, %r20, %r5;
mul.lo.s32 %r22, %r21, %r6;
sub.s32 %r23, %r20, %r22;
shl.b32 %r24, %r18, 1;
mul.wide.s32 %rd5, %r24, 4;
add.s64 %rd6, %rd4, %rd5;
ld.global.nc.u32 %r25, [%rd6];
add.s32 %r26, %r25, %r21;
ld.global.nc.u32 %r27, [%rd6+4];
add.s32 %r28, %r23, %r27;
add.s32 %r29, %r16, %r19;
mad.lo.s32 %r30, %r26, %r5, %r29;
add.s32 %r2, %r30, %r28;
setp.lt.s32 %p1, %r1, %r12;
setp.gt.s32 %p2, %r26, -1;
and.pred %p3, %p1, %p2;
setp.lt.s32 %p4, %r26, %r5;
and.pred %p5, %p3, %p4;
setp.gt.s32 %p6, %r28, -1;
and.pred %p7, %p5, %p6;
setp.lt.s32 %p8, %r28, %r6;
and.pred %p9, %p7, %p8;
@!%p9 bra BB0_2;
bra.uni BB0_1;
BB0_1:
cvta.to.global.u64 %rd7, %rd1;
mul.wide.s32 %rd8, %r1, 4;
add.s64 %rd9, %rd7, %rd8;
ld.global.nc.f32 %f1, [%rd9];
cvta.to.global.u64 %rd10, %rd2;
mul.wide.s32 %rd11, %r2, 4;
add.s64 %rd12, %rd10, %rd11;
st.global.f32 [%rd12], %f1;
BB0_2:
ret;
}

View File

@@ -74,8 +74,8 @@ class constant;
class node {
protected:
static ir::value* explicit_cast(ir::builder &builder, ir::value *src, ir::type *dst_ty);
static void implicit_broadcast(ir::module *mod, ir::type *dst_ty, ir::value *&src);
static void implicit_broadcast(ir::module *mod, ir::value *&lhs, ir::value *&rhs);
static void implicit_broadcast(ir::module *mod, ir::value *&arg, ir::type *ty);
static void implicit_cast(ir::builder &builder, ir::value *&lhs, ir::value *&rhs,
bool &is_float, bool &is_ptr, bool &is_int, bool &is_signed);
public:
@@ -164,6 +164,27 @@ private:
const constant* axis_;
};
class get_range_id: public builtin_expression{
public:
get_range_id(node *axis): axis_((constant*)axis) { }
ir::value* codegen(ir::module *) const;
private:
const constant* axis_;
};
class atomic_cas: public builtin_expression{
public:
atomic_cas(node *ptr, node *cmp, node *val): ptr_(ptr), cmp_(cmp), val_(val) { }
ir::value* codegen(ir::module *) const;
private:
const node *ptr_;
const node *cmp_;
const node *val_;
};
class matmul_expression: public builtin_expression{
public:
matmul_expression(node* A, node *B, node *C):
@@ -176,6 +197,49 @@ private:
const expression *C_;
};
class max_expression: public builtin_expression{
public:
max_expression(node* x, node* y)
: x_((expression*)x), y_((expression*)y){ }
ir::value* codegen(ir::module *) const;
private:
const expression *x_;
const expression *y_;
};
class min_expression: public builtin_expression{
public:
min_expression(node* x, node* y)
: x_((expression*)x), y_((expression*)y){ }
ir::value* codegen(ir::module *mod) const;
private:
const expression *x_;
const expression *y_;
};
class select_expression: public builtin_expression{
public:
select_expression(node* pred, node* if_value, node* else_value)
: pred_((expression*)pred), if_value_((expression*)if_value), else_value_((expression*)else_value) { }
ir::value* codegen(ir::module *mod) const;
private:
const expression *pred_;
const expression *if_value_;
const expression *else_value_;
};
class trans_expression: public builtin_expression{
public:
trans_expression(node *arg): arg_(arg) {}
ir::value* codegen(ir::module *mod) const;
private:
node* arg_;
};
class indexing_expression: public postfix_expression{
public:
@@ -189,6 +253,8 @@ private:
const list<slice*>* slices_;
};
class named_expression: public expression {
public:
named_expression(node *id): id_((const identifier*)id) { lvalue_ = this; }

View File

@@ -55,7 +55,7 @@ STORAGE_SPEC_T get_storage_spec(node *op) { return ((token*)op)->storage_spec;}
%token VOID UINT1 UINT8 UINT16 UINT32 UINT64 INT1 INT8 INT16 INT32 INT64 FP32 FP64
%token IF ELSE FOR CONTINUE
%token NEWAXIS ELLIPSIS AT
%token GET_GLOBAL_RANGE DOT ALLOC_CONST
%token GET_GLOBAL_RANGE GET_RANGE_ID DOT TRANS MAX MIN SELECT ATOMIC_CAS ALLOC_CONST
%start translation_unit
%%
@@ -118,8 +118,15 @@ identifier
builtin
: GET_GLOBAL_RANGE '[' primary_expression ']' '(' constant ')' { $$ = new get_global_range($3, $6); }
| GET_RANGE_ID '(' constant ')' { $$ = new get_range_id($3); }
| DOT '(' expression ',' expression ',' expression ')' { $$ = new matmul_expression($3, $5, $7); }
| ALLOC_CONST type_specifier '[' constant ']' { $$ = new alloc_const(new typed_declaration_specifier(get_type_spec($2)), $4); }
| TRANS '(' expression ')' { $$ = new trans_expression($3); }
| MAX '(' expression ',' expression ')' { $$ = new max_expression($3, $5); }
| MIN '(' expression ',' expression ')' { $$ = new min_expression($3, $5); }
| SELECT '(' expression ',' expression ',' expression ')' { $$ = new select_expression($3, $5, $7); }
| ATOMIC_CAS '(' expression ',' expression ',' expression ')' { $$ = new atomic_cas($3, $5, $7); }
;
primary_expression
: identifier { $$ = new named_expression($1); }

View File

@@ -41,7 +41,13 @@ using triton::ast::return_void;
"fp64" { return return_impl(FP64, yytext); }
"..." { return return_impl(ELLIPSIS, yytext); }
"get_global_range" { return return_impl(GET_GLOBAL_RANGE, yytext); }
"get_range_id" { return return_impl(GET_RANGE_ID, yytext); }
"__atomic_cas" { return return_impl(ATOMIC_CAS, yytext); }
"dot" { return return_impl(DOT, yytext); }
"max" { return return_impl(MAX, yytext); }
"min" { return return_impl(MIN, yytext); }
"select" { return return_impl(SELECT, yytext); }
"trans" { return return_impl(TRANS, yytext); }
"continue" { return return_impl(CONTINUE, yytext); }
"alloc_const" { return return_impl(ALLOC_CONST, yytext); }
{L}({L}|{D})* { return return_impl(IDENTIFIER, yytext); }
@@ -52,8 +58,6 @@ using triton::ast::return_void;
L?'(\\.|[^\\'])+' { return return_impl(CONSTANT, yytext); }
{D}+{E}{FS}? { return return_impl(CONSTANT, yytext); }
{D}*"."{D}+({E})?{FS}? { return return_impl(CONSTANT, yytext); }
{D}+"."{D}*({E})?{FS}? { return return_impl(CONSTANT, yytext); }
L?\"(\\.|[^\\"])*\" { return return_impl(STRING_LITERAL, yytext); }

View File

@@ -1,45 +0,0 @@
#ifndef TDL_INCLUDE_IR_CODEGEN_LAYOUT_H
#define TDL_INCLUDE_IR_CODEGEN_LAYOUT_H
#include <vector>
#include <map>
namespace triton {
namespace ir {
class module;
class instruction;
class value;
}
namespace codegen{
struct shared_view_info{
ir::value *usr;
bool has_dedicated_storage;
};
class layout {
private:
typedef std::vector<shared_view_info> shared_view_val_t;
void add_phi_nodes(ir::value *v);
void add_shared_views(ir::value *v);
public:
// accessors
unsigned get_num_shared_views(ir::value *v);
shared_view_info get_shared_view(ir::value *v, unsigned idx);
// run
void run(ir::module &mod);
private:
std::map<ir::value*, shared_view_val_t> shared_views_;
};
}
}
#endif

View File

@@ -0,0 +1,27 @@
#ifndef TDL_INCLUDE_CODEGEN_OPTIMIZE_CSE_H
#define TDL_INCLUDE_CODEGEN_OPTIMIZE_CSE_H
#include <tuple>
#include <vector>
#include <set>
namespace triton {
namespace ir {
class module;
}
namespace codegen{
class tune;
class optimize_cse {
public:
optimize_cse() {}
void run(ir::module &mod);
};
}
}
#endif

View File

@@ -0,0 +1,31 @@
#ifndef TDL_INCLUDE_CODEGEN_OPTIMIZE_DOT_H
#define TDL_INCLUDE_CODEGEN_OPTIMIZE_DOT_H
#include <tuple>
#include <vector>
#include <set>
namespace triton {
namespace ir {
class module;
}
namespace codegen{
class tune;
class optimize_dot {
public:
optimize_dot(tune* params): params_(params) {}
void run(ir::module &mod);
private:
tune* params_;
};
}
}
#endif

View File

@@ -0,0 +1,33 @@
#ifndef TDL_INCLUDE_CODEGEN_OPTIMIZE_TRANS_H
#define TDL_INCLUDE_CODEGEN_OPTIMIZE_TRANS_H
#include <tuple>
#include <vector>
#include <set>
namespace triton {
namespace ir {
class module;
class value;
class instruction;
class trans_inst;
class builder;
}
namespace codegen{
class optimize_trans {
private:
ir::value *replace_phi(ir::value* value, std::vector<ir::instruction*>& to_delete, ir::builder &builder);
public:
optimize_trans() {}
void run(ir::module &mod);
};
}
}
#endif

View File

@@ -7,7 +7,7 @@
#include "triton/ir/module.h"
#include "triton/ir/function.h"
#include "triton/ir/type.h"
#include "triton/codegen/buffer_info.h"
#include "triton/codegen/shmem_info.h"
namespace llvm{
@@ -21,9 +21,9 @@ namespace llvm{
namespace triton{
namespace codegen{
class allocation;
class shmem_allocation;
class tune;
class buffer_info_pass;
class shmem_info;
class target;
typedef std::vector<llvm::Value*> indices_t;
@@ -129,7 +129,7 @@ private:
void lower_tile_instruction(ir::instruction *src, llvm::IRBuilder<> &builder);
public:
selection(allocation *alloc, tune *params, buffer_info_pass *buffer_info, target *tgt)
selection(shmem_allocation *alloc, tune *params, shmem_info *buffer_info, target *tgt)
: alloc_(alloc), params_(params), buffer_info_(buffer_info), tgt_(tgt){ }
void run(ir::module &src, llvm::Module &dst);
@@ -139,11 +139,12 @@ private:
tmap_t tmap_;
pmap_t pmap_;
pmap_t last_block_;
allocation *alloc_;
shmem_allocation *alloc_;
tune *params_;
target *tgt_;
buffer_info_pass *buffer_info_;
shmem_info *buffer_info_;
std::map<ir::metaparameter*, distributed_axis> axes_;
llvm::Value *sh_mem_ptr_;
};
}

View File

@@ -1,41 +0,0 @@
#ifndef TDL_INCLUDE_CODEGEN_SHARED_COPY_H
#define TDL_INCLUDE_CODEGEN_SHARED_COPY_H
#include <tuple>
#include <vector>
namespace triton {
namespace ir {
class module;
class value;
class builder;
class basic_block;
}
namespace codegen{
class buffer_info_pass;
class place_shared_copy {
private:
typedef std::pair<unsigned, unsigned> interval_t;
typedef std::vector<interval_t> interval_vec_t;
private:
bool intersect(const interval_vec_t &I, interval_t i);
void add_copy(ir::value *x, ir::builder &builder);
public:
place_shared_copy(buffer_info_pass *info): info_(info) { }
void run(ir::module &mod);
private:
buffer_info_pass *info_;
};
}
}
#endif

View File

@@ -16,12 +16,12 @@ namespace codegen{
class layout;
class target_tuner;
class liveness;
class buffer_info_pass;
class shmem_liveness;
class shmem_info;
class allocation {
class shmem_allocation {
public:
allocation(liveness *live, buffer_info_pass *buffer_info)
shmem_allocation(shmem_liveness *live, shmem_info *buffer_info)
: liveness_(live), buffer_info_(buffer_info){ }
// utilities
@@ -39,8 +39,8 @@ private:
std::map<ir::value*, unsigned> num_bytes_;
size_t allocated_size_;
// dependences
liveness *liveness_;
buffer_info_pass *buffer_info_;
shmem_liveness *liveness_;
shmem_info *buffer_info_;
};
}

View File

@@ -17,10 +17,10 @@ namespace ir {
namespace codegen{
class allocation;
class buffer_info_pass;
class shmem_allocation;
class shmem_info;
class barriers {
class shmem_barriers {
private:
typedef std::pair<unsigned, unsigned> interval_t;
typedef std::vector<interval_t> interval_vec_t;
@@ -36,12 +36,12 @@ private:
std::pair<interval_vec_t, interval_vec_t> transfer(ir::basic_block *block, const interval_vec_t &written_to, const interval_vec_t &read_from, std::set<ir::instruction *> &insert_loc);
public:
barriers(allocation *alloc, buffer_info_pass *buffer_info): alloc_(alloc), buffer_info_(buffer_info) {}
shmem_barriers(shmem_allocation *alloc, shmem_info *buffer_info): alloc_(alloc), buffer_info_(buffer_info) {}
void run(ir::module &mod);
private:
allocation *alloc_;
buffer_info_pass *buffer_info_;
shmem_allocation *alloc_;
shmem_info *buffer_info_;
};

View File

@@ -10,18 +10,19 @@ namespace ir {
class module;
class value;
class phi_node;
class instruction;
}
namespace codegen{
class buffer_info_pass {
class shmem_info {
public:
void run(ir::module &mod);
// queries
bool is_double(ir::value *x);
void add_shared(ir::value *v);
bool is_shared(ir::value *x);
bool is_loop_latch(ir::phi_node *phi, ir::value *terminator);
bool is_loop_latch(ir::phi_node *phi, ir::instruction *terminator);
ir::value *get_reference(ir::value *x);
void replace(ir::value* before, ir::value *after);

View File

@@ -15,7 +15,7 @@ namespace codegen{
typedef unsigned slot_index;
class buffer_info_pass;
class shmem_info;
struct segment {
slot_index start;
@@ -30,7 +30,7 @@ struct segment {
}
};
class liveness {
class shmem_liveness {
private:
typedef std::map<ir::value*, slot_index> indices_map_t;
typedef std::map<ir::value*, segment> intervals_map_t;
@@ -43,7 +43,7 @@ public:
public:
// constructor
liveness(buffer_info_pass *info): info_(info){ }
shmem_liveness(shmem_info *info): info_(info){ }
// accessors
const intervals_map_t& intervals() const { return intervals_; }
@@ -53,7 +53,7 @@ public:
void run(ir::module &mod);
private:
buffer_info_pass *info_;
shmem_info *info_;
has_storage_map_t has_dedicated_storage_;
indices_map_t indices_;
intervals_map_t intervals_;

View File

@@ -24,6 +24,7 @@ public:
virtual llvm::Instruction* add_barrier(llvm::Module *module, llvm::IRBuilder<>& builder) = 0;
virtual llvm::Value* get_global_offset(llvm::Module *module, llvm::IRBuilder<>& builder, unsigned stride, unsigned ax) = 0;
virtual llvm::Value* get_local_id(llvm::Module *module, llvm::IRBuilder<>& builder, unsigned ax) = 0;
virtual llvm::Value* get_block_id(llvm::Module *module, llvm::IRBuilder<>& builder, unsigned ax) = 0;
bool is_gpu() const;
private:
@@ -37,6 +38,7 @@ public:
llvm::Instruction* add_barrier(llvm::Module *module, llvm::IRBuilder<>& builder);
llvm::Value* get_global_offset(llvm::Module *module, llvm::IRBuilder<>& builder, unsigned stride, unsigned ax);
llvm::Value* get_local_id(llvm::Module *module, llvm::IRBuilder<>& builder, unsigned ax);
llvm::Value* get_block_id(llvm::Module *module, llvm::IRBuilder<>& builder, unsigned ax);
};
class nvidia_cu_target: public target {
@@ -46,6 +48,7 @@ public:
llvm::Instruction* add_barrier(llvm::Module *module, llvm::IRBuilder<>& builder);
llvm::Value* get_global_offset(llvm::Module *module, llvm::IRBuilder<>& builder, unsigned stride, unsigned ax);
llvm::Value* get_local_id(llvm::Module *module, llvm::IRBuilder<>& builder, unsigned ax);
llvm::Value* get_block_id(llvm::Module *module, llvm::IRBuilder<>& builder, unsigned ax);
};
class cpu_target: public target {
@@ -55,6 +58,7 @@ public:
llvm::Instruction* add_barrier(llvm::Module *module, llvm::IRBuilder<>& builder);
llvm::Value* get_global_offset(llvm::Module *module, llvm::IRBuilder<>& builder, unsigned stride, unsigned ax);
llvm::Value* get_local_id(llvm::Module *module, llvm::IRBuilder<>& builder, unsigned ax);
llvm::Value* get_block_id(llvm::Module *module, llvm::IRBuilder<>& builder, unsigned ax);
};
}

View File

@@ -90,7 +90,7 @@ class cu_module: public module {
public:
cu_module(driver::context* context, llvm::Module *module);
cu_module(driver::context* context, const std::string& source);
cu_buffer symbol(const char * name) const;
cu_buffer* symbol(const char * name) const;
private:
std::string source_;

View File

@@ -67,6 +67,7 @@ public:
value* create_fp_ext(value *src, type *dst_ty, const std::string &name = "");
value* create_fp_trunc(value *src, type *dst_ty, const std::string &name = "");
value* create_int_cast(value *src, type *dst_ty, bool is_signed, const std::string &name = "");
value *create_downcast(value *arg, const std::string &name = "");
// Phi instruction
phi_node* create_phi(type *ty, unsigned num_reserved, const std::string &name = "");
// Binary instructions
@@ -124,7 +125,11 @@ public:
value *create_broadcast(value *arg, const type::tile_shapes_t &shapes, const std::string &name = "");
// Built-in instruction
value *create_get_global_range(unsigned axis, type::tile_shapes_t::value_type size, const std::string &name = "");
value *create_matmul(value *A, value *B, value *C, const std::string &name = "");
value *create_get_range_id(unsigned axis, const std::string &name = "");
value *create_atomic_cas(value *ptr, value *cmp, value *val, const std::string &name = "");
value *create_dot(value *A, value *B, value *C, const std::string &name = "");
value *create_trans(value *A, const std::string &name = "");
value *create_select(value *pred, value *if_value, value *else_value, const std::string &name = "");
// Intrinsics
value *create_copy_to_shared(value *arg, const std::string &name = "");
value *create_vectorize(value *arg, const std::string &name = "");

View File

@@ -54,6 +54,7 @@ public:
void set_value(uint64_t value) { has_value_ = true; value_ = value; }
bool has_value() { return has_value_; }
const std::vector<unsigned>& get_space() { return space_; }
void set_space(const std::vector<unsigned> &space) { space_ = space; }
private:
std::vector<unsigned> space_;

View File

@@ -464,6 +464,17 @@ public:
};
// downcast
class downcast_inst: public unary_inst {
private:
using unary_inst::unary_inst;
std::string repr_impl() const { return "downcast"; }
public:
static instruction* create(value *arg, const std::string &name = "", instruction *next = nullptr);
};
//===----------------------------------------------------------------------===//
// builtin_inst classes
//===----------------------------------------------------------------------===//
@@ -488,17 +499,76 @@ private:
unsigned axis_;
};
class matmul_inst: public builtin_inst {
class get_range_id_inst: public builtin_inst {
private:
matmul_inst(value *A, value *B, value *C, const std::string &name, instruction *next);
std::string repr_impl() const { return "dot"; }
get_range_id_inst(type *ty, unsigned axis, const std::string &name, instruction *next);
std::string repr_impl() const { return "get_range_id(" + std::to_string(axis_) + ")"; }
public:
static instruction* create(value *A, value *B, value *C,
const std::string &name = "",
instruction *next = nullptr);
static instruction* create(context &ctx, unsigned axis, const std::string &name = "", instruction *next = nullptr);
unsigned get_axis() const { return axis_; }
private:
unsigned axis_;
};
class atomic_cas_inst: public builtin_inst {
private:
atomic_cas_inst(value *ptr, value *cmp, value *val, const std::string &name, instruction *next);
std::string repr_impl() const { return "atomic_cas"; }
public:
static instruction* create(value *ptr, value *cmp, value *val, const std::string &name = "", instruction *next = nullptr);
};
class dot_inst: public builtin_inst {
public:
enum TransT { NoTrans, Trans };
private:
dot_inst(value *A, value *B, value *C, TransT AT, TransT BT, const std::string &name, instruction *next);
std::string repr_impl() const { return std::string("dot.") + ((AT_==NoTrans)?"n":"t") + ((BT_==NoTrans)?"n":"t"); }
public:
static instruction* create_nn(value *A, value *B, value *C, const std::string &name = "", instruction *next = nullptr);
static instruction* create_nt(value *A, value *B, value *C, const std::string &name = "", instruction *next = nullptr);
static instruction* create_tn(value *A, value *B, value *C, const std::string &name = "", instruction *next = nullptr);
static instruction* create_tt(value *A, value *B, value *C, const std::string &name = "", instruction *next = nullptr);
bool is_a_trans() { return AT_ == Trans; }
bool is_b_trans() { return BT_ == Trans; }
private:
TransT AT_;
TransT BT_;
};
//class outer_inst: public builtin_inst {
//private:
// outer_inst(value *A, value *B, value *C, const std::string &name, instruction *next);
//public:
// static instruction* create(value *A, value *B, value *C, const std::string &name = "", instruction *next = nullptr);
//};
class trans_inst: public builtin_inst {
public:
ir::type* get_res_ty(ir::type* in);
private:
trans_inst(value *arg, const std::string& name, instruction* next);
std::string repr_impl() const { return "trans"; }
public:
static instruction* create(value *arg, const std::string &name = "", instruction *next = nullptr);
};
class select_inst: public builtin_inst {
private:
select_inst(value *pred, value *if_value, value *else_value, const std::string& name, instruction* next);
std::string repr_impl() const { return "select"; }
public:
static instruction* create(value *pred, value *if_value, value *else_value, const std::string &name = "", instruction *next = nullptr);
};
//===----------------------------------------------------------------------===//
// intrinsics classes

View File

@@ -66,6 +66,7 @@ public:
// Getters
value *get_value(const std::string& name, basic_block* block);
value *get_value(const std::string& name);
const std::string& get_name();
std::function<ir::value*()> get_continue_fn();
// Seal block -- no more predecessors will be added
void seal_block(basic_block *block);

View File

@@ -10,13 +10,15 @@
#include "triton/driver/kernel.h"
#include "triton/codegen/selection.h"
#include "triton/codegen/tune.h"
#include "triton/codegen/shared_copy.h"
#include "triton/codegen/allocation.h"
#include "triton/codegen/liveness.h"
#include "triton/codegen/vectorize.h"
#include "triton/codegen/buffer_info.h"
#include "triton/codegen/barriers.h"
#include "triton/codegen/optimize_dot.h"
#include "triton/codegen/optimize_cse.h"
#include "triton/codegen/optimize_trans.h"
#include "triton/codegen/shmem_allocation.h"
#include "triton/codegen/shmem_liveness.h"
#include "triton/codegen/shmem_info.h"
#include "triton/codegen/shmem_barriers.h"
#include "triton/codegen/target.h"
#include "triton/codegen/vectorize.h"
#include <functional>
namespace llvm {
@@ -45,48 +47,59 @@ public:
struct passes_wrapper {
passes_wrapper(codegen::target* target)
: shared(&buffer_info), liveness(&buffer_info),
allocation(&liveness, &buffer_info),
barriers(&allocation, &buffer_info),
: shmem_liveness(&shmem_info),
shmem_allocation(&shmem_liveness, &shmem_info),
shmem_barriers(&shmem_allocation, &shmem_info),
vectorize(&tune),
selection(&allocation, &tune, &buffer_info, target),
selection(&shmem_allocation, &tune, &shmem_info, target),
optimize_dot(&tune),
optimize_cse(),
optimize_trans(),
target_(target) { }
void init(ir::module &module) {
void target_independent(ir::module &module) {
optimize_dot.run(module);
optimize_trans.run(module);
// ir::print(module, std::cout);
}
void target_dependent(ir::module &module) {
if(target_->is_gpu()){
buffer_info.run(module);
shared.run(module);
liveness.run(module);
allocation.run();
barriers.run(module);
shmem_info.run(module);
shmem_liveness.run(module);
shmem_allocation.run();
shmem_barriers.run(module);
}
vectorize.run(module);
}
codegen::tune tune;
codegen::buffer_info_pass buffer_info;
codegen::place_shared_copy shared;
codegen::liveness liveness;
codegen::allocation allocation;
codegen::barriers barriers;
codegen::shmem_info shmem_info;
codegen::shmem_liveness shmem_liveness;
codegen::shmem_allocation shmem_allocation;
codegen::shmem_barriers shmem_barriers;
codegen::vectorize vectorize;
codegen::selection selection;
codegen::optimize_dot optimize_dot;
codegen::optimize_cse optimize_cse;
codegen::optimize_trans optimize_trans;
codegen::target* target_;
};
private:
std::string compute_data_layout(bool is_64bit = true, bool use_short_pointers = true);
std::unique_ptr<llvm::Module> make_llvm_module(triton::ir::module &module, passes_wrapper &passes);
std::unique_ptr<ir::module> make_triton_module(const std::string &src);
std::unique_ptr<ir::module> make_triton_module(const std::string &name, const std::string &src);
public:
jit(driver::context* context);
void autotune(const std::string &src, benchmark_t benchmark);
void autotune(const std::string &name, const std::string &src, benchmark_t benchmark);
void add_module(ir::module &module, const std::vector<unsigned>& params = {});
void add_module(const std::string &src, const std::vector<unsigned>& params = {});
void add_module(const std::string &name, const std::string &src, const std::vector<unsigned>& params = {});
driver::kernel* get_function(const std::string &name);
launch_information get_launch_info(const std::string &name);
unsigned get_int(const std::string &name);
driver::buffer *get_buffer(const std::string &name);
private:
std::vector<driver::module*> modules_;

View File

@@ -95,55 +95,75 @@ void node::implicit_cast(ir::builder &builder, ir::value *&lhs, ir::value *&rhs,
throw std::runtime_error("unreachable");
}
void node::implicit_broadcast(ir::module *mod, ir::value *&arg, ir::type *ty) {
ir::value *tmp = ir::undef_value::get(ty);
implicit_broadcast(mod, arg, tmp);
}
void node::implicit_broadcast(ir::module *mod, ir::value *&lhs, ir::value *&rhs){
ir::builder &builder = mod->get_builder();
void node::implicit_broadcast(ir::module *mod, ir::value *&lhs, ir::value *&rhs) {
ir::type *lhs_ty = lhs->get_type();
ir::type *rhs_ty = rhs->get_type();
ir::type::tile_shapes_t::value_type one = ir::tile_type::make_one(mod->get_context());
// Both are scalar
ir::type *res_ty = nullptr;
if(!lhs_ty->is_tile_ty() && !rhs_ty->is_tile_ty())
return;
// One argument is scalar
if(lhs_ty->is_tile_ty() ^ rhs_ty->is_tile_ty()){
auto &shapes = lhs_ty->is_tile_ty()?lhs_ty->get_tile_shapes():rhs_ty->get_tile_shapes();
auto &scalar = lhs_ty->is_tile_ty()?rhs:lhs;
scalar = builder.create_splat(scalar, shapes);
else if(lhs_ty->is_tile_ty() && !rhs_ty->is_tile_ty())
res_ty = lhs_ty;
else if(!lhs_ty->is_tile_ty() && rhs_ty->is_tile_ty())
res_ty = rhs_ty;
else{
auto lhs_shapes = lhs_ty->get_tile_shapes();
auto rhs_shapes = rhs_ty->get_tile_shapes();
size_t lhs_size = lhs_shapes.size();
size_t rhs_size = rhs_shapes.size();
size_t res_size = std::max(lhs_size, rhs_size);
ir::type::tile_shapes_t res_shapes(res_size);
ir::type::tile_shapes_t::value_type one = ir::tile_type::make_one(mod->get_context());
for(int i = 0; i < res_size; i++){
if(i >= res_size - lhs_size && i >= res_size - rhs_size)
res_shapes[i] = lhs_shapes[i]==one?rhs_shapes[i]:lhs_shapes[i];
else if(i >= res_size - lhs_size)
res_shapes[i] = lhs_shapes[i];
else if(i >= res_size - rhs_size)
res_shapes[i] = rhs_shapes[i];
}
res_ty = ir::tile_type::get(lhs_ty->get_scalar_ty(), res_shapes);
}
implicit_broadcast(mod, res_ty, rhs);
implicit_broadcast(mod, res_ty, lhs);
}
void node::implicit_broadcast(ir::module *mod, ir::type *ty, ir::value *&src){
ir::builder &builder = mod->get_builder();
ir::type *src_ty = src->get_type();
ir::type::tile_shapes_t::value_type one = ir::tile_type::make_one(mod->get_context());
// Both are scalar
if(!ty->is_tile_ty() && !src_ty->is_tile_ty())
return;
// Broadcast scalar
if(ty->is_tile_ty() && !src_ty->is_tile_ty()){
src = builder.create_splat(src, ty->get_tile_shapes());
return;
}
// Downcast tile
if(!ty->is_tile_ty() && src_ty->is_tile_ty()){
for(ir::constant *shape: src_ty->get_tile_shapes())
if(shape != one)
throw std::runtime_error("cannot downcast");
src = builder.create_downcast(src);
return;
}
// Both are arrays
auto lhs_shapes = lhs->get_type()->get_tile_shapes();
auto rhs_shapes = rhs->get_type()->get_tile_shapes();
if(lhs_shapes == rhs_shapes)
return;
int lhs_dim = lhs_shapes.size();
int rhs_dim = rhs_shapes.size();
auto &shortest = (lhs_dim < rhs_dim)?lhs_shapes:rhs_shapes;
auto &longest = (lhs_dim < rhs_dim)?rhs_shapes:lhs_shapes;
size_t ndim = longest.size();
int off = longest.size() - shortest.size();
for(int i = longest.size() - 1; i>= 0; i--){
if(shortest[off + i] != longest[i] && shortest[off + i] != one && longest[i] != one)
throw std::runtime_error("cannot broadcast");
}
auto dst_shapes = ty->get_tile_shapes();
auto src_shapes = src_ty->get_tile_shapes();
int dst_dim = dst_shapes.size();
int src_dim = src_shapes.size();
// Pad
int off = dst_dim - src_dim;
for(size_t i = 0; i < off; i++)
shortest.insert(shortest.begin(), one);
ir::value *&target = (lhs_dim < rhs_dim)?lhs:rhs;
src_shapes.insert(src_shapes.begin(), one);
if(off > 0)
target = builder.create_reshape(target, shortest);
src = builder.create_reshape(src, src_shapes);
// Broadcast
ir::type::tile_shapes_t shapes(ndim);
for(size_t i = 0; i < ndim; i++)
shapes[i] = shortest[i]==one?longest[i]:shortest[i];
if(shapes != lhs_shapes)
lhs = builder.create_broadcast(lhs, shapes);
if(shapes != rhs_shapes)
rhs = builder.create_broadcast(rhs, shapes);
for(int i = dst_dim - 1; i>= 0; i--)
if(dst_shapes[i] != src_shapes[i] && dst_shapes[i] != one && src_shapes[i] != one)
throw std::runtime_error("cannot broadcast");
if(dst_shapes != src_shapes)
src = builder.create_broadcast(src, dst_shapes);
}
/* Helper */
@@ -336,7 +356,9 @@ ir::value* iteration_statement::codegen(ir::module *mod) const{
return builder.create_cond_br(cond, loop_bb, next_bb);
});
init_->codegen(mod);
builder.create_br(loop_bb);
ir::value *cond = stop_->codegen(mod);
builder.create_cond_br(cond, loop_bb, next_bb);
// builder.create_br(loop_bb);
builder.set_insert_point(loop_bb);
if(!is_terminator(statements_->codegen(mod)))
mod->get_continue_fn()();
@@ -378,6 +400,7 @@ ir::value* selection_statement::codegen(ir::module* mod) const{
builder.create_br(endif_bb);
}
// Endif
mod->seal_block(endif_bb);
builder.set_insert_point(endif_bb);
return nullptr;
}
@@ -422,7 +445,7 @@ ir::value* initializer::codegen(ir::module * mod) const{
else if(expr_){
value = expr_->codegen(mod);
value = explicit_cast(mod->get_builder(), value, ty);
implicit_broadcast(mod, value, ty);
implicit_broadcast(mod, ty, value);
}
value->set_name(name);
mod->set_value(name, value);
@@ -543,6 +566,19 @@ ir::value* get_global_range::codegen(ir::module *mod) const {
return builder.create_get_global_range(axis_->value(), (ir::constant_int*)size_->codegen(mod));
}
// get_range_id
ir::value* get_range_id::codegen(ir::module *mod) const {
return mod->get_builder().create_get_range_id(axis_->value());
}
// atomic cas
ir::value* atomic_cas::codegen(ir::module *mod) const {
ir::value *ptr = ptr_->codegen(mod);
ir::value *cmp = cmp_->codegen(mod);
ir::value *val = val_->codegen(mod);
return mod->get_builder().create_atomic_cas(ptr, cmp, val);
}
// matmul
ir::value* matmul_expression::codegen(ir::module *mod) const {
ir::value *A = A_->codegen(mod);
@@ -554,10 +590,37 @@ ir::value* matmul_expression::codegen(ir::module *mod) const {
// ir::type *tile_ty = ir::tile_type::get(scalar_ty, {M, N});
// ir::value *tmp = ir::undef_value::get(tile_ty);
// implicit_broadcast(mod, tmp, C);
return mod->get_builder().create_matmul(A, B, C);
return mod->get_builder().create_dot(A, B, C);
}
// min
ir::value* min_expression::codegen(ir::module *mod) const {
ir::value* cmp = binary_operator(LT, (node*)x_, (node*)y_).codegen(mod);
ir::value* x = ((ir::cmp_inst*)cmp)->get_operand(0);
ir::value* y = ((ir::cmp_inst*)cmp)->get_operand(1);
return mod->get_builder().create_select(cmp, x, y);
}
// max
ir::value* max_expression::codegen(ir::module *mod) const {
ir::value* cmp = binary_operator(GT, (node*)x_, (node*)y_).codegen(mod);
ir::value* x = ((ir::cmp_inst*)cmp)->get_operand(0);
ir::value* y = ((ir::cmp_inst*)cmp)->get_operand(1);
return mod->get_builder().create_select(cmp, x, y);
}
// select
ir::value* select_expression::codegen(ir::module *mod) const {
ir::value* pred = pred_->codegen(mod);
ir::value* if_value = if_value_->codegen(mod);
ir::value* else_value = else_value_->codegen(mod);
return mod->get_builder().create_select(pred, if_value, else_value);
}
// Trans
ir::value* trans_expression::codegen(ir::module *mod) const {
return mod->get_builder().create_trans(arg_->codegen(mod));
}
/* Postfix expression */
ir::value* indexing_expression::codegen(ir::module *mod) const{
@@ -573,6 +636,7 @@ ir::value* indexing_expression::codegen(ir::module *mod) const{
return mod->get_builder().create_reshape(in, out_shapes);
}
/* Unary operator */
ir::value *unary_operator::llvm_op(ir::builder &builder, ir::value *arg, const std::string &name) const{
ir::type *atype = arg->get_type();
@@ -666,7 +730,7 @@ ir::value *assignment_expression::codegen(ir::module *mod) const{
if(auto *x = dynamic_cast<const named_expression*>(lvalue_)){
ir::type *ty = mod->get_scope().types.at(x->id()->name());
rvalue = explicit_cast(mod->get_builder(), rvalue, ty);
implicit_broadcast(mod, rvalue, ty);
implicit_broadcast(mod, ty, rvalue);
mod->set_value(x->id()->name(), rvalue);
}
else if(auto* x = dynamic_cast<const unary_operator*>(lvalue_)){

View File

@@ -1,90 +0,0 @@
#include "triton/codegen/buffer_info.h"
#include "triton/ir/module.h"
#include "triton/ir/function.h"
#include "triton/ir/basic_block.h"
#include "triton/ir/instructions.h"
#include "triton/ir/type.h"
namespace triton {
namespace codegen{
// run pass on module
bool buffer_info_pass::is_loop_latch(ir::phi_node *phi, ir::value *terminator){
if(auto *br = dynamic_cast<ir::cond_branch_inst*>(terminator))
return br->get_true_dest() == phi->get_parent()
|| br->get_false_dest() == phi->get_parent();
else if(auto *br = dynamic_cast<ir::uncond_branch_inst*>(terminator))
return false;
else
throw std::runtime_error("unreachable");
}
void buffer_info_pass::replace(ir::value* before, ir::value *after) {
shared_.erase(before);
shared_.insert(after);
if(refs_.find(before) != refs_.end()){
ir::value* v = refs_.at(before);
refs_.erase(before);
refs_.insert({after, v});
}
}
void buffer_info_pass::run(ir::module &mod) {
// Find which buffers are shared
for(ir::function *fn: mod.get_function_list())
for(ir::basic_block *block: fn->blocks())
for(ir::instruction *i: block->get_inst_list())
if(dynamic_cast<ir::matmul_inst*>(i)){
shared_.insert(i->get_operand(0));
shared_.insert(i->get_operand(1));
}
// Handles phi nodes
for(ir::function *fn: mod.get_function_list())
for(ir::basic_block *block: fn->blocks())
for(ir::instruction *i: block->get_inst_list()) {
if(!i->get_type()->is_tile_ty())
continue;
// handle phi
if(auto *phi = dynamic_cast<ir::phi_node*>(i))
if(is_shared(phi)){
// determine if the value is in shared memory
bool is_double = false;
for(unsigned n = 0; n < phi->get_num_incoming(); n++){
ir::basic_block *inc_block = phi->get_incoming_block(n);
ir::value *terminator = inc_block->get_inst_list().back();
is_double = is_double || is_loop_latch(phi, terminator);
}
// add to double-buffered
if(is_double)
double_.insert(phi);
// set references of input
for(unsigned n = 0; n < phi->get_num_incoming(); n++){
ir::value *inc_val = phi->get_incoming_value(n);
refs_[inc_val] = phi;
}
}
}
for(auto &ref: refs_)
shared_.insert(ref.first);
}
// query double-buffered status
bool buffer_info_pass::is_double(ir::value *x)
{ return double_.find(x) != double_.end(); }
// query shared status
bool buffer_info_pass::is_shared(ir::value *x)
{ return shared_.find(x) != shared_.end(); }
// get reference if any
ir::value *buffer_info_pass::get_reference(ir::value *x)
{ return refs_[x]; }
}
}

View File

@@ -1,56 +0,0 @@
#include "triton/codegen/layout.h"
#include "triton/ir/function.h"
#include "triton/ir/module.h"
#include "triton/ir/basic_block.h"
#include "triton/ir/instructions.h"
namespace triton{
namespace codegen{
shared_view_info layout::get_shared_view(ir::value *v, unsigned idx){
return shared_views_.at(v)[idx];
}
unsigned layout::get_num_shared_views(ir::value *v){
return shared_views_.at(v).size();
}
// Phi node
void layout::add_phi_nodes(ir::value *v){
if(ir::phi_node *phi = dynamic_cast<ir::phi_node*>(v))
if(shared_views_.find(phi) != shared_views_.end())
for(ir::value *v: phi->ops()){
shared_views_[v] = shared_views_[phi];
for(shared_view_info &info: shared_views_[v])
info.has_dedicated_storage = false;
}
}
// Memory Layout
void layout::add_shared_views(ir::value *v){
// GEMM has shared inputs
if(dynamic_cast<ir::matmul_inst*>(v))
shared_views_[v].push_back({v, true});
if(dynamic_cast<ir::reshape_inst*>(v))
shared_views_[v].push_back({v, true});
}
// Entry point
void layout::run(ir::module &mod) {
for(ir::function *fn: mod.get_function_list()){
// Non-phis
for(ir::basic_block *block: fn->blocks())
for(ir::instruction *instr: block->get_inst_list()) {
add_shared_views(instr);
}
// Phi nodes
for(ir::basic_block *block: fn->blocks())
for(ir::instruction *instr: block->get_inst_list()) {
add_phi_nodes(instr);
}
}
}
}
}

View File

@@ -0,0 +1,14 @@
#include "triton/ir/function.h"
#include "triton/ir/basic_block.h"
#include "triton/ir/module.h"
#include "triton/codegen/optimize_cse.h"
namespace triton {
namespace codegen{
void optimize_cse::run(ir::module &mod) {
}
}
}

View File

@@ -0,0 +1,50 @@
#include "triton/ir/function.h"
#include "triton/ir/basic_block.h"
#include "triton/ir/module.h"
#include "triton/codegen/optimize_dot.h"
#include "triton/codegen/tune.h"
namespace triton {
namespace codegen{
inline bool is_trans(ir::value *v){
return dynamic_cast<ir::trans_inst*>(v) != nullptr;
}
void optimize_dot::run(ir::module &mod) {
ir::builder &builder = mod.get_builder();
std::vector<ir::instruction*> to_delete;
// iterate
for(ir::function *fn: mod.get_function_list())
for(ir::basic_block *block: fn->blocks())
for(ir::instruction *i: block->get_inst_list())
if(auto dot = dynamic_cast<ir::dot_inst*>(i))
if(dot->get_operand(1)->get_type()->get_tile_shapes()[1]->get_value() != 1)
if(!dot->is_a_trans() && !dot->is_b_trans()){
builder.set_insert_point(i);
ir::value *A = dot->get_operand(0);
ir::value *B = dot->get_operand(1);
ir::value *D = dot->get_operand(2);
// dot(op(a), trans(b))
if(is_trans(B)){
ir::value* BN = ((ir::trans_inst*)B)->get_operand(0);
ir::instruction *NT = builder.insert(ir::dot_inst::create_nt(A, BN, D));
dot->replace_all_uses_with(NT);
to_delete.push_back((ir::instruction*)B);
to_delete.push_back(dot);
}
// dot(op(a), b)
if(!is_trans(B)){
ir::value* BT = builder.create_trans(B);
ir::instruction *NT = builder.insert(ir::dot_inst::create_nt(A, BT, D));
dot->replace_all_uses_with(NT);
to_delete.push_back(dot);
}
}
for(ir::instruction* i: to_delete)
i->erase_from_parent();
}
}
}

View File

@@ -0,0 +1,71 @@
#include "triton/ir/module.h"
#include "triton/ir/function.h"
#include "triton/codegen/optimize_trans.h"
namespace triton {
namespace codegen{
ir::value* optimize_trans::replace_phi(ir::value* value,
std::vector<ir::instruction*>& to_delete,
ir::builder& builder){
if(auto phi = dynamic_cast<ir::phi_node*>(value)) {
// transpose operands
std::vector<ir::value*> incs;
for(unsigned n = 0; n < phi->get_num_incoming(); n++)
incs.push_back(replace_phi(phi->get_incoming_value(n), to_delete, builder));
// create phi for transposed values
builder.set_insert_point(phi);
ir::phi_node* result = builder.create_phi(incs[0]->get_type(), incs.size(), phi->get_name());
for(unsigned n = 0; n < phi->get_num_incoming(); n++)
result->add_incoming(incs[n], phi->get_incoming_block(n));
phi->replace_all_uses_with(result);
to_delete.push_back(phi);
return result;
}
else if(auto i = dynamic_cast<ir::instruction*>(value)){
ir::basic_block* block = i->get_parent();
auto it = std::find(block->begin(), block->end(), i);
it++;
builder.set_insert_point(it);
ir::instruction *trans = (ir::instruction*)builder.create_trans(i);
i->replace_all_uses_with(trans);
trans->set_operand(0, i);
return trans;
}
throw std::runtime_error("cannot transpose phi");
}
void optimize_trans::run(ir::module &mod) {
ir::builder &builder = mod.get_builder();
std::vector<ir::instruction*> to_delete;
// iterate
for(ir::function *fn: mod.get_function_list())
for(ir::basic_block *block: fn->blocks())
for(ir::instruction* i: block->get_inst_list()){
// filter transposition
if(auto trans = dynamic_cast<ir::trans_inst*>(i)) {
auto users = trans->get_users();
auto ops = trans->ops();
if(users.size() > 1 || ops.size() > 1)
continue;
ir::value* op = *ops.begin();
// chains of transpositions
// TODO
// trans(phi) -> phi(trans(), trans()...)
if(dynamic_cast<ir::phi_node*>(op)){
ir::value* new_phi = replace_phi(op, to_delete, builder);
to_delete.push_back(trans);
trans->replace_all_uses_with(new_phi);
}
}
}
// erase dead code
for(ir::instruction* i: to_delete)
i->erase_from_parent();
}
}
}

View File

@@ -1,6 +1,6 @@
#include "triton/codegen/selection.h"
#include "triton/codegen/tune.h"
#include "triton/codegen/allocation.h"
#include "triton/codegen/shmem_allocation.h"
#include "triton/codegen/target.h"
#include "llvm/IR/InstrTypes.h"
#include "llvm/IR/Module.h"
@@ -309,7 +309,47 @@ Instruction *selection::llvm_inst(ir::instruction *inst, std::function<Value*(ir
}
if(ir::load_inst* ii = dynamic_cast<ir::load_inst*>(inst)){
Value *ptr = value(ii->get_pointer_operand());
return builder.Insert(new LoadInst(ptr));
LoadInst *result = new LoadInst(ptr);
return builder.Insert(result);
}
if(ir::store_inst* ii = dynamic_cast<ir::store_inst*>(inst)){
Value *val = value(ii->get_value_operand());
Value *ptr = value(ii->get_pointer_operand());
builder.CreateStore(val, ptr);
return nullptr;
}
if(ir::select_inst* ii = dynamic_cast<ir::select_inst*>(inst)){
Value *pred = value(ii->get_operand(0));
Value *if_value = value(ii->get_operand(1));
Value *else_value = value(ii->get_operand(2));
return builder.Insert(SelectInst::Create(pred, if_value, else_value));
}
if(ir::get_range_id_inst* ii = dynamic_cast<ir::get_range_id_inst*>(inst)){
Value *offset = tgt_->get_block_id(builder.GetInsertBlock()->getModule(), builder, ii->get_axis());
return (Instruction*)builder.CreateAdd(offset, builder.getInt32(0));
}
if(ir::atomic_cas_inst* ii = dynamic_cast<ir::atomic_cas_inst*>(inst)){
BasicBlock *current = builder.GetInsertBlock();
Module *module = current->getModule();
Value *tid = tgt_->get_local_id(module, builder, 0);
Value *pred = builder.CreateICmpEQ(tid, builder.getInt32(0));
BasicBlock *tid_0_bb = BasicBlock::Create(ctx, "tid_0", current->getParent());
BasicBlock *tid_0_done_bb = BasicBlock::Create(ctx, "tid_0_done", current->getParent());
Value *ptr = builder.CreateGEP(sh_mem_ptr_, builder.getInt32(alloc_->get_offset(ii)));
ptr = builder.CreateBitCast(ptr, PointerType::get(builder.getInt32Ty(), ptr->getType()->getPointerAddressSpace()));
builder.CreateCondBr(pred, tid_0_bb, tid_0_done_bb);
builder.SetInsertPoint(tid_0_bb);
Value *cas_ptr = value(ii->get_operand(0));
Value *cas_cmp = value(ii->get_operand(1));
Value *cas_val = value(ii->get_operand(2));
Value *old = builder.CreateAtomicCmpXchg(cas_ptr, cas_cmp, cas_val, AtomicOrdering::Monotonic, AtomicOrdering::Monotonic);
old = builder.CreateExtractValue(old, {0});
builder.CreateStore(old, ptr);
builder.CreateBr(tid_0_done_bb);
builder.SetInsertPoint(tid_0_done_bb);
tgt_->add_barrier(module, builder);
Value *res = builder.CreateLoad(ptr);
return (Instruction*)res;
}
// unknown instruction
throw std::runtime_error("unknown conversion from ir::instruction to Instruction");
@@ -446,7 +486,7 @@ void selection::create_grids(std::vector<ir::value*> &grids,
bind_references(op);
// bind
const auto& shapes = v->get_type()->get_tile_shapes();
if(dynamic_cast<ir::copy_to_shared_inst*>(v) || buffer_info_->is_double(v))
if(buffer_info_->is_shared(v))
return;
for(size_t d = 0; d < shapes.size(); d++){
if(shapes[d]->get_value() == 1)
@@ -490,20 +530,11 @@ void selection::create_tile(ir::value *v, IRBuilder<> &builder,
shapes2.push_back(shape->get_value());
Type* ty = llvm_type(v->get_type()->get_scalar_ty(), ctx);
// create shared tile
if(dynamic_cast<ir::copy_to_shared_inst*>(v) || (buffer_info_->is_double(v))){
if(buffer_info_->is_shared(v)){
// shared copy
PointerType *ptr_ty = ty->getPointerTo(sh_mem_ptr->getType()->getPointerAddressSpace());
// TODO - buffer info not up-to-date with references
if(dynamic_cast<ir::copy_to_shared_inst*>(v)) {
if(!has_phi_user(v)){
size_t offset = alloc_->get_offset(v);
Value *ptr = builder.CreateGEP(sh_mem_ptr, builder.getInt32(offset));
ptr = builder.CreateBitCast(ptr, ptr_ty);
tmap_.insert({v, new shared_tile(ty, shapes2, ptr, builder)});
}
}
// phi-node (double-buffering)
else if(auto *phi = dynamic_cast<ir::phi_node*>(v)) {
if(auto *phi = dynamic_cast<ir::phi_node*>(v)) {
BasicBlock *parent = (BasicBlock*)vmap_[phi->get_parent()];
unsigned id_pre = 0, id_loop = 1;
if(phi->get_incoming_block(0) == phi->get_parent())
@@ -522,13 +553,19 @@ void selection::create_tile(ir::value *v, IRBuilder<> &builder,
for(unsigned i = 0; i < phi->get_num_incoming(); i++) {
ir::basic_block* inc_block = phi->get_incoming_block(i);
ir::value* inc_value = phi->get_incoming_value(i);
ir::value* terminator = inc_block->get_inst_list().back();
ir::instruction* terminator = inc_block->get_inst_list().back();
bool is_loop_latch = buffer_info_->is_loop_latch(phi, terminator);
tmap_.insert({inc_value, new shared_tile(ty, shapes2, is_loop_latch?next_ptr:pre_ptr, builder)});
}
}
else
throw std::runtime_error("unknown shared memory tile");
else {
if(!has_phi_user(v)){
size_t offset = alloc_->get_offset(v);
Value *ptr = builder.CreateGEP(sh_mem_ptr, builder.getInt32(offset));
ptr = builder.CreateBitCast(ptr, ptr_ty);
tmap_.insert({v, new shared_tile(ty, shapes2, ptr, builder)});
}
}
}
// create distributed tile
else {
@@ -607,10 +644,16 @@ void selection::lower_tile_instruction(ir::instruction *ins, llvm::IRBuilder<> &
tile *value = tmap_.at(x->get_value_operand());
ptr->for_each([&](indices_t idx){
set_mask_insert_pt(idx);
builder.CreateStore(value->get_value(idx), ptr->get_value(idx));
StoreInst *store = new StoreInst(value->get_value(idx), ptr->get_value(idx));
// store->setAlignment(16);
builder.Insert(store);
});
}
else {
if(auto *x = dynamic_cast<ir::downcast_inst*>(ins)){
vmap_[x] = tmap_[x->get_operand(0)]->get_value({builder.getInt32(0)});
return;
}
tile *ti = tmap_[ins];
distributed_tile* result = (distributed_tile*)ti;
if(!ins->get_type()->is_tile_ty())
@@ -727,31 +770,67 @@ void selection::lower_tile_instruction(ir::instruction *ins, llvm::IRBuilder<> &
ti->set_value(idx, in->get_value(idx));
});
}
else if(dynamic_cast<ir::copy_to_shared_inst*>(ins) || (buffer_info_->is_double(ins)))
// trans
else if(dynamic_cast<ir::trans_inst*>(ins)) {
distributed_tile* in = (distributed_tile*)tmap_.at(ins->get_operand(0));
in->for_each([&](indices_t idx){
indices_t out_idx = idx;
std::rotate(out_idx.begin(), out_idx.begin() + 1, out_idx.end());
ti->set_value(out_idx, in->get_value(idx));
});
}
else if(buffer_info_->is_shared(ins))
return;
// matrix multiplication
else if(dynamic_cast<ir::matmul_inst*>(ins)) {
// dot
else if(auto dot = dynamic_cast<ir::dot_inst*>(ins)) {
ir::value *A = ins->get_operand(0);
ir::value *B = ins->get_operand(1);
ir::value *C = ins->get_operand(2);
shared_tile *TA = (shared_tile*)tmap_.at(A);
shared_tile *TB = (shared_tile*)tmap_.at(B);
bool AT = dot->is_a_trans();
bool BT = dot->is_b_trans();
distributed_tile *TC = (distributed_tile*)tmap_.at(C);
TA->set_vector_size(TC->axis(0).contiguous);
TB->set_vector_size(TC->axis(1).contiguous);
Function *f_mul_add = Intrinsic::getDeclaration(module, Intrinsic::fmuladd, {llvm_type(C->get_type()->get_scalar_ty(), ctx)});
result->for_each([&](indices_t idx){
Value *res = TC->get_value(idx);
unsigned NK = A->get_type()->get_tile_shapes()[1]->get_value();
for(unsigned K = 0; K < NK; ++K){
indices_t a_idx = {idx[0], builder.getInt32(K)};
indices_t b_idx = {idx[1], builder.getInt32(K)};
if(dot->get_operand(0)->get_type()->get_tile_shapes()[1]->get_value() != 1)
{
shared_tile *TA = (shared_tile*)tmap_.at(A);
shared_tile *TB = (shared_tile*)tmap_.at(B);
TA->set_vector_size(TC->axis(0).contiguous);
TB->set_vector_size(TC->axis(1).contiguous);
result->for_each([&](indices_t idx){
Value *res = TC->get_value(idx);
unsigned NK = A->get_type()->get_tile_shapes()[1]->get_value();
for(unsigned K = 0; K < NK; ++K){
indices_t a_idx = {idx[0], builder.getInt32(K)};
indices_t b_idx = {builder.getInt32(K), idx[1]};
if(AT)
std::swap(a_idx[0], a_idx[1]);
if(BT)
std::swap(b_idx[0], b_idx[1]);
Value *a = TA->get_value(a_idx);
Value *b = TB->get_value(b_idx);
res = builder.CreateCall(f_mul_add, {a, b, res});
}
result->set_value(idx, res);
});
}
else
{
distributed_tile *TA = (distributed_tile*)tmap_.at(A);
distributed_tile *TB = (distributed_tile*)tmap_.at(B);
result->for_each([&](indices_t idx){
Value *res = TC->get_value(idx);
indices_t a_idx = {idx[0], builder.getInt32(0)};
indices_t b_idx = {builder.getInt32(0), idx[1]};
if(AT)
std::swap(a_idx[0], a_idx[1]);
if(BT)
std::swap(b_idx[0], b_idx[1]);
Value *a = TA->get_value(a_idx);
Value *b = TB->get_value(b_idx);
res = builder.CreateCall(f_mul_add, {a, b, res});
}
result->set_value(idx, res);
});
result->set_value(idx, res);
});
}
}
// element-wise
else {
@@ -858,6 +937,7 @@ void selection::run(ir::module &src, Module &dst) {
nullptr, "__shared_ptr", nullptr, GlobalVariable::NotThreadLocal, 3);
sh_mem_ptr = dst_builder.CreateBitCast(sh_mem_array, ptr_ty);
}
sh_mem_ptr_ = sh_mem_ptr;
// create grids
init_grids(fn, dst_builder, sh_mem_ptr);
@@ -890,7 +970,7 @@ void selection::run(ir::module &src, Module &dst) {
for(unsigned n = 0; n < phi->get_num_incoming(); n++){
ir::basic_block* inc_block = phi->get_incoming_block(n);
ir::value* inc_val = phi->get_incoming_value(n);
ir::value* terminator = inc_block->get_inst_list().back();
ir::instruction* terminator = inc_block->get_inst_list().back();
BasicBlock *llvm_inc_block = last_block.at(inc_block);
shared_tile *inc_shared = (shared_tile*)tmap_.at(inc_val);
bool is_loop_latch = buffer_info_->is_loop_latch(phi, terminator);
@@ -920,8 +1000,8 @@ void selection::run(ir::module &src, Module &dst) {
});
}
else {
PHINode *llvm_phi = (PHINode*)vmap_.at(phi);
Value *llvm_inc_val = vmap_.at(inc_val);
PHINode *llvm_phi = (PHINode*)llvm_value(phi, dst_builder);
Value *llvm_inc_val = llvm_value(inc_val, dst_builder);
llvm_phi->addIncoming(llvm_inc_val, llvm_inc_block);
}
}

View File

@@ -1,40 +0,0 @@
#include <algorithm>
#include "triton/codegen/shared_copy.h"
#include "triton/codegen/buffer_info.h"
#include "triton/ir/module.h"
#include "triton/ir/function.h"
#include "triton/ir/basic_block.h"
#include "triton/ir/instructions.h"
namespace triton {
namespace codegen{
void place_shared_copy::add_copy(ir::value *x, ir::builder &builder) {
if(auto *i = dynamic_cast<ir::instruction*>(x)){
ir::basic_block* block = i->get_parent();
auto it = std::find(block->begin(), block->end(), i);
builder.set_insert_point(++it);
}
ir::instruction *rx = (ir::instruction*)builder.create_copy_to_shared(x);
x->replace_all_uses_with(rx);
rx->set_operand(0, x);
}
void place_shared_copy::run(ir::module &mod) {
ir::builder &builder = mod.get_builder();
for(ir::function *fn: mod.get_function_list())
for(ir::basic_block *block: fn->blocks())
for(ir::instruction *i: block->get_inst_list())
if(info_->is_shared(i) && !info_->is_double(i))
add_copy(i, builder);
for(ir::function *fn: mod.get_function_list())
for(ir::basic_block *block: fn->blocks())
for(ir::instruction *i: block->get_inst_list())
if(auto* cts = dynamic_cast<ir::copy_to_shared_inst*>(i))
info_->replace(cts->get_operand(0), cts);
}
}
}

View File

@@ -1,7 +1,6 @@
#include "triton/codegen/allocation.h"
#include "triton/codegen/liveness.h"
#include "triton/codegen/layout.h"
#include "triton/codegen/buffer_info.h"
#include "triton/codegen/shmem_allocation.h"
#include "triton/codegen/shmem_liveness.h"
#include "triton/codegen/shmem_info.h"
#include "triton/ir/basic_block.h"
#include "triton/ir/type.h"
#include "triton/ir/value.h"
@@ -11,14 +10,14 @@
namespace triton{
namespace codegen{
unsigned allocation::get_num_bytes(ir::value *x) {
unsigned result = x->get_type()->get_tile_bitwidth() / 8;
unsigned shmem_allocation::get_num_bytes(ir::value *x) {
unsigned result = x->get_type()->get_primitive_size_in_bits() / 8;
if(buffer_info_->is_double(x))
result *= 2;
return result;
}
void allocation::run(){
void shmem_allocation::run(){
using std::max;
using std::min;
typedef std::multimap<unsigned, segment> triples_map_type;

View File

@@ -1,7 +1,7 @@
#include <algorithm>
#include "triton/codegen/barriers.h"
#include "triton/codegen/allocation.h"
#include "triton/codegen/buffer_info.h"
#include "triton/codegen/shmem_barriers.h"
#include "triton/codegen/shmem_allocation.h"
#include "triton/codegen/shmem_info.h"
#include "triton/ir/module.h"
#include "triton/ir/function.h"
#include "triton/ir/basic_block.h"
@@ -12,7 +12,7 @@ namespace triton {
namespace codegen{
bool barriers::intersect(const interval_vec_t &X, interval_t x) {
bool shmem_barriers::intersect(const interval_vec_t &X, interval_t x) {
return std::any_of(X.begin(), X.end(), [&](const interval_t &y){
bool left_intersect = y.first <= x.first && x.first < y.second;
bool right_intersect = y.first <= x.second && x.second < y.second;
@@ -20,31 +20,31 @@ bool barriers::intersect(const interval_vec_t &X, interval_t x) {
});
}
bool barriers::intersect(const interval_vec_t &X, const interval_vec_t &Y) {
bool shmem_barriers::intersect(const interval_vec_t &X, const interval_vec_t &Y) {
return std::any_of(Y.begin(), Y.end(), [&](const interval_t &y){
return intersect(X, y);
});
}
void barriers::add_reference(ir::value *v, interval_vec_t &res){
if(dynamic_cast<ir::copy_to_shared_inst*>(v)){
void shmem_barriers::add_reference(ir::value *v, interval_vec_t &res){
if(buffer_info_->is_shared(v) && !dynamic_cast<ir::phi_node*>(v)){
unsigned offset = alloc_->get_offset(v);
unsigned num_bytes = alloc_->get_num_bytes(v);
res.push_back(interval_t(offset, offset + num_bytes));
}
}
void barriers::get_read_intervals(ir::instruction *i, interval_vec_t &res){
void shmem_barriers::get_read_intervals(ir::instruction *i, interval_vec_t &res){
for(ir::value *op: i->ops())
add_reference(op, res);
}
void barriers::get_written_intervals(ir::instruction *i, interval_vec_t &res){
void shmem_barriers::get_written_intervals(ir::instruction *i, interval_vec_t &res){
if(!dynamic_cast<ir::phi_node*>(i))
add_reference(i, res);
}
void barriers::insert_barrier(ir::instruction *instr, ir::builder &builder) {
void shmem_barriers::insert_barrier(ir::instruction *instr, ir::builder &builder) {
if(auto *phi = dynamic_cast<ir::phi_node*>(instr)) {
std::set<ir::value*> incoming;
for(unsigned n = 0; n < phi->get_num_incoming(); n++){
@@ -63,16 +63,16 @@ void barriers::insert_barrier(ir::instruction *instr, ir::builder &builder) {
}
}
barriers::interval_vec_t barriers::join(const std::vector<interval_vec_t>& intervals) {
barriers::interval_vec_t result;
shmem_barriers::interval_vec_t shmem_barriers::join(const std::vector<interval_vec_t>& intervals) {
shmem_barriers::interval_vec_t result;
for(auto x: intervals)
for(interval_t i: x)
result.push_back(i);
return result;
}
std::pair<barriers::interval_vec_t,
barriers::interval_vec_t> barriers::transfer(ir::basic_block *block,
std::pair<shmem_barriers::interval_vec_t,
shmem_barriers::interval_vec_t> shmem_barriers::transfer(ir::basic_block *block,
const interval_vec_t &written_to,
const interval_vec_t &read_from,
std::set<ir::instruction*>& insert_loc) {
@@ -83,13 +83,13 @@ std::pair<barriers::interval_vec_t,
interval_vec_t read, written;
get_read_intervals(i, read);
get_written_intervals(i, written);
bool read_while_written = intersect(new_written_to, read);
bool written_while_read = intersect(new_read_from, written);
bool read_after_write = intersect(new_written_to, read);
bool write_after_read = intersect(new_read_from, written);
// double buffering: write and phi-node read won't intersect
if(dynamic_cast<ir::copy_to_shared_inst*>(i) &&
if(buffer_info_->is_shared(i) &&
buffer_info_->is_double(buffer_info_->get_reference(i)))
written_while_read = false;
if(read_while_written || written_while_read) {
write_after_read = false;
if(read_after_write || write_after_read) {
insert_loc.insert(i);
new_written_to.clear();
new_read_from.clear();
@@ -100,7 +100,7 @@ std::pair<barriers::interval_vec_t,
return std::make_pair(new_written_to, new_read_from);
}
void barriers::run(ir::module &mod) {
void shmem_barriers::run(ir::module &mod) {
ir::builder &builder = mod.get_builder();
for(ir::function *fn: mod.get_function_list()){
std::vector<ir::basic_block*> rpo = ir::cfg::reverse_post_order(fn);

135
lib/codegen/shmem_info.cpp Normal file
View File

@@ -0,0 +1,135 @@
#include "triton/codegen/shmem_info.h"
#include "triton/ir/module.h"
#include "triton/ir/function.h"
#include "triton/ir/basic_block.h"
#include "triton/ir/instructions.h"
#include "triton/ir/type.h"
namespace triton {
namespace codegen{
// run pass on module
bool shmem_info::is_loop_latch(ir::phi_node *phi, ir::instruction *terminator){
if(phi->get_parent() != terminator->get_parent())
return false;
if(auto *br = dynamic_cast<ir::cond_branch_inst*>(terminator))
return br->get_true_dest() == phi->get_parent()
|| br->get_false_dest() == phi->get_parent();
else if(auto *br = dynamic_cast<ir::uncond_branch_inst*>(terminator))
return false;
else
throw std::runtime_error("unreachable");
}
void shmem_info::replace(ir::value* before, ir::value *after) {
shared_.erase(before);
shared_.insert(after);
if(refs_.find(before) != refs_.end()){
ir::value* v = refs_.at(before);
refs_.erase(before);
refs_.insert({after, v});
}
}
inline bool get_is_shared(ir::value* v) {
if(auto x = dynamic_cast<ir::atomic_cas_inst*>(v))
return true;
if(auto x = dynamic_cast<ir::trans_inst*>(v))
return true;
if(auto x = dynamic_cast<ir::copy_to_shared_inst*>(v))
return true;
if(auto x = dynamic_cast<ir::phi_node*>(v)){
bool res = true;
for(unsigned inc = 0; inc < x->get_num_incoming(); inc++)
res = res && get_is_shared(x->get_incoming_value(inc));
return res;
}
return false;
}
void add_copy(ir::value *x, ir::builder &builder) {
if(auto phi = dynamic_cast<ir::phi_node*>(x)){
for(unsigned i = 0; i < phi->get_num_incoming(); ++i)
add_copy(phi->get_incoming_value(i), builder);
}
else {
if(get_is_shared(x))
return;
if(auto *i = dynamic_cast<ir::instruction*>(x)){
ir::basic_block* block = i->get_parent();
auto it = std::find(block->begin(), block->end(), i);
builder.set_insert_point(++it);
}
ir::instruction *rx = (ir::instruction*)builder.create_copy_to_shared(x);
x->replace_all_uses_with(rx);
rx->set_operand(0, x);
}
}
void shmem_info::run(ir::module &mod) {
// Add shared copies
for(ir::function *fn: mod.get_function_list()){
ir::builder builder(mod.get_context());
for(ir::basic_block *block: fn->blocks())
for(ir::instruction *i: block->get_inst_list()){
if(dynamic_cast<ir::dot_inst*>(i))
if(i->get_operand(1)->get_type()->get_tile_shapes()[1]->get_value() != 1){
add_copy(i->get_operand(0), builder);
add_copy(i->get_operand(1), builder);
}
}
}
// Find which buffers are shared
for(ir::function *fn: mod.get_function_list())
for(ir::basic_block *block: fn->blocks())
for(ir::instruction *i: block->get_inst_list())
if(get_is_shared(i))
shared_.insert(i);
// double-buffering
for(ir::function *fn: mod.get_function_list())
for(ir::basic_block *block: fn->blocks())
for(ir::instruction *i: block->get_inst_list()) {
if(!i->get_type()->is_tile_ty())
continue;
// handle phi
if(auto *phi = dynamic_cast<ir::phi_node*>(i))
if(is_shared(phi)){
// determine if the value is in shared memory
bool is_double = false;
for(unsigned n = 0; n < phi->get_num_incoming(); n++){
ir::basic_block *inc_block = phi->get_incoming_block(n);
ir::instruction *terminator = inc_block->get_inst_list().back();
is_double = is_double || is_loop_latch(phi, terminator);
}
// add to double-buffered
if(is_double)
double_.insert(phi);
// set references of input
for(unsigned n = 0; n < phi->get_num_incoming(); n++){
ir::value *inc_val = phi->get_incoming_value(n);
refs_[inc_val] = phi;
}
}
}
}
// query double-buffered status
bool shmem_info::is_double(ir::value *x)
{ return double_.find(x) != double_.end(); }
// query shared status
bool shmem_info::is_shared(ir::value *x)
{ return shared_.find(x) != shared_.end(); }
// get reference if any
ir::value *shmem_info::get_reference(ir::value *x)
{ return refs_[x]; }
}
}

View File

@@ -1,5 +1,5 @@
#include "triton/codegen/liveness.h"
#include "triton/codegen/buffer_info.h"
#include "triton/codegen/shmem_liveness.h"
#include "triton/codegen/shmem_info.h"
#include "triton/ir/basic_block.h"
#include "triton/ir/function.h"
#include "triton/ir/module.h"
@@ -11,19 +11,7 @@ namespace codegen{
// Entry point
inline bool is_shared(ir::value* v) {
if(auto x = dynamic_cast<ir::copy_to_shared_inst*>(v))
return true;
if(auto x = dynamic_cast<ir::phi_node*>(v)){
bool res = true;
for(unsigned inc = 0; inc < x->get_num_incoming(); inc++)
res = res && is_shared(x->get_incoming_value(inc));
return res;
}
return false;
}
void liveness::run(ir::module &mod) {
void shmem_liveness::run(ir::module &mod) {
for(ir::function *fn: mod.get_function_list()){
// Assigns index to each instruction
slot_index index = 0;

View File

@@ -4,6 +4,7 @@
#include "llvm/IR/Intrinsics.h"
#include "llvm/IR/Value.h"
#include "llvm/IR/IRBuilder.h"
#include <iostream>
using namespace llvm;
@@ -26,6 +27,12 @@ Instruction* amd_cl_target::add_barrier(Module *module, IRBuilder<>& builder) {
}
Value* amd_cl_target::get_global_offset(Module *module, IRBuilder<>& builder, unsigned stride, unsigned ax) {
Value* group_id = get_block_id(module, builder, ax);
Value* result = builder.CreateMul(builder.getInt32(stride), group_id);
return result;
}
Value* amd_cl_target::get_block_id(Module *module, IRBuilder<>& builder, unsigned ax) {
static std::array<Intrinsic::ID, 3> ids = {
Intrinsic::amdgcn_workgroup_id_x,
Intrinsic::amdgcn_workgroup_id_y,
@@ -33,8 +40,7 @@ Value* amd_cl_target::get_global_offset(Module *module, IRBuilder<>& builder, un
};
Value* get_group_id = Intrinsic::getDeclaration(module, ids[ax]);
Value* group_id = builder.CreateCall(get_group_id, {});
Value* result = builder.CreateMul(builder.getInt32(stride), group_id);
return result;
return group_id;
}
Value* amd_cl_target::get_local_id(Module *module, IRBuilder<>& builder, unsigned ax) {
@@ -65,6 +71,12 @@ Instruction* nvidia_cu_target::add_barrier(Module *module, IRBuilder<>& builder)
}
Value* nvidia_cu_target::get_global_offset(Module *module, IRBuilder<>& builder, unsigned stride, unsigned ax) {
Value* group_id = get_block_id(module, builder, ax);
Value* result = builder.CreateMul(builder.getInt32(stride), group_id);
return result;
}
Value* nvidia_cu_target::get_block_id(Module *module, IRBuilder<>& builder, unsigned ax) {
static std::array<Intrinsic::ID, 3> ids = {
Intrinsic::nvvm_read_ptx_sreg_ctaid_x,
Intrinsic::nvvm_read_ptx_sreg_ctaid_y,
@@ -72,8 +84,7 @@ Value* nvidia_cu_target::get_global_offset(Module *module, IRBuilder<>& builder,
};
Value* get_group_id = Intrinsic::getDeclaration(module, ids[ax]);
Value* group_id = builder.CreateCall(get_group_id, {});
Value* result = builder.CreateMul(builder.getInt32(stride), group_id);
return result;
return group_id;
}
Value* nvidia_cu_target::get_local_id(Module *module, IRBuilder<>& builder, unsigned ax) {
@@ -97,7 +108,7 @@ Instruction* cpu_target::add_barrier(Module *module, IRBuilder<>& builder) {
return (Instruction*)builder.CreateAdd(builder.getInt32(0), builder.getInt32(0));
}
Value* cpu_target::get_global_offset(Module *module, IRBuilder<>& builder, unsigned stride, unsigned ax) {
Value* cpu_target::get_block_id(Module *module, llvm::IRBuilder<> &builder, unsigned ax) {
const Function *fn = builder.GetInsertBlock()->getParent();
size_t num_params = fn->getFunctionType()->getNumParams();
static std::array<const Argument*, 3> ids = {
@@ -105,7 +116,11 @@ Value* cpu_target::get_global_offset(Module *module, IRBuilder<>& builder, unsig
fn->arg_begin() + num_params - 2,
fn->arg_begin() + num_params - 1
};
Value* result = builder.CreateMul(builder.getInt32(stride), (Argument*)ids[ax]);
return (Argument*)ids[ax];
}
Value* cpu_target::get_global_offset(Module *module, IRBuilder<>& builder, unsigned stride, unsigned ax) {
Value* result = builder.CreateMul(builder.getInt32(stride), get_block_id(module, builder, ax));
return result;
}
@@ -113,6 +128,5 @@ Value* cpu_target::get_local_id(Module *module, IRBuilder<>& builder, unsigned a
return builder.getInt32(0);
}
}
}

View File

@@ -1,5 +1,4 @@
#include "triton/codegen/tune.h"
#include "triton/codegen/shared_copy.h"
#include "triton/ir/instructions.h"
#include "triton/ir/type.h"
#include "triton/ir/module.h"
@@ -40,6 +39,8 @@ void tune::init_c_graph(ir::instruction *v) {
ir::type::tile_shapes_t shapes;
if(auto *store = dynamic_cast<ir::store_inst*>(v))
shapes = store->get_pointer_operand()->get_type()->get_tile_shapes();
else if(auto *downcast = dynamic_cast<ir::downcast_inst*>(v))
return;
else
shapes = v->get_type()->get_tile_shapes();
// Reshape
@@ -56,6 +57,14 @@ void tune::init_c_graph(ir::instruction *v) {
// Splat
else if(dynamic_cast<ir::splat_inst*>(v)){
}
// Trans
else if(dynamic_cast<ir::trans_inst*>(v)){
ir::value *op = v->get_operand(0);
size_t n_shapes = shapes.size();
for(unsigned i = 0; i < n_shapes; i++){
add_constraint({v, (i + 1) % n_shapes}, {op, i});
}
}
// Broadcast
else if(dynamic_cast<ir::broadcast_inst*>(v)){
@@ -68,7 +77,7 @@ void tune::init_c_graph(ir::instruction *v) {
}
}
// Matrix multiplication
else if(dynamic_cast<ir::matmul_inst*>(v)){
else if(dynamic_cast<ir::dot_inst*>(v)){
ir::value *D = v->get_operand(2);
add_constraint({v, 0}, {D, 0});
add_constraint({v, 1}, {D, 1});
@@ -119,6 +128,13 @@ std::vector<ir::metaparameter *> tune::get_params(ir::module &mod) {
if(seen.insert(x.second).second && !x.second->has_value()){
result.push_back(x.second);
}
for(auto x: mod.globals()){
if(auto mp = dynamic_cast<ir::metaparameter*>(x.second))
if(seen.insert(mp).second && !mp->has_value())
result.push_back(mp);
}
return result;
}
@@ -145,23 +161,22 @@ void tune::run(ir::module &mod) {
// Layout parameters
while(!nodes_.empty()){
ir::type *ty = mod.get_builder().get_int32_ty();
ir::metaparameter *nts = ir::metaparameter::create(ctx, ty, 2, 4);
ir::metaparameter *nts = ir::metaparameter::create(ctx, ty, 1, 1);
nts->set_value(1);
ir::metaparameter *mts = ir::metaparameter::create(ctx, ty, 4, 32);
connected_components(*nodes_.begin(), {nts, mts}, nodes_, dependencies_);
}
}
// Simplify metaparameters
std::set<ir::metaparameter*> fixed_io_nts;
for(ir::function *fn: mod.get_function_list())
for(ir::basic_block *block: fn->blocks())
for(ir::instruction *i : block->get_inst_list())
if(dynamic_cast<ir::load_inst*>(i) || dynamic_cast<ir::store_inst*>(i))
if(i->get_type()->is_tile_ty())
for(unsigned d = 1; d < i->get_type()->get_tile_shapes().size(); d++)
fixed_io_nts.insert(params_.at(i).at("nts.d" + std::to_string(d)));
for(ir::metaparameter* mp: fixed_io_nts)
mp->set_value(1);
if(dynamic_cast<ir::load_inst*>(i) && i->get_type()->is_tile_ty()){
ir::type *ty = mod.get_builder().get_int32_ty();
std::unique_ptr<ir::metaparameter> tmp(ir::metaparameter::create(ctx, ty, 2, 2));
*params_.at(i).at("nts.d0") = *tmp;
}
}
void tune::init(ir::module &mod) {

View File

@@ -64,9 +64,6 @@ buffer* buffer::create(driver::context* ctx, size_t size) {
host_buffer::host_buffer(driver::context *context, size_t size)
: buffer(context, host_buffer_t(), true){
hst_->data = new char[size];
std::cout << size << std::endl;
std::cout << "allocating " << (float*)hst_->data << std::endl;
std::cout << *((float*)(hst_->data) + 512*500) << std::endl;
}
//

View File

@@ -106,7 +106,11 @@ void module::compile_llvm_module(llvm::Module* module, const std::string& triple
const std::string& features,
file_type_t ft) {
init_llvm();
// debug
// llvm::legacy::PassManager pm;
// pm.add(llvm::createPrintModulePass(llvm::outs()));
// pm.add(llvm::createVerifierPass());
// pm.run(*module);
// create machine
module->setTargetTriple(triple);
std::string error;
@@ -249,6 +253,7 @@ std::string cu_module::compile_llvm_module(llvm::Module* module) {
cu_module::cu_module(driver::context * context, llvm::Module* ll_module): cu_module(context, compile_llvm_module(ll_module)) { }
cu_module::cu_module(driver::context * context, std::string const & source) : module(context, CUmodule(), true), source_(source){
// std::cout << source << std::endl;
cu_context::context_switcher ctx_switch(*context);
// JIT compile source-code
CUjit_option opt[] = {CU_JIT_ERROR_LOG_BUFFER_SIZE_BYTES, CU_JIT_ERROR_LOG_BUFFER};
@@ -264,11 +269,11 @@ cu_module::cu_module(driver::context * context, std::string const & source) : mo
}
}
cu_buffer cu_module::symbol(const char *name) const{
cu_buffer* cu_module::symbol(const char *name) const{
CUdeviceptr handle;
size_t size;
dispatch::cuModuleGetGlobal_v2(&handle, &size, *cu_, name);
return cu_buffer(ctx_, handle, false);
return new cu_buffer(ctx_, handle, false);
}

View File

@@ -285,6 +285,10 @@ value *builder::create_broadcast(value *arg, const type::tile_shapes_t &shapes,
return insert(broadcast_inst::create(arg, shapes, name));
}
value *builder::create_downcast(value *arg, const std::string &name) {
return insert(downcast_inst::create(arg, name));
}
//===----------------------------------------------------------------------===//
// built-in instructions
//===----------------------------------------------------------------------===//
@@ -293,8 +297,24 @@ value *builder::create_get_global_range(unsigned axis, type::tile_shapes_t::valu
return insert(get_global_range_inst::create(ctx_, axis, size, name));
}
value *builder::create_matmul(value *A, value *B, value *C, const std::string &name) {
return insert(matmul_inst::create(A, B, C, name));
value *builder::create_get_range_id(unsigned axis, const std::string &name) {
return insert(get_range_id_inst::create(ctx_, axis, name));
}
value *builder::create_atomic_cas(value *ptr, value *cmp, value *val, const std::string &name){
return insert(atomic_cas_inst::create(ptr, cmp, val, name));
}
value *builder::create_dot(value *A, value *B, value *C, const std::string &name) {
return insert(dot_inst::create_nn(A, B, C, name));
}
value *builder::create_trans(value *A, const std::string &name) {
return insert(trans_inst::create(A, name));
}
value *builder::create_select(value *pred, value *if_value, value *else_value, const std::string &name){
return insert(select_inst::create(pred, if_value, else_value, name));
}
//===----------------------------------------------------------------------===//

View File

@@ -28,6 +28,8 @@ instruction::instruction(type *ty, unsigned num_ops, unsigned num_results, const
void instruction::erase_from_parent() {
parent_->erase(this);
for(ir::value* op: ops())
op->erase_use(this);
}
bool instruction::has_tile_result_or_op() {
@@ -482,27 +484,82 @@ instruction* broadcast_inst::create(value *arg, const type::tile_shapes_t &shape
return new broadcast_inst(arg, shapes, name, next);
}
// downcast
instruction* downcast_inst::create(value *arg, const std::string &name, instruction *next) {
return new downcast_inst(arg->get_type()->get_scalar_ty(), arg, name, next);
}
//===----------------------------------------------------------------------===//
// matmul_inst classes
//===----------------------------------------------------------------------===//
matmul_inst::matmul_inst(value *A, value *B, value *C,
dot_inst::dot_inst(value *A, value *B, value *C, TransT AT, TransT BT,
const std::string &name, instruction *next)
: builtin_inst(C->get_type(), 3, 0, name, next) {
: builtin_inst(C->get_type(), 3, 1, name, next), AT_(AT), BT_(BT) {
set_operand(0, A);
set_operand(1, B);
set_operand(2, C);
}
instruction *matmul_inst::create(value *A, value *B, value *C,
instruction *dot_inst::create_nn(value *A, value *B, value *C,
const std::string &name, instruction *next) {
return new matmul_inst(A, B, C, name, next);
return new dot_inst(A, B, C, NoTrans, NoTrans, name, next);
}
instruction *dot_inst::create_nt(value *A, value *B, value *C,
const std::string &name, instruction *next) {
return new dot_inst(A, B, C, NoTrans, Trans, name, next);
}
instruction *dot_inst::create_tn(value *A, value *B, value *C,
const std::string &name, instruction *next) {
return new dot_inst(A, B, C, Trans, NoTrans, name, next);
}
instruction *dot_inst::create_tt(value *A, value *B, value *C,
const std::string &name, instruction *next) {
return new dot_inst(A, B, C, Trans, Trans, name, next);
}
//===----------------------------------------------------------------------===//
// trans instructions
//===----------------------------------------------------------------------===//
ir::type* trans_inst::get_res_ty(ir::type* ty) {
auto shapes = ty->get_tile_shapes();
std::rotate(shapes.begin(), shapes.begin() + 1, shapes.end());
return tile_type::get(ty->get_scalar_ty(), shapes);
}
trans_inst::trans_inst(value *arg, const std::string &name, instruction *next)
: builtin_inst(get_res_ty(arg->get_type()), 1, 1, name, next) {
set_operand(0, arg);
}
instruction* trans_inst::create(value *arg, const std::string &name, instruction *next) {
return new trans_inst(arg, name, next);
}
//===----------------------------------------------------------------------===//
// select instructions
//===----------------------------------------------------------------------===//
select_inst::select_inst(value *pred, value *if_value, value *else_value, const std::string &name, instruction *next)
: builtin_inst(if_value->get_type(), 3, 1, name, next){
set_operand(0, pred);
set_operand(1, if_value);
set_operand(2, else_value);
}
instruction* select_inst::create(value *pred, value *if_value, value *else_value, const std::string &name, instruction *next) {
return new select_inst(pred, if_value, else_value, name, next);
}
//===----------------------------------------------------------------------===//
// builtin instructions
//===----------------------------------------------------------------------===//
// get_global_range
get_global_range_inst::get_global_range_inst(type *ty, unsigned axis,
const std::string &name, instruction *next)
: builtin_inst(ty, 0, 1, name, next), axis_(axis) {
@@ -516,6 +573,28 @@ instruction* get_global_range_inst::create(context &ctx, unsigned axis, type::ti
return new get_global_range_inst(tile_ty, axis, name, next);
}
// get_range_id
get_range_id_inst::get_range_id_inst(type *ty, unsigned axis, const std::string &name, instruction *next)
: builtin_inst(ty, 0, 1, name, next), axis_(axis){
}
instruction* get_range_id_inst::create(context &ctx, unsigned axis, const std::string &name, instruction *next) {
return new get_range_id_inst(type::get_int32_ty(ctx), axis, name, next);
}
// atomic cas
atomic_cas_inst::atomic_cas_inst(value *ptr, value *cmp, value *val, const std::string &name, instruction *next)
: builtin_inst(ptr->get_type()->get_pointer_element_ty(), 3, 1, name, next) {
set_operand(0, ptr);
set_operand(1, cmp);
set_operand(2, val);
}
instruction* atomic_cas_inst::create(value *ptr, value *cmp, value *val, const std::string &name, instruction *next) {
return new atomic_cas_inst(ptr, cmp, val, name, next);
}
//===----------------------------------------------------------------------===//
// intrinsic instructions
//===----------------------------------------------------------------------===//
@@ -530,7 +609,7 @@ vectorize_inst* vectorize_inst::create(value *arg, const std::string &name, inst
barrier_inst::barrier_inst(context &ctx, const std::string &name,
instruction *next)
: instruction(type::get_void_ty(ctx), 0, 1, name, next){ }
: instruction(type::get_void_ty(ctx), 0, 0, name, next){ }
barrier_inst* barrier_inst::create(context &ctx, const std::string &name, instruction *next) {
return new barrier_inst(ctx, name, next);

View File

@@ -128,6 +128,9 @@ ir::value *module::get_value(const std::string& name) {
return get_value(name, builder_.get_insert_block());
}
const std::string& module::get_name() {
return name_;
}
void module::seal_block(ir::basic_block *block){
for(auto &x: incomplete_phis_[block]){

View File

@@ -172,7 +172,7 @@ unsigned tile_type::get_bitwidth() const {
tile_type* tile_type::get(type *elt_ty, const tile_shapes_t &shapes) {
assert(elt_ty && "Can't get a tile of <null> type!");
assert(shapes.size() && "Can't create a tile with empty shapes!");
assert(is_valid_elt_ty(elt_ty) && "Invalid type for pointer element!");
assert(is_valid_elt_ty(elt_ty) && "Invalid type for tile element!");
// look-up
context_impl *impl = elt_ty->get_context().p_impl.get();
tile_type *&entry = impl->tile_tys[std::make_pair(elt_ty, shapes)];

View File

@@ -68,7 +68,7 @@ void loop_nest(std::vector<std::vector<T>> const & iterates, std::function<void(
std::unique_ptr<llvm::Module> jit::make_llvm_module(ir::module &module, passes_wrapper &passes) {
llvm::Module* result = new llvm::Module("matmul", llvm_context_);
llvm::Module* result = new llvm::Module(module.get_name(), llvm_context_);
passes.selection.run(module, *result);
// launch information
auto &launch_info_map = launch_info_map_[result->getName()];
@@ -79,14 +79,14 @@ std::unique_ptr<llvm::Module> jit::make_llvm_module(ir::module &module, passes_w
return std::unique_ptr<llvm::Module>(result);
}
std::unique_ptr<ir::module> jit::make_triton_module(const std::string &src) {
std::unique_ptr<ir::module> jit::make_triton_module(const std::string &name, const std::string &src) {
// create AST from Triton-C source
YY_BUFFER_STATE buffer = yy_scan_string(src.c_str());
yyparse();
yy_delete_buffer(buffer);
translation_unit *program = ast_root;
// create Triton-IR from AST
ir::module* module = new ir::module("matrix", triton_context_);
ir::module* module = new ir::module(name, triton_context_);
program->codegen(module);
return std::unique_ptr<ir::module>(module);
}
@@ -97,18 +97,20 @@ jit::jit(driver::context *context): driver_context_(context),
}
void jit::autotune(const std::string &src, benchmark_t benchmark) {
void jit::autotune(const std::string &name, const std::string &src, benchmark_t benchmark) {
// find metaparameters
auto ptt_module = make_triton_module(src);
auto ptt_module = make_triton_module(name, src);
ir::module &tt_module = *ptt_module;
// set parameters
passes_wrapper passes(target_.get());
passes.target_independent(tt_module);
passes.tune.run(tt_module);
auto mps = passes.tune.get_params(tt_module);
// create parameter ranges
std::vector<std::vector<unsigned>> ranges;
for(ir::metaparameter *mp: mps)
ranges.push_back(mp->get_space());
// std::cout << ranges.size() << std::endl;
// iterate over parameters
unsigned i;
double best = 0;
@@ -117,51 +119,56 @@ void jit::autotune(const std::string &src, benchmark_t benchmark) {
i = 0;
for(ir::metaparameter *mp: mps)
mp->set_value(params[i++]);
passes.target_independent(tt_module);
passes.tune.init(tt_module);
if(!passes.tune.check_constraints(errors))
return;
// Deep copy of the module and tuner
auto ptt_module = make_triton_module(src);
auto ptt_module = make_triton_module(name, src);
ir::module &tt_module = *ptt_module;
passes_wrapper passes(target_.get());
passes.target_independent(tt_module);
passes.tune.run(tt_module);
i = 0;
for(ir::metaparameter* mp: passes.tune.get_params(tt_module)){
mp->set_value(params[i++]);
}
passes.tune.init(tt_module);
passes.init(tt_module);
passes.target_dependent(tt_module);
driver::device* device = driver_context_->device();
if(passes.allocation.get_allocated_size() > device->max_shared_memory())
if(passes.shmem_allocation.get_allocated_size() > device->max_shared_memory())
return;
if(passes.tune.get_num_threads() > device->max_threads_per_block())
return;
// Compile
auto ll_module = make_llvm_module(tt_module, passes);
std::unique_ptr<driver::module> module(driver::module::create(driver_context_, &*ll_module));
std::unique_ptr<driver::kernel> kernel(driver::kernel::create(module.get(), "matmul"));
launch_information info = launch_info_map_.at("matmul");
std::unique_ptr<driver::kernel> kernel(driver::kernel::create(module.get(), name.c_str()));
launch_information info = launch_info_map_.at(name.c_str());
for(unsigned p: params)
std::cout << p << " " << std::flush;
// add globals
for(auto x: tt_module.globals())
global_ints_[x.first] = ((ir::metaparameter*)x.second)->get_value();
modules_.push_back(module.get());
double perf;
perf = benchmark(kernel.get(), info);
best = std::max(perf, best);
std::cout << perf << " [ " << best << " ] " << std::endl;
modules_.pop_back();
});
}
void jit::add_module(ir::module &tt_module, const std::vector<unsigned> &params) {
// set parameters
passes_wrapper passes(target_.get());
passes.target_independent(tt_module);
passes.tune.run(tt_module);
unsigned i = 0;
for(ir::metaparameter* mp: passes.tune.get_params(tt_module))
mp->set_value(params[i++]);
passes.tune.init(tt_module);
passes.init(tt_module);
passes.target_dependent(tt_module);
// check constraints
std::map<ir::value*, std::vector<std::string>> errors;
passes.tune.check_constraints(errors);
@@ -184,8 +191,8 @@ void jit::add_module(ir::module &tt_module, const std::vector<unsigned> &params)
global_ints_[x.first] = ((ir::metaparameter*)x.second)->get_value();
}
void jit::add_module(const std::string &src, const std::vector<unsigned> &params) {
auto ptt_module = make_triton_module(src);
void jit::add_module(const std::string &name, const std::string &src, const std::vector<unsigned> &params) {
auto ptt_module = make_triton_module(name, src);
add_module(*ptt_module, params);
}
@@ -201,4 +208,9 @@ unsigned jit::get_int(const std::string &name){
return global_ints_.at(name);
}
driver::buffer *jit::get_buffer(const std::string &name){
driver::cu_module *mod = (driver::cu_module*)modules_.front();
return mod->symbol(name.c_str());
}
}