[CODEGEN][ANALYSIS] Fixed issue in layout inference

This commit is contained in:
Philippe Tillet
2020-08-07 00:49:04 -04:00
committed by Philippe Tillet
parent 89e456107b
commit ba9955ae39
9 changed files with 109 additions and 72 deletions

View File

@@ -38,19 +38,22 @@ inline double bench(std::function<void()> const & op, driver::stream * stream, b
double total_time = 0;
op();
stream->synchronize();
while(total_time*1e-9 < 1e-2){
float norm = 1;
tmr.start();
for(size_t i = 0; i < 10; i++){
// while(total_time*1e-9 < 1e-2){
// float norm = 1;
// normalize clock if possible to reduce noise in auto-tuning
// if(normalize)
// if(auto cu_device = dynamic_cast<const triton::driver::cu_device*>(stream->context()->device()))
// norm = (float)cu_device->current_sm_clock()/cu_device->max_sm_clock();
tmr.start();
op();
stream->synchronize();
times.push_back(norm*tmr.get().count());
total_time+=times.back();
// times.push_back(norm*tmr.get().count());
// total_time+=times.back();
}
return *std::min_element(times.begin(), times.end());
stream->synchronize();
return (float)tmr.get().count() / 10;
// return *std::min_element(times.begin(), times.end());
}
}

View File

@@ -16,29 +16,6 @@ namespace analysis{
* Helper Functions *
* -------------------------------- */
inline int gcd_impl(int a, int b, int *x, int *y)
{
// Base Case
if (a == 0)
{
*x = 0;
*y = 1;
return b;
}
int x1, y1; // To store results of recursive call
int gcd = gcd_impl(b%a, a, &x1, &y1);
// Update x and y using results of
// recursive call
*x = y1 - (b/a) * x1;
*y = x1;
return gcd;
}
inline int gcd(int a, int b) {
int x, y;
return gcd_impl(a, b, &x, &y);
}
inline unsigned clamp(unsigned x, unsigned a, unsigned b) {
unsigned lo = std::min(a, b);
unsigned hi = std::max(a, b);
@@ -210,8 +187,7 @@ scanline_layout::scanline_layout(size_t num_warps,
if(ptr)
contiguous = std::min<int>(align->contiguous(ptr)[i], 4);
int max_contiguous = shape_[i] / (num_warps*32);
nts_[i] = clamp(size / num_threads, 1, gcd(contiguous, max_contiguous));
nts_[i] = clamp(size / num_threads, 1, std::min<int>(contiguous, shape_[i]));
mts_[i] = clamp(num_threads, 1, shape_[i] / nts_[i]);
size /= shape_[i];
num_threads /= mts_[i];

View File

@@ -21,13 +21,11 @@ class _conv(torch.autograd.Function):
int ridx = get_program_id(0);
int ridy = get_program_id(1);
int ridz = get_program_id(2);
/*
int gridx = M / TM;
int gridy = N / TN;
int rid = ridx + ridy * gridx;
ridx = rid / gridy;
ridy = rid % gridy;
*/
int rm[TM] = ridx * TM + 0 ... TM;
int rn[TN] = ridy * TN + 0 ... TN;
// reduction splitting
@@ -68,10 +66,12 @@ class _conv(torch.autograd.Function):
TYPE* pb[TK, TN] = B + offb;
// prefetches operands
bool checka[TM, TK] = rh >= 0 && rh < HH && rw >= 0 && rw < WW;
bool checkam[TM, TK] = rm[:, newaxis] < M;
bool checka[TM, TK] = checkam && rh >= 0 && rh < HH && rw >= 0 && rw < WW;
bool checkb[TK, TN] = rk[:, newaxis] < K;
TYPE a[TM, TK] = checka ? *pa : 0;
TYPE b[TK, TN] = checkb ? *pb : 0;
int total = 0;
// reduction loop
float acc[TM, TN] = 0;
@@ -81,8 +81,6 @@ class _conv(torch.autograd.Function):
int adelta[TK] = *padelta;
padelta += TK;
pa += adelta[newaxis, :];
// increment B
pb += TK * ldb_s;
// bounds-checking A
rk += TK;
rs = rk % SS;
@@ -90,7 +88,9 @@ class _conv(torch.autograd.Function):
rr = rcir % RR;
rh = rh_0[:, newaxis] + rr[newaxis, :];
rw = rw_0[:, newaxis] + rs[newaxis, :];
bool checka[TM, TK] = rh >= 0 && rh < HH && rw >= 0 && rw < WW;
bool checka[TM, TK] = checkam && rh >= 0 && rh < HH && rw >= 0 && rw < WW;
// increment B
pb += TK * ldb_s;
// bounds-checking B
bool checkb[TK, TN] = k > TK;
a = checka ? *pa : 0;
@@ -152,18 +152,18 @@ class _conv(torch.autograd.Function):
Q = (W + 2*pad[1] - S)//stride[1] + 1
# compile kernel
if dtype not in _conv.kernel:
TK = 8
defines = {
'TYPE' : dtype,
'TM' : [64, 128],
'TN' : [64, 128],
'TK' : [8],
'TM' : [16, 32, 64, 128],
'TN' : [16, 32, 64, 128],
'TK' : [TK],
'TZ' : [1],
'LUTSIZE' : 4*CI*R*S,
'HH': H, 'WW': W, 'PP': P, 'QQ': Q, 'SS': S, 'RR': R,
}
idx = torch.arange(CI*R*S)
ci, r, s = _conv.unpack(idx, CI, R, S)
nci, nr, ns = _conv.unpack(idx + 8, CI, R, S)
nci, nr, ns = _conv.unpack(idx + TK, CI, R, S)
delta = (nci - ci)*a.stride(1) + (nr - r)*a.stride(2) + (ns - s)*a.stride(3)
delta = delta.type(torch.int32).cuda()
_conv.kernel[dtype] = (delta, triton.kernel(_conv.src, num_warps=[2, 4], defines=defines))
@@ -186,7 +186,7 @@ class _conv(torch.autograd.Function):
conv = _conv.apply
torch.manual_seed(0)
Z, H, W, CI, CO, R, S = 1, 32, 64, 256, 2048, 3, 3
Z, H, W, CI, CO, R, S = 1, 56, 56, 1024, 1024, 3, 3
pad = (1, 1)
stride = (1, 1)
a = torch.rand((Z, CI, H, W)).cuda()

View File

@@ -112,7 +112,7 @@ class _dot(torch.autograd.Function):
time = kernel(a, b, c, 1., M, N, K,
a.stride(0), b.stride(0), c.stride(0),
grid=grid, bench=100)
print(time)
print(2*M*N*K/(time*1e-6)*1e-9)
return c

View File

@@ -170,8 +170,8 @@ class kernel:
arg_types = libtriton.get_fn_signature(self.src, opt)
self.fw_op = _make_framework_op(arg_types)
def set_constant(self, name, value):
libtriton.register_cst(self.op_id, name, value)
def set_constant(self, device, name, value):
libtriton.register_cst((self.op_id, device), name, value)
def __call__(self, *args, **kwargs):
for x in args:

View File

@@ -1,4 +1,4 @@
foreach(PROG dot copy)
foreach(PROG dot copy conv)
set(TARGET bench_${PROG})
add_executable(${TARGET} ${PROG}.cc)
set_target_properties(${TARGET} PROPERTIES OUTPUT_NAME ${TARGET})

View File

@@ -9,11 +9,40 @@ int main() {
// shapes to benchmark
typedef std::tuple<std::vector<int>, bool, bool, int, int, int> config_t;
std::vector<config_t> configs;
for(auto ord: std::vector<std::vector<int>>{{1, 0}})
for(auto x: std::vector<std::array<bool, 2>>{{false, false}}){
for(auto ord: std::vector<std::vector<int>>{{0, 1}})
for(auto x: std::vector<std::array<bool, 2>>{{false, true}, {false, false}, {true, false}, {true, true}}){
std::vector<config_t> tmp = {
// config_t{ord, x[0], x[1], 128, 128, 128},
// config_t{ord, x[0], x[1], 256, 256, 256},
// config_t{ord, x[0], x[1], 384, 384, 384},
// config_t{ord, x[0], x[1], 512, 512, 512},
config_t{ord, x[0], x[1], 1024, 1024, 1024},
// config_t{ord, x[0], x[1], 768, 768, 768},
// config_t{ord, x[0], x[1], 1024, 1024, 1024},
// config_t{ord, x[0], x[1], 1280, 1280, 1280},
// config_t{ord, x[0], x[1], 1536, 1536, 1536},
// config_t{ord, x[0], x[1], 2048, 2048, 2048},
config_t{ord, x[0], x[1], 8192, 8192, 8192},
// config_t{ord, x[0], x[1], 256, 16, 256},
// config_t{ord, x[0], x[1], 512, 16, 512},
// config_t{ord, x[0], x[1], 768, 16, 768},
// config_t{ord, x[0], x[1], 1024, 16, 1024},
// config_t{ord, x[0], x[1], 1280, 16, 1280},
// config_t{ord, x[0], x[1], 1536, 16, 1536},
// config_t{ord, x[0], x[1], 2048, 16, 2048},
// config_t{ord, x[0], x[1], 3072, 16, 3072},
// config_t{ord, x[0], x[1], 4096, 16, 4096},
// config_t{ord, x[0], x[1], 5120, 16, 5120},
// config_t{ord, x[0], x[1], 6144, 16, 6144},
// config_t{ord, x[0], x[1], 7168, 16, 7168},
// config_t{ord, x[0], x[1], 64, 64, 4096},
// config_t{ord, x[0], x[1], 64, 64, 8192},
// config_t{ord, x[0], x[1], 64, 64, 16384},
// config_t{ord, x[0], x[1], 64, 64, 32768},
// config_t{ord, x[0], x[1], 64, 64, 65536},
// config_t{ord, x[0], x[1], 64, 64, 131072}
// config_t{ord, x[0], x[1], 127008, 768, 576},
// config_t{ord, x[0], x[1], 8192, 8192, 8192}
// config_t{ord, x[0], x[1], 16, 2048, 2048},
@@ -36,7 +65,7 @@ int main() {
for(const auto& c: configs){
std::tie(ord, AT, BT, M, N, K) = c;
std::cout << "// " << c ;
for(auto perf: bench_dot(stream, FLOAT, AT, BT, M, N, K, ord, ord))
for(auto perf: bench_dot(stream, HALF, AT, BT, M, N, K, ord, ord))
std::cout << ", " << perf << std::flush;
std::cout << std::endl;
}

View File

@@ -82,6 +82,8 @@ void triton_dot(drv::stream* stream, bool AT, bool BT,
auto dc = std::shared_ptr<drv::buffer>(drv::buffer::create(context, M*N*dt_nbytes));
auto da = std::shared_ptr<drv::buffer>(drv::buffer::create(context, M*K*dt_nbytes));
auto db = std::shared_ptr<drv::buffer>(drv::buffer::create(context, K*N*dt_nbytes));
auto dlocks = std::shared_ptr<drv::buffer>(drv::buffer::create(context, 1024*1024*2*4));
((drv::cu_buffer*)dlocks.get())->set_zero(stream, dlocks->size());
// macros
rt::function::options_space_t opt;
@@ -110,15 +112,21 @@ void triton_dot(drv::stream* stream, bool AT, bool BT,
}
if(mode == BENCH) {
opt.defines.push_back({"TM", {"128"}});
opt.defines.push_back({"TN", {"32"}});
opt.defines.push_back({"TK", {to_string<T>::value == "half" ? "16" : "8"}});
opt.defines.push_back({"TN", {"128"}});
opt.defines.push_back({"TK", {"16"}});
opt.defines.push_back({"TZ", {"1"}});
opt.num_warps = {4};
}
// kernels
rt::function function(src::dot, opt);
std::vector<rt::arg> args = {&*da, &*db, &*dc, (float)1, M, N, K, lda, ldb, ldc};
auto grid = grid2d(M, N);
std::vector<rt::arg> args = {&*da, &*db, &*dc, (float)1, M, N, K, lda, ldb, ldc, &*dlocks};
auto grid = [M, N](const rt::function::options_t& x) {
return rt::grid_t{ceil(M, x.D<int>("TM")),
ceil(N, x.D<int>("TN")),
(size_t)x.D<int>("TZ")};
};
// metrics
if(mode == BENCH){
@@ -126,13 +134,13 @@ void triton_dot(drv::stream* stream, bool AT, bool BT,
double triton_ns = triton::tools::bench([&]() { function(args, grid, stream);}, stream);
bench.push_back(tflops(triton_ns));
// // cublas
// cublas
// if(cublas::cublasinit()){
// T alpha(static_cast<double>(1));
// T beta(static_cast<double>(0));
// cublasGemmAlgo_t fastest;
// cublasGemm(CUDA_R_32F, stream, AT, BT, M, N, K, &alpha, &*da, lda, &*db, ldb, &beta, &*dc, ldc, &fastest);
// double cublas_ms = triton::tools::bench([&]() { cublasGemm(CUDA_R_16F, stream, AT, BT, M, N, K,
// double cublas_ms = triton::tools::bench([&]() { cublasGemm(CUDA_R_32F, stream, AT, BT, M, N, K,
// &alpha, &*da, lda, &*db, ldb, &beta, &*dc,
// ldc, nullptr, fastest); }, stream);
// bench.push_back(tflops(cublas_ms));

View File

@@ -6,13 +6,15 @@ __global__ void dot(TYPE * A __noalias __readonly __aligned(16),
TYPE * B __noalias __readonly __aligned(16),
TYPE * C __noalias __aligned(16),
float alpha,
int M, int N, int K,
int M, int N, int K __multipleof(16),
int lda __multipleof(8),
int ldb __multipleof(8),
int ldc __multipleof(8)) {
int ldc __multipleof(8),
int* locks) {
// prologue
int ridx = get_program_id(0);
int ridy = get_program_id(1);
int ridz = get_program_id(2);
int gridx = M / TM;
int gridy = N / TN;
int rid = ridx + ridy * gridx;
@@ -20,7 +22,10 @@ __global__ void dot(TYPE * A __noalias __readonly __aligned(16),
ridy = rid % gridy;
int rm[TM] = ridx * TM + 0 ... TM;
int rn[TN] = ridy * TN + 0 ... TN;
int rk[TK] = 0 ... TK;
// reduction splitting
K = K / TZ;
int rk[TK] = ridz * K + 0 ... TK;
// pointers to operands
int offa[SHAPE_A] = rk[BROADCAST_AK] * STRIDE_AK + rm[BROADCAST_AM] * STRIDE_AM;
@@ -35,9 +40,9 @@ __global__ void dot(TYPE * A __noalias __readonly __aligned(16),
TYPE b[SHAPE_B] = checkb ? *pb : 0;
// reduction loop
float c[TM, TN] = 0;
float acc[TM, TN] = 0;
for(int k = K; k > 0; k -= TK){
c += USEA @ USEB;
acc += USEA @ USEB;
bool checka[SHAPE_A] = k > TK;
bool checkb[SHAPE_B] = k > TK;
pa += TK * STRIDE_AK;
@@ -45,7 +50,8 @@ __global__ void dot(TYPE * A __noalias __readonly __aligned(16),
a = *?(checka)pa;
b = *?(checkb)pb;
}
//c = c * alpha;
acc = acc * alpha;
TYPE c[TM, TN] = acc;
// epilogue
int rxm[TM] = get_program_id(0) * TM + 0 ... TM;
@@ -53,7 +59,22 @@ __global__ void dot(TYPE * A __noalias __readonly __aligned(16),
int offc[TM, TN] = rxm[:, newaxis] * ldc + rxn[newaxis, :];
TYPE* pc[TM, TN] = C + offc;
bool checkc[TM, TN] = (rxm[:, newaxis] < M) && (rxn[newaxis, :] < N);
*?(checkc)pc = (TYPE[TM, TN])c;
#if (TZ==1)
*?(checkc) pc = c;
#else
// accumulate partial result using spin-locks
int *plock = locks + rid;
int *pcount = plock + get_num_programs(0) * get_num_programs(1);
for(int repeat = 1; repeat == 1; repeat = atomic_cas(plock, 0, 1));
int count = *pcount;
if(count == 0)
*?(checkc) pc = c;
else
*?(checkc) pc = c + *?(checkc)pc;
atomic_xchg(pcount, (count + 1) % TZ);
atomic_xchg(plock, 0);
#endif
}
)";