[BUILD] Added automatic nightly build releases to pip in CI; removed build-time dependence on LLVM and PyTorch (#77)
Recently there has been more and more report about installation issues: - Installing Triton before upgrading pytorch can create some issues because Triton uses some torch headers - llvm-10-dev not available on some platform; llvm-11-dev not available on e.g. Ubuntu. absence of nightly builds This PR should fix all these issues. Some CMake tricks are used to download and install llvm at build time. Triton Python bindings were modified to remove dependence on pytorch ops. Midnight CI job added to generate binary wheels for all Triton version and update them on pypi's new triton-nightly project. This PR will also make it very easy to use LLVM forks in the future for whatever needs we have.
This commit is contained in:
committed by
Philippe Tillet
parent
3ad0a4d7be
commit
2f80a98776
@@ -4,8 +4,6 @@
|
||||
#include "cutlass/library/singleton.h"
|
||||
#include "pybind11/pybind11.h"
|
||||
#include "triton/tools/bench.hpp"
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
#include <torch/extension.h>
|
||||
|
||||
using namespace cutlass;
|
||||
using namespace cutlass::library;
|
||||
@@ -132,58 +130,56 @@ const Operation *autotune(int M, int N, int K,
|
||||
}
|
||||
|
||||
// map of torch datatypes to cutlass datatypes
|
||||
std::map<caffe2::TypeIdentifier, NumericTypeID> type_map = {
|
||||
{caffe2::TypeMeta::Id<at::Half>(), NumericTypeID::kF16},
|
||||
{caffe2::TypeMeta::Id<float>(), NumericTypeID::kF32},
|
||||
{caffe2::TypeMeta::Id<double>(), NumericTypeID::kF64}};
|
||||
std::map<std::string, NumericTypeID> type_map = {
|
||||
{"float16", NumericTypeID::kF16},
|
||||
{"float32", NumericTypeID::kF32},
|
||||
{"float64", NumericTypeID::kF64}};
|
||||
|
||||
void cutlass_matmul(torch::Tensor A, torch::Tensor B, torch::Tensor C) {
|
||||
size_t M = A.size(0);
|
||||
size_t N = B.size(1);
|
||||
size_t K = A.size(1);
|
||||
size_t lda = A.stride(0);
|
||||
size_t ldb = B.stride(0);
|
||||
size_t ldc = C.stride(1);
|
||||
size_t ldd = C.stride(1);
|
||||
void *ptr_A = A.data_ptr();
|
||||
void *ptr_B = B.data_ptr();
|
||||
void *ptr_C = C.data_ptr();
|
||||
void cutlass_matmul(uintptr_t A, uintptr_t B, uintptr_t C,
|
||||
size_t M, size_t N, size_t K,
|
||||
size_t stride_a_0, size_t stride_a_1,
|
||||
size_t stride_b_0, size_t stride_b_1,
|
||||
size_t stride_c_0, size_t stride_c_1,
|
||||
std::string type_a, std::string type_b, std::string type_c,
|
||||
size_t dev_id, uint64_t stream_handle) {
|
||||
void *ptr_A = (void *)A;
|
||||
void *ptr_B = (void *)B;
|
||||
void *ptr_C = (void *)C;
|
||||
void *ptr_D = ptr_C;
|
||||
size_t lda = stride_a_0;
|
||||
size_t ldb = stride_b_0;
|
||||
size_t ldc = stride_c_1;
|
||||
size_t ldd = ldc;
|
||||
float alpha = 1.0f;
|
||||
float beta = 0.0f;
|
||||
// layout for A
|
||||
LayoutTypeID layout_A;
|
||||
if (A.stride(0) == 1)
|
||||
if (stride_a_0 == 1)
|
||||
layout_A = LayoutTypeID::kColumnMajor;
|
||||
else if (A.stride(1) == 1)
|
||||
else if (stride_a_1 == 1)
|
||||
layout_A = LayoutTypeID::kRowMajor;
|
||||
else {
|
||||
A = A.contiguous();
|
||||
layout_A = LayoutTypeID::kRowMajor;
|
||||
}
|
||||
else
|
||||
throw std::runtime_error("A layout is not supported");
|
||||
// layout for B
|
||||
LayoutTypeID layout_B;
|
||||
if (B.stride(0) == 1)
|
||||
if (stride_b_0 == 1)
|
||||
layout_B = LayoutTypeID::kColumnMajor;
|
||||
else if (B.stride(1) == 1)
|
||||
else if (stride_b_1 == 1)
|
||||
layout_B = LayoutTypeID::kRowMajor;
|
||||
else {
|
||||
B = B.contiguous();
|
||||
layout_B = LayoutTypeID::kRowMajor;
|
||||
}
|
||||
else
|
||||
throw std::runtime_error("B layout is not supported");
|
||||
// data types
|
||||
NumericTypeID element_compute = NumericTypeID::kF32;
|
||||
NumericTypeID element_A = type_map[A.dtype().id()];
|
||||
NumericTypeID element_B = type_map[B.dtype().id()];
|
||||
NumericTypeID element_C = type_map[C.dtype().id()];
|
||||
NumericTypeID element_A = type_map[type_a];
|
||||
NumericTypeID element_B = type_map[type_b];
|
||||
NumericTypeID element_C = type_map[type_c];
|
||||
// misc. flags
|
||||
ScalarPointerMode scalar_mode = ScalarPointerMode::kHost;
|
||||
NumericTypeID element_scalar = NumericTypeID::kF32;
|
||||
ComplexTransform transform_A = ComplexTransform::kNone;
|
||||
ComplexTransform transform_B = ComplexTransform::kNone;
|
||||
// runtime flags
|
||||
size_t dev_id = C.device().index();
|
||||
cudaStream_t stream = c10::cuda::getCurrentCUDAStream(dev_id).stream();
|
||||
cudaStream_t stream = (cudaStream_t)stream_handle;
|
||||
// auto-tune
|
||||
std::vector<size_t> tune_key = {M, N, K, (size_t)element_A, (size_t)element_B, (size_t)element_C,
|
||||
dev_id, (size_t)element_compute, (size_t)scalar_mode};
|
||||
|
Reference in New Issue
Block a user