From 6d62d88d4f3899bbf9009547a7d731eed291b244 Mon Sep 17 00:00:00 2001 From: Philippe Tillet Date: Tue, 26 Jul 2022 17:25:03 -0700 Subject: [PATCH] [CI] run clang-format (#24) --- .github/workflows/integration-tests.yml | 8 +- bin/FileCheck/FileCheck.cpp | 17 +- bin/triton-opt.cpp | 20 +- include/triton/Analysis/AxisInfo.h | 67 +- include/triton/Conversion/Passes.h | 8 +- .../TritonToTritonGPU/TritonToTritonGPU.h | 7 +- include/triton/Dialect/Triton/IR/Dialect.h | 9 +- include/triton/Dialect/Triton/IR/Traits.h | 14 +- .../triton/Dialect/Triton/Transforms/Passes.h | 2 +- include/triton/Dialect/TritonGPU/IR/Dialect.h | 1 - .../Transforms/TritonGPUConversion.h | 5 +- include/triton/driver/dispatch.h | 387 +- include/triton/driver/error.h | 415 +- include/triton/driver/llvm.h | 25 +- include/triton/tools/bench.hpp | 47 +- include/triton/tools/graph.h | 27 +- include/triton/tools/sha1.hpp | 248 +- include/triton/tools/sys/exec.hpp | 18 +- include/triton/tools/sys/getenv.hpp | 27 +- include/triton/tools/sys/mkdir.hpp | 74 +- include/triton/tools/thread_pool.h | 121 +- lib/Analysis/AxisInfo.cpp | 138 +- lib/Conversion/PassDetail.h | 10 +- .../TritonToTritonGPU/TritonToTritonGPU.cpp | 265 +- lib/Dialect/Triton/IR/Dialect.cpp | 8 +- lib/Dialect/Triton/IR/Ops.cpp | 58 +- lib/Dialect/Triton/IR/Types.cpp | 4 +- lib/Dialect/Triton/Transforms/Combine.cpp | 61 +- lib/Dialect/TritonGPU/IR/Dialect.cpp | 112 +- lib/Dialect/TritonGPU/Transforms/Combine.cpp | 4 +- lib/Dialect/TritonGPU/Transforms/Pipeline.cpp | 126 +- .../Transforms/TritonGPUConversion.cpp | 109 +- lib/Dialect/TritonGPU/Transforms/Verifier.cpp | 17 +- lib/driver/dispatch.cc | 408 +- lib/driver/error.cc | 410 +- lib/driver/llvm.cc | 248 +- python/src/pybind11/attr.h | 610 +-- python/src/pybind11/buffer_info.h | 165 +- python/src/pybind11/cast.h | 3420 ++++++++------- python/src/pybind11/chrono.h | 241 +- python/src/pybind11/common.h | 3 +- python/src/pybind11/complex.h | 59 +- python/src/pybind11/detail/class.h | 861 ++-- python/src/pybind11/detail/common.h | 1012 +++-- python/src/pybind11/detail/descr.h | 96 +- python/src/pybind11/detail/init.h | 502 ++- python/src/pybind11/detail/internals.h | 445 +- python/src/pybind11/detail/typeid.h | 35 +- python/src/pybind11/eigen.h | 1044 +++-- python/src/pybind11/embed.h | 182 +- python/src/pybind11/eval.h | 149 +- python/src/pybind11/functional.h | 154 +- python/src/pybind11/iostream.h | 168 +- python/src/pybind11/numpy.h | 2616 ++++++------ python/src/pybind11/operators.h | 284 +- python/src/pybind11/options.h | 74 +- python/src/pybind11/pybind11.h | 3714 +++++++++-------- python/src/pybind11/pytypes.h | 2154 +++++----- python/src/pybind11/stl.h | 546 +-- python/src/pybind11/stl_bind.h | 911 ++-- python/src/triton.cc | 2018 +++++---- test/lib/Analysis/TestAxisInfo.cpp | 52 +- 62 files changed, 13673 insertions(+), 11367 deletions(-) mode change 100755 => 100644 include/triton/driver/dispatch.h mode change 100755 => 100644 include/triton/driver/error.h mode change 100755 => 100644 include/triton/tools/sys/getenv.hpp mode change 100755 => 100644 include/triton/tools/sys/mkdir.hpp mode change 100755 => 100644 lib/driver/dispatch.cc mode change 100755 => 100644 lib/driver/error.cc diff --git a/.github/workflows/integration-tests.yml b/.github/workflows/integration-tests.yml index 89ac8f403..cabbdba90 100644 --- a/.github/workflows/integration-tests.yml +++ b/.github/workflows/integration-tests.yml @@ -27,10 +27,16 @@ jobs: pip install isort isort -c ./python || ( echo '::error title=Imports not sorted::Please run \"isort ./python\"' ; exit 1 ) - - name: Check style + - name: Check python style run: | pip install autopep8 autopep8 -a -r -d --exit-code ./python || ( echo '::error title=Style issues::Please run \"autopep8 -a -r -i ./python\"' ; exit 1 ) + + - name: Check cpp style + run: | + sudo apt-get install clang-format + find . -regex '.*\.\(cpp\|hpp\|h\|cc\)' -not -path "./python/build/*" -not -path "./include/triton/external/*" -print0 | xargs -0 -n1 clang-format -style=file --dry-run -Werror -i || + (echo '::error title=Style issues:: Please run `find . -regex ".*\.\(cpp\|hpp\|h\|cc\)" -not -path "./python/build/*" -not -path "./include/triton/external/*" -print0 | xargs -0 -n1 clang-format -style=file -i`' ; exit 1) - name: Flake8 run: | diff --git a/bin/FileCheck/FileCheck.cpp b/bin/FileCheck/FileCheck.cpp index 6742853c9..819efc354 100644 --- a/bin/FileCheck/FileCheck.cpp +++ b/bin/FileCheck/FileCheck.cpp @@ -57,9 +57,8 @@ static cl::opt NoCanonicalizeWhiteSpace( "strict-whitespace", cl::desc("Do not treat all horizontal whitespace as equivalent")); -static cl::opt IgnoreCase( - "ignore-case", - cl::desc("Use case-insensitive matching")); +static cl::opt IgnoreCase("ignore-case", + cl::desc("Use case-insensitive matching")); static cl::list ImplicitCheckNot( "implicit-check-not", @@ -169,12 +168,6 @@ static cl::list DumpInputContexts( typedef cl::list::const_iterator prefix_iterator; - - - - - - static void DumpCommandLine(int argc, char **argv) { errs() << "FileCheck command line: "; for (int I = 0; I < argc; I++) @@ -613,8 +606,7 @@ static void DumpAnnotatedInput(raw_ostream &OS, const FileCheckRequest &Req, ElidedLinesOS.enable_colors(true); auto AnnotationItr = Annotations.begin(), AnnotationEnd = Annotations.end(); for (unsigned Line = 1; - InputFilePtr != InputFileEnd || AnnotationItr != AnnotationEnd; - ++Line) { + InputFilePtr != InputFileEnd || AnnotationItr != AnnotationEnd; ++Line) { const unsigned char *InputFileLine = InputFilePtr; // Compute the previous and next line included by the filter. @@ -691,8 +683,7 @@ static void DumpAnnotatedInput(raw_ostream &OS, const FileCheckRequest &Req, unsigned InputLineWidth = InputFilePtr - InputFileLine; // Print any annotations. - while (AnnotationItr != AnnotationEnd && - AnnotationItr->InputLine == Line) { + while (AnnotationItr != AnnotationEnd && AnnotationItr->InputLine == Line) { WithColor COS(*LineOS, AnnotationItr->Marker.Color, /*Bold=*/true, /*BG=*/false, TheColorMode); // The two spaces below are where the ": " appears on input lines. diff --git a/bin/triton-opt.cpp b/bin/triton-opt.cpp index d5d73e5f6..4942214cc 100644 --- a/bin/triton-opt.cpp +++ b/bin/triton-opt.cpp @@ -10,11 +10,11 @@ #include "mlir/InitAllPasses.h" #include "mlir/Support/MlirOptMain.h" -namespace mlir{ -namespace test{ +namespace mlir { +namespace test { void registerTestAlignmentPass(); } -} +} // namespace mlir int main(int argc, char **argv) { mlir::registerAllPasses(); @@ -25,13 +25,11 @@ int main(int argc, char **argv) { // TODO: register Triton & TritonGPU passes mlir::DialectRegistry registry; - registry.insert(); + registry + .insert(); - return mlir::asMainReturnCode( - mlir::MlirOptMain(argc, argv, "Triton (GPU) optimizer driver\n", registry) - ); + return mlir::asMainReturnCode(mlir::MlirOptMain( + argc, argv, "Triton (GPU) optimizer driver\n", registry)); } diff --git a/include/triton/Analysis/AxisInfo.h b/include/triton/Analysis/AxisInfo.h index 0910f341f..c9be250fc 100644 --- a/include/triton/Analysis/AxisInfo.h +++ b/include/triton/Analysis/AxisInfo.h @@ -10,7 +10,6 @@ namespace mlir { - //===----------------------------------------------------------------------===// // AxisInfo //===----------------------------------------------------------------------===// @@ -25,26 +24,25 @@ public: public: // Default constructor - AxisInfo(): AxisInfo({}, {}, {}) { } + AxisInfo() : AxisInfo({}, {}, {}) {} // Construct contiguity info with known contiguity AxisInfo(ContiguityT knownContiguity, DivisibilityT knownDivisibility, ConstancyT knownConstancy) - : contiguity(knownContiguity), divisibility(knownDivisibility), - constancy(knownConstancy), rank(contiguity.size()) { - assert(knownDivisibility.size() == rank); - assert(knownConstancy.size() == rank); - } - - + : contiguity(knownContiguity), divisibility(knownDivisibility), + constancy(knownConstancy), rank(contiguity.size()) { + assert(knownDivisibility.size() == rank); + assert(knownConstancy.size() == rank); + } + // Accessors - int getContiguity(size_t d) const { return contiguity[d]; } - const ContiguityT& getContiguity() const { return contiguity; } + int getContiguity(size_t d) const { return contiguity[d]; } + const ContiguityT &getContiguity() const { return contiguity; } int getDivisibility(size_t d) const { return divisibility[d]; } - const DivisibilityT& getDivisibility() const { return divisibility; } + const DivisibilityT &getDivisibility() const { return divisibility; } - int getConstancy(size_t d) const { return constancy[d]; } - const ConstancyT& getConstancy() const { return constancy; } + int getConstancy(size_t d) const { return constancy[d]; } + const ConstancyT &getConstancy() const { return constancy; } int getRank() const { return rank; } @@ -56,13 +54,13 @@ public: } /// The pessimistic value state of the contiguity is unknown. - static AxisInfo getPessimisticValueState(MLIRContext *context) - { return AxisInfo(); } + static AxisInfo getPessimisticValueState(MLIRContext *context) { + return AxisInfo(); + } static AxisInfo getPessimisticValueState(Value value); // The gcd of both arguments for each dimension - static AxisInfo join(const AxisInfo &lhs, - const AxisInfo &rhs); + static AxisInfo join(const AxisInfo &lhs, const AxisInfo &rhs); private: /// The _contiguity_ information maps the `d`-th @@ -81,7 +79,7 @@ private: /// [19, 23, 27, 31] /// Would have contiguity [2, 1]. ContiguityT contiguity; - + /// The _divisibility_ information maps the `d`-th /// dimension to the largest power-of-two that /// divides the first element of all the values along it @@ -107,39 +105,36 @@ private: /// [16, 16, 16, 16, 20, 20, 20, 20] /// would have constancy [1, 4] ConstancyT constancy; - + // number of dimensions of the lattice int rank; }; - -class AxisInfoAnalysis - : public ForwardDataFlowAnalysis { +class AxisInfoAnalysis : public ForwardDataFlowAnalysis { private: static const int maxPow2Divisor = 65536; - - int highestPowOf2Divisor(int n){ - if(n==0) + + int highestPowOf2Divisor(int n) { + if (n == 0) return maxPow2Divisor; return (n & (~(n - 1))); } - AxisInfo visitBinaryOp(Operation* op, AxisInfo lhsInfo, AxisInfo rhsInfo, - const std::function& getContiguity, - const std::function& getDivisibility, - const std::function& getConstancy); + AxisInfo visitBinaryOp( + Operation *op, AxisInfo lhsInfo, AxisInfo rhsInfo, + const std::function &getContiguity, + const std::function &getDivisibility, + const std::function &getConstancy); public: using ForwardDataFlowAnalysis::ForwardDataFlowAnalysis; - ChangeResult visitOperation(Operation *op, - ArrayRef *> operands) override; - + ChangeResult + visitOperation(Operation *op, + ArrayRef *> operands) override; }; - -} - +} // namespace mlir #endif \ No newline at end of file diff --git a/include/triton/Conversion/Passes.h b/include/triton/Conversion/Passes.h index 125551f5c..8cf53bc1c 100644 --- a/include/triton/Conversion/Passes.h +++ b/include/triton/Conversion/Passes.h @@ -3,17 +3,13 @@ #include "triton/Conversion/TritonToTritonGPU/TritonToTritonGPU.h" -namespace mlir -{ -namespace triton -{ +namespace mlir { +namespace triton { #define GEN_PASS_REGISTRATION #include "triton/Conversion/Passes.h.inc" - } // namespace triton } // namespace mlir - #endif \ No newline at end of file diff --git a/include/triton/Conversion/TritonToTritonGPU/TritonToTritonGPU.h b/include/triton/Conversion/TritonToTritonGPU/TritonToTritonGPU.h index b21b6a1f1..bdb058249 100644 --- a/include/triton/Conversion/TritonToTritonGPU/TritonToTritonGPU.h +++ b/include/triton/Conversion/TritonToTritonGPU/TritonToTritonGPU.h @@ -3,18 +3,17 @@ #include -namespace mlir{ +namespace mlir { class ModuleOp; template class OperationPass; -namespace triton{ +namespace triton { -std::unique_ptr> +std::unique_ptr> createConvertTritonToTritonGPUPass(int numWarps = 4); } } // namespace mlir - #endif \ No newline at end of file diff --git a/include/triton/Dialect/Triton/IR/Dialect.h b/include/triton/Dialect/Triton/IR/Dialect.h index 80a2aab2e..8590db9c4 100644 --- a/include/triton/Dialect/Triton/IR/Dialect.h +++ b/include/triton/Dialect/Triton/IR/Dialect.h @@ -1,17 +1,16 @@ #ifndef TRITON_DIALECT_TRITON_IR_DIALECT_H_ #define TRITON_DIALECT_TRITON_IR_DIALECT_H_ - +#include "mlir/Dialect/SCF/SCF.h" +#include "mlir/Dialect/StandardOps/IR/Ops.h" #include "mlir/IR/BuiltinOps.h" #include "mlir/IR/Dialect.h" #include "mlir/Interfaces/ControlFlowInterfaces.h" -#include "mlir/Dialect/StandardOps/IR/Ops.h" -#include "mlir/Dialect/SCF/SCF.h" -#include "triton/Dialect/Triton/IR/Traits.h" -#include "triton/Dialect/Triton/IR/Types.h" #include "triton/Dialect/Triton/IR/Dialect.h.inc" #include "triton/Dialect/Triton/IR/OpsEnums.h.inc" +#include "triton/Dialect/Triton/IR/Traits.h" +#include "triton/Dialect/Triton/IR/Types.h" #define GET_OP_CLASSES #include "triton/Dialect/Triton/IR/Ops.h.inc" diff --git a/include/triton/Dialect/Triton/IR/Traits.h b/include/triton/Dialect/Triton/IR/Traits.h index a72e2c99c..fd20236f1 100644 --- a/include/triton/Dialect/Triton/IR/Traits.h +++ b/include/triton/Dialect/Triton/IR/Traits.h @@ -19,7 +19,7 @@ public: static LogicalResult verifyTrait(Operation *op) { // The rationale for this number is to prevent users from creating programs // that would have catastrophic register pressure and cause the compiler to - // hang. + // hang. // Since H100 has 256KB registers, we should allow users to create tensors // of size up to 256K elements. It will spill for datatypes wider than 1B, // but we probably should limit number of elements (rather than bytes) to @@ -31,8 +31,8 @@ public: for (int64_t s : tensorType.getShape()) numElements *= s; if (numElements > maxElement) - return op->emitError("Maximum allowed number of elements is ") << maxElement << ", but " - << *op << " has more than that"; + return op->emitError("Maximum allowed number of elements is ") + << maxElement << ", but " << *op << " has more than that"; if ((numElements & (numElements - 1)) != 0) return op->emitError("Number of elements must be power-of-two, but ") << *op << " doesn't follow the rule"; @@ -45,8 +45,8 @@ public: for (int64_t s : tensorType.getShape()) numElements *= s; if (numElements > maxElement) - return op->emitError("Maximum allowed number of elements is ") << maxElement << ", but " - << *op << " has more than that"; + return op->emitError("Maximum allowed number of elements is ") + << maxElement << ", but " << *op << " has more than that"; if ((numElements & (numElements - 1)) != 0) return op->emitError("Number of elements must be power-of-two, but ") << *op << " doesn't follow the rule"; @@ -57,7 +57,7 @@ public: } }; -} -} +} // namespace OpTrait +} // namespace mlir #endif diff --git a/include/triton/Dialect/Triton/Transforms/Passes.h b/include/triton/Dialect/Triton/Transforms/Passes.h index 1064501b1..5dae1a498 100644 --- a/include/triton/Dialect/Triton/Transforms/Passes.h +++ b/include/triton/Dialect/Triton/Transforms/Passes.h @@ -13,6 +13,6 @@ std::unique_ptr createCombineOpsPass(); #define GEN_PASS_REGISTRATION #include "triton/Dialect/Triton/Transforms/Passes.h.inc" -} +} // namespace mlir #endif diff --git a/include/triton/Dialect/TritonGPU/IR/Dialect.h b/include/triton/Dialect/TritonGPU/IR/Dialect.h index dfa5ef864..9e8605ec8 100644 --- a/include/triton/Dialect/TritonGPU/IR/Dialect.h +++ b/include/triton/Dialect/TritonGPU/IR/Dialect.h @@ -15,5 +15,4 @@ #define GET_OP_CLASSES #include "triton/Dialect/TritonGPU/IR/Ops.h.inc" - #endif // TRITON_DIALECT_TRITONGPU_IR_DIALECT_H_ diff --git a/include/triton/Dialect/TritonGPU/Transforms/TritonGPUConversion.h b/include/triton/Dialect/TritonGPU/Transforms/TritonGPUConversion.h index fd9048570..6cb59c327 100644 --- a/include/triton/Dialect/TritonGPU/Transforms/TritonGPUConversion.h +++ b/include/triton/Dialect/TritonGPU/Transforms/TritonGPUConversion.h @@ -14,6 +14,7 @@ namespace mlir { class TritonGPUTypeConverter : public TypeConverter { public: TritonGPUTypeConverter(MLIRContext *context, int numThreads); + private: MLIRContext *context; int numThreads; @@ -21,8 +22,10 @@ private: class TritonGPUConversionTarget : public ConversionTarget { TritonGPUTypeConverter &typeConverter; + public: - explicit TritonGPUConversionTarget(MLIRContext &ctx, TritonGPUTypeConverter &typeConverter); + explicit TritonGPUConversionTarget(MLIRContext &ctx, + TritonGPUTypeConverter &typeConverter); /// update layouts & insert ConvertLayoutOp if necessary LogicalResult refineLayouts(ModuleOp mod, int numThreads); diff --git a/include/triton/driver/dispatch.h b/include/triton/driver/dispatch.h old mode 100755 new mode 100644 index 5503bacaf..85fc2cbc9 --- a/include/triton/driver/dispatch.h +++ b/include/triton/driver/dispatch.h @@ -3,10 +3,10 @@ #ifndef _TRITON_DRIVER_DISPATCH_H_ #define _TRITON_DRIVER_DISPATCH_H_ -#include #include +#include -//CUDA Backend +// CUDA Backend #include "triton/external/CUDA/cuda.h" #include "triton/external/CUDA/nvml.h" @@ -14,47 +14,43 @@ //#define __HIP_PLATFORM_AMD__ #include "triton/external/hip.h" -//Exceptions +// Exceptions #include #include namespace llvm { class PassRegistry; class Module; -} +} // namespace llvm -namespace triton -{ -namespace driver -{ +namespace triton { +namespace driver { class cu_context; -template void check(T){} +template void check(T) {} void check(CUresult err); void check(hipError_t err); -class dispatch -{ +class dispatch { protected: - template - struct return_type; + template struct return_type; - template - struct return_type - { typedef R type; }; + template struct return_type { + typedef R type; + }; typedef bool (*f_init_t)(); - template - static typename return_type::type f_impl(void*& lib_h, FunPtrT, void*& cache, const char * name, Args... args) - { + template + static typename return_type::type + f_impl(void *&lib_h, FunPtrT, void *&cache, const char *name, Args... args) { initializer(); - if(cache == nullptr){ + if (cache == nullptr) { cache = dlsym(lib_h, name); - if(cache == 0) - throw std::runtime_error("dlsym unable to load function"); - } + if (cache == 0) + throw std::runtime_error("dlsym unable to load function"); + } FunPtrT fptr; *reinterpret_cast(&fptr) = cache; typename return_type::type res = (*fptr)(args...); @@ -76,63 +72,99 @@ public: // context management static CUresult cuInit(unsigned int Flags); static CUresult cuCtxDestroy_v2(CUcontext ctx); - static CUresult cuCtxCreate_v2(CUcontext *pctx, unsigned int flags, CUdevice dev); + static CUresult cuCtxCreate_v2(CUcontext *pctx, unsigned int flags, + CUdevice dev); static CUresult cuCtxPushCurrent_v2(CUcontext ctx); static CUresult cuCtxPopCurrent_v2(CUcontext *pctx); - static CUresult cuCtxGetDevice(CUdevice* result); - static CUresult cuCtxEnablePeerAccess(CUcontext peerContext, unsigned int flags); + static CUresult cuCtxGetDevice(CUdevice *result); + static CUresult cuCtxEnablePeerAccess(CUcontext peerContext, + unsigned int flags); static CUresult cuDriverGetVersion(int *driverVersion); // device management static CUresult cuDeviceGet(CUdevice *device, int ordinal); static CUresult cuDeviceGetName(char *name, int len, CUdevice dev); static CUresult cuDeviceGetPCIBusId(char *id, int len, CUdevice dev); - static CUresult cuDeviceGetAttribute(int *pi, CUdevice_attribute attrib, CUdevice dev); + static CUresult cuDeviceGetAttribute(int *pi, CUdevice_attribute attrib, + CUdevice dev); static CUresult cuDeviceGetCount(int *count); // link management - static CUresult cuLinkAddData_v2(CUlinkState state, CUjitInputType type, void* data, size_t size, const char* name, unsigned int numOptions, CUjit_option* options, void** optionValues); - static CUresult cuLinkCreate_v2(unsigned int numOptions, CUjit_option* options, void** optionValues, CUlinkState* stateOut); - static CUresult cuLinkComplete(CUlinkState state, void** cubinOut, size_t* sizeOut); + static CUresult cuLinkAddData_v2(CUlinkState state, CUjitInputType type, + void *data, size_t size, const char *name, + unsigned int numOptions, + CUjit_option *options, void **optionValues); + static CUresult cuLinkCreate_v2(unsigned int numOptions, + CUjit_option *options, void **optionValues, + CUlinkState *stateOut); + static CUresult cuLinkComplete(CUlinkState state, void **cubinOut, + size_t *sizeOut); static CUresult cuLinkDestroy(CUlinkState state); // module management - static CUresult cuModuleGetGlobal_v2(CUdeviceptr *dptr, size_t* bytes, CUmodule hmod, const char *name); + static CUresult cuModuleGetGlobal_v2(CUdeviceptr *dptr, size_t *bytes, + CUmodule hmod, const char *name); static CUresult cuModuleLoad(CUmodule *module, const char *fname); - static CUresult cuModuleLoadData(CUmodule* module, const void* image); + static CUresult cuModuleLoadData(CUmodule *module, const void *image); static CUresult cuModuleUnload(CUmodule hmod); - static CUresult cuModuleLoadDataEx(CUmodule *module, const void *image, unsigned int numOptions, CUjit_option *options, void **optionValues); - static CUresult cuModuleGetFunction(CUfunction *hfunc, CUmodule hmod, const char *name); + static CUresult cuModuleLoadDataEx(CUmodule *module, const void *image, + unsigned int numOptions, + CUjit_option *options, + void **optionValues); + static CUresult cuModuleGetFunction(CUfunction *hfunc, CUmodule hmod, + const char *name); // stream management static CUresult cuStreamCreate(CUstream *phStream, unsigned int Flags); static CUresult cuStreamSynchronize(CUstream hStream); - static CUresult cuStreamGetCtx(CUstream hStream, CUcontext* pctx); + static CUresult cuStreamGetCtx(CUstream hStream, CUcontext *pctx); static CUresult cuStreamDestroy_v2(CUstream hStream); - static CUresult cuLaunchKernel(CUfunction f, unsigned int gridDimX, unsigned int gridDimY, unsigned int gridDimZ, unsigned int blockDimX, unsigned int blockDimY, unsigned int blockDimZ, unsigned int sharedMemBytes, CUstream hStream, void **kernelParams, void **extra); + static CUresult cuLaunchKernel(CUfunction f, unsigned int gridDimX, + unsigned int gridDimY, unsigned int gridDimZ, + unsigned int blockDimX, unsigned int blockDimY, + unsigned int blockDimZ, + unsigned int sharedMemBytes, CUstream hStream, + void **kernelParams, void **extra); // function management - static CUresult cuFuncGetAttribute(int* pi, CUfunction_attribute attrib, CUfunction hfunc); - static CUresult cuFuncSetAttribute(CUfunction hfunc, CUfunction_attribute attrib, int value); + static CUresult cuFuncGetAttribute(int *pi, CUfunction_attribute attrib, + CUfunction hfunc); + static CUresult cuFuncSetAttribute(CUfunction hfunc, + CUfunction_attribute attrib, int value); static CUresult cuFuncSetCacheConfig(CUfunction hfunc, CUfunc_cache config); // memory management static CUresult cuMemAlloc_v2(CUdeviceptr *dptr, size_t bytesize); - static CUresult cuPointerGetAttribute(void * data, CUpointer_attribute attribute, CUdeviceptr ptr); - static CUresult cuMemsetD8Async(CUdeviceptr dst, unsigned char x, size_t N, CUstream stream); - static CUresult cuMemcpyDtoH_v2(void *dstHost, CUdeviceptr srcDevice, size_t ByteCount); + static CUresult cuPointerGetAttribute(void *data, + CUpointer_attribute attribute, + CUdeviceptr ptr); + static CUresult cuMemsetD8Async(CUdeviceptr dst, unsigned char x, size_t N, + CUstream stream); + static CUresult cuMemcpyDtoH_v2(void *dstHost, CUdeviceptr srcDevice, + size_t ByteCount); static CUresult cuMemFree_v2(CUdeviceptr dptr); - static CUresult cuMemcpyDtoHAsync_v2(void *dstHost, CUdeviceptr srcDevice, size_t ByteCount, CUstream hStream); - static CUresult cuMemcpyHtoDAsync_v2(CUdeviceptr dstDevice, const void *srcHost, size_t ByteCount, CUstream hStream); - static CUresult cuMemcpyHtoD_v2(CUdeviceptr dstDevice, const void *srcHost, size_t ByteCount); + static CUresult cuMemcpyDtoHAsync_v2(void *dstHost, CUdeviceptr srcDevice, + size_t ByteCount, CUstream hStream); + static CUresult cuMemcpyHtoDAsync_v2(CUdeviceptr dstDevice, + const void *srcHost, size_t ByteCount, + CUstream hStream); + static CUresult cuMemcpyHtoD_v2(CUdeviceptr dstDevice, const void *srcHost, + size_t ByteCount); // event management static CUresult cuEventCreate(CUevent *phEvent, unsigned int Flags); - static CUresult cuEventElapsedTime(float *pMilliseconds, CUevent hStart, CUevent hEnd); + static CUresult cuEventElapsedTime(float *pMilliseconds, CUevent hStart, + CUevent hEnd); static CUresult cuEventRecord(CUevent hEvent, CUstream hStream); static CUresult cuEventDestroy_v2(CUevent hEvent); - /* ------------------- * * NVML * ------------------- */ - static nvmlReturn_t nvmlDeviceGetHandleByPciBusId_v2( const char* pciBusId, nvmlDevice_t* device); - static nvmlReturn_t nvmlDeviceGetClockInfo(nvmlDevice_t device, nvmlClockType_t type, unsigned int *clock); - static nvmlReturn_t nvmlDeviceGetMaxClockInfo(nvmlDevice_t device, nvmlClockType_t type, unsigned int *clock); - static nvmlReturn_t nvmlDeviceSetApplicationsClocks(nvmlDevice_t device, unsigned int mem_clock, unsigned int sm_clock); + static nvmlReturn_t nvmlDeviceGetHandleByPciBusId_v2(const char *pciBusId, + nvmlDevice_t *device); + static nvmlReturn_t nvmlDeviceGetClockInfo(nvmlDevice_t device, + nvmlClockType_t type, + unsigned int *clock); + static nvmlReturn_t nvmlDeviceGetMaxClockInfo(nvmlDevice_t device, + nvmlClockType_t type, + unsigned int *clock); + static nvmlReturn_t nvmlDeviceSetApplicationsClocks(nvmlDevice_t device, + unsigned int mem_clock, + unsigned int sm_clock); /* ------------------- * * HIP @@ -140,177 +172,198 @@ public: // context management static hipError_t hipInit(unsigned int Flags); static hipError_t hipCtxDestroy(hipCtx_t ctx); - static hipError_t hipCtxCreate(hipCtx_t *pctx, unsigned int flags, hipDevice_t dev); + static hipError_t hipCtxCreate(hipCtx_t *pctx, unsigned int flags, + hipDevice_t dev); static hipError_t hipCtxPushCurrent(hipCtx_t ctx); static hipError_t hipCtxPopCurrent(hipCtx_t *pctx); - static hipError_t hipCtxGetDevice(hipDevice_t* result); - static hipError_t hipCtxEnablePeerAccess(hipCtx_t peerContext, unsigned int flags); + static hipError_t hipCtxGetDevice(hipDevice_t *result); + static hipError_t hipCtxEnablePeerAccess(hipCtx_t peerContext, + unsigned int flags); static hipError_t hipDriverGetVersion(int *driverVersion); // device management static hipError_t hipGetDevice(hipDevice_t *device, int ordinal); static hipError_t hipDeviceGetName(char *name, int len, hipDevice_t dev); static hipError_t hipDeviceGetPCIBusId(char *id, int len, hipDevice_t dev); - static hipError_t hipDeviceGetAttribute(int *pi, hipDeviceAttribute_t attrib, hipDevice_t dev); + static hipError_t hipDeviceGetAttribute(int *pi, hipDeviceAttribute_t attrib, + hipDevice_t dev); static hipError_t hipGetDeviceCount(int *count); // module management - static hipError_t hipModuleGetGlobal(hipDeviceptr_t *dptr, size_t* bytes, hipModule_t hmod, const char *name); + static hipError_t hipModuleGetGlobal(hipDeviceptr_t *dptr, size_t *bytes, + hipModule_t hmod, const char *name); static hipError_t hipModuleLoad(hipModule_t *module, const char *fname); - static hipError_t hipModuleLoadData(hipModule_t* module, const void* image); + static hipError_t hipModuleLoadData(hipModule_t *module, const void *image); static hipError_t hipModuleUnload(hipModule_t hmod); - static hipError_t hipModuleLoadDataEx(hipModule_t *module, const void *image, unsigned int numOptions, hipJitOption *options, void **optionValues); - static hipError_t hipModuleGetFunction(hipFunction_t *hfunc, hipModule_t hmod, const char *name); + static hipError_t hipModuleLoadDataEx(hipModule_t *module, const void *image, + unsigned int numOptions, + hipJitOption *options, + void **optionValues); + static hipError_t hipModuleGetFunction(hipFunction_t *hfunc, hipModule_t hmod, + const char *name); // stream management static hipError_t hipStreamCreate(hipStream_t *phStream, unsigned int Flags); static hipError_t hipStreamSynchronize(hipStream_t hStream); static hipError_t hipStreamDestroy(hipStream_t hStream); - static hipError_t hipModuleLaunchKernel(hipFunction_t f, unsigned int gridDimX, unsigned int gridDimY, unsigned int gridDimZ, unsigned int blockDimX, unsigned int blockDimY, unsigned int blockDimZ, unsigned int sharedMemBytes, hipStream_t hStream, void **kernelParams, void **extra); + static hipError_t + hipModuleLaunchKernel(hipFunction_t f, unsigned int gridDimX, + unsigned int gridDimY, unsigned int gridDimZ, + unsigned int blockDimX, unsigned int blockDimY, + unsigned int blockDimZ, unsigned int sharedMemBytes, + hipStream_t hStream, void **kernelParams, void **extra); // function management - static hipError_t hipFuncGetAttributes(hipFuncAttributes* attrib, void* hfunc); - static hipError_t hipFuncSetAttribute(hipFunction_t hfunc, hipFuncAttribute attrib, int value); - static hipError_t hipFuncSetCacheConfig(hipFunction_t hfunc, hipFuncCache_t config); + static hipError_t hipFuncGetAttributes(hipFuncAttributes *attrib, + void *hfunc); + static hipError_t hipFuncSetAttribute(hipFunction_t hfunc, + hipFuncAttribute attrib, int value); + static hipError_t hipFuncSetCacheConfig(hipFunction_t hfunc, + hipFuncCache_t config); // memory management static hipError_t hipMalloc(hipDeviceptr_t *dptr, size_t bytesize); - static hipError_t hipPointerGetAttribute(void * data, CUpointer_attribute attribute, hipDeviceptr_t ptr); - static hipError_t hipMemsetD8Async(hipDeviceptr_t dst, unsigned char x, size_t N, hipStream_t stream); - static hipError_t hipMemcpyDtoH(void *dstHost, hipDeviceptr_t srcDevice, size_t ByteCount); + static hipError_t hipPointerGetAttribute(void *data, + CUpointer_attribute attribute, + hipDeviceptr_t ptr); + static hipError_t hipMemsetD8Async(hipDeviceptr_t dst, unsigned char x, + size_t N, hipStream_t stream); + static hipError_t hipMemcpyDtoH(void *dstHost, hipDeviceptr_t srcDevice, + size_t ByteCount); static hipError_t hipFree(hipDeviceptr_t dptr); - static hipError_t hipMemcpyDtoHAsync(void *dstHost, hipDeviceptr_t srcDevice, size_t ByteCount, hipStream_t hStream); - static hipError_t hipMemcpyHtoDAsync(hipDeviceptr_t dstDevice, const void *srcHost, size_t ByteCount, hipStream_t hStream); - static hipError_t hipMemcpyHtoD(hipDeviceptr_t dstDevice, const void *srcHost, size_t ByteCount); + static hipError_t hipMemcpyDtoHAsync(void *dstHost, hipDeviceptr_t srcDevice, + size_t ByteCount, hipStream_t hStream); + static hipError_t hipMemcpyHtoDAsync(hipDeviceptr_t dstDevice, + const void *srcHost, size_t ByteCount, + hipStream_t hStream); + static hipError_t hipMemcpyHtoD(hipDeviceptr_t dstDevice, const void *srcHost, + size_t ByteCount); // event management static hipError_t hipEventCreate(hipEvent_t *phEvent, unsigned int Flags); - static hipError_t hipEventElapsedTime(float *pMilliseconds, hipEvent_t hStart, hipEvent_t hEnd); + static hipError_t hipEventElapsedTime(float *pMilliseconds, hipEvent_t hStart, + hipEvent_t hEnd); static hipError_t hipEventRecord(hipEvent_t hEvent, hipStream_t hStream); static hipError_t hipEventDestroy(hipEvent_t hEvent); - - private: - // Libraries - static void* cuda_; - static void* nvml_; - static void* hip_; - + static void *cuda_; + static void *nvml_; + static void *hip_; /* ------------------- * * CUDA * ------------------- */ // context management - static void* cuCtxGetCurrent_; - static void* cuCtxSetCurrent_; - static void* cuCtxDestroy_v2_; - static void* cuCtxCreate_v2_; - static void* cuCtxGetDevice_; - static void* cuCtxPushCurrent_v2_; - static void* cuCtxPopCurrent_v2_; - static void* cuCtxEnablePeerAccess_; - static void* cuDriverGetVersion_; - static void* cuInit_; + static void *cuCtxGetCurrent_; + static void *cuCtxSetCurrent_; + static void *cuCtxDestroy_v2_; + static void *cuCtxCreate_v2_; + static void *cuCtxGetDevice_; + static void *cuCtxPushCurrent_v2_; + static void *cuCtxPopCurrent_v2_; + static void *cuCtxEnablePeerAccess_; + static void *cuDriverGetVersion_; + static void *cuInit_; // device management - static void* cuDeviceGet_; - static void* cuDeviceGetName_; - static void* cuDeviceGetPCIBusId_; - static void* cuDeviceGetAttribute_; - static void* cuDeviceGetCount_; + static void *cuDeviceGet_; + static void *cuDeviceGetName_; + static void *cuDeviceGetPCIBusId_; + static void *cuDeviceGetAttribute_; + static void *cuDeviceGetCount_; // link management - static void* cuLinkAddData_v2_; - static void* cuLinkCreate_v2_; - static void* cuLinkDestroy_; - static void* cuLinkComplete_; + static void *cuLinkAddData_v2_; + static void *cuLinkCreate_v2_; + static void *cuLinkDestroy_; + static void *cuLinkComplete_; // module management - static void* cuModuleGetGlobal_v2_; - static void* cuModuleLoad_; - static void* cuModuleUnload_; - static void* cuModuleLoadDataEx_; - static void* cuModuleLoadData_; - static void* cuModuleGetFunction_; + static void *cuModuleGetGlobal_v2_; + static void *cuModuleLoad_; + static void *cuModuleUnload_; + static void *cuModuleLoadDataEx_; + static void *cuModuleLoadData_; + static void *cuModuleGetFunction_; // stream management - static void* cuStreamCreate_; - static void* cuStreamSynchronize_; - static void* cuStreamDestroy_v2_; - static void* cuStreamGetCtx_; - static void* cuLaunchKernel_; + static void *cuStreamCreate_; + static void *cuStreamSynchronize_; + static void *cuStreamDestroy_v2_; + static void *cuStreamGetCtx_; + static void *cuLaunchKernel_; // function management - static void* cuFuncGetAttribute_; - static void* cuFuncSetAttribute_; - static void* cuFuncSetCacheConfig_; + static void *cuFuncGetAttribute_; + static void *cuFuncSetAttribute_; + static void *cuFuncSetCacheConfig_; // memory management - static void* cuMemcpyDtoH_v2_; - static void* cuMemFree_v2_; - static void* cuMemcpyDtoHAsync_v2_; - static void* cuMemcpyHtoDAsync_v2_; - static void* cuMemcpyHtoD_v2_; - static void* cuMemAlloc_v2_; - static void* cuMemsetD8Async_; - static void* cuPointerGetAttribute_; + static void *cuMemcpyDtoH_v2_; + static void *cuMemFree_v2_; + static void *cuMemcpyDtoHAsync_v2_; + static void *cuMemcpyHtoDAsync_v2_; + static void *cuMemcpyHtoD_v2_; + static void *cuMemAlloc_v2_; + static void *cuMemsetD8Async_; + static void *cuPointerGetAttribute_; // event management - static void* cuEventCreate_; - static void* cuEventElapsedTime_; - static void* cuEventRecord_; - static void* cuEventDestroy_v2_; + static void *cuEventCreate_; + static void *cuEventElapsedTime_; + static void *cuEventRecord_; + static void *cuEventDestroy_v2_; /* ------------------- * * NVML * ------------------- */ - static void* nvmlInit_v2_; - static void* nvmlDeviceGetHandleByPciBusId_v2_; - static void* nvmlDeviceGetClockInfo_; - static void* nvmlDeviceGetMaxClockInfo_; - static void* nvmlDeviceSetApplicationsClocks_; + static void *nvmlInit_v2_; + static void *nvmlDeviceGetHandleByPciBusId_v2_; + static void *nvmlDeviceGetClockInfo_; + static void *nvmlDeviceGetMaxClockInfo_; + static void *nvmlDeviceSetApplicationsClocks_; /* ------------------- * * HIP * ------------------- */ // context management - static void* hipInit_; - static void* hipCtxDestroy_; - static void* hipCtxCreate_; - static void* hipCtxPushCurrent_; - static void* hipCtxPopCurrent_; - static void* hipCtxGetDevice_; - static void* hipCtxEnablePeerAccess_; - static void* hipDriverGetVersion_; + static void *hipInit_; + static void *hipCtxDestroy_; + static void *hipCtxCreate_; + static void *hipCtxPushCurrent_; + static void *hipCtxPopCurrent_; + static void *hipCtxGetDevice_; + static void *hipCtxEnablePeerAccess_; + static void *hipDriverGetVersion_; // device management - static void* hipGetDevice_; - static void* hipDeviceGetName_; - static void* hipDeviceGetPCIBusId_; - static void* hipDeviceGetAttribute_; - static void* hipGetDeviceCount_; + static void *hipGetDevice_; + static void *hipDeviceGetName_; + static void *hipDeviceGetPCIBusId_; + static void *hipDeviceGetAttribute_; + static void *hipGetDeviceCount_; // module management - static void* hipModuleGetGlobal_; - static void* hipModuleLoad_; - static void* hipModuleLoadData_; - static void* hipModuleUnload_; - static void* hipModuleLoadDataEx_; - static void* hipModuleGetFunction_; + static void *hipModuleGetGlobal_; + static void *hipModuleLoad_; + static void *hipModuleLoadData_; + static void *hipModuleUnload_; + static void *hipModuleLoadDataEx_; + static void *hipModuleGetFunction_; // stream management - static void* hipStreamCreate_; - static void* hipStreamSynchronize_; - static void* hipStreamDestroy_; - static void* hipModuleLaunchKernel_;; + static void *hipStreamCreate_; + static void *hipStreamSynchronize_; + static void *hipStreamDestroy_; + static void *hipModuleLaunchKernel_; + ; // function management - static void* hipFuncGetAttributes_; - static void* hipFuncSetAttribute_; - static void* hipFuncSetCacheConfig_; + static void *hipFuncGetAttributes_; + static void *hipFuncSetAttribute_; + static void *hipFuncSetCacheConfig_; // memory management - static void* hipMalloc_; - static void* hipPointerGetAttribute_; - static void* hipMemsetD8Async_; - static void* hipMemcpyDtoH_; - static void* hipFree_; - static void* hipMemcpyDtoHAsync_; - static void* hipMemcpyHtoDAsync_; - static void* hipMemcpyHtoD_; + static void *hipMalloc_; + static void *hipPointerGetAttribute_; + static void *hipMemsetD8Async_; + static void *hipMemcpyDtoH_; + static void *hipFree_; + static void *hipMemcpyDtoHAsync_; + static void *hipMemcpyHtoDAsync_; + static void *hipMemcpyHtoD_; // event management - static void* hipEventCreate_; - static void* hipEventElapsedTime_; - static void* hipEventRecord_; - static void* hipEventDestroy_; + static void *hipEventCreate_; + static void *hipEventElapsedTime_; + static void *hipEventRecord_; + static void *hipEventDestroy_; }; -} -} - +} // namespace driver +} // namespace triton #endif diff --git a/include/triton/driver/error.h b/include/triton/driver/error.h old mode 100755 new mode 100644 index c3168c1ad..229e1dee4 --- a/include/triton/driver/error.h +++ b/include/triton/driver/error.h @@ -3,223 +3,252 @@ #ifndef _TRITON_DRIVER_ERROR_H_ #define _TRITON_DRIVER_ERROR_H_ -#include #include "triton/driver/dispatch.h" +#include +namespace triton { -namespace triton -{ +namespace driver { - namespace driver - { +namespace exception { - namespace exception - { +namespace nvrtc { - namespace nvrtc - { +#define TRITON_CREATE_NVRTC_EXCEPTION(name, msg) \ + class name : public std::exception { \ + public: \ + const char *what() const throw() override { return "NVRTC: Error- " msg; } \ + } -#define TRITON_CREATE_NVRTC_EXCEPTION(name, msg) \ -class name: public std::exception { public: const char * what() const throw() override { return "NVRTC: Error- " msg; } } - - TRITON_CREATE_NVRTC_EXCEPTION(out_of_memory ,"out of memory"); - TRITON_CREATE_NVRTC_EXCEPTION(program_creation_failure ,"program creation failure"); - TRITON_CREATE_NVRTC_EXCEPTION(invalid_input ,"invalid input"); - TRITON_CREATE_NVRTC_EXCEPTION(invalid_program ,"invalid program"); - TRITON_CREATE_NVRTC_EXCEPTION(invalid_option ,"invalid option"); - TRITON_CREATE_NVRTC_EXCEPTION(compilation ,"compilation"); - TRITON_CREATE_NVRTC_EXCEPTION(builtin_operation_failure ,"builtin operation failure"); - TRITON_CREATE_NVRTC_EXCEPTION(unknown_error ,"unknown error"); +TRITON_CREATE_NVRTC_EXCEPTION(out_of_memory, "out of memory"); +TRITON_CREATE_NVRTC_EXCEPTION(program_creation_failure, + "program creation failure"); +TRITON_CREATE_NVRTC_EXCEPTION(invalid_input, "invalid input"); +TRITON_CREATE_NVRTC_EXCEPTION(invalid_program, "invalid program"); +TRITON_CREATE_NVRTC_EXCEPTION(invalid_option, "invalid option"); +TRITON_CREATE_NVRTC_EXCEPTION(compilation, "compilation"); +TRITON_CREATE_NVRTC_EXCEPTION(builtin_operation_failure, + "builtin operation failure"); +TRITON_CREATE_NVRTC_EXCEPTION(unknown_error, "unknown error"); #undef TRITON_CREATE_NVRTC_EXCEPTION +} // namespace nvrtc + +namespace cuda { +class base : public std::exception {}; + +#define TRITON_CREATE_CUDA_EXCEPTION(name, msg) \ + class name : public base { \ + public: \ + const char *what() const throw() override { return "CUDA: Error- " msg; } \ } - - namespace cuda - { - class base: public std::exception{}; - -#define TRITON_CREATE_CUDA_EXCEPTION(name, msg) \ -class name: public base { public:const char * what() const throw() override { return "CUDA: Error- " msg; } } - - - TRITON_CREATE_CUDA_EXCEPTION(invalid_value ,"invalid value"); - TRITON_CREATE_CUDA_EXCEPTION(out_of_memory ,"out of memory"); - TRITON_CREATE_CUDA_EXCEPTION(not_initialized ,"not initialized"); - TRITON_CREATE_CUDA_EXCEPTION(deinitialized ,"deinitialized"); - TRITON_CREATE_CUDA_EXCEPTION(profiler_disabled ,"profiler disabled"); - TRITON_CREATE_CUDA_EXCEPTION(profiler_not_initialized ,"profiler not initialized"); - TRITON_CREATE_CUDA_EXCEPTION(profiler_already_started ,"profiler already started"); - TRITON_CREATE_CUDA_EXCEPTION(profiler_already_stopped ,"profiler already stopped"); - TRITON_CREATE_CUDA_EXCEPTION(no_device ,"no device"); - TRITON_CREATE_CUDA_EXCEPTION(invalid_device ,"invalid device"); - TRITON_CREATE_CUDA_EXCEPTION(invalid_image ,"invalid image"); - TRITON_CREATE_CUDA_EXCEPTION(invalid_context ,"invalid context"); - TRITON_CREATE_CUDA_EXCEPTION(context_already_current ,"context already current"); - TRITON_CREATE_CUDA_EXCEPTION(map_failed ,"map failed"); - TRITON_CREATE_CUDA_EXCEPTION(unmap_failed ,"unmap failed"); - TRITON_CREATE_CUDA_EXCEPTION(array_is_mapped ,"array is mapped"); - TRITON_CREATE_CUDA_EXCEPTION(already_mapped ,"already mapped"); - TRITON_CREATE_CUDA_EXCEPTION(no_binary_for_gpu ,"no binary for gpu"); - TRITON_CREATE_CUDA_EXCEPTION(already_acquired ,"already acquired"); - TRITON_CREATE_CUDA_EXCEPTION(not_mapped ,"not mapped"); - TRITON_CREATE_CUDA_EXCEPTION(not_mapped_as_array ,"not mapped as array"); - TRITON_CREATE_CUDA_EXCEPTION(not_mapped_as_pointer ,"not mapped as pointer"); - TRITON_CREATE_CUDA_EXCEPTION(ecc_uncorrectable ,"ecc uncorrectable"); - TRITON_CREATE_CUDA_EXCEPTION(unsupported_limit ,"unsupported limit"); - TRITON_CREATE_CUDA_EXCEPTION(context_already_in_use ,"context already in use"); - TRITON_CREATE_CUDA_EXCEPTION(peer_access_unsupported ,"peer access unsupported"); - TRITON_CREATE_CUDA_EXCEPTION(invalid_ptx ,"invalid ptx"); - TRITON_CREATE_CUDA_EXCEPTION(invalid_graphics_context ,"invalid graphics context"); - TRITON_CREATE_CUDA_EXCEPTION(invalid_source ,"invalid source"); - TRITON_CREATE_CUDA_EXCEPTION(file_not_found ,"file not found"); - TRITON_CREATE_CUDA_EXCEPTION(shared_object_symbol_not_found ,"shared object symbol not found"); - TRITON_CREATE_CUDA_EXCEPTION(shared_object_init_failed ,"shared object init failed"); - TRITON_CREATE_CUDA_EXCEPTION(operating_system ,"operating system"); - TRITON_CREATE_CUDA_EXCEPTION(invalid_handle ,"invalid handle"); - TRITON_CREATE_CUDA_EXCEPTION(not_found ,"not found"); - TRITON_CREATE_CUDA_EXCEPTION(not_ready ,"not ready"); - TRITON_CREATE_CUDA_EXCEPTION(illegal_address ,"illegal address"); - TRITON_CREATE_CUDA_EXCEPTION(launch_out_of_resources ,"launch out of resources"); - TRITON_CREATE_CUDA_EXCEPTION(launch_timeout ,"launch timeout"); - TRITON_CREATE_CUDA_EXCEPTION(launch_incompatible_texturing ,"launch incompatible texturing"); - TRITON_CREATE_CUDA_EXCEPTION(peer_access_already_enabled ,"peer access already enabled"); - TRITON_CREATE_CUDA_EXCEPTION(peer_access_not_enabled ,"peer access not enabled"); - TRITON_CREATE_CUDA_EXCEPTION(primary_context_active ,"primary context active"); - TRITON_CREATE_CUDA_EXCEPTION(context_is_destroyed ,"context is destroyed"); - TRITON_CREATE_CUDA_EXCEPTION(assert_error ,"assert"); - TRITON_CREATE_CUDA_EXCEPTION(too_many_peers ,"too many peers"); - TRITON_CREATE_CUDA_EXCEPTION(host_memory_already_registered ,"host memory already registered"); - TRITON_CREATE_CUDA_EXCEPTION(host_memory_not_registered ,"hot memory not registered"); - TRITON_CREATE_CUDA_EXCEPTION(hardware_stack_error ,"hardware stack error"); - TRITON_CREATE_CUDA_EXCEPTION(illegal_instruction ,"illegal instruction"); - TRITON_CREATE_CUDA_EXCEPTION(misaligned_address ,"misaligned address"); - TRITON_CREATE_CUDA_EXCEPTION(invalid_address_space ,"invalid address space"); - TRITON_CREATE_CUDA_EXCEPTION(invalid_pc ,"invalid pc"); - TRITON_CREATE_CUDA_EXCEPTION(launch_failed ,"launch failed"); - TRITON_CREATE_CUDA_EXCEPTION(not_permitted ,"not permitted"); - TRITON_CREATE_CUDA_EXCEPTION(not_supported ,"not supported"); - TRITON_CREATE_CUDA_EXCEPTION(unknown ,"unknown"); +TRITON_CREATE_CUDA_EXCEPTION(invalid_value, "invalid value"); +TRITON_CREATE_CUDA_EXCEPTION(out_of_memory, "out of memory"); +TRITON_CREATE_CUDA_EXCEPTION(not_initialized, "not initialized"); +TRITON_CREATE_CUDA_EXCEPTION(deinitialized, "deinitialized"); +TRITON_CREATE_CUDA_EXCEPTION(profiler_disabled, "profiler disabled"); +TRITON_CREATE_CUDA_EXCEPTION(profiler_not_initialized, + "profiler not initialized"); +TRITON_CREATE_CUDA_EXCEPTION(profiler_already_started, + "profiler already started"); +TRITON_CREATE_CUDA_EXCEPTION(profiler_already_stopped, + "profiler already stopped"); +TRITON_CREATE_CUDA_EXCEPTION(no_device, "no device"); +TRITON_CREATE_CUDA_EXCEPTION(invalid_device, "invalid device"); +TRITON_CREATE_CUDA_EXCEPTION(invalid_image, "invalid image"); +TRITON_CREATE_CUDA_EXCEPTION(invalid_context, "invalid context"); +TRITON_CREATE_CUDA_EXCEPTION(context_already_current, + "context already current"); +TRITON_CREATE_CUDA_EXCEPTION(map_failed, "map failed"); +TRITON_CREATE_CUDA_EXCEPTION(unmap_failed, "unmap failed"); +TRITON_CREATE_CUDA_EXCEPTION(array_is_mapped, "array is mapped"); +TRITON_CREATE_CUDA_EXCEPTION(already_mapped, "already mapped"); +TRITON_CREATE_CUDA_EXCEPTION(no_binary_for_gpu, "no binary for gpu"); +TRITON_CREATE_CUDA_EXCEPTION(already_acquired, "already acquired"); +TRITON_CREATE_CUDA_EXCEPTION(not_mapped, "not mapped"); +TRITON_CREATE_CUDA_EXCEPTION(not_mapped_as_array, "not mapped as array"); +TRITON_CREATE_CUDA_EXCEPTION(not_mapped_as_pointer, "not mapped as pointer"); +TRITON_CREATE_CUDA_EXCEPTION(ecc_uncorrectable, "ecc uncorrectable"); +TRITON_CREATE_CUDA_EXCEPTION(unsupported_limit, "unsupported limit"); +TRITON_CREATE_CUDA_EXCEPTION(context_already_in_use, "context already in use"); +TRITON_CREATE_CUDA_EXCEPTION(peer_access_unsupported, + "peer access unsupported"); +TRITON_CREATE_CUDA_EXCEPTION(invalid_ptx, "invalid ptx"); +TRITON_CREATE_CUDA_EXCEPTION(invalid_graphics_context, + "invalid graphics context"); +TRITON_CREATE_CUDA_EXCEPTION(invalid_source, "invalid source"); +TRITON_CREATE_CUDA_EXCEPTION(file_not_found, "file not found"); +TRITON_CREATE_CUDA_EXCEPTION(shared_object_symbol_not_found, + "shared object symbol not found"); +TRITON_CREATE_CUDA_EXCEPTION(shared_object_init_failed, + "shared object init failed"); +TRITON_CREATE_CUDA_EXCEPTION(operating_system, "operating system"); +TRITON_CREATE_CUDA_EXCEPTION(invalid_handle, "invalid handle"); +TRITON_CREATE_CUDA_EXCEPTION(not_found, "not found"); +TRITON_CREATE_CUDA_EXCEPTION(not_ready, "not ready"); +TRITON_CREATE_CUDA_EXCEPTION(illegal_address, "illegal address"); +TRITON_CREATE_CUDA_EXCEPTION(launch_out_of_resources, + "launch out of resources"); +TRITON_CREATE_CUDA_EXCEPTION(launch_timeout, "launch timeout"); +TRITON_CREATE_CUDA_EXCEPTION(launch_incompatible_texturing, + "launch incompatible texturing"); +TRITON_CREATE_CUDA_EXCEPTION(peer_access_already_enabled, + "peer access already enabled"); +TRITON_CREATE_CUDA_EXCEPTION(peer_access_not_enabled, + "peer access not enabled"); +TRITON_CREATE_CUDA_EXCEPTION(primary_context_active, "primary context active"); +TRITON_CREATE_CUDA_EXCEPTION(context_is_destroyed, "context is destroyed"); +TRITON_CREATE_CUDA_EXCEPTION(assert_error, "assert"); +TRITON_CREATE_CUDA_EXCEPTION(too_many_peers, "too many peers"); +TRITON_CREATE_CUDA_EXCEPTION(host_memory_already_registered, + "host memory already registered"); +TRITON_CREATE_CUDA_EXCEPTION(host_memory_not_registered, + "hot memory not registered"); +TRITON_CREATE_CUDA_EXCEPTION(hardware_stack_error, "hardware stack error"); +TRITON_CREATE_CUDA_EXCEPTION(illegal_instruction, "illegal instruction"); +TRITON_CREATE_CUDA_EXCEPTION(misaligned_address, "misaligned address"); +TRITON_CREATE_CUDA_EXCEPTION(invalid_address_space, "invalid address space"); +TRITON_CREATE_CUDA_EXCEPTION(invalid_pc, "invalid pc"); +TRITON_CREATE_CUDA_EXCEPTION(launch_failed, "launch failed"); +TRITON_CREATE_CUDA_EXCEPTION(not_permitted, "not permitted"); +TRITON_CREATE_CUDA_EXCEPTION(not_supported, "not supported"); +TRITON_CREATE_CUDA_EXCEPTION(unknown, "unknown"); #undef TRITON_CREATE_CUDA_EXCEPTION +} // namespace cuda + +namespace cublas { +class base : public std::exception {}; + +#define TRITON_CREATE_CUBLAS_EXCEPTION(name, msg) \ + class name : public base { \ + public: \ + const char *what() const throw() override { \ + return "CUBLAS: Error- " msg; \ + } \ } - namespace cublas - { - class base: public std::exception{}; - -#define TRITON_CREATE_CUBLAS_EXCEPTION(name, msg) \ -class name: public base { public: const char * what() const throw() override { return "CUBLAS: Error- " msg; } } - - TRITON_CREATE_CUBLAS_EXCEPTION(not_initialized ,"not initialized"); - TRITON_CREATE_CUBLAS_EXCEPTION(alloc_failed ,"alloc failed"); - TRITON_CREATE_CUBLAS_EXCEPTION(invalid_value ,"invalid value"); - TRITON_CREATE_CUBLAS_EXCEPTION(arch_mismatch ,"arch mismatch"); - TRITON_CREATE_CUBLAS_EXCEPTION(mapping_error ,"mapping error"); - TRITON_CREATE_CUBLAS_EXCEPTION(execution_failed ,"execution failed"); - TRITON_CREATE_CUBLAS_EXCEPTION(internal_error ,"internal error"); - TRITON_CREATE_CUBLAS_EXCEPTION(not_supported ,"not supported"); - TRITON_CREATE_CUBLAS_EXCEPTION(license_error ,"license error"); - TRITON_CREATE_CUBLAS_EXCEPTION(unknown ,"unknown"); +TRITON_CREATE_CUBLAS_EXCEPTION(not_initialized, "not initialized"); +TRITON_CREATE_CUBLAS_EXCEPTION(alloc_failed, "alloc failed"); +TRITON_CREATE_CUBLAS_EXCEPTION(invalid_value, "invalid value"); +TRITON_CREATE_CUBLAS_EXCEPTION(arch_mismatch, "arch mismatch"); +TRITON_CREATE_CUBLAS_EXCEPTION(mapping_error, "mapping error"); +TRITON_CREATE_CUBLAS_EXCEPTION(execution_failed, "execution failed"); +TRITON_CREATE_CUBLAS_EXCEPTION(internal_error, "internal error"); +TRITON_CREATE_CUBLAS_EXCEPTION(not_supported, "not supported"); +TRITON_CREATE_CUBLAS_EXCEPTION(license_error, "license error"); +TRITON_CREATE_CUBLAS_EXCEPTION(unknown, "unknown"); #undef TRITON_CREATE_CUBLAS_EXCEPTION +} // namespace cublas + +namespace cudnn { +#define TRITON_CREATE_CUDNN_EXCEPTION(name, msg) \ + class name : public std::exception { \ + public: \ + const char *what() const throw() override { return "CUDNN: Error- " msg; } \ } - namespace cudnn - { -#define TRITON_CREATE_CUDNN_EXCEPTION(name, msg) \ -class name: public std::exception { public: const char * what() const throw() override { return "CUDNN: Error- " msg; } } +TRITON_CREATE_CUDNN_EXCEPTION(not_initialized, "not initialized"); +TRITON_CREATE_CUDNN_EXCEPTION(alloc_failed, "allocation failed"); +TRITON_CREATE_CUDNN_EXCEPTION(bad_param, "bad param"); +TRITON_CREATE_CUDNN_EXCEPTION(internal_error, "internal error"); +TRITON_CREATE_CUDNN_EXCEPTION(invalid_value, "invalid value"); +TRITON_CREATE_CUDNN_EXCEPTION(arch_mismatch, "arch mismatch"); +TRITON_CREATE_CUDNN_EXCEPTION(mapping_error, "mapping error"); +TRITON_CREATE_CUDNN_EXCEPTION(execution_failed, "execution failed"); +TRITON_CREATE_CUDNN_EXCEPTION(not_supported, "not supported"); +TRITON_CREATE_CUDNN_EXCEPTION(license_error, "license error"); +TRITON_CREATE_CUDNN_EXCEPTION(runtime_prerequisite_missing, + "prerequisite missing"); +TRITON_CREATE_CUDNN_EXCEPTION(runtime_in_progress, "runtime in progress"); +TRITON_CREATE_CUDNN_EXCEPTION(runtime_fp_overflow, "runtime fp overflow"); +} // namespace cudnn - TRITON_CREATE_CUDNN_EXCEPTION(not_initialized ,"not initialized"); - TRITON_CREATE_CUDNN_EXCEPTION(alloc_failed ,"allocation failed"); - TRITON_CREATE_CUDNN_EXCEPTION(bad_param ,"bad param"); - TRITON_CREATE_CUDNN_EXCEPTION(internal_error ,"internal error"); - TRITON_CREATE_CUDNN_EXCEPTION(invalid_value ,"invalid value"); - TRITON_CREATE_CUDNN_EXCEPTION(arch_mismatch ,"arch mismatch"); - TRITON_CREATE_CUDNN_EXCEPTION(mapping_error ,"mapping error"); - TRITON_CREATE_CUDNN_EXCEPTION(execution_failed ,"execution failed"); - TRITON_CREATE_CUDNN_EXCEPTION(not_supported ,"not supported"); - TRITON_CREATE_CUDNN_EXCEPTION(license_error ,"license error"); - TRITON_CREATE_CUDNN_EXCEPTION(runtime_prerequisite_missing ,"prerequisite missing"); - TRITON_CREATE_CUDNN_EXCEPTION(runtime_in_progress ,"runtime in progress"); - TRITON_CREATE_CUDNN_EXCEPTION(runtime_fp_overflow ,"runtime fp overflow"); +namespace hip { +class base : public std::exception {}; + +#define TRITON_CREATE_HIP_EXCEPTION(name, msg) \ + class name : public base { \ + public: \ + const char *what() const throw() override { return "HIP: Error- " msg; } \ } - - - - namespace hip - { - class base: public std::exception{}; - -#define TRITON_CREATE_HIP_EXCEPTION(name, msg) \ -class name: public base { public:const char * what() const throw() override { return "HIP: Error- " msg; } } - - - TRITON_CREATE_HIP_EXCEPTION(invalid_value ,"invalid value"); - TRITON_CREATE_HIP_EXCEPTION(out_of_memory ,"out of memory"); - TRITON_CREATE_HIP_EXCEPTION(not_initialized ,"not initialized"); - TRITON_CREATE_HIP_EXCEPTION(deinitialized ,"deinitialized"); - TRITON_CREATE_HIP_EXCEPTION(profiler_disabled ,"profiler disabled"); - TRITON_CREATE_HIP_EXCEPTION(profiler_not_initialized ,"profiler not initialized"); - TRITON_CREATE_HIP_EXCEPTION(profiler_already_started ,"profiler already started"); - TRITON_CREATE_HIP_EXCEPTION(profiler_already_stopped ,"profiler already stopped"); - TRITON_CREATE_HIP_EXCEPTION(no_device ,"no device"); - TRITON_CREATE_HIP_EXCEPTION(invalid_device ,"invalid device"); - TRITON_CREATE_HIP_EXCEPTION(invalid_image ,"invalid image"); - TRITON_CREATE_HIP_EXCEPTION(invalid_context ,"invalid context"); - TRITON_CREATE_HIP_EXCEPTION(context_already_current ,"context already current"); - TRITON_CREATE_HIP_EXCEPTION(map_failed ,"map failed"); - TRITON_CREATE_HIP_EXCEPTION(unmap_failed ,"unmap failed"); - TRITON_CREATE_HIP_EXCEPTION(array_is_mapped ,"array is mapped"); - TRITON_CREATE_HIP_EXCEPTION(already_mapped ,"already mapped"); - TRITON_CREATE_HIP_EXCEPTION(no_binary_for_gpu ,"no binary for gpu"); - TRITON_CREATE_HIP_EXCEPTION(already_acquired ,"already acquired"); - TRITON_CREATE_HIP_EXCEPTION(not_mapped ,"not mapped"); - TRITON_CREATE_HIP_EXCEPTION(not_mapped_as_array ,"not mapped as array"); - TRITON_CREATE_HIP_EXCEPTION(not_mapped_as_pointer ,"not mapped as pointer"); - TRITON_CREATE_HIP_EXCEPTION(ecc_uncorrectable ,"ecc uncorrectable"); - TRITON_CREATE_HIP_EXCEPTION(unsupported_limit ,"unsupported limit"); - TRITON_CREATE_HIP_EXCEPTION(context_already_in_use ,"context already in use"); - TRITON_CREATE_HIP_EXCEPTION(peer_access_unsupported ,"peer access unsupported"); - TRITON_CREATE_HIP_EXCEPTION(invalid_ptx ,"invalid ptx"); - TRITON_CREATE_HIP_EXCEPTION(invalid_graphics_context ,"invalid graphics context"); - TRITON_CREATE_HIP_EXCEPTION(invalid_source ,"invalid source"); - TRITON_CREATE_HIP_EXCEPTION(file_not_found ,"file not found"); - TRITON_CREATE_HIP_EXCEPTION(shared_object_symbol_not_found ,"shared object symbol not found"); - TRITON_CREATE_HIP_EXCEPTION(shared_object_init_failed ,"shared object init failed"); - TRITON_CREATE_HIP_EXCEPTION(operating_system ,"operating system"); - TRITON_CREATE_HIP_EXCEPTION(invalid_handle ,"invalid handle"); - TRITON_CREATE_HIP_EXCEPTION(not_found ,"not found"); - TRITON_CREATE_HIP_EXCEPTION(not_ready ,"not ready"); - TRITON_CREATE_HIP_EXCEPTION(illegal_address ,"illegal address"); - TRITON_CREATE_HIP_EXCEPTION(launch_out_of_resources ,"launch out of resources"); - TRITON_CREATE_HIP_EXCEPTION(launch_timeout ,"launch timeout"); - TRITON_CREATE_HIP_EXCEPTION(launch_incompatible_texturing ,"launch incompatible texturing"); - TRITON_CREATE_HIP_EXCEPTION(peer_access_already_enabled ,"peer access already enabled"); - TRITON_CREATE_HIP_EXCEPTION(peer_access_not_enabled ,"peer access not enabled"); - TRITON_CREATE_HIP_EXCEPTION(primary_context_active ,"primary context active"); - TRITON_CREATE_HIP_EXCEPTION(context_is_destroyed ,"context is destroyed"); - TRITON_CREATE_HIP_EXCEPTION(assert_error ,"assert"); - TRITON_CREATE_HIP_EXCEPTION(too_many_peers ,"too many peers"); - TRITON_CREATE_HIP_EXCEPTION(host_memory_already_registered ,"host memory already registered"); - TRITON_CREATE_HIP_EXCEPTION(host_memory_not_registered ,"hot memory not registered"); - TRITON_CREATE_HIP_EXCEPTION(hardware_stack_error ,"hardware stack error"); - TRITON_CREATE_HIP_EXCEPTION(illegal_instruction ,"illegal instruction"); - TRITON_CREATE_HIP_EXCEPTION(misaligned_address ,"misaligned address"); - TRITON_CREATE_HIP_EXCEPTION(invalid_address_space ,"invalid address space"); - TRITON_CREATE_HIP_EXCEPTION(invalid_pc ,"invalid pc"); - TRITON_CREATE_HIP_EXCEPTION(launch_failed ,"launch failed"); - TRITON_CREATE_HIP_EXCEPTION(not_permitted ,"not permitted"); - TRITON_CREATE_HIP_EXCEPTION(not_supported ,"not supported"); - TRITON_CREATE_HIP_EXCEPTION(invalid_symbol ,"invalid symbol"); - TRITON_CREATE_HIP_EXCEPTION(unknown ,"unknown"); +TRITON_CREATE_HIP_EXCEPTION(invalid_value, "invalid value"); +TRITON_CREATE_HIP_EXCEPTION(out_of_memory, "out of memory"); +TRITON_CREATE_HIP_EXCEPTION(not_initialized, "not initialized"); +TRITON_CREATE_HIP_EXCEPTION(deinitialized, "deinitialized"); +TRITON_CREATE_HIP_EXCEPTION(profiler_disabled, "profiler disabled"); +TRITON_CREATE_HIP_EXCEPTION(profiler_not_initialized, + "profiler not initialized"); +TRITON_CREATE_HIP_EXCEPTION(profiler_already_started, + "profiler already started"); +TRITON_CREATE_HIP_EXCEPTION(profiler_already_stopped, + "profiler already stopped"); +TRITON_CREATE_HIP_EXCEPTION(no_device, "no device"); +TRITON_CREATE_HIP_EXCEPTION(invalid_device, "invalid device"); +TRITON_CREATE_HIP_EXCEPTION(invalid_image, "invalid image"); +TRITON_CREATE_HIP_EXCEPTION(invalid_context, "invalid context"); +TRITON_CREATE_HIP_EXCEPTION(context_already_current, "context already current"); +TRITON_CREATE_HIP_EXCEPTION(map_failed, "map failed"); +TRITON_CREATE_HIP_EXCEPTION(unmap_failed, "unmap failed"); +TRITON_CREATE_HIP_EXCEPTION(array_is_mapped, "array is mapped"); +TRITON_CREATE_HIP_EXCEPTION(already_mapped, "already mapped"); +TRITON_CREATE_HIP_EXCEPTION(no_binary_for_gpu, "no binary for gpu"); +TRITON_CREATE_HIP_EXCEPTION(already_acquired, "already acquired"); +TRITON_CREATE_HIP_EXCEPTION(not_mapped, "not mapped"); +TRITON_CREATE_HIP_EXCEPTION(not_mapped_as_array, "not mapped as array"); +TRITON_CREATE_HIP_EXCEPTION(not_mapped_as_pointer, "not mapped as pointer"); +TRITON_CREATE_HIP_EXCEPTION(ecc_uncorrectable, "ecc uncorrectable"); +TRITON_CREATE_HIP_EXCEPTION(unsupported_limit, "unsupported limit"); +TRITON_CREATE_HIP_EXCEPTION(context_already_in_use, "context already in use"); +TRITON_CREATE_HIP_EXCEPTION(peer_access_unsupported, "peer access unsupported"); +TRITON_CREATE_HIP_EXCEPTION(invalid_ptx, "invalid ptx"); +TRITON_CREATE_HIP_EXCEPTION(invalid_graphics_context, + "invalid graphics context"); +TRITON_CREATE_HIP_EXCEPTION(invalid_source, "invalid source"); +TRITON_CREATE_HIP_EXCEPTION(file_not_found, "file not found"); +TRITON_CREATE_HIP_EXCEPTION(shared_object_symbol_not_found, + "shared object symbol not found"); +TRITON_CREATE_HIP_EXCEPTION(shared_object_init_failed, + "shared object init failed"); +TRITON_CREATE_HIP_EXCEPTION(operating_system, "operating system"); +TRITON_CREATE_HIP_EXCEPTION(invalid_handle, "invalid handle"); +TRITON_CREATE_HIP_EXCEPTION(not_found, "not found"); +TRITON_CREATE_HIP_EXCEPTION(not_ready, "not ready"); +TRITON_CREATE_HIP_EXCEPTION(illegal_address, "illegal address"); +TRITON_CREATE_HIP_EXCEPTION(launch_out_of_resources, "launch out of resources"); +TRITON_CREATE_HIP_EXCEPTION(launch_timeout, "launch timeout"); +TRITON_CREATE_HIP_EXCEPTION(launch_incompatible_texturing, + "launch incompatible texturing"); +TRITON_CREATE_HIP_EXCEPTION(peer_access_already_enabled, + "peer access already enabled"); +TRITON_CREATE_HIP_EXCEPTION(peer_access_not_enabled, "peer access not enabled"); +TRITON_CREATE_HIP_EXCEPTION(primary_context_active, "primary context active"); +TRITON_CREATE_HIP_EXCEPTION(context_is_destroyed, "context is destroyed"); +TRITON_CREATE_HIP_EXCEPTION(assert_error, "assert"); +TRITON_CREATE_HIP_EXCEPTION(too_many_peers, "too many peers"); +TRITON_CREATE_HIP_EXCEPTION(host_memory_already_registered, + "host memory already registered"); +TRITON_CREATE_HIP_EXCEPTION(host_memory_not_registered, + "hot memory not registered"); +TRITON_CREATE_HIP_EXCEPTION(hardware_stack_error, "hardware stack error"); +TRITON_CREATE_HIP_EXCEPTION(illegal_instruction, "illegal instruction"); +TRITON_CREATE_HIP_EXCEPTION(misaligned_address, "misaligned address"); +TRITON_CREATE_HIP_EXCEPTION(invalid_address_space, "invalid address space"); +TRITON_CREATE_HIP_EXCEPTION(invalid_pc, "invalid pc"); +TRITON_CREATE_HIP_EXCEPTION(launch_failed, "launch failed"); +TRITON_CREATE_HIP_EXCEPTION(not_permitted, "not permitted"); +TRITON_CREATE_HIP_EXCEPTION(not_supported, "not supported"); +TRITON_CREATE_HIP_EXCEPTION(invalid_symbol, "invalid symbol"); +TRITON_CREATE_HIP_EXCEPTION(unknown, "unknown"); #undef TRITON_CREATE_CUDA_EXCEPTION - } +} // namespace hip - } - } -} +} // namespace exception +} // namespace driver +} // namespace triton #endif diff --git a/include/triton/driver/llvm.h b/include/triton/driver/llvm.h index c0c1c0f37..b3ce0d0cc 100644 --- a/include/triton/driver/llvm.h +++ b/include/triton/driver/llvm.h @@ -1,20 +1,21 @@ -#include #include "triton/driver/dispatch.h" +#include -namespace llvm{ +namespace llvm { class Module; } -namespace triton{ -namespace driver{ +namespace triton { +namespace driver { void init_llvm(); -std::string path_to_ptxas(int& version); -std::string llir_to_ptx(llvm::Module* module, int cc, int version); -std::string ptx_to_cubin(const std::string& ptx, const std::string& ptxas_path, int cc); -CUmodule ptx_to_cumodule(const std::string& ptx, int cc); -std::string llir_to_amdgpu(llvm::Module* module, const std::string& proc); -hipModule_t amdgpu_to_hipmodule(const std::string& path); +std::string path_to_ptxas(int &version); +std::string llir_to_ptx(llvm::Module *module, int cc, int version); +std::string ptx_to_cubin(const std::string &ptx, const std::string &ptxas_path, + int cc); +CUmodule ptx_to_cumodule(const std::string &ptx, int cc); +std::string llir_to_amdgpu(llvm::Module *module, const std::string &proc); +hipModule_t amdgpu_to_hipmodule(const std::string &path); -} -} +} // namespace driver +} // namespace triton diff --git a/include/triton/tools/bench.hpp b/include/triton/tools/bench.hpp index c0dbd5061..258e06933 100644 --- a/include/triton/tools/bench.hpp +++ b/include/triton/tools/bench.hpp @@ -3,52 +3,55 @@ #ifndef _TRITON_TOOLS_BENCH_H_ #define _TRITON_TOOLS_BENCH_H_ -#include -#include -#include #include "triton/driver/device.h" #include "triton/driver/stream.h" +#include +#include +#include -namespace triton{ -namespace tools{ +namespace triton { +namespace tools { -class timer{ - typedef std::chrono::high_resolution_clock high_resolution_clock; - typedef std::chrono::nanoseconds nanoseconds; +class timer { + typedef std::chrono::high_resolution_clock high_resolution_clock; + typedef std::chrono::nanoseconds nanoseconds; public: - explicit timer(bool run = false) - { if (run) start(); } + explicit timer(bool run = false) { + if (run) + start(); + } - void start() - { _start = high_resolution_clock::now(); } + void start() { _start = high_resolution_clock::now(); } - nanoseconds get() const - { return std::chrono::duration_cast(high_resolution_clock::now() - _start); } + nanoseconds get() const { + return std::chrono::duration_cast( + high_resolution_clock::now() - _start); + } private: - high_resolution_clock::time_point _start; + high_resolution_clock::time_point _start; }; -inline double bench(std::function const & op, driver::stream * stream, size_t warmup = 10, size_t repeat = 200) -{ +inline double bench(std::function const &op, driver::stream *stream, + size_t warmup = 10, size_t repeat = 200) { timer tmr; std::vector times; double total_time = 0; - for(size_t i = 0; i < warmup; i++) + for (size_t i = 0; i < warmup; i++) op(); stream->synchronize(); tmr.start(); - for(size_t i = 0; i < repeat; i++){ + for (size_t i = 0; i < repeat; i++) { op(); } stream->synchronize(); return (float)tmr.get().count() / repeat; -// return *std::min_element(times.begin(), times.end()); + // return *std::min_element(times.begin(), times.end()); } -} -} +} // namespace tools +} // namespace triton #endif diff --git a/include/triton/tools/graph.h b/include/triton/tools/graph.h index c2ba8d854..3725eb091 100644 --- a/include/triton/tools/graph.h +++ b/include/triton/tools/graph.h @@ -3,16 +3,15 @@ #ifndef _TRITON_TOOLS_THREAD_GRAPH_H_ #define _TRITON_TOOLS_THREAD_GRAPH_H_ +#include #include #include #include -#include namespace triton { -namespace tools{ +namespace tools { -template -class graph { +template class graph { typedef std::map> edges_t; public: @@ -21,27 +20,27 @@ public: private: void connected_components_impl(node_t x, std::set &nodes, - nmap_t* nmap, cmap_t* cmap, int id) const { - if(nmap) + nmap_t *nmap, cmap_t *cmap, int id) const { + if (nmap) (*nmap)[x] = id; - if(cmap) + if (cmap) (*cmap)[id].push_back(x); - if(nodes.find(x) != nodes.end()) { + if (nodes.find(x) != nodes.end()) { nodes.erase(x); - for(const node_t &y: edges_.at(x)) + for (const node_t &y : edges_.at(x)) connected_components_impl(y, nodes, nmap, cmap, id); } } public: void connected_components(cmap_t *cmap, nmap_t *nmap) const { - if(cmap) + if (cmap) cmap->clear(); - if(nmap) + if (nmap) nmap->clear(); std::set nodes = nodes_; unsigned id = 0; - while(!nodes.empty()){ + while (!nodes.empty()) { connected_components_impl(*nodes.begin(), nodes, nmap, cmap, id++); } } @@ -63,7 +62,7 @@ private: edges_t edges_; }; -} -} +} // namespace tools +} // namespace triton #endif diff --git a/include/triton/tools/sha1.hpp b/include/triton/tools/sha1.hpp index 630a3fd77..1e71034de 100644 --- a/include/triton/tools/sha1.hpp +++ b/include/triton/tools/sha1.hpp @@ -33,154 +33,140 @@ #ifndef _TRITON_TOOLS_SHA1_HPP_ #define _TRITON_TOOLS_SHA1_HPP_ -namespace sha1 +namespace sha1 { +namespace // local { - namespace // local - { - // Rotate an integer value to left. - inline unsigned int rol(const unsigned int value, - const unsigned int steps) - { - return ((value << steps) | (value >> (32 - steps))); - } +// Rotate an integer value to left. +inline unsigned int rol(const unsigned int value, const unsigned int steps) { + return ((value << steps) | (value >> (32 - steps))); +} - // Sets the first 16 integers in the buffert to zero. - // Used for clearing the W buffert. - inline void clearWBuffert(unsigned int* buffert) - { - for (int pos = 16; --pos >= 0;) - { - buffert[pos] = 0; - } - } +// Sets the first 16 integers in the buffert to zero. +// Used for clearing the W buffert. +inline void clearWBuffert(unsigned int *buffert) { + for (int pos = 16; --pos >= 0;) { + buffert[pos] = 0; + } +} - inline void innerHash(unsigned int* result, unsigned int* w) - { - unsigned int a = result[0]; - unsigned int b = result[1]; - unsigned int c = result[2]; - unsigned int d = result[3]; - unsigned int e = result[4]; +inline void innerHash(unsigned int *result, unsigned int *w) { + unsigned int a = result[0]; + unsigned int b = result[1]; + unsigned int c = result[2]; + unsigned int d = result[3]; + unsigned int e = result[4]; - int round = 0; + int round = 0; - #define sha1macro(func,val) \ - { \ - const unsigned int t = rol(a, 5) + (func) + e + val + w[round]; \ - e = d; \ - d = c; \ - c = rol(b, 30); \ - b = a; \ - a = t; \ - } +#define sha1macro(func, val) \ + { \ + const unsigned int t = rol(a, 5) + (func) + e + val + w[round]; \ + e = d; \ + d = c; \ + c = rol(b, 30); \ + b = a; \ + a = t; \ + } - while (round < 16) - { - sha1macro((b & c) | (~b & d), 0x5a827999) - ++round; - } - while (round < 20) - { - w[round] = rol((w[round - 3] ^ w[round - 8] ^ w[round - 14] ^ w[round - 16]), 1); - sha1macro((b & c) | (~b & d), 0x5a827999) - ++round; - } - while (round < 40) - { - w[round] = rol((w[round - 3] ^ w[round - 8] ^ w[round - 14] ^ w[round - 16]), 1); - sha1macro(b ^ c ^ d, 0x6ed9eba1) - ++round; - } - while (round < 60) - { - w[round] = rol((w[round - 3] ^ w[round - 8] ^ w[round - 14] ^ w[round - 16]), 1); - sha1macro((b & c) | (b & d) | (c & d), 0x8f1bbcdc) - ++round; - } - while (round < 80) - { - w[round] = rol((w[round - 3] ^ w[round - 8] ^ w[round - 14] ^ w[round - 16]), 1); - sha1macro(b ^ c ^ d, 0xca62c1d6) - ++round; - } + while (round < 16) { + sha1macro((b & c) | (~b & d), 0x5a827999)++ round; + } + while (round < 20) { + w[round] = + rol((w[round - 3] ^ w[round - 8] ^ w[round - 14] ^ w[round - 16]), 1); + sha1macro((b & c) | (~b & d), 0x5a827999)++ round; + } + while (round < 40) { + w[round] = + rol((w[round - 3] ^ w[round - 8] ^ w[round - 14] ^ w[round - 16]), 1); + sha1macro(b ^ c ^ d, 0x6ed9eba1)++ round; + } + while (round < 60) { + w[round] = + rol((w[round - 3] ^ w[round - 8] ^ w[round - 14] ^ w[round - 16]), 1); + sha1macro((b & c) | (b & d) | (c & d), 0x8f1bbcdc)++ round; + } + while (round < 80) { + w[round] = + rol((w[round - 3] ^ w[round - 8] ^ w[round - 14] ^ w[round - 16]), 1); + sha1macro(b ^ c ^ d, 0xca62c1d6)++ round; + } - #undef sha1macro +#undef sha1macro - result[0] += a; - result[1] += b; - result[2] += c; - result[3] += d; - result[4] += e; - } - } // namespace + result[0] += a; + result[1] += b; + result[2] += c; + result[3] += d; + result[4] += e; +} +} // namespace - inline void calc(const void* src, const int bytelength, unsigned char* hash) - { - // Init the result array. - unsigned int result[5] = { 0x67452301, 0xefcdab89, 0x98badcfe, 0x10325476, 0xc3d2e1f0 }; +inline void calc(const void *src, const int bytelength, unsigned char *hash) { + // Init the result array. + unsigned int result[5] = {0x67452301, 0xefcdab89, 0x98badcfe, 0x10325476, + 0xc3d2e1f0}; - // Cast the void src pointer to be the byte array we can work with. - const unsigned char* sarray = (const unsigned char*) src; + // Cast the void src pointer to be the byte array we can work with. + const unsigned char *sarray = (const unsigned char *)src; - // The reusable round buffer - unsigned int w[80]; + // The reusable round buffer + unsigned int w[80]; - // Loop through all complete 64byte blocks. - const int endOfFullBlocks = bytelength - 64; - int endCurrentBlock; - int currentBlock = 0; + // Loop through all complete 64byte blocks. + const int endOfFullBlocks = bytelength - 64; + int endCurrentBlock; + int currentBlock = 0; - while (currentBlock <= endOfFullBlocks) - { - endCurrentBlock = currentBlock + 64; + while (currentBlock <= endOfFullBlocks) { + endCurrentBlock = currentBlock + 64; - // Init the round buffer with the 64 byte block data. - for (int roundPos = 0; currentBlock < endCurrentBlock; currentBlock += 4) - { - // This line will swap endian on big endian and keep endian on little endian. - w[roundPos++] = (unsigned int) sarray[currentBlock + 3] - | (((unsigned int) sarray[currentBlock + 2]) << 8) - | (((unsigned int) sarray[currentBlock + 1]) << 16) - | (((unsigned int) sarray[currentBlock]) << 24); - } - innerHash(result, w); - } - - // Handle the last and not full 64 byte block if existing. - endCurrentBlock = bytelength - currentBlock; - clearWBuffert(w); - int lastBlockBytes = 0; - for (;lastBlockBytes < endCurrentBlock; ++lastBlockBytes) - { - w[lastBlockBytes >> 2] |= (unsigned int) sarray[lastBlockBytes + currentBlock] << ((3 - (lastBlockBytes & 3)) << 3); - } - w[lastBlockBytes >> 2] |= 0x80 << ((3 - (lastBlockBytes & 3)) << 3); - if (endCurrentBlock >= 56) - { - innerHash(result, w); - clearWBuffert(w); - } - w[15] = bytelength << 3; - innerHash(result, w); - - // Store hash in result pointer, and make sure we get in in the correct order on both endian models. - for (int hashByte = 20; --hashByte >= 0;) - { - hash[hashByte] = (result[hashByte >> 2] >> (((3 - hashByte) & 0x3) << 3)) & 0xff; - } + // Init the round buffer with the 64 byte block data. + for (int roundPos = 0; currentBlock < endCurrentBlock; currentBlock += 4) { + // This line will swap endian on big endian and keep endian on little + // endian. + w[roundPos++] = (unsigned int)sarray[currentBlock + 3] | + (((unsigned int)sarray[currentBlock + 2]) << 8) | + (((unsigned int)sarray[currentBlock + 1]) << 16) | + (((unsigned int)sarray[currentBlock]) << 24); } + innerHash(result, w); + } - inline void toHexString(const unsigned char* hash, char* hexstring) - { - const char hexDigits[] = { "0123456789abcdef" }; + // Handle the last and not full 64 byte block if existing. + endCurrentBlock = bytelength - currentBlock; + clearWBuffert(w); + int lastBlockBytes = 0; + for (; lastBlockBytes < endCurrentBlock; ++lastBlockBytes) { + w[lastBlockBytes >> 2] |= + (unsigned int)sarray[lastBlockBytes + currentBlock] + << ((3 - (lastBlockBytes & 3)) << 3); + } + w[lastBlockBytes >> 2] |= 0x80 << ((3 - (lastBlockBytes & 3)) << 3); + if (endCurrentBlock >= 56) { + innerHash(result, w); + clearWBuffert(w); + } + w[15] = bytelength << 3; + innerHash(result, w); - for (int hashByte = 20; --hashByte >= 0;) - { - hexstring[hashByte << 1] = hexDigits[(hash[hashByte] >> 4) & 0xf]; - hexstring[(hashByte << 1) + 1] = hexDigits[hash[hashByte] & 0xf]; - } - hexstring[40] = 0; - } + // Store hash in result pointer, and make sure we get in in the correct order + // on both endian models. + for (int hashByte = 20; --hashByte >= 0;) { + hash[hashByte] = + (result[hashByte >> 2] >> (((3 - hashByte) & 0x3) << 3)) & 0xff; + } +} + +inline void toHexString(const unsigned char *hash, char *hexstring) { + const char hexDigits[] = {"0123456789abcdef"}; + + for (int hashByte = 20; --hashByte >= 0;) { + hexstring[hashByte << 1] = hexDigits[(hash[hashByte] >> 4) & 0xf]; + hexstring[(hashByte << 1) + 1] = hexDigits[hash[hashByte] & 0xf]; + } + hexstring[40] = 0; +} } // namespace sha1 #endif diff --git a/include/triton/tools/sys/exec.hpp b/include/triton/tools/sys/exec.hpp index 5b664553e..e96a04314 100644 --- a/include/triton/tools/sys/exec.hpp +++ b/include/triton/tools/sys/exec.hpp @@ -7,11 +7,8 @@ #include #include -namespace triton -{ -namespace tools -{ - +namespace triton { +namespace tools { #ifdef _WIN32 #define popen _popen @@ -19,12 +16,12 @@ namespace tools #endif #ifndef WEXITSTATUS -#define WEXITSTATUS(stat_val) ((unsigned)(stat_val) & 255) +#define WEXITSTATUS(stat_val) ((unsigned)(stat_val)&255) #endif -int exec(const std::string& cmd, std::string& result) { +int exec(const std::string &cmd, std::string &result) { char buffer[128]; - FILE* pipe = popen(cmd.c_str(), "r"); + FILE *pipe = popen(cmd.c_str(), "r"); if (!pipe) return 0; result.clear(); @@ -37,10 +34,9 @@ int exec(const std::string& cmd, std::string& result) { } int status = pclose(pipe); return WEXITSTATUS(status); - } -} -} +} // namespace tools +} // namespace triton #endif diff --git a/include/triton/tools/sys/getenv.hpp b/include/triton/tools/sys/getenv.hpp old mode 100755 new mode 100644 index 755a84a66..1f1c57521 --- a/include/triton/tools/sys/getenv.hpp +++ b/include/triton/tools/sys/getenv.hpp @@ -22,26 +22,23 @@ #ifndef TDL_TOOLS_SYS_GETENV_HPP #define TDL_TOOLS_SYS_GETENV_HPP -#include #include +#include -namespace triton -{ +namespace triton { -namespace tools -{ - - inline std::string getenv(const char * name) - { - const char * cstr = std::getenv(name); - if(!cstr) - return ""; - std::string result(cstr); - return result; - } +namespace tools { +inline std::string getenv(const char *name) { + const char *cstr = std::getenv(name); + if (!cstr) + return ""; + std::string result(cstr); + return result; } -} +} // namespace tools + +} // namespace triton #endif diff --git a/include/triton/tools/sys/mkdir.hpp b/include/triton/tools/sys/mkdir.hpp old mode 100755 new mode 100644 index 5198a0098..10cb0da6a --- a/include/triton/tools/sys/mkdir.hpp +++ b/include/triton/tools/sys/mkdir.hpp @@ -22,55 +22,49 @@ #ifndef TDL_TOOLS_SYS_MKDIR_HPP #define TDL_TOOLS_SYS_MKDIR_HPP -#include -#include #include -#include +#include #include +#include +#include #if defined(_WIN32) - #include +#include #endif -namespace triton -{ +namespace triton { -namespace tools -{ - - inline int mkdir(std::string const & path) - { - #if defined(_WIN32) - return _mkdir(path.c_str()); - #else - return ::mkdir(path.c_str(), 0777); - #endif - } - - inline int mkpath(std::string const & path) - { - int status = 0; - size_t pp = 0; - size_t sp; - while ((sp = path.find('/', pp)) != std::string::npos) - { - if (sp != pp){ - status = mkdir(path.substr(0, sp)); - } - pp = sp + 1; - } - return (status==0 || errno==EEXIST)?0:-1; - } - - inline int mtime(std::string const & path) - { - struct stat st; - if(stat(path.c_str(), &st) != 0) - return 0; - return st.st_mtime; - } +namespace tools { +inline int mkdir(std::string const &path) { +#if defined(_WIN32) + return _mkdir(path.c_str()); +#else + return ::mkdir(path.c_str(), 0777); +#endif } +inline int mkpath(std::string const &path) { + int status = 0; + size_t pp = 0; + size_t sp; + while ((sp = path.find('/', pp)) != std::string::npos) { + if (sp != pp) { + status = mkdir(path.substr(0, sp)); + } + pp = sp + 1; + } + return (status == 0 || errno == EEXIST) ? 0 : -1; } +inline int mtime(std::string const &path) { + struct stat st; + if (stat(path.c_str(), &st) != 0) + return 0; + return st.st_mtime; +} + +} // namespace tools + +} // namespace triton + #endif diff --git a/include/triton/tools/thread_pool.h b/include/triton/tools/thread_pool.h index fbcf2b684..e8a6ca6ca 100644 --- a/include/triton/tools/thread_pool.h +++ b/include/triton/tools/thread_pool.h @@ -3,88 +3,79 @@ #ifndef _TRITON_TOOLS_THREAD_POOL_H_ #define _TRITON_TOOLS_THREAD_POOL_H_ -#include -#include -#include -#include -#include #include -#include #include +#include +#include +#include +#include #include +#include +#include class ThreadPool { public: - ThreadPool(size_t threads) - : stop(false) { - for(size_t i = 0;i < threads;++i) - workers.emplace_back( - [this] { - for(;;){ - std::function task; - { - std::unique_lock lock(this->queue_mutex); - this->condition.wait(lock, - [this]{ return this->stop || !this->tasks.empty(); }); - if(this->stop && this->tasks.empty()) - return; - task = std::move(this->tasks.front()); - this->tasks.pop(); - } - task(); - } - } - ); - } + ThreadPool(size_t threads) : stop(false) { + for (size_t i = 0; i < threads; ++i) + workers.emplace_back([this] { + for (;;) { + std::function task; + { + std::unique_lock lock(this->queue_mutex); + this->condition.wait( + lock, [this] { return this->stop || !this->tasks.empty(); }); + if (this->stop && this->tasks.empty()) + return; + task = std::move(this->tasks.front()); + this->tasks.pop(); + } + task(); + } + }); + } + template + auto enqueue(F &&f, Args &&... args) + -> std::future::type> { + using return_type = typename std::result_of::type; - template - auto enqueue(F&& f, Args&&... args) - -> std::future::type> + auto task = std::make_shared>( + std::bind(std::forward(f), std::forward(args)...)); + + std::future res = task->get_future(); { - using return_type = typename std::result_of::type; + std::unique_lock lock(queue_mutex); - auto task = std::make_shared< std::packaged_task >( - std::bind(std::forward(f), std::forward(args)...) - ); + // don't allow enqueueing after stopping the pool + if (stop) + throw std::runtime_error("enqueue on stopped ThreadPool"); - std::future res = task->get_future(); - { - std::unique_lock lock(queue_mutex); - - // don't allow enqueueing after stopping the pool - if(stop) - throw std::runtime_error("enqueue on stopped ThreadPool"); - - tasks.emplace([task](){ (*task)(); }); - } - condition.notify_one(); - return res; + tasks.emplace([task]() { (*task)(); }); } + condition.notify_one(); + return res; + } - - ~ThreadPool() { - { - std::unique_lock lock(queue_mutex); - stop = true; - } - condition.notify_all(); - for(std::thread &worker: workers) - worker.join(); + ~ThreadPool() { + { + std::unique_lock lock(queue_mutex); + stop = true; } - + condition.notify_all(); + for (std::thread &worker : workers) + worker.join(); + } private: - // need to keep track of threads so we can join them - std::vector< std::thread > workers; - // the task queue - std::queue< std::function > tasks; + // need to keep track of threads so we can join them + std::vector workers; + // the task queue + std::queue> tasks; - // synchronization - std::mutex queue_mutex; - std::condition_variable condition; - bool stop; + // synchronization + std::mutex queue_mutex; + std::condition_variable condition; + bool stop; }; - #endif diff --git a/lib/Analysis/AxisInfo.cpp b/lib/Analysis/AxisInfo.cpp index 6222e5261..ef926c190 100644 --- a/lib/Analysis/AxisInfo.cpp +++ b/lib/Analysis/AxisInfo.cpp @@ -8,24 +8,23 @@ namespace mlir { - //===----------------------------------------------------------------------===// // AxisInfo //===----------------------------------------------------------------------===// - // Function for extended Euclidean Algorithm -static int gcd_impl(int a, int b, int *x, int *y){ +// Function for extended Euclidean Algorithm +static int gcd_impl(int a, int b, int *x, int *y) { // Base Case if (a == 0) { - *x = 0; - *y = 1; - return b; + *x = 0; + *y = 1; + return b; } int x1, y1; // To store results of recursive call - int gcd = gcd_impl(b%a, a, &x1, &y1); + int gcd = gcd_impl(b % a, a, &x1, &y1); // Update x and y using results of // recursive call - *x = y1 - (b/a) * x1; + *x = y1 - (b / a) * x1; *y = x1; return gcd; } @@ -35,17 +34,17 @@ static int gcd(int a, int b) { return gcd_impl(a, b, &x, &y); } - AxisInfo AxisInfo::getPessimisticValueState(Value value) { size_t rank = 1; - if(TensorType ty = value.getType().dyn_cast()) + if (TensorType ty = value.getType().dyn_cast()) rank = ty.getRank(); int divHint = 1; - if(BlockArgument blockArg = value.dyn_cast()){ - Operation* op = blockArg.getOwner()->getParentOp(); - if(FuncOp fun = dyn_cast(op)){ - Attribute attr = fun.getArgAttr(blockArg.getArgNumber(), "tt.divisibility"); - if(attr) + if (BlockArgument blockArg = value.dyn_cast()) { + Operation *op = blockArg.getOwner()->getParentOp(); + if (FuncOp fun = dyn_cast(op)) { + Attribute attr = + fun.getArgAttr(blockArg.getArgNumber(), "tt.divisibility"); + if (attr) divHint = attr.cast().getValue().getZExtValue(); } } @@ -55,51 +54,51 @@ AxisInfo AxisInfo::getPessimisticValueState(Value value) { return AxisInfo(contiguity, divisibility, constancy); } - // The gcd of both arguments for each dimension -AxisInfo AxisInfo::join(const AxisInfo &lhs, - const AxisInfo &rhs) { +AxisInfo AxisInfo::join(const AxisInfo &lhs, const AxisInfo &rhs) { ContiguityT retContiguity; DivisibilityT retDivisibility; ConstancyT retConstancy; - for(size_t d = 0; d < lhs.getRank(); d++){ + for (size_t d = 0; d < lhs.getRank(); d++) { retContiguity.push_back(gcd(lhs.getContiguity(d), rhs.getContiguity(d))); - retDivisibility.push_back(gcd(lhs.getDivisibility(d), rhs.getDivisibility(d))); + retDivisibility.push_back( + gcd(lhs.getDivisibility(d), rhs.getDivisibility(d))); retConstancy.push_back(gcd(lhs.getConstancy(d), rhs.getConstancy(d))); } return AxisInfo(retContiguity, retDivisibility, retConstancy); } - //===----------------------------------------------------------------------===// // AxisInfoAnalysis //===----------------------------------------------------------------------===// -AxisInfo AxisInfoAnalysis::visitBinaryOp(Operation* op, AxisInfo lhsInfo, AxisInfo rhsInfo, - const std::function& getContiguity, - const std::function& getDivisibility, - const std::function& getConstancy) { - int rank = lhsInfo.getRank(); - AxisInfo::ContiguityT newContiguity; - AxisInfo::DivisibilityT newDivisibility; - AxisInfo::ConstancyT newConstancy; - for(size_t d = 0; d < rank; d++){ - newContiguity.push_back(getContiguity(lhsInfo, rhsInfo, d)); - newDivisibility.push_back(getDivisibility(lhsInfo, rhsInfo, d)); - newConstancy.push_back(getConstancy(lhsInfo, rhsInfo, d)); - } - return AxisInfo(newContiguity, newDivisibility, newConstancy); +AxisInfo AxisInfoAnalysis::visitBinaryOp( + Operation *op, AxisInfo lhsInfo, AxisInfo rhsInfo, + const std::function &getContiguity, + const std::function &getDivisibility, + const std::function &getConstancy) { + int rank = lhsInfo.getRank(); + AxisInfo::ContiguityT newContiguity; + AxisInfo::DivisibilityT newDivisibility; + AxisInfo::ConstancyT newConstancy; + for (size_t d = 0; d < rank; d++) { + newContiguity.push_back(getContiguity(lhsInfo, rhsInfo, d)); + newDivisibility.push_back(getDivisibility(lhsInfo, rhsInfo, d)); + newConstancy.push_back(getConstancy(lhsInfo, rhsInfo, d)); + } + return AxisInfo(newContiguity, newDivisibility, newConstancy); } -ChangeResult AxisInfoAnalysis::visitOperation(Operation *op, - ArrayRef *> operands) { +ChangeResult AxisInfoAnalysis::visitOperation( + Operation *op, ArrayRef *> operands) { AxisInfo curr; // This preserves the input axes (e.g., cast): if (llvm::isa(op)) curr = operands[0]->getValue(); // Constant ranges - if (triton::MakeRangeOp make_range = llvm::dyn_cast(op)){ + if (triton::MakeRangeOp make_range = + llvm::dyn_cast(op)) { int start = make_range.start(); int end = make_range.end(); AxisInfo::ContiguityT contiguity = {end - start}; @@ -108,61 +107,59 @@ ChangeResult AxisInfoAnalysis::visitOperation(Operation *op, curr = AxisInfo(contiguity, divisibility, constancy); } // Constant - if (arith::ConstantOp constant = llvm::dyn_cast(op)){ + if (arith::ConstantOp constant = llvm::dyn_cast(op)) { auto intAttr = constant.getValue().dyn_cast(); - if(intAttr){ + if (intAttr) { size_t val = intAttr.getValue().getZExtValue(); curr = AxisInfo({1}, {highestPowOf2Divisor(val)}, {1}); } // TODO: generalize to dense attr auto splatAttr = constant.getValue().dyn_cast(); - if(splatAttr && splatAttr.getElementType().isInteger(32)){ + if (splatAttr && splatAttr.getElementType().isInteger(32)) { auto value = splatAttr.getSplatValue(); TensorType ty = splatAttr.getType().cast(); - curr = AxisInfo(AxisInfo::ContiguityT(ty.getRank(), 1), - AxisInfo::DivisibilityT(ty.getRank(), highestPowOf2Divisor(value)), - AxisInfo::ConstancyT(ty.getShape().begin(), ty.getShape().end())); - + curr = AxisInfo( + AxisInfo::ContiguityT(ty.getRank(), 1), + AxisInfo::DivisibilityT(ty.getRank(), highestPowOf2Divisor(value)), + AxisInfo::ConstancyT(ty.getShape().begin(), ty.getShape().end())); } } // Addition - if (llvm::isa(op)){ - auto newContiguity = [&](AxisInfo lhs, AxisInfo rhs, int d){ + if (llvm::isa(op)) { + auto newContiguity = [&](AxisInfo lhs, AxisInfo rhs, int d) { return std::max(gcd(lhs.getContiguity(d), rhs.getConstancy(d)), gcd(lhs.getConstancy(d), rhs.getContiguity(d))); }; - auto newConstancy = [&](AxisInfo lhs, AxisInfo rhs, int d){ + auto newConstancy = [&](AxisInfo lhs, AxisInfo rhs, int d) { return gcd(lhs.getConstancy(d), rhs.getConstancy(d)); }; - auto newDivisibility = [&](AxisInfo lhs, AxisInfo rhs, int d){ + auto newDivisibility = [&](AxisInfo lhs, AxisInfo rhs, int d) { return gcd(lhs.getDivisibility(d), rhs.getDivisibility(d)); }; curr = visitBinaryOp(op, operands[0]->getValue(), operands[1]->getValue(), - newContiguity, newDivisibility, newConstancy); + newContiguity, newDivisibility, newConstancy); } // Multiplication - if (llvm::isa(op)){ - auto newContiguity = [](AxisInfo lhs, AxisInfo rhs, int d){ - return 1; - }; - auto newConstancy = [](AxisInfo lhs, AxisInfo rhs, int d){ + if (llvm::isa(op)) { + auto newContiguity = [](AxisInfo lhs, AxisInfo rhs, int d) { return 1; }; + auto newConstancy = [](AxisInfo lhs, AxisInfo rhs, int d) { return gcd(lhs.getConstancy(d), rhs.getConstancy(d)); }; - auto newDivisibility = [](AxisInfo lhs, AxisInfo rhs, int d){ - return lhs.getDivisibility(d)*rhs.getDivisibility(d); + auto newDivisibility = [](AxisInfo lhs, AxisInfo rhs, int d) { + return lhs.getDivisibility(d) * rhs.getDivisibility(d); }; curr = visitBinaryOp(op, operands[0]->getValue(), operands[1]->getValue(), - newContiguity, newDivisibility, newConstancy); + newContiguity, newDivisibility, newConstancy); } // Splat - if (llvm::isa(op)){ + if (llvm::isa(op)) { Type _retTy = *op->result_type_begin(); TensorType retTy = _retTy.cast(); AxisInfo opInfo = operands[0]->getValue(); AxisInfo::ContiguityT contiguity; AxisInfo::DivisibilityT divisibility; AxisInfo::ConstancyT constancy; - for(size_t d = 0; d < retTy.getRank(); d++){ + for (size_t d = 0; d < retTy.getRank(); d++) { contiguity.push_back(1); divisibility.push_back(opInfo.getDivisibility(0)); constancy.push_back(retTy.getShape()[d]); @@ -171,7 +168,7 @@ ChangeResult AxisInfoAnalysis::visitOperation(Operation *op, } // Reshape // TODO: Replace by `unsqueeze` - if (llvm::isa(op)){ + if (llvm::isa(op)) { Type _retTy = *op->result_type_begin(); Type _opTy = *op->operand_type_begin(); TensorType retTy = _retTy.cast(); @@ -184,20 +181,17 @@ ChangeResult AxisInfoAnalysis::visitOperation(Operation *op, AxisInfo::ConstancyT constancy; bool is_skewed = false; size_t current = 0; - for(size_t d = 0; d < retTy.getRank(); d++){ - if(retShape[d] == 1){ + for (size_t d = 0; d < retTy.getRank(); d++) { + if (retShape[d] == 1) { contiguity.push_back(1); divisibility.push_back(1); constancy.push_back(1); - } - else if(!is_skewed - && retShape[d] == opShape[current]){ + } else if (!is_skewed && retShape[d] == opShape[current]) { contiguity.push_back(opInfo.getContiguity()[current]); divisibility.push_back(opInfo.getDivisibility()[current]); constancy.push_back(opInfo.getConstancy()[current]); current++; - } - else { + } else { is_skewed = true; contiguity.push_back(1); divisibility.push_back(1); @@ -207,7 +201,7 @@ ChangeResult AxisInfoAnalysis::visitOperation(Operation *op, curr = AxisInfo(contiguity, divisibility, constancy); } // Broadcast - if (llvm::isa(op)){ + if (llvm::isa(op)) { Type _retTy = *op->result_type_begin(); Type _opTy = *op->operand_type_begin(); TensorType retTy = _retTy.cast(); @@ -218,14 +212,14 @@ ChangeResult AxisInfoAnalysis::visitOperation(Operation *op, AxisInfo::ContiguityT contiguity; AxisInfo::DivisibilityT divisibility; AxisInfo::ConstancyT constancy; - for(size_t d = 0; d < retTy.getRank(); d++){ + for (size_t d = 0; d < retTy.getRank(); d++) { contiguity.push_back(opShape[d] == 1 ? 1 : opInfo.getContiguity(d)); divisibility.push_back(opInfo.getDivisibility(d)); constancy.push_back(opShape[d] == 1 ? retShape[d] : 1); } curr = AxisInfo(contiguity, divisibility, constancy); } - if(curr.getRank() == 0){ + if (curr.getRank() == 0) { return markAllPessimisticFixpoint(op->getResults()); } // join all latice elements @@ -236,4 +230,4 @@ ChangeResult AxisInfoAnalysis::visitOperation(Operation *op, return result; } -} \ No newline at end of file +} // namespace mlir \ No newline at end of file diff --git a/lib/Conversion/PassDetail.h b/lib/Conversion/PassDetail.h index e772f41b6..e60f15a84 100644 --- a/lib/Conversion/PassDetail.h +++ b/lib/Conversion/PassDetail.h @@ -2,14 +2,16 @@ #define TRITON_CONVERSION_PASSDETAIL_H #include "mlir/Pass/Pass.h" +#include "triton/Dialect/Triton/IR/Dialect.h" +#include "triton/Dialect/TritonGPU/IR/Dialect.h" -namespace mlir{ -namespace triton{ +namespace mlir { +namespace triton { #define GEN_PASS_CLASSES #include "triton/Conversion/Passes.h.inc" -} -} +} // namespace triton +} // namespace mlir #endif diff --git a/lib/Conversion/TritonToTritonGPU/TritonToTritonGPU.cpp b/lib/Conversion/TritonToTritonGPU/TritonToTritonGPU.cpp index 22663b504..c927e766d 100644 --- a/lib/Conversion/TritonToTritonGPU/TritonToTritonGPU.cpp +++ b/lib/Conversion/TritonToTritonGPU/TritonToTritonGPU.cpp @@ -1,42 +1,42 @@ +#include "triton/Conversion/TritonToTritonGPU/TritonToTritonGPU.h" +#include "../PassDetail.h" +#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" +#include "mlir/Transforms/DialectConversion.h" #include "triton/Dialect/Triton/IR/Dialect.h" #include "triton/Dialect/TritonGPU/IR/Dialect.h" -#include "triton/Conversion/TritonToTritonGPU/TritonToTritonGPU.h" #include "triton/Dialect/TritonGPU/Transforms/TritonGPUConversion.h" -#include "mlir/Transforms/DialectConversion.h" -#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" -#include "../PassDetail.h" using namespace mlir; using namespace mlir::triton; namespace { -template -class ArithGenericPattern : public OpConversionPattern { +template class ArithGenericPattern : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; - LogicalResult matchAndRewrite(Op op, typename Op::Adaptor adaptor, - ConversionPatternRewriter &rewriter) const override { + LogicalResult + matchAndRewrite(Op op, typename Op::Adaptor adaptor, + ConversionPatternRewriter &rewriter) const override { Type retType = this->getTypeConverter()->convertType(op.getType()); - Op res = rewriter.replaceOpWithNewOp( - op, retType, adaptor.getOperands() - ); + Op res = + rewriter.replaceOpWithNewOp(op, retType, adaptor.getOperands()); return success(); } }; -template +template class ArithCmpPattern : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; - LogicalResult matchAndRewrite(SrcOp op, typename SrcOp::Adaptor adaptor, - ConversionPatternRewriter &rewriter) const override { + LogicalResult + matchAndRewrite(SrcOp op, typename SrcOp::Adaptor adaptor, + ConversionPatternRewriter &rewriter) const override { Type retType = this->getTypeConverter()->convertType(op.getType()); - DstOp res = rewriter.replaceOpWithNewOp( - op, retType, adaptor.getPredicate(), adaptor.getLhs(), adaptor.getRhs() - ); + DstOp res = + rewriter.replaceOpWithNewOp(op, retType, adaptor.getPredicate(), + adaptor.getLhs(), adaptor.getRhs()); return success(); } }; @@ -45,36 +45,40 @@ class ArithConstantPattern : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; - LogicalResult matchAndRewrite(arith::ConstantOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { + LogicalResult + matchAndRewrite(arith::ConstantOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { Type retType = getTypeConverter()->convertType(op.getType()); auto value = adaptor.getValue().dyn_cast(); assert(value); rewriter.replaceOpWithNewOp( - op, retType, value.reshape(retType) // This is a hack. We just want to add encoding + op, retType, + value.reshape(retType) // This is a hack. We just want to add encoding ); return success(); } }; -class ConvertArithmeticOp: public ConversionPattern { +class ConvertArithmeticOp : public ConversionPattern { public: - ConvertArithmeticOp(TritonGPUTypeConverter &typeConverter, MLIRContext *context) - : ConversionPattern(typeConverter, MatchAnyOpTypeTag(), /*benefit=*/1, - context) {} + ConvertArithmeticOp(TritonGPUTypeConverter &typeConverter, + MLIRContext *context) + : ConversionPattern(typeConverter, MatchAnyOpTypeTag(), /*benefit=*/1, + context) {} - LogicalResult matchAndRewrite(Operation* op, ArrayRef operands, - ConversionPatternRewriter& rewriter) const override { - Dialect* dialect = op->getDialect(); - if(dialect->getTypeID() != mlir::TypeID::get()) - return failure(); - return success(); - } + LogicalResult + matchAndRewrite(Operation *op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const override { + Dialect *dialect = op->getDialect(); + if (dialect->getTypeID() != mlir::TypeID::get()) + return failure(); + return success(); + } }; void populateArithmeticPatternsAndLegality( - TritonGPUTypeConverter& typeConverter, RewritePatternSet &patterns, - TritonGPUConversionTarget &target){ + TritonGPUTypeConverter &typeConverter, RewritePatternSet &patterns, + TritonGPUConversionTarget &target) { // -------------- // Add legality and rewrite pattern rules for operations // from the Arithmetic dialect. The basic premise is that @@ -91,59 +95,49 @@ void populateArithmeticPatternsAndLegality( // ); // Rewrite rule // patterns.add(typeConverter, context); - patterns.add, - ArithGenericPattern, - ArithGenericPattern, - ArithGenericPattern, - ArithGenericPattern, - ArithGenericPattern, - ArithGenericPattern, - ArithGenericPattern, - ArithGenericPattern, - ArithGenericPattern, - ArithGenericPattern, - ArithGenericPattern, - ArithGenericPattern, - ArithGenericPattern, - ArithGenericPattern, - ArithGenericPattern, // NegFOp - // Floating point - ArithGenericPattern, - ArithGenericPattern, - // MaxMin - ArithGenericPattern, - ArithGenericPattern, - ArithGenericPattern, - ArithGenericPattern, - ArithGenericPattern, - ArithGenericPattern, - // Floating point - ArithGenericPattern, - ArithGenericPattern, - ArithGenericPattern, - // Cmp - ArithCmpPattern, - ArithCmpPattern, - // Cast Ops - ArithGenericPattern, - ArithGenericPattern - >(typeConverter, context); + patterns.add< + ArithConstantPattern, ArithGenericPattern, + ArithGenericPattern, ArithGenericPattern, + ArithGenericPattern, ArithGenericPattern, + ArithGenericPattern, + ArithGenericPattern, + ArithGenericPattern, + ArithGenericPattern, ArithGenericPattern, + ArithGenericPattern, ArithGenericPattern, + ArithGenericPattern, ArithGenericPattern, + ArithGenericPattern, + ArithGenericPattern, // NegFOp + // Floating point + ArithGenericPattern, ArithGenericPattern, + // MaxMin + ArithGenericPattern, ArithGenericPattern, + ArithGenericPattern, ArithGenericPattern, + ArithGenericPattern, ArithGenericPattern, + // Floating point + ArithGenericPattern, ArithGenericPattern, + ArithGenericPattern, + // Cmp + ArithCmpPattern, + ArithCmpPattern, + // Cast Ops + ArithGenericPattern, + ArithGenericPattern>(typeConverter, context); } // // Triton patterns // // TODO: Do we need to put them in anonymous namespace? -struct TritonMakeRangePattern : public OpConversionPattern { +struct TritonMakeRangePattern + : public OpConversionPattern { using OpConversionPattern::OpConversionPattern; - LogicalResult matchAndRewrite(triton::MakeRangeOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { + LogicalResult + matchAndRewrite(triton::MakeRangeOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { Type retType = getTypeConverter()->convertType(op.getType()); rewriter.replaceOpWithNewOp( - op, retType, adaptor.start(), adaptor.end() - ); + op, retType, adaptor.start(), adaptor.end()); return success(); } }; @@ -151,8 +145,9 @@ struct TritonMakeRangePattern : public OpConversionPattern struct TritonDotPattern : public OpConversionPattern { using OpConversionPattern::OpConversionPattern; - LogicalResult matchAndRewrite(triton::DotOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { + LogicalResult + matchAndRewrite(triton::DotOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { Type retType = getTypeConverter()->convertType(op.getType()); // a & b must be of smem layout auto aType = adaptor.a().getType().cast(); @@ -165,18 +160,21 @@ struct TritonDotPattern : public OpConversionPattern { Value b = adaptor.b(); SmallVector order{1, 0}; if (!aEncoding.isa()) { - Attribute encoding = triton::gpu::TritonGPUSharedEncodingAttr::get(getContext(), 1, 1, 1, order); - auto dstType = RankedTensorType::get(aType.getShape(), aType.getElementType(), encoding); + Attribute encoding = triton::gpu::TritonGPUSharedEncodingAttr::get( + getContext(), 1, 1, 1, order); + auto dstType = RankedTensorType::get(aType.getShape(), + aType.getElementType(), encoding); a = rewriter.create(a.getLoc(), dstType, a); } if (!bEncoding.isa()) { - Attribute encoding = triton::gpu::TritonGPUSharedEncodingAttr::get(getContext(), 1, 1, 1, order); - auto dstType = RankedTensorType::get(bType.getShape(), bType.getElementType(), encoding); + Attribute encoding = triton::gpu::TritonGPUSharedEncodingAttr::get( + getContext(), 1, 1, 1, order); + auto dstType = RankedTensorType::get(bType.getShape(), + bType.getElementType(), encoding); b = rewriter.create(b.getLoc(), dstType, b); } auto newDot = rewriter.replaceOpWithNewOp( - op, retType, a, b, adaptor.c(), adaptor.allowTF32() - ); + op, retType, a, b, adaptor.c(), adaptor.allowTF32()); return success(); } }; @@ -184,14 +182,13 @@ struct TritonDotPattern : public OpConversionPattern { struct TritonLoadPattern : public OpConversionPattern { using OpConversionPattern::OpConversionPattern; - LogicalResult matchAndRewrite(triton::LoadOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { + LogicalResult + matchAndRewrite(triton::LoadOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { Type retType = getTypeConverter()->convertType(op.getType()); rewriter.replaceOpWithNewOp( - op, retType, - adaptor.ptr(), adaptor.mask(), adaptor.other(), - adaptor.cache(), adaptor.evict(), adaptor.isVolatile() - ); + op, retType, adaptor.ptr(), adaptor.mask(), adaptor.other(), + adaptor.cache(), adaptor.evict(), adaptor.isVolatile()); return success(); } }; @@ -199,11 +196,11 @@ struct TritonLoadPattern : public OpConversionPattern { struct TritonStorePattern : public OpConversionPattern { using OpConversionPattern::OpConversionPattern; - LogicalResult matchAndRewrite(triton::StoreOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { + LogicalResult + matchAndRewrite(triton::StoreOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { auto newOp = rewriter.replaceOpWithNewOp( - op, adaptor.ptr(), adaptor.value(), adaptor.mask() - ); + op, adaptor.ptr(), adaptor.value(), adaptor.mask()); return success(); } }; @@ -212,12 +209,11 @@ template struct TritonGenericPattern : public OpConversionPattern { using OpConversionPattern::OpConversionPattern; - LogicalResult matchAndRewrite(Op op, typename Op::Adaptor adaptor, - ConversionPatternRewriter &rewriter) const override { + LogicalResult + matchAndRewrite(Op op, typename Op::Adaptor adaptor, + ConversionPatternRewriter &rewriter) const override { Type retType = this->getTypeConverter()->convertType(op.getType()); - rewriter.replaceOpWithNewOp( - op, retType, adaptor.getOperands() - ); + rewriter.replaceOpWithNewOp(op, retType, adaptor.getOperands()); return success(); } }; @@ -225,30 +221,25 @@ struct TritonGenericPattern : public OpConversionPattern { struct TritonReducePattern : public OpConversionPattern { using OpConversionPattern::OpConversionPattern; - LogicalResult matchAndRewrite(triton::ReduceOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { + LogicalResult + matchAndRewrite(triton::ReduceOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { Type retType = this->getTypeConverter()->convertType(op.getType()); auto newOp = rewriter.replaceOpWithNewOp( - op, retType, adaptor.redOp(), adaptor.operand(), adaptor.axis() - ); + op, retType, adaptor.redOp(), adaptor.operand(), adaptor.axis()); return success(); } }; -void populateTritonPatterns( - TritonGPUTypeConverter& typeConverter, RewritePatternSet &patterns -) { +void populateTritonPatterns(TritonGPUTypeConverter &typeConverter, + RewritePatternSet &patterns) { MLIRContext *context = patterns.getContext(); patterns.add, TritonGenericPattern, TritonGenericPattern, - TritonGenericPattern, - TritonReducePattern, - TritonMakeRangePattern, - TritonDotPattern, - TritonLoadPattern, - TritonStorePattern - >(typeConverter, context); + TritonGenericPattern, TritonReducePattern, + TritonMakeRangePattern, TritonDotPattern, TritonLoadPattern, + TritonStorePattern>(typeConverter, context); } // @@ -259,17 +250,19 @@ void populateTritonPatterns( struct SCFForPattern : public OpConversionPattern { using OpConversionPattern::OpConversionPattern; // Ref: ConvertForOpTypes - LogicalResult matchAndRewrite(scf::ForOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto newOp = cast(rewriter.cloneWithoutRegions(*op.getOperation())); + LogicalResult + matchAndRewrite(scf::ForOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto newOp = + cast(rewriter.cloneWithoutRegions(*op.getOperation())); rewriter.inlineRegionBefore(op.getLoopBody(), newOp.getLoopBody(), newOp.getLoopBody().end()); // Now, update all the types. // Convert the types of block arguments within the given region. This - // replaces each block with a new block containing the updated signature. The - // entry block may have a special conversion if `entryConversion` is + // replaces each block with a new block containing the updated signature. + // The entry block may have a special conversion if `entryConversion` is // provided. On success, the new entry block to the region is returned for // convenience. Otherwise, failure is returned. if (failed(rewriter.convertRegionTypes(&newOp.getLoopBody(), @@ -299,33 +292,27 @@ struct SCFForPattern : public OpConversionPattern { struct SCFYieldPattern : public OpConversionPattern { using OpConversionPattern::OpConversionPattern; - LogicalResult matchAndRewrite(scf::YieldOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { + LogicalResult + matchAndRewrite(scf::YieldOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { // rewriter.setInsertionPointToEnd(rewriter.getInsertionBlock()); // rewriter.create(op.getLoc(), adaptor.getOperands()); // op.erase(); - rewriter.replaceOpWithNewOp( - op, adaptor.getOperands() - ); + rewriter.replaceOpWithNewOp(op, adaptor.getOperands()); return success(); } }; -void populateSCFPatterns( - TritonGPUTypeConverter &typeConverter, RewritePatternSet &patterns -) { +void populateSCFPatterns(TritonGPUTypeConverter &typeConverter, + RewritePatternSet &patterns) { MLIRContext *context = patterns.getContext(); - patterns.add(typeConverter, context); + patterns.add(typeConverter, context); } - -class ConvertTritonToTritonGPU : - public ConvertTritonToTritonGPUBase { +class ConvertTritonToTritonGPU + : public ConvertTritonToTritonGPUBase { public: - ConvertTritonToTritonGPU(int numWarps) { - this->numWarps = numWarps; - } + ConvertTritonToTritonGPU(int numWarps) { this->numWarps = numWarps; } void runOnOperation() override { MLIRContext *context = &getContext(); @@ -339,21 +326,21 @@ public: // add rules populateArithmeticPatternsAndLegality(typeConverter, patterns, target); populateTritonPatterns(typeConverter, patterns); - // TODO: can we use + // TODO: can we use // mlir::scf::populateSCFStructurealTypeConversionsAndLegality(...) here? populateSCFPatterns(typeConverter, patterns); - if(failed(applyPartialConversion(mod, target, std::move(patterns)))) - return signalPassFailure(); + if (failed(applyPartialConversion(mod, target, std::move(patterns)))) + return signalPassFailure(); // update layouts // broadcast src => multicast, dst => broadcasted - if(failed(target.refineLayouts(mod, numWarps))) + if (failed(target.refineLayouts(mod, numWarps))) return signalPassFailure(); } }; -} +} // namespace std::unique_ptr> mlir::triton::createConvertTritonToTritonGPUPass(int numWarps) { diff --git a/lib/Dialect/Triton/IR/Dialect.cpp b/lib/Dialect/Triton/IR/Dialect.cpp index ff7ce0436..4b286e5b8 100644 --- a/lib/Dialect/Triton/IR/Dialect.cpp +++ b/lib/Dialect/Triton/IR/Dialect.cpp @@ -7,7 +7,6 @@ #include "mlir/IR/DialectImplementation.h" - #include "triton/Dialect/Triton/IR/Dialect.cpp.inc" using namespace mlir; @@ -19,12 +18,13 @@ void TritonDialect::initialize() { addOperations< #define GET_OP_LIST #include "triton/Dialect/Triton/IR/Ops.cpp.inc" - >(); + >(); // We can also add interface here. } -Operation *TritonDialect::materializeConstant(OpBuilder &builder, Attribute value, - Type type, Location loc) { +Operation *TritonDialect::materializeConstant(OpBuilder &builder, + Attribute value, Type type, + Location loc) { return builder.create(loc, type, value); } \ No newline at end of file diff --git a/lib/Dialect/Triton/IR/Ops.cpp b/lib/Dialect/Triton/IR/Ops.cpp index fd911b7a3..3d9204183 100644 --- a/lib/Dialect/Triton/IR/Ops.cpp +++ b/lib/Dialect/Triton/IR/Ops.cpp @@ -13,14 +13,16 @@ namespace triton { static Type getI1SameShape(Type type) { auto i1Type = IntegerType::get(type.getContext(), 1); if (auto tensorType = type.dyn_cast()) - return RankedTensorType::get(tensorType.getShape(), i1Type, tensorType.getEncoding()); + return RankedTensorType::get(tensorType.getShape(), i1Type, + tensorType.getEncoding()); return Type(); } static Type getI32SameShape(Type type) { auto i32Type = IntegerType::get(type.getContext(), 32); if (auto tensorType = type.dyn_cast()) - return RankedTensorType::get(tensorType.getShape(), i32Type, tensorType.getEncoding()); + return RankedTensorType::get(tensorType.getShape(), i32Type, + tensorType.getEncoding()); return Type(); } @@ -34,8 +36,8 @@ static Type getPointerTypeFromTensor(Type type) { return Type(); } -} -} +} // namespace triton +} // namespace mlir #define GET_OP_CLASSES #include "triton/Dialect/Triton/IR/Ops.cpp.inc" @@ -48,50 +50,48 @@ namespace triton { //-- StoreOp -- // Default mask -void StoreOp::build(::mlir::OpBuilder &builder, ::mlir::OperationState &state, ::mlir::Value ptr, ::mlir::Value value) { +void StoreOp::build(::mlir::OpBuilder &builder, ::mlir::OperationState &state, + ::mlir::Value ptr, ::mlir::Value value) { TensorType ptrType = ptr.getType().dyn_cast(); auto shape = ptrType.getShape(); ::mlir::Value mask = builder.create( - ptr.getLoc(), - RankedTensorType::get(shape, builder.getI1Type()), - DenseIntElementsAttr::get( - RankedTensorType::get(shape, builder.getI1Type()), true - ) - ); + ptr.getLoc(), RankedTensorType::get(shape, builder.getI1Type()), + DenseIntElementsAttr::get( + RankedTensorType::get(shape, builder.getI1Type()), true)); state.addOperands(ptr); state.addOperands(value); state.addOperands(mask); } //-- LoadOp -- -void LoadOp::build(::mlir::OpBuilder &builder, ::mlir::OperationState &state, ::mlir::Value ptr, - ::mlir::triton::CacheModifier cache, ::mlir::triton::EvictionPolicy evict, bool isVolatile) { +void LoadOp::build(::mlir::OpBuilder &builder, ::mlir::OperationState &state, + ::mlir::Value ptr, ::mlir::triton::CacheModifier cache, + ::mlir::triton::EvictionPolicy evict, bool isVolatile) { TensorType ptrType = ptr.getType().dyn_cast(); - Type elementType = ptrType.getElementType().dyn_cast().getPointeeType(); + Type elementType = + ptrType.getElementType().dyn_cast().getPointeeType(); auto shape = ptrType.getShape(); // mask ::mlir::Value mask = builder.create( - ptr.getLoc(), - RankedTensorType::get(shape, builder.getI1Type()), - DenseIntElementsAttr::get( - RankedTensorType::get(shape, builder.getI1Type()), true - ) - ); + ptr.getLoc(), RankedTensorType::get(shape, builder.getI1Type()), + DenseIntElementsAttr::get( + RankedTensorType::get(shape, builder.getI1Type()), true)); // other Type resultType = RankedTensorType::get(shape, elementType); ::mlir::Value other = builder.create( - ptr.getLoc(), - resultType, - DenseElementsAttr::get( - resultType, builder.getZeroAttr(elementType) - ) - ); + ptr.getLoc(), resultType, + DenseElementsAttr::get(resultType, builder.getZeroAttr(elementType))); state.addOperands(ptr); state.addOperands(mask); state.addOperands(other); - state.addAttribute(cacheAttrName(state.name), ::mlir::triton::CacheModifierAttr::get(builder.getContext(), cache)); - state.addAttribute(evictAttrName(state.name), ::mlir::triton::EvictionPolicyAttr::get(builder.getContext(), evict)); - state.addAttribute(isVolatileAttrName(state.name), builder.getBoolAttr(isVolatile)); + state.addAttribute( + cacheAttrName(state.name), + ::mlir::triton::CacheModifierAttr::get(builder.getContext(), cache)); + state.addAttribute( + evictAttrName(state.name), + ::mlir::triton::EvictionPolicyAttr::get(builder.getContext(), evict)); + state.addAttribute(isVolatileAttrName(state.name), + builder.getBoolAttr(isVolatile)); state.addTypes({resultType}); } diff --git a/lib/Dialect/Triton/IR/Types.cpp b/lib/Dialect/Triton/IR/Types.cpp index 66e8c7b05..5884a2ec4 100644 --- a/lib/Dialect/Triton/IR/Types.cpp +++ b/lib/Dialect/Triton/IR/Types.cpp @@ -1,6 +1,6 @@ -#include "triton/Dialect/Triton/IR/Dialect.h" #include "triton/Dialect/Triton/IR/Types.h" #include "mlir/IR/DialectImplementation.h" // required by `Types.cpp.inc` +#include "triton/Dialect/Triton/IR/Dialect.h" #include "llvm/ADT/TypeSwitch.h" // required by `Types.cpp.inc` using namespace mlir; @@ -16,7 +16,7 @@ void TritonDialect::registerTypes() { addTypes< #define GET_TYPEDEF_LIST #include "triton/Dialect/Triton/IR/Types.cpp.inc" - >(); + >(); } Type PointerType::parse(AsmParser &parser) { diff --git a/lib/Dialect/Triton/Transforms/Combine.cpp b/lib/Dialect/Triton/Transforms/Combine.cpp index 2fc073c05..ca5841aad 100644 --- a/lib/Dialect/Triton/Transforms/Combine.cpp +++ b/lib/Dialect/Triton/Transforms/Combine.cpp @@ -17,21 +17,23 @@ namespace { class CombineDotOp : public mlir::RewritePattern { public: CombineDotOp(mlir::MLIRContext *context) - : mlir::RewritePattern(mlir::RewritePattern::MatchAnyOpTypeTag(), 1, context) {} - mlir::LogicalResult matchAndRewrite(mlir::Operation *op, - mlir::PatternRewriter &rewriter) const override { + : mlir::RewritePattern(mlir::RewritePattern::MatchAnyOpTypeTag(), 1, + context) {} + mlir::LogicalResult + matchAndRewrite(mlir::Operation *op, + mlir::PatternRewriter &rewriter) const override { if (llvm::isa(op)) { if (isCandidate(op->getOperand(0)).succeeded()) { auto dotOp = op->getOperand(0).getDefiningOp(); rewriter.replaceOpWithNewOp( - op, dotOp->getResultTypes().front(), dotOp.a(), - dotOp.b(), op->getOperand(1), dotOp.allowTF32()); + op, dotOp->getResultTypes().front(), dotOp.a(), dotOp.b(), + op->getOperand(1), dotOp.allowTF32()); return mlir::success(); } else if (isCandidate(op->getOperand(1)).succeeded()) { auto dotOp = op->getOperand(1).getDefiningOp(); rewriter.replaceOpWithNewOp( - op, dotOp->getResultTypes().front(), dotOp.a(), - dotOp.b(), op->getOperand(0), dotOp.allowTF32()); + op, dotOp->getResultTypes().front(), dotOp.a(), dotOp.b(), + op->getOperand(0), dotOp.allowTF32()); return mlir::success(); } } @@ -54,7 +56,7 @@ private: return true; // broadcast(constant_0) if (auto bc = val.getDefiningOp()) { - if (mlir::matchPattern(bc.src(), mlir::m_Zero()) || + if (mlir::matchPattern(bc.src(), mlir::m_Zero()) || mlir::matchPattern(bc.src(), mlir::m_AnyZeroFloat())) return true; } @@ -68,18 +70,19 @@ private: class CombineGEPOp : public mlir::RewritePattern { public: CombineGEPOp(mlir::MLIRContext *context) - : mlir::RewritePattern(mlir::RewritePattern::MatchAnyOpTypeTag(), 1, context) {} + : mlir::RewritePattern(mlir::RewritePattern::MatchAnyOpTypeTag(), 1, + context) {} - mlir::LogicalResult matchAndRewrite(mlir::Operation *op, - mlir::PatternRewriter &rewriter) const override { + mlir::LogicalResult + matchAndRewrite(mlir::Operation *op, + mlir::PatternRewriter &rewriter) const override { if (llvm::isa(op)) { if (auto gep2 = op->getOperand(0).getDefiningOp()) { auto loc = op->getLoc(); mlir::Value newIdx = rewriter.create( - loc, op->getOperand(1), gep2->getOperand(1)); + loc, op->getOperand(1), gep2->getOperand(1)); rewriter.replaceOpWithNewOp( - op, op->getResultTypes().front(), gep2->getOperand(0), newIdx - ); + op, op->getResultTypes().front(), gep2->getOperand(0), newIdx); return mlir::success(); } } @@ -92,20 +95,21 @@ public: class CombineSelectMaskedLoadOp : public mlir::RewritePattern { public: CombineSelectMaskedLoadOp(mlir::MLIRContext *context) - : mlir::RewritePattern(mlir::RewritePattern::MatchAnyOpTypeTag(), 1, context) {} + : mlir::RewritePattern(mlir::RewritePattern::MatchAnyOpTypeTag(), 1, + context) {} - mlir::LogicalResult matchAndRewrite(mlir::Operation *op, - mlir::PatternRewriter &rewriter) const override { + mlir::LogicalResult + matchAndRewrite(mlir::Operation *op, + mlir::PatternRewriter &rewriter) const override { if (llvm::isa(op)) { if (auto load = op->getOperand(1).getDefiningOp()) { mlir::Value cond = op->getOperand(0); if (auto bc = load.mask().getDefiningOp()) { if (bc.src().getDefiningOp() == cond.getDefiningOp()) { rewriter.replaceOpWithNewOp( - op, op->getResultTypes().front(), - load.ptr(), load.mask(), op->getOperand(2), - load.cache(), load.evict(), load.isVolatile() - ); + op, op->getResultTypes().front(), load.ptr(), load.mask(), + op->getOperand(2), load.cache(), load.evict(), + load.isVolatile()); return mlir::success(); } } @@ -120,11 +124,11 @@ public: class CombineBroadcastConstantOp : public mlir::RewritePattern { public: CombineBroadcastConstantOp(mlir::MLIRContext *context) - : mlir::RewritePattern(mlir::RewritePattern::MatchAnyOpTypeTag(), 1, - context) {} - + : mlir::RewritePattern(mlir::RewritePattern::MatchAnyOpTypeTag(), 1, + context) {} + LogicalResult matchAndRewrite(Operation *op, - PatternRewriter &rewriter) const override { + PatternRewriter &rewriter) const override { if (auto broadcast = llvm::dyn_cast(op)) { if (auto cst = broadcast.src().getDefiningOp()) { Attribute value = cst.getValue(); @@ -132,15 +136,14 @@ public: if (auto denseValue = value.dyn_cast()) { if (!denseValue.isSplat()) return failure(); - value = DenseElementsAttr::get(resType, denseValue.getSplatValue()); + value = DenseElementsAttr::get(resType, + denseValue.getSplatValue()); } else { if (!value.isa()) return failure(); value = DenseElementsAttr::get(resType, value); } - rewriter.replaceOpWithNewOp( - op, value, resType - ); + rewriter.replaceOpWithNewOp(op, value, resType); return success(); } } diff --git a/lib/Dialect/TritonGPU/IR/Dialect.cpp b/lib/Dialect/TritonGPU/IR/Dialect.cpp index d66a08892..127f8366e 100644 --- a/lib/Dialect/TritonGPU/IR/Dialect.cpp +++ b/lib/Dialect/TritonGPU/IR/Dialect.cpp @@ -11,19 +11,18 @@ using namespace mlir::triton::gpu; // parse an array of integers static LogicalResult parseIntArrayAttr(AsmParser &parser, const NamedAttribute &attr, - /*SmallVector*/auto &res, - StringRef desc) { + /*SmallVector*/ auto &res, + StringRef desc) { auto arrayAttr = attr.getValue().dyn_cast(); if (!arrayAttr) { - parser.emitError(parser.getNameLoc(), "expected an array for ") - << desc; + parser.emitError(parser.getNameLoc(), "expected an array for ") << desc; return failure(); } for (Attribute i : arrayAttr) { auto intAttr = i.dyn_cast(); if (!intAttr) { parser.emitError(parser.getNameLoc(), "expected an integer value in ") - << desc; + << desc; return failure(); } res.push_back(intAttr.getUInt()); @@ -46,7 +45,7 @@ static Attribute parseBlocked(AsmParser &parser, Type type) { return {}; if (parser.parseGreater().failed()) return {}; - + SmallVector threadTileSize; SmallVector warpTileSize; SmallVector blockTileSize; @@ -55,19 +54,23 @@ static Attribute parseBlocked(AsmParser &parser, Type type) { for (const NamedAttribute &attr : dict) { if (attr.getName() == "threadTileSize") { - if (parseIntArrayAttr(parser, attr, threadTileSize, "thread tile size").failed()) + if (parseIntArrayAttr(parser, attr, threadTileSize, "thread tile size") + .failed()) return {}; } else if (attr.getName() == "warpTileSize") { - if (parseIntArrayAttr(parser, attr, warpTileSize, "warp tile size").failed()) + if (parseIntArrayAttr(parser, attr, warpTileSize, "warp tile size") + .failed()) return {}; } else if (attr.getName() == "blockTileSize") { - if (parseIntArrayAttr(parser, attr, blockTileSize, "block tile size").failed()) + if (parseIntArrayAttr(parser, attr, blockTileSize, "block tile size") + .failed()) return {}; } else if (attr.getName() == "order") { if (parseIntArrayAttr(parser, attr, order, "order").failed()) return {}; } else if (attr.getName() == "broadcastAxis") { - if (parseIntArrayAttr(parser, attr, broadcastAxis, "broadcastAxis").failed()) + if (parseIntArrayAttr(parser, attr, broadcastAxis, "broadcastAxis") + .failed()) return {}; } else { parser.emitError(parser.getNameLoc(), "unexpected key: ") @@ -76,12 +79,9 @@ static Attribute parseBlocked(AsmParser &parser, Type type) { } } - return parser.getChecked(parser.getContext(), - threadTileSize, - warpTileSize, - blockTileSize, - order, - broadcastAxis); + return parser.getChecked( + parser.getContext(), threadTileSize, warpTileSize, blockTileSize, order, + broadcastAxis); } static void printBlocked(AsmPrinter &printer, auto *attr) { @@ -94,8 +94,7 @@ static void printBlocked(AsmPrinter &printer, auto *attr) { << "}>"; } -Attribute -TritonGPUBlockedEncodingAttr::parse(AsmParser &parser, Type type) { +Attribute TritonGPUBlockedEncodingAttr::parse(AsmParser &parser, Type type) { parseBlocked(parser, type); } @@ -103,8 +102,8 @@ void TritonGPUBlockedEncodingAttr::print(mlir::AsmPrinter &printer) const { printBlocked(printer, this); } -Attribute -TritonGPUBlockedMulticastEncodingAttr::parse(AsmParser &parser, Type type) { +Attribute TritonGPUBlockedMulticastEncodingAttr::parse(AsmParser &parser, + Type type) { parseBlocked(parser, type); } @@ -131,38 +130,37 @@ static Attribute parseMma(AsmParser &parser, Type type) { for (const NamedAttribute &attr : dict) { if (attr.getName() == "fragmentPerWarp") { - if (parseIntArrayAttr(parser, attr, fragmentPerWarp, "fragmentPerWarp").failed()) + if (parseIntArrayAttr(parser, attr, fragmentPerWarp, "fragmentPerWarp") + .failed()) return {}; } else if (attr.getName() == "shapePerWarp") { - if (parseIntArrayAttr(parser, attr, shapePerWarp, "shapePerWarp").failed()) + if (parseIntArrayAttr(parser, attr, shapePerWarp, "shapePerWarp") + .failed()) return {}; } else if (attr.getName() == "warpPerTile") { if (parseIntArrayAttr(parser, attr, warpPerTile, "warpPerTile").failed()) return {}; } else if (attr.getName() == "shapePerTile") { - if (parseIntArrayAttr(parser, attr, shapePerTile, "shapePerTile").failed()) + if (parseIntArrayAttr(parser, attr, shapePerTile, "shapePerTile") + .failed()) return {}; } else if (attr.getName() == "repetitions") { if (parseIntArrayAttr(parser, attr, repetitions, "repetitions").failed()) return {}; } else if (attr.getName() == "contigPerThread") { - if (parseIntArrayAttr(parser, attr, contigPerThread, "contigPerThread").failed()) + if (parseIntArrayAttr(parser, attr, contigPerThread, "contigPerThread") + .failed()) return {}; } else { parser.emitError(parser.getNameLoc(), "unexpected key: ") - << attr.getName().strref(); + << attr.getName().strref(); return {}; } } - return parser.getChecked(parser.getContext(), - fragmentPerWarp, - shapePerWarp, - warpPerTile, - shapePerTile, - repetitions, - contigPerThread, - broadcastAxis); + return parser.getChecked( + parser.getContext(), fragmentPerWarp, shapePerWarp, warpPerTile, + shapePerTile, repetitions, contigPerThread, broadcastAxis); } static void printMma(AsmPrinter &printer, auto *attr) { @@ -176,8 +174,7 @@ static void printMma(AsmPrinter &printer, auto *attr) { << "}>"; } -Attribute -TritonGPUMmaEncodingAttr::parse(AsmParser &parser, Type type) { +Attribute TritonGPUMmaEncodingAttr::parse(AsmParser &parser, Type type) { return parseMma(parser, type); } @@ -185,8 +182,8 @@ void TritonGPUMmaEncodingAttr::print(AsmPrinter &printer) const { printMma(printer, this); } -Attribute -TritonGPUMmaMulticastEncodingAttr::parse(AsmParser &parser, Type type) { +Attribute TritonGPUMmaMulticastEncodingAttr::parse(AsmParser &parser, + Type type) { return parseMma(parser, type); } @@ -194,8 +191,7 @@ void TritonGPUMmaMulticastEncodingAttr::print(AsmPrinter &printer) const { printMma(printer, this); } -Attribute -TritonGPUSharedEncodingAttr::parse(AsmParser &parser, Type type) { +Attribute TritonGPUSharedEncodingAttr::parse(AsmParser &parser, Type type) { if (parser.parseLess().failed()) return {}; // Parse the data as a dictionary @@ -210,8 +206,7 @@ TritonGPUSharedEncodingAttr::parse(AsmParser &parser, Type type) { unsigned maxPhase = 0; SmallVector order; - auto parseUInt = [&parser](const NamedAttribute &attr, - unsigned &value, + auto parseUInt = [&parser](const NamedAttribute &attr, unsigned &value, StringRef desc) -> LogicalResult { auto intAttr = attr.getValue().dyn_cast(); if (!intAttr) { @@ -237,29 +232,25 @@ TritonGPUSharedEncodingAttr::parse(AsmParser &parser, Type type) { return {}; } else { parser.emitError(parser.getNameLoc(), "unexpected key: ") - << attr.getName().strref(); + << attr.getName().strref(); return {}; } } - return parser.getChecked(parser.getContext(), - vec, - perPhase, - maxPhase, - order); + return parser.getChecked( + parser.getContext(), vec, perPhase, maxPhase, order); } void TritonGPUSharedEncodingAttr::print(AsmPrinter &printer) const { printer << "<{" - << "vec = " << getVec() - << ", perPhase = " << getPerPhase() - << ", maxPhase = " << getMaxPhase() - << ", order = [" << getOrder() << "]" + << "vec = " << getVec() << ", perPhase = " << getPerPhase() + << ", maxPhase = " << getMaxPhase() << ", order = [" << getOrder() + << "]" << "}>"; } class TritonGPUOpAsmInterface : public OpAsmDialectInterface { - public: +public: using OpAsmDialectInterface::OpAsmDialectInterface; AliasResult getAlias(Attribute attr, raw_ostream &os) const override { @@ -289,7 +280,7 @@ class TritonGPUOpAsmInterface : public OpAsmDialectInterface { OpAsmDialectInterface::getAlias(attr, os); } - private: +private: static void printMma(const auto &attr, raw_ostream &os) { TritonGPUOpAsmInterface::printArray(attr.getFragmentPerWarp(), os); TritonGPUOpAsmInterface::printArray(attr.getShapePerWarp(), os); @@ -338,7 +329,7 @@ void TritonGPUDialect::initialize() { addOperations< #define GET_OP_LIST #include "triton/Dialect/TritonGPU/IR/Ops.cpp.inc" - >(); + >(); addInterfaces(); } @@ -349,7 +340,8 @@ namespace triton { static Type getI1SameShape(Type type) { auto i1Type = IntegerType::get(type.getContext(), 1); if (auto tensorType = type.dyn_cast()) - return RankedTensorType::get(tensorType.getShape(), i1Type, tensorType.getEncoding()); + return RankedTensorType::get(tensorType.getShape(), i1Type, + tensorType.getEncoding()); return Type(); } @@ -368,8 +360,8 @@ static Type getPointeeType(Type type) { return Type(); } -} -} +} // namespace triton +} // namespace mlir static LogicalResult verify(CopyAsyncOp op) { Type resType = op.getResult().getType(); @@ -385,11 +377,9 @@ static LogicalResult verify(CopyAsyncOp op) { #define GET_OP_CLASSES #include "triton/Dialect/TritonGPU/IR/Ops.cpp.inc" - // verify TritonGPU ops -LogicalResult -TritonGPUDialect::verifyOperationAttribute(Operation *op, - NamedAttribute attr) { +LogicalResult TritonGPUDialect::verifyOperationAttribute(Operation *op, + NamedAttribute attr) { // TODO: fill this. return success(); } diff --git a/lib/Dialect/TritonGPU/Transforms/Combine.cpp b/lib/Dialect/TritonGPU/Transforms/Combine.cpp index 0052a3975..92b9127a3 100644 --- a/lib/Dialect/TritonGPU/Transforms/Combine.cpp +++ b/lib/Dialect/TritonGPU/Transforms/Combine.cpp @@ -27,8 +27,8 @@ namespace { #define GEN_PASS_CLASSES #include "triton/Dialect/TritonGPU/Transforms/Passes.h.inc" -class TritonGPUCombineOpsPass - : public TritonGPUCombineOpsBase { +class TritonGPUCombineOpsPass + : public TritonGPUCombineOpsBase { public: void runOnOperation() override { MLIRContext *context = &getContext(); diff --git a/lib/Dialect/TritonGPU/Transforms/Pipeline.cpp b/lib/Dialect/TritonGPU/Transforms/Pipeline.cpp index b68276678..13e807921 100644 --- a/lib/Dialect/TritonGPU/Transforms/Pipeline.cpp +++ b/lib/Dialect/TritonGPU/Transforms/Pipeline.cpp @@ -6,12 +6,11 @@ //===----------------------------------------------------------------------===// // // This file implements loop software pipelining -// The implementation here is inspired by the pipeline pass in Triton (-v2.0) +// The implementation here is inspired by the pipeline pass in Triton (-v2.0) // and SCF's LoopPipelining. // //===----------------------------------------------------------------------===// - using namespace mlir; #define GEN_PASS_CLASSES @@ -41,14 +40,15 @@ class LoopPipeliner { /// Block arguments that loads depend on DenseSet depArgs; /// Operations (inside the loop body) that loads depend on - DenseSet depOps; + DenseSet depOps; /// collect values that v depends on and are defined inside the loop void collectDeps(Value v, int stages, DenseSet &deps); void setValueMapping(Value origin, Value newValue, int stage); + public: - LoopPipeliner(scf::ForOp forOp, int numStages) + LoopPipeliner(scf::ForOp forOp, int numStages) : forOp(forOp), numStages(numStages) { // cache yieldOp yieldOp = cast(forOp.getBody()->getTerminator()); @@ -86,7 +86,7 @@ void LoopPipeliner::collectDeps(Value v, int stages, DenseSet &deps) { if (auto arg = v.dyn_cast()) { deps.insert(v); // Note: we have iv as the first arg, so the op idx is arg.getArgNumber()-1 - collectDeps(yieldOp->getOperand(arg.getArgNumber() - 1), stages-1, deps); + collectDeps(yieldOp->getOperand(arg.getArgNumber() - 1), stages - 1, deps); } else { // value // v might be in deps, but we still need to visit v. // This is because v might depends on value in previous iterations @@ -123,8 +123,8 @@ LogicalResult LoopPipeliner::initialize() { } // for (triton::LoadOp loadOp : allLoads) { - // llvm::errs() << loadOp << " depends on: #" << loadDeps[loadOp].size() << " values\n"; - // for (Value dep : loadDeps[loadOp]) + // llvm::errs() << loadOp << " depends on: #" << loadDeps[loadOp].size() << + // " values\n"; for (Value dep : loadDeps[loadOp]) // llvm::errs() << dep << "\n"; // llvm::errs() << "\n"; // } @@ -147,9 +147,13 @@ LogicalResult LoopPipeliner::initialize() { if (isCandiate && loadOp.getResult().hasOneUse()) { isCandiate = false; Operation *use = *loadOp.getResult().getUsers().begin(); - if (auto convertLayout = llvm::dyn_cast(use)) { - if (auto tensorType = convertLayout.getResult().getType().dyn_cast()) { - if (tensorType.getEncoding().isa()) { + if (auto convertLayout = + llvm::dyn_cast(use)) { + if (auto tensorType = convertLayout.getResult() + .getType() + .dyn_cast()) { + if (tensorType.getEncoding() + .isa()) { isCandiate = true; loadsMapping[loadOp] = convertLayout; } @@ -162,7 +166,6 @@ LogicalResult LoopPipeliner::initialize() { loads.insert(loadOp); } - // we have some loads to pipeline if (!loads.empty()) { // update depArgs & depOps @@ -202,10 +205,10 @@ void LoopPipeliner::emitPrologue() { // special handling for loop condition as there is no condition in ForOp Value loopCond = builder.create( - iv.getLoc(), arith::CmpIPredicate::slt, iv, forOp.getUpperBound()); + iv.getLoc(), arith::CmpIPredicate::slt, iv, forOp.getUpperBound()); // rematerialize peeled values - SmallVector orderedDeps; + SmallVector orderedDeps; for (Operation &op : forOp.getLoopBody().front()) { if (depOps.contains(&op)) orderedDeps.push_back(&op); @@ -221,10 +224,9 @@ void LoopPipeliner::emitPrologue() { // TODO: check if the hardware supports copyasync if (auto loadOp = llvm::dyn_cast(op)) { newOp = builder.create( - op->getLoc(), loadsMapping[loadOp].getType(), - loadOp.ptr(), loadOp.mask(), loadOp.other(), - loadOp.cache(), loadOp.evict(), loadOp.isVolatile() - ); + op->getLoc(), loadsMapping[loadOp].getType(), loadOp.ptr(), + loadOp.mask(), loadOp.other(), loadOp.cache(), loadOp.evict(), + loadOp.isVolatile()); } else llvm_unreachable("This should be LoadOp"); } else @@ -245,12 +247,10 @@ void LoopPipeliner::emitPrologue() { // assert(I1 or TensorOf<[I1]>); OpBuilder::InsertionGuard g(builder); builder.setInsertionPoint(newOp); - Value splatCond = builder.create(mask.getLoc(), - mask.getType(), - loopCond); - Value newMask = builder.create(mask.getLoc(), - mask, - splatCond); + Value splatCond = builder.create( + mask.getLoc(), mask.getType(), loopCond); + Value newMask = + builder.create(mask.getLoc(), mask, splatCond); newOp->setOperand(1, newMask); } @@ -264,8 +264,9 @@ void LoopPipeliner::emitPrologue() { // update mapping for loop-carried values (args) for (OpOperand &operand : yieldOp->getOpOperands()) { if (operand.get() == op->getResult(dstIdx)) - setValueMapping(forOp.getRegionIterArgs()[operand.getOperandNumber()], - newOp->getResult(dstIdx), stage + 1); + setValueMapping( + forOp.getRegionIterArgs()[operand.getOperandNumber()], + newOp->getResult(dstIdx), stage + 1); } } } @@ -296,21 +297,19 @@ scf::ForOp LoopPipeliner::createNewForOp() { size_t depArgsBeginIdx = newLoopArgs.size(); for (BlockArgument depArg : depArgs) { depArgsIdx[depArg] = newLoopArgs.size(); - newLoopArgs.push_back(valueMapping[depArg][numStages-1]); + newLoopArgs.push_back(valueMapping[depArg][numStages - 1]); } size_t nextIVIdx = newLoopArgs.size(); - newLoopArgs.push_back(valueMapping[forOp.getInductionVar()][numStages-2]); + newLoopArgs.push_back(valueMapping[forOp.getInductionVar()][numStages - 2]); for (size_t i = 0; i < newLoopArgs.size(); ++i) assert(newLoopArgs[i]); // 1. signature of the new ForOp - auto newForOp = builder.create(forOp.getLoc(), - forOp.getLowerBound(), - forOp.getUpperBound(), - forOp.getStep(), - newLoopArgs); + auto newForOp = builder.create( + forOp.getLoc(), forOp.getLowerBound(), forOp.getUpperBound(), + forOp.getStep(), newLoopArgs); // 2. body of the new ForOp builder.setInsertionPointToStart(newForOp.getBody()); @@ -329,15 +328,15 @@ scf::ForOp LoopPipeliner::createNewForOp() { // 3. replace loads with block args (from prologue) for (size_t idx = 0; idx < loads.size(); ++idx) { Value load = loads[idx]; - assert(load.hasOneUse() && "we assume that this load has one use (ConvertLayout)"); + assert(load.hasOneUse() && + "we assume that this load has one use (ConvertLayout)"); Value loadUse = load.getUsers().begin()->getResult(0); mapping.lookup(loadUse).replaceAllUsesWith( - newForOp.getRegionIterArgs()[loadIdx + idx*(numStages-1)]); + newForOp.getRegionIterArgs()[loadIdx + idx * (numStages - 1)]); } - // 4. prefetch the next iteration - SmallVector orderedDeps; + SmallVector orderedDeps; for (Operation &op : forOp.getLoopBody().front()) { if (depOps.contains(&op)) orderedDeps.push_back(&op); @@ -350,41 +349,39 @@ scf::ForOp LoopPipeliner::createNewForOp() { DenseMap depArgsMapping; size_t argIdx = 0; for (BlockArgument arg : depArgs) { - nextMapping.map(arg, newForOp.getRegionIterArgs()[argIdx + depArgsBeginIdx]); + nextMapping.map(arg, + newForOp.getRegionIterArgs()[argIdx + depArgsBeginIdx]); ++argIdx; } // special handling for iv & loop condition - Value nextIV = builder.create(newForOp.getInductionVar().getLoc(), - newForOp.getRegionIterArgs()[nextIVIdx], - newForOp.getStep()); - Value nextLoopCond = builder.create( - nextIV.getLoc(), arith::CmpIPredicate::slt, - nextIV, newForOp.getUpperBound()); + Value nextIV = builder.create( + newForOp.getInductionVar().getLoc(), + newForOp.getRegionIterArgs()[nextIVIdx], newForOp.getStep()); + Value nextLoopCond = + builder.create(nextIV.getLoc(), arith::CmpIPredicate::slt, + nextIV, newForOp.getUpperBound()); for (Operation *op : orderedDeps) { Operation *nextOp = nullptr; // update loading mask if (loads.contains(op->getResult(0))) { auto loadOp = llvm::cast(op); Value mask = loadOp.mask(); - Value splatCond = builder.create(mask.getLoc(), - mask.getType(), - nextLoopCond); - Value newMask = builder.create(mask.getLoc(), - splatCond, - nextMapping.lookupOrDefault(mask)); - // if mask is defined outside the loop, don't update the map more than once + Value splatCond = builder.create( + mask.getLoc(), mask.getType(), nextLoopCond); + Value newMask = builder.create( + mask.getLoc(), splatCond, nextMapping.lookupOrDefault(mask)); + // if mask is defined outside the loop, don't update the map more than + // once if (!(forOp.isDefinedOutsideOfLoop(mask) && nextMapping.contains(mask))) nextMapping.map(mask, newMask); // TODO: more elegant way to do this? nextOp = builder.create( - op->getLoc(), loadsMapping[op->getResult(0)].getType(), - nextMapping.lookupOrDefault(loadOp.ptr()), - nextMapping.lookupOrDefault(loadOp.mask()), - nextMapping.lookupOrDefault(loadOp.other()), - loadOp.cache(), loadOp.evict(), loadOp.isVolatile() - ); - } - else + op->getLoc(), loadsMapping[op->getResult(0)].getType(), + nextMapping.lookupOrDefault(loadOp.ptr()), + nextMapping.lookupOrDefault(loadOp.mask()), + nextMapping.lookupOrDefault(loadOp.other()), loadOp.cache(), + loadOp.evict(), loadOp.isVolatile()); + } else nextOp = builder.clone(*op, nextMapping); // llvm::errs() << "epilogue cloning...: " << *op << "\n"; // update mapping of results @@ -411,15 +408,16 @@ scf::ForOp LoopPipeliner::createNewForOp() { for (size_t idx = 0; idx < loads.size(); ++idx) { Value load = loads[idx]; for (int stage = 1; stage < numStages - 1; ++stage) { - yieldValues.push_back(newForOp.getRegionIterArgs()[ - loadIdx + idx*(numStages-1) + stage - ]); + yieldValues.push_back( + newForOp + .getRegionIterArgs()[loadIdx + idx * (numStages - 1) + stage]); } yieldValues.push_back(nextMapping.lookup(load)); } for (size_t i = depArgsBeginIdx; i < nextIVIdx; ++i) - yieldValues.push_back(depArgsMapping.lookup(newForOp.getRegionIterArgs()[i])); + yieldValues.push_back( + depArgsMapping.lookup(newForOp.getRegionIterArgs()[i])); yieldValues.push_back(nextIV); builder.setInsertionPointToEnd(newForOp.getBody()); builder.create(forOp.getBody()->getTerminator()->getLoc(), @@ -430,9 +428,7 @@ scf::ForOp LoopPipeliner::createNewForOp() { // ref: mlir/lib/Dialect/SCF/Transforms/LoopPipelining.cpp struct PipelinePass : public TritonGPUPipelineBase { PipelinePass() = default; - PipelinePass(int numStages) { - this->numStages = numStages; - } + PipelinePass(int numStages) { this->numStages = numStages; } void runOnOperation() override { int numStages = this->numStages; diff --git a/lib/Dialect/TritonGPU/Transforms/TritonGPUConversion.cpp b/lib/Dialect/TritonGPU/Transforms/TritonGPUConversion.cpp index bedf9f38a..091ca05d3 100644 --- a/lib/Dialect/TritonGPU/Transforms/TritonGPUConversion.cpp +++ b/lib/Dialect/TritonGPU/Transforms/TritonGPUConversion.cpp @@ -1,7 +1,7 @@ #include "triton/Dialect/TritonGPU/Transforms/TritonGPUConversion.h" +#include "mlir/IR/BlockAndValueMapping.h" #include "triton/Dialect/Triton/IR/Dialect.h" #include "triton/Dialect/TritonGPU/IR/Dialect.h" -#include "mlir/IR/BlockAndValueMapping.h" #include using namespace mlir; @@ -10,7 +10,7 @@ using namespace mlir::triton::gpu; // // TypeConverter // -TritonGPUTypeConverter::TritonGPUTypeConverter(MLIRContext *context, +TritonGPUTypeConverter::TritonGPUTypeConverter(MLIRContext *context, int numThreads) : context(context), numThreads(numThreads) { // TODO: how does MLIR pick the right conversion? @@ -38,14 +38,14 @@ TritonGPUTypeConverter::TritonGPUTypeConverter(MLIRContext *context, // or assert no encoding? // Now we assume: - // contiguous = 1, order = 0, 1, 2, ..., + // contiguous = 1, order = 0, 1, 2, ..., llvm::SmallVector threadTileSize(rank, 1); // naive layout llvm::SmallVector warpTileSize(rank, 1); llvm::SmallVector blockTileSize(rank); llvm::SmallVector order(rank); llvm::SmallVector broadcastAxis; int remainingThreads = numThreads; - int remainingLanes = /*warp size*/32; + int remainingLanes = /*warp size*/ 32; for (int64_t dim = 0; dim < rank; ++dim) { blockTileSize[dim] = std::clamp(remainingThreads, 1, int(shape[dim])); warpTileSize[dim] = std::clamp(remainingLanes, 1, int(shape[dim])); @@ -56,7 +56,8 @@ TritonGPUTypeConverter::TritonGPUTypeConverter(MLIRContext *context, // TODO: will we need repetition? } Attribute encoding = triton::gpu::TritonGPUBlockedEncodingAttr::get( - context, threadTileSize, warpTileSize, blockTileSize, order, broadcastAxis); + context, threadTileSize, warpTileSize, blockTileSize, order, + broadcastAxis); return RankedTensorType::get(shape, elementType, encoding); }); @@ -65,8 +66,9 @@ TritonGPUTypeConverter::TritonGPUTypeConverter(MLIRContext *context, // // This will be called when (newArgType != origArgType) // This will create newArg, and map(origArg, newArg) - addArgumentMaterialization([&](OpBuilder &builder, RankedTensorType tensorType, - ValueRange inputs, Location loc) { + addArgumentMaterialization([&](OpBuilder &builder, + RankedTensorType tensorType, ValueRange inputs, + Location loc) { llvm_unreachable("Not implemented"); return llvm::None; }); @@ -74,7 +76,7 @@ TritonGPUTypeConverter::TritonGPUTypeConverter(MLIRContext *context, // If the origValue still has live user(s), use this to // convert origValue to newValue addSourceMaterialization([&](OpBuilder &builder, RankedTensorType tensorType, - ValueRange inputs, Location loc) { + ValueRange inputs, Location loc) { llvm_unreachable("Not implemented"); return llvm::None; }); @@ -83,7 +85,7 @@ TritonGPUTypeConverter::TritonGPUTypeConverter(MLIRContext *context, // where, desiredType = typeConverter->convertType(origType) // NOTE: only for remapped values. addTargetMaterialization([&](OpBuilder &builder, RankedTensorType tensorType, - ValueRange inputs, Location loc) { + ValueRange inputs, Location loc) { llvm_unreachable("Not implemented"); return llvm::None; }); @@ -93,30 +95,31 @@ TritonGPUTypeConverter::TritonGPUTypeConverter(MLIRContext *context, // TritonGPUConversion // TritonGPUConversionTarget::TritonGPUConversionTarget( - MLIRContext &context, TritonGPUTypeConverter &typeConverter) + MLIRContext &context, TritonGPUTypeConverter &typeConverter) : ConversionTarget(context), typeConverter(typeConverter) { // TODO: we should also verify ops of TritonGPUDialect addLegalDialect(); // Some ops from SCF are illegal - addIllegalOp(); - - addDynamicallyLegalDialect([&](Operation *op) { - if (typeConverter.isLegal(op)) - return true; - return false; - }); + addIllegalOp(); + addDynamicallyLegalDialect( + [&](Operation *op) { + if (typeConverter.isLegal(op)) + return true; + return false; + }); // We have requirements for the data layouts addDynamicallyLegalOp([this](triton::DotOp dotOp) -> bool { - Attribute aEncoding = dotOp.a().getType().cast().getEncoding(); - Attribute bEncoding = dotOp.b().getType().cast().getEncoding(); - if (aEncoding && aEncoding.isa() && + Attribute aEncoding = + dotOp.a().getType().cast().getEncoding(); + Attribute bEncoding = + dotOp.b().getType().cast().getEncoding(); + if (aEncoding && + aEncoding.isa() && bEncoding && bEncoding.isa()) return true; // // TODO: we should delete this @@ -124,7 +127,6 @@ TritonGPUConversionTarget::TritonGPUConversionTarget( // return true; return false; }); - } // %dst = tt.broadcast %src @@ -133,12 +135,10 @@ TritonGPUConversionTarget::TritonGPUConversionTarget( // %bcst = tt.broadcast %newSrc // %dst = convert_layout %bcst LogicalResult TritonGPUConversionTarget::refineLayouts(ModuleOp mod, - int numThreads) { + int numThreads) { // collect broadcasts SmallVector broadcasts; - mod.walk([&](triton::BroadcastOp op) { - broadcasts.push_back(op); - }); + mod.walk([&](triton::BroadcastOp op) { broadcasts.push_back(op); }); BlockAndValueMapping mapping; for (auto broadcast : broadcasts) { @@ -161,20 +161,14 @@ LogicalResult TritonGPUConversionTarget::refineLayouts(ModuleOp mod, broadcastAxis.push_back(ax); Attribute originSrcEnc = tensorType.getEncoding(); - if (auto blockedEnc = originSrcEnc.dyn_cast()) { + if (auto blockedEnc = + originSrcEnc.dyn_cast()) { auto newSrcEnc = TritonGPUBlockedMulticastEncodingAttr::get( - blockedEnc.getContext(), - blockedEnc.getThreadTileSize(), - blockedEnc.getWarpTileSize(), - blockedEnc.getBlockTileSize(), - blockedEnc.getOrder(), - broadcastAxis - ); + blockedEnc.getContext(), blockedEnc.getThreadTileSize(), + blockedEnc.getWarpTileSize(), blockedEnc.getBlockTileSize(), + blockedEnc.getOrder(), broadcastAxis); newSrcType = RankedTensorType::get( - tensorType.getShape(), - tensorType.getElementType(), - newSrcEnc - ); + tensorType.getShape(), tensorType.getElementType(), newSrcEnc); } else llvm_unreachable("src of broadcast should have blocked encoding"); } else { @@ -186,34 +180,25 @@ LogicalResult TritonGPUConversionTarget::refineLayouts(ModuleOp mod, // create new src if (!isSrcScalar) // we don't need to convert layout for scalar values - src = builder.create( - src.getLoc(), newSrcType, src - ); + src = builder.create(src.getLoc(), + newSrcType, src); // create new broadcast // compute new type (encoding) auto originDstEnc = originDstTensorType.getEncoding() - .dyn_cast(); + .dyn_cast(); auto newEnc = TritonGPUBlockedEncodingAttr::get( - originDstEnc.getContext(), - originDstEnc.getThreadTileSize(), - originDstEnc.getWarpTileSize(), - originDstEnc.getBlockTileSize(), - originDstEnc.getOrder(), - broadcastAxis - ); - auto newType = RankedTensorType::get( - originDstTensorType.getShape(), - originDstTensorType.getElementType(), - newEnc - ); - Value newBroadcast = builder.create( - broadcast.getLoc(), newType, src - ); + originDstEnc.getContext(), originDstEnc.getThreadTileSize(), + originDstEnc.getWarpTileSize(), originDstEnc.getBlockTileSize(), + originDstEnc.getOrder(), broadcastAxis); + auto newType = + RankedTensorType::get(originDstTensorType.getShape(), + originDstTensorType.getElementType(), newEnc); + Value newBroadcast = + builder.create(broadcast.getLoc(), newType, src); // we don't want to change the encoding of the result Value newDst = builder.create( - broadcast.getLoc(), originDstType, newBroadcast - ); + broadcast.getLoc(), originDstType, newBroadcast); broadcast.replaceAllUsesWith(newDst); mapping.map(broadcast, newDst); diff --git a/lib/Dialect/TritonGPU/Transforms/Verifier.cpp b/lib/Dialect/TritonGPU/Transforms/Verifier.cpp index 16e1d3ec6..e88799927 100644 --- a/lib/Dialect/TritonGPU/Transforms/Verifier.cpp +++ b/lib/Dialect/TritonGPU/Transforms/Verifier.cpp @@ -5,7 +5,6 @@ using namespace mlir; - #define GEN_PASS_CLASSES #include "triton/Dialect/TritonGPU/Transforms/Passes.h.inc" @@ -37,28 +36,30 @@ private: if (!encoding.isa()) return dotOp.emitError() << name << " should be of shared layout"; } else - return dotOp.emitError() << name << "'s type should be of RankedTensorType"; + return dotOp.emitError() + << name << "'s type should be of RankedTensorType"; } Attribute cLayout; for (auto it : llvm::zip(llvm::SmallVector{cType, dType}, - llvm::SmallVector{'c', 'd'})) { + llvm::SmallVector{'c', 'd'})) { Type type = std::get<0>(it); char name = std::get<1>(it); if (auto tensorType = type.dyn_cast()) { Attribute encoding = tensorType.getEncoding(); if (!encoding) return dotOp.emitError() << name << " should have encoding"; - if (!encoding.isa() && + if (!encoding.isa() && !encoding.isa()) - return dotOp.emitError() << name << " should be of distributed layout"; + return dotOp.emitError() + << name << " should be of distributed layout"; if (name == 'c') cLayout = encoding; else if (encoding != cLayout) return dotOp.emitError() << "d & c should have the same layout"; } else - return dotOp.emitError() << name - << "'s type should be of RankedTensorType"; + return dotOp.emitError() + << name << "'s type should be of RankedTensorType"; } // signalPassFailure(); @@ -89,7 +90,7 @@ private: } void verifyImpl(Operation *op) { - if(verifySingleOp(op).failed()) + if (verifySingleOp(op).failed()) signalPassFailure(); // verify that all child regions are ok diff --git a/lib/driver/dispatch.cc b/lib/driver/dispatch.cc old mode 100755 new mode 100644 index 9e2aca432..427453b38 --- a/lib/driver/dispatch.cc +++ b/lib/driver/dispatch.cc @@ -1,107 +1,152 @@ /* Copyright 2015-2017 Philippe Tillet -* -* Permission is hereby granted, free of charge, to any person obtaining -* a copy of this software and associated documentation files -* (the "Software"), to deal in the Software without restriction, -* including without limitation the rights to use, copy, modify, merge, -* publish, distribute, sublicense, and/or sell copies of the Software, -* and to permit persons to whom the Software is furnished to do so, -* subject to the following conditions: -* -* The above copyright notice and this permission notice shall be -* included in all copies or substantial portions of the Software. -* -* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, -* EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF -* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. -* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY -* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, -* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE -* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. -*/ + * + * Permission is hereby granted, free of charge, to any person obtaining + * a copy of this software and associated documentation files + * (the "Software"), to deal in the Software without restriction, + * including without limitation the rights to use, copy, modify, merge, + * publish, distribute, sublicense, and/or sell copies of the Software, + * and to permit persons to whom the Software is furnished to do so, + * subject to the following conditions: + * + * The above copyright notice and this permission notice shall be + * included in all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, + * EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF + * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. + * IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY + * CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, + * TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE + * SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + */ #include "triton/driver/dispatch.h" -namespace triton -{ -namespace driver -{ +namespace triton { +namespace driver { -//Helpers for function definition -#define DEFINE0(init, hlib, ret, fname) ret dispatch::fname()\ -{return f_impl(hlib, fname, fname ## _, #fname); }\ -void* dispatch::fname ## _; +// Helpers for function definition +#define DEFINE0(init, hlib, ret, fname) \ + ret dispatch::fname() { \ + return f_impl(hlib, fname, fname##_, #fname); \ + } \ + void *dispatch::fname##_; -#define DEFINE1(init, hlib, ret, fname, t1) ret dispatch::fname(t1 a)\ -{return f_impl(hlib, fname, fname ## _, #fname, a); }\ -void* dispatch::fname ## _; +#define DEFINE1(init, hlib, ret, fname, t1) \ + ret dispatch::fname(t1 a) { \ + return f_impl(hlib, fname, fname##_, #fname, a); \ + } \ + void *dispatch::fname##_; -#define DEFINE2(init, hlib, ret, fname, t1, t2) ret dispatch::fname(t1 a, t2 b)\ -{return f_impl(hlib, fname, fname ## _, #fname, a, b); }\ -void* dispatch::fname ## _; +#define DEFINE2(init, hlib, ret, fname, t1, t2) \ + ret dispatch::fname(t1 a, t2 b) { \ + return f_impl(hlib, fname, fname##_, #fname, a, b); \ + } \ + void *dispatch::fname##_; -#define DEFINE3(init, hlib, ret, fname, t1, t2, t3) ret dispatch::fname(t1 a, t2 b, t3 c)\ -{return f_impl(hlib, fname, fname ## _, #fname, a, b, c); }\ -void* dispatch::fname ## _; +#define DEFINE3(init, hlib, ret, fname, t1, t2, t3) \ + ret dispatch::fname(t1 a, t2 b, t3 c) { \ + return f_impl(hlib, fname, fname##_, #fname, a, b, c); \ + } \ + void *dispatch::fname##_; -#define DEFINE4(init, hlib, ret, fname, t1, t2, t3, t4) ret dispatch::fname(t1 a, t2 b, t3 c, t4 d)\ -{return f_impl(hlib, fname, fname ## _, #fname, a, b, c, d); }\ -void* dispatch::fname ## _; +#define DEFINE4(init, hlib, ret, fname, t1, t2, t3, t4) \ + ret dispatch::fname(t1 a, t2 b, t3 c, t4 d) { \ + return f_impl(hlib, fname, fname##_, #fname, a, b, c, d); \ + } \ + void *dispatch::fname##_; -#define DEFINE5(init, hlib, ret, fname, t1, t2, t3, t4, t5) ret dispatch::fname(t1 a, t2 b, t3 c, t4 d, t5 e)\ -{return f_impl(hlib, fname, fname ## _, #fname, a, b, c, d, e); }\ -void* dispatch::fname ## _; +#define DEFINE5(init, hlib, ret, fname, t1, t2, t3, t4, t5) \ + ret dispatch::fname(t1 a, t2 b, t3 c, t4 d, t5 e) { \ + return f_impl(hlib, fname, fname##_, #fname, a, b, c, d, \ + e); \ + } \ + void *dispatch::fname##_; -#define DEFINE6(init, hlib, ret, fname, t1, t2, t3, t4, t5, t6) ret dispatch::fname(t1 a, t2 b, t3 c, t4 d, t5 e, t6 f)\ -{return f_impl(hlib, fname, fname ## _, #fname, a, b, c, d, e, f); }\ -void* dispatch::fname ## _; +#define DEFINE6(init, hlib, ret, fname, t1, t2, t3, t4, t5, t6) \ + ret dispatch::fname(t1 a, t2 b, t3 c, t4 d, t5 e, t6 f) { \ + return f_impl(hlib, fname, fname##_, #fname, a, b, c, d, \ + e, f); \ + } \ + void *dispatch::fname##_; -#define DEFINE7(init, hlib, ret, fname, t1, t2, t3, t4, t5, t6, t7) ret dispatch::fname(t1 a, t2 b, t3 c, t4 d, t5 e, t6 f, t7 g)\ -{return f_impl(hlib, fname, fname ## _, #fname, a, b, c, d, e, f, g); }\ -void* dispatch::fname ## _; +#define DEFINE7(init, hlib, ret, fname, t1, t2, t3, t4, t5, t6, t7) \ + ret dispatch::fname(t1 a, t2 b, t3 c, t4 d, t5 e, t6 f, t7 g) { \ + return f_impl(hlib, fname, fname##_, #fname, a, b, c, d, \ + e, f, g); \ + } \ + void *dispatch::fname##_; -#define DEFINE8(init, hlib, ret, fname, t1, t2, t3, t4, t5, t6, t7, t8) ret dispatch::fname(t1 a, t2 b, t3 c, t4 d, t5 e, t6 f, t7 g, t8 h)\ -{return f_impl(hlib, fname, fname ## _, #fname, a, b, c, d, e, f, g, h); }\ -void* dispatch::fname ## _; +#define DEFINE8(init, hlib, ret, fname, t1, t2, t3, t4, t5, t6, t7, t8) \ + ret dispatch::fname(t1 a, t2 b, t3 c, t4 d, t5 e, t6 f, t7 g, t8 h) { \ + return f_impl(hlib, fname, fname##_, #fname, a, b, c, d, \ + e, f, g, h); \ + } \ + void *dispatch::fname##_; -#define DEFINE9(init, hlib, ret, fname, t1, t2, t3, t4, t5, t6, t7, t8, t9) ret dispatch::fname(t1 a, t2 b, t3 c, t4 d, t5 e, t6 f, t7 g, t8 h, t9 i)\ -{return f_impl(hlib, fname, fname ## _, #fname, a, b, c, d, e, f, g, h, i); }\ -void* dispatch::fname ## _; +#define DEFINE9(init, hlib, ret, fname, t1, t2, t3, t4, t5, t6, t7, t8, t9) \ + ret dispatch::fname(t1 a, t2 b, t3 c, t4 d, t5 e, t6 f, t7 g, t8 h, t9 i) { \ + return f_impl(hlib, fname, fname##_, #fname, a, b, c, d, \ + e, f, g, h, i); \ + } \ + void *dispatch::fname##_; -#define DEFINE10(init, hlib, ret, fname, t1, t2, t3, t4, t5, t6, t7, t8, t9, t10) ret dispatch::fname(t1 a, t2 b, t3 c, t4 d, t5 e, t6 f, t7 g, t8 h, t9 i, t10 j)\ -{return f_impl(hlib, fname, fname ## _, #fname, a, b, c, d, e, f, g, h, i, j); }\ -void* dispatch::fname ## _; +#define DEFINE10(init, hlib, ret, fname, t1, t2, t3, t4, t5, t6, t7, t8, t9, \ + t10) \ + ret dispatch::fname(t1 a, t2 b, t3 c, t4 d, t5 e, t6 f, t7 g, t8 h, t9 i, \ + t10 j) { \ + return f_impl(hlib, fname, fname##_, #fname, a, b, c, d, \ + e, f, g, h, i, j); \ + } \ + void *dispatch::fname##_; -#define DEFINE11(init, hlib, ret, fname, t1, t2, t3, t4, t5, t6, t7, t8, t9, t10, t11) ret dispatch::fname(t1 a, t2 b, t3 c, t4 d, t5 e, t6 f, t7 g, t8 h, t9 i, t10 j, t11 k)\ -{return f_impl(hlib, fname, fname ## _, #fname, a, b, c, d, e, f, g, h, i, j, k); }\ -void* dispatch::fname ## _; +#define DEFINE11(init, hlib, ret, fname, t1, t2, t3, t4, t5, t6, t7, t8, t9, \ + t10, t11) \ + ret dispatch::fname(t1 a, t2 b, t3 c, t4 d, t5 e, t6 f, t7 g, t8 h, t9 i, \ + t10 j, t11 k) { \ + return f_impl(hlib, fname, fname##_, #fname, a, b, c, d, \ + e, f, g, h, i, j, k); \ + } \ + void *dispatch::fname##_; -#define DEFINE13(init, hlib, ret, fname, t1, t2, t3, t4, t5, t6, t7, t8, t9, t10, t11, t12, t13) ret dispatch::fname(t1 a, t2 b, t3 c, t4 d, t5 e, t6 f, t7 g, t8 h, t9 i, t10 j, t11 k, t12 l, t13 m)\ -{return f_impl(hlib, fname, fname ## _, #fname, a, b, c, d, e, f, g, h, i, j, k, l, m); }\ -void* dispatch::fname ## _; - -#define DEFINE19(init, hlib, ret, fname, t1, t2, t3, t4, t5, t6, t7, t8, t9, t10, t11, t12, t13, t14, t15, t16, t17, t18, t19) ret dispatch::fname(t1 a, t2 b, t3 c, t4 d, t5 e, t6 f, t7 g, t8 h, t9 i, t10 j, t11 k, t12 l, t13 m, t14 n, t15 o, t16 p, t17 q, t18 r, t19 s)\ -{return f_impl(hlib, fname, fname ## _, #fname, a, b, c, d, e, f, g, h, i, j, k, l, m, n, o, p, q, r, s); }\ -void* dispatch::fname ## _; +#define DEFINE13(init, hlib, ret, fname, t1, t2, t3, t4, t5, t6, t7, t8, t9, \ + t10, t11, t12, t13) \ + ret dispatch::fname(t1 a, t2 b, t3 c, t4 d, t5 e, t6 f, t7 g, t8 h, t9 i, \ + t10 j, t11 k, t12 l, t13 m) { \ + return f_impl(hlib, fname, fname##_, #fname, a, b, c, d, \ + e, f, g, h, i, j, k, l, m); \ + } \ + void *dispatch::fname##_; +#define DEFINE19(init, hlib, ret, fname, t1, t2, t3, t4, t5, t6, t7, t8, t9, \ + t10, t11, t12, t13, t14, t15, t16, t17, t18, t19) \ + ret dispatch::fname(t1 a, t2 b, t3 c, t4 d, t5 e, t6 f, t7 g, t8 h, t9 i, \ + t10 j, t11 k, t12 l, t13 m, t14 n, t15 o, t16 p, t17 q, \ + t18 r, t19 s) { \ + return f_impl(hlib, fname, fname##_, #fname, a, b, c, d, \ + e, f, g, h, i, j, k, l, m, n, o, p, q, r, \ + s); \ + } \ + void *dispatch::fname##_; /* ------------------- * * CUDA * ------------------- */ -bool dispatch::cuinit(){ - if(cuda_==nullptr){ - #ifdef _WIN32 +bool dispatch::cuinit() { + if (cuda_ == nullptr) { +#ifdef _WIN32 cuda_ = dlopen("cudart64_110.dll", RTLD_LAZY); - #else +#else cuda_ = dlopen("libcuda.so", RTLD_LAZY); - if(!cuda_) + if (!cuda_) cuda_ = dlopen("libcuda.so.1", RTLD_LAZY); - #endif - if(!cuda_) - throw std::runtime_error("Could not find `libcuda.so`. Make sure it is in your LD_LIBRARY_PATH."); +#endif + if (!cuda_) + throw std::runtime_error("Could not find `libcuda.so`. Make sure it is " + "in your LD_LIBRARY_PATH."); } - if(cuda_ == nullptr) + if (cuda_ == nullptr) return false; CUresult (*fptr)(unsigned int); cuInit_ = dlsym(cuda_, "cuInit"); @@ -112,21 +157,33 @@ bool dispatch::cuinit(){ } #define CUDA_DEFINE1(ret, fname, t1) DEFINE1(cuinit, cuda_, ret, fname, t1) -#define CUDA_DEFINE2(ret, fname, t1, t2) DEFINE2(cuinit, cuda_, ret, fname, t1, t2) -#define CUDA_DEFINE3(ret, fname, t1, t2, t3) DEFINE3(cuinit, cuda_, ret, fname, t1, t2, t3) -#define CUDA_DEFINE4(ret, fname, t1, t2, t3, t4) DEFINE4(cuinit, cuda_, ret, fname, t1, t2, t3, t4) -#define CUDA_DEFINE5(ret, fname, t1, t2, t3, t4, t5) DEFINE5(cuinit, cuda_, ret, fname, t1, t2, t3, t4, t5) -#define CUDA_DEFINE6(ret, fname, t1, t2, t3, t4, t5, t6) DEFINE6(cuinit, cuda_, ret, fname, t1, t2, t3, t4, t5, t6) -#define CUDA_DEFINE7(ret, fname, t1, t2, t3, t4, t5, t6, t7) DEFINE7(cuinit, cuda_, ret, fname, t1, t2, t3, t4, t5, t6, t7) -#define CUDA_DEFINE8(ret, fname, t1, t2, t3, t4, t5, t6, t7, t8) DEFINE8(cuinit, cuda_, ret, fname, t1, t2, t3, t4, t5, t6, t7, t8) -#define CUDA_DEFINE9(ret, fname, t1, t2, t3, t4, t5, t6, t7, t8, t9) DEFINE9(cuinit, cuda_, ret, fname, t1, t2, t3, t4, t5, t6, t7, t8, t9) -#define CUDA_DEFINE10(ret, fname, t1, t2, t3, t4, t5, t6, t7, t8, t9, t10) DEFINE10(cuinit, cuda_, ret, fname, t1, t2, t3, t4, t5, t6, t7, t8, t9, t10) -#define CUDA_DEFINE11(ret, fname, t1, t2, t3, t4, t5, t6, t7, t8, t9, t10, t11) DEFINE11(cuinit, cuda_, ret, fname, t1, t2, t3, t4, t5, t6, t7, t8, t9, t10, t11) +#define CUDA_DEFINE2(ret, fname, t1, t2) \ + DEFINE2(cuinit, cuda_, ret, fname, t1, t2) +#define CUDA_DEFINE3(ret, fname, t1, t2, t3) \ + DEFINE3(cuinit, cuda_, ret, fname, t1, t2, t3) +#define CUDA_DEFINE4(ret, fname, t1, t2, t3, t4) \ + DEFINE4(cuinit, cuda_, ret, fname, t1, t2, t3, t4) +#define CUDA_DEFINE5(ret, fname, t1, t2, t3, t4, t5) \ + DEFINE5(cuinit, cuda_, ret, fname, t1, t2, t3, t4, t5) +#define CUDA_DEFINE6(ret, fname, t1, t2, t3, t4, t5, t6) \ + DEFINE6(cuinit, cuda_, ret, fname, t1, t2, t3, t4, t5, t6) +#define CUDA_DEFINE7(ret, fname, t1, t2, t3, t4, t5, t6, t7) \ + DEFINE7(cuinit, cuda_, ret, fname, t1, t2, t3, t4, t5, t6, t7) +#define CUDA_DEFINE8(ret, fname, t1, t2, t3, t4, t5, t6, t7, t8) \ + DEFINE8(cuinit, cuda_, ret, fname, t1, t2, t3, t4, t5, t6, t7, t8) +#define CUDA_DEFINE9(ret, fname, t1, t2, t3, t4, t5, t6, t7, t8, t9) \ + DEFINE9(cuinit, cuda_, ret, fname, t1, t2, t3, t4, t5, t6, t7, t8, t9) +#define CUDA_DEFINE10(ret, fname, t1, t2, t3, t4, t5, t6, t7, t8, t9, t10) \ + DEFINE10(cuinit, cuda_, ret, fname, t1, t2, t3, t4, t5, t6, t7, t8, t9, t10) +#define CUDA_DEFINE11(ret, fname, t1, t2, t3, t4, t5, t6, t7, t8, t9, t10, \ + t11) \ + DEFINE11(cuinit, cuda_, ret, fname, t1, t2, t3, t4, t5, t6, t7, t8, t9, t10, \ + t11) // context management CUDA_DEFINE1(CUresult, cuCtxDestroy_v2, CUcontext) CUDA_DEFINE3(CUresult, cuCtxCreate_v2, CUcontext *, unsigned int, CUdevice) -CUDA_DEFINE1(CUresult, cuCtxGetDevice, CUdevice*) +CUDA_DEFINE1(CUresult, cuCtxGetDevice, CUdevice *) CUDA_DEFINE2(CUresult, cuCtxEnablePeerAccess, CUcontext, unsigned int) CUDA_DEFINE1(CUresult, cuInit, unsigned int) CUDA_DEFINE1(CUresult, cuDriverGetVersion, int *) @@ -134,59 +191,71 @@ CUDA_DEFINE1(CUresult, cuDriverGetVersion, int *) CUDA_DEFINE2(CUresult, cuDeviceGet, CUdevice *, int) CUDA_DEFINE3(CUresult, cuDeviceGetName, char *, int, CUdevice) CUDA_DEFINE3(CUresult, cuDeviceGetPCIBusId, char *, int, CUdevice) -CUDA_DEFINE3(CUresult, cuDeviceGetAttribute, int *, CUdevice_attribute, CUdevice) -CUDA_DEFINE1(CUresult, cuDeviceGetCount, int*) +CUDA_DEFINE3(CUresult, cuDeviceGetAttribute, int *, CUdevice_attribute, + CUdevice) +CUDA_DEFINE1(CUresult, cuDeviceGetCount, int *) // link management -CUDA_DEFINE8(CUresult, cuLinkAddData_v2, CUlinkState, CUjitInputType, void*, size_t, const char*, unsigned int, CUjit_option*, void**); -CUDA_DEFINE4(CUresult, cuLinkCreate_v2, unsigned int, CUjit_option*, void**, CUlinkState*); +CUDA_DEFINE8(CUresult, cuLinkAddData_v2, CUlinkState, CUjitInputType, void *, + size_t, const char *, unsigned int, CUjit_option *, void **); +CUDA_DEFINE4(CUresult, cuLinkCreate_v2, unsigned int, CUjit_option *, void **, + CUlinkState *); CUDA_DEFINE1(CUresult, cuLinkDestroy, CUlinkState); -CUDA_DEFINE3(CUresult, cuLinkComplete, CUlinkState, void**, size_t*); +CUDA_DEFINE3(CUresult, cuLinkComplete, CUlinkState, void **, size_t *); // module management -CUDA_DEFINE4(CUresult, cuModuleGetGlobal_v2, CUdeviceptr*, size_t*, CUmodule, const char*) +CUDA_DEFINE4(CUresult, cuModuleGetGlobal_v2, CUdeviceptr *, size_t *, CUmodule, + const char *) CUDA_DEFINE2(CUresult, cuModuleLoad, CUmodule *, const char *) CUDA_DEFINE1(CUresult, cuModuleUnload, CUmodule) CUDA_DEFINE2(CUresult, cuModuleLoadData, CUmodule *, const void *) -CUDA_DEFINE5(CUresult, cuModuleLoadDataEx, CUmodule *, const void *, unsigned int, CUjit_option *, void **) -CUDA_DEFINE3(CUresult, cuModuleGetFunction, CUfunction *, CUmodule, const char *) +CUDA_DEFINE5(CUresult, cuModuleLoadDataEx, CUmodule *, const void *, + unsigned int, CUjit_option *, void **) +CUDA_DEFINE3(CUresult, cuModuleGetFunction, CUfunction *, CUmodule, + const char *) // stream management CUDA_DEFINE2(CUresult, cuStreamCreate, CUstream *, unsigned int) CUDA_DEFINE1(CUresult, cuStreamSynchronize, CUstream) CUDA_DEFINE1(CUresult, cuStreamDestroy_v2, CUstream) -CUDA_DEFINE2(CUresult, cuStreamGetCtx, CUstream, CUcontext*) -CUDA_DEFINE11(CUresult, cuLaunchKernel, CUfunction, unsigned int, unsigned int, unsigned int, unsigned int, unsigned int, unsigned int, unsigned int, CUstream, void **, void **) +CUDA_DEFINE2(CUresult, cuStreamGetCtx, CUstream, CUcontext *) +CUDA_DEFINE11(CUresult, cuLaunchKernel, CUfunction, unsigned int, unsigned int, + unsigned int, unsigned int, unsigned int, unsigned int, + unsigned int, CUstream, void **, void **) // function management -CUDA_DEFINE3(CUresult, cuFuncGetAttribute, int*, CUfunction_attribute, CUfunction) -CUDA_DEFINE3(CUresult, cuFuncSetAttribute, CUfunction, CUfunction_attribute, int) +CUDA_DEFINE3(CUresult, cuFuncGetAttribute, int *, CUfunction_attribute, + CUfunction) +CUDA_DEFINE3(CUresult, cuFuncSetAttribute, CUfunction, CUfunction_attribute, + int) CUDA_DEFINE2(CUresult, cuFuncSetCacheConfig, CUfunction, CUfunc_cache) // memory management CUDA_DEFINE3(CUresult, cuMemcpyDtoH_v2, void *, CUdeviceptr, size_t) CUDA_DEFINE1(CUresult, cuMemFree_v2, CUdeviceptr) -CUDA_DEFINE4(CUresult, cuMemcpyDtoHAsync_v2, void *, CUdeviceptr, size_t, CUstream) -CUDA_DEFINE4(CUresult, cuMemcpyHtoDAsync_v2, CUdeviceptr, const void *, size_t, CUstream) -CUDA_DEFINE3(CUresult, cuMemcpyHtoD_v2, CUdeviceptr, const void *, size_t ) -CUDA_DEFINE2(CUresult, cuMemAlloc_v2, CUdeviceptr*, size_t) -CUDA_DEFINE3(CUresult, cuPointerGetAttribute, void*, CUpointer_attribute, CUdeviceptr) -CUDA_DEFINE4(CUresult, cuMemsetD8Async, CUdeviceptr, unsigned char, size_t, CUstream) +CUDA_DEFINE4(CUresult, cuMemcpyDtoHAsync_v2, void *, CUdeviceptr, size_t, + CUstream) +CUDA_DEFINE4(CUresult, cuMemcpyHtoDAsync_v2, CUdeviceptr, const void *, size_t, + CUstream) +CUDA_DEFINE3(CUresult, cuMemcpyHtoD_v2, CUdeviceptr, const void *, size_t) +CUDA_DEFINE2(CUresult, cuMemAlloc_v2, CUdeviceptr *, size_t) +CUDA_DEFINE3(CUresult, cuPointerGetAttribute, void *, CUpointer_attribute, + CUdeviceptr) +CUDA_DEFINE4(CUresult, cuMemsetD8Async, CUdeviceptr, unsigned char, size_t, + CUstream) // event management CUDA_DEFINE2(CUresult, cuEventCreate, CUevent *, unsigned int) CUDA_DEFINE3(CUresult, cuEventElapsedTime, float *, CUevent, CUevent) CUDA_DEFINE2(CUresult, cuEventRecord, CUevent, CUstream) CUDA_DEFINE1(CUresult, cuEventDestroy_v2, CUevent) - - /* ------------------- * * NVML * ------------------- */ -bool dispatch::nvmlinit(){ - #ifdef _WIN32 - if(nvml_==nullptr) +bool dispatch::nvmlinit() { +#ifdef _WIN32 + if (nvml_ == nullptr) nvml_ = dlopen("nvml.dll", RTLD_LAZY); - #else - if(nvml_==nullptr) +#else + if (nvml_ == nullptr) nvml_ = dlopen("libnvidia-ml.so", RTLD_LAZY); - #endif +#endif nvmlReturn_t (*fptr)(); nvmlInit_v2_ = dlsym(nvml_, "nvmlInit_v2"); *reinterpret_cast(&fptr) = nvmlInit_v2_; @@ -197,21 +266,27 @@ bool dispatch::nvmlinit(){ #define NVML_DEFINE0(ret, fname) DEFINE0(nvmlinit, nvml_, ret, fname) #define NVML_DEFINE1(ret, fname, t1) DEFINE1(nvmlinit, nvml_, ret, fname, t1) -#define NVML_DEFINE2(ret, fname, t1, t2) DEFINE2(nvmlinit, nvml_, ret, fname, t1, t2) -#define NVML_DEFINE3(ret, fname, t1, t2, t3) DEFINE3(nvmlinit, nvml_, ret, fname, t1, t2, t3) +#define NVML_DEFINE2(ret, fname, t1, t2) \ + DEFINE2(nvmlinit, nvml_, ret, fname, t1, t2) +#define NVML_DEFINE3(ret, fname, t1, t2, t3) \ + DEFINE3(nvmlinit, nvml_, ret, fname, t1, t2, t3) -NVML_DEFINE2(nvmlReturn_t, nvmlDeviceGetHandleByPciBusId_v2, const char *, nvmlDevice_t*) -NVML_DEFINE3(nvmlReturn_t, nvmlDeviceGetClockInfo, nvmlDevice_t, nvmlClockType_t, unsigned int*) -NVML_DEFINE3(nvmlReturn_t, nvmlDeviceGetMaxClockInfo, nvmlDevice_t, nvmlClockType_t, unsigned int*) -NVML_DEFINE3(nvmlReturn_t, nvmlDeviceSetApplicationsClocks, nvmlDevice_t, unsigned int, unsigned int) +NVML_DEFINE2(nvmlReturn_t, nvmlDeviceGetHandleByPciBusId_v2, const char *, + nvmlDevice_t *) +NVML_DEFINE3(nvmlReturn_t, nvmlDeviceGetClockInfo, nvmlDevice_t, + nvmlClockType_t, unsigned int *) +NVML_DEFINE3(nvmlReturn_t, nvmlDeviceGetMaxClockInfo, nvmlDevice_t, + nvmlClockType_t, unsigned int *) +NVML_DEFINE3(nvmlReturn_t, nvmlDeviceSetApplicationsClocks, nvmlDevice_t, + unsigned int, unsigned int) /* ------------------- * * HIP * ------------------- */ -bool dispatch::hipinit(){ - if(hip_==nullptr) +bool dispatch::hipinit() { + if (hip_ == nullptr) hip_ = dlopen("libamdhip64.so", RTLD_LAZY); - if(hip_ == nullptr) + if (hip_ == nullptr) return false; hipError_t (*fptr)(); hipInit_ = dlsym(hip_, "hipInit"); @@ -222,23 +297,34 @@ bool dispatch::hipinit(){ } #define HIP_DEFINE1(ret, fname, t1) DEFINE1(hipinit, hip_, ret, fname, t1) -#define HIP_DEFINE2(ret, fname, t1, t2) DEFINE2(hipinit, hip_, ret, fname, t1, t2) -#define HIP_DEFINE3(ret, fname, t1, t2, t3) DEFINE3(hipinit, hip_, ret, fname, t1, t2, t3) -#define HIP_DEFINE4(ret, fname, t1, t2, t3, t4) DEFINE4(hipinit, hip_, ret, fname, t1, t2, t3, t4) -#define HIP_DEFINE5(ret, fname, t1, t2, t3, t4, t5) DEFINE5(hipinit, hip_, ret, fname, t1, t2, t3, t4, t5) -#define HIP_DEFINE6(ret, fname, t1, t2, t3, t4, t5, t6) DEFINE6(hipinit, hip_, ret, fname, t1, t2, t3, t4, t5, t6) -#define HIP_DEFINE7(ret, fname, t1, t2, t3, t4, t5, t6, t7) DEFINE7(hipinit, hip_, ret, fname, t1, t2, t3, t4, t5, t6, t7) -#define HIP_DEFINE8(ret, fname, t1, t2, t3, t4, t5, t6, t7, t8) DEFINE8(hipinit, hip_, ret, fname, t1, t2, t3, t4, t5, t6, t7, t8) -#define HIP_DEFINE9(ret, fname, t1, t2, t3, t4, t5, t6, t7, t8, t9) DEFINE9(hipinit, hip_, ret, fname, t1, t2, t3, t4, t5, t6, t7, t8, t9) -#define HIP_DEFINE10(ret, fname, t1, t2, t3, t4, t5, t6, t7, t8, t9, t10) DEFINE10(hipinit, hip_, ret, fname, t1, t2, t3, t4, t5, t6, t7, t8, t9, t10) -#define HIP_DEFINE11(ret, fname, t1, t2, t3, t4, t5, t6, t7, t8, t9, t10, t11) DEFINE11(hipinit, hip_, ret, fname, t1, t2, t3, t4, t5, t6, t7, t8, t9, t10, t11) +#define HIP_DEFINE2(ret, fname, t1, t2) \ + DEFINE2(hipinit, hip_, ret, fname, t1, t2) +#define HIP_DEFINE3(ret, fname, t1, t2, t3) \ + DEFINE3(hipinit, hip_, ret, fname, t1, t2, t3) +#define HIP_DEFINE4(ret, fname, t1, t2, t3, t4) \ + DEFINE4(hipinit, hip_, ret, fname, t1, t2, t3, t4) +#define HIP_DEFINE5(ret, fname, t1, t2, t3, t4, t5) \ + DEFINE5(hipinit, hip_, ret, fname, t1, t2, t3, t4, t5) +#define HIP_DEFINE6(ret, fname, t1, t2, t3, t4, t5, t6) \ + DEFINE6(hipinit, hip_, ret, fname, t1, t2, t3, t4, t5, t6) +#define HIP_DEFINE7(ret, fname, t1, t2, t3, t4, t5, t6, t7) \ + DEFINE7(hipinit, hip_, ret, fname, t1, t2, t3, t4, t5, t6, t7) +#define HIP_DEFINE8(ret, fname, t1, t2, t3, t4, t5, t6, t7, t8) \ + DEFINE8(hipinit, hip_, ret, fname, t1, t2, t3, t4, t5, t6, t7, t8) +#define HIP_DEFINE9(ret, fname, t1, t2, t3, t4, t5, t6, t7, t8, t9) \ + DEFINE9(hipinit, hip_, ret, fname, t1, t2, t3, t4, t5, t6, t7, t8, t9) +#define HIP_DEFINE10(ret, fname, t1, t2, t3, t4, t5, t6, t7, t8, t9, t10) \ + DEFINE10(hipinit, hip_, ret, fname, t1, t2, t3, t4, t5, t6, t7, t8, t9, t10) +#define HIP_DEFINE11(ret, fname, t1, t2, t3, t4, t5, t6, t7, t8, t9, t10, t11) \ + DEFINE11(hipinit, hip_, ret, fname, t1, t2, t3, t4, t5, t6, t7, t8, t9, t10, \ + t11) // context management HIP_DEFINE1(hipError_t, hipCtxDestroy, hipCtx_t) HIP_DEFINE3(hipError_t, hipCtxCreate, hipCtx_t *, unsigned int, hipDevice_t) -HIP_DEFINE1(hipError_t, hipCtxGetDevice, hipDevice_t*) +HIP_DEFINE1(hipError_t, hipCtxGetDevice, hipDevice_t *) HIP_DEFINE1(hipError_t, hipCtxPushCurrent, hipCtx_t) -HIP_DEFINE1(hipError_t, hipCtxPopCurrent, hipCtx_t*) +HIP_DEFINE1(hipError_t, hipCtxPopCurrent, hipCtx_t *) HIP_DEFINE2(hipError_t, hipCtxEnablePeerAccess, hipCtx_t, unsigned int) HIP_DEFINE1(hipError_t, hipInit, unsigned int) HIP_DEFINE1(hipError_t, hipDriverGetVersion, int *) @@ -246,56 +332,64 @@ HIP_DEFINE1(hipError_t, hipDriverGetVersion, int *) HIP_DEFINE2(hipError_t, hipGetDevice, hipDevice_t *, int) HIP_DEFINE3(hipError_t, hipDeviceGetName, char *, int, hipDevice_t) HIP_DEFINE3(hipError_t, hipDeviceGetPCIBusId, char *, int, hipDevice_t) -HIP_DEFINE3(hipError_t, hipDeviceGetAttribute, int *, hipDeviceAttribute_t, hipDevice_t) +HIP_DEFINE3(hipError_t, hipDeviceGetAttribute, int *, hipDeviceAttribute_t, + hipDevice_t) HIP_DEFINE1(hipError_t, hipGetDeviceCount, int *) // module management -HIP_DEFINE4(hipError_t, hipModuleGetGlobal, hipDeviceptr_t*, size_t*, hipModule_t, const char*) +HIP_DEFINE4(hipError_t, hipModuleGetGlobal, hipDeviceptr_t *, size_t *, + hipModule_t, const char *) HIP_DEFINE2(hipError_t, hipModuleLoad, hipModule_t *, const char *) HIP_DEFINE1(hipError_t, hipModuleUnload, hipModule_t) HIP_DEFINE2(hipError_t, hipModuleLoadData, hipModule_t *, const void *) -HIP_DEFINE5(hipError_t, hipModuleLoadDataEx, hipModule_t *, const void *, unsigned int, hipJitOption *, void **) -HIP_DEFINE3(hipError_t, hipModuleGetFunction, hipFunction_t *, hipModule_t, const char *) +HIP_DEFINE5(hipError_t, hipModuleLoadDataEx, hipModule_t *, const void *, + unsigned int, hipJitOption *, void **) +HIP_DEFINE3(hipError_t, hipModuleGetFunction, hipFunction_t *, hipModule_t, + const char *) // stream management HIP_DEFINE2(hipError_t, hipStreamCreate, hipStream_t *, unsigned int) HIP_DEFINE1(hipError_t, hipStreamSynchronize, hipStream_t) HIP_DEFINE1(hipError_t, hipStreamDestroy, hipStream_t) -HIP_DEFINE11(hipError_t, hipModuleLaunchKernel, hipFunction_t, unsigned int, unsigned int, unsigned int, unsigned int, unsigned int, unsigned int, unsigned int, hipStream_t, void **, void **) +HIP_DEFINE11(hipError_t, hipModuleLaunchKernel, hipFunction_t, unsigned int, + unsigned int, unsigned int, unsigned int, unsigned int, + unsigned int, unsigned int, hipStream_t, void **, void **) // function management -HIP_DEFINE2(hipError_t, hipFuncGetAttributes, hipFuncAttributes*, void*) +HIP_DEFINE2(hipError_t, hipFuncGetAttributes, hipFuncAttributes *, void *) HIP_DEFINE2(hipError_t, hipFuncSetCacheConfig, hipFunction_t, hipFuncCache_t) // memory management HIP_DEFINE3(hipError_t, hipMemcpyDtoH, void *, hipDeviceptr_t, size_t) HIP_DEFINE1(hipError_t, hipFree, hipDeviceptr_t) -HIP_DEFINE4(hipError_t, hipMemcpyDtoHAsync, void *, hipDeviceptr_t, size_t, hipStream_t) -HIP_DEFINE4(hipError_t, hipMemcpyHtoDAsync, hipDeviceptr_t, const void *, size_t, hipStream_t) -HIP_DEFINE3(hipError_t, hipMemcpyHtoD, hipDeviceptr_t, const void *, size_t ) -HIP_DEFINE2(hipError_t, hipMalloc, hipDeviceptr_t*, size_t) -HIP_DEFINE3(hipError_t, hipPointerGetAttribute, void*, CUpointer_attribute, hipDeviceptr_t) -HIP_DEFINE4(hipError_t, hipMemsetD8Async, hipDeviceptr_t, unsigned char, size_t, hipStream_t) +HIP_DEFINE4(hipError_t, hipMemcpyDtoHAsync, void *, hipDeviceptr_t, size_t, + hipStream_t) +HIP_DEFINE4(hipError_t, hipMemcpyHtoDAsync, hipDeviceptr_t, const void *, + size_t, hipStream_t) +HIP_DEFINE3(hipError_t, hipMemcpyHtoD, hipDeviceptr_t, const void *, size_t) +HIP_DEFINE2(hipError_t, hipMalloc, hipDeviceptr_t *, size_t) +HIP_DEFINE3(hipError_t, hipPointerGetAttribute, void *, CUpointer_attribute, + hipDeviceptr_t) +HIP_DEFINE4(hipError_t, hipMemsetD8Async, hipDeviceptr_t, unsigned char, size_t, + hipStream_t) // event management HIP_DEFINE2(hipError_t, hipEventCreate, hipEvent_t *, unsigned int) HIP_DEFINE3(hipError_t, hipEventElapsedTime, float *, hipEvent_t, hipEvent_t) HIP_DEFINE2(hipError_t, hipEventRecord, hipEvent_t, hipStream_t) HIP_DEFINE1(hipError_t, hipEventDestroy, hipEvent_t) - /* ------------------- * * COMMON * ------------------- */ // Release -void dispatch::release(){ - if(cuda_){ +void dispatch::release() { + if (cuda_) { dlclose(cuda_); cuda_ = nullptr; } } -void* dispatch::cuda_; -void* dispatch::nvml_; -void* dispatch::nvmlInit_v2_; -void* dispatch::hip_; +void *dispatch::cuda_; +void *dispatch::nvml_; +void *dispatch::nvmlInit_v2_; +void *dispatch::hip_; - -} -} +} // namespace driver +} // namespace triton diff --git a/lib/driver/error.cc b/lib/driver/error.cc old mode 100755 new mode 100644 index f723351c2..4b366746e --- a/lib/driver/error.cc +++ b/lib/driver/error.cc @@ -1,166 +1,270 @@ /* Copyright 2015-2017 Philippe Tillet -* -* Permission is hereby granted, free of charge, to any person obtaining -* a copy of this software and associated documentation files -* (the "Software"), to deal in the Software without restriction, -* including without limitation the rights to use, copy, modify, merge, -* publish, distribute, sublicense, and/or sell copies of the Software, -* and to permit persons to whom the Software is furnished to do so, -* subject to the following conditions: -* -* The above copyright notice and this permission notice shall be -* included in all copies or substantial portions of the Software. -* -* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, -* EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF -* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. -* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY -* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, -* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE -* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. -*/ + * + * Permission is hereby granted, free of charge, to any person obtaining + * a copy of this software and associated documentation files + * (the "Software"), to deal in the Software without restriction, + * including without limitation the rights to use, copy, modify, merge, + * publish, distribute, sublicense, and/or sell copies of the Software, + * and to permit persons to whom the Software is furnished to do so, + * subject to the following conditions: + * + * The above copyright notice and this permission notice shall be + * included in all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, + * EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF + * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. + * IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY + * CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, + * TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE + * SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + */ #include "triton/driver/error.h" -namespace triton -{ -namespace driver -{ +namespace triton { +namespace driver { -void check(CUresult err) -{ +void check(CUresult err) { using namespace exception::cuda; - switch(err) - { - case CUDA_SUCCESS : break; - case CUDA_ERROR_INVALID_VALUE : throw invalid_value(); - case CUDA_ERROR_OUT_OF_MEMORY : throw out_of_memory(); - case CUDA_ERROR_NOT_INITIALIZED : throw not_initialized(); - case CUDA_ERROR_DEINITIALIZED : throw deinitialized(); - case CUDA_ERROR_PROFILER_DISABLED : throw profiler_disabled(); - case CUDA_ERROR_PROFILER_NOT_INITIALIZED : throw profiler_not_initialized(); - case CUDA_ERROR_PROFILER_ALREADY_STARTED : throw profiler_already_started(); - case CUDA_ERROR_PROFILER_ALREADY_STOPPED : throw profiler_already_stopped(); - case CUDA_ERROR_NO_DEVICE : throw no_device(); - case CUDA_ERROR_INVALID_DEVICE : throw invalid_device(); - case CUDA_ERROR_INVALID_IMAGE : throw invalid_image(); - case CUDA_ERROR_INVALID_CONTEXT : throw invalid_context(); - case CUDA_ERROR_CONTEXT_ALREADY_CURRENT : throw context_already_current(); - case CUDA_ERROR_MAP_FAILED : throw map_failed(); - case CUDA_ERROR_UNMAP_FAILED : throw unmap_failed(); - case CUDA_ERROR_ARRAY_IS_MAPPED : throw array_is_mapped(); - case CUDA_ERROR_ALREADY_MAPPED : throw already_mapped(); - case CUDA_ERROR_NO_BINARY_FOR_GPU : throw no_binary_for_gpu(); - case CUDA_ERROR_ALREADY_ACQUIRED : throw already_acquired(); - case CUDA_ERROR_NOT_MAPPED : throw not_mapped(); - case CUDA_ERROR_NOT_MAPPED_AS_ARRAY : throw not_mapped_as_array(); - case CUDA_ERROR_NOT_MAPPED_AS_POINTER : throw not_mapped_as_pointer(); - case CUDA_ERROR_ECC_UNCORRECTABLE : throw ecc_uncorrectable(); - case CUDA_ERROR_UNSUPPORTED_LIMIT : throw unsupported_limit(); - case CUDA_ERROR_CONTEXT_ALREADY_IN_USE : throw context_already_in_use(); - case CUDA_ERROR_PEER_ACCESS_UNSUPPORTED : throw peer_access_unsupported(); - case CUDA_ERROR_INVALID_PTX : throw invalid_ptx(); - case CUDA_ERROR_INVALID_GRAPHICS_CONTEXT : throw invalid_graphics_context(); - case CUDA_ERROR_INVALID_SOURCE : throw invalid_source(); - case CUDA_ERROR_FILE_NOT_FOUND : throw file_not_found(); - case CUDA_ERROR_SHARED_OBJECT_SYMBOL_NOT_FOUND : throw shared_object_symbol_not_found(); - case CUDA_ERROR_SHARED_OBJECT_INIT_FAILED : throw shared_object_init_failed(); - case CUDA_ERROR_OPERATING_SYSTEM : throw operating_system(); - case CUDA_ERROR_INVALID_HANDLE : throw invalid_handle(); - case CUDA_ERROR_NOT_FOUND : throw not_found(); - case CUDA_ERROR_NOT_READY : throw not_ready(); - case CUDA_ERROR_ILLEGAL_ADDRESS : throw illegal_address(); - case CUDA_ERROR_LAUNCH_OUT_OF_RESOURCES : throw launch_out_of_resources(); - case CUDA_ERROR_LAUNCH_TIMEOUT : throw launch_timeout(); - case CUDA_ERROR_LAUNCH_INCOMPATIBLE_TEXTURING : throw launch_incompatible_texturing(); - case CUDA_ERROR_PEER_ACCESS_ALREADY_ENABLED : throw peer_access_already_enabled(); - case CUDA_ERROR_PEER_ACCESS_NOT_ENABLED : throw peer_access_not_enabled(); - case CUDA_ERROR_PRIMARY_CONTEXT_ACTIVE : throw primary_context_active(); - case CUDA_ERROR_CONTEXT_IS_DESTROYED : throw context_is_destroyed(); - case CUDA_ERROR_ASSERT : throw assert_error(); - case CUDA_ERROR_TOO_MANY_PEERS : throw too_many_peers(); - case CUDA_ERROR_HOST_MEMORY_ALREADY_REGISTERED : throw host_memory_already_registered(); - case CUDA_ERROR_HOST_MEMORY_NOT_REGISTERED : throw host_memory_not_registered(); - case CUDA_ERROR_HARDWARE_STACK_ERROR : throw hardware_stack_error(); - case CUDA_ERROR_ILLEGAL_INSTRUCTION : throw illegal_instruction(); - case CUDA_ERROR_MISALIGNED_ADDRESS : throw misaligned_address(); - case CUDA_ERROR_INVALID_ADDRESS_SPACE : throw invalid_address_space(); - case CUDA_ERROR_INVALID_PC : throw invalid_pc(); - case CUDA_ERROR_LAUNCH_FAILED : throw launch_failed(); - case CUDA_ERROR_NOT_PERMITTED : throw not_permitted(); - case CUDA_ERROR_NOT_SUPPORTED : throw not_supported(); - case CUDA_ERROR_UNKNOWN : throw unknown(); - default : throw unknown(); + switch (err) { + case CUDA_SUCCESS: + break; + case CUDA_ERROR_INVALID_VALUE: + throw invalid_value(); + case CUDA_ERROR_OUT_OF_MEMORY: + throw out_of_memory(); + case CUDA_ERROR_NOT_INITIALIZED: + throw not_initialized(); + case CUDA_ERROR_DEINITIALIZED: + throw deinitialized(); + case CUDA_ERROR_PROFILER_DISABLED: + throw profiler_disabled(); + case CUDA_ERROR_PROFILER_NOT_INITIALIZED: + throw profiler_not_initialized(); + case CUDA_ERROR_PROFILER_ALREADY_STARTED: + throw profiler_already_started(); + case CUDA_ERROR_PROFILER_ALREADY_STOPPED: + throw profiler_already_stopped(); + case CUDA_ERROR_NO_DEVICE: + throw no_device(); + case CUDA_ERROR_INVALID_DEVICE: + throw invalid_device(); + case CUDA_ERROR_INVALID_IMAGE: + throw invalid_image(); + case CUDA_ERROR_INVALID_CONTEXT: + throw invalid_context(); + case CUDA_ERROR_CONTEXT_ALREADY_CURRENT: + throw context_already_current(); + case CUDA_ERROR_MAP_FAILED: + throw map_failed(); + case CUDA_ERROR_UNMAP_FAILED: + throw unmap_failed(); + case CUDA_ERROR_ARRAY_IS_MAPPED: + throw array_is_mapped(); + case CUDA_ERROR_ALREADY_MAPPED: + throw already_mapped(); + case CUDA_ERROR_NO_BINARY_FOR_GPU: + throw no_binary_for_gpu(); + case CUDA_ERROR_ALREADY_ACQUIRED: + throw already_acquired(); + case CUDA_ERROR_NOT_MAPPED: + throw not_mapped(); + case CUDA_ERROR_NOT_MAPPED_AS_ARRAY: + throw not_mapped_as_array(); + case CUDA_ERROR_NOT_MAPPED_AS_POINTER: + throw not_mapped_as_pointer(); + case CUDA_ERROR_ECC_UNCORRECTABLE: + throw ecc_uncorrectable(); + case CUDA_ERROR_UNSUPPORTED_LIMIT: + throw unsupported_limit(); + case CUDA_ERROR_CONTEXT_ALREADY_IN_USE: + throw context_already_in_use(); + case CUDA_ERROR_PEER_ACCESS_UNSUPPORTED: + throw peer_access_unsupported(); + case CUDA_ERROR_INVALID_PTX: + throw invalid_ptx(); + case CUDA_ERROR_INVALID_GRAPHICS_CONTEXT: + throw invalid_graphics_context(); + case CUDA_ERROR_INVALID_SOURCE: + throw invalid_source(); + case CUDA_ERROR_FILE_NOT_FOUND: + throw file_not_found(); + case CUDA_ERROR_SHARED_OBJECT_SYMBOL_NOT_FOUND: + throw shared_object_symbol_not_found(); + case CUDA_ERROR_SHARED_OBJECT_INIT_FAILED: + throw shared_object_init_failed(); + case CUDA_ERROR_OPERATING_SYSTEM: + throw operating_system(); + case CUDA_ERROR_INVALID_HANDLE: + throw invalid_handle(); + case CUDA_ERROR_NOT_FOUND: + throw not_found(); + case CUDA_ERROR_NOT_READY: + throw not_ready(); + case CUDA_ERROR_ILLEGAL_ADDRESS: + throw illegal_address(); + case CUDA_ERROR_LAUNCH_OUT_OF_RESOURCES: + throw launch_out_of_resources(); + case CUDA_ERROR_LAUNCH_TIMEOUT: + throw launch_timeout(); + case CUDA_ERROR_LAUNCH_INCOMPATIBLE_TEXTURING: + throw launch_incompatible_texturing(); + case CUDA_ERROR_PEER_ACCESS_ALREADY_ENABLED: + throw peer_access_already_enabled(); + case CUDA_ERROR_PEER_ACCESS_NOT_ENABLED: + throw peer_access_not_enabled(); + case CUDA_ERROR_PRIMARY_CONTEXT_ACTIVE: + throw primary_context_active(); + case CUDA_ERROR_CONTEXT_IS_DESTROYED: + throw context_is_destroyed(); + case CUDA_ERROR_ASSERT: + throw assert_error(); + case CUDA_ERROR_TOO_MANY_PEERS: + throw too_many_peers(); + case CUDA_ERROR_HOST_MEMORY_ALREADY_REGISTERED: + throw host_memory_already_registered(); + case CUDA_ERROR_HOST_MEMORY_NOT_REGISTERED: + throw host_memory_not_registered(); + case CUDA_ERROR_HARDWARE_STACK_ERROR: + throw hardware_stack_error(); + case CUDA_ERROR_ILLEGAL_INSTRUCTION: + throw illegal_instruction(); + case CUDA_ERROR_MISALIGNED_ADDRESS: + throw misaligned_address(); + case CUDA_ERROR_INVALID_ADDRESS_SPACE: + throw invalid_address_space(); + case CUDA_ERROR_INVALID_PC: + throw invalid_pc(); + case CUDA_ERROR_LAUNCH_FAILED: + throw launch_failed(); + case CUDA_ERROR_NOT_PERMITTED: + throw not_permitted(); + case CUDA_ERROR_NOT_SUPPORTED: + throw not_supported(); + case CUDA_ERROR_UNKNOWN: + throw unknown(); + default: + throw unknown(); } } void check(hipError_t error) { using namespace exception::hip; - switch(error) - { - case hipSuccess : break; - case hipErrorInvalidValue : throw invalid_value(); - case hipErrorMemoryAllocation : throw out_of_memory(); - case hipErrorNotInitialized : throw not_initialized(); - case hipErrorDeinitialized : throw deinitialized(); - case hipErrorProfilerDisabled : throw profiler_disabled(); - case hipErrorProfilerNotInitialized : throw profiler_not_initialized(); - case hipErrorProfilerAlreadyStarted : throw profiler_already_started(); - case hipErrorProfilerAlreadyStopped : throw profiler_already_stopped(); - case hipErrorNoDevice : throw no_device(); - case hipErrorInvalidSymbol : throw invalid_symbol(); - case hipErrorInvalidDevice : throw invalid_device(); - case hipErrorInvalidImage : throw invalid_image(); - case hipErrorInvalidContext : throw invalid_context(); - case hipErrorContextAlreadyCurrent : throw context_already_current(); - case hipErrorMapFailed : throw map_failed(); - case hipErrorUnmapFailed : throw unmap_failed(); - case hipErrorArrayIsMapped : throw array_is_mapped(); - case hipErrorAlreadyMapped : throw already_mapped(); - case hipErrorNoBinaryForGpu : throw no_binary_for_gpu(); - case hipErrorAlreadyAcquired : throw already_acquired(); - case hipErrorNotMapped : throw not_mapped(); - case hipErrorNotMappedAsArray : throw not_mapped_as_array(); - case hipErrorNotMappedAsPointer : throw not_mapped_as_pointer(); - case hipErrorECCNotCorrectable : throw ecc_uncorrectable(); - case hipErrorUnsupportedLimit : throw unsupported_limit(); - case hipErrorContextAlreadyInUse : throw context_already_in_use(); - case hipErrorPeerAccessUnsupported : throw peer_access_unsupported(); - case hipErrorInvalidKernelFile : throw invalid_ptx(); - case hipErrorInvalidGraphicsContext : throw invalid_graphics_context(); - case hipErrorInvalidSource : throw invalid_source(); - case hipErrorFileNotFound : throw file_not_found(); - case hipErrorSharedObjectSymbolNotFound : throw shared_object_symbol_not_found(); - case hipErrorSharedObjectInitFailed : throw shared_object_init_failed(); - case hipErrorOperatingSystem : throw operating_system(); - case hipErrorInvalidResourceHandle : throw invalid_handle(); - case hipErrorNotFound : throw not_found(); - case hipErrorNotReady : throw not_ready(); - case hipErrorIllegalAddress : throw illegal_address(); - case hipErrorLaunchOutOfResources : throw launch_out_of_resources(); - case hipErrorLaunchTimeOut : throw launch_timeout(); - // case hipErrorLaunchIncompatibleTexturing : throw launch_incompatible_texturing(); - case hipErrorPeerAccessAlreadyEnabled : throw peer_access_already_enabled(); - case hipErrorPeerAccessNotEnabled : throw peer_access_not_enabled(); - // case hipErrorPrimaryContextActive : throw primary_context_active(); - // case hipErrorContextIsDestroyed : throw context_is_destroyed(); - case hipErrorAssert : throw assert_error(); - // case hipErrorTooManyPeers : throw too_many_peers(); - case hipErrorHostMemoryAlreadyRegistered : throw host_memory_already_registered(); - case hipErrorHostMemoryNotRegistered : throw host_memory_not_registered(); - // case hipErrorHardwareStackError : throw hardware_stack_error(); - // case hipErrorIllegalInstruction : throw illegal_instruction(); - // case hipErrorMisalignedAddress : throw misaligned_address(); - // case hipErrorInvalidAddressSpace : throw invalid_address_space(); - // case hipErrorInvalidPc : throw invalid_pc(); - case hipErrorLaunchFailure : throw launch_failed(); - // case hipErrorNotPermitted : throw not_permitted(); - case hipErrorNotSupported : throw not_supported(); - case hipErrorUnknown : throw unknown(); - default : throw unknown(); -} -} - -} + switch (error) { + case hipSuccess: + break; + case hipErrorInvalidValue: + throw invalid_value(); + case hipErrorMemoryAllocation: + throw out_of_memory(); + case hipErrorNotInitialized: + throw not_initialized(); + case hipErrorDeinitialized: + throw deinitialized(); + case hipErrorProfilerDisabled: + throw profiler_disabled(); + case hipErrorProfilerNotInitialized: + throw profiler_not_initialized(); + case hipErrorProfilerAlreadyStarted: + throw profiler_already_started(); + case hipErrorProfilerAlreadyStopped: + throw profiler_already_stopped(); + case hipErrorNoDevice: + throw no_device(); + case hipErrorInvalidSymbol: + throw invalid_symbol(); + case hipErrorInvalidDevice: + throw invalid_device(); + case hipErrorInvalidImage: + throw invalid_image(); + case hipErrorInvalidContext: + throw invalid_context(); + case hipErrorContextAlreadyCurrent: + throw context_already_current(); + case hipErrorMapFailed: + throw map_failed(); + case hipErrorUnmapFailed: + throw unmap_failed(); + case hipErrorArrayIsMapped: + throw array_is_mapped(); + case hipErrorAlreadyMapped: + throw already_mapped(); + case hipErrorNoBinaryForGpu: + throw no_binary_for_gpu(); + case hipErrorAlreadyAcquired: + throw already_acquired(); + case hipErrorNotMapped: + throw not_mapped(); + case hipErrorNotMappedAsArray: + throw not_mapped_as_array(); + case hipErrorNotMappedAsPointer: + throw not_mapped_as_pointer(); + case hipErrorECCNotCorrectable: + throw ecc_uncorrectable(); + case hipErrorUnsupportedLimit: + throw unsupported_limit(); + case hipErrorContextAlreadyInUse: + throw context_already_in_use(); + case hipErrorPeerAccessUnsupported: + throw peer_access_unsupported(); + case hipErrorInvalidKernelFile: + throw invalid_ptx(); + case hipErrorInvalidGraphicsContext: + throw invalid_graphics_context(); + case hipErrorInvalidSource: + throw invalid_source(); + case hipErrorFileNotFound: + throw file_not_found(); + case hipErrorSharedObjectSymbolNotFound: + throw shared_object_symbol_not_found(); + case hipErrorSharedObjectInitFailed: + throw shared_object_init_failed(); + case hipErrorOperatingSystem: + throw operating_system(); + case hipErrorInvalidResourceHandle: + throw invalid_handle(); + case hipErrorNotFound: + throw not_found(); + case hipErrorNotReady: + throw not_ready(); + case hipErrorIllegalAddress: + throw illegal_address(); + case hipErrorLaunchOutOfResources: + throw launch_out_of_resources(); + case hipErrorLaunchTimeOut: + throw launch_timeout(); + // case hipErrorLaunchIncompatibleTexturing : throw + // launch_incompatible_texturing(); + case hipErrorPeerAccessAlreadyEnabled: + throw peer_access_already_enabled(); + case hipErrorPeerAccessNotEnabled: + throw peer_access_not_enabled(); + // case hipErrorPrimaryContextActive : throw primary_context_active(); + // case hipErrorContextIsDestroyed : throw context_is_destroyed(); + case hipErrorAssert: + throw assert_error(); + // case hipErrorTooManyPeers : throw too_many_peers(); + case hipErrorHostMemoryAlreadyRegistered: + throw host_memory_already_registered(); + case hipErrorHostMemoryNotRegistered: + throw host_memory_not_registered(); + // case hipErrorHardwareStackError : throw hardware_stack_error(); + // case hipErrorIllegalInstruction : throw illegal_instruction(); + // case hipErrorMisalignedAddress : throw misaligned_address(); + // case hipErrorInvalidAddressSpace : throw invalid_address_space(); + // case hipErrorInvalidPc : throw invalid_pc(); + case hipErrorLaunchFailure: + throw launch_failed(); + // case hipErrorNotPermitted : throw not_permitted(); + case hipErrorNotSupported: + throw not_supported(); + case hipErrorUnknown: + throw unknown(); + default: + throw unknown(); + } } +} // namespace driver +} // namespace triton diff --git a/lib/driver/llvm.cc b/lib/driver/llvm.cc index ee82c467e..f78e9c8e9 100644 --- a/lib/driver/llvm.cc +++ b/lib/driver/llvm.cc @@ -1,73 +1,73 @@ /* Copyright 2015-2017 Philippe Tillet -* -* Permission is hereby granted, free of charge, to any person obtaining -* a copy of this software and associated documentation files -* (the "Software"), to deal in the Software without restriction, -* including without limitation the rights to use, copy, modify, merge, -* publish, distribute, sublicense, and/or sell copies of the Software, -* and to permit persons to whom the Software is furnished to do so, -* subject to the following conditions: -* -* The above copyright notice and this permission notice shall be -* included in all copies or substantial portions of the Software. -* -* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, -* EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF -* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. -* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY -* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, -* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE -* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. -*/ + * + * Permission is hereby granted, free of charge, to any person obtaining + * a copy of this software and associated documentation files + * (the "Software"), to deal in the Software without restriction, + * including without limitation the rights to use, copy, modify, merge, + * publish, distribute, sublicense, and/or sell copies of the Software, + * and to permit persons to whom the Software is furnished to do so, + * subject to the following conditions: + * + * The above copyright notice and this permission notice shall be + * included in all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, + * EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF + * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. + * IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY + * CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, + * TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE + * SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + */ #include #if __has_include() - #include +#include #endif -#include -#include -#include "triton/driver/llvm.h" #include "triton/driver/dispatch.h" #include "triton/driver/error.h" +#include "triton/driver/llvm.h" #include "triton/tools/sha1.hpp" +#include "triton/tools/sys/exec.hpp" #include "triton/tools/sys/getenv.hpp" #include "triton/tools/sys/mkdir.hpp" -#include "triton/tools/sys/exec.hpp" -#include "llvm/MC/TargetRegistry.h" +#include "llvm/ExecutionEngine/ExecutionEngine.h" +#include "llvm/ExecutionEngine/SectionMemoryManager.h" #include "llvm/IR/IRBuilder.h" -#include "llvm/IR/Verifier.h" #include "llvm/IR/IRPrintingPasses.h" +#include "llvm/IR/LegacyPassManager.h" #include "llvm/IR/Module.h" +#include "llvm/IR/Verifier.h" +#include "llvm/MC/TargetRegistry.h" #include "llvm/Support/CodeGen.h" #include "llvm/Support/CommandLine.h" #include "llvm/Support/SourceMgr.h" -#include "llvm/Support/raw_ostream.h" #include "llvm/Support/TargetSelect.h" +#include "llvm/Support/raw_ostream.h" #include "llvm/Target/TargetMachine.h" #include "llvm/Target/TargetOptions.h" -#include "llvm/IR/LegacyPassManager.h" -#include "llvm/ExecutionEngine/ExecutionEngine.h" -#include "llvm/ExecutionEngine/SectionMemoryManager.h" -#include "llvm/Transforms/Utils/Cloning.h" #include "llvm/Transforms/Scalar.h" +#include "llvm/Transforms/Utils/Cloning.h" +#include +#include // begin AMD stuff +#include "llvm/ADT/StringRef.h" +#include "llvm/Analysis/TargetLibraryInfo.h" #include "llvm/Support/FileSystem.h" #include "llvm/Support/FormattedStream.h" #include "llvm/Support/Program.h" #include "llvm/Support/ToolOutputFile.h" -#include "llvm/ADT/StringRef.h" -#include "llvm/Analysis/TargetLibraryInfo.h" // end AMD stuff -extern "C"{ - int set_curterm(char* nterm){ return 0; } - int del_curterm(char* nterm){ return 0; } - int tigetnum(char *capname) { return 0; } - int setupterm(char *term, int fildes, int *errret) { return 0; } +extern "C" { +int set_curterm(char *nterm) { return 0; } +int del_curterm(char *nterm) { return 0; } +int tigetnum(char *capname) { return 0; } +int setupterm(char *term, int fildes, int *errret) { return 0; } } -namespace triton{ -namespace driver{ +namespace triton { +namespace driver { void init_llvm() { LLVMInitializeNVPTXTargetInfo(); @@ -80,82 +80,93 @@ void init_llvm() { LLVMInitializeAMDGPUAsmPrinter(); } - /* ------------------------ */ // CUDA // /* ------------------------ */ -static bool find_and_replace(std::string& str, const std::string& begin, const std::string& end, const std::string& target){ +static bool find_and_replace(std::string &str, const std::string &begin, + const std::string &end, + const std::string &target) { size_t start_replace = str.find(begin); size_t end_replace = str.find(end, start_replace); - if(start_replace == std::string::npos) + if (start_replace == std::string::npos) return false; str.replace(start_replace, end_replace + 1 - start_replace, target); return true; } -std::string path_to_ptxas(int& version) { +std::string path_to_ptxas(int &version) { std::vector rets; std::string ret; // search pathes for ptxas std::vector ptxas_prefixes = {"", "/usr/local/cuda/bin/"}; std::string triton_ptxas = tools::getenv("TRITON_PTXAS_PATH"); - if(!triton_ptxas.empty()) + if (!triton_ptxas.empty()) ptxas_prefixes.insert(ptxas_prefixes.begin(), triton_ptxas); // see what path for ptxas are valid std::vector working_ptxas; - for(std::string prefix: ptxas_prefixes){ + for (std::string prefix : ptxas_prefixes) { std::string ptxas = prefix + "ptxas"; bool works = tools::exec(ptxas + " --version 2>&1", ret) == 0; - if(works) { + if (works) { working_ptxas.push_back(ptxas); rets.push_back(ret); } } // error if no working ptxas was found - if(working_ptxas.empty()) - throw std::runtime_error("`ptxas` was searched in TRITON_PTXAS_PATH, /usr/local/cuda/bin/ or PATH" + if (working_ptxas.empty()) + throw std::runtime_error("`ptxas` was searched in TRITON_PTXAS_PATH, " + "/usr/local/cuda/bin/ or PATH" " but a working version could not be found."); std::string ptxas = working_ptxas.front(); // parse version std::regex version_regex("release (\\d+)\\.(\\d+)"); std::smatch match; bool found = false; - // currently choosing the first ptxas. Other logics can be implemented in future - for(std::string ret : rets) { - if(std::regex_search(ret, match, version_regex)){ + // currently choosing the first ptxas. Other logics can be implemented in + // future + for (std::string ret : rets) { + if (std::regex_search(ret, match, version_regex)) { int major = std::stoi(match[1]); int minor = std::stoi(match[2]); - version = major*1000 + minor*10; + version = major * 1000 + minor * 10; found = true; break; } } - if ( not found) { + if (not found) { throw std::runtime_error("Error in parsing version"); } return ptxas; } - -int vptx(int version){ - if(version >= 11040) return 74; - if(version >= 11030) return 73; - if(version >= 11020) return 72; - if(version >= 11010) return 71; - if(version >= 11000) return 70; - if(version >= 10020) return 65; - if(version >= 10010) return 64; - if(version >= 10000) return 63; +int vptx(int version) { + if (version >= 11040) + return 74; + if (version >= 11030) + return 73; + if (version >= 11020) + return 72; + if (version >= 11010) + return 71; + if (version >= 11000) + return 70; + if (version >= 10020) + return 65; + if (version >= 10010) + return 64; + if (version >= 10000) + return 63; throw std::runtime_error("Triton requires CUDA 10+"); } -std::string llir_to_ptx(llvm::Module* module, int cc, int version){ +std::string llir_to_ptx(llvm::Module *module, int cc, int version) { // LLVM version in use may not officially support target hardware int max_nvvm_cc = 75; int max_nvvm_ptx = 74; // options auto options = llvm::cl::getRegisteredOptions(); - auto* short_ptr = static_cast*>(options["nvptx-short-ptr"]); + auto *short_ptr = + static_cast *>(options["nvptx-short-ptr"]); assert(short_ptr); short_ptr->setValue(true); // compute capability @@ -170,7 +181,8 @@ std::string llir_to_ptx(llvm::Module* module, int cc, int version){ std::string proc = "sm_" + std::to_string(std::min(cc, max_nvvm_cc)); std::string layout = ""; std::string features = ""; - // std::string features = "+ptx" + std::to_string(std::min(ptx, max_nvvm_ptx)); + // std::string features = "+ptx" + std::to_string(std::min(ptx, + // max_nvvm_ptx)); init_llvm(); // verify and store llvm llvm::legacy::PassManager pm; @@ -181,16 +193,18 @@ std::string llir_to_ptx(llvm::Module* module, int cc, int version){ // create machine module->setTargetTriple(triple); std::string error; - auto target = llvm::TargetRegistry::lookupTarget(module->getTargetTriple(), error); + auto target = + llvm::TargetRegistry::lookupTarget(module->getTargetTriple(), error); llvm::TargetOptions opt; opt.AllowFPOpFusion = llvm::FPOpFusion::Fast; opt.UnsafeFPMath = false; opt.NoInfsFPMath = false; opt.NoNaNsFPMath = true; - llvm::TargetMachine *machine = target->createTargetMachine(module->getTargetTriple(), proc, features, opt, - llvm::Reloc::PIC_, llvm::None, llvm::CodeGenOpt::Aggressive); + llvm::TargetMachine *machine = target->createTargetMachine( + module->getTargetTriple(), proc, features, opt, llvm::Reloc::PIC_, + llvm::None, llvm::CodeGenOpt::Aggressive); // set data layout - if(layout.empty()) + if (layout.empty()) module->setDataLayout(machine->createDataLayout()); else module->setDataLayout(layout); @@ -200,19 +214,25 @@ std::string llir_to_ptx(llvm::Module* module, int cc, int version){ llvm::legacy::PassManager pass; llvm::raw_svector_ostream stream(buffer); // emit - machine->addPassesToEmitFile(pass, stream, nullptr, llvm::CodeGenFileType::CGFT_AssemblyFile); + machine->addPassesToEmitFile(pass, stream, nullptr, + llvm::CodeGenFileType::CGFT_AssemblyFile); pass.run(*module); // post-process std::string result(buffer.begin(), buffer.end()); - find_and_replace(result, ".version", "\n", ".version " + std::to_string(ptx_major) + "." + std::to_string(ptx_minor) + "\n"); + find_and_replace(result, ".version", "\n", + ".version " + std::to_string(ptx_major) + "." + + std::to_string(ptx_minor) + "\n"); find_and_replace(result, ".target", "\n", ".target " + sm + "\n"); - while(find_and_replace(result, "\t// begin inline asm", "\n", "")); - while(find_and_replace(result, "\t// end inline asm", "\n", "")); + while (find_and_replace(result, "\t// begin inline asm", "\n", "")) + ; + while (find_and_replace(result, "\t// end inline asm", "\n", "")) + ; return result; } -std::string ptx_to_cubin(const std::string& ptx, const std::string& ptxas, int cc) { +std::string ptx_to_cubin(const std::string &ptx, const std::string &ptxas, + int cc) { // compile ptx with ptxas char _fsrc[L_tmpnam]; char _flog[L_tmpnam]; @@ -221,15 +241,16 @@ std::string ptx_to_cubin(const std::string& ptx, const std::string& ptxas, int c std::string fsrc = _fsrc; std::string flog = _flog; std::string fbin = fsrc + ".o"; - const char* _fbin = fbin.c_str(); + const char *_fbin = fbin.c_str(); std::ofstream ofs(fsrc); ofs << ptx << std::endl; ofs.close(); std::string cmd; int err; - cmd = ptxas + " -v --gpu-name=sm_" + std::to_string(cc) + " " + fsrc + " -o " + fsrc + ".o 2> " + flog; + cmd = ptxas + " -v --gpu-name=sm_" + std::to_string(cc) + " " + fsrc + + " -o " + fsrc + ".o 2> " + flog; err = system(cmd.c_str()); - if(err != 0){ + if (err != 0) { std::ifstream _log(_flog); std::string log(std::istreambuf_iterator(_log), {}); unlink(_fsrc); @@ -237,7 +258,7 @@ std::string ptx_to_cubin(const std::string& ptx, const std::string& ptxas, int c throw std::runtime_error("Internal Triton PTX codegen error: \n" + log); } CUmodule ret; - std::ifstream _cubin(_fbin, std::ios::binary ); + std::ifstream _cubin(_fbin, std::ios::binary); std::string cubin(std::istreambuf_iterator(_cubin), {}); _cubin.close(); unlink(_fsrc); @@ -251,11 +272,11 @@ std::string ptx_to_cubin(const std::string& ptx, const std::string& ptxas, int c // HIP // /* ------------------------ */ -std::string llir_to_amdgpu(llvm::Module* module, const std::string& _proc) { +std::string llir_to_amdgpu(llvm::Module *module, const std::string &_proc) { init_llvm(); -// proc = std::get<0>(GetFeatureStrFromGCNArchName(rocminfo)); -// features = std::get<1>(GetFeatureStrFromGCNArchName(rocminfo)); + // proc = std::get<0>(GetFeatureStrFromGCNArchName(rocminfo)); + // features = std::get<1>(GetFeatureStrFromGCNArchName(rocminfo)); // create llvm::SmallVector buffer; @@ -270,17 +291,18 @@ std::string llir_to_amdgpu(llvm::Module* module, const std::string& _proc) { // create machine module->setTargetTriple(triple); std::string error; - auto target = llvm::TargetRegistry::lookupTarget(module->getTargetTriple(), error); + auto target = + llvm::TargetRegistry::lookupTarget(module->getTargetTriple(), error); llvm::TargetOptions opt; opt.AllowFPOpFusion = llvm::FPOpFusion::Fast; opt.UnsafeFPMath = false; opt.NoInfsFPMath = false; opt.NoNaNsFPMath = true; - llvm::TargetMachine *machine = target->createTargetMachine(module->getTargetTriple(), proc, features, opt, - llvm::Reloc::PIC_, llvm::None, - llvm::CodeGenOpt::Aggressive); + llvm::TargetMachine *machine = target->createTargetMachine( + module->getTargetTriple(), proc, features, opt, llvm::Reloc::PIC_, + llvm::None, llvm::CodeGenOpt::Aggressive); // set data layout - if(layout.empty()) + if (layout.empty()) module->setDataLayout(machine->createDataLayout()); else module->setDataLayout(layout); @@ -295,33 +317,37 @@ std::string llir_to_amdgpu(llvm::Module* module, const std::string& _proc) { std::error_code ec; // Save GCN ISA binary. - std::string isabin_path = std::string("/tmp/") + module_name + std::string(".o"); + std::string isabin_path = + std::string("/tmp/") + module_name + std::string(".o"); std::unique_ptr isabin_fs( new llvm::raw_fd_ostream(isabin_path, ec, llvm::sys::fs::OF_Text)); - if (ec) - { - std::cout << isabin_path << " was not created. error code: " << ec << std::endl; + if (ec) { + std::cout << isabin_path << " was not created. error code: " << ec + << std::endl; } // emit - machine->addPassesToEmitFile(pass, *isabin_fs, nullptr, llvm::CGFT_ObjectFile); + machine->addPassesToEmitFile(pass, *isabin_fs, nullptr, + llvm::CGFT_ObjectFile); pass.run(*module); // Save GCN ISA. - std::string amdgcn_path = std::string("/tmp/") + module_name + std::string(".gcn"); + std::string amdgcn_path = + std::string("/tmp/") + module_name + std::string(".gcn"); std::string result(buffer.begin(), buffer.end()); std::ofstream amdgcn(amdgcn_path); amdgcn << result; amdgcn.close(); // generate HASCO file - std::string hsaco_path = std::string("/tmp/") + module_name + std::string(".hsaco"); + std::string hsaco_path = + std::string("/tmp/") + module_name + std::string(".hsaco"); std::string error_message; int lld_result = llvm::sys::ExecuteAndWait("/opt/rocm/llvm/bin/ld.lld", - {"/opt/rocm/llvm/bin/ld.lld", "-flavor", "gnu", "-shared", "-o", hsaco_path, isabin_path}, + {"/opt/rocm/llvm/bin/ld.lld", "-flavor", "gnu", + "-shared", "-o", hsaco_path, isabin_path}, llvm::None, {}, 0, 0, &error_message); - if (lld_result) - { + if (lld_result) { std::cout << "ld.lld execute fail: " << std::endl; std::cout << error_message << std::endl; std::cout << lld_result << std::endl; @@ -330,33 +356,29 @@ std::string llir_to_amdgpu(llvm::Module* module, const std::string& _proc) { return hsaco_path; } - -hipModule_t amdgpu_to_hipmodule(const std::string& path) { +hipModule_t amdgpu_to_hipmodule(const std::string &path) { // Read HSACO. std::ifstream hsaco_file(path, std::ios::binary | std::ios::ate); std::ifstream::pos_type hsaco_file_size = hsaco_file.tellg(); std::vector hsaco(hsaco_file_size); hsaco_file.seekg(0, std::ios::beg); - hsaco_file.read(reinterpret_cast(&hsaco[0]), hsaco_file_size); + hsaco_file.read(reinterpret_cast(&hsaco[0]), hsaco_file_size); hsaco_file.close(); - hipJitOption opt[] = {hipJitOptionErrorLogBufferSizeBytes, hipJitOptionErrorLogBuffer, - hipJitOptionInfoLogBufferSizeBytes, hipJitOptionInfoLogBuffer, - hipJitOptionLogVerbose}; + hipJitOption opt[] = {hipJitOptionErrorLogBufferSizeBytes, + hipJitOptionErrorLogBuffer, + hipJitOptionInfoLogBufferSizeBytes, + hipJitOptionInfoLogBuffer, hipJitOptionLogVerbose}; const unsigned int errbufsize = 8192; const unsigned int logbufsize = 8192; char _err[errbufsize]; char _log[logbufsize]; - void* optval[] = {(void*)(uintptr_t)errbufsize, - (void*)_err, (void*)(uintptr_t)logbufsize, - (void*)_log, (void*)1}; + void *optval[] = {(void *)(uintptr_t)errbufsize, (void *)_err, + (void *)(uintptr_t)logbufsize, (void *)_log, (void *)1}; hipModule_t ret; dispatch::hipModuleLoadDataEx(&ret, hsaco.data(), 5, opt, optval); return ret; } - - -} -} - +} // namespace driver +} // namespace triton diff --git a/python/src/pybind11/attr.h b/python/src/pybind11/attr.h index 6962d6fc5..eada4e1f1 100644 --- a/python/src/pybind11/attr.h +++ b/python/src/pybind11/attr.h @@ -18,60 +18,83 @@ NAMESPACE_BEGIN(PYBIND11_NAMESPACE) /// @{ /// Annotation for methods -struct is_method { handle class_; is_method(const handle &c) : class_(c) { } }; +struct is_method { + handle class_; + is_method(const handle &c) : class_(c) {} +}; /// Annotation for operators -struct is_operator { }; +struct is_operator {}; /// Annotation for parent scope -struct scope { handle value; scope(const handle &s) : value(s) { } }; +struct scope { + handle value; + scope(const handle &s) : value(s) {} +}; /// Annotation for documentation -struct doc { const char *value; doc(const char *value) : value(value) { } }; +struct doc { + const char *value; + doc(const char *value) : value(value) {} +}; /// Annotation for function names -struct name { const char *value; name(const char *value) : value(value) { } }; +struct name { + const char *value; + name(const char *value) : value(value) {} +}; -/// Annotation indicating that a function is an overload associated with a given "sibling" -struct sibling { handle value; sibling(const handle &value) : value(value.ptr()) { } }; +/// Annotation indicating that a function is an overload associated with a given +/// "sibling" +struct sibling { + handle value; + sibling(const handle &value) : value(value.ptr()) {} +}; /// Annotation indicating that a class derives from another given type template struct base { - PYBIND11_DEPRECATED("base() was deprecated in favor of specifying 'T' as a template argument to class_") - base() { } + PYBIND11_DEPRECATED("base() was deprecated in favor of specifying 'T' as " + "a template argument to class_") + base() {} }; /// Keep patient alive while nurse lives -template struct keep_alive { }; +template struct keep_alive {}; -/// Annotation indicating that a class is involved in a multiple inheritance relationship -struct multiple_inheritance { }; +/// Annotation indicating that a class is involved in a multiple inheritance +/// relationship +struct multiple_inheritance {}; /// Annotation which enables dynamic attributes, i.e. adds `__dict__` to a class -struct dynamic_attr { }; +struct dynamic_attr {}; /// Annotation which enables the buffer protocol for a type -struct buffer_protocol { }; +struct buffer_protocol {}; /// Annotation which requests that a special metaclass is created for a type struct metaclass { - handle value; + handle value; - PYBIND11_DEPRECATED("py::metaclass() is no longer required. It's turned on by default now.") - metaclass() {} + PYBIND11_DEPRECATED( + "py::metaclass() is no longer required. It's turned on by default now.") + metaclass() {} - /// Override pybind11's default metaclass - explicit metaclass(handle value) : value(value) { } + /// Override pybind11's default metaclass + explicit metaclass(handle value) : value(value) {} }; /// Annotation that marks a class as local to the module: -struct module_local { const bool value; constexpr module_local(bool v = true) : value(v) { } }; +struct module_local { + const bool value; + constexpr module_local(bool v = true) : value(v) {} +}; /// Annotation to mark enums as an arithmetic type -struct arithmetic { }; +struct arithmetic {}; /** \rst - A call policy which places one or more guard variables (``Ts...``) around the function call. + A call policy which places one or more guard variables (``Ts...``) around + the function call. For example, this definition: @@ -92,20 +115,19 @@ template struct call_guard; template <> struct call_guard<> { using type = detail::void_type; }; -template -struct call_guard { - static_assert(std::is_default_constructible::value, - "The guard type must be default constructible"); +template struct call_guard { + static_assert(std::is_default_constructible::value, + "The guard type must be default constructible"); - using type = T; + using type = T; }; -template -struct call_guard { - struct type { - T guard{}; // Compose multiple guard types with left-to-right default-constructor order - typename call_guard::type next{}; - }; +template struct call_guard { + struct type { + T guard{}; // Compose multiple guard types with left-to-right + // default-constructor order + typename call_guard::type next{}; + }; }; /// @} annotations @@ -115,181 +137,190 @@ NAMESPACE_BEGIN(detail) enum op_id : int; enum op_type : int; struct undefined_t; -template struct op_; -inline void keep_alive_impl(size_t Nurse, size_t Patient, function_call &call, handle ret); +template +struct op_; +inline void keep_alive_impl(size_t Nurse, size_t Patient, function_call &call, + handle ret); /// Internal data structure which holds metadata about a keyword argument struct argument_record { - const char *name; ///< Argument name - const char *descr; ///< Human-readable version of the argument value - handle value; ///< Associated Python object - bool convert : 1; ///< True if the argument is allowed to convert when loading - bool none : 1; ///< True if None is allowed when loading + const char *name; ///< Argument name + const char *descr; ///< Human-readable version of the argument value + handle value; ///< Associated Python object + bool convert : 1; ///< True if the argument is allowed to convert when loading + bool none : 1; ///< True if None is allowed when loading - argument_record(const char *name, const char *descr, handle value, bool convert, bool none) - : name(name), descr(descr), value(value), convert(convert), none(none) { } + argument_record(const char *name, const char *descr, handle value, + bool convert, bool none) + : name(name), descr(descr), value(value), convert(convert), none(none) {} }; -/// Internal data structure which holds metadata about a bound function (signature, overloads, etc.) +/// Internal data structure which holds metadata about a bound function +/// (signature, overloads, etc.) struct function_record { - function_record() - : is_constructor(false), is_new_style_constructor(false), is_stateless(false), - is_operator(false), has_args(false), has_kwargs(false), is_method(false) { } + function_record() + : is_constructor(false), is_new_style_constructor(false), + is_stateless(false), is_operator(false), has_args(false), + has_kwargs(false), is_method(false) {} - /// Function name - char *name = nullptr; /* why no C++ strings? They generate heavier code.. */ + /// Function name + char *name = nullptr; /* why no C++ strings? They generate heavier code.. */ - // User-specified documentation string - char *doc = nullptr; + // User-specified documentation string + char *doc = nullptr; - /// Human-readable version of the function signature - char *signature = nullptr; + /// Human-readable version of the function signature + char *signature = nullptr; - /// List of registered keyword arguments - std::vector args; + /// List of registered keyword arguments + std::vector args; - /// Pointer to lambda function which converts arguments and performs the actual call - handle (*impl) (function_call &) = nullptr; + /// Pointer to lambda function which converts arguments and performs the + /// actual call + handle (*impl)(function_call &) = nullptr; - /// Storage for the wrapped function pointer and captured data, if any - void *data[3] = { }; + /// Storage for the wrapped function pointer and captured data, if any + void *data[3] = {}; - /// Pointer to custom destructor for 'data' (if needed) - void (*free_data) (function_record *ptr) = nullptr; + /// Pointer to custom destructor for 'data' (if needed) + void (*free_data)(function_record *ptr) = nullptr; - /// Return value policy associated with this function - return_value_policy policy = return_value_policy::automatic; + /// Return value policy associated with this function + return_value_policy policy = return_value_policy::automatic; - /// True if name == '__init__' - bool is_constructor : 1; + /// True if name == '__init__' + bool is_constructor : 1; - /// True if this is a new-style `__init__` defined in `detail/init.h` - bool is_new_style_constructor : 1; + /// True if this is a new-style `__init__` defined in `detail/init.h` + bool is_new_style_constructor : 1; - /// True if this is a stateless function pointer - bool is_stateless : 1; + /// True if this is a stateless function pointer + bool is_stateless : 1; - /// True if this is an operator (__add__), etc. - bool is_operator : 1; + /// True if this is an operator (__add__), etc. + bool is_operator : 1; - /// True if the function has a '*args' argument - bool has_args : 1; + /// True if the function has a '*args' argument + bool has_args : 1; - /// True if the function has a '**kwargs' argument - bool has_kwargs : 1; + /// True if the function has a '**kwargs' argument + bool has_kwargs : 1; - /// True if this is a method - bool is_method : 1; + /// True if this is a method + bool is_method : 1; - /// Number of arguments (including py::args and/or py::kwargs, if present) - std::uint16_t nargs; + /// Number of arguments (including py::args and/or py::kwargs, if present) + std::uint16_t nargs; - /// Python method object - PyMethodDef *def = nullptr; + /// Python method object + PyMethodDef *def = nullptr; - /// Python handle to the parent scope (a class or a module) - handle scope; + /// Python handle to the parent scope (a class or a module) + handle scope; - /// Python handle to the sibling function representing an overload chain - handle sibling; + /// Python handle to the sibling function representing an overload chain + handle sibling; - /// Pointer to next overload - function_record *next = nullptr; + /// Pointer to next overload + function_record *next = nullptr; }; -/// Special data structure which (temporarily) holds metadata about a bound class +/// Special data structure which (temporarily) holds metadata about a bound +/// class struct type_record { - PYBIND11_NOINLINE type_record() - : multiple_inheritance(false), dynamic_attr(false), buffer_protocol(false), - default_holder(true), module_local(false) { } + PYBIND11_NOINLINE type_record() + : multiple_inheritance(false), dynamic_attr(false), + buffer_protocol(false), default_holder(true), module_local(false) {} - /// Handle to the parent scope - handle scope; + /// Handle to the parent scope + handle scope; - /// Name of the class - const char *name = nullptr; + /// Name of the class + const char *name = nullptr; - // Pointer to RTTI type_info data structure - const std::type_info *type = nullptr; + // Pointer to RTTI type_info data structure + const std::type_info *type = nullptr; - /// How large is the underlying C++ type? - size_t type_size = 0; + /// How large is the underlying C++ type? + size_t type_size = 0; - /// What is the alignment of the underlying C++ type? - size_t type_align = 0; + /// What is the alignment of the underlying C++ type? + size_t type_align = 0; - /// How large is the type's holder? - size_t holder_size = 0; + /// How large is the type's holder? + size_t holder_size = 0; - /// The global operator new can be overridden with a class-specific variant - void *(*operator_new)(size_t) = nullptr; + /// The global operator new can be overridden with a class-specific variant + void *(*operator_new)(size_t) = nullptr; - /// Function pointer to class_<..>::init_instance - void (*init_instance)(instance *, const void *) = nullptr; + /// Function pointer to class_<..>::init_instance + void (*init_instance)(instance *, const void *) = nullptr; - /// Function pointer to class_<..>::dealloc - void (*dealloc)(detail::value_and_holder &) = nullptr; + /// Function pointer to class_<..>::dealloc + void (*dealloc)(detail::value_and_holder &) = nullptr; - /// List of base classes of the newly created type - list bases; + /// List of base classes of the newly created type + list bases; - /// Optional docstring - const char *doc = nullptr; + /// Optional docstring + const char *doc = nullptr; - /// Custom metaclass (optional) - handle metaclass; + /// Custom metaclass (optional) + handle metaclass; - /// Multiple inheritance marker - bool multiple_inheritance : 1; + /// Multiple inheritance marker + bool multiple_inheritance : 1; - /// Does the class manage a __dict__? - bool dynamic_attr : 1; + /// Does the class manage a __dict__? + bool dynamic_attr : 1; - /// Does the class implement the buffer protocol? - bool buffer_protocol : 1; + /// Does the class implement the buffer protocol? + bool buffer_protocol : 1; - /// Is the default (unique_ptr) holder type used? - bool default_holder : 1; + /// Is the default (unique_ptr) holder type used? + bool default_holder : 1; - /// Is the class definition local to the module shared object? - bool module_local : 1; + /// Is the class definition local to the module shared object? + bool module_local : 1; - PYBIND11_NOINLINE void add_base(const std::type_info &base, void *(*caster)(void *)) { - auto base_info = detail::get_type_info(base, false); - if (!base_info) { - std::string tname(base.name()); - detail::clean_type_id(tname); - pybind11_fail("generic_type: type \"" + std::string(name) + - "\" referenced unknown base type \"" + tname + "\""); - } - - if (default_holder != base_info->default_holder) { - std::string tname(base.name()); - detail::clean_type_id(tname); - pybind11_fail("generic_type: type \"" + std::string(name) + "\" " + - (default_holder ? "does not have" : "has") + - " a non-default holder type while its base \"" + tname + "\" " + - (base_info->default_holder ? "does not" : "does")); - } - - bases.append((PyObject *) base_info->type); - - if (base_info->type->tp_dictoffset != 0) - dynamic_attr = true; - - if (caster) - base_info->implicit_casts.emplace_back(type, caster); + PYBIND11_NOINLINE void add_base(const std::type_info &base, + void *(*caster)(void *)) { + auto base_info = detail::get_type_info(base, false); + if (!base_info) { + std::string tname(base.name()); + detail::clean_type_id(tname); + pybind11_fail("generic_type: type \"" + std::string(name) + + "\" referenced unknown base type \"" + tname + "\""); } + + if (default_holder != base_info->default_holder) { + std::string tname(base.name()); + detail::clean_type_id(tname); + pybind11_fail("generic_type: type \"" + std::string(name) + "\" " + + (default_holder ? "does not have" : "has") + + " a non-default holder type while its base \"" + tname + + "\" " + (base_info->default_holder ? "does not" : "does")); + } + + bases.append((PyObject *)base_info->type); + + if (base_info->type->tp_dictoffset != 0) + dynamic_attr = true; + + if (caster) + base_info->implicit_casts.emplace_back(type, caster); + } }; -inline function_call::function_call(const function_record &f, handle p) : - func(f), parent(p) { - args.reserve(f.nargs); - args_convert.reserve(f.nargs); +inline function_call::function_call(const function_record &f, handle p) + : func(f), parent(p) { + args.reserve(f.nargs); + args_convert.reserve(f.nargs); } /// Tag for a new-style `__init__` defined in `detail/init.h` -struct is_new_style_constructor { }; +struct is_new_style_constructor {}; /** * Partial template specializations to process custom attributes provided to @@ -300,135 +331,191 @@ struct is_new_style_constructor { }; template struct process_attribute; template struct process_attribute_default { - /// Default implementation: do nothing - static void init(const T &, function_record *) { } - static void init(const T &, type_record *) { } - static void precall(function_call &) { } - static void postcall(function_call &, handle) { } + /// Default implementation: do nothing + static void init(const T &, function_record *) {} + static void init(const T &, type_record *) {} + static void precall(function_call &) {} + static void postcall(function_call &, handle) {} }; /// Process an attribute specifying the function's name template <> struct process_attribute : process_attribute_default { - static void init(const name &n, function_record *r) { r->name = const_cast(n.value); } + static void init(const name &n, function_record *r) { + r->name = const_cast(n.value); + } }; /// Process an attribute specifying the function's docstring template <> struct process_attribute : process_attribute_default { - static void init(const doc &n, function_record *r) { r->doc = const_cast(n.value); } + static void init(const doc &n, function_record *r) { + r->doc = const_cast(n.value); + } }; -/// Process an attribute specifying the function's docstring (provided as a C-style string) -template <> struct process_attribute : process_attribute_default { - static void init(const char *d, function_record *r) { r->doc = const_cast(d); } - static void init(const char *d, type_record *r) { r->doc = const_cast(d); } +/// Process an attribute specifying the function's docstring (provided as a +/// C-style string) +template <> +struct process_attribute + : process_attribute_default { + static void init(const char *d, function_record *r) { + r->doc = const_cast(d); + } + static void init(const char *d, type_record *r) { + r->doc = const_cast(d); + } }; -template <> struct process_attribute : process_attribute { }; +template <> +struct process_attribute : process_attribute {}; /// Process an attribute indicating the function's return value policy -template <> struct process_attribute : process_attribute_default { - static void init(const return_value_policy &p, function_record *r) { r->policy = p; } +template <> +struct process_attribute + : process_attribute_default { + static void init(const return_value_policy &p, function_record *r) { + r->policy = p; + } }; -/// Process an attribute which indicates that this is an overloaded function associated with a given sibling -template <> struct process_attribute : process_attribute_default { - static void init(const sibling &s, function_record *r) { r->sibling = s.value; } +/// Process an attribute which indicates that this is an overloaded function +/// associated with a given sibling +template <> +struct process_attribute : process_attribute_default { + static void init(const sibling &s, function_record *r) { + r->sibling = s.value; + } }; /// Process an attribute which indicates that this function is a method -template <> struct process_attribute : process_attribute_default { - static void init(const is_method &s, function_record *r) { r->is_method = true; r->scope = s.class_; } +template <> +struct process_attribute : process_attribute_default { + static void init(const is_method &s, function_record *r) { + r->is_method = true; + r->scope = s.class_; + } }; /// Process an attribute which indicates the parent scope of a method template <> struct process_attribute : process_attribute_default { - static void init(const scope &s, function_record *r) { r->scope = s.value; } + static void init(const scope &s, function_record *r) { r->scope = s.value; } }; /// Process an attribute which indicates that this function is an operator -template <> struct process_attribute : process_attribute_default { - static void init(const is_operator &, function_record *r) { r->is_operator = true; } +template <> +struct process_attribute : process_attribute_default { + static void init(const is_operator &, function_record *r) { + r->is_operator = true; + } }; -template <> struct process_attribute : process_attribute_default { - static void init(const is_new_style_constructor &, function_record *r) { r->is_new_style_constructor = true; } +template <> +struct process_attribute + : process_attribute_default { + static void init(const is_new_style_constructor &, function_record *r) { + r->is_new_style_constructor = true; + } }; /// Process a keyword argument attribute (*without* a default value) template <> struct process_attribute : process_attribute_default { - static void init(const arg &a, function_record *r) { - if (r->is_method && r->args.empty()) - r->args.emplace_back("self", nullptr, handle(), true /*convert*/, false /*none not allowed*/); - r->args.emplace_back(a.name, nullptr, handle(), !a.flag_noconvert, a.flag_none); - } + static void init(const arg &a, function_record *r) { + if (r->is_method && r->args.empty()) + r->args.emplace_back("self", nullptr, handle(), true /*convert*/, + false /*none not allowed*/); + r->args.emplace_back(a.name, nullptr, handle(), !a.flag_noconvert, + a.flag_none); + } }; /// Process a keyword argument attribute (*with* a default value) template <> struct process_attribute : process_attribute_default { - static void init(const arg_v &a, function_record *r) { - if (r->is_method && r->args.empty()) - r->args.emplace_back("self", nullptr /*descr*/, handle() /*parent*/, true /*convert*/, false /*none not allowed*/); + static void init(const arg_v &a, function_record *r) { + if (r->is_method && r->args.empty()) + r->args.emplace_back("self", nullptr /*descr*/, handle() /*parent*/, + true /*convert*/, false /*none not allowed*/); - if (!a.value) { + if (!a.value) { #if !defined(NDEBUG) - std::string descr("'"); - if (a.name) descr += std::string(a.name) + ": "; - descr += a.type + "'"; - if (r->is_method) { - if (r->name) - descr += " in method '" + (std::string) str(r->scope) + "." + (std::string) r->name + "'"; - else - descr += " in method of '" + (std::string) str(r->scope) + "'"; - } else if (r->name) { - descr += " in function '" + (std::string) r->name + "'"; - } - pybind11_fail("arg(): could not convert default argument " - + descr + " into a Python object (type not registered yet?)"); + std::string descr("'"); + if (a.name) + descr += std::string(a.name) + ": "; + descr += a.type + "'"; + if (r->is_method) { + if (r->name) + descr += " in method '" + (std::string)str(r->scope) + "." + + (std::string)r->name + "'"; + else + descr += " in method of '" + (std::string)str(r->scope) + "'"; + } else if (r->name) { + descr += " in function '" + (std::string)r->name + "'"; + } + pybind11_fail("arg(): could not convert default argument " + descr + + " into a Python object (type not registered yet?)"); #else - pybind11_fail("arg(): could not convert default argument " - "into a Python object (type not registered yet?). " - "Compile in debug mode for more information."); + pybind11_fail("arg(): could not convert default argument " + "into a Python object (type not registered yet?). " + "Compile in debug mode for more information."); #endif - } - r->args.emplace_back(a.name, a.descr, a.value.inc_ref(), !a.flag_noconvert, a.flag_none); } + r->args.emplace_back(a.name, a.descr, a.value.inc_ref(), !a.flag_noconvert, + a.flag_none); + } }; -/// Process a parent class attribute. Single inheritance only (class_ itself already guarantees that) +/// Process a parent class attribute. Single inheritance only (class_ itself +/// already guarantees that) template -struct process_attribute::value>> : process_attribute_default { - static void init(const handle &h, type_record *r) { r->bases.append(h); } +struct process_attribute::value>> + : process_attribute_default { + static void init(const handle &h, type_record *r) { r->bases.append(h); } }; -/// Process a parent class attribute (deprecated, does not support multiple inheritance) +/// Process a parent class attribute (deprecated, does not support multiple +/// inheritance) template struct process_attribute> : process_attribute_default> { - static void init(const base &, type_record *r) { r->add_base(typeid(T), nullptr); } + static void init(const base &, type_record *r) { + r->add_base(typeid(T), nullptr); + } }; /// Process a multiple inheritance attribute template <> -struct process_attribute : process_attribute_default { - static void init(const multiple_inheritance &, type_record *r) { r->multiple_inheritance = true; } +struct process_attribute + : process_attribute_default { + static void init(const multiple_inheritance &, type_record *r) { + r->multiple_inheritance = true; + } }; template <> -struct process_attribute : process_attribute_default { - static void init(const dynamic_attr &, type_record *r) { r->dynamic_attr = true; } +struct process_attribute + : process_attribute_default { + static void init(const dynamic_attr &, type_record *r) { + r->dynamic_attr = true; + } }; template <> -struct process_attribute : process_attribute_default { - static void init(const buffer_protocol &, type_record *r) { r->buffer_protocol = true; } +struct process_attribute + : process_attribute_default { + static void init(const buffer_protocol &, type_record *r) { + r->buffer_protocol = true; + } }; template <> struct process_attribute : process_attribute_default { - static void init(const metaclass &m, type_record *r) { r->metaclass = m.value; } + static void init(const metaclass &m, type_record *r) { + r->metaclass = m.value; + } }; template <> -struct process_attribute : process_attribute_default { - static void init(const module_local &l, type_record *r) { r->module_local = l.value; } +struct process_attribute + : process_attribute_default { + static void init(const module_local &l, type_record *r) { + r->module_local = l.value; + } }; /// Process an 'arithmetic' attribute for enums (does nothing here) @@ -436,57 +523,78 @@ template <> struct process_attribute : process_attribute_default {}; template -struct process_attribute> : process_attribute_default> { }; +struct process_attribute> + : process_attribute_default> {}; /** * Process a keep_alive call policy -- invokes keep_alive_impl during the * pre-call handler if both Nurse, Patient != 0 and use the post-call handler * otherwise */ -template struct process_attribute> : public process_attribute_default> { - template = 0> - static void precall(function_call &call) { keep_alive_impl(Nurse, Patient, call, handle()); } - template = 0> - static void postcall(function_call &, handle) { } - template = 0> - static void precall(function_call &) { } - template = 0> - static void postcall(function_call &call, handle ret) { keep_alive_impl(Nurse, Patient, call, ret); } +template +struct process_attribute> + : public process_attribute_default> { + template = 0> + static void precall(function_call &call) { + keep_alive_impl(Nurse, Patient, call, handle()); + } + template = 0> + static void postcall(function_call &, handle) {} + template = 0> + static void precall(function_call &) {} + template = 0> + static void postcall(function_call &call, handle ret) { + keep_alive_impl(Nurse, Patient, call, ret); + } }; /// Recursively iterate over variadic template arguments template struct process_attributes { - static void init(const Args&... args, function_record *r) { - int unused[] = { 0, (process_attribute::type>::init(args, r), 0) ... }; - ignore_unused(unused); - } - static void init(const Args&... args, type_record *r) { - int unused[] = { 0, (process_attribute::type>::init(args, r), 0) ... }; - ignore_unused(unused); - } - static void precall(function_call &call) { - int unused[] = { 0, (process_attribute::type>::precall(call), 0) ... }; - ignore_unused(unused); - } - static void postcall(function_call &call, handle fn_ret) { - int unused[] = { 0, (process_attribute::type>::postcall(call, fn_ret), 0) ... }; - ignore_unused(unused); - } + static void init(const Args &... args, function_record *r) { + int unused[] = { + 0, (process_attribute::type>::init(args, r), + 0)...}; + ignore_unused(unused); + } + static void init(const Args &... args, type_record *r) { + int unused[] = { + 0, (process_attribute::type>::init(args, r), + 0)...}; + ignore_unused(unused); + } + static void precall(function_call &call) { + int unused[] = { + 0, (process_attribute::type>::precall(call), + 0)...}; + ignore_unused(unused); + } + static void postcall(function_call &call, handle fn_ret) { + int unused[] = { + 0, (process_attribute::type>::postcall( + call, fn_ret), + 0)...}; + ignore_unused(unused); + } }; -template -using is_call_guard = is_instantiation; +template using is_call_guard = is_instantiation; -/// Extract the ``type`` from the first `call_guard` in `Extras...` (or `void_type` if none found) +/// Extract the ``type`` from the first `call_guard` in `Extras...` (or +/// `void_type` if none found) template -using extract_guard_t = typename exactly_one_t, Extra...>::type; +using extract_guard_t = + typename exactly_one_t, Extra...>::type; /// Check the number of named arguments at compile time template ::value...), - size_t self = constexpr_sum(std::is_same::value...)> + size_t self = constexpr_sum(std::is_same::value...)> constexpr bool expected_num_args(size_t nargs, bool has_args, bool has_kwargs) { - return named == 0 || (self + named + has_args + has_kwargs) == nargs; + return named == 0 || (self + named + has_args + has_kwargs) == nargs; } NAMESPACE_END(detail) diff --git a/python/src/pybind11/buffer_info.h b/python/src/pybind11/buffer_info.h index 9f072fa73..3c080140c 100644 --- a/python/src/pybind11/buffer_info.h +++ b/python/src/pybind11/buffer_info.h @@ -15,93 +15,112 @@ NAMESPACE_BEGIN(PYBIND11_NAMESPACE) /// Information record describing a Python buffer object struct buffer_info { - void *ptr = nullptr; // Pointer to the underlying storage - ssize_t itemsize = 0; // Size of individual items in bytes - ssize_t size = 0; // Total number of entries - std::string format; // For homogeneous buffers, this should be set to format_descriptor::format() - ssize_t ndim = 0; // Number of dimensions - std::vector shape; // Shape of the tensor (1 entry per dimension) - std::vector strides; // Number of entries between adjacent entries (for each per dimension) + void *ptr = nullptr; // Pointer to the underlying storage + ssize_t itemsize = 0; // Size of individual items in bytes + ssize_t size = 0; // Total number of entries + std::string format; // For homogeneous buffers, this should be set to + // format_descriptor::format() + ssize_t ndim = 0; // Number of dimensions + std::vector shape; // Shape of the tensor (1 entry per dimension) + std::vector strides; // Number of entries between adjacent entries + // (for each per dimension) - buffer_info() { } + buffer_info() {} - buffer_info(void *ptr, ssize_t itemsize, const std::string &format, ssize_t ndim, - detail::any_container shape_in, detail::any_container strides_in) - : ptr(ptr), itemsize(itemsize), size(1), format(format), ndim(ndim), - shape(std::move(shape_in)), strides(std::move(strides_in)) { - if (ndim != (ssize_t) shape.size() || ndim != (ssize_t) strides.size()) - pybind11_fail("buffer_info: ndim doesn't match shape and/or strides length"); - for (size_t i = 0; i < (size_t) ndim; ++i) - size *= shape[i]; - } - - template - buffer_info(T *ptr, detail::any_container shape_in, detail::any_container strides_in) - : buffer_info(private_ctr_tag(), ptr, sizeof(T), format_descriptor::format(), static_cast(shape_in->size()), std::move(shape_in), std::move(strides_in)) { } - - buffer_info(void *ptr, ssize_t itemsize, const std::string &format, ssize_t size) - : buffer_info(ptr, itemsize, format, 1, {size}, {itemsize}) { } - - template - buffer_info(T *ptr, ssize_t size) - : buffer_info(ptr, sizeof(T), format_descriptor::format(), size) { } - - explicit buffer_info(Py_buffer *view, bool ownview = true) - : buffer_info(view->buf, view->itemsize, view->format, view->ndim, - {view->shape, view->shape + view->ndim}, {view->strides, view->strides + view->ndim}) { - this->view = view; - this->ownview = ownview; - } - - buffer_info(const buffer_info &) = delete; - buffer_info& operator=(const buffer_info &) = delete; - - buffer_info(buffer_info &&other) { - (*this) = std::move(other); - } - - buffer_info& operator=(buffer_info &&rhs) { - ptr = rhs.ptr; - itemsize = rhs.itemsize; - size = rhs.size; - format = std::move(rhs.format); - ndim = rhs.ndim; - shape = std::move(rhs.shape); - strides = std::move(rhs.strides); - std::swap(view, rhs.view); - std::swap(ownview, rhs.ownview); - return *this; - } - - ~buffer_info() { - if (view && ownview) { PyBuffer_Release(view); delete view; } + buffer_info(void *ptr, ssize_t itemsize, const std::string &format, + ssize_t ndim, detail::any_container shape_in, + detail::any_container strides_in) + : ptr(ptr), itemsize(itemsize), size(1), format(format), ndim(ndim), + shape(std::move(shape_in)), strides(std::move(strides_in)) { + if (ndim != (ssize_t)shape.size() || ndim != (ssize_t)strides.size()) + pybind11_fail( + "buffer_info: ndim doesn't match shape and/or strides length"); + for (size_t i = 0; i < (size_t)ndim; ++i) + size *= shape[i]; + } + + template + buffer_info(T *ptr, detail::any_container shape_in, + detail::any_container strides_in) + : buffer_info(private_ctr_tag(), ptr, sizeof(T), + format_descriptor::format(), + static_cast(shape_in->size()), std::move(shape_in), + std::move(strides_in)) {} + + buffer_info(void *ptr, ssize_t itemsize, const std::string &format, + ssize_t size) + : buffer_info(ptr, itemsize, format, 1, {size}, {itemsize}) {} + + template + buffer_info(T *ptr, ssize_t size) + : buffer_info(ptr, sizeof(T), format_descriptor::format(), size) {} + + explicit buffer_info(Py_buffer *view, bool ownview = true) + : buffer_info(view->buf, view->itemsize, view->format, view->ndim, + {view->shape, view->shape + view->ndim}, + {view->strides, view->strides + view->ndim}) { + this->view = view; + this->ownview = ownview; + } + + buffer_info(const buffer_info &) = delete; + buffer_info &operator=(const buffer_info &) = delete; + + buffer_info(buffer_info &&other) { (*this) = std::move(other); } + + buffer_info &operator=(buffer_info &&rhs) { + ptr = rhs.ptr; + itemsize = rhs.itemsize; + size = rhs.size; + format = std::move(rhs.format); + ndim = rhs.ndim; + shape = std::move(rhs.shape); + strides = std::move(rhs.strides); + std::swap(view, rhs.view); + std::swap(ownview, rhs.ownview); + return *this; + } + + ~buffer_info() { + if (view && ownview) { + PyBuffer_Release(view); + delete view; } + } private: - struct private_ctr_tag { }; + struct private_ctr_tag {}; - buffer_info(private_ctr_tag, void *ptr, ssize_t itemsize, const std::string &format, ssize_t ndim, - detail::any_container &&shape_in, detail::any_container &&strides_in) - : buffer_info(ptr, itemsize, format, ndim, std::move(shape_in), std::move(strides_in)) { } + buffer_info(private_ctr_tag, void *ptr, ssize_t itemsize, + const std::string &format, ssize_t ndim, + detail::any_container &&shape_in, + detail::any_container &&strides_in) + : buffer_info(ptr, itemsize, format, ndim, std::move(shape_in), + std::move(strides_in)) {} - Py_buffer *view = nullptr; - bool ownview = false; + Py_buffer *view = nullptr; + bool ownview = false; }; NAMESPACE_BEGIN(detail) template struct compare_buffer_info { - static bool compare(const buffer_info& b) { - return b.format == format_descriptor::format() && b.itemsize == (ssize_t) sizeof(T); - } + static bool compare(const buffer_info &b) { + return b.format == format_descriptor::format() && + b.itemsize == (ssize_t)sizeof(T); + } }; -template struct compare_buffer_info::value>> { - static bool compare(const buffer_info& b) { - return (size_t) b.itemsize == sizeof(T) && (b.format == format_descriptor::value || - ((sizeof(T) == sizeof(long)) && b.format == (std::is_unsigned::value ? "L" : "l")) || - ((sizeof(T) == sizeof(size_t)) && b.format == (std::is_unsigned::value ? "N" : "n"))); - } +template +struct compare_buffer_info::value>> { + static bool compare(const buffer_info &b) { + return (size_t)b.itemsize == sizeof(T) && + (b.format == format_descriptor::value || + ((sizeof(T) == sizeof(long)) && + b.format == (std::is_unsigned::value ? "L" : "l")) || + ((sizeof(T) == sizeof(size_t)) && + b.format == (std::is_unsigned::value ? "N" : "n"))); + } }; NAMESPACE_END(detail) diff --git a/python/src/pybind11/cast.h b/python/src/pybind11/cast.h index 8d0fd5d90..49a9adb29 100644 --- a/python/src/pybind11/cast.h +++ b/python/src/pybind11/cast.h @@ -10,23 +10,23 @@ #pragma once -#include "pytypes.h" -#include "detail/typeid.h" #include "detail/descr.h" #include "detail/internals.h" +#include "detail/typeid.h" +#include "pytypes.h" #include #include #include #include #if defined(PYBIND11_CPP17) -# if defined(__has_include) -# if __has_include() -# define PYBIND11_HAS_STRING_VIEW -# endif -# elif defined(_MSC_VER) -# define PYBIND11_HAS_STRING_VIEW -# endif +#if defined(__has_include) +#if __has_include() +#define PYBIND11_HAS_STRING_VIEW +#endif +#elif defined(_MSC_VER) +#define PYBIND11_HAS_STRING_VIEW +#endif #endif #ifdef PYBIND11_HAS_STRING_VIEW #include @@ -35,444 +35,486 @@ NAMESPACE_BEGIN(PYBIND11_NAMESPACE) NAMESPACE_BEGIN(detail) -/// A life support system for temporary objects created by `type_caster::load()`. -/// Adding a patient will keep it alive up until the enclosing function returns. +/// A life support system for temporary objects created by +/// `type_caster::load()`. Adding a patient will keep it alive up until the +/// enclosing function returns. class loader_life_support { public: - /// A new patient frame is created when a function is entered - loader_life_support() { - get_internals().loader_patient_stack.push_back(nullptr); - } - - /// ... and destroyed after it returns - ~loader_life_support() { - auto &stack = get_internals().loader_patient_stack; - if (stack.empty()) - pybind11_fail("loader_life_support: internal error"); - - auto ptr = stack.back(); - stack.pop_back(); - Py_CLEAR(ptr); - - // A heuristic to reduce the stack's capacity (e.g. after long recursive calls) - if (stack.capacity() > 16 && stack.size() != 0 && stack.capacity() / stack.size() > 2) - stack.shrink_to_fit(); - } - - /// This can only be used inside a pybind11-bound function, either by `argument_loader` - /// at argument preparation time or by `py::cast()` at execution time. - PYBIND11_NOINLINE static void add_patient(handle h) { - auto &stack = get_internals().loader_patient_stack; - if (stack.empty()) - throw cast_error("When called outside a bound function, py::cast() cannot " - "do Python -> C++ conversions which require the creation " - "of temporary values"); - - auto &list_ptr = stack.back(); - if (list_ptr == nullptr) { - list_ptr = PyList_New(1); - if (!list_ptr) - pybind11_fail("loader_life_support: error allocating list"); - PyList_SET_ITEM(list_ptr, 0, h.inc_ref().ptr()); - } else { - auto result = PyList_Append(list_ptr, h.ptr()); - if (result == -1) - pybind11_fail("loader_life_support: error adding patient"); - } + /// A new patient frame is created when a function is entered + loader_life_support() { + get_internals().loader_patient_stack.push_back(nullptr); + } + + /// ... and destroyed after it returns + ~loader_life_support() { + auto &stack = get_internals().loader_patient_stack; + if (stack.empty()) + pybind11_fail("loader_life_support: internal error"); + + auto ptr = stack.back(); + stack.pop_back(); + Py_CLEAR(ptr); + + // A heuristic to reduce the stack's capacity (e.g. after long recursive + // calls) + if (stack.capacity() > 16 && stack.size() != 0 && + stack.capacity() / stack.size() > 2) + stack.shrink_to_fit(); + } + + /// This can only be used inside a pybind11-bound function, either by + /// `argument_loader` at argument preparation time or by `py::cast()` at + /// execution time. + PYBIND11_NOINLINE static void add_patient(handle h) { + auto &stack = get_internals().loader_patient_stack; + if (stack.empty()) + throw cast_error( + "When called outside a bound function, py::cast() cannot " + "do Python -> C++ conversions which require the creation " + "of temporary values"); + + auto &list_ptr = stack.back(); + if (list_ptr == nullptr) { + list_ptr = PyList_New(1); + if (!list_ptr) + pybind11_fail("loader_life_support: error allocating list"); + PyList_SET_ITEM(list_ptr, 0, h.inc_ref().ptr()); + } else { + auto result = PyList_Append(list_ptr, h.ptr()); + if (result == -1) + pybind11_fail("loader_life_support: error adding patient"); } + } }; -// Gets the cache entry for the given type, creating it if necessary. The return value is the pair -// returned by emplace, i.e. an iterator for the entry and a bool set to `true` if the entry was -// just created. -inline std::pair all_type_info_get_cache(PyTypeObject *type); +// Gets the cache entry for the given type, creating it if necessary. The +// return value is the pair returned by emplace, i.e. an iterator for the entry +// and a bool set to `true` if the entry was just created. +inline std::pair +all_type_info_get_cache(PyTypeObject *type); // Populates a just-created cache entry. -PYBIND11_NOINLINE inline void all_type_info_populate(PyTypeObject *t, std::vector &bases) { - std::vector check; - for (handle parent : reinterpret_borrow(t->tp_bases)) - check.push_back((PyTypeObject *) parent.ptr()); +PYBIND11_NOINLINE inline void +all_type_info_populate(PyTypeObject *t, std::vector &bases) { + std::vector check; + for (handle parent : reinterpret_borrow(t->tp_bases)) + check.push_back((PyTypeObject *)parent.ptr()); - auto const &type_dict = get_internals().registered_types_py; - for (size_t i = 0; i < check.size(); i++) { - auto type = check[i]; - // Ignore Python2 old-style class super type: - if (!PyType_Check((PyObject *) type)) continue; + auto const &type_dict = get_internals().registered_types_py; + for (size_t i = 0; i < check.size(); i++) { + auto type = check[i]; + // Ignore Python2 old-style class super type: + if (!PyType_Check((PyObject *)type)) + continue; - // Check `type` in the current set of registered python types: - auto it = type_dict.find(type); - if (it != type_dict.end()) { - // We found a cache entry for it, so it's either pybind-registered or has pre-computed - // pybind bases, but we have to make sure we haven't already seen the type(s) before: we - // want to follow Python/virtual C++ rules that there should only be one instance of a - // common base. - for (auto *tinfo : it->second) { - // NB: Could use a second set here, rather than doing a linear search, but since - // having a large number of immediate pybind11-registered types seems fairly - // unlikely, that probably isn't worthwhile. - bool found = false; - for (auto *known : bases) { - if (known == tinfo) { found = true; break; } - } - if (!found) bases.push_back(tinfo); - } - } - else if (type->tp_bases) { - // It's some python type, so keep follow its bases classes to look for one or more - // registered types - if (i + 1 == check.size()) { - // When we're at the end, we can pop off the current element to avoid growing - // `check` when adding just one base (which is typical--i.e. when there is no - // multiple inheritance) - check.pop_back(); - i--; - } - for (handle parent : reinterpret_borrow(type->tp_bases)) - check.push_back((PyTypeObject *) parent.ptr()); + // Check `type` in the current set of registered python types: + auto it = type_dict.find(type); + if (it != type_dict.end()) { + // We found a cache entry for it, so it's either pybind-registered or has + // pre-computed pybind bases, but we have to make sure we haven't already + // seen the type(s) before: we want to follow Python/virtual C++ rules + // that there should only be one instance of a common base. + for (auto *tinfo : it->second) { + // NB: Could use a second set here, rather than doing a linear search, + // but since having a large number of immediate pybind11-registered + // types seems fairly unlikely, that probably isn't worthwhile. + bool found = false; + for (auto *known : bases) { + if (known == tinfo) { + found = true; + break; + } } + if (!found) + bases.push_back(tinfo); + } + } else if (type->tp_bases) { + // It's some python type, so keep follow its bases classes to look for one + // or more registered types + if (i + 1 == check.size()) { + // When we're at the end, we can pop off the current element to avoid + // growing `check` when adding just one base (which is typical--i.e. + // when there is no multiple inheritance) + check.pop_back(); + i--; + } + for (handle parent : reinterpret_borrow(type->tp_bases)) + check.push_back((PyTypeObject *)parent.ptr()); } + } } /** - * Extracts vector of type_info pointers of pybind-registered roots of the given Python type. Will - * be just 1 pybind type for the Python type of a pybind-registered class, or for any Python-side - * derived class that uses single inheritance. Will contain as many types as required for a Python - * class that uses multiple inheritance to inherit (directly or indirectly) from multiple - * pybind-registered classes. Will be empty if neither the type nor any base classes are - * pybind-registered. + * Extracts vector of type_info pointers of pybind-registered roots of the given + * Python type. Will be just 1 pybind type for the Python type of a + * pybind-registered class, or for any Python-side derived class that uses + * single inheritance. Will contain as many types as required for a Python + * class that uses multiple inheritance to inherit (directly or indirectly) from + * multiple pybind-registered classes. Will be empty if neither the type nor + * any base classes are pybind-registered. * * The value is cached for the lifetime of the Python type. */ -inline const std::vector &all_type_info(PyTypeObject *type) { - auto ins = all_type_info_get_cache(type); - if (ins.second) - // New cache entry: populate it - all_type_info_populate(type, ins.first->second); +inline const std::vector & +all_type_info(PyTypeObject *type) { + auto ins = all_type_info_get_cache(type); + if (ins.second) + // New cache entry: populate it + all_type_info_populate(type, ins.first->second); - return ins.first->second; + return ins.first->second; } /** - * Gets a single pybind11 type info for a python type. Returns nullptr if neither the type nor any - * ancestors are pybind11-registered. Throws an exception if there are multiple bases--use - * `all_type_info` instead if you want to support multiple bases. + * Gets a single pybind11 type info for a python type. Returns nullptr if + * neither the type nor any ancestors are pybind11-registered. Throws an + * exception if there are multiple bases--use `all_type_info` instead if you + * want to support multiple bases. */ -PYBIND11_NOINLINE inline detail::type_info* get_type_info(PyTypeObject *type) { - auto &bases = all_type_info(type); - if (bases.size() == 0) - return nullptr; - if (bases.size() > 1) - pybind11_fail("pybind11::detail::get_type_info: type has multiple pybind11-registered bases"); - return bases.front(); +PYBIND11_NOINLINE inline detail::type_info *get_type_info(PyTypeObject *type) { + auto &bases = all_type_info(type); + if (bases.size() == 0) + return nullptr; + if (bases.size() > 1) + pybind11_fail("pybind11::detail::get_type_info: type has multiple " + "pybind11-registered bases"); + return bases.front(); } inline detail::type_info *get_local_type_info(const std::type_index &tp) { - auto &locals = registered_local_types_cpp(); - auto it = locals.find(tp); - if (it != locals.end()) - return it->second; - return nullptr; + auto &locals = registered_local_types_cpp(); + auto it = locals.find(tp); + if (it != locals.end()) + return it->second; + return nullptr; } inline detail::type_info *get_global_type_info(const std::type_index &tp) { - auto &types = get_internals().registered_types_cpp; - auto it = types.find(tp); - if (it != types.end()) - return it->second; - return nullptr; + auto &types = get_internals().registered_types_cpp; + auto it = types.find(tp); + if (it != types.end()) + return it->second; + return nullptr; } -/// Return the type info for a given C++ type; on lookup failure can either throw or return nullptr. -PYBIND11_NOINLINE inline detail::type_info *get_type_info(const std::type_index &tp, - bool throw_if_missing = false) { - if (auto ltype = get_local_type_info(tp)) - return ltype; - if (auto gtype = get_global_type_info(tp)) - return gtype; +/// Return the type info for a given C++ type; on lookup failure can either +/// throw or return nullptr. +PYBIND11_NOINLINE inline detail::type_info * +get_type_info(const std::type_index &tp, bool throw_if_missing = false) { + if (auto ltype = get_local_type_info(tp)) + return ltype; + if (auto gtype = get_global_type_info(tp)) + return gtype; - if (throw_if_missing) { - std::string tname = tp.name(); - detail::clean_type_id(tname); - pybind11_fail("pybind11::detail::get_type_info: unable to find type info for \"" + tname + "\""); - } - return nullptr; + if (throw_if_missing) { + std::string tname = tp.name(); + detail::clean_type_id(tname); + pybind11_fail( + "pybind11::detail::get_type_info: unable to find type info for \"" + + tname + "\""); + } + return nullptr; } -PYBIND11_NOINLINE inline handle get_type_handle(const std::type_info &tp, bool throw_if_missing) { - detail::type_info *type_info = get_type_info(tp, throw_if_missing); - return handle(type_info ? ((PyObject *) type_info->type) : nullptr); +PYBIND11_NOINLINE inline handle get_type_handle(const std::type_info &tp, + bool throw_if_missing) { + detail::type_info *type_info = get_type_info(tp, throw_if_missing); + return handle(type_info ? ((PyObject *)type_info->type) : nullptr); } struct value_and_holder { - instance *inst = nullptr; - size_t index = 0u; - const detail::type_info *type = nullptr; - void **vh = nullptr; + instance *inst = nullptr; + size_t index = 0u; + const detail::type_info *type = nullptr; + void **vh = nullptr; - // Main constructor for a found value/holder: - value_and_holder(instance *i, const detail::type_info *type, size_t vpos, size_t index) : - inst{i}, index{index}, type{type}, - vh{inst->simple_layout ? inst->simple_value_holder : &inst->nonsimple.values_and_holders[vpos]} - {} + // Main constructor for a found value/holder: + value_and_holder(instance *i, const detail::type_info *type, size_t vpos, + size_t index) + : inst{i}, index{index}, type{type}, + vh{inst->simple_layout ? inst->simple_value_holder + : &inst->nonsimple.values_and_holders[vpos]} {} - // Default constructor (used to signal a value-and-holder not found by get_value_and_holder()) - value_and_holder() {} + // Default constructor (used to signal a value-and-holder not found by + // get_value_and_holder()) + value_and_holder() {} - // Used for past-the-end iterator - value_and_holder(size_t index) : index{index} {} + // Used for past-the-end iterator + value_and_holder(size_t index) : index{index} {} - template V *&value_ptr() const { - return reinterpret_cast(vh[0]); - } - // True if this `value_and_holder` has a non-null value pointer - explicit operator bool() const { return value_ptr(); } + template V *&value_ptr() const { + return reinterpret_cast(vh[0]); + } + // True if this `value_and_holder` has a non-null value pointer + explicit operator bool() const { return value_ptr(); } - template H &holder() const { - return reinterpret_cast(vh[1]); - } - bool holder_constructed() const { - return inst->simple_layout - ? inst->simple_holder_constructed - : inst->nonsimple.status[index] & instance::status_holder_constructed; - } - void set_holder_constructed(bool v = true) { - if (inst->simple_layout) - inst->simple_holder_constructed = v; - else if (v) - inst->nonsimple.status[index] |= instance::status_holder_constructed; - else - inst->nonsimple.status[index] &= (uint8_t) ~instance::status_holder_constructed; - } - bool instance_registered() const { - return inst->simple_layout - ? inst->simple_instance_registered - : inst->nonsimple.status[index] & instance::status_instance_registered; - } - void set_instance_registered(bool v = true) { - if (inst->simple_layout) - inst->simple_instance_registered = v; - else if (v) - inst->nonsimple.status[index] |= instance::status_instance_registered; - else - inst->nonsimple.status[index] &= (uint8_t) ~instance::status_instance_registered; - } + template H &holder() const { + return reinterpret_cast(vh[1]); + } + bool holder_constructed() const { + return inst->simple_layout ? inst->simple_holder_constructed + : inst->nonsimple.status[index] & + instance::status_holder_constructed; + } + void set_holder_constructed(bool v = true) { + if (inst->simple_layout) + inst->simple_holder_constructed = v; + else if (v) + inst->nonsimple.status[index] |= instance::status_holder_constructed; + else + inst->nonsimple.status[index] &= + (uint8_t)~instance::status_holder_constructed; + } + bool instance_registered() const { + return inst->simple_layout ? inst->simple_instance_registered + : inst->nonsimple.status[index] & + instance::status_instance_registered; + } + void set_instance_registered(bool v = true) { + if (inst->simple_layout) + inst->simple_instance_registered = v; + else if (v) + inst->nonsimple.status[index] |= instance::status_instance_registered; + else + inst->nonsimple.status[index] &= + (uint8_t)~instance::status_instance_registered; + } }; // Container for accessing and iterating over an instance's values/holders struct values_and_holders { private: - instance *inst; - using type_vec = std::vector; - const type_vec &tinfo; + instance *inst; + using type_vec = std::vector; + const type_vec &tinfo; public: - values_and_holders(instance *inst) : inst{inst}, tinfo(all_type_info(Py_TYPE(inst))) {} + values_and_holders(instance *inst) + : inst{inst}, tinfo(all_type_info(Py_TYPE(inst))) {} - struct iterator { - private: - instance *inst = nullptr; - const type_vec *types = nullptr; - value_and_holder curr; - friend struct values_and_holders; - iterator(instance *inst, const type_vec *tinfo) - : inst{inst}, types{tinfo}, - curr(inst /* instance */, - types->empty() ? nullptr : (*types)[0] /* type info */, - 0, /* vpos: (non-simple types only): the first vptr comes first */ - 0 /* index */) - {} - // Past-the-end iterator: - iterator(size_t end) : curr(end) {} - public: - bool operator==(const iterator &other) { return curr.index == other.curr.index; } - bool operator!=(const iterator &other) { return curr.index != other.curr.index; } - iterator &operator++() { - if (!inst->simple_layout) - curr.vh += 1 + (*types)[curr.index]->holder_size_in_ptrs; - ++curr.index; - curr.type = curr.index < types->size() ? (*types)[curr.index] : nullptr; - return *this; - } - value_and_holder &operator*() { return curr; } - value_and_holder *operator->() { return &curr; } - }; + struct iterator { + private: + instance *inst = nullptr; + const type_vec *types = nullptr; + value_and_holder curr; + friend struct values_and_holders; + iterator(instance *inst, const type_vec *tinfo) + : inst{inst}, types{tinfo}, + curr( + inst /* instance */, + types->empty() ? nullptr : (*types)[0] /* type info */, + 0, /* vpos: (non-simple types only): the first vptr comes first */ + 0 /* index */) {} + // Past-the-end iterator: + iterator(size_t end) : curr(end) {} - iterator begin() { return iterator(inst, &tinfo); } - iterator end() { return iterator(tinfo.size()); } - - iterator find(const type_info *find_type) { - auto it = begin(), endit = end(); - while (it != endit && it->type != find_type) ++it; - return it; + public: + bool operator==(const iterator &other) { + return curr.index == other.curr.index; } + bool operator!=(const iterator &other) { + return curr.index != other.curr.index; + } + iterator &operator++() { + if (!inst->simple_layout) + curr.vh += 1 + (*types)[curr.index]->holder_size_in_ptrs; + ++curr.index; + curr.type = curr.index < types->size() ? (*types)[curr.index] : nullptr; + return *this; + } + value_and_holder &operator*() { return curr; } + value_and_holder *operator->() { return &curr; } + }; - size_t size() { return tinfo.size(); } + iterator begin() { return iterator(inst, &tinfo); } + iterator end() { return iterator(tinfo.size()); } + + iterator find(const type_info *find_type) { + auto it = begin(), endit = end(); + while (it != endit && it->type != find_type) + ++it; + return it; + } + + size_t size() { return tinfo.size(); } }; /** - * Extracts C++ value and holder pointer references from an instance (which may contain multiple - * values/holders for python-side multiple inheritance) that match the given type. Throws an error - * if the given type (or ValueType, if omitted) is not a pybind11 base of the given instance. If - * `find_type` is omitted (or explicitly specified as nullptr) the first value/holder are returned, - * regardless of type (and the resulting .type will be nullptr). + * Extracts C++ value and holder pointer references from an instance (which may + * contain multiple values/holders for python-side multiple inheritance) that + * match the given type. Throws an error if the given type (or ValueType, if + * omitted) is not a pybind11 base of the given instance. If `find_type` is + * omitted (or explicitly specified as nullptr) the first value/holder are + * returned, regardless of type (and the resulting .type will be nullptr). * - * The returned object should be short-lived: in particular, it must not outlive the called-upon - * instance. + * The returned object should be short-lived: in particular, it must not outlive + * the called-upon instance. */ -PYBIND11_NOINLINE inline value_and_holder instance::get_value_and_holder(const type_info *find_type /*= nullptr default in common.h*/, bool throw_if_missing /*= true in common.h*/) { - // Optimize common case: - if (!find_type || Py_TYPE(this) == find_type->type) - return value_and_holder(this, find_type, 0, 0); +PYBIND11_NOINLINE inline value_and_holder instance::get_value_and_holder( + const type_info *find_type /*= nullptr default in common.h*/, + bool throw_if_missing /*= true in common.h*/) { + // Optimize common case: + if (!find_type || Py_TYPE(this) == find_type->type) + return value_and_holder(this, find_type, 0, 0); - detail::values_and_holders vhs(this); - auto it = vhs.find(find_type); - if (it != vhs.end()) - return *it; + detail::values_and_holders vhs(this); + auto it = vhs.find(find_type); + if (it != vhs.end()) + return *it; - if (!throw_if_missing) - return value_and_holder(); + if (!throw_if_missing) + return value_and_holder(); #if defined(NDEBUG) - pybind11_fail("pybind11::detail::instance::get_value_and_holder: " - "type is not a pybind11 base of the given instance " - "(compile in debug mode for type details)"); + pybind11_fail("pybind11::detail::instance::get_value_and_holder: " + "type is not a pybind11 base of the given instance " + "(compile in debug mode for type details)"); #else - pybind11_fail("pybind11::detail::instance::get_value_and_holder: `" + - std::string(find_type->type->tp_name) + "' is not a pybind11 base of the given `" + - std::string(Py_TYPE(this)->tp_name) + "' instance"); + pybind11_fail("pybind11::detail::instance::get_value_and_holder: `" + + std::string(find_type->type->tp_name) + + "' is not a pybind11 base of the given `" + + std::string(Py_TYPE(this)->tp_name) + "' instance"); #endif } PYBIND11_NOINLINE inline void instance::allocate_layout() { - auto &tinfo = all_type_info(Py_TYPE(this)); + auto &tinfo = all_type_info(Py_TYPE(this)); - const size_t n_types = tinfo.size(); + const size_t n_types = tinfo.size(); - if (n_types == 0) - pybind11_fail("instance allocation failed: new instance has no pybind11-registered base types"); + if (n_types == 0) + pybind11_fail("instance allocation failed: new instance has no " + "pybind11-registered base types"); - simple_layout = - n_types == 1 && tinfo.front()->holder_size_in_ptrs <= instance_simple_holder_in_ptrs(); + simple_layout = n_types == 1 && tinfo.front()->holder_size_in_ptrs <= + instance_simple_holder_in_ptrs(); - // Simple path: no python-side multiple inheritance, and a small-enough holder - if (simple_layout) { - simple_value_holder[0] = nullptr; - simple_holder_constructed = false; - simple_instance_registered = false; + // Simple path: no python-side multiple inheritance, and a small-enough holder + if (simple_layout) { + simple_value_holder[0] = nullptr; + simple_holder_constructed = false; + simple_instance_registered = false; + } else { // multiple base types or a too-large holder + // Allocate space to hold: [v1*][h1][v2*][h2]...[bb...] where [vN*] is a + // value pointer, [hN] is the (uninitialized) holder instance for value N, + // and [bb...] is a set of bool values that tracks whether each associated + // holder has been initialized. Each [block] is padded, if necessary, to an + // integer multiple of sizeof(void *). + size_t space = 0; + for (auto t : tinfo) { + space += 1; // value pointer + space += t->holder_size_in_ptrs; // holder instance } - else { // multiple base types or a too-large holder - // Allocate space to hold: [v1*][h1][v2*][h2]...[bb...] where [vN*] is a value pointer, - // [hN] is the (uninitialized) holder instance for value N, and [bb...] is a set of bool - // values that tracks whether each associated holder has been initialized. Each [block] is - // padded, if necessary, to an integer multiple of sizeof(void *). - size_t space = 0; - for (auto t : tinfo) { - space += 1; // value pointer - space += t->holder_size_in_ptrs; // holder instance - } - size_t flags_at = space; - space += size_in_ptrs(n_types); // status bytes (holder_constructed and instance_registered) + size_t flags_at = space; + space += size_in_ptrs( + n_types); // status bytes (holder_constructed and instance_registered) - // Allocate space for flags, values, and holders, and initialize it to 0 (flags and values, - // in particular, need to be 0). Use Python's memory allocation functions: in Python 3.6 - // they default to using pymalloc, which is designed to be efficient for small allocations - // like the one we're doing here; in earlier versions (and for larger allocations) they are - // just wrappers around malloc. + // Allocate space for flags, values, and holders, and initialize it to 0 + // (flags and values, in particular, need to be 0). Use Python's memory + // allocation functions: in Python 3.6 they default to using pymalloc, which + // is designed to be efficient for small allocations like the one we're + // doing here; in earlier versions (and for larger allocations) they are + // just wrappers around malloc. #if PY_VERSION_HEX >= 0x03050000 - nonsimple.values_and_holders = (void **) PyMem_Calloc(space, sizeof(void *)); - if (!nonsimple.values_and_holders) throw std::bad_alloc(); + nonsimple.values_and_holders = (void **)PyMem_Calloc(space, sizeof(void *)); + if (!nonsimple.values_and_holders) + throw std::bad_alloc(); #else - nonsimple.values_and_holders = (void **) PyMem_New(void *, space); - if (!nonsimple.values_and_holders) throw std::bad_alloc(); - std::memset(nonsimple.values_and_holders, 0, space * sizeof(void *)); + nonsimple.values_and_holders = (void **)PyMem_New(void *, space); + if (!nonsimple.values_and_holders) + throw std::bad_alloc(); + std::memset(nonsimple.values_and_holders, 0, space * sizeof(void *)); #endif - nonsimple.status = reinterpret_cast(&nonsimple.values_and_holders[flags_at]); - } - owned = true; + nonsimple.status = + reinterpret_cast(&nonsimple.values_and_holders[flags_at]); + } + owned = true; } PYBIND11_NOINLINE inline void instance::deallocate_layout() { - if (!simple_layout) - PyMem_Free(nonsimple.values_and_holders); + if (!simple_layout) + PyMem_Free(nonsimple.values_and_holders); } -PYBIND11_NOINLINE inline bool isinstance_generic(handle obj, const std::type_info &tp) { - handle type = detail::get_type_handle(tp, false); - if (!type) - return false; - return isinstance(obj, type); +PYBIND11_NOINLINE inline bool isinstance_generic(handle obj, + const std::type_info &tp) { + handle type = detail::get_type_handle(tp, false); + if (!type) + return false; + return isinstance(obj, type); } PYBIND11_NOINLINE inline std::string error_string() { - if (!PyErr_Occurred()) { - PyErr_SetString(PyExc_RuntimeError, "Unknown internal error occurred"); - return "Unknown internal error occurred"; - } + if (!PyErr_Occurred()) { + PyErr_SetString(PyExc_RuntimeError, "Unknown internal error occurred"); + return "Unknown internal error occurred"; + } - error_scope scope; // Preserve error state + error_scope scope; // Preserve error state - std::string errorString; - if (scope.type) { - errorString += handle(scope.type).attr("__name__").cast(); - errorString += ": "; - } - if (scope.value) - errorString += (std::string) str(scope.value); + std::string errorString; + if (scope.type) { + errorString += handle(scope.type).attr("__name__").cast(); + errorString += ": "; + } + if (scope.value) + errorString += (std::string)str(scope.value); - PyErr_NormalizeException(&scope.type, &scope.value, &scope.trace); + PyErr_NormalizeException(&scope.type, &scope.value, &scope.trace); #if PY_MAJOR_VERSION >= 3 - if (scope.trace != nullptr) - PyException_SetTraceback(scope.value, scope.trace); + if (scope.trace != nullptr) + PyException_SetTraceback(scope.value, scope.trace); #endif #if !defined(PYPY_VERSION) - if (scope.trace) { - PyTracebackObject *trace = (PyTracebackObject *) scope.trace; + if (scope.trace) { + PyTracebackObject *trace = (PyTracebackObject *)scope.trace; - /* Get the deepest trace possible */ - while (trace->tb_next) - trace = trace->tb_next; + /* Get the deepest trace possible */ + while (trace->tb_next) + trace = trace->tb_next; - PyFrameObject *frame = trace->tb_frame; - errorString += "\n\nAt:\n"; - while (frame) { - int lineno = PyFrame_GetLineNumber(frame); - errorString += - " " + handle(frame->f_code->co_filename).cast() + - "(" + std::to_string(lineno) + "): " + - handle(frame->f_code->co_name).cast() + "\n"; - frame = frame->f_back; - } + PyFrameObject *frame = trace->tb_frame; + errorString += "\n\nAt:\n"; + while (frame) { + int lineno = PyFrame_GetLineNumber(frame); + errorString += + " " + handle(frame->f_code->co_filename).cast() + "(" + + std::to_string(lineno) + + "): " + handle(frame->f_code->co_name).cast() + "\n"; + frame = frame->f_back; } + } #endif - return errorString; + return errorString; } -PYBIND11_NOINLINE inline handle get_object_handle(const void *ptr, const detail::type_info *type ) { - auto &instances = get_internals().registered_instances; - auto range = instances.equal_range(ptr); - for (auto it = range.first; it != range.second; ++it) { - for (auto vh : values_and_holders(it->second)) { - if (vh.type == type) - return handle((PyObject *) it->second); - } +PYBIND11_NOINLINE inline handle +get_object_handle(const void *ptr, const detail::type_info *type) { + auto &instances = get_internals().registered_instances; + auto range = instances.equal_range(ptr); + for (auto it = range.first; it != range.second; ++it) { + for (auto vh : values_and_holders(it->second)) { + if (vh.type == type) + return handle((PyObject *)it->second); } - return handle(); + } + return handle(); } inline PyThreadState *get_thread_state_unchecked() { #if defined(PYPY_VERSION) - return PyThreadState_GET(); + return PyThreadState_GET(); #elif PY_VERSION_HEX < 0x03000000 - return _PyThreadState_Current; + return _PyThreadState_Current; #elif PY_VERSION_HEX < 0x03050000 - return (PyThreadState*) _Py_atomic_load_relaxed(&_PyThreadState_Current); + return (PyThreadState *)_Py_atomic_load_relaxed(&_PyThreadState_Current); #elif PY_VERSION_HEX < 0x03050200 - return (PyThreadState*) _PyThreadState_Current.value; + return (PyThreadState *)_PyThreadState_Current.value; #else - return _PyThreadState_UncheckedGet(); + return _PyThreadState_UncheckedGet(); #endif } @@ -482,1100 +524,1283 @@ inline PyObject *make_new_instance(PyTypeObject *type); class type_caster_generic { public: - PYBIND11_NOINLINE type_caster_generic(const std::type_info &type_info) - : typeinfo(get_type_info(type_info)), cpptype(&type_info) { } + PYBIND11_NOINLINE type_caster_generic(const std::type_info &type_info) + : typeinfo(get_type_info(type_info)), cpptype(&type_info) {} - type_caster_generic(const type_info *typeinfo) - : typeinfo(typeinfo), cpptype(typeinfo ? typeinfo->cpptype : nullptr) { } + type_caster_generic(const type_info *typeinfo) + : typeinfo(typeinfo), cpptype(typeinfo ? typeinfo->cpptype : nullptr) {} - bool load(handle src, bool convert) { - return load_impl(src, convert); + bool load(handle src, bool convert) { + return load_impl(src, convert); + } + + PYBIND11_NOINLINE static handle + cast(const void *_src, return_value_policy policy, handle parent, + const detail::type_info *tinfo, void *(*copy_constructor)(const void *), + void *(*move_constructor)(const void *), + const void *existing_holder = nullptr) { + if (!tinfo) // no type info: error will be set already + return handle(); + + void *src = const_cast(_src); + if (src == nullptr) + return none().release(); + + auto it_instances = get_internals().registered_instances.equal_range(src); + for (auto it_i = it_instances.first; it_i != it_instances.second; ++it_i) { + for (auto instance_type : detail::all_type_info(Py_TYPE(it_i->second))) { + if (instance_type && + same_type(*instance_type->cpptype, *tinfo->cpptype)) + return handle((PyObject *)it_i->second).inc_ref(); + } } - PYBIND11_NOINLINE static handle cast(const void *_src, return_value_policy policy, handle parent, - const detail::type_info *tinfo, - void *(*copy_constructor)(const void *), - void *(*move_constructor)(const void *), - const void *existing_holder = nullptr) { - if (!tinfo) // no type info: error will be set already - return handle(); + auto inst = reinterpret_steal(make_new_instance(tinfo->type)); + auto wrapper = reinterpret_cast(inst.ptr()); + wrapper->owned = false; + void *&valueptr = values_and_holders(wrapper).begin()->value_ptr(); - void *src = const_cast(_src); - if (src == nullptr) - return none().release(); + switch (policy) { + case return_value_policy::automatic: + case return_value_policy::take_ownership: + valueptr = src; + wrapper->owned = true; + break; - auto it_instances = get_internals().registered_instances.equal_range(src); - for (auto it_i = it_instances.first; it_i != it_instances.second; ++it_i) { - for (auto instance_type : detail::all_type_info(Py_TYPE(it_i->second))) { - if (instance_type && same_type(*instance_type->cpptype, *tinfo->cpptype)) - return handle((PyObject *) it_i->second).inc_ref(); - } - } + case return_value_policy::automatic_reference: + case return_value_policy::reference: + valueptr = src; + wrapper->owned = false; + break; - auto inst = reinterpret_steal(make_new_instance(tinfo->type)); - auto wrapper = reinterpret_cast(inst.ptr()); - wrapper->owned = false; - void *&valueptr = values_and_holders(wrapper).begin()->value_ptr(); + case return_value_policy::copy: + if (copy_constructor) + valueptr = copy_constructor(src); + else + throw cast_error("return_value_policy = copy, but the " + "object is non-copyable!"); + wrapper->owned = true; + break; - switch (policy) { - case return_value_policy::automatic: - case return_value_policy::take_ownership: - valueptr = src; - wrapper->owned = true; - break; + case return_value_policy::move: + if (move_constructor) + valueptr = move_constructor(src); + else if (copy_constructor) + valueptr = copy_constructor(src); + else + throw cast_error("return_value_policy = move, but the " + "object is neither movable nor copyable!"); + wrapper->owned = true; + break; - case return_value_policy::automatic_reference: - case return_value_policy::reference: - valueptr = src; - wrapper->owned = false; - break; + case return_value_policy::reference_internal: + valueptr = src; + wrapper->owned = false; + keep_alive_impl(inst, parent); + break; - case return_value_policy::copy: - if (copy_constructor) - valueptr = copy_constructor(src); - else - throw cast_error("return_value_policy = copy, but the " - "object is non-copyable!"); - wrapper->owned = true; - break; - - case return_value_policy::move: - if (move_constructor) - valueptr = move_constructor(src); - else if (copy_constructor) - valueptr = copy_constructor(src); - else - throw cast_error("return_value_policy = move, but the " - "object is neither movable nor copyable!"); - wrapper->owned = true; - break; - - case return_value_policy::reference_internal: - valueptr = src; - wrapper->owned = false; - keep_alive_impl(inst, parent); - break; - - default: - throw cast_error("unhandled return_value_policy: should not happen!"); - } - - tinfo->init_instance(wrapper, existing_holder); - - return inst.release(); + default: + throw cast_error("unhandled return_value_policy: should not happen!"); } - // Base methods for generic caster; there are overridden in copyable_holder_caster - void load_value(value_and_holder &&v_h) { - auto *&vptr = v_h.value_ptr(); - // Lazy allocation for unallocated values: - if (vptr == nullptr) { - auto *type = v_h.type ? v_h.type : typeinfo; - if (type->operator_new) { - vptr = type->operator_new(type->type_size); - } else { - #if defined(PYBIND11_CPP17) - if (type->type_align > __STDCPP_DEFAULT_NEW_ALIGNMENT__) - vptr = ::operator new(type->type_size, - (std::align_val_t) type->type_align); - else - #endif - vptr = ::operator new(type->type_size); - } - } - value = vptr; + tinfo->init_instance(wrapper, existing_holder); + + return inst.release(); + } + + // Base methods for generic caster; there are overridden in + // copyable_holder_caster + void load_value(value_and_holder &&v_h) { + auto *&vptr = v_h.value_ptr(); + // Lazy allocation for unallocated values: + if (vptr == nullptr) { + auto *type = v_h.type ? v_h.type : typeinfo; + if (type->operator_new) { + vptr = type->operator_new(type->type_size); + } else { +#if defined(PYBIND11_CPP17) + if (type->type_align > __STDCPP_DEFAULT_NEW_ALIGNMENT__) + vptr = ::operator new(type->type_size, + (std::align_val_t)type->type_align); + else +#endif + vptr = ::operator new(type->type_size); + } } - bool try_implicit_casts(handle src, bool convert) { - for (auto &cast : typeinfo->implicit_casts) { - type_caster_generic sub_caster(*cast.first); - if (sub_caster.load(src, convert)) { - value = cast.second(sub_caster.value); - return true; - } - } + value = vptr; + } + bool try_implicit_casts(handle src, bool convert) { + for (auto &cast : typeinfo->implicit_casts) { + type_caster_generic sub_caster(*cast.first); + if (sub_caster.load(src, convert)) { + value = cast.second(sub_caster.value); + return true; + } + } + return false; + } + bool try_direct_conversions(handle src) { + for (auto &converter : *typeinfo->direct_conversions) { + if (converter(src.ptr(), value)) + return true; + } + return false; + } + void check_holder_compat() {} + + PYBIND11_NOINLINE static void *local_load(PyObject *src, + const type_info *ti) { + auto caster = type_caster_generic(ti); + if (caster.load(src, false)) + return caster.value; + return nullptr; + } + + /// Try to load with foreign typeinfo, if available. Used when there is no + /// native typeinfo, or when the native one wasn't able to produce a value. + PYBIND11_NOINLINE bool try_load_foreign_module_local(handle src) { + constexpr auto *local_key = PYBIND11_MODULE_LOCAL_ID; + const auto pytype = src.get_type(); + if (!hasattr(pytype, local_key)) + return false; + + type_info *foreign_typeinfo = + reinterpret_borrow(getattr(pytype, local_key)); + // Only consider this foreign loader if actually foreign and is a loader of + // the correct cpp type + if (foreign_typeinfo->module_local_load == &local_load || + (cpptype && !same_type(*cpptype, *foreign_typeinfo->cpptype))) + return false; + + if (auto result = + foreign_typeinfo->module_local_load(src.ptr(), foreign_typeinfo)) { + value = result; + return true; + } + return false; + } + + // Implementation of `load`; this takes the type of `this` so that it can + // dispatch the relevant bits of code between here and copyable_holder_caster + // where the two classes need different logic (without having to resort to + // virtual inheritance). + template + PYBIND11_NOINLINE bool load_impl(handle src, bool convert) { + if (!src) + return false; + if (!typeinfo) + return try_load_foreign_module_local(src); + if (src.is_none()) { + // Defer accepting None to other overloads (if we aren't in convert mode): + if (!convert) return false; - } - bool try_direct_conversions(handle src) { - for (auto &converter : *typeinfo->direct_conversions) { - if (converter(src.ptr(), value)) - return true; - } - return false; - } - void check_holder_compat() {} - - PYBIND11_NOINLINE static void *local_load(PyObject *src, const type_info *ti) { - auto caster = type_caster_generic(ti); - if (caster.load(src, false)) - return caster.value; - return nullptr; + value = nullptr; + return true; } - /// Try to load with foreign typeinfo, if available. Used when there is no - /// native typeinfo, or when the native one wasn't able to produce a value. - PYBIND11_NOINLINE bool try_load_foreign_module_local(handle src) { - constexpr auto *local_key = PYBIND11_MODULE_LOCAL_ID; - const auto pytype = src.get_type(); - if (!hasattr(pytype, local_key)) - return false; + auto &this_ = static_cast(*this); + this_.check_holder_compat(); - type_info *foreign_typeinfo = reinterpret_borrow(getattr(pytype, local_key)); - // Only consider this foreign loader if actually foreign and is a loader of the correct cpp type - if (foreign_typeinfo->module_local_load == &local_load - || (cpptype && !same_type(*cpptype, *foreign_typeinfo->cpptype))) - return false; + PyTypeObject *srctype = Py_TYPE(src.ptr()); - if (auto result = foreign_typeinfo->module_local_load(src.ptr(), foreign_typeinfo)) { - value = result; + // Case 1: If src is an exact type match for the target type then we can + // reinterpret_cast the instance's value pointer to the target type: + if (srctype == typeinfo->type) { + this_.load_value( + reinterpret_cast(src.ptr())->get_value_and_holder()); + return true; + } + // Case 2: We have a derived class + else if (PyType_IsSubtype(srctype, typeinfo->type)) { + auto &bases = all_type_info(srctype); + bool no_cpp_mi = typeinfo->simple_type; + + // Case 2a: the python type is a Python-inherited derived class that + // inherits from just one simple (no MI) pybind11 class, or is an exact + // match, so the C++ instance is of the right type and we can use + // reinterpret_cast. (This is essentially the same as case 2b, but because + // not using multiple inheritance is extremely common, we handle it + // specially to avoid the loop iterator and type pointer lookup overhead) + if (bases.size() == 1 && + (no_cpp_mi || bases.front()->type == typeinfo->type)) { + this_.load_value( + reinterpret_cast(src.ptr())->get_value_and_holder()); + return true; + } + // Case 2b: the python type inherits from multiple C++ bases. Check the + // bases to see if we can find an exact match (or, for a simple C++ type, + // an inherited match); if so, we can safely reinterpret_cast to the + // relevant pointer. + else if (bases.size() > 1) { + for (auto base : bases) { + if (no_cpp_mi ? PyType_IsSubtype(base->type, typeinfo->type) + : base->type == typeinfo->type) { + this_.load_value( + reinterpret_cast(src.ptr())->get_value_and_holder( + base)); return true; + } } - return false; + } + + // Case 2c: C++ multiple inheritance is involved and we couldn't find an + // exact type match in the registered bases, above, so try implicit + // casting (needed for proper C++ casting when MI is involved). + if (this_.try_implicit_casts(src, convert)) + return true; } - // Implementation of `load`; this takes the type of `this` so that it can dispatch the relevant - // bits of code between here and copyable_holder_caster where the two classes need different - // logic (without having to resort to virtual inheritance). - template - PYBIND11_NOINLINE bool load_impl(handle src, bool convert) { - if (!src) return false; - if (!typeinfo) return try_load_foreign_module_local(src); - if (src.is_none()) { - // Defer accepting None to other overloads (if we aren't in convert mode): - if (!convert) return false; - value = nullptr; - return true; + // Perform an implicit conversion + if (convert) { + for (auto &converter : typeinfo->implicit_conversions) { + auto temp = + reinterpret_steal(converter(src.ptr(), typeinfo->type)); + if (load_impl(temp, false)) { + loader_life_support::add_patient(temp); + return true; } - - auto &this_ = static_cast(*this); - this_.check_holder_compat(); - - PyTypeObject *srctype = Py_TYPE(src.ptr()); - - // Case 1: If src is an exact type match for the target type then we can reinterpret_cast - // the instance's value pointer to the target type: - if (srctype == typeinfo->type) { - this_.load_value(reinterpret_cast(src.ptr())->get_value_and_holder()); - return true; - } - // Case 2: We have a derived class - else if (PyType_IsSubtype(srctype, typeinfo->type)) { - auto &bases = all_type_info(srctype); - bool no_cpp_mi = typeinfo->simple_type; - - // Case 2a: the python type is a Python-inherited derived class that inherits from just - // one simple (no MI) pybind11 class, or is an exact match, so the C++ instance is of - // the right type and we can use reinterpret_cast. - // (This is essentially the same as case 2b, but because not using multiple inheritance - // is extremely common, we handle it specially to avoid the loop iterator and type - // pointer lookup overhead) - if (bases.size() == 1 && (no_cpp_mi || bases.front()->type == typeinfo->type)) { - this_.load_value(reinterpret_cast(src.ptr())->get_value_and_holder()); - return true; - } - // Case 2b: the python type inherits from multiple C++ bases. Check the bases to see if - // we can find an exact match (or, for a simple C++ type, an inherited match); if so, we - // can safely reinterpret_cast to the relevant pointer. - else if (bases.size() > 1) { - for (auto base : bases) { - if (no_cpp_mi ? PyType_IsSubtype(base->type, typeinfo->type) : base->type == typeinfo->type) { - this_.load_value(reinterpret_cast(src.ptr())->get_value_and_holder(base)); - return true; - } - } - } - - // Case 2c: C++ multiple inheritance is involved and we couldn't find an exact type match - // in the registered bases, above, so try implicit casting (needed for proper C++ casting - // when MI is involved). - if (this_.try_implicit_casts(src, convert)) - return true; - } - - // Perform an implicit conversion - if (convert) { - for (auto &converter : typeinfo->implicit_conversions) { - auto temp = reinterpret_steal(converter(src.ptr(), typeinfo->type)); - if (load_impl(temp, false)) { - loader_life_support::add_patient(temp); - return true; - } - } - if (this_.try_direct_conversions(src)) - return true; - } - - // Failed to match local typeinfo. Try again with global. - if (typeinfo->module_local) { - if (auto gtype = get_global_type_info(*typeinfo->cpptype)) { - typeinfo = gtype; - return load(src, false); - } - } - - // Global typeinfo has precedence over foreign module_local - return try_load_foreign_module_local(src); + } + if (this_.try_direct_conversions(src)) + return true; } - - // Called to do type lookup and wrap the pointer and type in a pair when a dynamic_cast - // isn't needed or can't be used. If the type is unknown, sets the error and returns a pair - // with .second = nullptr. (p.first = nullptr is not an error: it becomes None). - PYBIND11_NOINLINE static std::pair src_and_type( - const void *src, const std::type_info &cast_type, const std::type_info *rtti_type = nullptr) { - if (auto *tpi = get_type_info(cast_type)) - return {src, const_cast(tpi)}; - - // Not found, set error: - std::string tname = rtti_type ? rtti_type->name() : cast_type.name(); - detail::clean_type_id(tname); - std::string msg = "Unregistered type : " + tname; - PyErr_SetString(PyExc_TypeError, msg.c_str()); - return {nullptr, nullptr}; + // Failed to match local typeinfo. Try again with global. + if (typeinfo->module_local) { + if (auto gtype = get_global_type_info(*typeinfo->cpptype)) { + typeinfo = gtype; + return load(src, false); + } } - const type_info *typeinfo = nullptr; - const std::type_info *cpptype = nullptr; - void *value = nullptr; + // Global typeinfo has precedence over foreign module_local + return try_load_foreign_module_local(src); + } + + // Called to do type lookup and wrap the pointer and type in a pair when a + // dynamic_cast isn't needed or can't be used. If the type is unknown, sets + // the error and returns a pair with .second = nullptr. (p.first = nullptr is + // not an error: it becomes None). + PYBIND11_NOINLINE static std::pair + src_and_type(const void *src, const std::type_info &cast_type, + const std::type_info *rtti_type = nullptr) { + if (auto *tpi = get_type_info(cast_type)) + return {src, const_cast(tpi)}; + + // Not found, set error: + std::string tname = rtti_type ? rtti_type->name() : cast_type.name(); + detail::clean_type_id(tname); + std::string msg = "Unregistered type : " + tname; + PyErr_SetString(PyExc_TypeError, msg.c_str()); + return {nullptr, nullptr}; + } + + const type_info *typeinfo = nullptr; + const std::type_info *cpptype = nullptr; + void *value = nullptr; }; /** - * Determine suitable casting operator for pointer-or-lvalue-casting type casters. The type caster - * needs to provide `operator T*()` and `operator T&()` operators. + * Determine suitable casting operator for pointer-or-lvalue-casting type + * casters. The type caster needs to provide `operator T*()` and `operator + * T&()` operators. * - * If the type supports moving the value away via an `operator T&&() &&` method, it should use - * `movable_cast_op_type` instead. + * If the type supports moving the value away via an `operator T&&() &&` method, + * it should use `movable_cast_op_type` instead. */ template using cast_op_type = conditional_t>::value, - typename std::add_pointer>::type, - typename std::add_lvalue_reference>::type>; + typename std::add_pointer>::type, + typename std::add_lvalue_reference>::type>; /** - * Determine suitable casting operator for a type caster with a movable value. Such a type caster - * needs to provide `operator T*()`, `operator T&()`, and `operator T&&() &&`. The latter will be - * called in appropriate contexts where the value can be moved rather than copied. + * Determine suitable casting operator for a type caster with a movable value. + * Such a type caster needs to provide `operator T*()`, `operator T&()`, and + * `operator T&&() &&`. The latter will be called in appropriate contexts where + * the value can be moved rather than copied. * - * These operator are automatically provided when using the PYBIND11_TYPE_CASTER macro. + * These operator are automatically provided when using the PYBIND11_TYPE_CASTER + * macro. */ template -using movable_cast_op_type = - conditional_t::type>::value, - typename std::add_pointer>::type, +using movable_cast_op_type = conditional_t< + std::is_pointer::type>::value, + typename std::add_pointer>::type, conditional_t::value, - typename std::add_rvalue_reference>::type, - typename std::add_lvalue_reference>::type>>; + typename std::add_rvalue_reference>::type, + typename std::add_lvalue_reference>::type>>; -// std::is_copy_constructible isn't quite enough: it lets std::vector (and similar) through when -// T is non-copyable, but code containing such a copy constructor fails to actually compile. -template struct is_copy_constructible : std::is_copy_constructible {}; +// std::is_copy_constructible isn't quite enough: it lets std::vector (and +// similar) through when T is non-copyable, but code containing such a copy +// constructor fails to actually compile. +template +struct is_copy_constructible : std::is_copy_constructible {}; -// Specialization for types that appear to be copy constructible but also look like stl containers -// (we specifically check for: has `value_type` and `reference` with `reference = value_type&`): if -// so, copy constructability depends on whether the value_type is copy constructible. -template struct is_copy_constructible, - std::is_same - >::value>> : is_copy_constructible {}; +// Specialization for types that appear to be copy constructible but also look +// like stl containers (we specifically check for: has `value_type` and +// `reference` with `reference = value_type&`): if so, copy constructability +// depends on whether the value_type is copy constructible. +template +struct is_copy_constructible< + Container, + enable_if_t, + std::is_same>::value>> + : is_copy_constructible {}; #if !defined(PYBIND11_CPP17) -// Likewise for std::pair before C++17 (which mandates that the copy constructor not exist when the -// two types aren't themselves copy constructible). -template struct is_copy_constructible> +// Likewise for std::pair before C++17 (which mandates that the copy constructor +// not exist when the two types aren't themselves copy constructible). +template +struct is_copy_constructible> : all_of, is_copy_constructible> {}; #endif NAMESPACE_END(detail) -// polymorphic_type_hook::get(src, tinfo) determines whether the object pointed -// to by `src` actually is an instance of some class derived from `itype`. -// If so, it sets `tinfo` to point to the std::type_info representing that derived -// type, and returns a pointer to the start of the most-derived object of that type -// (in which `src` is a subobject; this will be the same address as `src` in most -// single inheritance cases). If not, or if `src` is nullptr, it simply returns `src` -// and leaves `tinfo` at its default value of nullptr. +// polymorphic_type_hook::get(src, tinfo) determines whether the object +// pointed to by `src` actually is an instance of some class derived from +// `itype`. If so, it sets `tinfo` to point to the std::type_info representing +// that derived type, and returns a pointer to the start of the most-derived +// object of that type (in which `src` is a subobject; this will be the same +// address as `src` in most single inheritance cases). If not, or if `src` is +// nullptr, it simply returns `src` and leaves `tinfo` at its default value of +// nullptr. // -// The default polymorphic_type_hook just returns src. A specialization for polymorphic -// types determines the runtime type of the passed object and adjusts the this-pointer -// appropriately via dynamic_cast. This is what enables a C++ Animal* to appear -// to Python as a Dog (if Dog inherits from Animal, Animal is polymorphic, Dog is -// registered with pybind11, and this Animal is in fact a Dog). +// The default polymorphic_type_hook just returns src. A specialization for +// polymorphic types determines the runtime type of the passed object and +// adjusts the this-pointer appropriately via dynamic_cast. This is what +// enables a C++ Animal* to appear to Python as a Dog (if Dog inherits from +// Animal, Animal is polymorphic, Dog is registered with pybind11, and this +// Animal is in fact a Dog). // -// You may specialize polymorphic_type_hook yourself for types that want to appear -// polymorphic to Python but do not use C++ RTTI. (This is a not uncommon pattern -// in performance-sensitive applications, used most notably in LLVM.) -template -struct polymorphic_type_hook -{ - static const void *get(const itype *src, const std::type_info*&) { return src; } +// You may specialize polymorphic_type_hook yourself for types that want to +// appear polymorphic to Python but do not use C++ RTTI. (This is a not uncommon +// pattern in performance-sensitive applications, used most notably in LLVM.) +template struct polymorphic_type_hook { + static const void *get(const itype *src, const std::type_info *&) { + return src; + } }; template -struct polymorphic_type_hook::value>> -{ - static const void *get(const itype *src, const std::type_info*& type) { - type = src ? &typeid(*src) : nullptr; - return dynamic_cast(src); - } +struct polymorphic_type_hook< + itype, detail::enable_if_t::value>> { + static const void *get(const itype *src, const std::type_info *&type) { + type = src ? &typeid(*src) : nullptr; + return dynamic_cast(src); + } }; NAMESPACE_BEGIN(detail) /// Generic type caster for objects stored on the heap template class type_caster_base : public type_caster_generic { - using itype = intrinsic_t; + using itype = intrinsic_t; public: - static constexpr auto name = _(); + static constexpr auto name = _(); - type_caster_base() : type_caster_base(typeid(type)) { } - explicit type_caster_base(const std::type_info &info) : type_caster_generic(info) { } + type_caster_base() : type_caster_base(typeid(type)) {} + explicit type_caster_base(const std::type_info &info) + : type_caster_generic(info) {} - static handle cast(const itype &src, return_value_policy policy, handle parent) { - if (policy == return_value_policy::automatic || policy == return_value_policy::automatic_reference) - policy = return_value_policy::copy; - return cast(&src, policy, parent); + static handle cast(const itype &src, return_value_policy policy, + handle parent) { + if (policy == return_value_policy::automatic || + policy == return_value_policy::automatic_reference) + policy = return_value_policy::copy; + return cast(&src, policy, parent); + } + + static handle cast(itype &&src, return_value_policy, handle parent) { + return cast(&src, return_value_policy::move, parent); + } + + // Returns a (pointer, type_info) pair taking care of necessary type lookup + // for a polymorphic type (using RTTI by default, but can be overridden by + // specializing polymorphic_type_hook). If the instance isn't derived, returns + // the base version. + static std::pair + src_and_type(const itype *src) { + auto &cast_type = typeid(itype); + const std::type_info *instance_type = nullptr; + const void *vsrc = polymorphic_type_hook::get(src, instance_type); + if (instance_type && !same_type(cast_type, *instance_type)) { + // This is a base pointer to a derived type. If the derived type is + // registered with pybind11, we want to make the full derived object + // available. In the typical case where itype is polymorphic, we get the + // correct derived pointer (which may be != base pointer) by a + // dynamic_cast to most derived type. If itype is not polymorphic, we + // won't get here except via a user-provided specialization of + // polymorphic_type_hook, and the user has promised that no this-pointer + // adjustment is required in that case, so it's OK to use static_cast. + if (const auto *tpi = get_type_info(*instance_type)) + return {vsrc, tpi}; } + // Otherwise we have either a nullptr, an `itype` pointer, or an unknown + // derived pointer, so don't do a cast + return type_caster_generic::src_and_type(src, cast_type, instance_type); + } - static handle cast(itype &&src, return_value_policy, handle parent) { - return cast(&src, return_value_policy::move, parent); - } + static handle cast(const itype *src, return_value_policy policy, + handle parent) { + auto st = src_and_type(src); + return type_caster_generic::cast(st.first, policy, parent, st.second, + make_copy_constructor(src), + make_move_constructor(src)); + } - // Returns a (pointer, type_info) pair taking care of necessary type lookup for a - // polymorphic type (using RTTI by default, but can be overridden by specializing - // polymorphic_type_hook). If the instance isn't derived, returns the base version. - static std::pair src_and_type(const itype *src) { - auto &cast_type = typeid(itype); - const std::type_info *instance_type = nullptr; - const void *vsrc = polymorphic_type_hook::get(src, instance_type); - if (instance_type && !same_type(cast_type, *instance_type)) { - // This is a base pointer to a derived type. If the derived type is registered - // with pybind11, we want to make the full derived object available. - // In the typical case where itype is polymorphic, we get the correct - // derived pointer (which may be != base pointer) by a dynamic_cast to - // most derived type. If itype is not polymorphic, we won't get here - // except via a user-provided specialization of polymorphic_type_hook, - // and the user has promised that no this-pointer adjustment is - // required in that case, so it's OK to use static_cast. - if (const auto *tpi = get_type_info(*instance_type)) - return {vsrc, tpi}; - } - // Otherwise we have either a nullptr, an `itype` pointer, or an unknown derived pointer, so - // don't do a cast - return type_caster_generic::src_and_type(src, cast_type, instance_type); - } + static handle cast_holder(const itype *src, const void *holder) { + auto st = src_and_type(src); + return type_caster_generic::cast(st.first, + return_value_policy::take_ownership, {}, + st.second, nullptr, nullptr, holder); + } - static handle cast(const itype *src, return_value_policy policy, handle parent) { - auto st = src_and_type(src); - return type_caster_generic::cast( - st.first, policy, parent, st.second, - make_copy_constructor(src), make_move_constructor(src)); - } + template using cast_op_type = detail::cast_op_type; - static handle cast_holder(const itype *src, const void *holder) { - auto st = src_and_type(src); - return type_caster_generic::cast( - st.first, return_value_policy::take_ownership, {}, st.second, - nullptr, nullptr, holder); - } - - template using cast_op_type = detail::cast_op_type; - - operator itype*() { return (type *) value; } - operator itype&() { if (!value) throw reference_cast_error(); return *((itype *) value); } + operator itype *() { return (type *)value; } + operator itype &() { + if (!value) + throw reference_cast_error(); + return *((itype *)value); + } protected: - using Constructor = void *(*)(const void *); + using Constructor = void *(*)(const void *); - /* Only enabled when the types are {copy,move}-constructible *and* when the type - does not have a private operator new implementation. */ - template ::value>> - static auto make_copy_constructor(const T *x) -> decltype(new T(*x), Constructor{}) { - return [](const void *arg) -> void * { - return new T(*reinterpret_cast(arg)); - }; - } + /* Only enabled when the types are {copy,move}-constructible *and* when the + type does not have a private operator new implementation. */ + template ::value>> + static auto make_copy_constructor(const T *x) + -> decltype(new T(*x), Constructor{}) { + return [](const void *arg) -> void * { + return new T(*reinterpret_cast(arg)); + }; + } - template ::value>> - static auto make_move_constructor(const T *x) -> decltype(new T(std::move(*const_cast(x))), Constructor{}) { - return [](const void *arg) -> void * { - return new T(std::move(*const_cast(reinterpret_cast(arg)))); - }; - } + template ::value>> + static auto make_move_constructor(const T *x) + -> decltype(new T(std::move(*const_cast(x))), Constructor{}) { + return [](const void *arg) -> void * { + return new T( + std::move(*const_cast(reinterpret_cast(arg)))); + }; + } - static Constructor make_copy_constructor(...) { return nullptr; } - static Constructor make_move_constructor(...) { return nullptr; } + static Constructor make_copy_constructor(...) { return nullptr; } + static Constructor make_move_constructor(...) { return nullptr; } }; -template class type_caster : public type_caster_base { }; +template +class type_caster : public type_caster_base {}; template using make_caster = type_caster>; -// Shortcut for calling a caster's `cast_op_type` cast operator for casting a type_caster to a T -template typename make_caster::template cast_op_type cast_op(make_caster &caster) { - return caster.operator typename make_caster::template cast_op_type(); +// Shortcut for calling a caster's `cast_op_type` cast operator for casting a +// type_caster to a T +template +typename make_caster::template cast_op_type +cast_op(make_caster &caster) { + return caster.operator typename make_caster::template cast_op_type(); } -template typename make_caster::template cast_op_type::type> +template +typename make_caster::template cast_op_type< + typename std::add_rvalue_reference::type> cast_op(make_caster &&caster) { - return std::move(caster).operator - typename make_caster::template cast_op_type::type>(); + return std::move(caster).operator typename make_caster:: + template cast_op_type::type>(); } template class type_caster> { private: - using caster_t = make_caster; - caster_t subcaster; - using subcaster_cast_op_type = typename caster_t::template cast_op_type; - static_assert(std::is_same::type &, subcaster_cast_op_type>::value, - "std::reference_wrapper caster requires T to have a caster with an `T &` operator"); + using caster_t = make_caster; + caster_t subcaster; + using subcaster_cast_op_type = typename caster_t::template cast_op_type; + static_assert(std::is_same::type &, + subcaster_cast_op_type>::value, + "std::reference_wrapper caster requires T to have a caster " + "with an `T &` operator"); + public: - bool load(handle src, bool convert) { return subcaster.load(src, convert); } - static constexpr auto name = caster_t::name; - static handle cast(const std::reference_wrapper &src, return_value_policy policy, handle parent) { - // It is definitely wrong to take ownership of this pointer, so mask that rvp - if (policy == return_value_policy::take_ownership || policy == return_value_policy::automatic) - policy = return_value_policy::automatic_reference; - return caster_t::cast(&src.get(), policy, parent); - } - template using cast_op_type = std::reference_wrapper; - operator std::reference_wrapper() { return subcaster.operator subcaster_cast_op_type&(); } + bool load(handle src, bool convert) { return subcaster.load(src, convert); } + static constexpr auto name = caster_t::name; + static handle cast(const std::reference_wrapper &src, + return_value_policy policy, handle parent) { + // It is definitely wrong to take ownership of this pointer, so mask that + // rvp + if (policy == return_value_policy::take_ownership || + policy == return_value_policy::automatic) + policy = return_value_policy::automatic_reference; + return caster_t::cast(&src.get(), policy, parent); + } + template using cast_op_type = std::reference_wrapper; + operator std::reference_wrapper() { + return subcaster.operator subcaster_cast_op_type &(); + } }; -#define PYBIND11_TYPE_CASTER(type, py_name) \ - protected: \ - type value; \ - public: \ - static constexpr auto name = py_name; \ - template >::value, int> = 0> \ - static handle cast(T_ *src, return_value_policy policy, handle parent) { \ - if (!src) return none().release(); \ - if (policy == return_value_policy::take_ownership) { \ - auto h = cast(std::move(*src), policy, parent); delete src; return h; \ - } else { \ - return cast(*src, policy, parent); \ - } \ - } \ - operator type*() { return &value; } \ - operator type&() { return value; } \ - operator type&&() && { return std::move(value); } \ - template using cast_op_type = pybind11::detail::movable_cast_op_type +#define PYBIND11_TYPE_CASTER(type, py_name) \ +protected: \ + type value; \ + \ +public: \ + static constexpr auto name = py_name; \ + template >::value, int> = 0> \ + static handle cast(T_ *src, return_value_policy policy, handle parent) { \ + if (!src) \ + return none().release(); \ + if (policy == return_value_policy::take_ownership) { \ + auto h = cast(std::move(*src), policy, parent); \ + delete src; \ + return h; \ + } else { \ + return cast(*src, policy, parent); \ + } \ + } \ + operator type *() { return &value; } \ + operator type &() { return value; } \ + operator type &&() && { return std::move(value); } \ + template \ + using cast_op_type = pybind11::detail::movable_cast_op_type - -template using is_std_char_type = any_of< - std::is_same, /* std::string */ - std::is_same, /* std::u16string */ - std::is_same, /* std::u32string */ - std::is_same /* std::wstring */ ->; +template +using is_std_char_type = + any_of, /* std::string */ + std::is_same, /* std::u16string */ + std::is_same, /* std::u32string */ + std::is_same /* std::wstring */ + >; template -struct type_caster::value && !is_std_char_type::value>> { - using _py_type_0 = conditional_t; - using _py_type_1 = conditional_t::value, _py_type_0, typename std::make_unsigned<_py_type_0>::type>; - using py_type = conditional_t::value, double, _py_type_1>; +struct type_caster::value && + !is_std_char_type::value>> { + using _py_type_0 = conditional_t; + using _py_type_1 = + conditional_t::value, _py_type_0, + typename std::make_unsigned<_py_type_0>::type>; + using py_type = + conditional_t::value, double, _py_type_1>; + public: + bool load(handle src, bool convert) { + py_type py_value; - bool load(handle src, bool convert) { - py_type py_value; + if (!src) + return false; - if (!src) - return false; + if (std::is_floating_point::value) { + if (convert || PyFloat_Check(src.ptr())) + py_value = (py_type)PyFloat_AsDouble(src.ptr()); + else + return false; + } else if (PyFloat_Check(src.ptr())) { + return false; + } else if (std::is_unsigned::value) { + py_value = as_unsigned(src.ptr()); + } else { // signed integer: + py_value = sizeof(T) <= sizeof(long) + ? (py_type)PyLong_AsLong(src.ptr()) + : (py_type)PYBIND11_LONG_AS_LONGLONG(src.ptr()); + } - if (std::is_floating_point::value) { - if (convert || PyFloat_Check(src.ptr())) - py_value = (py_type) PyFloat_AsDouble(src.ptr()); - else - return false; - } else if (PyFloat_Check(src.ptr())) { - return false; - } else if (std::is_unsigned::value) { - py_value = as_unsigned(src.ptr()); - } else { // signed integer: - py_value = sizeof(T) <= sizeof(long) - ? (py_type) PyLong_AsLong(src.ptr()) - : (py_type) PYBIND11_LONG_AS_LONGLONG(src.ptr()); - } - - bool py_err = py_value == (py_type) -1 && PyErr_Occurred(); - if (py_err || (std::is_integral::value && sizeof(py_type) != sizeof(T) && - (py_value < (py_type) std::numeric_limits::min() || - py_value > (py_type) std::numeric_limits::max()))) { - bool type_error = py_err && PyErr_ExceptionMatches( + bool py_err = py_value == (py_type)-1 && PyErr_Occurred(); + if (py_err || (std::is_integral::value && sizeof(py_type) != sizeof(T) && + (py_value < (py_type)std::numeric_limits::min() || + py_value > (py_type)std::numeric_limits::max()))) { + bool type_error = py_err && PyErr_ExceptionMatches( #if PY_VERSION_HEX < 0x03000000 && !defined(PYPY_VERSION) - PyExc_SystemError + PyExc_SystemError #else - PyExc_TypeError + PyExc_TypeError #endif - ); - PyErr_Clear(); - if (type_error && convert && PyNumber_Check(src.ptr())) { - auto tmp = reinterpret_steal(std::is_floating_point::value - ? PyNumber_Float(src.ptr()) - : PyNumber_Long(src.ptr())); - PyErr_Clear(); - return load(tmp, false); - } - return false; - } - - value = (T) py_value; - return true; + ); + PyErr_Clear(); + if (type_error && convert && PyNumber_Check(src.ptr())) { + auto tmp = reinterpret_steal(std::is_floating_point::value + ? PyNumber_Float(src.ptr()) + : PyNumber_Long(src.ptr())); + PyErr_Clear(); + return load(tmp, false); + } + return false; } - template - static typename std::enable_if::value, handle>::type - cast(U src, return_value_policy /* policy */, handle /* parent */) { - return PyFloat_FromDouble((double) src); - } + value = (T)py_value; + return true; + } - template - static typename std::enable_if::value && std::is_signed::value && (sizeof(U) <= sizeof(long)), handle>::type - cast(U src, return_value_policy /* policy */, handle /* parent */) { - return PYBIND11_LONG_FROM_SIGNED((long) src); - } + template + static typename std::enable_if::value, handle>::type + cast(U src, return_value_policy /* policy */, handle /* parent */) { + return PyFloat_FromDouble((double)src); + } - template - static typename std::enable_if::value && std::is_unsigned::value && (sizeof(U) <= sizeof(unsigned long)), handle>::type - cast(U src, return_value_policy /* policy */, handle /* parent */) { - return PYBIND11_LONG_FROM_UNSIGNED((unsigned long) src); - } + template + static typename std::enable_if::value && + std::is_signed::value && + (sizeof(U) <= sizeof(long)), + handle>::type + cast(U src, return_value_policy /* policy */, handle /* parent */) { + return PYBIND11_LONG_FROM_SIGNED((long)src); + } - template - static typename std::enable_if::value && std::is_signed::value && (sizeof(U) > sizeof(long)), handle>::type - cast(U src, return_value_policy /* policy */, handle /* parent */) { - return PyLong_FromLongLong((long long) src); - } + template + static typename std::enable_if::value && + std::is_unsigned::value && + (sizeof(U) <= sizeof(unsigned long)), + handle>::type + cast(U src, return_value_policy /* policy */, handle /* parent */) { + return PYBIND11_LONG_FROM_UNSIGNED((unsigned long)src); + } - template - static typename std::enable_if::value && std::is_unsigned::value && (sizeof(U) > sizeof(unsigned long)), handle>::type - cast(U src, return_value_policy /* policy */, handle /* parent */) { - return PyLong_FromUnsignedLongLong((unsigned long long) src); - } + template + static typename std::enable_if::value && + std::is_signed::value && + (sizeof(U) > sizeof(long)), + handle>::type + cast(U src, return_value_policy /* policy */, handle /* parent */) { + return PyLong_FromLongLong((long long)src); + } - PYBIND11_TYPE_CASTER(T, _::value>("int", "float")); + template + static typename std::enable_if::value && + std::is_unsigned::value && + (sizeof(U) > sizeof(unsigned long)), + handle>::type + cast(U src, return_value_policy /* policy */, handle /* parent */) { + return PyLong_FromUnsignedLongLong((unsigned long long)src); + } + + PYBIND11_TYPE_CASTER(T, _::value>("int", "float")); }; -template struct void_caster { +template struct void_caster { public: - bool load(handle src, bool) { - if (src && src.is_none()) - return true; - return false; - } - static handle cast(T, return_value_policy /* policy */, handle /* parent */) { - return none().inc_ref(); - } - PYBIND11_TYPE_CASTER(T, _("None")); + bool load(handle src, bool) { + if (src && src.is_none()) + return true; + return false; + } + static handle cast(T, return_value_policy /* policy */, handle /* parent */) { + return none().inc_ref(); + } + PYBIND11_TYPE_CASTER(T, _("None")); }; template <> class type_caster : public void_caster {}; template <> class type_caster : public type_caster { public: - using type_caster::cast; + using type_caster::cast; - bool load(handle h, bool) { - if (!h) { - return false; - } else if (h.is_none()) { - value = nullptr; - return true; - } - - /* Check if this is a capsule */ - if (isinstance(h)) { - value = reinterpret_borrow(h); - return true; - } - - /* Check if this is a C++ type */ - auto &bases = all_type_info((PyTypeObject *) h.get_type().ptr()); - if (bases.size() == 1) { // Only allowing loading from a single-value type - value = values_and_holders(reinterpret_cast(h.ptr())).begin()->value_ptr(); - return true; - } - - /* Fail */ - return false; + bool load(handle h, bool) { + if (!h) { + return false; + } else if (h.is_none()) { + value = nullptr; + return true; } - static handle cast(const void *ptr, return_value_policy /* policy */, handle /* parent */) { - if (ptr) - return capsule(ptr).release(); - else - return none().inc_ref(); + /* Check if this is a capsule */ + if (isinstance(h)) { + value = reinterpret_borrow(h); + return true; } - template using cast_op_type = void*&; - operator void *&() { return value; } - static constexpr auto name = _("capsule"); + /* Check if this is a C++ type */ + auto &bases = all_type_info((PyTypeObject *)h.get_type().ptr()); + if (bases.size() == 1) { // Only allowing loading from a single-value type + value = values_and_holders(reinterpret_cast(h.ptr())) + .begin() + ->value_ptr(); + return true; + } + + /* Fail */ + return false; + } + + static handle cast(const void *ptr, return_value_policy /* policy */, + handle /* parent */) { + if (ptr) + return capsule(ptr).release(); + else + return none().inc_ref(); + } + + template using cast_op_type = void *&; + operator void * &() { return value; } + static constexpr auto name = _("capsule"); + private: - void *value = nullptr; + void *value = nullptr; }; -template <> class type_caster : public void_caster { }; +template <> +class type_caster : public void_caster {}; template <> class type_caster { public: - bool load(handle src, bool convert) { - if (!src) return false; - else if (src.ptr() == Py_True) { value = true; return true; } - else if (src.ptr() == Py_False) { value = false; return true; } - else if (convert || !strcmp("numpy.bool_", Py_TYPE(src.ptr())->tp_name)) { - // (allow non-implicit conversion for numpy booleans) + bool load(handle src, bool convert) { + if (!src) + return false; + else if (src.ptr() == Py_True) { + value = true; + return true; + } else if (src.ptr() == Py_False) { + value = false; + return true; + } else if (convert || !strcmp("numpy.bool_", Py_TYPE(src.ptr())->tp_name)) { + // (allow non-implicit conversion for numpy booleans) - Py_ssize_t res = -1; - if (src.is_none()) { - res = 0; // None is implicitly converted to False - } - #if defined(PYPY_VERSION) - // On PyPy, check that "__bool__" (or "__nonzero__" on Python 2.7) attr exists - else if (hasattr(src, PYBIND11_BOOL_ATTR)) { - res = PyObject_IsTrue(src.ptr()); - } - #else - // Alternate approach for CPython: this does the same as the above, but optimized - // using the CPython API so as to avoid an unneeded attribute lookup. - else if (auto tp_as_number = src.ptr()->ob_type->tp_as_number) { - if (PYBIND11_NB_BOOL(tp_as_number)) { - res = (*PYBIND11_NB_BOOL(tp_as_number))(src.ptr()); - } - } - #endif - if (res == 0 || res == 1) { - value = (bool) res; - return true; - } + Py_ssize_t res = -1; + if (src.is_none()) { + res = 0; // None is implicitly converted to False + } +#if defined(PYPY_VERSION) + // On PyPy, check that "__bool__" (or "__nonzero__" on Python 2.7) attr + // exists + else if (hasattr(src, PYBIND11_BOOL_ATTR)) { + res = PyObject_IsTrue(src.ptr()); + } +#else + // Alternate approach for CPython: this does the same as the above, but + // optimized using the CPython API so as to avoid an unneeded attribute + // lookup. + else if (auto tp_as_number = src.ptr()->ob_type->tp_as_number) { + if (PYBIND11_NB_BOOL(tp_as_number)) { + res = (*PYBIND11_NB_BOOL(tp_as_number))(src.ptr()); } - return false; + } +#endif + if (res == 0 || res == 1) { + value = (bool)res; + return true; + } } - static handle cast(bool src, return_value_policy /* policy */, handle /* parent */) { - return handle(src ? Py_True : Py_False).inc_ref(); - } - PYBIND11_TYPE_CASTER(bool, _("bool")); + return false; + } + static handle cast(bool src, return_value_policy /* policy */, + handle /* parent */) { + return handle(src ? Py_True : Py_False).inc_ref(); + } + PYBIND11_TYPE_CASTER(bool, _("bool")); }; // Helper class for UTF-{8,16,32} C++ stl strings: template struct string_caster { - using CharT = typename StringType::value_type; + using CharT = typename StringType::value_type; - // Simplify life by being able to assume standard char sizes (the standard only guarantees - // minimums, but Python requires exact sizes) - static_assert(!std::is_same::value || sizeof(CharT) == 1, "Unsupported char size != 1"); - static_assert(!std::is_same::value || sizeof(CharT) == 2, "Unsupported char16_t size != 2"); - static_assert(!std::is_same::value || sizeof(CharT) == 4, "Unsupported char32_t size != 4"); - // wchar_t can be either 16 bits (Windows) or 32 (everywhere else) - static_assert(!std::is_same::value || sizeof(CharT) == 2 || sizeof(CharT) == 4, - "Unsupported wchar_t size != 2/4"); - static constexpr size_t UTF_N = 8 * sizeof(CharT); + // Simplify life by being able to assume standard char sizes (the standard + // only guarantees minimums, but Python requires exact sizes) + static_assert(!std::is_same::value || sizeof(CharT) == 1, + "Unsupported char size != 1"); + static_assert(!std::is_same::value || sizeof(CharT) == 2, + "Unsupported char16_t size != 2"); + static_assert(!std::is_same::value || sizeof(CharT) == 4, + "Unsupported char32_t size != 4"); + // wchar_t can be either 16 bits (Windows) or 32 (everywhere else) + static_assert(!std::is_same::value || sizeof(CharT) == 2 || + sizeof(CharT) == 4, + "Unsupported wchar_t size != 2/4"); + static constexpr size_t UTF_N = 8 * sizeof(CharT); - bool load(handle src, bool) { + bool load(handle src, bool) { #if PY_MAJOR_VERSION < 3 - object temp; + object temp; #endif - handle load_src = src; - if (!src) { - return false; - } else if (!PyUnicode_Check(load_src.ptr())) { + handle load_src = src; + if (!src) { + return false; + } else if (!PyUnicode_Check(load_src.ptr())) { #if PY_MAJOR_VERSION >= 3 - return load_bytes(load_src); + return load_bytes(load_src); #else - if (sizeof(CharT) == 1) { - return load_bytes(load_src); - } + if (sizeof(CharT) == 1) { + return load_bytes(load_src); + } - // The below is a guaranteed failure in Python 3 when PyUnicode_Check returns false - if (!PYBIND11_BYTES_CHECK(load_src.ptr())) - return false; + // The below is a guaranteed failure in Python 3 when PyUnicode_Check + // returns false + if (!PYBIND11_BYTES_CHECK(load_src.ptr())) + return false; - temp = reinterpret_steal(PyUnicode_FromObject(load_src.ptr())); - if (!temp) { PyErr_Clear(); return false; } - load_src = temp; + temp = reinterpret_steal(PyUnicode_FromObject(load_src.ptr())); + if (!temp) { + PyErr_Clear(); + return false; + } + load_src = temp; #endif - } - - object utfNbytes = reinterpret_steal(PyUnicode_AsEncodedString( - load_src.ptr(), UTF_N == 8 ? "utf-8" : UTF_N == 16 ? "utf-16" : "utf-32", nullptr)); - if (!utfNbytes) { PyErr_Clear(); return false; } - - const CharT *buffer = reinterpret_cast(PYBIND11_BYTES_AS_STRING(utfNbytes.ptr())); - size_t length = (size_t) PYBIND11_BYTES_SIZE(utfNbytes.ptr()) / sizeof(CharT); - if (UTF_N > 8) { buffer++; length--; } // Skip BOM for UTF-16/32 - value = StringType(buffer, length); - - // If we're loading a string_view we need to keep the encoded Python object alive: - if (IsView) - loader_life_support::add_patient(utfNbytes); - - return true; } - static handle cast(const StringType &src, return_value_policy /* policy */, handle /* parent */) { - const char *buffer = reinterpret_cast(src.data()); - ssize_t nbytes = ssize_t(src.size() * sizeof(CharT)); - handle s = decode_utfN(buffer, nbytes); - if (!s) throw error_already_set(); - return s; + object utfNbytes = reinterpret_steal(PyUnicode_AsEncodedString( + load_src.ptr(), + UTF_N == 8 ? "utf-8" : UTF_N == 16 ? "utf-16" : "utf-32", nullptr)); + if (!utfNbytes) { + PyErr_Clear(); + return false; } - PYBIND11_TYPE_CASTER(StringType, _(PYBIND11_STRING_NAME)); + const CharT *buffer = reinterpret_cast( + PYBIND11_BYTES_AS_STRING(utfNbytes.ptr())); + size_t length = + (size_t)PYBIND11_BYTES_SIZE(utfNbytes.ptr()) / sizeof(CharT); + if (UTF_N > 8) { + buffer++; + length--; + } // Skip BOM for UTF-16/32 + value = StringType(buffer, length); + + // If we're loading a string_view we need to keep the encoded Python object + // alive: + if (IsView) + loader_life_support::add_patient(utfNbytes); + + return true; + } + + static handle cast(const StringType &src, return_value_policy /* policy */, + handle /* parent */) { + const char *buffer = reinterpret_cast(src.data()); + ssize_t nbytes = ssize_t(src.size() * sizeof(CharT)); + handle s = decode_utfN(buffer, nbytes); + if (!s) + throw error_already_set(); + return s; + } + + PYBIND11_TYPE_CASTER(StringType, _(PYBIND11_STRING_NAME)); private: - static handle decode_utfN(const char *buffer, ssize_t nbytes) { + static handle decode_utfN(const char *buffer, ssize_t nbytes) { #if !defined(PYPY_VERSION) - return - UTF_N == 8 ? PyUnicode_DecodeUTF8(buffer, nbytes, nullptr) : - UTF_N == 16 ? PyUnicode_DecodeUTF16(buffer, nbytes, nullptr, nullptr) : - PyUnicode_DecodeUTF32(buffer, nbytes, nullptr, nullptr); + return UTF_N == 8 + ? PyUnicode_DecodeUTF8(buffer, nbytes, nullptr) + : UTF_N == 16 + ? PyUnicode_DecodeUTF16(buffer, nbytes, nullptr, nullptr) + : PyUnicode_DecodeUTF32(buffer, nbytes, nullptr, nullptr); #else - // PyPy seems to have multiple problems related to PyUnicode_UTF*: the UTF8 version - // sometimes segfaults for unknown reasons, while the UTF16 and 32 versions require a - // non-const char * arguments, which is also a nuisance, so bypass the whole thing by just - // passing the encoding as a string value, which works properly: - return PyUnicode_Decode(buffer, nbytes, UTF_N == 8 ? "utf-8" : UTF_N == 16 ? "utf-16" : "utf-32", nullptr); + // PyPy seems to have multiple problems related to PyUnicode_UTF*: the UTF8 + // version sometimes segfaults for unknown reasons, while the UTF16 and 32 + // versions require a non-const char * arguments, which is also a nuisance, + // so bypass the whole thing by just passing the encoding as a string value, + // which works properly: + return PyUnicode_Decode( + buffer, nbytes, + UTF_N == 8 ? "utf-8" : UTF_N == 16 ? "utf-16" : "utf-32", nullptr); #endif + } + + // When loading into a std::string or char*, accept a bytes object as-is (i.e. + // without any encoding/decoding attempt). For other C++ char sizes this is a + // no-op. which supports loading a unicode from a str, doesn't take this path. + template + bool load_bytes(enable_if_t src) { + if (PYBIND11_BYTES_CHECK(src.ptr())) { + // We were passed a Python 3 raw bytes; accept it into a std::string or + // char* without any encoding attempt. + const char *bytes = PYBIND11_BYTES_AS_STRING(src.ptr()); + if (bytes) { + value = StringType(bytes, (size_t)PYBIND11_BYTES_SIZE(src.ptr())); + return true; + } } - // When loading into a std::string or char*, accept a bytes object as-is (i.e. - // without any encoding/decoding attempt). For other C++ char sizes this is a no-op. - // which supports loading a unicode from a str, doesn't take this path. - template - bool load_bytes(enable_if_t src) { - if (PYBIND11_BYTES_CHECK(src.ptr())) { - // We were passed a Python 3 raw bytes; accept it into a std::string or char* - // without any encoding attempt. - const char *bytes = PYBIND11_BYTES_AS_STRING(src.ptr()); - if (bytes) { - value = StringType(bytes, (size_t) PYBIND11_BYTES_SIZE(src.ptr())); - return true; - } - } + return false; + } - return false; - } - - template - bool load_bytes(enable_if_t) { return false; } + template + bool load_bytes(enable_if_t) { + return false; + } }; template -struct type_caster, enable_if_t::value>> +struct type_caster, + enable_if_t::value>> : string_caster> {}; #ifdef PYBIND11_HAS_STRING_VIEW template -struct type_caster, enable_if_t::value>> +struct type_caster, + enable_if_t::value>> : string_caster, true> {}; #endif -// Type caster for C-style strings. We basically use a std::string type caster, but also add the -// ability to use None as a nullptr char* (which the string caster doesn't allow). -template struct type_caster::value>> { - using StringType = std::basic_string; - using StringCaster = type_caster; - StringCaster str_caster; - bool none = false; - CharT one_char = 0; +// Type caster for C-style strings. We basically use a std::string type caster, +// but also add the ability to use None as a nullptr char* (which the string +// caster doesn't allow). +template +struct type_caster::value>> { + using StringType = std::basic_string; + using StringCaster = type_caster; + StringCaster str_caster; + bool none = false; + CharT one_char = 0; + public: - bool load(handle src, bool convert) { - if (!src) return false; - if (src.is_none()) { - // Defer accepting None to other overloads (if we aren't in convert mode): - if (!convert) return false; - none = true; - return true; + bool load(handle src, bool convert) { + if (!src) + return false; + if (src.is_none()) { + // Defer accepting None to other overloads (if we aren't in convert mode): + if (!convert) + return false; + none = true; + return true; + } + return str_caster.load(src, convert); + } + + static handle cast(const CharT *src, return_value_policy policy, + handle parent) { + if (src == nullptr) + return pybind11::none().inc_ref(); + return StringCaster::cast(StringType(src), policy, parent); + } + + static handle cast(CharT src, return_value_policy policy, handle parent) { + if (std::is_same::value) { + handle s = PyUnicode_DecodeLatin1((const char *)&src, 1, nullptr); + if (!s) + throw error_already_set(); + return s; + } + return StringCaster::cast(StringType(1, src), policy, parent); + } + + operator CharT *() { + return none ? nullptr + : const_cast( + static_cast(str_caster).c_str()); + } + operator CharT &() { + if (none) + throw value_error("Cannot convert None to a character"); + + auto &value = static_cast(str_caster); + size_t str_len = value.size(); + if (str_len == 0) + throw value_error("Cannot convert empty string to a character"); + + // If we're in UTF-8 mode, we have two possible failures: one for a unicode + // character that is too high, and one for multiple unicode characters + // (caught later), so we need to figure out how long the first encoded + // character is in bytes to distinguish between these two errors. We also + // allow want to allow unicode characters U+0080 through U+00FF, as those + // can fit into a single char value. + if (StringCaster::UTF_N == 8 && str_len > 1 && str_len <= 4) { + unsigned char v0 = static_cast(value[0]); + size_t char0_bytes = + !(v0 & 0x80) ? 1 : // low bits only: 0-127 + (v0 & 0xE0) == 0xC0 ? 2 : // 0b110xxxxx - start of 2-byte sequence + (v0 & 0xF0) == 0xE0 ? 3 + : // 0b1110xxxx - start of 3-byte sequence + 4; // 0b11110xxx - start of 4-byte sequence + + if (char0_bytes == str_len) { + // If we have a 128-255 value, we can decode it into a single char: + if (char0_bytes == 2 && (v0 & 0xFC) == 0xC0) { // 0x110000xx 0x10xxxxxx + one_char = static_cast( + ((v0 & 3) << 6) + (static_cast(value[1]) & 0x3F)); + return one_char; } - return str_caster.load(src, convert); + // Otherwise we have a single character, but it's > U+00FF + throw value_error("Character code point not in range(0x100)"); + } } - static handle cast(const CharT *src, return_value_policy policy, handle parent) { - if (src == nullptr) return pybind11::none().inc_ref(); - return StringCaster::cast(StringType(src), policy, parent); + // UTF-16 is much easier: we can only have a surrogate pair for values above + // U+FFFF, thus a surrogate pair with total length 2 instantly indicates a + // range error (but not a "your string was too long" error). + else if (StringCaster::UTF_N == 16 && str_len == 2) { + one_char = static_cast(value[0]); + if (one_char >= 0xD800 && one_char < 0xE000) + throw value_error("Character code point not in range(0x10000)"); } - static handle cast(CharT src, return_value_policy policy, handle parent) { - if (std::is_same::value) { - handle s = PyUnicode_DecodeLatin1((const char *) &src, 1, nullptr); - if (!s) throw error_already_set(); - return s; - } - return StringCaster::cast(StringType(1, src), policy, parent); - } + if (str_len != 1) + throw value_error( + "Expected a character, but multi-character string found"); - operator CharT*() { return none ? nullptr : const_cast(static_cast(str_caster).c_str()); } - operator CharT&() { - if (none) - throw value_error("Cannot convert None to a character"); + one_char = value[0]; + return one_char; + } - auto &value = static_cast(str_caster); - size_t str_len = value.size(); - if (str_len == 0) - throw value_error("Cannot convert empty string to a character"); - - // If we're in UTF-8 mode, we have two possible failures: one for a unicode character that - // is too high, and one for multiple unicode characters (caught later), so we need to figure - // out how long the first encoded character is in bytes to distinguish between these two - // errors. We also allow want to allow unicode characters U+0080 through U+00FF, as those - // can fit into a single char value. - if (StringCaster::UTF_N == 8 && str_len > 1 && str_len <= 4) { - unsigned char v0 = static_cast(value[0]); - size_t char0_bytes = !(v0 & 0x80) ? 1 : // low bits only: 0-127 - (v0 & 0xE0) == 0xC0 ? 2 : // 0b110xxxxx - start of 2-byte sequence - (v0 & 0xF0) == 0xE0 ? 3 : // 0b1110xxxx - start of 3-byte sequence - 4; // 0b11110xxx - start of 4-byte sequence - - if (char0_bytes == str_len) { - // If we have a 128-255 value, we can decode it into a single char: - if (char0_bytes == 2 && (v0 & 0xFC) == 0xC0) { // 0x110000xx 0x10xxxxxx - one_char = static_cast(((v0 & 3) << 6) + (static_cast(value[1]) & 0x3F)); - return one_char; - } - // Otherwise we have a single character, but it's > U+00FF - throw value_error("Character code point not in range(0x100)"); - } - } - - // UTF-16 is much easier: we can only have a surrogate pair for values above U+FFFF, thus a - // surrogate pair with total length 2 instantly indicates a range error (but not a "your - // string was too long" error). - else if (StringCaster::UTF_N == 16 && str_len == 2) { - one_char = static_cast(value[0]); - if (one_char >= 0xD800 && one_char < 0xE000) - throw value_error("Character code point not in range(0x10000)"); - } - - if (str_len != 1) - throw value_error("Expected a character, but multi-character string found"); - - one_char = value[0]; - return one_char; - } - - static constexpr auto name = _(PYBIND11_STRING_NAME); - template using cast_op_type = pybind11::detail::cast_op_type<_T>; + static constexpr auto name = _(PYBIND11_STRING_NAME); + template + using cast_op_type = pybind11::detail::cast_op_type<_T>; }; // Base implementation for std::tuple and std::pair -template class Tuple, typename... Ts> class tuple_caster { - using type = Tuple; - static constexpr auto size = sizeof...(Ts); - using indices = make_index_sequence; +template