[BUILD] Added automatic nightly build releases to pip in CI; removed build-time dependence on LLVM and PyTorch (#77)
Recently there has been more and more report about installation issues: - Installing Triton before upgrading pytorch can create some issues because Triton uses some torch headers - llvm-10-dev not available on some platform; llvm-11-dev not available on e.g. Ubuntu. absence of nightly builds This PR should fix all these issues. Some CMake tricks are used to download and install llvm at build time. Triton Python bindings were modified to remove dependence on pytorch ops. Midnight CI job added to generate binary wheels for all Triton version and update them on pypi's new triton-nightly project. This PR will also make it very easy to use LLVM forks in the future for whatever needs we have.
This commit is contained in:
committed by
Philippe Tillet
parent
3ad0a4d7be
commit
2f80a98776
@@ -1 +0,0 @@
|
||||
../../CMakeLists.txt
|
@@ -1 +0,0 @@
|
||||
../../cmake/
|
@@ -4,8 +4,6 @@
|
||||
#include "cutlass/library/singleton.h"
|
||||
#include "pybind11/pybind11.h"
|
||||
#include "triton/tools/bench.hpp"
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
#include <torch/extension.h>
|
||||
|
||||
using namespace cutlass;
|
||||
using namespace cutlass::library;
|
||||
@@ -132,58 +130,56 @@ const Operation *autotune(int M, int N, int K,
|
||||
}
|
||||
|
||||
// map of torch datatypes to cutlass datatypes
|
||||
std::map<caffe2::TypeIdentifier, NumericTypeID> type_map = {
|
||||
{caffe2::TypeMeta::Id<at::Half>(), NumericTypeID::kF16},
|
||||
{caffe2::TypeMeta::Id<float>(), NumericTypeID::kF32},
|
||||
{caffe2::TypeMeta::Id<double>(), NumericTypeID::kF64}};
|
||||
std::map<std::string, NumericTypeID> type_map = {
|
||||
{"float16", NumericTypeID::kF16},
|
||||
{"float32", NumericTypeID::kF32},
|
||||
{"float64", NumericTypeID::kF64}};
|
||||
|
||||
void cutlass_matmul(torch::Tensor A, torch::Tensor B, torch::Tensor C) {
|
||||
size_t M = A.size(0);
|
||||
size_t N = B.size(1);
|
||||
size_t K = A.size(1);
|
||||
size_t lda = A.stride(0);
|
||||
size_t ldb = B.stride(0);
|
||||
size_t ldc = C.stride(1);
|
||||
size_t ldd = C.stride(1);
|
||||
void *ptr_A = A.data_ptr();
|
||||
void *ptr_B = B.data_ptr();
|
||||
void *ptr_C = C.data_ptr();
|
||||
void cutlass_matmul(uintptr_t A, uintptr_t B, uintptr_t C,
|
||||
size_t M, size_t N, size_t K,
|
||||
size_t stride_a_0, size_t stride_a_1,
|
||||
size_t stride_b_0, size_t stride_b_1,
|
||||
size_t stride_c_0, size_t stride_c_1,
|
||||
std::string type_a, std::string type_b, std::string type_c,
|
||||
size_t dev_id, uint64_t stream_handle) {
|
||||
void *ptr_A = (void *)A;
|
||||
void *ptr_B = (void *)B;
|
||||
void *ptr_C = (void *)C;
|
||||
void *ptr_D = ptr_C;
|
||||
size_t lda = stride_a_0;
|
||||
size_t ldb = stride_b_0;
|
||||
size_t ldc = stride_c_1;
|
||||
size_t ldd = ldc;
|
||||
float alpha = 1.0f;
|
||||
float beta = 0.0f;
|
||||
// layout for A
|
||||
LayoutTypeID layout_A;
|
||||
if (A.stride(0) == 1)
|
||||
if (stride_a_0 == 1)
|
||||
layout_A = LayoutTypeID::kColumnMajor;
|
||||
else if (A.stride(1) == 1)
|
||||
else if (stride_a_1 == 1)
|
||||
layout_A = LayoutTypeID::kRowMajor;
|
||||
else {
|
||||
A = A.contiguous();
|
||||
layout_A = LayoutTypeID::kRowMajor;
|
||||
}
|
||||
else
|
||||
throw std::runtime_error("A layout is not supported");
|
||||
// layout for B
|
||||
LayoutTypeID layout_B;
|
||||
if (B.stride(0) == 1)
|
||||
if (stride_b_0 == 1)
|
||||
layout_B = LayoutTypeID::kColumnMajor;
|
||||
else if (B.stride(1) == 1)
|
||||
else if (stride_b_1 == 1)
|
||||
layout_B = LayoutTypeID::kRowMajor;
|
||||
else {
|
||||
B = B.contiguous();
|
||||
layout_B = LayoutTypeID::kRowMajor;
|
||||
}
|
||||
else
|
||||
throw std::runtime_error("B layout is not supported");
|
||||
// data types
|
||||
NumericTypeID element_compute = NumericTypeID::kF32;
|
||||
NumericTypeID element_A = type_map[A.dtype().id()];
|
||||
NumericTypeID element_B = type_map[B.dtype().id()];
|
||||
NumericTypeID element_C = type_map[C.dtype().id()];
|
||||
NumericTypeID element_A = type_map[type_a];
|
||||
NumericTypeID element_B = type_map[type_b];
|
||||
NumericTypeID element_C = type_map[type_c];
|
||||
// misc. flags
|
||||
ScalarPointerMode scalar_mode = ScalarPointerMode::kHost;
|
||||
NumericTypeID element_scalar = NumericTypeID::kF32;
|
||||
ComplexTransform transform_A = ComplexTransform::kNone;
|
||||
ComplexTransform transform_B = ComplexTransform::kNone;
|
||||
// runtime flags
|
||||
size_t dev_id = C.device().index();
|
||||
cudaStream_t stream = c10::cuda::getCurrentCUDAStream(dev_id).stream();
|
||||
cudaStream_t stream = (cudaStream_t)stream_handle;
|
||||
// auto-tune
|
||||
std::vector<size_t> tune_key = {M, N, K, (size_t)element_A, (size_t)element_B, (size_t)element_C,
|
||||
dev_id, (size_t)element_compute, (size_t)scalar_mode};
|
||||
|
@@ -1 +0,0 @@
|
||||
../../include/
|
@@ -1 +0,0 @@
|
||||
../../lib/
|
@@ -8,7 +8,6 @@ void init_cutlass(pybind11::module &m);
|
||||
PYBIND11_MODULE(libtriton, m) {
|
||||
m.doc() = "Python bindings to the C++ Triton API";
|
||||
init_triton(m);
|
||||
init_torch_utils(m);
|
||||
init_superblocking(m);
|
||||
#ifdef WITH_CUTLASS_BINDINGS
|
||||
init_cutlass(m);
|
||||
|
119
python/src/superblock.cc
Normal file
119
python/src/superblock.cc
Normal file
@@ -0,0 +1,119 @@
|
||||
#include <iostream>
|
||||
#include <pybind11/numpy.h>
|
||||
#include <pybind11/pybind11.h>
|
||||
#include <pybind11/stl.h>
|
||||
#include <string>
|
||||
#include <tuple>
|
||||
#include <vector>
|
||||
#ifdef _OPENMP
|
||||
#include <omp.h>
|
||||
#endif
|
||||
|
||||
// row-major 3d tensor
|
||||
class tensor_3d {
|
||||
public:
|
||||
tensor_3d(int size_0, int size_1, int size_2, int *data = nullptr) : data_(size_0 * size_1 * size_2, 0) {
|
||||
if (data)
|
||||
std::copy(data, data + data_.size(), data_.begin());
|
||||
stride_0_ = size_1 * size_2;
|
||||
stride_1_ = size_2;
|
||||
stride_2_ = 1;
|
||||
}
|
||||
|
||||
int &operator()(int i, int j, int k) {
|
||||
return data_[i * stride_0_ + j * stride_1_ + k];
|
||||
}
|
||||
|
||||
private:
|
||||
std::vector<int> data_;
|
||||
int stride_0_;
|
||||
int stride_1_;
|
||||
int stride_2_;
|
||||
};
|
||||
|
||||
std::vector<int> segment_blocks(tensor_3d &layout, tensor_3d &idx, int max_width, int H, int M, int N) {
|
||||
tensor_3d tmp(H, M, N);
|
||||
std::vector<int> current(H, 0);
|
||||
int num = 0;
|
||||
std::vector<int> lut(H * M * N * 4);
|
||||
for (size_t h = 0; h < H; h++) {
|
||||
// surrounding indices
|
||||
std::vector<int> ii_left(max_width, -1);
|
||||
std::vector<std::vector<int>> ii_top(max_width, std::vector<int>(N, -1));
|
||||
// start the dynamic programming algorithm
|
||||
for (size_t m = 0; m < M; m++) {
|
||||
for (size_t n = 0; n < N; n++) {
|
||||
int v = layout(h, m, n);
|
||||
if (v == 0)
|
||||
continue;
|
||||
int n_left = ii_left[max_width - 1];
|
||||
int m_top = ii_top[max_width - 1][n];
|
||||
int top = (m_top >= 0) ? tmp(h, m_top, n) : 0;
|
||||
int left = (n_left >= 0) ? tmp(h, m, n_left) : 0;
|
||||
int topleft = (m_top >= 0 && n_left >= 0) ? tmp(h, m_top, n_left) : 0;
|
||||
int width = std::min(left, std::min(top, topleft)) + 1;
|
||||
// reset width if blocks cannot be
|
||||
// packed together (i.e., there's a 1 "in the middle")
|
||||
for (int nn = n_left + 1; nn < n; nn++)
|
||||
if (ii_top[max_width - 1][nn] > ii_top[max_width - 1][n])
|
||||
width = 1;
|
||||
tmp(h, m, n) = width;
|
||||
// update n_left ring buffer
|
||||
for (int k = 0; k < max_width - 1; k++)
|
||||
ii_left[k] = ii_left[k + 1];
|
||||
ii_left[max_width - 1] = n;
|
||||
// update ii_top ring buffer
|
||||
for (int k = 0; k < max_width - 1; k++)
|
||||
ii_top[k][n] = ii_top[k + 1][n];
|
||||
ii_top[max_width - 1][n] = m;
|
||||
// block is too small -- skip
|
||||
if (width != max_width)
|
||||
continue;
|
||||
// retained blocks are set to zeros
|
||||
for (size_t km = 0; km < max_width; km++)
|
||||
for (size_t kn = 0; kn < max_width; kn++) {
|
||||
int mm = ii_top[km][n];
|
||||
int nn = ii_left[kn];
|
||||
if (mm < 0 || nn < 0)
|
||||
continue;
|
||||
layout(h, mm, nn) = 0;
|
||||
tmp(h, mm, nn) = 0;
|
||||
lut[num++] = (int)h;
|
||||
lut[num++] = (int)mm;
|
||||
lut[num++] = (int)nn;
|
||||
lut[num++] = idx(h, mm, nn);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
lut.resize(num);
|
||||
return lut;
|
||||
}
|
||||
|
||||
typedef std::pair<int, pybind11::array_t<int>> lut_t;
|
||||
|
||||
std::vector<lut_t> superblock(uintptr_t LAYOUT, int H, int M, int N, int start_width) {
|
||||
std::vector<lut_t> ret;
|
||||
int current = 0;
|
||||
tensor_3d layout(H, M, N, (int *)LAYOUT);
|
||||
tensor_3d idx(H, M, N);
|
||||
for (int64_t h = 0; h < H; h++)
|
||||
for (int64_t m = 0; m < M; m++)
|
||||
for (int64_t n = 0; n < N; n++) {
|
||||
if (layout(h, m, n) == 0)
|
||||
continue;
|
||||
idx(h, m, n) = current++;
|
||||
}
|
||||
// create lut
|
||||
for (int max_width = start_width; max_width > 0; max_width /= 2) {
|
||||
auto lut = segment_blocks(layout, idx, max_width, H, M, N);
|
||||
if (lut.size() == 0)
|
||||
continue;
|
||||
ret.push_back(std::make_pair(max_width, pybind11::array_t<int>(lut.size(), lut.data())));
|
||||
}
|
||||
return ret;
|
||||
}
|
||||
|
||||
void init_superblocking(pybind11::module &m) {
|
||||
m.def("superblock", &superblock, "super-blocking for block-sparse matrix multiplication");
|
||||
}
|
@@ -1,117 +0,0 @@
|
||||
#include <torch/extension.h>
|
||||
#include <string>
|
||||
#include <tuple>
|
||||
#include <vector>
|
||||
#ifdef _OPENMP
|
||||
#include <omp.h>
|
||||
#endif
|
||||
|
||||
typedef std::vector<std::tuple<int, torch::Tensor>> ret_t;
|
||||
|
||||
void segment_blocks(torch::Tensor layout, torch::Tensor idx, torch::Tensor scratch, int max_width, ret_t& ret){
|
||||
size_t H = layout.size(0);
|
||||
size_t M = layout.size(1);
|
||||
size_t N = layout.size(2);
|
||||
torch::Tensor tmp = torch::zeros_like(layout);
|
||||
auto _tmp = tmp.accessor <int, 3>();
|
||||
auto _layout = layout.accessor <int, 3>();
|
||||
auto _idx = idx.accessor <int, 3>();
|
||||
auto _scratch = scratch.accessor<int, 3>();
|
||||
std::vector<int> current(H, 0);
|
||||
#ifdef _OPENMP
|
||||
#pragma omp parallel for
|
||||
#endif
|
||||
for(size_t h = 0; h < H; h++){
|
||||
// surrounding indices
|
||||
std::vector<int> ii_left(max_width, -1);
|
||||
std::vector<std::vector<int>> ii_top(max_width, std::vector<int>(N, -1));
|
||||
|
||||
for(size_t m = 0; m < M; m++){
|
||||
for(size_t n = 0; n < N; n++){
|
||||
int v = _layout[h][m][n];
|
||||
if(v == 0)
|
||||
continue;
|
||||
int n_left= ii_left[max_width-1];
|
||||
int m_top = ii_top [max_width-1][n];
|
||||
int top = (m_top >= 0) ? _tmp[h][m_top][n] : 0;
|
||||
int left = (n_left >= 0) ? _tmp[h][m][n_left] : 0;
|
||||
int topleft = (m_top >=0 && n_left >= 0) ? _tmp[h][m_top][n_left] : 0;
|
||||
int width = std::min(left, std::min(top, topleft)) + 1;
|
||||
|
||||
// reset width if blocks cannot be
|
||||
// packed together (i.e., there's a 1 "in the middle")
|
||||
for(int nn = n_left + 1; nn < n; nn++)
|
||||
if(ii_top[max_width-1][nn] > ii_top[max_width-1][n])
|
||||
width = 1;
|
||||
_tmp[h][m][n] = width;
|
||||
|
||||
// update n_left ring buffer
|
||||
for(int k = 0; k < max_width-1; k++)
|
||||
ii_left[k] = ii_left[k+1];
|
||||
ii_left[max_width-1] = n;
|
||||
|
||||
// update ii_top ring buffer
|
||||
for(int k = 0; k < max_width-1; k++)
|
||||
ii_top[k][n] = ii_top[k+1][n];
|
||||
ii_top[max_width-1][n] = m;
|
||||
|
||||
// block is too small -- skip
|
||||
if(width != max_width)
|
||||
continue;
|
||||
|
||||
// retained blocks are set to zeros
|
||||
for(size_t km = 0; km < max_width; km++)
|
||||
for(size_t kn = 0; kn < max_width; kn++)
|
||||
{
|
||||
int mm = ii_top[km][n];
|
||||
int nn = ii_left[kn];
|
||||
if(mm < 0 || nn < 0)
|
||||
continue;
|
||||
_layout[h][mm][nn] = 0;
|
||||
_tmp[h][mm][nn] = 0;
|
||||
_scratch[h][current[h]][0] = (int)h;
|
||||
_scratch[h][current[h]][1] = (int)mm;
|
||||
_scratch[h][current[h]][2] = (int)nn;
|
||||
_scratch[h][current[h]][3] = _idx[h][mm][nn];
|
||||
current[h]++;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
std::vector<torch::Tensor> to_cat;
|
||||
for(size_t h = 0; h < H; h++)
|
||||
if(current[h] > 0)
|
||||
to_cat.push_back(scratch[h].slice(0, 0, current[h]));
|
||||
if(!to_cat.empty())
|
||||
ret.push_back(std::make_tuple(max_width, torch::cat(to_cat)));
|
||||
}
|
||||
|
||||
|
||||
ret_t superblock(torch::Tensor layout, int start_width) {
|
||||
ret_t ret;
|
||||
// block index
|
||||
torch::Tensor idx = torch::zeros_like(layout);
|
||||
int current = 0;
|
||||
int64_t H = layout.size(0);
|
||||
int64_t M = layout.size(1);
|
||||
int64_t N = layout.size(2);
|
||||
auto _layout = layout.accessor <int, 3>();
|
||||
auto _idx = idx.accessor<int, 3>();
|
||||
for(int64_t h = 0; h < H; h++)
|
||||
for(int64_t m = 0; m < M; m++)
|
||||
for(int64_t n = 0; n < N; n++){
|
||||
if(_layout[h][m][n] == 0)
|
||||
continue;
|
||||
_idx[h][m][n] = current++;
|
||||
}
|
||||
// scratch memory
|
||||
torch::Tensor scratch = torch::empty({H, layout.sum().item<int>(), 4}, layout.dtype());
|
||||
for(int max_width = start_width; max_width > 0; max_width /= 2)
|
||||
segment_blocks(layout, idx, scratch, max_width, ret);
|
||||
return ret;
|
||||
}
|
||||
|
||||
|
||||
void init_superblocking(pybind11::module &m) {
|
||||
m.def("superblock", &superblock, "super-blocking for block-sparse matrix multiplication");
|
||||
}
|
@@ -1,32 +0,0 @@
|
||||
|
||||
#include "triton/driver/device.h"
|
||||
#include "triton/driver/stream.h"
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
#include <cuda_runtime_api.h>
|
||||
#include <torch/extension.h>
|
||||
|
||||
namespace torch_utils {
|
||||
|
||||
uint64_t cu_device(int64_t dev_id) {
|
||||
CUdevice handle;
|
||||
triton::driver::dispatch::cuDeviceGet(&handle, dev_id);
|
||||
return (uint64_t)handle;
|
||||
}
|
||||
|
||||
uint64_t cu_stream(int64_t dev_id) {
|
||||
return (uint64_t)c10::cuda::getCurrentCUDAStream(dev_id).stream();
|
||||
}
|
||||
|
||||
void set_device(int64_t dev_id) {
|
||||
if (dev_id >= 0)
|
||||
C10_CUDA_CHECK(cudaSetDevice(dev_id));
|
||||
}
|
||||
|
||||
} // namespace torch_utils
|
||||
|
||||
void init_torch_utils(pybind11::module &m) {
|
||||
pybind11::module subm = m.def_submodule("torch_utils");
|
||||
subm.def("cu_device", &torch_utils::cu_device);
|
||||
subm.def("cu_stream", &torch_utils::cu_stream);
|
||||
subm.def("set_device", &torch_utils::set_device);
|
||||
}
|
@@ -89,7 +89,11 @@ void init_triton_driver(py::module &&m) {
|
||||
py::class_<drv::device>(m, "device");
|
||||
// cuda device
|
||||
py::class_<drv::cu_device, driver::device>(m, "cu_device")
|
||||
.def(py::init<CUdevice, bool>());
|
||||
.def(py::init([](int dev_id, bool take_ownership) {
|
||||
CUdevice handle;
|
||||
drv::dispatch::cuDeviceGet(&handle, dev_id);
|
||||
return new drv::cu_device(handle, take_ownership);
|
||||
}));
|
||||
// host device
|
||||
py::class_<drv::host_device, driver::device>(m, "host_device")
|
||||
.def(py::init<>());
|
||||
|
Reference in New Issue
Block a user