Files
triton/python/src/torch/superblock.cc
Philippe Tillet 269ebc12e5 [PYTHON][TESTS][DOC] Various improvement of the API and code quality:
* Simplified `triton.kernel` API to achieve lower latency:
  > .data_ptr() must now be passed as kernel argument. No more implicit
conversion from torch.tensor
  > compilation options are now constant attributes, i.e., opt.d('VAR')
becomes opt.VAR
  > torch.device must now be passed explicitly to triton.kernel (no
longer inferred from torch.tensor arguments)
* C++ tests moved to `python/tests/`
* C++ tutorial created in `tutorials/`
* Python tutorial created in python/tutorials/
* Version changed to 1.0alpha
* No longer copying C++ headers into the Python package
* added python/triton/ops/ package for pre-written Triton ops
2021-07-27 12:38:48 -07:00

117 lines
3.8 KiB
C++

#include <torch/extension.h>
#include <string>
#include <tuple>
#include <vector>
#ifdef _OPENMP
#include <omp.h>
#endif
typedef std::vector<std::tuple<int, torch::Tensor>> ret_t;
void segment_blocks(torch::Tensor layout, torch::Tensor idx, torch::Tensor scratch, int max_width, ret_t& ret){
size_t H = layout.size(0);
size_t M = layout.size(1);
size_t N = layout.size(2);
torch::Tensor tmp = torch::zeros_like(layout);
auto _tmp = tmp.accessor <int, 3>();
auto _layout = layout.accessor <int, 3>();
auto _idx = idx.accessor <int, 3>();
auto _scratch = scratch.accessor<int, 3>();
std::vector<int> current(H, 0);
#ifdef _OPENMP
#pragma omp parallel for
#endif
for(size_t h = 0; h < H; h++){
// surrounding indices
std::vector<int> ii_left(max_width, -1);
std::vector<std::vector<int>> ii_top(max_width, std::vector<int>(N, -1));
for(size_t m = 0; m < M; m++){
for(size_t n = 0; n < N; n++){
int v = _layout[h][m][n];
if(v == 0)
continue;
int n_left= ii_left[max_width-1];
int m_top = ii_top [max_width-1][n];
int top = (m_top >= 0) ? _tmp[h][m_top][n] : 0;
int left = (n_left >= 0) ? _tmp[h][m][n_left] : 0;
int topleft = (m_top >=0 && n_left >= 0) ? _tmp[h][m_top][n_left] : 0;
int width = std::min(left, std::min(top, topleft)) + 1;
// reset width if blocks cannot be
// packed together (i.e., there's a 1 "in the middle")
for(int nn = n_left + 1; nn < n; nn++)
if(ii_top[max_width-1][nn] > ii_top[max_width-1][n])
width = 1;
_tmp[h][m][n] = width;
// update n_left ring buffer
for(int k = 0; k < max_width-1; k++)
ii_left[k] = ii_left[k+1];
ii_left[max_width-1] = n;
// update ii_top ring buffer
for(int k = 0; k < max_width-1; k++)
ii_top[k][n] = ii_top[k+1][n];
ii_top[max_width-1][n] = m;
// block is too small -- skip
if(width != max_width)
continue;
// retained blocks are set to zeros
for(size_t km = 0; km < max_width; km++)
for(size_t kn = 0; kn < max_width; kn++)
{
int mm = ii_top[km][n];
int nn = ii_left[kn];
if(mm < 0 || nn < 0)
continue;
_layout[h][mm][nn] = 0;
_tmp[h][mm][nn] = 0;
_scratch[h][current[h]][0] = (int)h;
_scratch[h][current[h]][1] = (int)mm;
_scratch[h][current[h]][2] = (int)nn;
_scratch[h][current[h]][3] = _idx[h][mm][nn];
current[h]++;
}
}
}
}
std::vector<torch::Tensor> to_cat;
for(size_t h = 0; h < H; h++)
if(current[h] > 0)
to_cat.push_back(scratch[h].slice(0, 0, current[h]));
if(!to_cat.empty())
ret.push_back(std::make_tuple(max_width, torch::cat(to_cat)));
}
ret_t superblock(torch::Tensor layout, int start_width) {
ret_t ret;
// block index
torch::Tensor idx = torch::zeros_like(layout);
int current = 0;
int64_t H = layout.size(0);
int64_t M = layout.size(1);
int64_t N = layout.size(2);
auto _layout = layout.accessor <int, 3>();
auto _idx = idx.accessor<int, 3>();
for(int64_t h = 0; h < H; h++)
for(int64_t m = 0; m < M; m++)
for(int64_t n = 0; n < N; n++){
if(_layout[h][m][n] == 0)
continue;
_idx[h][m][n] = current++;
}
// scratch memory
torch::Tensor scratch = torch::empty({H, layout.sum().item<int>(), 4}, layout.dtype());
for(int max_width = start_width; max_width > 0; max_width /= 2)
segment_blocks(layout, idx, scratch, max_width, ret);
return ret;
}
void init_superblocking(pybind11::module &m) {
m.def("superblock", &superblock, "super-blocking for block-sparse matrix multiplication");
}