From 44ca2c0cb878e827a4822c461f5efaa5536b50b0 Mon Sep 17 00:00:00 2001 From: Philippe Tillet Date: Thu, 26 Nov 2020 23:21:14 -0500 Subject: [PATCH] [DRIVER] Removed deprecated files and functions --- include/triton/driver/buffer.h | 1 - include/triton/driver/cublas.h | 229 --------------------------------- include/triton/driver/event.h | 29 ----- include/triton/driver/stream.h | 12 +- lib/driver/event.cc | 40 ------ lib/driver/stream.cc | 9 +- lib/runtime/function.cc | 2 +- 7 files changed, 6 insertions(+), 316 deletions(-) delete mode 100755 include/triton/driver/cublas.h delete mode 100755 include/triton/driver/event.h delete mode 100755 lib/driver/event.cc diff --git a/include/triton/driver/buffer.h b/include/triton/driver/buffer.h index a8d588640..b3e9d89be 100755 --- a/include/triton/driver/buffer.h +++ b/include/triton/driver/buffer.h @@ -20,7 +20,6 @@ public: buffer(size_t size, host_buffer_t hst, bool take_ownership); uintptr_t addr_as_uintptr_t(); static buffer* create(driver::context* ctx, size_t size); - driver::context* context(); size_t size(); protected: diff --git a/include/triton/driver/cublas.h b/include/triton/driver/cublas.h deleted file mode 100755 index 2553dcb89..000000000 --- a/include/triton/driver/cublas.h +++ /dev/null @@ -1,229 +0,0 @@ -/* Copyright 2015-2017 Philippe Tillet -* -* Permission is hereby granted, free of charge, to any person obtaining -* a copy of this software and associated documentation files -* (the "Software"), to deal in the Software without restriction, -* including without limitation the rights to use, copy, modify, merge, -* publish, distribute, sublicense, and/or sell copies of the Software, -* and to permit persons to whom the Software is furnished to do so, -* subject to the following conditions: -* -* The above copyright notice and this permission notice shall be -* included in all copies or substantial portions of the Software. -* -* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, -* EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF -* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. -* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY -* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, -* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE -* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. -*/ - -#ifndef TDL_INCLUDE_DRIVER_CUBLAS_H -#define TDL_INCLUDE_DRIVER_CUBLAS_H - -#include "isaac/templates/common.hpp" -#include "triton/driver/dispatch.h" -#include "triton/driver/buffer.h" -#include "triton/driver/stream.h" -#include "triton/driver/backend.h" -#include "triton/driver/error.h" -#include "triton/tools/bench.hpp" -#include "triton/tools/collections.hpp" - -namespace triton -{ -namespace driver -{ - -enum cublasStrategy_t{ - CUBLAS_PREFER_FASTEST, - CUBLAS_HEURISTICS -}; - - -static const std::vector cublasAlgorithms = { - CUBLAS_GEMM_DFALT, CUBLAS_GEMM_ALGO0, CUBLAS_GEMM_ALGO1, CUBLAS_GEMM_ALGO2, CUBLAS_GEMM_ALGO3, - CUBLAS_GEMM_ALGO4, CUBLAS_GEMM_ALGO5, CUBLAS_GEMM_ALGO6, CUBLAS_GEMM_ALGO7 -}; - -static const std::map cudtype = {{FLOAT_TYPE, CUDA_R_32F}, {DOUBLE_TYPE,CUDA_R_64F}}; -static const std::map cuop = {{'N', CUBLAS_OP_N}, {'T', CUBLAS_OP_T}}; - -inline cublasGemmAlgo_t cublasGemmFastest(stream& stream, cublasHandle_t handle, cudaDataType cudt, cublasOperation_t AT, cublasOperation_t BT, int32_t M, int32_t N, int32_t K, - void* alpha, CUdeviceptr A, int32_t lda, CUdeviceptr B, int32_t ldb, - void* beta, CUdeviceptr C, int32_t ldc){ - - typedef std::tuple key_t; - // Benchmark fastest algorithm in cublasGemmEx - auto benchmark_fastest = [&](key_t const &){ - std::vector times; - for(cublasGemmAlgo_t a: cublasAlgorithms){ - try{ - times.push_back(bench([&](){ dispatch::cublasGemmEx(handle, AT, BT, M, N, K, alpha, (const void*)A, cudt, lda, (const void*)B, cudt, ldb, beta, (void*)C, cudt, ldc, cudt, a); }, - [&](){ stream.synchronize(); }, - stream.context().device())); - }catch(driver::exception::cublas::base const &){ - times.push_back(INFINITY); - } - } - size_t argmin = std::min_element(times.begin(), times.end()) - times.begin(); - return cublasAlgorithms[argmin]; - }; - // Cache result - static cpp::CachedMap cache(benchmark_fastest); - return cache.get(std::make_tuple(cudt, AT, BT, M, N, K)); -} - -/* Wrapper for cublasGemmEx */ -inline void cublasGemmEx(cublasHandle_t handle, cudaDataType cudt, cublasOperation_t AT, cublasOperation_t BT, int32_t M, int32_t N, int32_t K, - void* alpha, CUdeviceptr A, int32_t lda, CUdeviceptr B, int32_t ldb, - void* beta, CUdeviceptr C, int32_t ldc, cublasGemmAlgo_t algo) -{ dispatch::cublasGemmEx(handle, AT, BT, M, N, K, alpha, (const void*)A, cudt, lda, (const void*)B, cudt, ldb, beta, (void*)C, cudt, ldc, cudt, algo); } - - -/* Simplified API for default GEMM */ -inline void cublasGemm(DType dtype, stream& stream, char cAT, char cBT, int32_t M, int32_t N, int32_t K, scalar alpha, cu_buffer const & A, int32_t lda, cu_buffer const & B, int32_t ldb, scalar beta, cu_buffer& C, int32_t ldc, cublasGemmAlgo_t* fastest = NULL, cublasGemmAlgo_t algo = CUBLAS_GEMM_DFALT){ - ContextSwitcher ctx_switch(stream.context()); - cublasHandle_t handle = dispatch::cublasHandle(stream.context()); - dispatch::cublasSetStream_v2(handle, (CUstream)stream); - if(fastest) - *fastest = cublasGemmFastest(stream, handle, cudtype.at(dtype), cuop.at(cAT), cuop.at(cBT), M, N, K, alpha.data(), A, lda, B, ldb, beta.data(), C, ldc); - else - cublasGemmEx(handle, cudtype.at(dtype), cuop.at(cAT), cuop.at(cBT), M, N, K, alpha.data(), A, lda, B, ldb, beta.data(), C, ldc, algo); -} - -inline cudnnDataType_t cudnnDtype(DType dtype){ - switch(dtype){ - case INT8X4_TYPE: return CUDNN_DATA_INT8x4; - case INT32_TYPE: return CUDNN_DATA_INT32; - case FLOAT_TYPE: return CUDNN_DATA_FLOAT; - case DOUBLE_TYPE: return CUDNN_DATA_DOUBLE; - } - throw; -} - -inline cudnnTensorFormat_t format(cudnnDataType_t cutype){ - switch(cutype){ - case CUDNN_DATA_INT8x4: return CUDNN_TENSOR_NCHW_VECT_C; - default: return CUDNN_TENSOR_NCHW; - } -} - -inline void cudnnConv(DType dtype, stream& stream, int32_t D, int32_t H, int32_t W, int32_t N, int32_t K, int32_t M, int32_t P, int32_t Q, int32_t C, int32_t T, int32_t R, int32_t S, - int32_t pad_d, int32_t pad_h, int32_t pad_w, int32_t stride_d, int32_t stride_h, int32_t stride_w, scalar alpha, cu_buffer const & I, cu_buffer const & F, scalar beta, cu_buffer const & O){ - driver::driver::context const & ctx = stream.context(); - ContextSwitcher switch_ctx(ctx); - - std::vector pad = {pad_d, pad_h, pad_w}; - std::vector stride = {stride_d, stride_h, stride_w}; - std::vector upscale = {1, 1, 1}; - std::vector Oshapes = {N, K, M, P, Q}; - std::vector Fshapes = {K, C, T, R, S}; - std::vector Ishapes = {N, C, D, H, W}; - if(M == 1 && T == 1 && D == 1){ - pad.erase(pad.begin()); - stride.erase(stride.begin()); - upscale.erase(upscale.begin()); - Oshapes.erase(Oshapes.begin() + 2); - Ishapes.erase(Ishapes.begin() + 2); - Fshapes.erase(Fshapes.begin() + 2); - } - - cudnnHandle_t handle = dispatch::cudnnHandle(ctx); - cudnnDataType_t in_cutype = cudnnDtype(dtype); - cudnnDataType_t conv_cutype = (dtype == INT8X4_TYPE)?CUDNN_DATA_INT32:in_cutype; - - dispatch::cudnnSetStream(handle, (CUstream)stream); - cudnnTensorDescriptor_t tO, tI; - cudnnFilterDescriptor_t tF; - cudnnConvolutionDescriptor_t conv; - cudnnConvolutionFwdAlgo_t algo; - dispatch::cudnnCreateTensorDescriptor(&tO); - dispatch::cudnnCreateTensorDescriptor(&tI); - dispatch::cudnnCreateFilterDescriptor(&tF); - - dispatch::cudnnSetTensorNdDescriptorEx(tO, format(in_cutype), in_cutype, Oshapes.size(), Oshapes.data()); - dispatch::cudnnSetFilterNdDescriptor(tF, in_cutype, format(in_cutype), Fshapes.size(), Fshapes.data()); - dispatch::cudnnSetTensorNdDescriptorEx(tI, format(in_cutype), in_cutype, Ishapes.size(), Ishapes.data()); - - dispatch::cudnnCreateConvolutionDescriptor(&conv); - dispatch::cudnnSetConvolutionNdDescriptor(conv, pad.size(), pad.data(), stride.data(), upscale.data(), CUDNN_CROSS_CORRELATION, conv_cutype); - dispatch::cudnnGetConvolutionForwardAlgorithm(handle, tI, tF, conv, tO, CUDNN_CONVOLUTION_FWD_SPECIFY_WORKSPACE_LIMIT, 1024*1024*64, &algo); - - size_t workspace_size; - dispatch::cudnnGetConvolutionForwardWorkspaceSize(handle, tI, tF, conv, tO, algo, &workspace_size); - static cu_buffer work(ctx, 1024*1024*64); - CUdeviceptr twork = work; - CUdeviceptr pI = I, pF = F, pO = O; - dispatch::cudnnConvolutionForward(handle, alpha.data(), tI, (void*)pI, tF, (void*)pF, conv, algo, (void*)twork, workspace_size, beta.data(), tO, (void*)pO); -} - - -inline void cudnnPool(DType dtype, stream& stream, int32_t D, int32_t H, int32_t W, int32_t N, int32_t K, int32_t M, int32_t P, int32_t Q, int32_t T, int32_t R, int32_t S, - int32_t pad_d, int32_t pad_h, int32_t pad_w, int32_t stride_d, int32_t stride_h, int32_t stride_w, scalar alpha, cu_buffer const & I, scalar beta, cu_buffer const & O){ - driver::driver::context const & ctx = stream.context(); - ContextSwitcher switch_ctx(ctx); - - std::vector pad = {pad_d, pad_h, pad_w}; - std::vector stride = {stride_d, stride_h, stride_w}; - std::vector upscale = {1, 1, 1}; - std::vector Oshapes = {N, K, M, P, Q}; - std::vector Ishapes = {N, K, D, H, W}; - std::vector window = {T, R, S}; - if(M == 1 && T == 1 && D == 1){ - window.erase(window.begin()); - pad.erase(pad.begin()); - stride.erase(stride.begin()); - upscale.erase(upscale.begin()); - Oshapes.erase(Oshapes.begin() + 2); - Ishapes.erase(Ishapes.begin() + 2); - } - - cudnnHandle_t handle = dispatch::cudnnHandle(ctx); - cudnnDataType_t cutype = cudnnDtype(dtype); - - dispatch::cudnnSetStream(handle, (CUstream)stream); - cudnnTensorDescriptor_t tO, tI; - cudnnPoolingDescriptor_t desc; - dispatch::cudnnCreateTensorDescriptor(&tO); - dispatch::cudnnCreateTensorDescriptor(&tI); - - dispatch::cudnnSetTensorNdDescriptorEx(tO, CUDNN_TENSOR_NCHW, cutype, Oshapes.size(), Oshapes.data()); - dispatch::cudnnSetTensorNdDescriptorEx(tI, CUDNN_TENSOR_NCHW, cutype, Ishapes.size(), Ishapes.data()); - - dispatch::cudnnCreatePoolingDescriptor(&desc); - dispatch::cudnnSetPoolingNdDescriptor(desc, CUDNN_POOLING_MAX, CUDNN_NOT_PROPAGATE_NAN, window.size(), window.data(), pad.data(), stride.data()); - - CUdeviceptr pI = I, pO = O; - dispatch::cudnnPoolingForward(handle, desc, alpha.data(), tI, (void*)pI, beta.data(), tO, (void*)pO); -} - -inline void cudnnTransformTensor(driver::cu_stream & stream, - DType in_dtype, DType out_dtype, - cudnnTensorFormat_t in_layout, cudnnTensorFormat_t out_layout, - int32_t N, int32_t C, int32_t D, int32_t H, int32_t W, - scalar alpha, driver::cu_buffer const & I, scalar beta, driver::cu_buffer& O) -{ - cudnnHandle_t handle = dispatch::cudnnHandle(stream.context()); - dispatch::cudnnSetStream(handle, (CUstream)stream); - - cudnnTensorDescriptor_t tO, tI; - std::vector shapes = {N, C, D, H, W}; - dispatch::cudnnCreateTensorDescriptor(&tI); - dispatch::cudnnSetTensorNdDescriptorEx(tI, in_layout, cudnnDtype(in_dtype), shapes.size(), shapes.data()); - dispatch::cudnnCreateTensorDescriptor(&tO); - dispatch::cudnnSetTensorNdDescriptorEx(tO, out_layout, cudnnDtype(out_dtype), shapes.size(), shapes.data()); - - CUdeviceptr pI = I, pO = O; - dispatch::cudnnTransformTensor(handle, alpha.data(), tI, (void*)pI, beta.data(), tO, (void*)pO); -} - - -} -} - - - -#endif diff --git a/include/triton/driver/event.h b/include/triton/driver/event.h deleted file mode 100755 index 7310d001f..000000000 --- a/include/triton/driver/event.h +++ /dev/null @@ -1,29 +0,0 @@ -#pragma once - -#ifndef _TRITON_DRIVER_EVENT_H_ -#define _TRITON_DRIVER_EVENT_H_ - -#include "triton/driver/handle.h" - -namespace triton -{ - -namespace driver -{ - -// event -class event -{ -public: - float elapsed_time() const; - handle const & cu() const; - -private: - handle cu_; -}; - -} - -} - -#endif diff --git a/include/triton/driver/stream.h b/include/triton/driver/stream.h index 0d45975ff..b29813fd4 100755 --- a/include/triton/driver/stream.h +++ b/include/triton/driver/stream.h @@ -29,7 +29,7 @@ public: static driver::stream* create(backend_t backend); // methods virtual void synchronize() = 0; - virtual void enqueue(driver::kernel* kernel, std::array grid, std::array block, std::vector const * = NULL, event *event = NULL, void **args = NULL, size_t args_size = 0) = 0; + virtual void enqueue(driver::kernel* kernel, std::array grid, std::array block, void **args = NULL, size_t args_size = 0) = 0; virtual void write(driver::buffer* buf, bool blocking, std::size_t offset, std::size_t size, void const* ptr) = 0; virtual void read(driver::buffer* buf, bool blocking, std::size_t offset, std::size_t size, void* ptr) = 0; // template helpers @@ -42,12 +42,9 @@ public: // Host class host_stream: public stream { public: - // Constructors host_stream(); - - // Overridden void synchronize(); - void enqueue(driver::kernel* kernel, std::array grid, std::array block, std::vector const *, event *event, void **args, size_t args_size); + void enqueue(driver::kernel* kernel, std::array grid, std::array block, void **args, size_t args_size); void write(driver::buffer* buf, bool blocking, std::size_t offset, std::size_t size, void const* ptr); void read(driver::buffer* buf, bool blocking, std::size_t offset, std::size_t size, void* ptr); }; @@ -55,13 +52,10 @@ public: // CUDA class cu_stream: public stream { public: - // Constructors cu_stream(CUstream str, bool take_ownership); cu_stream(); - - // Overridden void synchronize(); - void enqueue(driver::kernel* kernel, std::array grid, std::array block, std::vector const *, event *event, void **args, size_t args_size); + void enqueue(driver::kernel* kernel, std::array grid, std::array block, void **args, size_t args_size); void write(driver::buffer* buf, bool blocking, std::size_t offset, std::size_t size, void const* ptr); void read(driver::buffer* buf, bool blocking, std::size_t offset, std::size_t size, void* ptr); }; diff --git a/lib/driver/event.cc b/lib/driver/event.cc deleted file mode 100755 index ad341d701..000000000 --- a/lib/driver/event.cc +++ /dev/null @@ -1,40 +0,0 @@ -/* Copyright 2015-2017 Philippe Tillet -* -* Permission is hereby granted, free of charge, to any person obtaining -* a copy of this software and associated documentation files -* (the "Software"), to deal in the Software without restriction, -* including without limitation the rights to use, copy, modify, merge, -* publish, distribute, sublicense, and/or sell copies of the Software, -* and to permit persons to whom the Software is furnished to do so, -* subject to the following conditions: -* -* The above copyright notice and this permission notice shall be -* included in all copies or substantial portions of the Software. -* -* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, -* EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF -* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. -* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY -* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, -* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE -* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. -*/ - -#include "triton/driver/event.h" - -namespace triton -{ -namespace driver -{ - -float event::elapsed_time() const{ - float time; - dispatch::cuEventElapsedTime(&time, cu_->first, cu_->second); - return time; -} - -handle const & event::cu() const -{ return cu_; } - -} -} diff --git a/lib/driver/stream.cc b/lib/driver/stream.cc index 4fd9e7436..7b25e4c18 100755 --- a/lib/driver/stream.cc +++ b/lib/driver/stream.cc @@ -27,7 +27,6 @@ #include "triton/driver/stream.h" #include "triton/driver/context.h" #include "triton/driver/device.h" -#include "triton/driver/event.h" #include "triton/driver/kernel.h" #include "triton/driver/buffer.h" #include "llvm/ExecutionEngine/ExecutionEngine.h" @@ -77,7 +76,7 @@ void host_stream::synchronize() { hst_->args.clear(); } -void host_stream::enqueue(driver::kernel* kernel, std::array grid, std::array block, std::vector const *, event* event, void **args, size_t args_size) { +void host_stream::enqueue(driver::kernel* kernel, std::array grid, std::array block, void **args, size_t args_size) { auto hst = kernel->module()->hst(); hst_->futures->reserve(hst_->futures->size() + grid[0]*grid[1]*grid[2]); char* params = new char[args_size]; @@ -114,17 +113,13 @@ void cu_stream::synchronize() { dispatch::cuStreamSynchronize(*cu_); } -void cu_stream::enqueue(driver::kernel* kernel, std::array grid, std::array block, std::vector const *, event* event, void** args, size_t args_size) { +void cu_stream::enqueue(driver::kernel* kernel, std::array grid, std::array block, void** args, size_t args_size) { void *config[] = { CU_LAUNCH_PARAM_BUFFER_POINTER, args, CU_LAUNCH_PARAM_BUFFER_SIZE, &args_size, CU_LAUNCH_PARAM_END }; - if(event) - dispatch::cuEventRecord(event->cu()->first, *cu_); dispatch::cuLaunchKernel(*kernel->cu(), grid[0], grid[1], grid[2], block[0], block[1], block[2], 0, *cu_, nullptr, config); - if(event) - dispatch::cuEventRecord(event->cu()->second, *cu_); } void cu_stream::write(driver::buffer* buffer, bool blocking, std::size_t offset, std::size_t size, void const* ptr) { diff --git a/lib/runtime/function.cc b/lib/runtime/function.cc index 96a1dbd97..71a18027d 100644 --- a/lib/runtime/function.cc +++ b/lib/runtime/function.cc @@ -176,7 +176,7 @@ void function::caller::operator ()(driver::stream *stream, const grid_t& _grid, for(size_t i = 0; i < 3; i++) grid[i] = (i < _grid.size()) ? _grid[i] : 1; // enqueue - stream->enqueue(&*bin_, grid, {opt_.num_warps * 32, 1, 1}, NULL, NULL, args, args_size); + stream->enqueue(&*bin_, grid, {opt_.num_warps * 32, 1, 1}, args, args_size); }