[CODEGEN][ANALYSIS] Fixed issue in layout inference
This commit is contained in:
committed by
Philippe Tillet
parent
89e456107b
commit
ba9955ae39
@@ -38,19 +38,22 @@ inline double bench(std::function<void()> const & op, driver::stream * stream, b
|
||||
double total_time = 0;
|
||||
op();
|
||||
stream->synchronize();
|
||||
while(total_time*1e-9 < 1e-2){
|
||||
float norm = 1;
|
||||
tmr.start();
|
||||
for(size_t i = 0; i < 10; i++){
|
||||
// while(total_time*1e-9 < 1e-2){
|
||||
// float norm = 1;
|
||||
// normalize clock if possible to reduce noise in auto-tuning
|
||||
// if(normalize)
|
||||
// if(auto cu_device = dynamic_cast<const triton::driver::cu_device*>(stream->context()->device()))
|
||||
// norm = (float)cu_device->current_sm_clock()/cu_device->max_sm_clock();
|
||||
tmr.start();
|
||||
op();
|
||||
stream->synchronize();
|
||||
times.push_back(norm*tmr.get().count());
|
||||
total_time+=times.back();
|
||||
// times.push_back(norm*tmr.get().count());
|
||||
// total_time+=times.back();
|
||||
}
|
||||
return *std::min_element(times.begin(), times.end());
|
||||
stream->synchronize();
|
||||
return (float)tmr.get().count() / 10;
|
||||
|
||||
// return *std::min_element(times.begin(), times.end());
|
||||
}
|
||||
|
||||
}
|
||||
|
@@ -16,29 +16,6 @@ namespace analysis{
|
||||
* Helper Functions *
|
||||
* -------------------------------- */
|
||||
|
||||
inline int gcd_impl(int a, int b, int *x, int *y)
|
||||
{
|
||||
// Base Case
|
||||
if (a == 0)
|
||||
{
|
||||
*x = 0;
|
||||
*y = 1;
|
||||
return b;
|
||||
}
|
||||
int x1, y1; // To store results of recursive call
|
||||
int gcd = gcd_impl(b%a, a, &x1, &y1);
|
||||
// Update x and y using results of
|
||||
// recursive call
|
||||
*x = y1 - (b/a) * x1;
|
||||
*y = x1;
|
||||
return gcd;
|
||||
}
|
||||
|
||||
inline int gcd(int a, int b) {
|
||||
int x, y;
|
||||
return gcd_impl(a, b, &x, &y);
|
||||
}
|
||||
|
||||
inline unsigned clamp(unsigned x, unsigned a, unsigned b) {
|
||||
unsigned lo = std::min(a, b);
|
||||
unsigned hi = std::max(a, b);
|
||||
@@ -210,8 +187,7 @@ scanline_layout::scanline_layout(size_t num_warps,
|
||||
if(ptr)
|
||||
contiguous = std::min<int>(align->contiguous(ptr)[i], 4);
|
||||
|
||||
int max_contiguous = shape_[i] / (num_warps*32);
|
||||
nts_[i] = clamp(size / num_threads, 1, gcd(contiguous, max_contiguous));
|
||||
nts_[i] = clamp(size / num_threads, 1, std::min<int>(contiguous, shape_[i]));
|
||||
mts_[i] = clamp(num_threads, 1, shape_[i] / nts_[i]);
|
||||
size /= shape_[i];
|
||||
num_threads /= mts_[i];
|
||||
|
@@ -21,13 +21,11 @@ class _conv(torch.autograd.Function):
|
||||
int ridx = get_program_id(0);
|
||||
int ridy = get_program_id(1);
|
||||
int ridz = get_program_id(2);
|
||||
/*
|
||||
int gridx = M / TM;
|
||||
int gridy = N / TN;
|
||||
int rid = ridx + ridy * gridx;
|
||||
ridx = rid / gridy;
|
||||
ridy = rid % gridy;
|
||||
*/
|
||||
int rm[TM] = ridx * TM + 0 ... TM;
|
||||
int rn[TN] = ridy * TN + 0 ... TN;
|
||||
// reduction splitting
|
||||
@@ -68,10 +66,12 @@ class _conv(torch.autograd.Function):
|
||||
TYPE* pb[TK, TN] = B + offb;
|
||||
|
||||
// prefetches operands
|
||||
bool checka[TM, TK] = rh >= 0 && rh < HH && rw >= 0 && rw < WW;
|
||||
bool checkam[TM, TK] = rm[:, newaxis] < M;
|
||||
bool checka[TM, TK] = checkam && rh >= 0 && rh < HH && rw >= 0 && rw < WW;
|
||||
bool checkb[TK, TN] = rk[:, newaxis] < K;
|
||||
TYPE a[TM, TK] = checka ? *pa : 0;
|
||||
TYPE b[TK, TN] = checkb ? *pb : 0;
|
||||
int total = 0;
|
||||
|
||||
// reduction loop
|
||||
float acc[TM, TN] = 0;
|
||||
@@ -81,8 +81,6 @@ class _conv(torch.autograd.Function):
|
||||
int adelta[TK] = *padelta;
|
||||
padelta += TK;
|
||||
pa += adelta[newaxis, :];
|
||||
// increment B
|
||||
pb += TK * ldb_s;
|
||||
// bounds-checking A
|
||||
rk += TK;
|
||||
rs = rk % SS;
|
||||
@@ -90,7 +88,9 @@ class _conv(torch.autograd.Function):
|
||||
rr = rcir % RR;
|
||||
rh = rh_0[:, newaxis] + rr[newaxis, :];
|
||||
rw = rw_0[:, newaxis] + rs[newaxis, :];
|
||||
bool checka[TM, TK] = rh >= 0 && rh < HH && rw >= 0 && rw < WW;
|
||||
bool checka[TM, TK] = checkam && rh >= 0 && rh < HH && rw >= 0 && rw < WW;
|
||||
// increment B
|
||||
pb += TK * ldb_s;
|
||||
// bounds-checking B
|
||||
bool checkb[TK, TN] = k > TK;
|
||||
a = checka ? *pa : 0;
|
||||
@@ -152,18 +152,18 @@ class _conv(torch.autograd.Function):
|
||||
Q = (W + 2*pad[1] - S)//stride[1] + 1
|
||||
# compile kernel
|
||||
if dtype not in _conv.kernel:
|
||||
TK = 8
|
||||
defines = {
|
||||
'TYPE' : dtype,
|
||||
'TM' : [64, 128],
|
||||
'TN' : [64, 128],
|
||||
'TK' : [8],
|
||||
'TM' : [16, 32, 64, 128],
|
||||
'TN' : [16, 32, 64, 128],
|
||||
'TK' : [TK],
|
||||
'TZ' : [1],
|
||||
'LUTSIZE' : 4*CI*R*S,
|
||||
'HH': H, 'WW': W, 'PP': P, 'QQ': Q, 'SS': S, 'RR': R,
|
||||
}
|
||||
idx = torch.arange(CI*R*S)
|
||||
ci, r, s = _conv.unpack(idx, CI, R, S)
|
||||
nci, nr, ns = _conv.unpack(idx + 8, CI, R, S)
|
||||
nci, nr, ns = _conv.unpack(idx + TK, CI, R, S)
|
||||
delta = (nci - ci)*a.stride(1) + (nr - r)*a.stride(2) + (ns - s)*a.stride(3)
|
||||
delta = delta.type(torch.int32).cuda()
|
||||
_conv.kernel[dtype] = (delta, triton.kernel(_conv.src, num_warps=[2, 4], defines=defines))
|
||||
@@ -186,7 +186,7 @@ class _conv(torch.autograd.Function):
|
||||
|
||||
conv = _conv.apply
|
||||
torch.manual_seed(0)
|
||||
Z, H, W, CI, CO, R, S = 1, 32, 64, 256, 2048, 3, 3
|
||||
Z, H, W, CI, CO, R, S = 1, 56, 56, 1024, 1024, 3, 3
|
||||
pad = (1, 1)
|
||||
stride = (1, 1)
|
||||
a = torch.rand((Z, CI, H, W)).cuda()
|
||||
|
@@ -112,7 +112,7 @@ class _dot(torch.autograd.Function):
|
||||
time = kernel(a, b, c, 1., M, N, K,
|
||||
a.stride(0), b.stride(0), c.stride(0),
|
||||
grid=grid, bench=100)
|
||||
print(time)
|
||||
print(2*M*N*K/(time*1e-6)*1e-9)
|
||||
return c
|
||||
|
||||
|
||||
|
@@ -170,8 +170,8 @@ class kernel:
|
||||
arg_types = libtriton.get_fn_signature(self.src, opt)
|
||||
self.fw_op = _make_framework_op(arg_types)
|
||||
|
||||
def set_constant(self, name, value):
|
||||
libtriton.register_cst(self.op_id, name, value)
|
||||
def set_constant(self, device, name, value):
|
||||
libtriton.register_cst((self.op_id, device), name, value)
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
for x in args:
|
||||
|
@@ -1,4 +1,4 @@
|
||||
foreach(PROG dot copy)
|
||||
foreach(PROG dot copy conv)
|
||||
set(TARGET bench_${PROG})
|
||||
add_executable(${TARGET} ${PROG}.cc)
|
||||
set_target_properties(${TARGET} PROPERTIES OUTPUT_NAME ${TARGET})
|
||||
|
@@ -9,11 +9,40 @@ int main() {
|
||||
// shapes to benchmark
|
||||
typedef std::tuple<std::vector<int>, bool, bool, int, int, int> config_t;
|
||||
std::vector<config_t> configs;
|
||||
for(auto ord: std::vector<std::vector<int>>{{1, 0}})
|
||||
for(auto x: std::vector<std::array<bool, 2>>{{false, false}}){
|
||||
for(auto ord: std::vector<std::vector<int>>{{0, 1}})
|
||||
for(auto x: std::vector<std::array<bool, 2>>{{false, true}, {false, false}, {true, false}, {true, true}}){
|
||||
std::vector<config_t> tmp = {
|
||||
// config_t{ord, x[0], x[1], 128, 128, 128},
|
||||
// config_t{ord, x[0], x[1], 256, 256, 256},
|
||||
// config_t{ord, x[0], x[1], 384, 384, 384},
|
||||
// config_t{ord, x[0], x[1], 512, 512, 512},
|
||||
config_t{ord, x[0], x[1], 1024, 1024, 1024},
|
||||
// config_t{ord, x[0], x[1], 768, 768, 768},
|
||||
// config_t{ord, x[0], x[1], 1024, 1024, 1024},
|
||||
// config_t{ord, x[0], x[1], 1280, 1280, 1280},
|
||||
// config_t{ord, x[0], x[1], 1536, 1536, 1536},
|
||||
// config_t{ord, x[0], x[1], 2048, 2048, 2048},
|
||||
config_t{ord, x[0], x[1], 8192, 8192, 8192},
|
||||
|
||||
// config_t{ord, x[0], x[1], 256, 16, 256},
|
||||
// config_t{ord, x[0], x[1], 512, 16, 512},
|
||||
// config_t{ord, x[0], x[1], 768, 16, 768},
|
||||
// config_t{ord, x[0], x[1], 1024, 16, 1024},
|
||||
// config_t{ord, x[0], x[1], 1280, 16, 1280},
|
||||
// config_t{ord, x[0], x[1], 1536, 16, 1536},
|
||||
// config_t{ord, x[0], x[1], 2048, 16, 2048},
|
||||
// config_t{ord, x[0], x[1], 3072, 16, 3072},
|
||||
// config_t{ord, x[0], x[1], 4096, 16, 4096},
|
||||
// config_t{ord, x[0], x[1], 5120, 16, 5120},
|
||||
// config_t{ord, x[0], x[1], 6144, 16, 6144},
|
||||
// config_t{ord, x[0], x[1], 7168, 16, 7168},
|
||||
|
||||
// config_t{ord, x[0], x[1], 64, 64, 4096},
|
||||
// config_t{ord, x[0], x[1], 64, 64, 8192},
|
||||
// config_t{ord, x[0], x[1], 64, 64, 16384},
|
||||
// config_t{ord, x[0], x[1], 64, 64, 32768},
|
||||
// config_t{ord, x[0], x[1], 64, 64, 65536},
|
||||
// config_t{ord, x[0], x[1], 64, 64, 131072}
|
||||
|
||||
// config_t{ord, x[0], x[1], 127008, 768, 576},
|
||||
// config_t{ord, x[0], x[1], 8192, 8192, 8192}
|
||||
// config_t{ord, x[0], x[1], 16, 2048, 2048},
|
||||
@@ -36,7 +65,7 @@ int main() {
|
||||
for(const auto& c: configs){
|
||||
std::tie(ord, AT, BT, M, N, K) = c;
|
||||
std::cout << "// " << c ;
|
||||
for(auto perf: bench_dot(stream, FLOAT, AT, BT, M, N, K, ord, ord))
|
||||
for(auto perf: bench_dot(stream, HALF, AT, BT, M, N, K, ord, ord))
|
||||
std::cout << ", " << perf << std::flush;
|
||||
std::cout << std::endl;
|
||||
}
|
||||
|
@@ -82,6 +82,8 @@ void triton_dot(drv::stream* stream, bool AT, bool BT,
|
||||
auto dc = std::shared_ptr<drv::buffer>(drv::buffer::create(context, M*N*dt_nbytes));
|
||||
auto da = std::shared_ptr<drv::buffer>(drv::buffer::create(context, M*K*dt_nbytes));
|
||||
auto db = std::shared_ptr<drv::buffer>(drv::buffer::create(context, K*N*dt_nbytes));
|
||||
auto dlocks = std::shared_ptr<drv::buffer>(drv::buffer::create(context, 1024*1024*2*4));
|
||||
((drv::cu_buffer*)dlocks.get())->set_zero(stream, dlocks->size());
|
||||
|
||||
// macros
|
||||
rt::function::options_space_t opt;
|
||||
@@ -110,15 +112,21 @@ void triton_dot(drv::stream* stream, bool AT, bool BT,
|
||||
}
|
||||
if(mode == BENCH) {
|
||||
opt.defines.push_back({"TM", {"128"}});
|
||||
opt.defines.push_back({"TN", {"32"}});
|
||||
opt.defines.push_back({"TK", {to_string<T>::value == "half" ? "16" : "8"}});
|
||||
opt.defines.push_back({"TN", {"128"}});
|
||||
opt.defines.push_back({"TK", {"16"}});
|
||||
opt.defines.push_back({"TZ", {"1"}});
|
||||
opt.num_warps = {4};
|
||||
}
|
||||
|
||||
// kernels
|
||||
|
||||
rt::function function(src::dot, opt);
|
||||
std::vector<rt::arg> args = {&*da, &*db, &*dc, (float)1, M, N, K, lda, ldb, ldc};
|
||||
auto grid = grid2d(M, N);
|
||||
std::vector<rt::arg> args = {&*da, &*db, &*dc, (float)1, M, N, K, lda, ldb, ldc, &*dlocks};
|
||||
auto grid = [M, N](const rt::function::options_t& x) {
|
||||
return rt::grid_t{ceil(M, x.D<int>("TM")),
|
||||
ceil(N, x.D<int>("TN")),
|
||||
(size_t)x.D<int>("TZ")};
|
||||
};
|
||||
|
||||
// metrics
|
||||
if(mode == BENCH){
|
||||
@@ -126,13 +134,13 @@ void triton_dot(drv::stream* stream, bool AT, bool BT,
|
||||
double triton_ns = triton::tools::bench([&]() { function(args, grid, stream);}, stream);
|
||||
bench.push_back(tflops(triton_ns));
|
||||
|
||||
// // cublas
|
||||
// cublas
|
||||
// if(cublas::cublasinit()){
|
||||
// T alpha(static_cast<double>(1));
|
||||
// T beta(static_cast<double>(0));
|
||||
// cublasGemmAlgo_t fastest;
|
||||
// cublasGemm(CUDA_R_32F, stream, AT, BT, M, N, K, &alpha, &*da, lda, &*db, ldb, &beta, &*dc, ldc, &fastest);
|
||||
// double cublas_ms = triton::tools::bench([&]() { cublasGemm(CUDA_R_16F, stream, AT, BT, M, N, K,
|
||||
// double cublas_ms = triton::tools::bench([&]() { cublasGemm(CUDA_R_32F, stream, AT, BT, M, N, K,
|
||||
// &alpha, &*da, lda, &*db, ldb, &beta, &*dc,
|
||||
// ldc, nullptr, fastest); }, stream);
|
||||
// bench.push_back(tflops(cublas_ms));
|
||||
|
@@ -6,13 +6,15 @@ __global__ void dot(TYPE * A __noalias __readonly __aligned(16),
|
||||
TYPE * B __noalias __readonly __aligned(16),
|
||||
TYPE * C __noalias __aligned(16),
|
||||
float alpha,
|
||||
int M, int N, int K,
|
||||
int M, int N, int K __multipleof(16),
|
||||
int lda __multipleof(8),
|
||||
int ldb __multipleof(8),
|
||||
int ldc __multipleof(8)) {
|
||||
int ldc __multipleof(8),
|
||||
int* locks) {
|
||||
// prologue
|
||||
int ridx = get_program_id(0);
|
||||
int ridy = get_program_id(1);
|
||||
int ridz = get_program_id(2);
|
||||
int gridx = M / TM;
|
||||
int gridy = N / TN;
|
||||
int rid = ridx + ridy * gridx;
|
||||
@@ -20,7 +22,10 @@ __global__ void dot(TYPE * A __noalias __readonly __aligned(16),
|
||||
ridy = rid % gridy;
|
||||
int rm[TM] = ridx * TM + 0 ... TM;
|
||||
int rn[TN] = ridy * TN + 0 ... TN;
|
||||
int rk[TK] = 0 ... TK;
|
||||
|
||||
// reduction splitting
|
||||
K = K / TZ;
|
||||
int rk[TK] = ridz * K + 0 ... TK;
|
||||
|
||||
// pointers to operands
|
||||
int offa[SHAPE_A] = rk[BROADCAST_AK] * STRIDE_AK + rm[BROADCAST_AM] * STRIDE_AM;
|
||||
@@ -35,9 +40,9 @@ __global__ void dot(TYPE * A __noalias __readonly __aligned(16),
|
||||
TYPE b[SHAPE_B] = checkb ? *pb : 0;
|
||||
|
||||
// reduction loop
|
||||
float c[TM, TN] = 0;
|
||||
float acc[TM, TN] = 0;
|
||||
for(int k = K; k > 0; k -= TK){
|
||||
c += USEA @ USEB;
|
||||
acc += USEA @ USEB;
|
||||
bool checka[SHAPE_A] = k > TK;
|
||||
bool checkb[SHAPE_B] = k > TK;
|
||||
pa += TK * STRIDE_AK;
|
||||
@@ -45,7 +50,8 @@ __global__ void dot(TYPE * A __noalias __readonly __aligned(16),
|
||||
a = *?(checka)pa;
|
||||
b = *?(checkb)pb;
|
||||
}
|
||||
//c = c * alpha;
|
||||
acc = acc * alpha;
|
||||
TYPE c[TM, TN] = acc;
|
||||
|
||||
// epilogue
|
||||
int rxm[TM] = get_program_id(0) * TM + 0 ... TM;
|
||||
@@ -53,7 +59,22 @@ __global__ void dot(TYPE * A __noalias __readonly __aligned(16),
|
||||
int offc[TM, TN] = rxm[:, newaxis] * ldc + rxn[newaxis, :];
|
||||
TYPE* pc[TM, TN] = C + offc;
|
||||
bool checkc[TM, TN] = (rxm[:, newaxis] < M) && (rxn[newaxis, :] < N);
|
||||
*?(checkc)pc = (TYPE[TM, TN])c;
|
||||
|
||||
#if (TZ==1)
|
||||
*?(checkc) pc = c;
|
||||
#else
|
||||
// accumulate partial result using spin-locks
|
||||
int *plock = locks + rid;
|
||||
int *pcount = plock + get_num_programs(0) * get_num_programs(1);
|
||||
for(int repeat = 1; repeat == 1; repeat = atomic_cas(plock, 0, 1));
|
||||
int count = *pcount;
|
||||
if(count == 0)
|
||||
*?(checkc) pc = c;
|
||||
else
|
||||
*?(checkc) pc = c + *?(checkc)pc;
|
||||
atomic_xchg(pcount, (count + 1) % TZ);
|
||||
atomic_xchg(plock, 0);
|
||||
#endif
|
||||
}
|
||||
)";
|
||||
|
||||
|
Reference in New Issue
Block a user