[PYTHON][TESTS][DOC] Various improvement of the API and code quality:
* Simplified `triton.kernel` API to achieve lower latency: > .data_ptr() must now be passed as kernel argument. No more implicit conversion from torch.tensor > compilation options are now constant attributes, i.e., opt.d('VAR') becomes opt.VAR > torch.device must now be passed explicitly to triton.kernel (no longer inferred from torch.tensor arguments) * C++ tests moved to `python/tests/` * C++ tutorial created in `tutorials/` * Python tutorial created in python/tutorials/ * Version changed to 1.0alpha * No longer copying C++ headers into the Python package * added python/triton/ops/ package for pre-written Triton ops
This commit is contained in:
@@ -13,15 +13,19 @@
|
||||
#include "triton/ir/function.h"
|
||||
|
||||
using namespace triton;
|
||||
|
||||
namespace rt = triton::runtime;
|
||||
namespace drv = triton::driver;
|
||||
|
||||
typedef std::pair<int, int> map_key_t;
|
||||
std::map<map_key_t, std::shared_ptr<rt::function::grid_fn_ty>> id_grid_map;
|
||||
std::map<map_key_t, std::shared_ptr<rt::function>> id_fn_map;
|
||||
|
||||
CUstream torch_get_cuda_stream(int64_t dev_id);
|
||||
CUdevice torch_get_cuda_device(int64_t dev_id);
|
||||
std::map<map_key_t, std::shared_ptr<rt::function::grid_fn_ty>> id_grid_map;
|
||||
std::map<int, std::shared_ptr<rt::function>> id_fn_map;
|
||||
std::map<int, std::shared_ptr<triton::driver::device>> tt_devices;
|
||||
std::map<int, std::shared_ptr<triton::driver::stream>> tt_streams;
|
||||
std::unordered_map<const rt::options_t*, pybind11::object> opt_cache_;
|
||||
extern CUstream torch_get_cuda_stream(int64_t dev_id);
|
||||
extern CUdevice torch_get_cuda_device(int64_t dev_id);
|
||||
|
||||
|
||||
/* Grid utilities */
|
||||
|
||||
@@ -36,106 +40,123 @@ void delete_grid(const map_key_t& key) {
|
||||
|
||||
/* Function utilities */
|
||||
|
||||
void register_fn(const map_key_t& key,
|
||||
void register_fn(int op_id,
|
||||
int dev_id,
|
||||
const std::string& src,
|
||||
const rt::options_space_t& opt) {
|
||||
if(id_fn_map.find(key) == id_fn_map.end())
|
||||
id_fn_map[key].reset(new rt::function(src, opt, ""));
|
||||
if(tt_devices.find(dev_id) == tt_devices.end()) {
|
||||
driver::device* device;
|
||||
driver::stream* stream;
|
||||
if(dev_id >= 0){
|
||||
device = new triton::driver::cu_device(torch_get_cuda_device(dev_id), false);
|
||||
stream = new triton::driver::cu_stream(torch_get_cuda_stream(dev_id), false);
|
||||
}
|
||||
else{
|
||||
device = new triton::driver::host_device();
|
||||
stream = new triton::driver::host_stream();
|
||||
}
|
||||
tt_devices[dev_id].reset(device);
|
||||
tt_streams[dev_id].reset(stream);
|
||||
}
|
||||
if(id_fn_map.find(op_id) == id_fn_map.end()){
|
||||
id_fn_map[op_id].reset(new rt::function(src, opt, &*tt_devices[dev_id]));
|
||||
}
|
||||
for(const auto& k: id_fn_map[op_id]->get_kernels()){
|
||||
const rt::options_t* opt = &k.first;
|
||||
pybind11::object obj = pybind11::cast(opt, pybind11::return_value_policy::reference);
|
||||
for(auto x: opt->defines)
|
||||
if(std::all_of(x.second.begin(), x.second.end(), ::isdigit))
|
||||
obj.attr(x.first.c_str()) = std::stoi(x.second);
|
||||
opt_cache_[&k.second->opt] = obj;
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
void delete_fn(const map_key_t& key) {
|
||||
id_fn_map.erase(key);
|
||||
void delete_fn(int op_id) {
|
||||
id_fn_map.erase(op_id);
|
||||
}
|
||||
|
||||
std::string get_fn_asm(const map_key_t& key, rt::asm_mode_t mode, const rt::options_t& opt) {
|
||||
triton::driver::cu_device device(key.second, false);
|
||||
return id_fn_map[key]->get_asm(mode, &device, opt);
|
||||
}
|
||||
|
||||
void cleanup() {
|
||||
id_grid_map.clear();
|
||||
id_fn_map.clear();
|
||||
opt_cache_.clear();
|
||||
}
|
||||
|
||||
size_t make_op_id() {
|
||||
return id_fn_map.size();
|
||||
}
|
||||
|
||||
/* Function signature */
|
||||
void make_module(const std::string& src, ir::module* ir,
|
||||
const runtime::options_space_t& opt) {
|
||||
std::string copy = triton::runtime::function::preheader() + src;
|
||||
// pre-process
|
||||
TokenSequence tokens;
|
||||
Preprocessor cpp(©, true);
|
||||
for(auto it: opt.defines){
|
||||
cpp.AddMacro(it.first, &it.second[0]);
|
||||
}
|
||||
cpp.Process(tokens);
|
||||
// parse
|
||||
Parser parser(tokens);
|
||||
parser.Parse();
|
||||
Generator gen(&parser);
|
||||
gen.Gen(ir);
|
||||
std::vector<rt::arg_type> get_fn_signature(size_t op_id) {
|
||||
return id_fn_map[op_id]->get_kernels()[0].second->get_sig();
|
||||
}
|
||||
|
||||
std::vector<rt::arg_type> get_fn_signature(const std::string& src,
|
||||
const runtime::options_space_t& opt) {
|
||||
// triton-ir code-gen
|
||||
ir::context ctx;
|
||||
auto ir = std::shared_ptr<ir::module>(new ir::module("", ctx));
|
||||
make_module(src, &*ir, opt);
|
||||
// function
|
||||
ir::function* fn = ir->get_function_list().front();
|
||||
// extract signature
|
||||
std::vector<rt::arg_type> ret;
|
||||
ir::function_type* ty = fn->get_fn_type();
|
||||
for(size_t i = 0; i < ty->get_num_params(); i++)
|
||||
ret.push_back(rt::convert(ty->get_param_ty(i)));
|
||||
return ret;
|
||||
void launch_kernel(int64_t op_id, int64_t dev_id, const std::string& args, size_t grid_0, size_t grid_1, size_t grid_2){
|
||||
rt::function* fn = id_fn_map.at(op_id).get();
|
||||
(*fn)((void**)args.c_str(), args.size(), {grid_0, grid_1, grid_2}, &*tt_streams[dev_id]);
|
||||
|
||||
// for(size_t n = 0; n < constant_names.size(); n++){
|
||||
// const torch::Tensor& x = constant_vals[n];
|
||||
// fn->set_cst(constant_names[n].c_str(), (char*)x.data_ptr(), x.numel()*x.element_size());
|
||||
}
|
||||
|
||||
typedef triton::runtime::options_t options_t;
|
||||
typedef triton::runtime::options_space_t options_space_t;
|
||||
pybind11::object autotune(int64_t op_id, int64_t dev_id, const std::string& args, const rt::function::grid_fn_ty& grid){
|
||||
rt::function* fn = id_fn_map.at(op_id).get();
|
||||
auto wrapper = [&grid](const rt::options_t& opt){
|
||||
pybind11::object obj = pybind11::cast(&opt, pybind11::return_value_policy::reference);
|
||||
for(auto x: opt.defines)
|
||||
if(std::all_of(x.second.begin(), x.second.end(), ::isdigit))
|
||||
obj.attr(x.first.c_str()) = std::stoi(x.second);
|
||||
return grid(*obj.cast<rt::options_t*>());
|
||||
};
|
||||
rt::kernel* kernel = fn->autotune((void**)args.c_str(), args.size(), wrapper, &*tt_streams[dev_id]);
|
||||
return opt_cache_.at(&kernel->opt);
|
||||
}
|
||||
|
||||
|
||||
void init_superblocking(pybind11::module &m);
|
||||
void init_launch(pybind11::module &m);
|
||||
|
||||
PYBIND11_MODULE(libtriton, m) {
|
||||
m.doc() = "Python bindings to the C++ Triton API";
|
||||
|
||||
// bindings for triton classes
|
||||
pybind11::enum_<rt::arg_type>(m, "arg_type")
|
||||
.value("int1", rt::INT1_T)
|
||||
.value("int8", rt::INT8_T)
|
||||
.value("int16", rt::INT16_T)
|
||||
.value("int32", rt::INT32_T)
|
||||
.value("int64", rt::INT64_T)
|
||||
.value("half", rt::HALF_T)
|
||||
.value("float", rt::FLOAT_T)
|
||||
.value("int1" , rt::INT1_T)
|
||||
.value("int8" , rt::INT8_T)
|
||||
.value("int16" , rt::INT16_T)
|
||||
.value("int32" , rt::INT32_T)
|
||||
.value("int64" , rt::INT64_T)
|
||||
.value("half" , rt::HALF_T)
|
||||
.value("float" , rt::FLOAT_T)
|
||||
.value("double", rt::DOUBLE_T)
|
||||
.value("buffer", rt::BUFFER_T);
|
||||
|
||||
pybind11::enum_<rt::asm_mode_t>(m, "asm_mode")
|
||||
.value("ptx", rt::ASM_NV_PTX)
|
||||
.value("ptx" , rt::ASM_NV_PTX)
|
||||
.value("sass", rt::ASM_NV_SASS);
|
||||
|
||||
pybind11::class_<options_t>(m, "options")
|
||||
.def(pybind11::init<>())
|
||||
.def("d", &options_t::D<int>)
|
||||
.def_readwrite("num_warps", &options_t::num_warps)
|
||||
.def_readwrite("defines" , &options_t::defines);
|
||||
pybind11::class_<rt::options_t>(m, "options", pybind11::dynamic_attr())
|
||||
.def_readwrite("num_warps", &rt::options_t::num_warps)
|
||||
.def_readwrite("defines" , &rt::options_t::defines);
|
||||
|
||||
pybind11::class_<options_space_t>(m, "options_space")
|
||||
pybind11::class_<rt::options_space_t>(m, "options_space")
|
||||
.def(pybind11::init<>())
|
||||
.def_readwrite("defines", &options_space_t::defines)
|
||||
.def_readwrite("num_warps", &options_space_t::num_warps);
|
||||
.def_readwrite("num_warps", &rt::options_space_t::num_warps)
|
||||
.def_readwrite("defines" , &rt::options_space_t::defines);
|
||||
|
||||
// hooks into triton constructs since frameworks may not use pybind11
|
||||
m.def("get_fn_signature", &get_fn_signature);
|
||||
m.def("get_fn_asm", &get_fn_asm);
|
||||
// m.def("get_fn_asm", &get_fn_asm);
|
||||
m.def("register_grid", ®ister_grid);
|
||||
m.def("delete_grid", &delete_grid);
|
||||
m.def("register_fn", ®ister_fn);
|
||||
m.def("delete_fn", &delete_fn);
|
||||
m.def("make_op_id", &make_op_id);
|
||||
m.def("cleanup", &cleanup);
|
||||
;
|
||||
m.def("autotune", &autotune, pybind11::return_value_policy::reference);
|
||||
m.def("launch_kernel", &launch_kernel);
|
||||
|
||||
init_launch(m);
|
||||
init_superblocking(m);
|
||||
}
|
||||
|
@@ -1,95 +0,0 @@
|
||||
// Thanks to Scott Gray (OpenAI) for the idea to pass the arguments
|
||||
// as a string constructed with struct.pack in python
|
||||
|
||||
#include "triton/driver/buffer.h"
|
||||
#include "triton/driver/stream.h"
|
||||
#include "triton/runtime/function.h"
|
||||
#include "triton/tools/bench.hpp"
|
||||
#include "torch/script.h"
|
||||
#include "ATen/cuda/CUDAContext.h"
|
||||
#include <c10/cuda/CUDAException.h>
|
||||
#include <cuda_runtime_api.h>
|
||||
|
||||
|
||||
namespace rt = triton::runtime;
|
||||
namespace drv = triton::driver;
|
||||
|
||||
typedef std::pair<int, int> map_key_t;
|
||||
extern std::map<map_key_t, std::shared_ptr<rt::function::grid_fn_ty>> id_grid_map;
|
||||
extern std::map<map_key_t, std::shared_ptr<rt::function>> id_fn_map;
|
||||
std::shared_ptr<drv::device> host_device;
|
||||
std::shared_ptr<drv::context> host_context;
|
||||
std::shared_ptr<drv::stream> host_stream;
|
||||
|
||||
int64_t cdiv_sum(torch::Tensor x, int64_t div){
|
||||
TORCH_CHECK(!x.is_cuda(), "Argument of cdiv_sum must be a CPU tensor")
|
||||
auto _x = x.accessor<int, 1>();
|
||||
int64_t ret = 0;
|
||||
for(size_t i = 0; i < x.size(0); i++)
|
||||
ret += (_x[i] + div - 1) / div;
|
||||
return ret;
|
||||
}
|
||||
|
||||
void init_host_stream() {
|
||||
if(!host_stream){
|
||||
host_device.reset(new drv::host_device());
|
||||
host_context.reset(drv::context::create(&*host_device));
|
||||
host_stream.reset(drv::stream::create(host_context->backend()));
|
||||
}
|
||||
}
|
||||
|
||||
CUstream torch_get_cuda_stream(int64_t dev_id) {
|
||||
return (CUstream)c10::cuda::getCurrentCUDAStream(dev_id).stream();
|
||||
}
|
||||
|
||||
CUdeviceptr torch_get_cuda_device(int64_t dev_id) {
|
||||
CUdevice ret;
|
||||
triton::driver::dispatch::cuDeviceGet(&ret, dev_id);
|
||||
return ret;
|
||||
}
|
||||
|
||||
void synchronize(int64_t dev_id) {
|
||||
if(dev_id == -1){
|
||||
init_host_stream();
|
||||
host_stream->synchronize();
|
||||
}
|
||||
else{
|
||||
triton::driver::cu_stream stream(torch_get_cuda_stream(dev_id), false);
|
||||
stream.synchronize();
|
||||
}
|
||||
}
|
||||
|
||||
torch::Tensor cuda_empty_like(torch::Tensor x){
|
||||
if(x.nbytes() == 0)
|
||||
return torch::empty_like(x);
|
||||
void* data;
|
||||
cudaMalloc(&data, x.nbytes());
|
||||
auto ret = torch::from_blob((void*)data, x.sizes(), x.strides(), [data](void* ptr) { cudaFree(data); }, x.options());
|
||||
return ret;
|
||||
}
|
||||
|
||||
void launch_kernel(int64_t op_id, int64_t dev_id, const std::string& args,
|
||||
const std::vector<std::string>& constant_names, const std::vector<torch::Tensor>& constant_vals){
|
||||
rt::function* fn = id_fn_map.at({op_id, dev_id}).get();
|
||||
for(size_t n = 0; n < constant_names.size(); n++){
|
||||
const torch::Tensor& x = constant_vals[n];
|
||||
fn->set_cst(constant_names[n].c_str(), (char*)x.data_ptr(), x.numel()*x.element_size());
|
||||
}
|
||||
if(dev_id == -1){
|
||||
init_host_stream();
|
||||
(*fn)((void**)args.c_str(), args.size(), *id_grid_map.at({op_id, dev_id}), &*host_stream, &*host_device);
|
||||
}
|
||||
else{
|
||||
C10_CUDA_CHECK(cudaSetDevice(dev_id));
|
||||
triton::driver::cu_stream stream(torch_get_cuda_stream(dev_id), false);
|
||||
triton::driver::cu_device device(torch_get_cuda_device(dev_id), false);
|
||||
(*fn)((void**)args.c_str(), args.size(), *id_grid_map.at({op_id, dev_id}), &stream, &device);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
static auto registry = torch::RegisterOperators()
|
||||
.op("triton::launch_kernel", &launch_kernel)
|
||||
.op("triton::cuda_empty_like", &cuda_empty_like)
|
||||
.op("triton::cdiv_sum", &cdiv_sum)
|
||||
.op("triton::synchronize", &synchronize);
|
83
python/src/torch/launch.cc
Normal file
83
python/src/torch/launch.cc
Normal file
@@ -0,0 +1,83 @@
|
||||
// Thanks to Scott Gray (OpenAI) for the idea to pass the arguments
|
||||
// as a string constructed with struct.pack in python
|
||||
|
||||
#include <pybind11/pybind11.h>
|
||||
#include <pybind11/stl.h>
|
||||
#include <pybind11/functional.h>
|
||||
#include "triton/driver/buffer.h"
|
||||
#include "triton/driver/stream.h"
|
||||
#include "triton/runtime/function.h"
|
||||
#include "triton/tools/bench.hpp"
|
||||
#include "torch/script.h"
|
||||
#include "ATen/cuda/CUDAContext.h"
|
||||
#include <c10/cuda/CUDAException.h>
|
||||
#include <cuda_runtime_api.h>
|
||||
|
||||
|
||||
namespace rt = triton::runtime;
|
||||
namespace drv = triton::driver;
|
||||
|
||||
typedef std::pair<int, int> map_key_t;
|
||||
extern std::map<map_key_t, std::shared_ptr<rt::function::grid_fn_ty>> id_grid_map;
|
||||
extern std::map<int, std::shared_ptr<rt::function>> id_fn_map;
|
||||
extern std::map<int, std::shared_ptr<drv::device>> tt_devices;
|
||||
extern std::map<int, std::shared_ptr<drv::stream>> tt_streams;
|
||||
|
||||
|
||||
int64_t cdiv(int64_t a, int64_t b) {
|
||||
return (a + b - 1) / b;
|
||||
}
|
||||
|
||||
int64_t largest_pow2_divisor(int64_t a){
|
||||
if(a % 8 == 0) return 8;
|
||||
if(a % 4 == 0) return 4;
|
||||
if(a % 2 == 0) return 2;
|
||||
return 1;
|
||||
}
|
||||
|
||||
int64_t cdiv_sum(torch::Tensor x, int64_t div){
|
||||
TORCH_CHECK(!x.is_cuda(), "Argument of cdiv_sum must be a CPU tensor")
|
||||
auto _x = x.accessor<int, 1>();
|
||||
int64_t ret = 0;
|
||||
for(size_t i = 0; i < x.size(0); i++)
|
||||
ret += (_x[i] + div - 1) / div;
|
||||
return ret;
|
||||
}
|
||||
|
||||
CUstream torch_get_cuda_stream(int64_t dev_id) {
|
||||
return (CUstream)c10::cuda::getCurrentCUDAStream(dev_id).stream();
|
||||
}
|
||||
|
||||
CUdeviceptr torch_get_cuda_device(int64_t dev_id) {
|
||||
CUdevice ret;
|
||||
triton::driver::dispatch::cuDeviceGet(&ret, dev_id);
|
||||
return ret;
|
||||
}
|
||||
|
||||
void synchronize(int64_t dev_id) {
|
||||
tt_streams[dev_id]->synchronize();
|
||||
}
|
||||
|
||||
torch::Tensor cuda_empty_like(torch::Tensor x){
|
||||
if(x.nbytes() == 0)
|
||||
return torch::empty_like(x);
|
||||
void* data;
|
||||
cudaMalloc(&data, x.nbytes());
|
||||
auto ret = torch::from_blob((void*)data, x.sizes(), x.strides(), [data](void* ptr) { cudaFree(data); }, x.options());
|
||||
return ret;
|
||||
}
|
||||
|
||||
void cuda_set_device(int64_t dev_id) {
|
||||
if(dev_id >= 0)
|
||||
C10_CUDA_CHECK(cudaSetDevice(dev_id));
|
||||
}
|
||||
|
||||
|
||||
void init_launch(pybind11::module &m) {
|
||||
m.def("cuda_set_device", &cuda_set_device);
|
||||
m.def("cuda_empty_like", &cuda_empty_like);
|
||||
m.def("largest_pow2_divisor", &largest_pow2_divisor);
|
||||
m.def("cdiv", &cdiv);
|
||||
m.def("cdiv_sum", &cdiv_sum);
|
||||
m.def("synchronize", &synchronize);
|
||||
}
|
117
python/src/torch/superblock.cc
Normal file
117
python/src/torch/superblock.cc
Normal file
@@ -0,0 +1,117 @@
|
||||
#include <torch/extension.h>
|
||||
#include <string>
|
||||
#include <tuple>
|
||||
#include <vector>
|
||||
#ifdef _OPENMP
|
||||
#include <omp.h>
|
||||
#endif
|
||||
|
||||
typedef std::vector<std::tuple<int, torch::Tensor>> ret_t;
|
||||
|
||||
void segment_blocks(torch::Tensor layout, torch::Tensor idx, torch::Tensor scratch, int max_width, ret_t& ret){
|
||||
size_t H = layout.size(0);
|
||||
size_t M = layout.size(1);
|
||||
size_t N = layout.size(2);
|
||||
torch::Tensor tmp = torch::zeros_like(layout);
|
||||
auto _tmp = tmp.accessor <int, 3>();
|
||||
auto _layout = layout.accessor <int, 3>();
|
||||
auto _idx = idx.accessor <int, 3>();
|
||||
auto _scratch = scratch.accessor<int, 3>();
|
||||
std::vector<int> current(H, 0);
|
||||
#ifdef _OPENMP
|
||||
#pragma omp parallel for
|
||||
#endif
|
||||
for(size_t h = 0; h < H; h++){
|
||||
// surrounding indices
|
||||
std::vector<int> ii_left(max_width, -1);
|
||||
std::vector<std::vector<int>> ii_top(max_width, std::vector<int>(N, -1));
|
||||
|
||||
for(size_t m = 0; m < M; m++){
|
||||
for(size_t n = 0; n < N; n++){
|
||||
int v = _layout[h][m][n];
|
||||
if(v == 0)
|
||||
continue;
|
||||
int n_left= ii_left[max_width-1];
|
||||
int m_top = ii_top [max_width-1][n];
|
||||
int top = (m_top >= 0) ? _tmp[h][m_top][n] : 0;
|
||||
int left = (n_left >= 0) ? _tmp[h][m][n_left] : 0;
|
||||
int topleft = (m_top >=0 && n_left >= 0) ? _tmp[h][m_top][n_left] : 0;
|
||||
int width = std::min(left, std::min(top, topleft)) + 1;
|
||||
|
||||
// reset width if blocks cannot be
|
||||
// packed together (i.e., there's a 1 "in the middle")
|
||||
for(int nn = n_left + 1; nn < n; nn++)
|
||||
if(ii_top[max_width-1][nn] > ii_top[max_width-1][n])
|
||||
width = 1;
|
||||
_tmp[h][m][n] = width;
|
||||
|
||||
// update n_left ring buffer
|
||||
for(int k = 0; k < max_width-1; k++)
|
||||
ii_left[k] = ii_left[k+1];
|
||||
ii_left[max_width-1] = n;
|
||||
|
||||
// update ii_top ring buffer
|
||||
for(int k = 0; k < max_width-1; k++)
|
||||
ii_top[k][n] = ii_top[k+1][n];
|
||||
ii_top[max_width-1][n] = m;
|
||||
|
||||
// block is too small -- skip
|
||||
if(width != max_width)
|
||||
continue;
|
||||
|
||||
// retained blocks are set to zeros
|
||||
for(size_t km = 0; km < max_width; km++)
|
||||
for(size_t kn = 0; kn < max_width; kn++)
|
||||
{
|
||||
int mm = ii_top[km][n];
|
||||
int nn = ii_left[kn];
|
||||
if(mm < 0 || nn < 0)
|
||||
continue;
|
||||
_layout[h][mm][nn] = 0;
|
||||
_tmp[h][mm][nn] = 0;
|
||||
_scratch[h][current[h]][0] = (int)h;
|
||||
_scratch[h][current[h]][1] = (int)mm;
|
||||
_scratch[h][current[h]][2] = (int)nn;
|
||||
_scratch[h][current[h]][3] = _idx[h][mm][nn];
|
||||
current[h]++;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
std::vector<torch::Tensor> to_cat;
|
||||
for(size_t h = 0; h < H; h++)
|
||||
if(current[h] > 0)
|
||||
to_cat.push_back(scratch[h].slice(0, 0, current[h]));
|
||||
if(!to_cat.empty())
|
||||
ret.push_back(std::make_tuple(max_width, torch::cat(to_cat)));
|
||||
}
|
||||
|
||||
|
||||
ret_t superblock(torch::Tensor layout, int start_width) {
|
||||
ret_t ret;
|
||||
// block index
|
||||
torch::Tensor idx = torch::zeros_like(layout);
|
||||
int current = 0;
|
||||
int64_t H = layout.size(0);
|
||||
int64_t M = layout.size(1);
|
||||
int64_t N = layout.size(2);
|
||||
auto _layout = layout.accessor <int, 3>();
|
||||
auto _idx = idx.accessor<int, 3>();
|
||||
for(int64_t h = 0; h < H; h++)
|
||||
for(int64_t m = 0; m < M; m++)
|
||||
for(int64_t n = 0; n < N; n++){
|
||||
if(_layout[h][m][n] == 0)
|
||||
continue;
|
||||
_idx[h][m][n] = current++;
|
||||
}
|
||||
// scratch memory
|
||||
torch::Tensor scratch = torch::empty({H, layout.sum().item<int>(), 4}, layout.dtype());
|
||||
for(int max_width = start_width; max_width > 0; max_width /= 2)
|
||||
segment_blocks(layout, idx, scratch, max_width, ret);
|
||||
return ret;
|
||||
}
|
||||
|
||||
|
||||
void init_superblocking(pybind11::module &m) {
|
||||
m.def("superblock", &superblock, "super-blocking for block-sparse matrix multiplication");
|
||||
}
|
Reference in New Issue
Block a user