[CI] run clang-format (#24)
This commit is contained in:
8
.github/workflows/integration-tests.yml
vendored
8
.github/workflows/integration-tests.yml
vendored
@@ -27,10 +27,16 @@ jobs:
|
||||
pip install isort
|
||||
isort -c ./python || ( echo '::error title=Imports not sorted::Please run \"isort ./python\"' ; exit 1 )
|
||||
|
||||
- name: Check style
|
||||
- name: Check python style
|
||||
run: |
|
||||
pip install autopep8
|
||||
autopep8 -a -r -d --exit-code ./python || ( echo '::error title=Style issues::Please run \"autopep8 -a -r -i ./python\"' ; exit 1 )
|
||||
|
||||
- name: Check cpp style
|
||||
run: |
|
||||
sudo apt-get install clang-format
|
||||
find . -regex '.*\.\(cpp\|hpp\|h\|cc\)' -not -path "./python/build/*" -not -path "./include/triton/external/*" -print0 | xargs -0 -n1 clang-format -style=file --dry-run -Werror -i ||
|
||||
(echo '::error title=Style issues:: Please run `find . -regex ".*\.\(cpp\|hpp\|h\|cc\)" -not -path "./python/build/*" -not -path "./include/triton/external/*" -print0 | xargs -0 -n1 clang-format -style=file -i`' ; exit 1)
|
||||
|
||||
- name: Flake8
|
||||
run: |
|
||||
|
@@ -57,9 +57,8 @@ static cl::opt<bool> NoCanonicalizeWhiteSpace(
|
||||
"strict-whitespace",
|
||||
cl::desc("Do not treat all horizontal whitespace as equivalent"));
|
||||
|
||||
static cl::opt<bool> IgnoreCase(
|
||||
"ignore-case",
|
||||
cl::desc("Use case-insensitive matching"));
|
||||
static cl::opt<bool> IgnoreCase("ignore-case",
|
||||
cl::desc("Use case-insensitive matching"));
|
||||
|
||||
static cl::list<std::string> ImplicitCheckNot(
|
||||
"implicit-check-not",
|
||||
@@ -169,12 +168,6 @@ static cl::list<unsigned> DumpInputContexts(
|
||||
|
||||
typedef cl::list<std::string>::const_iterator prefix_iterator;
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
static void DumpCommandLine(int argc, char **argv) {
|
||||
errs() << "FileCheck command line: ";
|
||||
for (int I = 0; I < argc; I++)
|
||||
@@ -613,8 +606,7 @@ static void DumpAnnotatedInput(raw_ostream &OS, const FileCheckRequest &Req,
|
||||
ElidedLinesOS.enable_colors(true);
|
||||
auto AnnotationItr = Annotations.begin(), AnnotationEnd = Annotations.end();
|
||||
for (unsigned Line = 1;
|
||||
InputFilePtr != InputFileEnd || AnnotationItr != AnnotationEnd;
|
||||
++Line) {
|
||||
InputFilePtr != InputFileEnd || AnnotationItr != AnnotationEnd; ++Line) {
|
||||
const unsigned char *InputFileLine = InputFilePtr;
|
||||
|
||||
// Compute the previous and next line included by the filter.
|
||||
@@ -691,8 +683,7 @@ static void DumpAnnotatedInput(raw_ostream &OS, const FileCheckRequest &Req,
|
||||
unsigned InputLineWidth = InputFilePtr - InputFileLine;
|
||||
|
||||
// Print any annotations.
|
||||
while (AnnotationItr != AnnotationEnd &&
|
||||
AnnotationItr->InputLine == Line) {
|
||||
while (AnnotationItr != AnnotationEnd && AnnotationItr->InputLine == Line) {
|
||||
WithColor COS(*LineOS, AnnotationItr->Marker.Color, /*Bold=*/true,
|
||||
/*BG=*/false, TheColorMode);
|
||||
// The two spaces below are where the ": " appears on input lines.
|
||||
|
@@ -10,11 +10,11 @@
|
||||
#include "mlir/InitAllPasses.h"
|
||||
#include "mlir/Support/MlirOptMain.h"
|
||||
|
||||
namespace mlir{
|
||||
namespace test{
|
||||
namespace mlir {
|
||||
namespace test {
|
||||
void registerTestAlignmentPass();
|
||||
}
|
||||
}
|
||||
} // namespace mlir
|
||||
|
||||
int main(int argc, char **argv) {
|
||||
mlir::registerAllPasses();
|
||||
@@ -25,13 +25,11 @@ int main(int argc, char **argv) {
|
||||
|
||||
// TODO: register Triton & TritonGPU passes
|
||||
mlir::DialectRegistry registry;
|
||||
registry.insert<mlir::triton::TritonDialect,
|
||||
mlir::triton::gpu::TritonGPUDialect,
|
||||
mlir::arith::ArithmeticDialect,
|
||||
mlir::StandardOpsDialect,
|
||||
mlir::scf::SCFDialect>();
|
||||
registry
|
||||
.insert<mlir::triton::TritonDialect, mlir::triton::gpu::TritonGPUDialect,
|
||||
mlir::arith::ArithmeticDialect, mlir::StandardOpsDialect,
|
||||
mlir::scf::SCFDialect>();
|
||||
|
||||
return mlir::asMainReturnCode(
|
||||
mlir::MlirOptMain(argc, argv, "Triton (GPU) optimizer driver\n", registry)
|
||||
);
|
||||
return mlir::asMainReturnCode(mlir::MlirOptMain(
|
||||
argc, argv, "Triton (GPU) optimizer driver\n", registry));
|
||||
}
|
||||
|
@@ -10,7 +10,6 @@
|
||||
|
||||
namespace mlir {
|
||||
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// AxisInfo
|
||||
//===----------------------------------------------------------------------===//
|
||||
@@ -25,26 +24,25 @@ public:
|
||||
|
||||
public:
|
||||
// Default constructor
|
||||
AxisInfo(): AxisInfo({}, {}, {}) { }
|
||||
AxisInfo() : AxisInfo({}, {}, {}) {}
|
||||
// Construct contiguity info with known contiguity
|
||||
AxisInfo(ContiguityT knownContiguity, DivisibilityT knownDivisibility,
|
||||
ConstancyT knownConstancy)
|
||||
: contiguity(knownContiguity), divisibility(knownDivisibility),
|
||||
constancy(knownConstancy), rank(contiguity.size()) {
|
||||
assert(knownDivisibility.size() == rank);
|
||||
assert(knownConstancy.size() == rank);
|
||||
}
|
||||
|
||||
|
||||
: contiguity(knownContiguity), divisibility(knownDivisibility),
|
||||
constancy(knownConstancy), rank(contiguity.size()) {
|
||||
assert(knownDivisibility.size() == rank);
|
||||
assert(knownConstancy.size() == rank);
|
||||
}
|
||||
|
||||
// Accessors
|
||||
int getContiguity(size_t d) const { return contiguity[d]; }
|
||||
const ContiguityT& getContiguity() const { return contiguity; }
|
||||
int getContiguity(size_t d) const { return contiguity[d]; }
|
||||
const ContiguityT &getContiguity() const { return contiguity; }
|
||||
|
||||
int getDivisibility(size_t d) const { return divisibility[d]; }
|
||||
const DivisibilityT& getDivisibility() const { return divisibility; }
|
||||
const DivisibilityT &getDivisibility() const { return divisibility; }
|
||||
|
||||
int getConstancy(size_t d) const { return constancy[d]; }
|
||||
const ConstancyT& getConstancy() const { return constancy; }
|
||||
int getConstancy(size_t d) const { return constancy[d]; }
|
||||
const ConstancyT &getConstancy() const { return constancy; }
|
||||
|
||||
int getRank() const { return rank; }
|
||||
|
||||
@@ -56,13 +54,13 @@ public:
|
||||
}
|
||||
|
||||
/// The pessimistic value state of the contiguity is unknown.
|
||||
static AxisInfo getPessimisticValueState(MLIRContext *context)
|
||||
{ return AxisInfo(); }
|
||||
static AxisInfo getPessimisticValueState(MLIRContext *context) {
|
||||
return AxisInfo();
|
||||
}
|
||||
static AxisInfo getPessimisticValueState(Value value);
|
||||
|
||||
// The gcd of both arguments for each dimension
|
||||
static AxisInfo join(const AxisInfo &lhs,
|
||||
const AxisInfo &rhs);
|
||||
static AxisInfo join(const AxisInfo &lhs, const AxisInfo &rhs);
|
||||
|
||||
private:
|
||||
/// The _contiguity_ information maps the `d`-th
|
||||
@@ -81,7 +79,7 @@ private:
|
||||
/// [19, 23, 27, 31]
|
||||
/// Would have contiguity [2, 1].
|
||||
ContiguityT contiguity;
|
||||
|
||||
|
||||
/// The _divisibility_ information maps the `d`-th
|
||||
/// dimension to the largest power-of-two that
|
||||
/// divides the first element of all the values along it
|
||||
@@ -107,39 +105,36 @@ private:
|
||||
/// [16, 16, 16, 16, 20, 20, 20, 20]
|
||||
/// would have constancy [1, 4]
|
||||
ConstancyT constancy;
|
||||
|
||||
|
||||
// number of dimensions of the lattice
|
||||
int rank;
|
||||
};
|
||||
|
||||
|
||||
class AxisInfoAnalysis
|
||||
: public ForwardDataFlowAnalysis<AxisInfo> {
|
||||
class AxisInfoAnalysis : public ForwardDataFlowAnalysis<AxisInfo> {
|
||||
|
||||
private:
|
||||
static const int maxPow2Divisor = 65536;
|
||||
|
||||
int highestPowOf2Divisor(int n){
|
||||
if(n==0)
|
||||
|
||||
int highestPowOf2Divisor(int n) {
|
||||
if (n == 0)
|
||||
return maxPow2Divisor;
|
||||
return (n & (~(n - 1)));
|
||||
}
|
||||
|
||||
AxisInfo visitBinaryOp(Operation* op, AxisInfo lhsInfo, AxisInfo rhsInfo,
|
||||
const std::function<int(AxisInfo,AxisInfo,int)>& getContiguity,
|
||||
const std::function<int(AxisInfo,AxisInfo,int)>& getDivisibility,
|
||||
const std::function<int(AxisInfo,AxisInfo,int)>& getConstancy);
|
||||
AxisInfo visitBinaryOp(
|
||||
Operation *op, AxisInfo lhsInfo, AxisInfo rhsInfo,
|
||||
const std::function<int(AxisInfo, AxisInfo, int)> &getContiguity,
|
||||
const std::function<int(AxisInfo, AxisInfo, int)> &getDivisibility,
|
||||
const std::function<int(AxisInfo, AxisInfo, int)> &getConstancy);
|
||||
|
||||
public:
|
||||
using ForwardDataFlowAnalysis<AxisInfo>::ForwardDataFlowAnalysis;
|
||||
|
||||
ChangeResult visitOperation(Operation *op,
|
||||
ArrayRef<LatticeElement<AxisInfo> *> operands) override;
|
||||
|
||||
ChangeResult
|
||||
visitOperation(Operation *op,
|
||||
ArrayRef<LatticeElement<AxisInfo> *> operands) override;
|
||||
};
|
||||
|
||||
|
||||
}
|
||||
|
||||
} // namespace mlir
|
||||
|
||||
#endif
|
@@ -3,17 +3,13 @@
|
||||
|
||||
#include "triton/Conversion/TritonToTritonGPU/TritonToTritonGPU.h"
|
||||
|
||||
namespace mlir
|
||||
{
|
||||
namespace triton
|
||||
{
|
||||
namespace mlir {
|
||||
namespace triton {
|
||||
|
||||
#define GEN_PASS_REGISTRATION
|
||||
#include "triton/Conversion/Passes.h.inc"
|
||||
|
||||
|
||||
} // namespace triton
|
||||
} // namespace mlir
|
||||
|
||||
|
||||
#endif
|
@@ -3,18 +3,17 @@
|
||||
|
||||
#include <memory>
|
||||
|
||||
namespace mlir{
|
||||
namespace mlir {
|
||||
|
||||
class ModuleOp;
|
||||
template <typename T> class OperationPass;
|
||||
|
||||
namespace triton{
|
||||
namespace triton {
|
||||
|
||||
std::unique_ptr<OperationPass<ModuleOp>>
|
||||
std::unique_ptr<OperationPass<ModuleOp>>
|
||||
createConvertTritonToTritonGPUPass(int numWarps = 4);
|
||||
|
||||
}
|
||||
} // namespace mlir
|
||||
|
||||
|
||||
#endif
|
@@ -1,17 +1,16 @@
|
||||
#ifndef TRITON_DIALECT_TRITON_IR_DIALECT_H_
|
||||
#define TRITON_DIALECT_TRITON_IR_DIALECT_H_
|
||||
|
||||
|
||||
#include "mlir/Dialect/SCF/SCF.h"
|
||||
#include "mlir/Dialect/StandardOps/IR/Ops.h"
|
||||
#include "mlir/IR/BuiltinOps.h"
|
||||
#include "mlir/IR/Dialect.h"
|
||||
#include "mlir/Interfaces/ControlFlowInterfaces.h"
|
||||
#include "mlir/Dialect/StandardOps/IR/Ops.h"
|
||||
#include "mlir/Dialect/SCF/SCF.h"
|
||||
|
||||
#include "triton/Dialect/Triton/IR/Traits.h"
|
||||
#include "triton/Dialect/Triton/IR/Types.h"
|
||||
#include "triton/Dialect/Triton/IR/Dialect.h.inc"
|
||||
#include "triton/Dialect/Triton/IR/OpsEnums.h.inc"
|
||||
#include "triton/Dialect/Triton/IR/Traits.h"
|
||||
#include "triton/Dialect/Triton/IR/Types.h"
|
||||
|
||||
#define GET_OP_CLASSES
|
||||
#include "triton/Dialect/Triton/IR/Ops.h.inc"
|
||||
|
@@ -19,7 +19,7 @@ public:
|
||||
static LogicalResult verifyTrait(Operation *op) {
|
||||
// The rationale for this number is to prevent users from creating programs
|
||||
// that would have catastrophic register pressure and cause the compiler to
|
||||
// hang.
|
||||
// hang.
|
||||
// Since H100 has 256KB registers, we should allow users to create tensors
|
||||
// of size up to 256K elements. It will spill for datatypes wider than 1B,
|
||||
// but we probably should limit number of elements (rather than bytes) to
|
||||
@@ -31,8 +31,8 @@ public:
|
||||
for (int64_t s : tensorType.getShape())
|
||||
numElements *= s;
|
||||
if (numElements > maxElement)
|
||||
return op->emitError("Maximum allowed number of elements is ") << maxElement << ", but "
|
||||
<< *op << " has more than that";
|
||||
return op->emitError("Maximum allowed number of elements is ")
|
||||
<< maxElement << ", but " << *op << " has more than that";
|
||||
if ((numElements & (numElements - 1)) != 0)
|
||||
return op->emitError("Number of elements must be power-of-two, but ")
|
||||
<< *op << " doesn't follow the rule";
|
||||
@@ -45,8 +45,8 @@ public:
|
||||
for (int64_t s : tensorType.getShape())
|
||||
numElements *= s;
|
||||
if (numElements > maxElement)
|
||||
return op->emitError("Maximum allowed number of elements is ") << maxElement << ", but "
|
||||
<< *op << " has more than that";
|
||||
return op->emitError("Maximum allowed number of elements is ")
|
||||
<< maxElement << ", but " << *op << " has more than that";
|
||||
if ((numElements & (numElements - 1)) != 0)
|
||||
return op->emitError("Number of elements must be power-of-two, but ")
|
||||
<< *op << " doesn't follow the rule";
|
||||
@@ -57,7 +57,7 @@ public:
|
||||
}
|
||||
};
|
||||
|
||||
}
|
||||
}
|
||||
} // namespace OpTrait
|
||||
} // namespace mlir
|
||||
|
||||
#endif
|
||||
|
@@ -13,6 +13,6 @@ std::unique_ptr<Pass> createCombineOpsPass();
|
||||
#define GEN_PASS_REGISTRATION
|
||||
#include "triton/Dialect/Triton/Transforms/Passes.h.inc"
|
||||
|
||||
}
|
||||
} // namespace mlir
|
||||
|
||||
#endif
|
||||
|
@@ -15,5 +15,4 @@
|
||||
#define GET_OP_CLASSES
|
||||
#include "triton/Dialect/TritonGPU/IR/Ops.h.inc"
|
||||
|
||||
|
||||
#endif // TRITON_DIALECT_TRITONGPU_IR_DIALECT_H_
|
||||
|
@@ -14,6 +14,7 @@ namespace mlir {
|
||||
class TritonGPUTypeConverter : public TypeConverter {
|
||||
public:
|
||||
TritonGPUTypeConverter(MLIRContext *context, int numThreads);
|
||||
|
||||
private:
|
||||
MLIRContext *context;
|
||||
int numThreads;
|
||||
@@ -21,8 +22,10 @@ private:
|
||||
|
||||
class TritonGPUConversionTarget : public ConversionTarget {
|
||||
TritonGPUTypeConverter &typeConverter;
|
||||
|
||||
public:
|
||||
explicit TritonGPUConversionTarget(MLIRContext &ctx, TritonGPUTypeConverter &typeConverter);
|
||||
explicit TritonGPUConversionTarget(MLIRContext &ctx,
|
||||
TritonGPUTypeConverter &typeConverter);
|
||||
|
||||
/// update layouts & insert ConvertLayoutOp if necessary
|
||||
LogicalResult refineLayouts(ModuleOp mod, int numThreads);
|
||||
|
387
include/triton/driver/dispatch.h
Executable file → Normal file
387
include/triton/driver/dispatch.h
Executable file → Normal file
@@ -3,10 +3,10 @@
|
||||
#ifndef _TRITON_DRIVER_DISPATCH_H_
|
||||
#define _TRITON_DRIVER_DISPATCH_H_
|
||||
|
||||
#include <type_traits>
|
||||
#include <dlfcn.h>
|
||||
#include <type_traits>
|
||||
|
||||
//CUDA Backend
|
||||
// CUDA Backend
|
||||
#include "triton/external/CUDA/cuda.h"
|
||||
#include "triton/external/CUDA/nvml.h"
|
||||
|
||||
@@ -14,47 +14,43 @@
|
||||
//#define __HIP_PLATFORM_AMD__
|
||||
#include "triton/external/hip.h"
|
||||
|
||||
//Exceptions
|
||||
// Exceptions
|
||||
#include <iostream>
|
||||
#include <stdexcept>
|
||||
|
||||
namespace llvm {
|
||||
class PassRegistry;
|
||||
class Module;
|
||||
}
|
||||
} // namespace llvm
|
||||
|
||||
namespace triton
|
||||
{
|
||||
namespace driver
|
||||
{
|
||||
namespace triton {
|
||||
namespace driver {
|
||||
|
||||
class cu_context;
|
||||
|
||||
template<class T> void check(T){}
|
||||
template <class T> void check(T) {}
|
||||
void check(CUresult err);
|
||||
void check(hipError_t err);
|
||||
|
||||
class dispatch
|
||||
{
|
||||
class dispatch {
|
||||
protected:
|
||||
template <class F>
|
||||
struct return_type;
|
||||
template <class F> struct return_type;
|
||||
|
||||
template <class R, class... A>
|
||||
struct return_type<R (*)(A...)>
|
||||
{ typedef R type; };
|
||||
template <class R, class... A> struct return_type<R (*)(A...)> {
|
||||
typedef R type;
|
||||
};
|
||||
|
||||
typedef bool (*f_init_t)();
|
||||
|
||||
template<f_init_t initializer, typename FunPtrT, typename... Args>
|
||||
static typename return_type<FunPtrT>::type f_impl(void*& lib_h, FunPtrT, void*& cache, const char * name, Args... args)
|
||||
{
|
||||
template <f_init_t initializer, typename FunPtrT, typename... Args>
|
||||
static typename return_type<FunPtrT>::type
|
||||
f_impl(void *&lib_h, FunPtrT, void *&cache, const char *name, Args... args) {
|
||||
initializer();
|
||||
if(cache == nullptr){
|
||||
if (cache == nullptr) {
|
||||
cache = dlsym(lib_h, name);
|
||||
if(cache == 0)
|
||||
throw std::runtime_error("dlsym unable to load function");
|
||||
}
|
||||
if (cache == 0)
|
||||
throw std::runtime_error("dlsym unable to load function");
|
||||
}
|
||||
FunPtrT fptr;
|
||||
*reinterpret_cast<void **>(&fptr) = cache;
|
||||
typename return_type<FunPtrT>::type res = (*fptr)(args...);
|
||||
@@ -76,63 +72,99 @@ public:
|
||||
// context management
|
||||
static CUresult cuInit(unsigned int Flags);
|
||||
static CUresult cuCtxDestroy_v2(CUcontext ctx);
|
||||
static CUresult cuCtxCreate_v2(CUcontext *pctx, unsigned int flags, CUdevice dev);
|
||||
static CUresult cuCtxCreate_v2(CUcontext *pctx, unsigned int flags,
|
||||
CUdevice dev);
|
||||
static CUresult cuCtxPushCurrent_v2(CUcontext ctx);
|
||||
static CUresult cuCtxPopCurrent_v2(CUcontext *pctx);
|
||||
static CUresult cuCtxGetDevice(CUdevice* result);
|
||||
static CUresult cuCtxEnablePeerAccess(CUcontext peerContext, unsigned int flags);
|
||||
static CUresult cuCtxGetDevice(CUdevice *result);
|
||||
static CUresult cuCtxEnablePeerAccess(CUcontext peerContext,
|
||||
unsigned int flags);
|
||||
static CUresult cuDriverGetVersion(int *driverVersion);
|
||||
// device management
|
||||
static CUresult cuDeviceGet(CUdevice *device, int ordinal);
|
||||
static CUresult cuDeviceGetName(char *name, int len, CUdevice dev);
|
||||
static CUresult cuDeviceGetPCIBusId(char *id, int len, CUdevice dev);
|
||||
static CUresult cuDeviceGetAttribute(int *pi, CUdevice_attribute attrib, CUdevice dev);
|
||||
static CUresult cuDeviceGetAttribute(int *pi, CUdevice_attribute attrib,
|
||||
CUdevice dev);
|
||||
static CUresult cuDeviceGetCount(int *count);
|
||||
// link management
|
||||
static CUresult cuLinkAddData_v2(CUlinkState state, CUjitInputType type, void* data, size_t size, const char* name, unsigned int numOptions, CUjit_option* options, void** optionValues);
|
||||
static CUresult cuLinkCreate_v2(unsigned int numOptions, CUjit_option* options, void** optionValues, CUlinkState* stateOut);
|
||||
static CUresult cuLinkComplete(CUlinkState state, void** cubinOut, size_t* sizeOut);
|
||||
static CUresult cuLinkAddData_v2(CUlinkState state, CUjitInputType type,
|
||||
void *data, size_t size, const char *name,
|
||||
unsigned int numOptions,
|
||||
CUjit_option *options, void **optionValues);
|
||||
static CUresult cuLinkCreate_v2(unsigned int numOptions,
|
||||
CUjit_option *options, void **optionValues,
|
||||
CUlinkState *stateOut);
|
||||
static CUresult cuLinkComplete(CUlinkState state, void **cubinOut,
|
||||
size_t *sizeOut);
|
||||
static CUresult cuLinkDestroy(CUlinkState state);
|
||||
// module management
|
||||
static CUresult cuModuleGetGlobal_v2(CUdeviceptr *dptr, size_t* bytes, CUmodule hmod, const char *name);
|
||||
static CUresult cuModuleGetGlobal_v2(CUdeviceptr *dptr, size_t *bytes,
|
||||
CUmodule hmod, const char *name);
|
||||
static CUresult cuModuleLoad(CUmodule *module, const char *fname);
|
||||
static CUresult cuModuleLoadData(CUmodule* module, const void* image);
|
||||
static CUresult cuModuleLoadData(CUmodule *module, const void *image);
|
||||
static CUresult cuModuleUnload(CUmodule hmod);
|
||||
static CUresult cuModuleLoadDataEx(CUmodule *module, const void *image, unsigned int numOptions, CUjit_option *options, void **optionValues);
|
||||
static CUresult cuModuleGetFunction(CUfunction *hfunc, CUmodule hmod, const char *name);
|
||||
static CUresult cuModuleLoadDataEx(CUmodule *module, const void *image,
|
||||
unsigned int numOptions,
|
||||
CUjit_option *options,
|
||||
void **optionValues);
|
||||
static CUresult cuModuleGetFunction(CUfunction *hfunc, CUmodule hmod,
|
||||
const char *name);
|
||||
// stream management
|
||||
static CUresult cuStreamCreate(CUstream *phStream, unsigned int Flags);
|
||||
static CUresult cuStreamSynchronize(CUstream hStream);
|
||||
static CUresult cuStreamGetCtx(CUstream hStream, CUcontext* pctx);
|
||||
static CUresult cuStreamGetCtx(CUstream hStream, CUcontext *pctx);
|
||||
static CUresult cuStreamDestroy_v2(CUstream hStream);
|
||||
static CUresult cuLaunchKernel(CUfunction f, unsigned int gridDimX, unsigned int gridDimY, unsigned int gridDimZ, unsigned int blockDimX, unsigned int blockDimY, unsigned int blockDimZ, unsigned int sharedMemBytes, CUstream hStream, void **kernelParams, void **extra);
|
||||
static CUresult cuLaunchKernel(CUfunction f, unsigned int gridDimX,
|
||||
unsigned int gridDimY, unsigned int gridDimZ,
|
||||
unsigned int blockDimX, unsigned int blockDimY,
|
||||
unsigned int blockDimZ,
|
||||
unsigned int sharedMemBytes, CUstream hStream,
|
||||
void **kernelParams, void **extra);
|
||||
// function management
|
||||
static CUresult cuFuncGetAttribute(int* pi, CUfunction_attribute attrib, CUfunction hfunc);
|
||||
static CUresult cuFuncSetAttribute(CUfunction hfunc, CUfunction_attribute attrib, int value);
|
||||
static CUresult cuFuncGetAttribute(int *pi, CUfunction_attribute attrib,
|
||||
CUfunction hfunc);
|
||||
static CUresult cuFuncSetAttribute(CUfunction hfunc,
|
||||
CUfunction_attribute attrib, int value);
|
||||
static CUresult cuFuncSetCacheConfig(CUfunction hfunc, CUfunc_cache config);
|
||||
// memory management
|
||||
static CUresult cuMemAlloc_v2(CUdeviceptr *dptr, size_t bytesize);
|
||||
static CUresult cuPointerGetAttribute(void * data, CUpointer_attribute attribute, CUdeviceptr ptr);
|
||||
static CUresult cuMemsetD8Async(CUdeviceptr dst, unsigned char x, size_t N, CUstream stream);
|
||||
static CUresult cuMemcpyDtoH_v2(void *dstHost, CUdeviceptr srcDevice, size_t ByteCount);
|
||||
static CUresult cuPointerGetAttribute(void *data,
|
||||
CUpointer_attribute attribute,
|
||||
CUdeviceptr ptr);
|
||||
static CUresult cuMemsetD8Async(CUdeviceptr dst, unsigned char x, size_t N,
|
||||
CUstream stream);
|
||||
static CUresult cuMemcpyDtoH_v2(void *dstHost, CUdeviceptr srcDevice,
|
||||
size_t ByteCount);
|
||||
static CUresult cuMemFree_v2(CUdeviceptr dptr);
|
||||
static CUresult cuMemcpyDtoHAsync_v2(void *dstHost, CUdeviceptr srcDevice, size_t ByteCount, CUstream hStream);
|
||||
static CUresult cuMemcpyHtoDAsync_v2(CUdeviceptr dstDevice, const void *srcHost, size_t ByteCount, CUstream hStream);
|
||||
static CUresult cuMemcpyHtoD_v2(CUdeviceptr dstDevice, const void *srcHost, size_t ByteCount);
|
||||
static CUresult cuMemcpyDtoHAsync_v2(void *dstHost, CUdeviceptr srcDevice,
|
||||
size_t ByteCount, CUstream hStream);
|
||||
static CUresult cuMemcpyHtoDAsync_v2(CUdeviceptr dstDevice,
|
||||
const void *srcHost, size_t ByteCount,
|
||||
CUstream hStream);
|
||||
static CUresult cuMemcpyHtoD_v2(CUdeviceptr dstDevice, const void *srcHost,
|
||||
size_t ByteCount);
|
||||
// event management
|
||||
static CUresult cuEventCreate(CUevent *phEvent, unsigned int Flags);
|
||||
static CUresult cuEventElapsedTime(float *pMilliseconds, CUevent hStart, CUevent hEnd);
|
||||
static CUresult cuEventElapsedTime(float *pMilliseconds, CUevent hStart,
|
||||
CUevent hEnd);
|
||||
static CUresult cuEventRecord(CUevent hEvent, CUstream hStream);
|
||||
static CUresult cuEventDestroy_v2(CUevent hEvent);
|
||||
|
||||
|
||||
/* ------------------- *
|
||||
* NVML
|
||||
* ------------------- */
|
||||
static nvmlReturn_t nvmlDeviceGetHandleByPciBusId_v2( const char* pciBusId, nvmlDevice_t* device);
|
||||
static nvmlReturn_t nvmlDeviceGetClockInfo(nvmlDevice_t device, nvmlClockType_t type, unsigned int *clock);
|
||||
static nvmlReturn_t nvmlDeviceGetMaxClockInfo(nvmlDevice_t device, nvmlClockType_t type, unsigned int *clock);
|
||||
static nvmlReturn_t nvmlDeviceSetApplicationsClocks(nvmlDevice_t device, unsigned int mem_clock, unsigned int sm_clock);
|
||||
static nvmlReturn_t nvmlDeviceGetHandleByPciBusId_v2(const char *pciBusId,
|
||||
nvmlDevice_t *device);
|
||||
static nvmlReturn_t nvmlDeviceGetClockInfo(nvmlDevice_t device,
|
||||
nvmlClockType_t type,
|
||||
unsigned int *clock);
|
||||
static nvmlReturn_t nvmlDeviceGetMaxClockInfo(nvmlDevice_t device,
|
||||
nvmlClockType_t type,
|
||||
unsigned int *clock);
|
||||
static nvmlReturn_t nvmlDeviceSetApplicationsClocks(nvmlDevice_t device,
|
||||
unsigned int mem_clock,
|
||||
unsigned int sm_clock);
|
||||
|
||||
/* ------------------- *
|
||||
* HIP
|
||||
@@ -140,177 +172,198 @@ public:
|
||||
// context management
|
||||
static hipError_t hipInit(unsigned int Flags);
|
||||
static hipError_t hipCtxDestroy(hipCtx_t ctx);
|
||||
static hipError_t hipCtxCreate(hipCtx_t *pctx, unsigned int flags, hipDevice_t dev);
|
||||
static hipError_t hipCtxCreate(hipCtx_t *pctx, unsigned int flags,
|
||||
hipDevice_t dev);
|
||||
static hipError_t hipCtxPushCurrent(hipCtx_t ctx);
|
||||
static hipError_t hipCtxPopCurrent(hipCtx_t *pctx);
|
||||
static hipError_t hipCtxGetDevice(hipDevice_t* result);
|
||||
static hipError_t hipCtxEnablePeerAccess(hipCtx_t peerContext, unsigned int flags);
|
||||
static hipError_t hipCtxGetDevice(hipDevice_t *result);
|
||||
static hipError_t hipCtxEnablePeerAccess(hipCtx_t peerContext,
|
||||
unsigned int flags);
|
||||
static hipError_t hipDriverGetVersion(int *driverVersion);
|
||||
// device management
|
||||
static hipError_t hipGetDevice(hipDevice_t *device, int ordinal);
|
||||
static hipError_t hipDeviceGetName(char *name, int len, hipDevice_t dev);
|
||||
static hipError_t hipDeviceGetPCIBusId(char *id, int len, hipDevice_t dev);
|
||||
static hipError_t hipDeviceGetAttribute(int *pi, hipDeviceAttribute_t attrib, hipDevice_t dev);
|
||||
static hipError_t hipDeviceGetAttribute(int *pi, hipDeviceAttribute_t attrib,
|
||||
hipDevice_t dev);
|
||||
static hipError_t hipGetDeviceCount(int *count);
|
||||
// module management
|
||||
static hipError_t hipModuleGetGlobal(hipDeviceptr_t *dptr, size_t* bytes, hipModule_t hmod, const char *name);
|
||||
static hipError_t hipModuleGetGlobal(hipDeviceptr_t *dptr, size_t *bytes,
|
||||
hipModule_t hmod, const char *name);
|
||||
static hipError_t hipModuleLoad(hipModule_t *module, const char *fname);
|
||||
static hipError_t hipModuleLoadData(hipModule_t* module, const void* image);
|
||||
static hipError_t hipModuleLoadData(hipModule_t *module, const void *image);
|
||||
static hipError_t hipModuleUnload(hipModule_t hmod);
|
||||
static hipError_t hipModuleLoadDataEx(hipModule_t *module, const void *image, unsigned int numOptions, hipJitOption *options, void **optionValues);
|
||||
static hipError_t hipModuleGetFunction(hipFunction_t *hfunc, hipModule_t hmod, const char *name);
|
||||
static hipError_t hipModuleLoadDataEx(hipModule_t *module, const void *image,
|
||||
unsigned int numOptions,
|
||||
hipJitOption *options,
|
||||
void **optionValues);
|
||||
static hipError_t hipModuleGetFunction(hipFunction_t *hfunc, hipModule_t hmod,
|
||||
const char *name);
|
||||
// stream management
|
||||
static hipError_t hipStreamCreate(hipStream_t *phStream, unsigned int Flags);
|
||||
static hipError_t hipStreamSynchronize(hipStream_t hStream);
|
||||
static hipError_t hipStreamDestroy(hipStream_t hStream);
|
||||
static hipError_t hipModuleLaunchKernel(hipFunction_t f, unsigned int gridDimX, unsigned int gridDimY, unsigned int gridDimZ, unsigned int blockDimX, unsigned int blockDimY, unsigned int blockDimZ, unsigned int sharedMemBytes, hipStream_t hStream, void **kernelParams, void **extra);
|
||||
static hipError_t
|
||||
hipModuleLaunchKernel(hipFunction_t f, unsigned int gridDimX,
|
||||
unsigned int gridDimY, unsigned int gridDimZ,
|
||||
unsigned int blockDimX, unsigned int blockDimY,
|
||||
unsigned int blockDimZ, unsigned int sharedMemBytes,
|
||||
hipStream_t hStream, void **kernelParams, void **extra);
|
||||
// function management
|
||||
static hipError_t hipFuncGetAttributes(hipFuncAttributes* attrib, void* hfunc);
|
||||
static hipError_t hipFuncSetAttribute(hipFunction_t hfunc, hipFuncAttribute attrib, int value);
|
||||
static hipError_t hipFuncSetCacheConfig(hipFunction_t hfunc, hipFuncCache_t config);
|
||||
static hipError_t hipFuncGetAttributes(hipFuncAttributes *attrib,
|
||||
void *hfunc);
|
||||
static hipError_t hipFuncSetAttribute(hipFunction_t hfunc,
|
||||
hipFuncAttribute attrib, int value);
|
||||
static hipError_t hipFuncSetCacheConfig(hipFunction_t hfunc,
|
||||
hipFuncCache_t config);
|
||||
// memory management
|
||||
static hipError_t hipMalloc(hipDeviceptr_t *dptr, size_t bytesize);
|
||||
static hipError_t hipPointerGetAttribute(void * data, CUpointer_attribute attribute, hipDeviceptr_t ptr);
|
||||
static hipError_t hipMemsetD8Async(hipDeviceptr_t dst, unsigned char x, size_t N, hipStream_t stream);
|
||||
static hipError_t hipMemcpyDtoH(void *dstHost, hipDeviceptr_t srcDevice, size_t ByteCount);
|
||||
static hipError_t hipPointerGetAttribute(void *data,
|
||||
CUpointer_attribute attribute,
|
||||
hipDeviceptr_t ptr);
|
||||
static hipError_t hipMemsetD8Async(hipDeviceptr_t dst, unsigned char x,
|
||||
size_t N, hipStream_t stream);
|
||||
static hipError_t hipMemcpyDtoH(void *dstHost, hipDeviceptr_t srcDevice,
|
||||
size_t ByteCount);
|
||||
static hipError_t hipFree(hipDeviceptr_t dptr);
|
||||
static hipError_t hipMemcpyDtoHAsync(void *dstHost, hipDeviceptr_t srcDevice, size_t ByteCount, hipStream_t hStream);
|
||||
static hipError_t hipMemcpyHtoDAsync(hipDeviceptr_t dstDevice, const void *srcHost, size_t ByteCount, hipStream_t hStream);
|
||||
static hipError_t hipMemcpyHtoD(hipDeviceptr_t dstDevice, const void *srcHost, size_t ByteCount);
|
||||
static hipError_t hipMemcpyDtoHAsync(void *dstHost, hipDeviceptr_t srcDevice,
|
||||
size_t ByteCount, hipStream_t hStream);
|
||||
static hipError_t hipMemcpyHtoDAsync(hipDeviceptr_t dstDevice,
|
||||
const void *srcHost, size_t ByteCount,
|
||||
hipStream_t hStream);
|
||||
static hipError_t hipMemcpyHtoD(hipDeviceptr_t dstDevice, const void *srcHost,
|
||||
size_t ByteCount);
|
||||
// event management
|
||||
static hipError_t hipEventCreate(hipEvent_t *phEvent, unsigned int Flags);
|
||||
static hipError_t hipEventElapsedTime(float *pMilliseconds, hipEvent_t hStart, hipEvent_t hEnd);
|
||||
static hipError_t hipEventElapsedTime(float *pMilliseconds, hipEvent_t hStart,
|
||||
hipEvent_t hEnd);
|
||||
static hipError_t hipEventRecord(hipEvent_t hEvent, hipStream_t hStream);
|
||||
static hipError_t hipEventDestroy(hipEvent_t hEvent);
|
||||
|
||||
|
||||
|
||||
private:
|
||||
|
||||
// Libraries
|
||||
static void* cuda_;
|
||||
static void* nvml_;
|
||||
static void* hip_;
|
||||
|
||||
static void *cuda_;
|
||||
static void *nvml_;
|
||||
static void *hip_;
|
||||
|
||||
/* ------------------- *
|
||||
* CUDA
|
||||
* ------------------- */
|
||||
// context management
|
||||
static void* cuCtxGetCurrent_;
|
||||
static void* cuCtxSetCurrent_;
|
||||
static void* cuCtxDestroy_v2_;
|
||||
static void* cuCtxCreate_v2_;
|
||||
static void* cuCtxGetDevice_;
|
||||
static void* cuCtxPushCurrent_v2_;
|
||||
static void* cuCtxPopCurrent_v2_;
|
||||
static void* cuCtxEnablePeerAccess_;
|
||||
static void* cuDriverGetVersion_;
|
||||
static void* cuInit_;
|
||||
static void *cuCtxGetCurrent_;
|
||||
static void *cuCtxSetCurrent_;
|
||||
static void *cuCtxDestroy_v2_;
|
||||
static void *cuCtxCreate_v2_;
|
||||
static void *cuCtxGetDevice_;
|
||||
static void *cuCtxPushCurrent_v2_;
|
||||
static void *cuCtxPopCurrent_v2_;
|
||||
static void *cuCtxEnablePeerAccess_;
|
||||
static void *cuDriverGetVersion_;
|
||||
static void *cuInit_;
|
||||
// device management
|
||||
static void* cuDeviceGet_;
|
||||
static void* cuDeviceGetName_;
|
||||
static void* cuDeviceGetPCIBusId_;
|
||||
static void* cuDeviceGetAttribute_;
|
||||
static void* cuDeviceGetCount_;
|
||||
static void *cuDeviceGet_;
|
||||
static void *cuDeviceGetName_;
|
||||
static void *cuDeviceGetPCIBusId_;
|
||||
static void *cuDeviceGetAttribute_;
|
||||
static void *cuDeviceGetCount_;
|
||||
// link management
|
||||
static void* cuLinkAddData_v2_;
|
||||
static void* cuLinkCreate_v2_;
|
||||
static void* cuLinkDestroy_;
|
||||
static void* cuLinkComplete_;
|
||||
static void *cuLinkAddData_v2_;
|
||||
static void *cuLinkCreate_v2_;
|
||||
static void *cuLinkDestroy_;
|
||||
static void *cuLinkComplete_;
|
||||
// module management
|
||||
static void* cuModuleGetGlobal_v2_;
|
||||
static void* cuModuleLoad_;
|
||||
static void* cuModuleUnload_;
|
||||
static void* cuModuleLoadDataEx_;
|
||||
static void* cuModuleLoadData_;
|
||||
static void* cuModuleGetFunction_;
|
||||
static void *cuModuleGetGlobal_v2_;
|
||||
static void *cuModuleLoad_;
|
||||
static void *cuModuleUnload_;
|
||||
static void *cuModuleLoadDataEx_;
|
||||
static void *cuModuleLoadData_;
|
||||
static void *cuModuleGetFunction_;
|
||||
// stream management
|
||||
static void* cuStreamCreate_;
|
||||
static void* cuStreamSynchronize_;
|
||||
static void* cuStreamDestroy_v2_;
|
||||
static void* cuStreamGetCtx_;
|
||||
static void* cuLaunchKernel_;
|
||||
static void *cuStreamCreate_;
|
||||
static void *cuStreamSynchronize_;
|
||||
static void *cuStreamDestroy_v2_;
|
||||
static void *cuStreamGetCtx_;
|
||||
static void *cuLaunchKernel_;
|
||||
// function management
|
||||
static void* cuFuncGetAttribute_;
|
||||
static void* cuFuncSetAttribute_;
|
||||
static void* cuFuncSetCacheConfig_;
|
||||
static void *cuFuncGetAttribute_;
|
||||
static void *cuFuncSetAttribute_;
|
||||
static void *cuFuncSetCacheConfig_;
|
||||
// memory management
|
||||
static void* cuMemcpyDtoH_v2_;
|
||||
static void* cuMemFree_v2_;
|
||||
static void* cuMemcpyDtoHAsync_v2_;
|
||||
static void* cuMemcpyHtoDAsync_v2_;
|
||||
static void* cuMemcpyHtoD_v2_;
|
||||
static void* cuMemAlloc_v2_;
|
||||
static void* cuMemsetD8Async_;
|
||||
static void* cuPointerGetAttribute_;
|
||||
static void *cuMemcpyDtoH_v2_;
|
||||
static void *cuMemFree_v2_;
|
||||
static void *cuMemcpyDtoHAsync_v2_;
|
||||
static void *cuMemcpyHtoDAsync_v2_;
|
||||
static void *cuMemcpyHtoD_v2_;
|
||||
static void *cuMemAlloc_v2_;
|
||||
static void *cuMemsetD8Async_;
|
||||
static void *cuPointerGetAttribute_;
|
||||
// event management
|
||||
static void* cuEventCreate_;
|
||||
static void* cuEventElapsedTime_;
|
||||
static void* cuEventRecord_;
|
||||
static void* cuEventDestroy_v2_;
|
||||
static void *cuEventCreate_;
|
||||
static void *cuEventElapsedTime_;
|
||||
static void *cuEventRecord_;
|
||||
static void *cuEventDestroy_v2_;
|
||||
|
||||
/* ------------------- *
|
||||
* NVML
|
||||
* ------------------- */
|
||||
static void* nvmlInit_v2_;
|
||||
static void* nvmlDeviceGetHandleByPciBusId_v2_;
|
||||
static void* nvmlDeviceGetClockInfo_;
|
||||
static void* nvmlDeviceGetMaxClockInfo_;
|
||||
static void* nvmlDeviceSetApplicationsClocks_;
|
||||
static void *nvmlInit_v2_;
|
||||
static void *nvmlDeviceGetHandleByPciBusId_v2_;
|
||||
static void *nvmlDeviceGetClockInfo_;
|
||||
static void *nvmlDeviceGetMaxClockInfo_;
|
||||
static void *nvmlDeviceSetApplicationsClocks_;
|
||||
|
||||
/* ------------------- *
|
||||
* HIP
|
||||
* ------------------- */
|
||||
// context management
|
||||
static void* hipInit_;
|
||||
static void* hipCtxDestroy_;
|
||||
static void* hipCtxCreate_;
|
||||
static void* hipCtxPushCurrent_;
|
||||
static void* hipCtxPopCurrent_;
|
||||
static void* hipCtxGetDevice_;
|
||||
static void* hipCtxEnablePeerAccess_;
|
||||
static void* hipDriverGetVersion_;
|
||||
static void *hipInit_;
|
||||
static void *hipCtxDestroy_;
|
||||
static void *hipCtxCreate_;
|
||||
static void *hipCtxPushCurrent_;
|
||||
static void *hipCtxPopCurrent_;
|
||||
static void *hipCtxGetDevice_;
|
||||
static void *hipCtxEnablePeerAccess_;
|
||||
static void *hipDriverGetVersion_;
|
||||
// device management
|
||||
static void* hipGetDevice_;
|
||||
static void* hipDeviceGetName_;
|
||||
static void* hipDeviceGetPCIBusId_;
|
||||
static void* hipDeviceGetAttribute_;
|
||||
static void* hipGetDeviceCount_;
|
||||
static void *hipGetDevice_;
|
||||
static void *hipDeviceGetName_;
|
||||
static void *hipDeviceGetPCIBusId_;
|
||||
static void *hipDeviceGetAttribute_;
|
||||
static void *hipGetDeviceCount_;
|
||||
// module management
|
||||
static void* hipModuleGetGlobal_;
|
||||
static void* hipModuleLoad_;
|
||||
static void* hipModuleLoadData_;
|
||||
static void* hipModuleUnload_;
|
||||
static void* hipModuleLoadDataEx_;
|
||||
static void* hipModuleGetFunction_;
|
||||
static void *hipModuleGetGlobal_;
|
||||
static void *hipModuleLoad_;
|
||||
static void *hipModuleLoadData_;
|
||||
static void *hipModuleUnload_;
|
||||
static void *hipModuleLoadDataEx_;
|
||||
static void *hipModuleGetFunction_;
|
||||
// stream management
|
||||
static void* hipStreamCreate_;
|
||||
static void* hipStreamSynchronize_;
|
||||
static void* hipStreamDestroy_;
|
||||
static void* hipModuleLaunchKernel_;;
|
||||
static void *hipStreamCreate_;
|
||||
static void *hipStreamSynchronize_;
|
||||
static void *hipStreamDestroy_;
|
||||
static void *hipModuleLaunchKernel_;
|
||||
;
|
||||
// function management
|
||||
static void* hipFuncGetAttributes_;
|
||||
static void* hipFuncSetAttribute_;
|
||||
static void* hipFuncSetCacheConfig_;
|
||||
static void *hipFuncGetAttributes_;
|
||||
static void *hipFuncSetAttribute_;
|
||||
static void *hipFuncSetCacheConfig_;
|
||||
// memory management
|
||||
static void* hipMalloc_;
|
||||
static void* hipPointerGetAttribute_;
|
||||
static void* hipMemsetD8Async_;
|
||||
static void* hipMemcpyDtoH_;
|
||||
static void* hipFree_;
|
||||
static void* hipMemcpyDtoHAsync_;
|
||||
static void* hipMemcpyHtoDAsync_;
|
||||
static void* hipMemcpyHtoD_;
|
||||
static void *hipMalloc_;
|
||||
static void *hipPointerGetAttribute_;
|
||||
static void *hipMemsetD8Async_;
|
||||
static void *hipMemcpyDtoH_;
|
||||
static void *hipFree_;
|
||||
static void *hipMemcpyDtoHAsync_;
|
||||
static void *hipMemcpyHtoDAsync_;
|
||||
static void *hipMemcpyHtoD_;
|
||||
// event management
|
||||
static void* hipEventCreate_;
|
||||
static void* hipEventElapsedTime_;
|
||||
static void* hipEventRecord_;
|
||||
static void* hipEventDestroy_;
|
||||
static void *hipEventCreate_;
|
||||
static void *hipEventElapsedTime_;
|
||||
static void *hipEventRecord_;
|
||||
static void *hipEventDestroy_;
|
||||
};
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace driver
|
||||
} // namespace triton
|
||||
|
||||
#endif
|
||||
|
415
include/triton/driver/error.h
Executable file → Normal file
415
include/triton/driver/error.h
Executable file → Normal file
@@ -3,223 +3,252 @@
|
||||
#ifndef _TRITON_DRIVER_ERROR_H_
|
||||
#define _TRITON_DRIVER_ERROR_H_
|
||||
|
||||
#include <exception>
|
||||
#include "triton/driver/dispatch.h"
|
||||
#include <exception>
|
||||
|
||||
namespace triton {
|
||||
|
||||
namespace triton
|
||||
{
|
||||
namespace driver {
|
||||
|
||||
namespace driver
|
||||
{
|
||||
namespace exception {
|
||||
|
||||
namespace exception
|
||||
{
|
||||
namespace nvrtc {
|
||||
|
||||
namespace nvrtc
|
||||
{
|
||||
#define TRITON_CREATE_NVRTC_EXCEPTION(name, msg) \
|
||||
class name : public std::exception { \
|
||||
public: \
|
||||
const char *what() const throw() override { return "NVRTC: Error- " msg; } \
|
||||
}
|
||||
|
||||
#define TRITON_CREATE_NVRTC_EXCEPTION(name, msg) \
|
||||
class name: public std::exception { public: const char * what() const throw() override { return "NVRTC: Error- " msg; } }
|
||||
|
||||
TRITON_CREATE_NVRTC_EXCEPTION(out_of_memory ,"out of memory");
|
||||
TRITON_CREATE_NVRTC_EXCEPTION(program_creation_failure ,"program creation failure");
|
||||
TRITON_CREATE_NVRTC_EXCEPTION(invalid_input ,"invalid input");
|
||||
TRITON_CREATE_NVRTC_EXCEPTION(invalid_program ,"invalid program");
|
||||
TRITON_CREATE_NVRTC_EXCEPTION(invalid_option ,"invalid option");
|
||||
TRITON_CREATE_NVRTC_EXCEPTION(compilation ,"compilation");
|
||||
TRITON_CREATE_NVRTC_EXCEPTION(builtin_operation_failure ,"builtin operation failure");
|
||||
TRITON_CREATE_NVRTC_EXCEPTION(unknown_error ,"unknown error");
|
||||
TRITON_CREATE_NVRTC_EXCEPTION(out_of_memory, "out of memory");
|
||||
TRITON_CREATE_NVRTC_EXCEPTION(program_creation_failure,
|
||||
"program creation failure");
|
||||
TRITON_CREATE_NVRTC_EXCEPTION(invalid_input, "invalid input");
|
||||
TRITON_CREATE_NVRTC_EXCEPTION(invalid_program, "invalid program");
|
||||
TRITON_CREATE_NVRTC_EXCEPTION(invalid_option, "invalid option");
|
||||
TRITON_CREATE_NVRTC_EXCEPTION(compilation, "compilation");
|
||||
TRITON_CREATE_NVRTC_EXCEPTION(builtin_operation_failure,
|
||||
"builtin operation failure");
|
||||
TRITON_CREATE_NVRTC_EXCEPTION(unknown_error, "unknown error");
|
||||
|
||||
#undef TRITON_CREATE_NVRTC_EXCEPTION
|
||||
} // namespace nvrtc
|
||||
|
||||
namespace cuda {
|
||||
class base : public std::exception {};
|
||||
|
||||
#define TRITON_CREATE_CUDA_EXCEPTION(name, msg) \
|
||||
class name : public base { \
|
||||
public: \
|
||||
const char *what() const throw() override { return "CUDA: Error- " msg; } \
|
||||
}
|
||||
|
||||
|
||||
namespace cuda
|
||||
{
|
||||
class base: public std::exception{};
|
||||
|
||||
#define TRITON_CREATE_CUDA_EXCEPTION(name, msg) \
|
||||
class name: public base { public:const char * what() const throw() override { return "CUDA: Error- " msg; } }
|
||||
|
||||
|
||||
TRITON_CREATE_CUDA_EXCEPTION(invalid_value ,"invalid value");
|
||||
TRITON_CREATE_CUDA_EXCEPTION(out_of_memory ,"out of memory");
|
||||
TRITON_CREATE_CUDA_EXCEPTION(not_initialized ,"not initialized");
|
||||
TRITON_CREATE_CUDA_EXCEPTION(deinitialized ,"deinitialized");
|
||||
TRITON_CREATE_CUDA_EXCEPTION(profiler_disabled ,"profiler disabled");
|
||||
TRITON_CREATE_CUDA_EXCEPTION(profiler_not_initialized ,"profiler not initialized");
|
||||
TRITON_CREATE_CUDA_EXCEPTION(profiler_already_started ,"profiler already started");
|
||||
TRITON_CREATE_CUDA_EXCEPTION(profiler_already_stopped ,"profiler already stopped");
|
||||
TRITON_CREATE_CUDA_EXCEPTION(no_device ,"no device");
|
||||
TRITON_CREATE_CUDA_EXCEPTION(invalid_device ,"invalid device");
|
||||
TRITON_CREATE_CUDA_EXCEPTION(invalid_image ,"invalid image");
|
||||
TRITON_CREATE_CUDA_EXCEPTION(invalid_context ,"invalid context");
|
||||
TRITON_CREATE_CUDA_EXCEPTION(context_already_current ,"context already current");
|
||||
TRITON_CREATE_CUDA_EXCEPTION(map_failed ,"map failed");
|
||||
TRITON_CREATE_CUDA_EXCEPTION(unmap_failed ,"unmap failed");
|
||||
TRITON_CREATE_CUDA_EXCEPTION(array_is_mapped ,"array is mapped");
|
||||
TRITON_CREATE_CUDA_EXCEPTION(already_mapped ,"already mapped");
|
||||
TRITON_CREATE_CUDA_EXCEPTION(no_binary_for_gpu ,"no binary for gpu");
|
||||
TRITON_CREATE_CUDA_EXCEPTION(already_acquired ,"already acquired");
|
||||
TRITON_CREATE_CUDA_EXCEPTION(not_mapped ,"not mapped");
|
||||
TRITON_CREATE_CUDA_EXCEPTION(not_mapped_as_array ,"not mapped as array");
|
||||
TRITON_CREATE_CUDA_EXCEPTION(not_mapped_as_pointer ,"not mapped as pointer");
|
||||
TRITON_CREATE_CUDA_EXCEPTION(ecc_uncorrectable ,"ecc uncorrectable");
|
||||
TRITON_CREATE_CUDA_EXCEPTION(unsupported_limit ,"unsupported limit");
|
||||
TRITON_CREATE_CUDA_EXCEPTION(context_already_in_use ,"context already in use");
|
||||
TRITON_CREATE_CUDA_EXCEPTION(peer_access_unsupported ,"peer access unsupported");
|
||||
TRITON_CREATE_CUDA_EXCEPTION(invalid_ptx ,"invalid ptx");
|
||||
TRITON_CREATE_CUDA_EXCEPTION(invalid_graphics_context ,"invalid graphics context");
|
||||
TRITON_CREATE_CUDA_EXCEPTION(invalid_source ,"invalid source");
|
||||
TRITON_CREATE_CUDA_EXCEPTION(file_not_found ,"file not found");
|
||||
TRITON_CREATE_CUDA_EXCEPTION(shared_object_symbol_not_found ,"shared object symbol not found");
|
||||
TRITON_CREATE_CUDA_EXCEPTION(shared_object_init_failed ,"shared object init failed");
|
||||
TRITON_CREATE_CUDA_EXCEPTION(operating_system ,"operating system");
|
||||
TRITON_CREATE_CUDA_EXCEPTION(invalid_handle ,"invalid handle");
|
||||
TRITON_CREATE_CUDA_EXCEPTION(not_found ,"not found");
|
||||
TRITON_CREATE_CUDA_EXCEPTION(not_ready ,"not ready");
|
||||
TRITON_CREATE_CUDA_EXCEPTION(illegal_address ,"illegal address");
|
||||
TRITON_CREATE_CUDA_EXCEPTION(launch_out_of_resources ,"launch out of resources");
|
||||
TRITON_CREATE_CUDA_EXCEPTION(launch_timeout ,"launch timeout");
|
||||
TRITON_CREATE_CUDA_EXCEPTION(launch_incompatible_texturing ,"launch incompatible texturing");
|
||||
TRITON_CREATE_CUDA_EXCEPTION(peer_access_already_enabled ,"peer access already enabled");
|
||||
TRITON_CREATE_CUDA_EXCEPTION(peer_access_not_enabled ,"peer access not enabled");
|
||||
TRITON_CREATE_CUDA_EXCEPTION(primary_context_active ,"primary context active");
|
||||
TRITON_CREATE_CUDA_EXCEPTION(context_is_destroyed ,"context is destroyed");
|
||||
TRITON_CREATE_CUDA_EXCEPTION(assert_error ,"assert");
|
||||
TRITON_CREATE_CUDA_EXCEPTION(too_many_peers ,"too many peers");
|
||||
TRITON_CREATE_CUDA_EXCEPTION(host_memory_already_registered ,"host memory already registered");
|
||||
TRITON_CREATE_CUDA_EXCEPTION(host_memory_not_registered ,"hot memory not registered");
|
||||
TRITON_CREATE_CUDA_EXCEPTION(hardware_stack_error ,"hardware stack error");
|
||||
TRITON_CREATE_CUDA_EXCEPTION(illegal_instruction ,"illegal instruction");
|
||||
TRITON_CREATE_CUDA_EXCEPTION(misaligned_address ,"misaligned address");
|
||||
TRITON_CREATE_CUDA_EXCEPTION(invalid_address_space ,"invalid address space");
|
||||
TRITON_CREATE_CUDA_EXCEPTION(invalid_pc ,"invalid pc");
|
||||
TRITON_CREATE_CUDA_EXCEPTION(launch_failed ,"launch failed");
|
||||
TRITON_CREATE_CUDA_EXCEPTION(not_permitted ,"not permitted");
|
||||
TRITON_CREATE_CUDA_EXCEPTION(not_supported ,"not supported");
|
||||
TRITON_CREATE_CUDA_EXCEPTION(unknown ,"unknown");
|
||||
TRITON_CREATE_CUDA_EXCEPTION(invalid_value, "invalid value");
|
||||
TRITON_CREATE_CUDA_EXCEPTION(out_of_memory, "out of memory");
|
||||
TRITON_CREATE_CUDA_EXCEPTION(not_initialized, "not initialized");
|
||||
TRITON_CREATE_CUDA_EXCEPTION(deinitialized, "deinitialized");
|
||||
TRITON_CREATE_CUDA_EXCEPTION(profiler_disabled, "profiler disabled");
|
||||
TRITON_CREATE_CUDA_EXCEPTION(profiler_not_initialized,
|
||||
"profiler not initialized");
|
||||
TRITON_CREATE_CUDA_EXCEPTION(profiler_already_started,
|
||||
"profiler already started");
|
||||
TRITON_CREATE_CUDA_EXCEPTION(profiler_already_stopped,
|
||||
"profiler already stopped");
|
||||
TRITON_CREATE_CUDA_EXCEPTION(no_device, "no device");
|
||||
TRITON_CREATE_CUDA_EXCEPTION(invalid_device, "invalid device");
|
||||
TRITON_CREATE_CUDA_EXCEPTION(invalid_image, "invalid image");
|
||||
TRITON_CREATE_CUDA_EXCEPTION(invalid_context, "invalid context");
|
||||
TRITON_CREATE_CUDA_EXCEPTION(context_already_current,
|
||||
"context already current");
|
||||
TRITON_CREATE_CUDA_EXCEPTION(map_failed, "map failed");
|
||||
TRITON_CREATE_CUDA_EXCEPTION(unmap_failed, "unmap failed");
|
||||
TRITON_CREATE_CUDA_EXCEPTION(array_is_mapped, "array is mapped");
|
||||
TRITON_CREATE_CUDA_EXCEPTION(already_mapped, "already mapped");
|
||||
TRITON_CREATE_CUDA_EXCEPTION(no_binary_for_gpu, "no binary for gpu");
|
||||
TRITON_CREATE_CUDA_EXCEPTION(already_acquired, "already acquired");
|
||||
TRITON_CREATE_CUDA_EXCEPTION(not_mapped, "not mapped");
|
||||
TRITON_CREATE_CUDA_EXCEPTION(not_mapped_as_array, "not mapped as array");
|
||||
TRITON_CREATE_CUDA_EXCEPTION(not_mapped_as_pointer, "not mapped as pointer");
|
||||
TRITON_CREATE_CUDA_EXCEPTION(ecc_uncorrectable, "ecc uncorrectable");
|
||||
TRITON_CREATE_CUDA_EXCEPTION(unsupported_limit, "unsupported limit");
|
||||
TRITON_CREATE_CUDA_EXCEPTION(context_already_in_use, "context already in use");
|
||||
TRITON_CREATE_CUDA_EXCEPTION(peer_access_unsupported,
|
||||
"peer access unsupported");
|
||||
TRITON_CREATE_CUDA_EXCEPTION(invalid_ptx, "invalid ptx");
|
||||
TRITON_CREATE_CUDA_EXCEPTION(invalid_graphics_context,
|
||||
"invalid graphics context");
|
||||
TRITON_CREATE_CUDA_EXCEPTION(invalid_source, "invalid source");
|
||||
TRITON_CREATE_CUDA_EXCEPTION(file_not_found, "file not found");
|
||||
TRITON_CREATE_CUDA_EXCEPTION(shared_object_symbol_not_found,
|
||||
"shared object symbol not found");
|
||||
TRITON_CREATE_CUDA_EXCEPTION(shared_object_init_failed,
|
||||
"shared object init failed");
|
||||
TRITON_CREATE_CUDA_EXCEPTION(operating_system, "operating system");
|
||||
TRITON_CREATE_CUDA_EXCEPTION(invalid_handle, "invalid handle");
|
||||
TRITON_CREATE_CUDA_EXCEPTION(not_found, "not found");
|
||||
TRITON_CREATE_CUDA_EXCEPTION(not_ready, "not ready");
|
||||
TRITON_CREATE_CUDA_EXCEPTION(illegal_address, "illegal address");
|
||||
TRITON_CREATE_CUDA_EXCEPTION(launch_out_of_resources,
|
||||
"launch out of resources");
|
||||
TRITON_CREATE_CUDA_EXCEPTION(launch_timeout, "launch timeout");
|
||||
TRITON_CREATE_CUDA_EXCEPTION(launch_incompatible_texturing,
|
||||
"launch incompatible texturing");
|
||||
TRITON_CREATE_CUDA_EXCEPTION(peer_access_already_enabled,
|
||||
"peer access already enabled");
|
||||
TRITON_CREATE_CUDA_EXCEPTION(peer_access_not_enabled,
|
||||
"peer access not enabled");
|
||||
TRITON_CREATE_CUDA_EXCEPTION(primary_context_active, "primary context active");
|
||||
TRITON_CREATE_CUDA_EXCEPTION(context_is_destroyed, "context is destroyed");
|
||||
TRITON_CREATE_CUDA_EXCEPTION(assert_error, "assert");
|
||||
TRITON_CREATE_CUDA_EXCEPTION(too_many_peers, "too many peers");
|
||||
TRITON_CREATE_CUDA_EXCEPTION(host_memory_already_registered,
|
||||
"host memory already registered");
|
||||
TRITON_CREATE_CUDA_EXCEPTION(host_memory_not_registered,
|
||||
"hot memory not registered");
|
||||
TRITON_CREATE_CUDA_EXCEPTION(hardware_stack_error, "hardware stack error");
|
||||
TRITON_CREATE_CUDA_EXCEPTION(illegal_instruction, "illegal instruction");
|
||||
TRITON_CREATE_CUDA_EXCEPTION(misaligned_address, "misaligned address");
|
||||
TRITON_CREATE_CUDA_EXCEPTION(invalid_address_space, "invalid address space");
|
||||
TRITON_CREATE_CUDA_EXCEPTION(invalid_pc, "invalid pc");
|
||||
TRITON_CREATE_CUDA_EXCEPTION(launch_failed, "launch failed");
|
||||
TRITON_CREATE_CUDA_EXCEPTION(not_permitted, "not permitted");
|
||||
TRITON_CREATE_CUDA_EXCEPTION(not_supported, "not supported");
|
||||
TRITON_CREATE_CUDA_EXCEPTION(unknown, "unknown");
|
||||
|
||||
#undef TRITON_CREATE_CUDA_EXCEPTION
|
||||
} // namespace cuda
|
||||
|
||||
namespace cublas {
|
||||
class base : public std::exception {};
|
||||
|
||||
#define TRITON_CREATE_CUBLAS_EXCEPTION(name, msg) \
|
||||
class name : public base { \
|
||||
public: \
|
||||
const char *what() const throw() override { \
|
||||
return "CUBLAS: Error- " msg; \
|
||||
} \
|
||||
}
|
||||
|
||||
namespace cublas
|
||||
{
|
||||
class base: public std::exception{};
|
||||
|
||||
#define TRITON_CREATE_CUBLAS_EXCEPTION(name, msg) \
|
||||
class name: public base { public: const char * what() const throw() override { return "CUBLAS: Error- " msg; } }
|
||||
|
||||
TRITON_CREATE_CUBLAS_EXCEPTION(not_initialized ,"not initialized");
|
||||
TRITON_CREATE_CUBLAS_EXCEPTION(alloc_failed ,"alloc failed");
|
||||
TRITON_CREATE_CUBLAS_EXCEPTION(invalid_value ,"invalid value");
|
||||
TRITON_CREATE_CUBLAS_EXCEPTION(arch_mismatch ,"arch mismatch");
|
||||
TRITON_CREATE_CUBLAS_EXCEPTION(mapping_error ,"mapping error");
|
||||
TRITON_CREATE_CUBLAS_EXCEPTION(execution_failed ,"execution failed");
|
||||
TRITON_CREATE_CUBLAS_EXCEPTION(internal_error ,"internal error");
|
||||
TRITON_CREATE_CUBLAS_EXCEPTION(not_supported ,"not supported");
|
||||
TRITON_CREATE_CUBLAS_EXCEPTION(license_error ,"license error");
|
||||
TRITON_CREATE_CUBLAS_EXCEPTION(unknown ,"unknown");
|
||||
TRITON_CREATE_CUBLAS_EXCEPTION(not_initialized, "not initialized");
|
||||
TRITON_CREATE_CUBLAS_EXCEPTION(alloc_failed, "alloc failed");
|
||||
TRITON_CREATE_CUBLAS_EXCEPTION(invalid_value, "invalid value");
|
||||
TRITON_CREATE_CUBLAS_EXCEPTION(arch_mismatch, "arch mismatch");
|
||||
TRITON_CREATE_CUBLAS_EXCEPTION(mapping_error, "mapping error");
|
||||
TRITON_CREATE_CUBLAS_EXCEPTION(execution_failed, "execution failed");
|
||||
TRITON_CREATE_CUBLAS_EXCEPTION(internal_error, "internal error");
|
||||
TRITON_CREATE_CUBLAS_EXCEPTION(not_supported, "not supported");
|
||||
TRITON_CREATE_CUBLAS_EXCEPTION(license_error, "license error");
|
||||
TRITON_CREATE_CUBLAS_EXCEPTION(unknown, "unknown");
|
||||
|
||||
#undef TRITON_CREATE_CUBLAS_EXCEPTION
|
||||
} // namespace cublas
|
||||
|
||||
namespace cudnn {
|
||||
#define TRITON_CREATE_CUDNN_EXCEPTION(name, msg) \
|
||||
class name : public std::exception { \
|
||||
public: \
|
||||
const char *what() const throw() override { return "CUDNN: Error- " msg; } \
|
||||
}
|
||||
|
||||
namespace cudnn
|
||||
{
|
||||
#define TRITON_CREATE_CUDNN_EXCEPTION(name, msg) \
|
||||
class name: public std::exception { public: const char * what() const throw() override { return "CUDNN: Error- " msg; } }
|
||||
TRITON_CREATE_CUDNN_EXCEPTION(not_initialized, "not initialized");
|
||||
TRITON_CREATE_CUDNN_EXCEPTION(alloc_failed, "allocation failed");
|
||||
TRITON_CREATE_CUDNN_EXCEPTION(bad_param, "bad param");
|
||||
TRITON_CREATE_CUDNN_EXCEPTION(internal_error, "internal error");
|
||||
TRITON_CREATE_CUDNN_EXCEPTION(invalid_value, "invalid value");
|
||||
TRITON_CREATE_CUDNN_EXCEPTION(arch_mismatch, "arch mismatch");
|
||||
TRITON_CREATE_CUDNN_EXCEPTION(mapping_error, "mapping error");
|
||||
TRITON_CREATE_CUDNN_EXCEPTION(execution_failed, "execution failed");
|
||||
TRITON_CREATE_CUDNN_EXCEPTION(not_supported, "not supported");
|
||||
TRITON_CREATE_CUDNN_EXCEPTION(license_error, "license error");
|
||||
TRITON_CREATE_CUDNN_EXCEPTION(runtime_prerequisite_missing,
|
||||
"prerequisite missing");
|
||||
TRITON_CREATE_CUDNN_EXCEPTION(runtime_in_progress, "runtime in progress");
|
||||
TRITON_CREATE_CUDNN_EXCEPTION(runtime_fp_overflow, "runtime fp overflow");
|
||||
} // namespace cudnn
|
||||
|
||||
TRITON_CREATE_CUDNN_EXCEPTION(not_initialized ,"not initialized");
|
||||
TRITON_CREATE_CUDNN_EXCEPTION(alloc_failed ,"allocation failed");
|
||||
TRITON_CREATE_CUDNN_EXCEPTION(bad_param ,"bad param");
|
||||
TRITON_CREATE_CUDNN_EXCEPTION(internal_error ,"internal error");
|
||||
TRITON_CREATE_CUDNN_EXCEPTION(invalid_value ,"invalid value");
|
||||
TRITON_CREATE_CUDNN_EXCEPTION(arch_mismatch ,"arch mismatch");
|
||||
TRITON_CREATE_CUDNN_EXCEPTION(mapping_error ,"mapping error");
|
||||
TRITON_CREATE_CUDNN_EXCEPTION(execution_failed ,"execution failed");
|
||||
TRITON_CREATE_CUDNN_EXCEPTION(not_supported ,"not supported");
|
||||
TRITON_CREATE_CUDNN_EXCEPTION(license_error ,"license error");
|
||||
TRITON_CREATE_CUDNN_EXCEPTION(runtime_prerequisite_missing ,"prerequisite missing");
|
||||
TRITON_CREATE_CUDNN_EXCEPTION(runtime_in_progress ,"runtime in progress");
|
||||
TRITON_CREATE_CUDNN_EXCEPTION(runtime_fp_overflow ,"runtime fp overflow");
|
||||
namespace hip {
|
||||
class base : public std::exception {};
|
||||
|
||||
#define TRITON_CREATE_HIP_EXCEPTION(name, msg) \
|
||||
class name : public base { \
|
||||
public: \
|
||||
const char *what() const throw() override { return "HIP: Error- " msg; } \
|
||||
}
|
||||
|
||||
|
||||
|
||||
|
||||
namespace hip
|
||||
{
|
||||
class base: public std::exception{};
|
||||
|
||||
#define TRITON_CREATE_HIP_EXCEPTION(name, msg) \
|
||||
class name: public base { public:const char * what() const throw() override { return "HIP: Error- " msg; } }
|
||||
|
||||
|
||||
TRITON_CREATE_HIP_EXCEPTION(invalid_value ,"invalid value");
|
||||
TRITON_CREATE_HIP_EXCEPTION(out_of_memory ,"out of memory");
|
||||
TRITON_CREATE_HIP_EXCEPTION(not_initialized ,"not initialized");
|
||||
TRITON_CREATE_HIP_EXCEPTION(deinitialized ,"deinitialized");
|
||||
TRITON_CREATE_HIP_EXCEPTION(profiler_disabled ,"profiler disabled");
|
||||
TRITON_CREATE_HIP_EXCEPTION(profiler_not_initialized ,"profiler not initialized");
|
||||
TRITON_CREATE_HIP_EXCEPTION(profiler_already_started ,"profiler already started");
|
||||
TRITON_CREATE_HIP_EXCEPTION(profiler_already_stopped ,"profiler already stopped");
|
||||
TRITON_CREATE_HIP_EXCEPTION(no_device ,"no device");
|
||||
TRITON_CREATE_HIP_EXCEPTION(invalid_device ,"invalid device");
|
||||
TRITON_CREATE_HIP_EXCEPTION(invalid_image ,"invalid image");
|
||||
TRITON_CREATE_HIP_EXCEPTION(invalid_context ,"invalid context");
|
||||
TRITON_CREATE_HIP_EXCEPTION(context_already_current ,"context already current");
|
||||
TRITON_CREATE_HIP_EXCEPTION(map_failed ,"map failed");
|
||||
TRITON_CREATE_HIP_EXCEPTION(unmap_failed ,"unmap failed");
|
||||
TRITON_CREATE_HIP_EXCEPTION(array_is_mapped ,"array is mapped");
|
||||
TRITON_CREATE_HIP_EXCEPTION(already_mapped ,"already mapped");
|
||||
TRITON_CREATE_HIP_EXCEPTION(no_binary_for_gpu ,"no binary for gpu");
|
||||
TRITON_CREATE_HIP_EXCEPTION(already_acquired ,"already acquired");
|
||||
TRITON_CREATE_HIP_EXCEPTION(not_mapped ,"not mapped");
|
||||
TRITON_CREATE_HIP_EXCEPTION(not_mapped_as_array ,"not mapped as array");
|
||||
TRITON_CREATE_HIP_EXCEPTION(not_mapped_as_pointer ,"not mapped as pointer");
|
||||
TRITON_CREATE_HIP_EXCEPTION(ecc_uncorrectable ,"ecc uncorrectable");
|
||||
TRITON_CREATE_HIP_EXCEPTION(unsupported_limit ,"unsupported limit");
|
||||
TRITON_CREATE_HIP_EXCEPTION(context_already_in_use ,"context already in use");
|
||||
TRITON_CREATE_HIP_EXCEPTION(peer_access_unsupported ,"peer access unsupported");
|
||||
TRITON_CREATE_HIP_EXCEPTION(invalid_ptx ,"invalid ptx");
|
||||
TRITON_CREATE_HIP_EXCEPTION(invalid_graphics_context ,"invalid graphics context");
|
||||
TRITON_CREATE_HIP_EXCEPTION(invalid_source ,"invalid source");
|
||||
TRITON_CREATE_HIP_EXCEPTION(file_not_found ,"file not found");
|
||||
TRITON_CREATE_HIP_EXCEPTION(shared_object_symbol_not_found ,"shared object symbol not found");
|
||||
TRITON_CREATE_HIP_EXCEPTION(shared_object_init_failed ,"shared object init failed");
|
||||
TRITON_CREATE_HIP_EXCEPTION(operating_system ,"operating system");
|
||||
TRITON_CREATE_HIP_EXCEPTION(invalid_handle ,"invalid handle");
|
||||
TRITON_CREATE_HIP_EXCEPTION(not_found ,"not found");
|
||||
TRITON_CREATE_HIP_EXCEPTION(not_ready ,"not ready");
|
||||
TRITON_CREATE_HIP_EXCEPTION(illegal_address ,"illegal address");
|
||||
TRITON_CREATE_HIP_EXCEPTION(launch_out_of_resources ,"launch out of resources");
|
||||
TRITON_CREATE_HIP_EXCEPTION(launch_timeout ,"launch timeout");
|
||||
TRITON_CREATE_HIP_EXCEPTION(launch_incompatible_texturing ,"launch incompatible texturing");
|
||||
TRITON_CREATE_HIP_EXCEPTION(peer_access_already_enabled ,"peer access already enabled");
|
||||
TRITON_CREATE_HIP_EXCEPTION(peer_access_not_enabled ,"peer access not enabled");
|
||||
TRITON_CREATE_HIP_EXCEPTION(primary_context_active ,"primary context active");
|
||||
TRITON_CREATE_HIP_EXCEPTION(context_is_destroyed ,"context is destroyed");
|
||||
TRITON_CREATE_HIP_EXCEPTION(assert_error ,"assert");
|
||||
TRITON_CREATE_HIP_EXCEPTION(too_many_peers ,"too many peers");
|
||||
TRITON_CREATE_HIP_EXCEPTION(host_memory_already_registered ,"host memory already registered");
|
||||
TRITON_CREATE_HIP_EXCEPTION(host_memory_not_registered ,"hot memory not registered");
|
||||
TRITON_CREATE_HIP_EXCEPTION(hardware_stack_error ,"hardware stack error");
|
||||
TRITON_CREATE_HIP_EXCEPTION(illegal_instruction ,"illegal instruction");
|
||||
TRITON_CREATE_HIP_EXCEPTION(misaligned_address ,"misaligned address");
|
||||
TRITON_CREATE_HIP_EXCEPTION(invalid_address_space ,"invalid address space");
|
||||
TRITON_CREATE_HIP_EXCEPTION(invalid_pc ,"invalid pc");
|
||||
TRITON_CREATE_HIP_EXCEPTION(launch_failed ,"launch failed");
|
||||
TRITON_CREATE_HIP_EXCEPTION(not_permitted ,"not permitted");
|
||||
TRITON_CREATE_HIP_EXCEPTION(not_supported ,"not supported");
|
||||
TRITON_CREATE_HIP_EXCEPTION(invalid_symbol ,"invalid symbol");
|
||||
TRITON_CREATE_HIP_EXCEPTION(unknown ,"unknown");
|
||||
TRITON_CREATE_HIP_EXCEPTION(invalid_value, "invalid value");
|
||||
TRITON_CREATE_HIP_EXCEPTION(out_of_memory, "out of memory");
|
||||
TRITON_CREATE_HIP_EXCEPTION(not_initialized, "not initialized");
|
||||
TRITON_CREATE_HIP_EXCEPTION(deinitialized, "deinitialized");
|
||||
TRITON_CREATE_HIP_EXCEPTION(profiler_disabled, "profiler disabled");
|
||||
TRITON_CREATE_HIP_EXCEPTION(profiler_not_initialized,
|
||||
"profiler not initialized");
|
||||
TRITON_CREATE_HIP_EXCEPTION(profiler_already_started,
|
||||
"profiler already started");
|
||||
TRITON_CREATE_HIP_EXCEPTION(profiler_already_stopped,
|
||||
"profiler already stopped");
|
||||
TRITON_CREATE_HIP_EXCEPTION(no_device, "no device");
|
||||
TRITON_CREATE_HIP_EXCEPTION(invalid_device, "invalid device");
|
||||
TRITON_CREATE_HIP_EXCEPTION(invalid_image, "invalid image");
|
||||
TRITON_CREATE_HIP_EXCEPTION(invalid_context, "invalid context");
|
||||
TRITON_CREATE_HIP_EXCEPTION(context_already_current, "context already current");
|
||||
TRITON_CREATE_HIP_EXCEPTION(map_failed, "map failed");
|
||||
TRITON_CREATE_HIP_EXCEPTION(unmap_failed, "unmap failed");
|
||||
TRITON_CREATE_HIP_EXCEPTION(array_is_mapped, "array is mapped");
|
||||
TRITON_CREATE_HIP_EXCEPTION(already_mapped, "already mapped");
|
||||
TRITON_CREATE_HIP_EXCEPTION(no_binary_for_gpu, "no binary for gpu");
|
||||
TRITON_CREATE_HIP_EXCEPTION(already_acquired, "already acquired");
|
||||
TRITON_CREATE_HIP_EXCEPTION(not_mapped, "not mapped");
|
||||
TRITON_CREATE_HIP_EXCEPTION(not_mapped_as_array, "not mapped as array");
|
||||
TRITON_CREATE_HIP_EXCEPTION(not_mapped_as_pointer, "not mapped as pointer");
|
||||
TRITON_CREATE_HIP_EXCEPTION(ecc_uncorrectable, "ecc uncorrectable");
|
||||
TRITON_CREATE_HIP_EXCEPTION(unsupported_limit, "unsupported limit");
|
||||
TRITON_CREATE_HIP_EXCEPTION(context_already_in_use, "context already in use");
|
||||
TRITON_CREATE_HIP_EXCEPTION(peer_access_unsupported, "peer access unsupported");
|
||||
TRITON_CREATE_HIP_EXCEPTION(invalid_ptx, "invalid ptx");
|
||||
TRITON_CREATE_HIP_EXCEPTION(invalid_graphics_context,
|
||||
"invalid graphics context");
|
||||
TRITON_CREATE_HIP_EXCEPTION(invalid_source, "invalid source");
|
||||
TRITON_CREATE_HIP_EXCEPTION(file_not_found, "file not found");
|
||||
TRITON_CREATE_HIP_EXCEPTION(shared_object_symbol_not_found,
|
||||
"shared object symbol not found");
|
||||
TRITON_CREATE_HIP_EXCEPTION(shared_object_init_failed,
|
||||
"shared object init failed");
|
||||
TRITON_CREATE_HIP_EXCEPTION(operating_system, "operating system");
|
||||
TRITON_CREATE_HIP_EXCEPTION(invalid_handle, "invalid handle");
|
||||
TRITON_CREATE_HIP_EXCEPTION(not_found, "not found");
|
||||
TRITON_CREATE_HIP_EXCEPTION(not_ready, "not ready");
|
||||
TRITON_CREATE_HIP_EXCEPTION(illegal_address, "illegal address");
|
||||
TRITON_CREATE_HIP_EXCEPTION(launch_out_of_resources, "launch out of resources");
|
||||
TRITON_CREATE_HIP_EXCEPTION(launch_timeout, "launch timeout");
|
||||
TRITON_CREATE_HIP_EXCEPTION(launch_incompatible_texturing,
|
||||
"launch incompatible texturing");
|
||||
TRITON_CREATE_HIP_EXCEPTION(peer_access_already_enabled,
|
||||
"peer access already enabled");
|
||||
TRITON_CREATE_HIP_EXCEPTION(peer_access_not_enabled, "peer access not enabled");
|
||||
TRITON_CREATE_HIP_EXCEPTION(primary_context_active, "primary context active");
|
||||
TRITON_CREATE_HIP_EXCEPTION(context_is_destroyed, "context is destroyed");
|
||||
TRITON_CREATE_HIP_EXCEPTION(assert_error, "assert");
|
||||
TRITON_CREATE_HIP_EXCEPTION(too_many_peers, "too many peers");
|
||||
TRITON_CREATE_HIP_EXCEPTION(host_memory_already_registered,
|
||||
"host memory already registered");
|
||||
TRITON_CREATE_HIP_EXCEPTION(host_memory_not_registered,
|
||||
"hot memory not registered");
|
||||
TRITON_CREATE_HIP_EXCEPTION(hardware_stack_error, "hardware stack error");
|
||||
TRITON_CREATE_HIP_EXCEPTION(illegal_instruction, "illegal instruction");
|
||||
TRITON_CREATE_HIP_EXCEPTION(misaligned_address, "misaligned address");
|
||||
TRITON_CREATE_HIP_EXCEPTION(invalid_address_space, "invalid address space");
|
||||
TRITON_CREATE_HIP_EXCEPTION(invalid_pc, "invalid pc");
|
||||
TRITON_CREATE_HIP_EXCEPTION(launch_failed, "launch failed");
|
||||
TRITON_CREATE_HIP_EXCEPTION(not_permitted, "not permitted");
|
||||
TRITON_CREATE_HIP_EXCEPTION(not_supported, "not supported");
|
||||
TRITON_CREATE_HIP_EXCEPTION(invalid_symbol, "invalid symbol");
|
||||
TRITON_CREATE_HIP_EXCEPTION(unknown, "unknown");
|
||||
|
||||
#undef TRITON_CREATE_CUDA_EXCEPTION
|
||||
}
|
||||
} // namespace hip
|
||||
|
||||
}
|
||||
}
|
||||
}
|
||||
} // namespace exception
|
||||
} // namespace driver
|
||||
} // namespace triton
|
||||
|
||||
#endif
|
||||
|
@@ -1,20 +1,21 @@
|
||||
#include <string>
|
||||
#include "triton/driver/dispatch.h"
|
||||
#include <string>
|
||||
|
||||
namespace llvm{
|
||||
namespace llvm {
|
||||
class Module;
|
||||
}
|
||||
|
||||
namespace triton{
|
||||
namespace driver{
|
||||
namespace triton {
|
||||
namespace driver {
|
||||
|
||||
void init_llvm();
|
||||
std::string path_to_ptxas(int& version);
|
||||
std::string llir_to_ptx(llvm::Module* module, int cc, int version);
|
||||
std::string ptx_to_cubin(const std::string& ptx, const std::string& ptxas_path, int cc);
|
||||
CUmodule ptx_to_cumodule(const std::string& ptx, int cc);
|
||||
std::string llir_to_amdgpu(llvm::Module* module, const std::string& proc);
|
||||
hipModule_t amdgpu_to_hipmodule(const std::string& path);
|
||||
std::string path_to_ptxas(int &version);
|
||||
std::string llir_to_ptx(llvm::Module *module, int cc, int version);
|
||||
std::string ptx_to_cubin(const std::string &ptx, const std::string &ptxas_path,
|
||||
int cc);
|
||||
CUmodule ptx_to_cumodule(const std::string &ptx, int cc);
|
||||
std::string llir_to_amdgpu(llvm::Module *module, const std::string &proc);
|
||||
hipModule_t amdgpu_to_hipmodule(const std::string &path);
|
||||
|
||||
}
|
||||
}
|
||||
} // namespace driver
|
||||
} // namespace triton
|
||||
|
@@ -3,52 +3,55 @@
|
||||
#ifndef _TRITON_TOOLS_BENCH_H_
|
||||
#define _TRITON_TOOLS_BENCH_H_
|
||||
|
||||
#include <chrono>
|
||||
#include <functional>
|
||||
#include <algorithm>
|
||||
#include "triton/driver/device.h"
|
||||
#include "triton/driver/stream.h"
|
||||
#include <algorithm>
|
||||
#include <chrono>
|
||||
#include <functional>
|
||||
|
||||
namespace triton{
|
||||
namespace tools{
|
||||
namespace triton {
|
||||
namespace tools {
|
||||
|
||||
class timer{
|
||||
typedef std::chrono::high_resolution_clock high_resolution_clock;
|
||||
typedef std::chrono::nanoseconds nanoseconds;
|
||||
class timer {
|
||||
typedef std::chrono::high_resolution_clock high_resolution_clock;
|
||||
typedef std::chrono::nanoseconds nanoseconds;
|
||||
|
||||
public:
|
||||
explicit timer(bool run = false)
|
||||
{ if (run) start(); }
|
||||
explicit timer(bool run = false) {
|
||||
if (run)
|
||||
start();
|
||||
}
|
||||
|
||||
void start()
|
||||
{ _start = high_resolution_clock::now(); }
|
||||
void start() { _start = high_resolution_clock::now(); }
|
||||
|
||||
nanoseconds get() const
|
||||
{ return std::chrono::duration_cast<nanoseconds>(high_resolution_clock::now() - _start); }
|
||||
nanoseconds get() const {
|
||||
return std::chrono::duration_cast<nanoseconds>(
|
||||
high_resolution_clock::now() - _start);
|
||||
}
|
||||
|
||||
private:
|
||||
high_resolution_clock::time_point _start;
|
||||
high_resolution_clock::time_point _start;
|
||||
};
|
||||
|
||||
inline double bench(std::function<void()> const & op, driver::stream * stream, size_t warmup = 10, size_t repeat = 200)
|
||||
{
|
||||
inline double bench(std::function<void()> const &op, driver::stream *stream,
|
||||
size_t warmup = 10, size_t repeat = 200) {
|
||||
timer tmr;
|
||||
std::vector<size_t> times;
|
||||
double total_time = 0;
|
||||
for(size_t i = 0; i < warmup; i++)
|
||||
for (size_t i = 0; i < warmup; i++)
|
||||
op();
|
||||
stream->synchronize();
|
||||
tmr.start();
|
||||
for(size_t i = 0; i < repeat; i++){
|
||||
for (size_t i = 0; i < repeat; i++) {
|
||||
op();
|
||||
}
|
||||
stream->synchronize();
|
||||
return (float)tmr.get().count() / repeat;
|
||||
|
||||
// return *std::min_element(times.begin(), times.end());
|
||||
// return *std::min_element(times.begin(), times.end());
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
} // namespace tools
|
||||
} // namespace triton
|
||||
|
||||
#endif
|
||||
|
@@ -3,16 +3,15 @@
|
||||
#ifndef _TRITON_TOOLS_THREAD_GRAPH_H_
|
||||
#define _TRITON_TOOLS_THREAD_GRAPH_H_
|
||||
|
||||
#include <iostream>
|
||||
#include <map>
|
||||
#include <set>
|
||||
#include <vector>
|
||||
#include <iostream>
|
||||
|
||||
namespace triton {
|
||||
namespace tools{
|
||||
namespace tools {
|
||||
|
||||
template<class node_t>
|
||||
class graph {
|
||||
template <class node_t> class graph {
|
||||
typedef std::map<node_t, std::set<node_t>> edges_t;
|
||||
|
||||
public:
|
||||
@@ -21,27 +20,27 @@ public:
|
||||
|
||||
private:
|
||||
void connected_components_impl(node_t x, std::set<node_t> &nodes,
|
||||
nmap_t* nmap, cmap_t* cmap, int id) const {
|
||||
if(nmap)
|
||||
nmap_t *nmap, cmap_t *cmap, int id) const {
|
||||
if (nmap)
|
||||
(*nmap)[x] = id;
|
||||
if(cmap)
|
||||
if (cmap)
|
||||
(*cmap)[id].push_back(x);
|
||||
if(nodes.find(x) != nodes.end()) {
|
||||
if (nodes.find(x) != nodes.end()) {
|
||||
nodes.erase(x);
|
||||
for(const node_t &y: edges_.at(x))
|
||||
for (const node_t &y : edges_.at(x))
|
||||
connected_components_impl(y, nodes, nmap, cmap, id);
|
||||
}
|
||||
}
|
||||
|
||||
public:
|
||||
void connected_components(cmap_t *cmap, nmap_t *nmap) const {
|
||||
if(cmap)
|
||||
if (cmap)
|
||||
cmap->clear();
|
||||
if(nmap)
|
||||
if (nmap)
|
||||
nmap->clear();
|
||||
std::set<node_t> nodes = nodes_;
|
||||
unsigned id = 0;
|
||||
while(!nodes.empty()){
|
||||
while (!nodes.empty()) {
|
||||
connected_components_impl(*nodes.begin(), nodes, nmap, cmap, id++);
|
||||
}
|
||||
}
|
||||
@@ -63,7 +62,7 @@ private:
|
||||
edges_t edges_;
|
||||
};
|
||||
|
||||
}
|
||||
}
|
||||
} // namespace tools
|
||||
} // namespace triton
|
||||
|
||||
#endif
|
||||
|
@@ -33,154 +33,140 @@
|
||||
#ifndef _TRITON_TOOLS_SHA1_HPP_
|
||||
#define _TRITON_TOOLS_SHA1_HPP_
|
||||
|
||||
namespace sha1
|
||||
namespace sha1 {
|
||||
namespace // local
|
||||
{
|
||||
namespace // local
|
||||
{
|
||||
// Rotate an integer value to left.
|
||||
inline unsigned int rol(const unsigned int value,
|
||||
const unsigned int steps)
|
||||
{
|
||||
return ((value << steps) | (value >> (32 - steps)));
|
||||
}
|
||||
// Rotate an integer value to left.
|
||||
inline unsigned int rol(const unsigned int value, const unsigned int steps) {
|
||||
return ((value << steps) | (value >> (32 - steps)));
|
||||
}
|
||||
|
||||
// Sets the first 16 integers in the buffert to zero.
|
||||
// Used for clearing the W buffert.
|
||||
inline void clearWBuffert(unsigned int* buffert)
|
||||
{
|
||||
for (int pos = 16; --pos >= 0;)
|
||||
{
|
||||
buffert[pos] = 0;
|
||||
}
|
||||
}
|
||||
// Sets the first 16 integers in the buffert to zero.
|
||||
// Used for clearing the W buffert.
|
||||
inline void clearWBuffert(unsigned int *buffert) {
|
||||
for (int pos = 16; --pos >= 0;) {
|
||||
buffert[pos] = 0;
|
||||
}
|
||||
}
|
||||
|
||||
inline void innerHash(unsigned int* result, unsigned int* w)
|
||||
{
|
||||
unsigned int a = result[0];
|
||||
unsigned int b = result[1];
|
||||
unsigned int c = result[2];
|
||||
unsigned int d = result[3];
|
||||
unsigned int e = result[4];
|
||||
inline void innerHash(unsigned int *result, unsigned int *w) {
|
||||
unsigned int a = result[0];
|
||||
unsigned int b = result[1];
|
||||
unsigned int c = result[2];
|
||||
unsigned int d = result[3];
|
||||
unsigned int e = result[4];
|
||||
|
||||
int round = 0;
|
||||
int round = 0;
|
||||
|
||||
#define sha1macro(func,val) \
|
||||
{ \
|
||||
const unsigned int t = rol(a, 5) + (func) + e + val + w[round]; \
|
||||
e = d; \
|
||||
d = c; \
|
||||
c = rol(b, 30); \
|
||||
b = a; \
|
||||
a = t; \
|
||||
}
|
||||
#define sha1macro(func, val) \
|
||||
{ \
|
||||
const unsigned int t = rol(a, 5) + (func) + e + val + w[round]; \
|
||||
e = d; \
|
||||
d = c; \
|
||||
c = rol(b, 30); \
|
||||
b = a; \
|
||||
a = t; \
|
||||
}
|
||||
|
||||
while (round < 16)
|
||||
{
|
||||
sha1macro((b & c) | (~b & d), 0x5a827999)
|
||||
++round;
|
||||
}
|
||||
while (round < 20)
|
||||
{
|
||||
w[round] = rol((w[round - 3] ^ w[round - 8] ^ w[round - 14] ^ w[round - 16]), 1);
|
||||
sha1macro((b & c) | (~b & d), 0x5a827999)
|
||||
++round;
|
||||
}
|
||||
while (round < 40)
|
||||
{
|
||||
w[round] = rol((w[round - 3] ^ w[round - 8] ^ w[round - 14] ^ w[round - 16]), 1);
|
||||
sha1macro(b ^ c ^ d, 0x6ed9eba1)
|
||||
++round;
|
||||
}
|
||||
while (round < 60)
|
||||
{
|
||||
w[round] = rol((w[round - 3] ^ w[round - 8] ^ w[round - 14] ^ w[round - 16]), 1);
|
||||
sha1macro((b & c) | (b & d) | (c & d), 0x8f1bbcdc)
|
||||
++round;
|
||||
}
|
||||
while (round < 80)
|
||||
{
|
||||
w[round] = rol((w[round - 3] ^ w[round - 8] ^ w[round - 14] ^ w[round - 16]), 1);
|
||||
sha1macro(b ^ c ^ d, 0xca62c1d6)
|
||||
++round;
|
||||
}
|
||||
while (round < 16) {
|
||||
sha1macro((b & c) | (~b & d), 0x5a827999)++ round;
|
||||
}
|
||||
while (round < 20) {
|
||||
w[round] =
|
||||
rol((w[round - 3] ^ w[round - 8] ^ w[round - 14] ^ w[round - 16]), 1);
|
||||
sha1macro((b & c) | (~b & d), 0x5a827999)++ round;
|
||||
}
|
||||
while (round < 40) {
|
||||
w[round] =
|
||||
rol((w[round - 3] ^ w[round - 8] ^ w[round - 14] ^ w[round - 16]), 1);
|
||||
sha1macro(b ^ c ^ d, 0x6ed9eba1)++ round;
|
||||
}
|
||||
while (round < 60) {
|
||||
w[round] =
|
||||
rol((w[round - 3] ^ w[round - 8] ^ w[round - 14] ^ w[round - 16]), 1);
|
||||
sha1macro((b & c) | (b & d) | (c & d), 0x8f1bbcdc)++ round;
|
||||
}
|
||||
while (round < 80) {
|
||||
w[round] =
|
||||
rol((w[round - 3] ^ w[round - 8] ^ w[round - 14] ^ w[round - 16]), 1);
|
||||
sha1macro(b ^ c ^ d, 0xca62c1d6)++ round;
|
||||
}
|
||||
|
||||
#undef sha1macro
|
||||
#undef sha1macro
|
||||
|
||||
result[0] += a;
|
||||
result[1] += b;
|
||||
result[2] += c;
|
||||
result[3] += d;
|
||||
result[4] += e;
|
||||
}
|
||||
} // namespace
|
||||
result[0] += a;
|
||||
result[1] += b;
|
||||
result[2] += c;
|
||||
result[3] += d;
|
||||
result[4] += e;
|
||||
}
|
||||
} // namespace
|
||||
|
||||
inline void calc(const void* src, const int bytelength, unsigned char* hash)
|
||||
{
|
||||
// Init the result array.
|
||||
unsigned int result[5] = { 0x67452301, 0xefcdab89, 0x98badcfe, 0x10325476, 0xc3d2e1f0 };
|
||||
inline void calc(const void *src, const int bytelength, unsigned char *hash) {
|
||||
// Init the result array.
|
||||
unsigned int result[5] = {0x67452301, 0xefcdab89, 0x98badcfe, 0x10325476,
|
||||
0xc3d2e1f0};
|
||||
|
||||
// Cast the void src pointer to be the byte array we can work with.
|
||||
const unsigned char* sarray = (const unsigned char*) src;
|
||||
// Cast the void src pointer to be the byte array we can work with.
|
||||
const unsigned char *sarray = (const unsigned char *)src;
|
||||
|
||||
// The reusable round buffer
|
||||
unsigned int w[80];
|
||||
// The reusable round buffer
|
||||
unsigned int w[80];
|
||||
|
||||
// Loop through all complete 64byte blocks.
|
||||
const int endOfFullBlocks = bytelength - 64;
|
||||
int endCurrentBlock;
|
||||
int currentBlock = 0;
|
||||
// Loop through all complete 64byte blocks.
|
||||
const int endOfFullBlocks = bytelength - 64;
|
||||
int endCurrentBlock;
|
||||
int currentBlock = 0;
|
||||
|
||||
while (currentBlock <= endOfFullBlocks)
|
||||
{
|
||||
endCurrentBlock = currentBlock + 64;
|
||||
while (currentBlock <= endOfFullBlocks) {
|
||||
endCurrentBlock = currentBlock + 64;
|
||||
|
||||
// Init the round buffer with the 64 byte block data.
|
||||
for (int roundPos = 0; currentBlock < endCurrentBlock; currentBlock += 4)
|
||||
{
|
||||
// This line will swap endian on big endian and keep endian on little endian.
|
||||
w[roundPos++] = (unsigned int) sarray[currentBlock + 3]
|
||||
| (((unsigned int) sarray[currentBlock + 2]) << 8)
|
||||
| (((unsigned int) sarray[currentBlock + 1]) << 16)
|
||||
| (((unsigned int) sarray[currentBlock]) << 24);
|
||||
}
|
||||
innerHash(result, w);
|
||||
}
|
||||
|
||||
// Handle the last and not full 64 byte block if existing.
|
||||
endCurrentBlock = bytelength - currentBlock;
|
||||
clearWBuffert(w);
|
||||
int lastBlockBytes = 0;
|
||||
for (;lastBlockBytes < endCurrentBlock; ++lastBlockBytes)
|
||||
{
|
||||
w[lastBlockBytes >> 2] |= (unsigned int) sarray[lastBlockBytes + currentBlock] << ((3 - (lastBlockBytes & 3)) << 3);
|
||||
}
|
||||
w[lastBlockBytes >> 2] |= 0x80 << ((3 - (lastBlockBytes & 3)) << 3);
|
||||
if (endCurrentBlock >= 56)
|
||||
{
|
||||
innerHash(result, w);
|
||||
clearWBuffert(w);
|
||||
}
|
||||
w[15] = bytelength << 3;
|
||||
innerHash(result, w);
|
||||
|
||||
// Store hash in result pointer, and make sure we get in in the correct order on both endian models.
|
||||
for (int hashByte = 20; --hashByte >= 0;)
|
||||
{
|
||||
hash[hashByte] = (result[hashByte >> 2] >> (((3 - hashByte) & 0x3) << 3)) & 0xff;
|
||||
}
|
||||
// Init the round buffer with the 64 byte block data.
|
||||
for (int roundPos = 0; currentBlock < endCurrentBlock; currentBlock += 4) {
|
||||
// This line will swap endian on big endian and keep endian on little
|
||||
// endian.
|
||||
w[roundPos++] = (unsigned int)sarray[currentBlock + 3] |
|
||||
(((unsigned int)sarray[currentBlock + 2]) << 8) |
|
||||
(((unsigned int)sarray[currentBlock + 1]) << 16) |
|
||||
(((unsigned int)sarray[currentBlock]) << 24);
|
||||
}
|
||||
innerHash(result, w);
|
||||
}
|
||||
|
||||
inline void toHexString(const unsigned char* hash, char* hexstring)
|
||||
{
|
||||
const char hexDigits[] = { "0123456789abcdef" };
|
||||
// Handle the last and not full 64 byte block if existing.
|
||||
endCurrentBlock = bytelength - currentBlock;
|
||||
clearWBuffert(w);
|
||||
int lastBlockBytes = 0;
|
||||
for (; lastBlockBytes < endCurrentBlock; ++lastBlockBytes) {
|
||||
w[lastBlockBytes >> 2] |=
|
||||
(unsigned int)sarray[lastBlockBytes + currentBlock]
|
||||
<< ((3 - (lastBlockBytes & 3)) << 3);
|
||||
}
|
||||
w[lastBlockBytes >> 2] |= 0x80 << ((3 - (lastBlockBytes & 3)) << 3);
|
||||
if (endCurrentBlock >= 56) {
|
||||
innerHash(result, w);
|
||||
clearWBuffert(w);
|
||||
}
|
||||
w[15] = bytelength << 3;
|
||||
innerHash(result, w);
|
||||
|
||||
for (int hashByte = 20; --hashByte >= 0;)
|
||||
{
|
||||
hexstring[hashByte << 1] = hexDigits[(hash[hashByte] >> 4) & 0xf];
|
||||
hexstring[(hashByte << 1) + 1] = hexDigits[hash[hashByte] & 0xf];
|
||||
}
|
||||
hexstring[40] = 0;
|
||||
}
|
||||
// Store hash in result pointer, and make sure we get in in the correct order
|
||||
// on both endian models.
|
||||
for (int hashByte = 20; --hashByte >= 0;) {
|
||||
hash[hashByte] =
|
||||
(result[hashByte >> 2] >> (((3 - hashByte) & 0x3) << 3)) & 0xff;
|
||||
}
|
||||
}
|
||||
|
||||
inline void toHexString(const unsigned char *hash, char *hexstring) {
|
||||
const char hexDigits[] = {"0123456789abcdef"};
|
||||
|
||||
for (int hashByte = 20; --hashByte >= 0;) {
|
||||
hexstring[hashByte << 1] = hexDigits[(hash[hashByte] >> 4) & 0xf];
|
||||
hexstring[(hashByte << 1) + 1] = hexDigits[hash[hashByte] & 0xf];
|
||||
}
|
||||
hexstring[40] = 0;
|
||||
}
|
||||
} // namespace sha1
|
||||
|
||||
#endif
|
||||
|
@@ -7,11 +7,8 @@
|
||||
#include <stdexcept>
|
||||
#include <string>
|
||||
|
||||
namespace triton
|
||||
{
|
||||
namespace tools
|
||||
{
|
||||
|
||||
namespace triton {
|
||||
namespace tools {
|
||||
|
||||
#ifdef _WIN32
|
||||
#define popen _popen
|
||||
@@ -19,12 +16,12 @@ namespace tools
|
||||
#endif
|
||||
|
||||
#ifndef WEXITSTATUS
|
||||
#define WEXITSTATUS(stat_val) ((unsigned)(stat_val) & 255)
|
||||
#define WEXITSTATUS(stat_val) ((unsigned)(stat_val)&255)
|
||||
#endif
|
||||
|
||||
int exec(const std::string& cmd, std::string& result) {
|
||||
int exec(const std::string &cmd, std::string &result) {
|
||||
char buffer[128];
|
||||
FILE* pipe = popen(cmd.c_str(), "r");
|
||||
FILE *pipe = popen(cmd.c_str(), "r");
|
||||
if (!pipe)
|
||||
return 0;
|
||||
result.clear();
|
||||
@@ -37,10 +34,9 @@ int exec(const std::string& cmd, std::string& result) {
|
||||
}
|
||||
int status = pclose(pipe);
|
||||
return WEXITSTATUS(status);
|
||||
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
} // namespace tools
|
||||
} // namespace triton
|
||||
|
||||
#endif
|
||||
|
27
include/triton/tools/sys/getenv.hpp
Executable file → Normal file
27
include/triton/tools/sys/getenv.hpp
Executable file → Normal file
@@ -22,26 +22,23 @@
|
||||
#ifndef TDL_TOOLS_SYS_GETENV_HPP
|
||||
#define TDL_TOOLS_SYS_GETENV_HPP
|
||||
|
||||
#include <string>
|
||||
#include <cstdlib>
|
||||
#include <string>
|
||||
|
||||
namespace triton
|
||||
{
|
||||
namespace triton {
|
||||
|
||||
namespace tools
|
||||
{
|
||||
|
||||
inline std::string getenv(const char * name)
|
||||
{
|
||||
const char * cstr = std::getenv(name);
|
||||
if(!cstr)
|
||||
return "";
|
||||
std::string result(cstr);
|
||||
return result;
|
||||
}
|
||||
namespace tools {
|
||||
|
||||
inline std::string getenv(const char *name) {
|
||||
const char *cstr = std::getenv(name);
|
||||
if (!cstr)
|
||||
return "";
|
||||
std::string result(cstr);
|
||||
return result;
|
||||
}
|
||||
|
||||
}
|
||||
} // namespace tools
|
||||
|
||||
} // namespace triton
|
||||
|
||||
#endif
|
||||
|
74
include/triton/tools/sys/mkdir.hpp
Executable file → Normal file
74
include/triton/tools/sys/mkdir.hpp
Executable file → Normal file
@@ -22,55 +22,49 @@
|
||||
#ifndef TDL_TOOLS_SYS_MKDIR_HPP
|
||||
#define TDL_TOOLS_SYS_MKDIR_HPP
|
||||
|
||||
#include <cstring>
|
||||
#include <string>
|
||||
#include <cstdlib>
|
||||
#include <sys/stat.h>
|
||||
#include <cstring>
|
||||
#include <errno.h>
|
||||
#include <string>
|
||||
#include <sys/stat.h>
|
||||
#if defined(_WIN32)
|
||||
#include <direct.h>
|
||||
#include <direct.h>
|
||||
#endif
|
||||
|
||||
namespace triton
|
||||
{
|
||||
namespace triton {
|
||||
|
||||
namespace tools
|
||||
{
|
||||
|
||||
inline int mkdir(std::string const & path)
|
||||
{
|
||||
#if defined(_WIN32)
|
||||
return _mkdir(path.c_str());
|
||||
#else
|
||||
return ::mkdir(path.c_str(), 0777);
|
||||
#endif
|
||||
}
|
||||
|
||||
inline int mkpath(std::string const & path)
|
||||
{
|
||||
int status = 0;
|
||||
size_t pp = 0;
|
||||
size_t sp;
|
||||
while ((sp = path.find('/', pp)) != std::string::npos)
|
||||
{
|
||||
if (sp != pp){
|
||||
status = mkdir(path.substr(0, sp));
|
||||
}
|
||||
pp = sp + 1;
|
||||
}
|
||||
return (status==0 || errno==EEXIST)?0:-1;
|
||||
}
|
||||
|
||||
inline int mtime(std::string const & path)
|
||||
{
|
||||
struct stat st;
|
||||
if(stat(path.c_str(), &st) != 0)
|
||||
return 0;
|
||||
return st.st_mtime;
|
||||
}
|
||||
namespace tools {
|
||||
|
||||
inline int mkdir(std::string const &path) {
|
||||
#if defined(_WIN32)
|
||||
return _mkdir(path.c_str());
|
||||
#else
|
||||
return ::mkdir(path.c_str(), 0777);
|
||||
#endif
|
||||
}
|
||||
|
||||
inline int mkpath(std::string const &path) {
|
||||
int status = 0;
|
||||
size_t pp = 0;
|
||||
size_t sp;
|
||||
while ((sp = path.find('/', pp)) != std::string::npos) {
|
||||
if (sp != pp) {
|
||||
status = mkdir(path.substr(0, sp));
|
||||
}
|
||||
pp = sp + 1;
|
||||
}
|
||||
return (status == 0 || errno == EEXIST) ? 0 : -1;
|
||||
}
|
||||
|
||||
inline int mtime(std::string const &path) {
|
||||
struct stat st;
|
||||
if (stat(path.c_str(), &st) != 0)
|
||||
return 0;
|
||||
return st.st_mtime;
|
||||
}
|
||||
|
||||
} // namespace tools
|
||||
|
||||
} // namespace triton
|
||||
|
||||
#endif
|
||||
|
@@ -3,88 +3,79 @@
|
||||
#ifndef _TRITON_TOOLS_THREAD_POOL_H_
|
||||
#define _TRITON_TOOLS_THREAD_POOL_H_
|
||||
|
||||
#include <vector>
|
||||
#include <queue>
|
||||
#include <memory>
|
||||
#include <thread>
|
||||
#include <mutex>
|
||||
#include <condition_variable>
|
||||
#include <future>
|
||||
#include <functional>
|
||||
#include <future>
|
||||
#include <memory>
|
||||
#include <mutex>
|
||||
#include <queue>
|
||||
#include <stdexcept>
|
||||
#include <thread>
|
||||
#include <vector>
|
||||
|
||||
class ThreadPool {
|
||||
public:
|
||||
ThreadPool(size_t threads)
|
||||
: stop(false) {
|
||||
for(size_t i = 0;i < threads;++i)
|
||||
workers.emplace_back(
|
||||
[this] {
|
||||
for(;;){
|
||||
std::function<void()> task;
|
||||
{
|
||||
std::unique_lock<std::mutex> lock(this->queue_mutex);
|
||||
this->condition.wait(lock,
|
||||
[this]{ return this->stop || !this->tasks.empty(); });
|
||||
if(this->stop && this->tasks.empty())
|
||||
return;
|
||||
task = std::move(this->tasks.front());
|
||||
this->tasks.pop();
|
||||
}
|
||||
task();
|
||||
}
|
||||
}
|
||||
);
|
||||
}
|
||||
ThreadPool(size_t threads) : stop(false) {
|
||||
for (size_t i = 0; i < threads; ++i)
|
||||
workers.emplace_back([this] {
|
||||
for (;;) {
|
||||
std::function<void()> task;
|
||||
{
|
||||
std::unique_lock<std::mutex> lock(this->queue_mutex);
|
||||
this->condition.wait(
|
||||
lock, [this] { return this->stop || !this->tasks.empty(); });
|
||||
if (this->stop && this->tasks.empty())
|
||||
return;
|
||||
task = std::move(this->tasks.front());
|
||||
this->tasks.pop();
|
||||
}
|
||||
task();
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
template <class F, class... Args>
|
||||
auto enqueue(F &&f, Args &&... args)
|
||||
-> std::future<typename std::result_of<F(Args...)>::type> {
|
||||
using return_type = typename std::result_of<F(Args...)>::type;
|
||||
|
||||
template<class F, class... Args>
|
||||
auto enqueue(F&& f, Args&&... args)
|
||||
-> std::future<typename std::result_of<F(Args...)>::type>
|
||||
auto task = std::make_shared<std::packaged_task<return_type()>>(
|
||||
std::bind(std::forward<F>(f), std::forward<Args>(args)...));
|
||||
|
||||
std::future<return_type> res = task->get_future();
|
||||
{
|
||||
using return_type = typename std::result_of<F(Args...)>::type;
|
||||
std::unique_lock<std::mutex> lock(queue_mutex);
|
||||
|
||||
auto task = std::make_shared< std::packaged_task<return_type()> >(
|
||||
std::bind(std::forward<F>(f), std::forward<Args>(args)...)
|
||||
);
|
||||
// don't allow enqueueing after stopping the pool
|
||||
if (stop)
|
||||
throw std::runtime_error("enqueue on stopped ThreadPool");
|
||||
|
||||
std::future<return_type> res = task->get_future();
|
||||
{
|
||||
std::unique_lock<std::mutex> lock(queue_mutex);
|
||||
|
||||
// don't allow enqueueing after stopping the pool
|
||||
if(stop)
|
||||
throw std::runtime_error("enqueue on stopped ThreadPool");
|
||||
|
||||
tasks.emplace([task](){ (*task)(); });
|
||||
}
|
||||
condition.notify_one();
|
||||
return res;
|
||||
tasks.emplace([task]() { (*task)(); });
|
||||
}
|
||||
condition.notify_one();
|
||||
return res;
|
||||
}
|
||||
|
||||
|
||||
~ThreadPool() {
|
||||
{
|
||||
std::unique_lock<std::mutex> lock(queue_mutex);
|
||||
stop = true;
|
||||
}
|
||||
condition.notify_all();
|
||||
for(std::thread &worker: workers)
|
||||
worker.join();
|
||||
~ThreadPool() {
|
||||
{
|
||||
std::unique_lock<std::mutex> lock(queue_mutex);
|
||||
stop = true;
|
||||
}
|
||||
|
||||
condition.notify_all();
|
||||
for (std::thread &worker : workers)
|
||||
worker.join();
|
||||
}
|
||||
|
||||
private:
|
||||
// need to keep track of threads so we can join them
|
||||
std::vector< std::thread > workers;
|
||||
// the task queue
|
||||
std::queue< std::function<void()> > tasks;
|
||||
// need to keep track of threads so we can join them
|
||||
std::vector<std::thread> workers;
|
||||
// the task queue
|
||||
std::queue<std::function<void()>> tasks;
|
||||
|
||||
// synchronization
|
||||
std::mutex queue_mutex;
|
||||
std::condition_variable condition;
|
||||
bool stop;
|
||||
// synchronization
|
||||
std::mutex queue_mutex;
|
||||
std::condition_variable condition;
|
||||
bool stop;
|
||||
};
|
||||
|
||||
|
||||
#endif
|
||||
|
@@ -8,24 +8,23 @@
|
||||
|
||||
namespace mlir {
|
||||
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// AxisInfo
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
// Function for extended Euclidean Algorithm
|
||||
static int gcd_impl(int a, int b, int *x, int *y){
|
||||
// Function for extended Euclidean Algorithm
|
||||
static int gcd_impl(int a, int b, int *x, int *y) {
|
||||
// Base Case
|
||||
if (a == 0) {
|
||||
*x = 0;
|
||||
*y = 1;
|
||||
return b;
|
||||
*x = 0;
|
||||
*y = 1;
|
||||
return b;
|
||||
}
|
||||
int x1, y1; // To store results of recursive call
|
||||
int gcd = gcd_impl(b%a, a, &x1, &y1);
|
||||
int gcd = gcd_impl(b % a, a, &x1, &y1);
|
||||
// Update x and y using results of
|
||||
// recursive call
|
||||
*x = y1 - (b/a) * x1;
|
||||
*x = y1 - (b / a) * x1;
|
||||
*y = x1;
|
||||
return gcd;
|
||||
}
|
||||
@@ -35,17 +34,17 @@ static int gcd(int a, int b) {
|
||||
return gcd_impl(a, b, &x, &y);
|
||||
}
|
||||
|
||||
|
||||
AxisInfo AxisInfo::getPessimisticValueState(Value value) {
|
||||
size_t rank = 1;
|
||||
if(TensorType ty = value.getType().dyn_cast<TensorType>())
|
||||
if (TensorType ty = value.getType().dyn_cast<TensorType>())
|
||||
rank = ty.getRank();
|
||||
int divHint = 1;
|
||||
if(BlockArgument blockArg = value.dyn_cast<BlockArgument>()){
|
||||
Operation* op = blockArg.getOwner()->getParentOp();
|
||||
if(FuncOp fun = dyn_cast<FuncOp>(op)){
|
||||
Attribute attr = fun.getArgAttr(blockArg.getArgNumber(), "tt.divisibility");
|
||||
if(attr)
|
||||
if (BlockArgument blockArg = value.dyn_cast<BlockArgument>()) {
|
||||
Operation *op = blockArg.getOwner()->getParentOp();
|
||||
if (FuncOp fun = dyn_cast<FuncOp>(op)) {
|
||||
Attribute attr =
|
||||
fun.getArgAttr(blockArg.getArgNumber(), "tt.divisibility");
|
||||
if (attr)
|
||||
divHint = attr.cast<IntegerAttr>().getValue().getZExtValue();
|
||||
}
|
||||
}
|
||||
@@ -55,51 +54,51 @@ AxisInfo AxisInfo::getPessimisticValueState(Value value) {
|
||||
return AxisInfo(contiguity, divisibility, constancy);
|
||||
}
|
||||
|
||||
|
||||
// The gcd of both arguments for each dimension
|
||||
AxisInfo AxisInfo::join(const AxisInfo &lhs,
|
||||
const AxisInfo &rhs) {
|
||||
AxisInfo AxisInfo::join(const AxisInfo &lhs, const AxisInfo &rhs) {
|
||||
ContiguityT retContiguity;
|
||||
DivisibilityT retDivisibility;
|
||||
ConstancyT retConstancy;
|
||||
for(size_t d = 0; d < lhs.getRank(); d++){
|
||||
for (size_t d = 0; d < lhs.getRank(); d++) {
|
||||
retContiguity.push_back(gcd(lhs.getContiguity(d), rhs.getContiguity(d)));
|
||||
retDivisibility.push_back(gcd(lhs.getDivisibility(d), rhs.getDivisibility(d)));
|
||||
retDivisibility.push_back(
|
||||
gcd(lhs.getDivisibility(d), rhs.getDivisibility(d)));
|
||||
retConstancy.push_back(gcd(lhs.getConstancy(d), rhs.getConstancy(d)));
|
||||
}
|
||||
return AxisInfo(retContiguity, retDivisibility, retConstancy);
|
||||
}
|
||||
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// AxisInfoAnalysis
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
AxisInfo AxisInfoAnalysis::visitBinaryOp(Operation* op, AxisInfo lhsInfo, AxisInfo rhsInfo,
|
||||
const std::function<int(AxisInfo,AxisInfo,int)>& getContiguity,
|
||||
const std::function<int(AxisInfo,AxisInfo,int)>& getDivisibility,
|
||||
const std::function<int(AxisInfo,AxisInfo,int)>& getConstancy) {
|
||||
int rank = lhsInfo.getRank();
|
||||
AxisInfo::ContiguityT newContiguity;
|
||||
AxisInfo::DivisibilityT newDivisibility;
|
||||
AxisInfo::ConstancyT newConstancy;
|
||||
for(size_t d = 0; d < rank; d++){
|
||||
newContiguity.push_back(getContiguity(lhsInfo, rhsInfo, d));
|
||||
newDivisibility.push_back(getDivisibility(lhsInfo, rhsInfo, d));
|
||||
newConstancy.push_back(getConstancy(lhsInfo, rhsInfo, d));
|
||||
}
|
||||
return AxisInfo(newContiguity, newDivisibility, newConstancy);
|
||||
AxisInfo AxisInfoAnalysis::visitBinaryOp(
|
||||
Operation *op, AxisInfo lhsInfo, AxisInfo rhsInfo,
|
||||
const std::function<int(AxisInfo, AxisInfo, int)> &getContiguity,
|
||||
const std::function<int(AxisInfo, AxisInfo, int)> &getDivisibility,
|
||||
const std::function<int(AxisInfo, AxisInfo, int)> &getConstancy) {
|
||||
int rank = lhsInfo.getRank();
|
||||
AxisInfo::ContiguityT newContiguity;
|
||||
AxisInfo::DivisibilityT newDivisibility;
|
||||
AxisInfo::ConstancyT newConstancy;
|
||||
for (size_t d = 0; d < rank; d++) {
|
||||
newContiguity.push_back(getContiguity(lhsInfo, rhsInfo, d));
|
||||
newDivisibility.push_back(getDivisibility(lhsInfo, rhsInfo, d));
|
||||
newConstancy.push_back(getConstancy(lhsInfo, rhsInfo, d));
|
||||
}
|
||||
return AxisInfo(newContiguity, newDivisibility, newConstancy);
|
||||
}
|
||||
|
||||
ChangeResult AxisInfoAnalysis::visitOperation(Operation *op,
|
||||
ArrayRef<LatticeElement<AxisInfo> *> operands) {
|
||||
ChangeResult AxisInfoAnalysis::visitOperation(
|
||||
Operation *op, ArrayRef<LatticeElement<AxisInfo> *> operands) {
|
||||
AxisInfo curr;
|
||||
// This preserves the input axes (e.g., cast):
|
||||
if (llvm::isa<arith::ExtSIOp, arith::ExtUIOp, arith::TruncIOp,
|
||||
triton::PtrToIntOp, triton::IntToPtrOp>(op))
|
||||
curr = operands[0]->getValue();
|
||||
// Constant ranges
|
||||
if (triton::MakeRangeOp make_range = llvm::dyn_cast<triton::MakeRangeOp>(op)){
|
||||
if (triton::MakeRangeOp make_range =
|
||||
llvm::dyn_cast<triton::MakeRangeOp>(op)) {
|
||||
int start = make_range.start();
|
||||
int end = make_range.end();
|
||||
AxisInfo::ContiguityT contiguity = {end - start};
|
||||
@@ -108,61 +107,59 @@ ChangeResult AxisInfoAnalysis::visitOperation(Operation *op,
|
||||
curr = AxisInfo(contiguity, divisibility, constancy);
|
||||
}
|
||||
// Constant
|
||||
if (arith::ConstantOp constant = llvm::dyn_cast<arith::ConstantOp>(op)){
|
||||
if (arith::ConstantOp constant = llvm::dyn_cast<arith::ConstantOp>(op)) {
|
||||
auto intAttr = constant.getValue().dyn_cast<IntegerAttr>();
|
||||
if(intAttr){
|
||||
if (intAttr) {
|
||||
size_t val = intAttr.getValue().getZExtValue();
|
||||
curr = AxisInfo({1}, {highestPowOf2Divisor(val)}, {1});
|
||||
}
|
||||
// TODO: generalize to dense attr
|
||||
auto splatAttr = constant.getValue().dyn_cast<SplatElementsAttr>();
|
||||
if(splatAttr && splatAttr.getElementType().isInteger(32)){
|
||||
if (splatAttr && splatAttr.getElementType().isInteger(32)) {
|
||||
auto value = splatAttr.getSplatValue<int>();
|
||||
TensorType ty = splatAttr.getType().cast<TensorType>();
|
||||
curr = AxisInfo(AxisInfo::ContiguityT(ty.getRank(), 1),
|
||||
AxisInfo::DivisibilityT(ty.getRank(), highestPowOf2Divisor(value)),
|
||||
AxisInfo::ConstancyT(ty.getShape().begin(), ty.getShape().end()));
|
||||
|
||||
curr = AxisInfo(
|
||||
AxisInfo::ContiguityT(ty.getRank(), 1),
|
||||
AxisInfo::DivisibilityT(ty.getRank(), highestPowOf2Divisor(value)),
|
||||
AxisInfo::ConstancyT(ty.getShape().begin(), ty.getShape().end()));
|
||||
}
|
||||
}
|
||||
// Addition
|
||||
if (llvm::isa<arith::AddIOp, triton::GEPOp>(op)){
|
||||
auto newContiguity = [&](AxisInfo lhs, AxisInfo rhs, int d){
|
||||
if (llvm::isa<arith::AddIOp, triton::GEPOp>(op)) {
|
||||
auto newContiguity = [&](AxisInfo lhs, AxisInfo rhs, int d) {
|
||||
return std::max(gcd(lhs.getContiguity(d), rhs.getConstancy(d)),
|
||||
gcd(lhs.getConstancy(d), rhs.getContiguity(d)));
|
||||
};
|
||||
auto newConstancy = [&](AxisInfo lhs, AxisInfo rhs, int d){
|
||||
auto newConstancy = [&](AxisInfo lhs, AxisInfo rhs, int d) {
|
||||
return gcd(lhs.getConstancy(d), rhs.getConstancy(d));
|
||||
};
|
||||
auto newDivisibility = [&](AxisInfo lhs, AxisInfo rhs, int d){
|
||||
auto newDivisibility = [&](AxisInfo lhs, AxisInfo rhs, int d) {
|
||||
return gcd(lhs.getDivisibility(d), rhs.getDivisibility(d));
|
||||
};
|
||||
curr = visitBinaryOp(op, operands[0]->getValue(), operands[1]->getValue(),
|
||||
newContiguity, newDivisibility, newConstancy);
|
||||
newContiguity, newDivisibility, newConstancy);
|
||||
}
|
||||
// Multiplication
|
||||
if (llvm::isa<arith::MulIOp>(op)){
|
||||
auto newContiguity = [](AxisInfo lhs, AxisInfo rhs, int d){
|
||||
return 1;
|
||||
};
|
||||
auto newConstancy = [](AxisInfo lhs, AxisInfo rhs, int d){
|
||||
if (llvm::isa<arith::MulIOp>(op)) {
|
||||
auto newContiguity = [](AxisInfo lhs, AxisInfo rhs, int d) { return 1; };
|
||||
auto newConstancy = [](AxisInfo lhs, AxisInfo rhs, int d) {
|
||||
return gcd(lhs.getConstancy(d), rhs.getConstancy(d));
|
||||
};
|
||||
auto newDivisibility = [](AxisInfo lhs, AxisInfo rhs, int d){
|
||||
return lhs.getDivisibility(d)*rhs.getDivisibility(d);
|
||||
auto newDivisibility = [](AxisInfo lhs, AxisInfo rhs, int d) {
|
||||
return lhs.getDivisibility(d) * rhs.getDivisibility(d);
|
||||
};
|
||||
curr = visitBinaryOp(op, operands[0]->getValue(), operands[1]->getValue(),
|
||||
newContiguity, newDivisibility, newConstancy);
|
||||
newContiguity, newDivisibility, newConstancy);
|
||||
}
|
||||
// Splat
|
||||
if (llvm::isa<triton::SplatOp>(op)){
|
||||
if (llvm::isa<triton::SplatOp>(op)) {
|
||||
Type _retTy = *op->result_type_begin();
|
||||
TensorType retTy = _retTy.cast<TensorType>();
|
||||
AxisInfo opInfo = operands[0]->getValue();
|
||||
AxisInfo::ContiguityT contiguity;
|
||||
AxisInfo::DivisibilityT divisibility;
|
||||
AxisInfo::ConstancyT constancy;
|
||||
for(size_t d = 0; d < retTy.getRank(); d++){
|
||||
for (size_t d = 0; d < retTy.getRank(); d++) {
|
||||
contiguity.push_back(1);
|
||||
divisibility.push_back(opInfo.getDivisibility(0));
|
||||
constancy.push_back(retTy.getShape()[d]);
|
||||
@@ -171,7 +168,7 @@ ChangeResult AxisInfoAnalysis::visitOperation(Operation *op,
|
||||
}
|
||||
// Reshape
|
||||
// TODO: Replace by `unsqueeze`
|
||||
if (llvm::isa<triton::ReshapeOp>(op)){
|
||||
if (llvm::isa<triton::ReshapeOp>(op)) {
|
||||
Type _retTy = *op->result_type_begin();
|
||||
Type _opTy = *op->operand_type_begin();
|
||||
TensorType retTy = _retTy.cast<TensorType>();
|
||||
@@ -184,20 +181,17 @@ ChangeResult AxisInfoAnalysis::visitOperation(Operation *op,
|
||||
AxisInfo::ConstancyT constancy;
|
||||
bool is_skewed = false;
|
||||
size_t current = 0;
|
||||
for(size_t d = 0; d < retTy.getRank(); d++){
|
||||
if(retShape[d] == 1){
|
||||
for (size_t d = 0; d < retTy.getRank(); d++) {
|
||||
if (retShape[d] == 1) {
|
||||
contiguity.push_back(1);
|
||||
divisibility.push_back(1);
|
||||
constancy.push_back(1);
|
||||
}
|
||||
else if(!is_skewed
|
||||
&& retShape[d] == opShape[current]){
|
||||
} else if (!is_skewed && retShape[d] == opShape[current]) {
|
||||
contiguity.push_back(opInfo.getContiguity()[current]);
|
||||
divisibility.push_back(opInfo.getDivisibility()[current]);
|
||||
constancy.push_back(opInfo.getConstancy()[current]);
|
||||
current++;
|
||||
}
|
||||
else {
|
||||
} else {
|
||||
is_skewed = true;
|
||||
contiguity.push_back(1);
|
||||
divisibility.push_back(1);
|
||||
@@ -207,7 +201,7 @@ ChangeResult AxisInfoAnalysis::visitOperation(Operation *op,
|
||||
curr = AxisInfo(contiguity, divisibility, constancy);
|
||||
}
|
||||
// Broadcast
|
||||
if (llvm::isa<triton::BroadcastOp>(op)){
|
||||
if (llvm::isa<triton::BroadcastOp>(op)) {
|
||||
Type _retTy = *op->result_type_begin();
|
||||
Type _opTy = *op->operand_type_begin();
|
||||
TensorType retTy = _retTy.cast<TensorType>();
|
||||
@@ -218,14 +212,14 @@ ChangeResult AxisInfoAnalysis::visitOperation(Operation *op,
|
||||
AxisInfo::ContiguityT contiguity;
|
||||
AxisInfo::DivisibilityT divisibility;
|
||||
AxisInfo::ConstancyT constancy;
|
||||
for(size_t d = 0; d < retTy.getRank(); d++){
|
||||
for (size_t d = 0; d < retTy.getRank(); d++) {
|
||||
contiguity.push_back(opShape[d] == 1 ? 1 : opInfo.getContiguity(d));
|
||||
divisibility.push_back(opInfo.getDivisibility(d));
|
||||
constancy.push_back(opShape[d] == 1 ? retShape[d] : 1);
|
||||
}
|
||||
curr = AxisInfo(contiguity, divisibility, constancy);
|
||||
}
|
||||
if(curr.getRank() == 0){
|
||||
if (curr.getRank() == 0) {
|
||||
return markAllPessimisticFixpoint(op->getResults());
|
||||
}
|
||||
// join all latice elements
|
||||
@@ -236,4 +230,4 @@ ChangeResult AxisInfoAnalysis::visitOperation(Operation *op,
|
||||
return result;
|
||||
}
|
||||
|
||||
}
|
||||
} // namespace mlir
|
@@ -2,14 +2,16 @@
|
||||
#define TRITON_CONVERSION_PASSDETAIL_H
|
||||
|
||||
#include "mlir/Pass/Pass.h"
|
||||
#include "triton/Dialect/Triton/IR/Dialect.h"
|
||||
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
|
||||
|
||||
namespace mlir{
|
||||
namespace triton{
|
||||
namespace mlir {
|
||||
namespace triton {
|
||||
|
||||
#define GEN_PASS_CLASSES
|
||||
#include "triton/Conversion/Passes.h.inc"
|
||||
|
||||
}
|
||||
}
|
||||
} // namespace triton
|
||||
} // namespace mlir
|
||||
|
||||
#endif
|
||||
|
@@ -1,42 +1,42 @@
|
||||
#include "triton/Conversion/TritonToTritonGPU/TritonToTritonGPU.h"
|
||||
#include "../PassDetail.h"
|
||||
#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
|
||||
#include "mlir/Transforms/DialectConversion.h"
|
||||
#include "triton/Dialect/Triton/IR/Dialect.h"
|
||||
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
|
||||
#include "triton/Conversion/TritonToTritonGPU/TritonToTritonGPU.h"
|
||||
#include "triton/Dialect/TritonGPU/Transforms/TritonGPUConversion.h"
|
||||
#include "mlir/Transforms/DialectConversion.h"
|
||||
#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
|
||||
#include "../PassDetail.h"
|
||||
|
||||
using namespace mlir;
|
||||
using namespace mlir::triton;
|
||||
|
||||
namespace {
|
||||
|
||||
template<class Op>
|
||||
class ArithGenericPattern : public OpConversionPattern<Op> {
|
||||
template <class Op> class ArithGenericPattern : public OpConversionPattern<Op> {
|
||||
public:
|
||||
using OpConversionPattern<Op>::OpConversionPattern;
|
||||
|
||||
LogicalResult matchAndRewrite(Op op, typename Op::Adaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
LogicalResult
|
||||
matchAndRewrite(Op op, typename Op::Adaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
Type retType = this->getTypeConverter()->convertType(op.getType());
|
||||
Op res = rewriter.replaceOpWithNewOp<Op>(
|
||||
op, retType, adaptor.getOperands()
|
||||
);
|
||||
Op res =
|
||||
rewriter.replaceOpWithNewOp<Op>(op, retType, adaptor.getOperands());
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
template<class SrcOp, class DstOp>
|
||||
template <class SrcOp, class DstOp>
|
||||
class ArithCmpPattern : public OpConversionPattern<SrcOp> {
|
||||
public:
|
||||
using OpConversionPattern<SrcOp>::OpConversionPattern;
|
||||
|
||||
LogicalResult matchAndRewrite(SrcOp op, typename SrcOp::Adaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
LogicalResult
|
||||
matchAndRewrite(SrcOp op, typename SrcOp::Adaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
Type retType = this->getTypeConverter()->convertType(op.getType());
|
||||
DstOp res = rewriter.replaceOpWithNewOp<DstOp>(
|
||||
op, retType, adaptor.getPredicate(), adaptor.getLhs(), adaptor.getRhs()
|
||||
);
|
||||
DstOp res =
|
||||
rewriter.replaceOpWithNewOp<DstOp>(op, retType, adaptor.getPredicate(),
|
||||
adaptor.getLhs(), adaptor.getRhs());
|
||||
return success();
|
||||
}
|
||||
};
|
||||
@@ -45,36 +45,40 @@ class ArithConstantPattern : public OpConversionPattern<arith::ConstantOp> {
|
||||
public:
|
||||
using OpConversionPattern<arith::ConstantOp>::OpConversionPattern;
|
||||
|
||||
LogicalResult matchAndRewrite(arith::ConstantOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
LogicalResult
|
||||
matchAndRewrite(arith::ConstantOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
Type retType = getTypeConverter()->convertType(op.getType());
|
||||
auto value = adaptor.getValue().dyn_cast<DenseElementsAttr>();
|
||||
assert(value);
|
||||
rewriter.replaceOpWithNewOp<arith::ConstantOp>(
|
||||
op, retType, value.reshape(retType) // This is a hack. We just want to add encoding
|
||||
op, retType,
|
||||
value.reshape(retType) // This is a hack. We just want to add encoding
|
||||
);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
class ConvertArithmeticOp: public ConversionPattern {
|
||||
class ConvertArithmeticOp : public ConversionPattern {
|
||||
public:
|
||||
ConvertArithmeticOp(TritonGPUTypeConverter &typeConverter, MLIRContext *context)
|
||||
: ConversionPattern(typeConverter, MatchAnyOpTypeTag(), /*benefit=*/1,
|
||||
context) {}
|
||||
ConvertArithmeticOp(TritonGPUTypeConverter &typeConverter,
|
||||
MLIRContext *context)
|
||||
: ConversionPattern(typeConverter, MatchAnyOpTypeTag(), /*benefit=*/1,
|
||||
context) {}
|
||||
|
||||
LogicalResult matchAndRewrite(Operation* op, ArrayRef<Value> operands,
|
||||
ConversionPatternRewriter& rewriter) const override {
|
||||
Dialect* dialect = op->getDialect();
|
||||
if(dialect->getTypeID() != mlir::TypeID::get<arith::ArithmeticDialect>())
|
||||
return failure();
|
||||
return success();
|
||||
}
|
||||
LogicalResult
|
||||
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
Dialect *dialect = op->getDialect();
|
||||
if (dialect->getTypeID() != mlir::TypeID::get<arith::ArithmeticDialect>())
|
||||
return failure();
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
void populateArithmeticPatternsAndLegality(
|
||||
TritonGPUTypeConverter& typeConverter, RewritePatternSet &patterns,
|
||||
TritonGPUConversionTarget &target){
|
||||
TritonGPUTypeConverter &typeConverter, RewritePatternSet &patterns,
|
||||
TritonGPUConversionTarget &target) {
|
||||
// --------------
|
||||
// Add legality and rewrite pattern rules for operations
|
||||
// from the Arithmetic dialect. The basic premise is that
|
||||
@@ -91,59 +95,49 @@ void populateArithmeticPatternsAndLegality(
|
||||
// );
|
||||
// Rewrite rule
|
||||
// patterns.add<ConvertArithmeticOp>(typeConverter, context);
|
||||
patterns.add<ArithConstantPattern,
|
||||
ArithGenericPattern<arith::AddIOp>,
|
||||
ArithGenericPattern<arith::SubIOp>,
|
||||
ArithGenericPattern<arith::MulIOp>,
|
||||
ArithGenericPattern<arith::DivUIOp>,
|
||||
ArithGenericPattern<arith::DivSIOp>,
|
||||
ArithGenericPattern<arith::CeilDivUIOp>,
|
||||
ArithGenericPattern<arith::CeilDivSIOp>,
|
||||
ArithGenericPattern<arith::FloorDivSIOp>,
|
||||
ArithGenericPattern<arith::RemUIOp>,
|
||||
ArithGenericPattern<arith::RemSIOp>,
|
||||
ArithGenericPattern<arith::AndIOp>,
|
||||
ArithGenericPattern<arith::OrIOp>,
|
||||
ArithGenericPattern<arith::XOrIOp>,
|
||||
ArithGenericPattern<arith::ShLIOp>,
|
||||
ArithGenericPattern<arith::ShRUIOp>,
|
||||
ArithGenericPattern<arith::ShRSIOp>, // NegFOp
|
||||
// Floating point
|
||||
ArithGenericPattern<arith::AddFOp>,
|
||||
ArithGenericPattern<arith::SubFOp>,
|
||||
// MaxMin
|
||||
ArithGenericPattern<arith::MaxFOp>,
|
||||
ArithGenericPattern<arith::MaxSIOp>,
|
||||
ArithGenericPattern<arith::MaxUIOp>,
|
||||
ArithGenericPattern<arith::MinFOp>,
|
||||
ArithGenericPattern<arith::MinSIOp>,
|
||||
ArithGenericPattern<arith::MinUIOp>,
|
||||
// Floating point
|
||||
ArithGenericPattern<arith::MulFOp>,
|
||||
ArithGenericPattern<arith::DivFOp>,
|
||||
ArithGenericPattern<arith::RemFOp>,
|
||||
// Cmp
|
||||
ArithCmpPattern<arith::CmpIOp, triton::gpu::CmpIOp>,
|
||||
ArithCmpPattern<arith::CmpFOp, triton::gpu::CmpFOp>,
|
||||
// Cast Ops
|
||||
ArithGenericPattern<arith::TruncIOp>,
|
||||
ArithGenericPattern<arith::TruncFOp>
|
||||
>(typeConverter, context);
|
||||
patterns.add<
|
||||
ArithConstantPattern, ArithGenericPattern<arith::AddIOp>,
|
||||
ArithGenericPattern<arith::SubIOp>, ArithGenericPattern<arith::MulIOp>,
|
||||
ArithGenericPattern<arith::DivUIOp>, ArithGenericPattern<arith::DivSIOp>,
|
||||
ArithGenericPattern<arith::CeilDivUIOp>,
|
||||
ArithGenericPattern<arith::CeilDivSIOp>,
|
||||
ArithGenericPattern<arith::FloorDivSIOp>,
|
||||
ArithGenericPattern<arith::RemUIOp>, ArithGenericPattern<arith::RemSIOp>,
|
||||
ArithGenericPattern<arith::AndIOp>, ArithGenericPattern<arith::OrIOp>,
|
||||
ArithGenericPattern<arith::XOrIOp>, ArithGenericPattern<arith::ShLIOp>,
|
||||
ArithGenericPattern<arith::ShRUIOp>,
|
||||
ArithGenericPattern<arith::ShRSIOp>, // NegFOp
|
||||
// Floating point
|
||||
ArithGenericPattern<arith::AddFOp>, ArithGenericPattern<arith::SubFOp>,
|
||||
// MaxMin
|
||||
ArithGenericPattern<arith::MaxFOp>, ArithGenericPattern<arith::MaxSIOp>,
|
||||
ArithGenericPattern<arith::MaxUIOp>, ArithGenericPattern<arith::MinFOp>,
|
||||
ArithGenericPattern<arith::MinSIOp>, ArithGenericPattern<arith::MinUIOp>,
|
||||
// Floating point
|
||||
ArithGenericPattern<arith::MulFOp>, ArithGenericPattern<arith::DivFOp>,
|
||||
ArithGenericPattern<arith::RemFOp>,
|
||||
// Cmp
|
||||
ArithCmpPattern<arith::CmpIOp, triton::gpu::CmpIOp>,
|
||||
ArithCmpPattern<arith::CmpFOp, triton::gpu::CmpFOp>,
|
||||
// Cast Ops
|
||||
ArithGenericPattern<arith::TruncIOp>,
|
||||
ArithGenericPattern<arith::TruncFOp>>(typeConverter, context);
|
||||
}
|
||||
|
||||
//
|
||||
// Triton patterns
|
||||
//
|
||||
// TODO: Do we need to put them in anonymous namespace?
|
||||
struct TritonMakeRangePattern : public OpConversionPattern<triton::MakeRangeOp> {
|
||||
struct TritonMakeRangePattern
|
||||
: public OpConversionPattern<triton::MakeRangeOp> {
|
||||
using OpConversionPattern<triton::MakeRangeOp>::OpConversionPattern;
|
||||
|
||||
LogicalResult matchAndRewrite(triton::MakeRangeOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
LogicalResult
|
||||
matchAndRewrite(triton::MakeRangeOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
Type retType = getTypeConverter()->convertType(op.getType());
|
||||
rewriter.replaceOpWithNewOp<triton::MakeRangeOp>(
|
||||
op, retType, adaptor.start(), adaptor.end()
|
||||
);
|
||||
op, retType, adaptor.start(), adaptor.end());
|
||||
return success();
|
||||
}
|
||||
};
|
||||
@@ -151,8 +145,9 @@ struct TritonMakeRangePattern : public OpConversionPattern<triton::MakeRangeOp>
|
||||
struct TritonDotPattern : public OpConversionPattern<triton::DotOp> {
|
||||
using OpConversionPattern<triton::DotOp>::OpConversionPattern;
|
||||
|
||||
LogicalResult matchAndRewrite(triton::DotOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
LogicalResult
|
||||
matchAndRewrite(triton::DotOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
Type retType = getTypeConverter()->convertType(op.getType());
|
||||
// a & b must be of smem layout
|
||||
auto aType = adaptor.a().getType().cast<RankedTensorType>();
|
||||
@@ -165,18 +160,21 @@ struct TritonDotPattern : public OpConversionPattern<triton::DotOp> {
|
||||
Value b = adaptor.b();
|
||||
SmallVector<unsigned, 2> order{1, 0};
|
||||
if (!aEncoding.isa<triton::gpu::TritonGPUSharedEncodingAttr>()) {
|
||||
Attribute encoding = triton::gpu::TritonGPUSharedEncodingAttr::get(getContext(), 1, 1, 1, order);
|
||||
auto dstType = RankedTensorType::get(aType.getShape(), aType.getElementType(), encoding);
|
||||
Attribute encoding = triton::gpu::TritonGPUSharedEncodingAttr::get(
|
||||
getContext(), 1, 1, 1, order);
|
||||
auto dstType = RankedTensorType::get(aType.getShape(),
|
||||
aType.getElementType(), encoding);
|
||||
a = rewriter.create<triton::gpu::ConvertLayoutOp>(a.getLoc(), dstType, a);
|
||||
}
|
||||
if (!bEncoding.isa<triton::gpu::TritonGPUSharedEncodingAttr>()) {
|
||||
Attribute encoding = triton::gpu::TritonGPUSharedEncodingAttr::get(getContext(), 1, 1, 1, order);
|
||||
auto dstType = RankedTensorType::get(bType.getShape(), bType.getElementType(), encoding);
|
||||
Attribute encoding = triton::gpu::TritonGPUSharedEncodingAttr::get(
|
||||
getContext(), 1, 1, 1, order);
|
||||
auto dstType = RankedTensorType::get(bType.getShape(),
|
||||
bType.getElementType(), encoding);
|
||||
b = rewriter.create<triton::gpu::ConvertLayoutOp>(b.getLoc(), dstType, b);
|
||||
}
|
||||
auto newDot = rewriter.replaceOpWithNewOp<triton::DotOp>(
|
||||
op, retType, a, b, adaptor.c(), adaptor.allowTF32()
|
||||
);
|
||||
op, retType, a, b, adaptor.c(), adaptor.allowTF32());
|
||||
return success();
|
||||
}
|
||||
};
|
||||
@@ -184,14 +182,13 @@ struct TritonDotPattern : public OpConversionPattern<triton::DotOp> {
|
||||
struct TritonLoadPattern : public OpConversionPattern<triton::LoadOp> {
|
||||
using OpConversionPattern<triton::LoadOp>::OpConversionPattern;
|
||||
|
||||
LogicalResult matchAndRewrite(triton::LoadOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
LogicalResult
|
||||
matchAndRewrite(triton::LoadOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
Type retType = getTypeConverter()->convertType(op.getType());
|
||||
rewriter.replaceOpWithNewOp<triton::LoadOp>(
|
||||
op, retType,
|
||||
adaptor.ptr(), adaptor.mask(), adaptor.other(),
|
||||
adaptor.cache(), adaptor.evict(), adaptor.isVolatile()
|
||||
);
|
||||
op, retType, adaptor.ptr(), adaptor.mask(), adaptor.other(),
|
||||
adaptor.cache(), adaptor.evict(), adaptor.isVolatile());
|
||||
return success();
|
||||
}
|
||||
};
|
||||
@@ -199,11 +196,11 @@ struct TritonLoadPattern : public OpConversionPattern<triton::LoadOp> {
|
||||
struct TritonStorePattern : public OpConversionPattern<triton::StoreOp> {
|
||||
using OpConversionPattern<triton::StoreOp>::OpConversionPattern;
|
||||
|
||||
LogicalResult matchAndRewrite(triton::StoreOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
LogicalResult
|
||||
matchAndRewrite(triton::StoreOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
auto newOp = rewriter.replaceOpWithNewOp<triton::StoreOp>(
|
||||
op, adaptor.ptr(), adaptor.value(), adaptor.mask()
|
||||
);
|
||||
op, adaptor.ptr(), adaptor.value(), adaptor.mask());
|
||||
return success();
|
||||
}
|
||||
};
|
||||
@@ -212,12 +209,11 @@ template <class Op>
|
||||
struct TritonGenericPattern : public OpConversionPattern<Op> {
|
||||
using OpConversionPattern<Op>::OpConversionPattern;
|
||||
|
||||
LogicalResult matchAndRewrite(Op op, typename Op::Adaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
LogicalResult
|
||||
matchAndRewrite(Op op, typename Op::Adaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
Type retType = this->getTypeConverter()->convertType(op.getType());
|
||||
rewriter.replaceOpWithNewOp<Op>(
|
||||
op, retType, adaptor.getOperands()
|
||||
);
|
||||
rewriter.replaceOpWithNewOp<Op>(op, retType, adaptor.getOperands());
|
||||
return success();
|
||||
}
|
||||
};
|
||||
@@ -225,30 +221,25 @@ struct TritonGenericPattern : public OpConversionPattern<Op> {
|
||||
struct TritonReducePattern : public OpConversionPattern<triton::ReduceOp> {
|
||||
using OpConversionPattern<triton::ReduceOp>::OpConversionPattern;
|
||||
|
||||
LogicalResult matchAndRewrite(triton::ReduceOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
LogicalResult
|
||||
matchAndRewrite(triton::ReduceOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
Type retType = this->getTypeConverter()->convertType(op.getType());
|
||||
auto newOp = rewriter.replaceOpWithNewOp<triton::ReduceOp>(
|
||||
op, retType, adaptor.redOp(), adaptor.operand(), adaptor.axis()
|
||||
);
|
||||
op, retType, adaptor.redOp(), adaptor.operand(), adaptor.axis());
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
void populateTritonPatterns(
|
||||
TritonGPUTypeConverter& typeConverter, RewritePatternSet &patterns
|
||||
) {
|
||||
void populateTritonPatterns(TritonGPUTypeConverter &typeConverter,
|
||||
RewritePatternSet &patterns) {
|
||||
MLIRContext *context = patterns.getContext();
|
||||
patterns.add<TritonGenericPattern<triton::ReshapeOp>,
|
||||
TritonGenericPattern<triton::SplatOp>,
|
||||
TritonGenericPattern<triton::BroadcastOp>,
|
||||
TritonGenericPattern<triton::GEPOp>,
|
||||
TritonReducePattern,
|
||||
TritonMakeRangePattern,
|
||||
TritonDotPattern,
|
||||
TritonLoadPattern,
|
||||
TritonStorePattern
|
||||
>(typeConverter, context);
|
||||
TritonGenericPattern<triton::GEPOp>, TritonReducePattern,
|
||||
TritonMakeRangePattern, TritonDotPattern, TritonLoadPattern,
|
||||
TritonStorePattern>(typeConverter, context);
|
||||
}
|
||||
|
||||
//
|
||||
@@ -259,17 +250,19 @@ void populateTritonPatterns(
|
||||
struct SCFForPattern : public OpConversionPattern<scf::ForOp> {
|
||||
using OpConversionPattern<scf::ForOp>::OpConversionPattern;
|
||||
// Ref: ConvertForOpTypes
|
||||
LogicalResult matchAndRewrite(scf::ForOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
auto newOp = cast<scf::ForOp>(rewriter.cloneWithoutRegions(*op.getOperation()));
|
||||
LogicalResult
|
||||
matchAndRewrite(scf::ForOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
auto newOp =
|
||||
cast<scf::ForOp>(rewriter.cloneWithoutRegions(*op.getOperation()));
|
||||
rewriter.inlineRegionBefore(op.getLoopBody(), newOp.getLoopBody(),
|
||||
newOp.getLoopBody().end());
|
||||
|
||||
// Now, update all the types.
|
||||
|
||||
// Convert the types of block arguments within the given region. This
|
||||
// replaces each block with a new block containing the updated signature. The
|
||||
// entry block may have a special conversion if `entryConversion` is
|
||||
// replaces each block with a new block containing the updated signature.
|
||||
// The entry block may have a special conversion if `entryConversion` is
|
||||
// provided. On success, the new entry block to the region is returned for
|
||||
// convenience. Otherwise, failure is returned.
|
||||
if (failed(rewriter.convertRegionTypes(&newOp.getLoopBody(),
|
||||
@@ -299,33 +292,27 @@ struct SCFForPattern : public OpConversionPattern<scf::ForOp> {
|
||||
struct SCFYieldPattern : public OpConversionPattern<scf::YieldOp> {
|
||||
using OpConversionPattern<scf::YieldOp>::OpConversionPattern;
|
||||
|
||||
LogicalResult matchAndRewrite(scf::YieldOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
LogicalResult
|
||||
matchAndRewrite(scf::YieldOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
// rewriter.setInsertionPointToEnd(rewriter.getInsertionBlock());
|
||||
// rewriter.create<scf::YieldOp>(op.getLoc(), adaptor.getOperands());
|
||||
// op.erase();
|
||||
rewriter.replaceOpWithNewOp<scf::YieldOp>(
|
||||
op, adaptor.getOperands()
|
||||
);
|
||||
rewriter.replaceOpWithNewOp<scf::YieldOp>(op, adaptor.getOperands());
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
void populateSCFPatterns(
|
||||
TritonGPUTypeConverter &typeConverter, RewritePatternSet &patterns
|
||||
) {
|
||||
void populateSCFPatterns(TritonGPUTypeConverter &typeConverter,
|
||||
RewritePatternSet &patterns) {
|
||||
MLIRContext *context = patterns.getContext();
|
||||
patterns.add<SCFYieldPattern, SCFForPattern
|
||||
>(typeConverter, context);
|
||||
patterns.add<SCFYieldPattern, SCFForPattern>(typeConverter, context);
|
||||
}
|
||||
|
||||
|
||||
class ConvertTritonToTritonGPU :
|
||||
public ConvertTritonToTritonGPUBase<ConvertTritonToTritonGPU> {
|
||||
class ConvertTritonToTritonGPU
|
||||
: public ConvertTritonToTritonGPUBase<ConvertTritonToTritonGPU> {
|
||||
public:
|
||||
ConvertTritonToTritonGPU(int numWarps) {
|
||||
this->numWarps = numWarps;
|
||||
}
|
||||
ConvertTritonToTritonGPU(int numWarps) { this->numWarps = numWarps; }
|
||||
|
||||
void runOnOperation() override {
|
||||
MLIRContext *context = &getContext();
|
||||
@@ -339,21 +326,21 @@ public:
|
||||
// add rules
|
||||
populateArithmeticPatternsAndLegality(typeConverter, patterns, target);
|
||||
populateTritonPatterns(typeConverter, patterns);
|
||||
// TODO: can we use
|
||||
// TODO: can we use
|
||||
// mlir::scf::populateSCFStructurealTypeConversionsAndLegality(...) here?
|
||||
populateSCFPatterns(typeConverter, patterns);
|
||||
|
||||
if(failed(applyPartialConversion(mod, target, std::move(patterns))))
|
||||
return signalPassFailure();
|
||||
if (failed(applyPartialConversion(mod, target, std::move(patterns))))
|
||||
return signalPassFailure();
|
||||
|
||||
// update layouts
|
||||
// broadcast src => multicast, dst => broadcasted
|
||||
if(failed(target.refineLayouts(mod, numWarps)))
|
||||
if (failed(target.refineLayouts(mod, numWarps)))
|
||||
return signalPassFailure();
|
||||
}
|
||||
};
|
||||
|
||||
}
|
||||
} // namespace
|
||||
|
||||
std::unique_ptr<OperationPass<ModuleOp>>
|
||||
mlir::triton::createConvertTritonToTritonGPUPass(int numWarps) {
|
||||
|
@@ -7,7 +7,6 @@
|
||||
|
||||
#include "mlir/IR/DialectImplementation.h"
|
||||
|
||||
|
||||
#include "triton/Dialect/Triton/IR/Dialect.cpp.inc"
|
||||
|
||||
using namespace mlir;
|
||||
@@ -19,12 +18,13 @@ void TritonDialect::initialize() {
|
||||
addOperations<
|
||||
#define GET_OP_LIST
|
||||
#include "triton/Dialect/Triton/IR/Ops.cpp.inc"
|
||||
>();
|
||||
>();
|
||||
|
||||
// We can also add interface here.
|
||||
}
|
||||
|
||||
Operation *TritonDialect::materializeConstant(OpBuilder &builder, Attribute value,
|
||||
Type type, Location loc) {
|
||||
Operation *TritonDialect::materializeConstant(OpBuilder &builder,
|
||||
Attribute value, Type type,
|
||||
Location loc) {
|
||||
return builder.create<arith::ConstantOp>(loc, type, value);
|
||||
}
|
@@ -13,14 +13,16 @@ namespace triton {
|
||||
static Type getI1SameShape(Type type) {
|
||||
auto i1Type = IntegerType::get(type.getContext(), 1);
|
||||
if (auto tensorType = type.dyn_cast<RankedTensorType>())
|
||||
return RankedTensorType::get(tensorType.getShape(), i1Type, tensorType.getEncoding());
|
||||
return RankedTensorType::get(tensorType.getShape(), i1Type,
|
||||
tensorType.getEncoding());
|
||||
return Type();
|
||||
}
|
||||
|
||||
static Type getI32SameShape(Type type) {
|
||||
auto i32Type = IntegerType::get(type.getContext(), 32);
|
||||
if (auto tensorType = type.dyn_cast<RankedTensorType>())
|
||||
return RankedTensorType::get(tensorType.getShape(), i32Type, tensorType.getEncoding());
|
||||
return RankedTensorType::get(tensorType.getShape(), i32Type,
|
||||
tensorType.getEncoding());
|
||||
return Type();
|
||||
}
|
||||
|
||||
@@ -34,8 +36,8 @@ static Type getPointerTypeFromTensor(Type type) {
|
||||
return Type();
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
} // namespace triton
|
||||
} // namespace mlir
|
||||
|
||||
#define GET_OP_CLASSES
|
||||
#include "triton/Dialect/Triton/IR/Ops.cpp.inc"
|
||||
@@ -48,50 +50,48 @@ namespace triton {
|
||||
|
||||
//-- StoreOp --
|
||||
// Default mask
|
||||
void StoreOp::build(::mlir::OpBuilder &builder, ::mlir::OperationState &state, ::mlir::Value ptr, ::mlir::Value value) {
|
||||
void StoreOp::build(::mlir::OpBuilder &builder, ::mlir::OperationState &state,
|
||||
::mlir::Value ptr, ::mlir::Value value) {
|
||||
TensorType ptrType = ptr.getType().dyn_cast<TensorType>();
|
||||
auto shape = ptrType.getShape();
|
||||
::mlir::Value mask = builder.create<arith::ConstantOp>(
|
||||
ptr.getLoc(),
|
||||
RankedTensorType::get(shape, builder.getI1Type()),
|
||||
DenseIntElementsAttr::get(
|
||||
RankedTensorType::get(shape, builder.getI1Type()), true
|
||||
)
|
||||
);
|
||||
ptr.getLoc(), RankedTensorType::get(shape, builder.getI1Type()),
|
||||
DenseIntElementsAttr::get(
|
||||
RankedTensorType::get(shape, builder.getI1Type()), true));
|
||||
state.addOperands(ptr);
|
||||
state.addOperands(value);
|
||||
state.addOperands(mask);
|
||||
}
|
||||
|
||||
//-- LoadOp --
|
||||
void LoadOp::build(::mlir::OpBuilder &builder, ::mlir::OperationState &state, ::mlir::Value ptr,
|
||||
::mlir::triton::CacheModifier cache, ::mlir::triton::EvictionPolicy evict, bool isVolatile) {
|
||||
void LoadOp::build(::mlir::OpBuilder &builder, ::mlir::OperationState &state,
|
||||
::mlir::Value ptr, ::mlir::triton::CacheModifier cache,
|
||||
::mlir::triton::EvictionPolicy evict, bool isVolatile) {
|
||||
TensorType ptrType = ptr.getType().dyn_cast<TensorType>();
|
||||
Type elementType = ptrType.getElementType().dyn_cast<PointerType>().getPointeeType();
|
||||
Type elementType =
|
||||
ptrType.getElementType().dyn_cast<PointerType>().getPointeeType();
|
||||
auto shape = ptrType.getShape();
|
||||
// mask
|
||||
::mlir::Value mask = builder.create<arith::ConstantOp>(
|
||||
ptr.getLoc(),
|
||||
RankedTensorType::get(shape, builder.getI1Type()),
|
||||
DenseIntElementsAttr::get(
|
||||
RankedTensorType::get(shape, builder.getI1Type()), true
|
||||
)
|
||||
);
|
||||
ptr.getLoc(), RankedTensorType::get(shape, builder.getI1Type()),
|
||||
DenseIntElementsAttr::get(
|
||||
RankedTensorType::get(shape, builder.getI1Type()), true));
|
||||
// other
|
||||
Type resultType = RankedTensorType::get(shape, elementType);
|
||||
::mlir::Value other = builder.create<arith::ConstantOp>(
|
||||
ptr.getLoc(),
|
||||
resultType,
|
||||
DenseElementsAttr::get(
|
||||
resultType, builder.getZeroAttr(elementType)
|
||||
)
|
||||
);
|
||||
ptr.getLoc(), resultType,
|
||||
DenseElementsAttr::get(resultType, builder.getZeroAttr(elementType)));
|
||||
state.addOperands(ptr);
|
||||
state.addOperands(mask);
|
||||
state.addOperands(other);
|
||||
state.addAttribute(cacheAttrName(state.name), ::mlir::triton::CacheModifierAttr::get(builder.getContext(), cache));
|
||||
state.addAttribute(evictAttrName(state.name), ::mlir::triton::EvictionPolicyAttr::get(builder.getContext(), evict));
|
||||
state.addAttribute(isVolatileAttrName(state.name), builder.getBoolAttr(isVolatile));
|
||||
state.addAttribute(
|
||||
cacheAttrName(state.name),
|
||||
::mlir::triton::CacheModifierAttr::get(builder.getContext(), cache));
|
||||
state.addAttribute(
|
||||
evictAttrName(state.name),
|
||||
::mlir::triton::EvictionPolicyAttr::get(builder.getContext(), evict));
|
||||
state.addAttribute(isVolatileAttrName(state.name),
|
||||
builder.getBoolAttr(isVolatile));
|
||||
state.addTypes({resultType});
|
||||
}
|
||||
|
||||
|
@@ -1,6 +1,6 @@
|
||||
#include "triton/Dialect/Triton/IR/Dialect.h"
|
||||
#include "triton/Dialect/Triton/IR/Types.h"
|
||||
#include "mlir/IR/DialectImplementation.h" // required by `Types.cpp.inc`
|
||||
#include "triton/Dialect/Triton/IR/Dialect.h"
|
||||
#include "llvm/ADT/TypeSwitch.h" // required by `Types.cpp.inc`
|
||||
|
||||
using namespace mlir;
|
||||
@@ -16,7 +16,7 @@ void TritonDialect::registerTypes() {
|
||||
addTypes<
|
||||
#define GET_TYPEDEF_LIST
|
||||
#include "triton/Dialect/Triton/IR/Types.cpp.inc"
|
||||
>();
|
||||
>();
|
||||
}
|
||||
|
||||
Type PointerType::parse(AsmParser &parser) {
|
||||
|
@@ -17,21 +17,23 @@ namespace {
|
||||
class CombineDotOp : public mlir::RewritePattern {
|
||||
public:
|
||||
CombineDotOp(mlir::MLIRContext *context)
|
||||
: mlir::RewritePattern(mlir::RewritePattern::MatchAnyOpTypeTag(), 1, context) {}
|
||||
mlir::LogicalResult matchAndRewrite(mlir::Operation *op,
|
||||
mlir::PatternRewriter &rewriter) const override {
|
||||
: mlir::RewritePattern(mlir::RewritePattern::MatchAnyOpTypeTag(), 1,
|
||||
context) {}
|
||||
mlir::LogicalResult
|
||||
matchAndRewrite(mlir::Operation *op,
|
||||
mlir::PatternRewriter &rewriter) const override {
|
||||
if (llvm::isa<mlir::arith::AddIOp, mlir::arith::AddFOp>(op)) {
|
||||
if (isCandidate(op->getOperand(0)).succeeded()) {
|
||||
auto dotOp = op->getOperand(0).getDefiningOp<mlir::triton::DotOp>();
|
||||
rewriter.replaceOpWithNewOp<mlir::triton::DotOp>(
|
||||
op, dotOp->getResultTypes().front(), dotOp.a(),
|
||||
dotOp.b(), op->getOperand(1), dotOp.allowTF32());
|
||||
op, dotOp->getResultTypes().front(), dotOp.a(), dotOp.b(),
|
||||
op->getOperand(1), dotOp.allowTF32());
|
||||
return mlir::success();
|
||||
} else if (isCandidate(op->getOperand(1)).succeeded()) {
|
||||
auto dotOp = op->getOperand(1).getDefiningOp<mlir::triton::DotOp>();
|
||||
rewriter.replaceOpWithNewOp<mlir::triton::DotOp>(
|
||||
op, dotOp->getResultTypes().front(), dotOp.a(),
|
||||
dotOp.b(), op->getOperand(0), dotOp.allowTF32());
|
||||
op, dotOp->getResultTypes().front(), dotOp.a(), dotOp.b(),
|
||||
op->getOperand(0), dotOp.allowTF32());
|
||||
return mlir::success();
|
||||
}
|
||||
}
|
||||
@@ -54,7 +56,7 @@ private:
|
||||
return true;
|
||||
// broadcast(constant_0)
|
||||
if (auto bc = val.getDefiningOp<mlir::triton::BroadcastOp>()) {
|
||||
if (mlir::matchPattern(bc.src(), mlir::m_Zero()) ||
|
||||
if (mlir::matchPattern(bc.src(), mlir::m_Zero()) ||
|
||||
mlir::matchPattern(bc.src(), mlir::m_AnyZeroFloat()))
|
||||
return true;
|
||||
}
|
||||
@@ -68,18 +70,19 @@ private:
|
||||
class CombineGEPOp : public mlir::RewritePattern {
|
||||
public:
|
||||
CombineGEPOp(mlir::MLIRContext *context)
|
||||
: mlir::RewritePattern(mlir::RewritePattern::MatchAnyOpTypeTag(), 1, context) {}
|
||||
: mlir::RewritePattern(mlir::RewritePattern::MatchAnyOpTypeTag(), 1,
|
||||
context) {}
|
||||
|
||||
mlir::LogicalResult matchAndRewrite(mlir::Operation *op,
|
||||
mlir::PatternRewriter &rewriter) const override {
|
||||
mlir::LogicalResult
|
||||
matchAndRewrite(mlir::Operation *op,
|
||||
mlir::PatternRewriter &rewriter) const override {
|
||||
if (llvm::isa<mlir::triton::GEPOp>(op)) {
|
||||
if (auto gep2 = op->getOperand(0).getDefiningOp<mlir::triton::GEPOp>()) {
|
||||
auto loc = op->getLoc();
|
||||
mlir::Value newIdx = rewriter.create<mlir::arith::AddIOp>(
|
||||
loc, op->getOperand(1), gep2->getOperand(1));
|
||||
loc, op->getOperand(1), gep2->getOperand(1));
|
||||
rewriter.replaceOpWithNewOp<mlir::triton::GEPOp>(
|
||||
op, op->getResultTypes().front(), gep2->getOperand(0), newIdx
|
||||
);
|
||||
op, op->getResultTypes().front(), gep2->getOperand(0), newIdx);
|
||||
return mlir::success();
|
||||
}
|
||||
}
|
||||
@@ -92,20 +95,21 @@ public:
|
||||
class CombineSelectMaskedLoadOp : public mlir::RewritePattern {
|
||||
public:
|
||||
CombineSelectMaskedLoadOp(mlir::MLIRContext *context)
|
||||
: mlir::RewritePattern(mlir::RewritePattern::MatchAnyOpTypeTag(), 1, context) {}
|
||||
: mlir::RewritePattern(mlir::RewritePattern::MatchAnyOpTypeTag(), 1,
|
||||
context) {}
|
||||
|
||||
mlir::LogicalResult matchAndRewrite(mlir::Operation *op,
|
||||
mlir::PatternRewriter &rewriter) const override {
|
||||
mlir::LogicalResult
|
||||
matchAndRewrite(mlir::Operation *op,
|
||||
mlir::PatternRewriter &rewriter) const override {
|
||||
if (llvm::isa<mlir::SelectOp>(op)) {
|
||||
if (auto load = op->getOperand(1).getDefiningOp<mlir::triton::LoadOp>()) {
|
||||
mlir::Value cond = op->getOperand(0);
|
||||
if (auto bc = load.mask().getDefiningOp<mlir::triton::BroadcastOp>()) {
|
||||
if (bc.src().getDefiningOp() == cond.getDefiningOp()) {
|
||||
rewriter.replaceOpWithNewOp<mlir::triton::LoadOp>(
|
||||
op, op->getResultTypes().front(),
|
||||
load.ptr(), load.mask(), op->getOperand(2),
|
||||
load.cache(), load.evict(), load.isVolatile()
|
||||
);
|
||||
op, op->getResultTypes().front(), load.ptr(), load.mask(),
|
||||
op->getOperand(2), load.cache(), load.evict(),
|
||||
load.isVolatile());
|
||||
return mlir::success();
|
||||
}
|
||||
}
|
||||
@@ -120,11 +124,11 @@ public:
|
||||
class CombineBroadcastConstantOp : public mlir::RewritePattern {
|
||||
public:
|
||||
CombineBroadcastConstantOp(mlir::MLIRContext *context)
|
||||
: mlir::RewritePattern(mlir::RewritePattern::MatchAnyOpTypeTag(), 1,
|
||||
context) {}
|
||||
|
||||
: mlir::RewritePattern(mlir::RewritePattern::MatchAnyOpTypeTag(), 1,
|
||||
context) {}
|
||||
|
||||
LogicalResult matchAndRewrite(Operation *op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
PatternRewriter &rewriter) const override {
|
||||
if (auto broadcast = llvm::dyn_cast<triton::BroadcastOp>(op)) {
|
||||
if (auto cst = broadcast.src().getDefiningOp<arith::ConstantOp>()) {
|
||||
Attribute value = cst.getValue();
|
||||
@@ -132,15 +136,14 @@ public:
|
||||
if (auto denseValue = value.dyn_cast<DenseElementsAttr>()) {
|
||||
if (!denseValue.isSplat())
|
||||
return failure();
|
||||
value = DenseElementsAttr::get(resType, denseValue.getSplatValue<Attribute>());
|
||||
value = DenseElementsAttr::get(resType,
|
||||
denseValue.getSplatValue<Attribute>());
|
||||
} else {
|
||||
if (!value.isa<FloatAttr, IntegerAttr>())
|
||||
return failure();
|
||||
value = DenseElementsAttr::get(resType, value);
|
||||
}
|
||||
rewriter.replaceOpWithNewOp<arith::ConstantOp>(
|
||||
op, value, resType
|
||||
);
|
||||
rewriter.replaceOpWithNewOp<arith::ConstantOp>(op, value, resType);
|
||||
return success();
|
||||
}
|
||||
}
|
||||
|
@@ -11,19 +11,18 @@ using namespace mlir::triton::gpu;
|
||||
// parse an array of integers
|
||||
static LogicalResult parseIntArrayAttr(AsmParser &parser,
|
||||
const NamedAttribute &attr,
|
||||
/*SmallVector<unsigned, 2>*/auto &res,
|
||||
StringRef desc) {
|
||||
/*SmallVector<unsigned, 2>*/ auto &res,
|
||||
StringRef desc) {
|
||||
auto arrayAttr = attr.getValue().dyn_cast<ArrayAttr>();
|
||||
if (!arrayAttr) {
|
||||
parser.emitError(parser.getNameLoc(), "expected an array for ")
|
||||
<< desc;
|
||||
parser.emitError(parser.getNameLoc(), "expected an array for ") << desc;
|
||||
return failure();
|
||||
}
|
||||
for (Attribute i : arrayAttr) {
|
||||
auto intAttr = i.dyn_cast<IntegerAttr>();
|
||||
if (!intAttr) {
|
||||
parser.emitError(parser.getNameLoc(), "expected an integer value in ")
|
||||
<< desc;
|
||||
<< desc;
|
||||
return failure();
|
||||
}
|
||||
res.push_back(intAttr.getUInt());
|
||||
@@ -46,7 +45,7 @@ static Attribute parseBlocked(AsmParser &parser, Type type) {
|
||||
return {};
|
||||
if (parser.parseGreater().failed())
|
||||
return {};
|
||||
|
||||
|
||||
SmallVector<unsigned, 2> threadTileSize;
|
||||
SmallVector<unsigned, 2> warpTileSize;
|
||||
SmallVector<unsigned, 2> blockTileSize;
|
||||
@@ -55,19 +54,23 @@ static Attribute parseBlocked(AsmParser &parser, Type type) {
|
||||
|
||||
for (const NamedAttribute &attr : dict) {
|
||||
if (attr.getName() == "threadTileSize") {
|
||||
if (parseIntArrayAttr(parser, attr, threadTileSize, "thread tile size").failed())
|
||||
if (parseIntArrayAttr(parser, attr, threadTileSize, "thread tile size")
|
||||
.failed())
|
||||
return {};
|
||||
} else if (attr.getName() == "warpTileSize") {
|
||||
if (parseIntArrayAttr(parser, attr, warpTileSize, "warp tile size").failed())
|
||||
if (parseIntArrayAttr(parser, attr, warpTileSize, "warp tile size")
|
||||
.failed())
|
||||
return {};
|
||||
} else if (attr.getName() == "blockTileSize") {
|
||||
if (parseIntArrayAttr(parser, attr, blockTileSize, "block tile size").failed())
|
||||
if (parseIntArrayAttr(parser, attr, blockTileSize, "block tile size")
|
||||
.failed())
|
||||
return {};
|
||||
} else if (attr.getName() == "order") {
|
||||
if (parseIntArrayAttr(parser, attr, order, "order").failed())
|
||||
return {};
|
||||
} else if (attr.getName() == "broadcastAxis") {
|
||||
if (parseIntArrayAttr(parser, attr, broadcastAxis, "broadcastAxis").failed())
|
||||
if (parseIntArrayAttr(parser, attr, broadcastAxis, "broadcastAxis")
|
||||
.failed())
|
||||
return {};
|
||||
} else {
|
||||
parser.emitError(parser.getNameLoc(), "unexpected key: ")
|
||||
@@ -76,12 +79,9 @@ static Attribute parseBlocked(AsmParser &parser, Type type) {
|
||||
}
|
||||
}
|
||||
|
||||
return parser.getChecked<TritonGPUBlockedEncodingAttr>(parser.getContext(),
|
||||
threadTileSize,
|
||||
warpTileSize,
|
||||
blockTileSize,
|
||||
order,
|
||||
broadcastAxis);
|
||||
return parser.getChecked<TritonGPUBlockedEncodingAttr>(
|
||||
parser.getContext(), threadTileSize, warpTileSize, blockTileSize, order,
|
||||
broadcastAxis);
|
||||
}
|
||||
|
||||
static void printBlocked(AsmPrinter &printer, auto *attr) {
|
||||
@@ -94,8 +94,7 @@ static void printBlocked(AsmPrinter &printer, auto *attr) {
|
||||
<< "}>";
|
||||
}
|
||||
|
||||
Attribute
|
||||
TritonGPUBlockedEncodingAttr::parse(AsmParser &parser, Type type) {
|
||||
Attribute TritonGPUBlockedEncodingAttr::parse(AsmParser &parser, Type type) {
|
||||
parseBlocked(parser, type);
|
||||
}
|
||||
|
||||
@@ -103,8 +102,8 @@ void TritonGPUBlockedEncodingAttr::print(mlir::AsmPrinter &printer) const {
|
||||
printBlocked(printer, this);
|
||||
}
|
||||
|
||||
Attribute
|
||||
TritonGPUBlockedMulticastEncodingAttr::parse(AsmParser &parser, Type type) {
|
||||
Attribute TritonGPUBlockedMulticastEncodingAttr::parse(AsmParser &parser,
|
||||
Type type) {
|
||||
parseBlocked(parser, type);
|
||||
}
|
||||
|
||||
@@ -131,38 +130,37 @@ static Attribute parseMma(AsmParser &parser, Type type) {
|
||||
|
||||
for (const NamedAttribute &attr : dict) {
|
||||
if (attr.getName() == "fragmentPerWarp") {
|
||||
if (parseIntArrayAttr(parser, attr, fragmentPerWarp, "fragmentPerWarp").failed())
|
||||
if (parseIntArrayAttr(parser, attr, fragmentPerWarp, "fragmentPerWarp")
|
||||
.failed())
|
||||
return {};
|
||||
} else if (attr.getName() == "shapePerWarp") {
|
||||
if (parseIntArrayAttr(parser, attr, shapePerWarp, "shapePerWarp").failed())
|
||||
if (parseIntArrayAttr(parser, attr, shapePerWarp, "shapePerWarp")
|
||||
.failed())
|
||||
return {};
|
||||
} else if (attr.getName() == "warpPerTile") {
|
||||
if (parseIntArrayAttr(parser, attr, warpPerTile, "warpPerTile").failed())
|
||||
return {};
|
||||
} else if (attr.getName() == "shapePerTile") {
|
||||
if (parseIntArrayAttr(parser, attr, shapePerTile, "shapePerTile").failed())
|
||||
if (parseIntArrayAttr(parser, attr, shapePerTile, "shapePerTile")
|
||||
.failed())
|
||||
return {};
|
||||
} else if (attr.getName() == "repetitions") {
|
||||
if (parseIntArrayAttr(parser, attr, repetitions, "repetitions").failed())
|
||||
return {};
|
||||
} else if (attr.getName() == "contigPerThread") {
|
||||
if (parseIntArrayAttr(parser, attr, contigPerThread, "contigPerThread").failed())
|
||||
if (parseIntArrayAttr(parser, attr, contigPerThread, "contigPerThread")
|
||||
.failed())
|
||||
return {};
|
||||
} else {
|
||||
parser.emitError(parser.getNameLoc(), "unexpected key: ")
|
||||
<< attr.getName().strref();
|
||||
<< attr.getName().strref();
|
||||
return {};
|
||||
}
|
||||
}
|
||||
|
||||
return parser.getChecked<TritonGPUMmaEncodingAttr>(parser.getContext(),
|
||||
fragmentPerWarp,
|
||||
shapePerWarp,
|
||||
warpPerTile,
|
||||
shapePerTile,
|
||||
repetitions,
|
||||
contigPerThread,
|
||||
broadcastAxis);
|
||||
return parser.getChecked<TritonGPUMmaEncodingAttr>(
|
||||
parser.getContext(), fragmentPerWarp, shapePerWarp, warpPerTile,
|
||||
shapePerTile, repetitions, contigPerThread, broadcastAxis);
|
||||
}
|
||||
|
||||
static void printMma(AsmPrinter &printer, auto *attr) {
|
||||
@@ -176,8 +174,7 @@ static void printMma(AsmPrinter &printer, auto *attr) {
|
||||
<< "}>";
|
||||
}
|
||||
|
||||
Attribute
|
||||
TritonGPUMmaEncodingAttr::parse(AsmParser &parser, Type type) {
|
||||
Attribute TritonGPUMmaEncodingAttr::parse(AsmParser &parser, Type type) {
|
||||
return parseMma(parser, type);
|
||||
}
|
||||
|
||||
@@ -185,8 +182,8 @@ void TritonGPUMmaEncodingAttr::print(AsmPrinter &printer) const {
|
||||
printMma(printer, this);
|
||||
}
|
||||
|
||||
Attribute
|
||||
TritonGPUMmaMulticastEncodingAttr::parse(AsmParser &parser, Type type) {
|
||||
Attribute TritonGPUMmaMulticastEncodingAttr::parse(AsmParser &parser,
|
||||
Type type) {
|
||||
return parseMma(parser, type);
|
||||
}
|
||||
|
||||
@@ -194,8 +191,7 @@ void TritonGPUMmaMulticastEncodingAttr::print(AsmPrinter &printer) const {
|
||||
printMma(printer, this);
|
||||
}
|
||||
|
||||
Attribute
|
||||
TritonGPUSharedEncodingAttr::parse(AsmParser &parser, Type type) {
|
||||
Attribute TritonGPUSharedEncodingAttr::parse(AsmParser &parser, Type type) {
|
||||
if (parser.parseLess().failed())
|
||||
return {};
|
||||
// Parse the data as a dictionary
|
||||
@@ -210,8 +206,7 @@ TritonGPUSharedEncodingAttr::parse(AsmParser &parser, Type type) {
|
||||
unsigned maxPhase = 0;
|
||||
SmallVector<unsigned, 2> order;
|
||||
|
||||
auto parseUInt = [&parser](const NamedAttribute &attr,
|
||||
unsigned &value,
|
||||
auto parseUInt = [&parser](const NamedAttribute &attr, unsigned &value,
|
||||
StringRef desc) -> LogicalResult {
|
||||
auto intAttr = attr.getValue().dyn_cast<IntegerAttr>();
|
||||
if (!intAttr) {
|
||||
@@ -237,29 +232,25 @@ TritonGPUSharedEncodingAttr::parse(AsmParser &parser, Type type) {
|
||||
return {};
|
||||
} else {
|
||||
parser.emitError(parser.getNameLoc(), "unexpected key: ")
|
||||
<< attr.getName().strref();
|
||||
<< attr.getName().strref();
|
||||
return {};
|
||||
}
|
||||
}
|
||||
|
||||
return parser.getChecked<TritonGPUSharedEncodingAttr>(parser.getContext(),
|
||||
vec,
|
||||
perPhase,
|
||||
maxPhase,
|
||||
order);
|
||||
return parser.getChecked<TritonGPUSharedEncodingAttr>(
|
||||
parser.getContext(), vec, perPhase, maxPhase, order);
|
||||
}
|
||||
|
||||
void TritonGPUSharedEncodingAttr::print(AsmPrinter &printer) const {
|
||||
printer << "<{"
|
||||
<< "vec = " << getVec()
|
||||
<< ", perPhase = " << getPerPhase()
|
||||
<< ", maxPhase = " << getMaxPhase()
|
||||
<< ", order = [" << getOrder() << "]"
|
||||
<< "vec = " << getVec() << ", perPhase = " << getPerPhase()
|
||||
<< ", maxPhase = " << getMaxPhase() << ", order = [" << getOrder()
|
||||
<< "]"
|
||||
<< "}>";
|
||||
}
|
||||
|
||||
class TritonGPUOpAsmInterface : public OpAsmDialectInterface {
|
||||
public:
|
||||
public:
|
||||
using OpAsmDialectInterface::OpAsmDialectInterface;
|
||||
|
||||
AliasResult getAlias(Attribute attr, raw_ostream &os) const override {
|
||||
@@ -289,7 +280,7 @@ class TritonGPUOpAsmInterface : public OpAsmDialectInterface {
|
||||
OpAsmDialectInterface::getAlias(attr, os);
|
||||
}
|
||||
|
||||
private:
|
||||
private:
|
||||
static void printMma(const auto &attr, raw_ostream &os) {
|
||||
TritonGPUOpAsmInterface::printArray(attr.getFragmentPerWarp(), os);
|
||||
TritonGPUOpAsmInterface::printArray(attr.getShapePerWarp(), os);
|
||||
@@ -338,7 +329,7 @@ void TritonGPUDialect::initialize() {
|
||||
addOperations<
|
||||
#define GET_OP_LIST
|
||||
#include "triton/Dialect/TritonGPU/IR/Ops.cpp.inc"
|
||||
>();
|
||||
>();
|
||||
addInterfaces<TritonGPUOpAsmInterface>();
|
||||
}
|
||||
|
||||
@@ -349,7 +340,8 @@ namespace triton {
|
||||
static Type getI1SameShape(Type type) {
|
||||
auto i1Type = IntegerType::get(type.getContext(), 1);
|
||||
if (auto tensorType = type.dyn_cast<RankedTensorType>())
|
||||
return RankedTensorType::get(tensorType.getShape(), i1Type, tensorType.getEncoding());
|
||||
return RankedTensorType::get(tensorType.getShape(), i1Type,
|
||||
tensorType.getEncoding());
|
||||
return Type();
|
||||
}
|
||||
|
||||
@@ -368,8 +360,8 @@ static Type getPointeeType(Type type) {
|
||||
return Type();
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
} // namespace triton
|
||||
} // namespace mlir
|
||||
|
||||
static LogicalResult verify(CopyAsyncOp op) {
|
||||
Type resType = op.getResult().getType();
|
||||
@@ -385,11 +377,9 @@ static LogicalResult verify(CopyAsyncOp op) {
|
||||
#define GET_OP_CLASSES
|
||||
#include "triton/Dialect/TritonGPU/IR/Ops.cpp.inc"
|
||||
|
||||
|
||||
// verify TritonGPU ops
|
||||
LogicalResult
|
||||
TritonGPUDialect::verifyOperationAttribute(Operation *op,
|
||||
NamedAttribute attr) {
|
||||
LogicalResult TritonGPUDialect::verifyOperationAttribute(Operation *op,
|
||||
NamedAttribute attr) {
|
||||
// TODO: fill this.
|
||||
return success();
|
||||
}
|
||||
|
@@ -27,8 +27,8 @@ namespace {
|
||||
#define GEN_PASS_CLASSES
|
||||
#include "triton/Dialect/TritonGPU/Transforms/Passes.h.inc"
|
||||
|
||||
class TritonGPUCombineOpsPass
|
||||
: public TritonGPUCombineOpsBase<TritonGPUCombineOpsPass> {
|
||||
class TritonGPUCombineOpsPass
|
||||
: public TritonGPUCombineOpsBase<TritonGPUCombineOpsPass> {
|
||||
public:
|
||||
void runOnOperation() override {
|
||||
MLIRContext *context = &getContext();
|
||||
|
@@ -6,12 +6,11 @@
|
||||
//===----------------------------------------------------------------------===//
|
||||
//
|
||||
// This file implements loop software pipelining
|
||||
// The implementation here is inspired by the pipeline pass in Triton (-v2.0)
|
||||
// The implementation here is inspired by the pipeline pass in Triton (-v2.0)
|
||||
// and SCF's LoopPipelining.
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
|
||||
using namespace mlir;
|
||||
|
||||
#define GEN_PASS_CLASSES
|
||||
@@ -41,14 +40,15 @@ class LoopPipeliner {
|
||||
/// Block arguments that loads depend on
|
||||
DenseSet<BlockArgument> depArgs;
|
||||
/// Operations (inside the loop body) that loads depend on
|
||||
DenseSet<Operation*> depOps;
|
||||
DenseSet<Operation *> depOps;
|
||||
|
||||
/// collect values that v depends on and are defined inside the loop
|
||||
void collectDeps(Value v, int stages, DenseSet<Value> &deps);
|
||||
|
||||
void setValueMapping(Value origin, Value newValue, int stage);
|
||||
|
||||
public:
|
||||
LoopPipeliner(scf::ForOp forOp, int numStages)
|
||||
LoopPipeliner(scf::ForOp forOp, int numStages)
|
||||
: forOp(forOp), numStages(numStages) {
|
||||
// cache yieldOp
|
||||
yieldOp = cast<scf::YieldOp>(forOp.getBody()->getTerminator());
|
||||
@@ -86,7 +86,7 @@ void LoopPipeliner::collectDeps(Value v, int stages, DenseSet<Value> &deps) {
|
||||
if (auto arg = v.dyn_cast<BlockArgument>()) {
|
||||
deps.insert(v);
|
||||
// Note: we have iv as the first arg, so the op idx is arg.getArgNumber()-1
|
||||
collectDeps(yieldOp->getOperand(arg.getArgNumber() - 1), stages-1, deps);
|
||||
collectDeps(yieldOp->getOperand(arg.getArgNumber() - 1), stages - 1, deps);
|
||||
} else { // value
|
||||
// v might be in deps, but we still need to visit v.
|
||||
// This is because v might depends on value in previous iterations
|
||||
@@ -123,8 +123,8 @@ LogicalResult LoopPipeliner::initialize() {
|
||||
}
|
||||
|
||||
// for (triton::LoadOp loadOp : allLoads) {
|
||||
// llvm::errs() << loadOp << " depends on: #" << loadDeps[loadOp].size() << " values\n";
|
||||
// for (Value dep : loadDeps[loadOp])
|
||||
// llvm::errs() << loadOp << " depends on: #" << loadDeps[loadOp].size() <<
|
||||
// " values\n"; for (Value dep : loadDeps[loadOp])
|
||||
// llvm::errs() << dep << "\n";
|
||||
// llvm::errs() << "\n";
|
||||
// }
|
||||
@@ -147,9 +147,13 @@ LogicalResult LoopPipeliner::initialize() {
|
||||
if (isCandiate && loadOp.getResult().hasOneUse()) {
|
||||
isCandiate = false;
|
||||
Operation *use = *loadOp.getResult().getUsers().begin();
|
||||
if (auto convertLayout = llvm::dyn_cast<triton::gpu::ConvertLayoutOp>(use)) {
|
||||
if (auto tensorType = convertLayout.getResult().getType().dyn_cast<RankedTensorType>()) {
|
||||
if (tensorType.getEncoding().isa<triton::gpu::TritonGPUSharedEncodingAttr>()) {
|
||||
if (auto convertLayout =
|
||||
llvm::dyn_cast<triton::gpu::ConvertLayoutOp>(use)) {
|
||||
if (auto tensorType = convertLayout.getResult()
|
||||
.getType()
|
||||
.dyn_cast<RankedTensorType>()) {
|
||||
if (tensorType.getEncoding()
|
||||
.isa<triton::gpu::TritonGPUSharedEncodingAttr>()) {
|
||||
isCandiate = true;
|
||||
loadsMapping[loadOp] = convertLayout;
|
||||
}
|
||||
@@ -162,7 +166,6 @@ LogicalResult LoopPipeliner::initialize() {
|
||||
loads.insert(loadOp);
|
||||
}
|
||||
|
||||
|
||||
// we have some loads to pipeline
|
||||
if (!loads.empty()) {
|
||||
// update depArgs & depOps
|
||||
@@ -202,10 +205,10 @@ void LoopPipeliner::emitPrologue() {
|
||||
|
||||
// special handling for loop condition as there is no condition in ForOp
|
||||
Value loopCond = builder.create<arith::CmpIOp>(
|
||||
iv.getLoc(), arith::CmpIPredicate::slt, iv, forOp.getUpperBound());
|
||||
iv.getLoc(), arith::CmpIPredicate::slt, iv, forOp.getUpperBound());
|
||||
|
||||
// rematerialize peeled values
|
||||
SmallVector<Operation*> orderedDeps;
|
||||
SmallVector<Operation *> orderedDeps;
|
||||
for (Operation &op : forOp.getLoopBody().front()) {
|
||||
if (depOps.contains(&op))
|
||||
orderedDeps.push_back(&op);
|
||||
@@ -221,10 +224,9 @@ void LoopPipeliner::emitPrologue() {
|
||||
// TODO: check if the hardware supports copyasync
|
||||
if (auto loadOp = llvm::dyn_cast<triton::LoadOp>(op)) {
|
||||
newOp = builder.create<triton::gpu::CopyAsyncOp>(
|
||||
op->getLoc(), loadsMapping[loadOp].getType(),
|
||||
loadOp.ptr(), loadOp.mask(), loadOp.other(),
|
||||
loadOp.cache(), loadOp.evict(), loadOp.isVolatile()
|
||||
);
|
||||
op->getLoc(), loadsMapping[loadOp].getType(), loadOp.ptr(),
|
||||
loadOp.mask(), loadOp.other(), loadOp.cache(), loadOp.evict(),
|
||||
loadOp.isVolatile());
|
||||
} else
|
||||
llvm_unreachable("This should be LoadOp");
|
||||
} else
|
||||
@@ -245,12 +247,10 @@ void LoopPipeliner::emitPrologue() {
|
||||
// assert(I1 or TensorOf<[I1]>);
|
||||
OpBuilder::InsertionGuard g(builder);
|
||||
builder.setInsertionPoint(newOp);
|
||||
Value splatCond = builder.create<triton::BroadcastOp>(mask.getLoc(),
|
||||
mask.getType(),
|
||||
loopCond);
|
||||
Value newMask = builder.create<arith::AndIOp>(mask.getLoc(),
|
||||
mask,
|
||||
splatCond);
|
||||
Value splatCond = builder.create<triton::BroadcastOp>(
|
||||
mask.getLoc(), mask.getType(), loopCond);
|
||||
Value newMask =
|
||||
builder.create<arith::AndIOp>(mask.getLoc(), mask, splatCond);
|
||||
newOp->setOperand(1, newMask);
|
||||
}
|
||||
|
||||
@@ -264,8 +264,9 @@ void LoopPipeliner::emitPrologue() {
|
||||
// update mapping for loop-carried values (args)
|
||||
for (OpOperand &operand : yieldOp->getOpOperands()) {
|
||||
if (operand.get() == op->getResult(dstIdx))
|
||||
setValueMapping(forOp.getRegionIterArgs()[operand.getOperandNumber()],
|
||||
newOp->getResult(dstIdx), stage + 1);
|
||||
setValueMapping(
|
||||
forOp.getRegionIterArgs()[operand.getOperandNumber()],
|
||||
newOp->getResult(dstIdx), stage + 1);
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -296,21 +297,19 @@ scf::ForOp LoopPipeliner::createNewForOp() {
|
||||
size_t depArgsBeginIdx = newLoopArgs.size();
|
||||
for (BlockArgument depArg : depArgs) {
|
||||
depArgsIdx[depArg] = newLoopArgs.size();
|
||||
newLoopArgs.push_back(valueMapping[depArg][numStages-1]);
|
||||
newLoopArgs.push_back(valueMapping[depArg][numStages - 1]);
|
||||
}
|
||||
|
||||
size_t nextIVIdx = newLoopArgs.size();
|
||||
newLoopArgs.push_back(valueMapping[forOp.getInductionVar()][numStages-2]);
|
||||
newLoopArgs.push_back(valueMapping[forOp.getInductionVar()][numStages - 2]);
|
||||
|
||||
for (size_t i = 0; i < newLoopArgs.size(); ++i)
|
||||
assert(newLoopArgs[i]);
|
||||
|
||||
// 1. signature of the new ForOp
|
||||
auto newForOp = builder.create<scf::ForOp>(forOp.getLoc(),
|
||||
forOp.getLowerBound(),
|
||||
forOp.getUpperBound(),
|
||||
forOp.getStep(),
|
||||
newLoopArgs);
|
||||
auto newForOp = builder.create<scf::ForOp>(
|
||||
forOp.getLoc(), forOp.getLowerBound(), forOp.getUpperBound(),
|
||||
forOp.getStep(), newLoopArgs);
|
||||
|
||||
// 2. body of the new ForOp
|
||||
builder.setInsertionPointToStart(newForOp.getBody());
|
||||
@@ -329,15 +328,15 @@ scf::ForOp LoopPipeliner::createNewForOp() {
|
||||
// 3. replace loads with block args (from prologue)
|
||||
for (size_t idx = 0; idx < loads.size(); ++idx) {
|
||||
Value load = loads[idx];
|
||||
assert(load.hasOneUse() && "we assume that this load has one use (ConvertLayout)");
|
||||
assert(load.hasOneUse() &&
|
||||
"we assume that this load has one use (ConvertLayout)");
|
||||
Value loadUse = load.getUsers().begin()->getResult(0);
|
||||
mapping.lookup(loadUse).replaceAllUsesWith(
|
||||
newForOp.getRegionIterArgs()[loadIdx + idx*(numStages-1)]);
|
||||
newForOp.getRegionIterArgs()[loadIdx + idx * (numStages - 1)]);
|
||||
}
|
||||
|
||||
|
||||
// 4. prefetch the next iteration
|
||||
SmallVector<Operation*> orderedDeps;
|
||||
SmallVector<Operation *> orderedDeps;
|
||||
for (Operation &op : forOp.getLoopBody().front()) {
|
||||
if (depOps.contains(&op))
|
||||
orderedDeps.push_back(&op);
|
||||
@@ -350,41 +349,39 @@ scf::ForOp LoopPipeliner::createNewForOp() {
|
||||
DenseMap<BlockArgument, Value> depArgsMapping;
|
||||
size_t argIdx = 0;
|
||||
for (BlockArgument arg : depArgs) {
|
||||
nextMapping.map(arg, newForOp.getRegionIterArgs()[argIdx + depArgsBeginIdx]);
|
||||
nextMapping.map(arg,
|
||||
newForOp.getRegionIterArgs()[argIdx + depArgsBeginIdx]);
|
||||
++argIdx;
|
||||
}
|
||||
// special handling for iv & loop condition
|
||||
Value nextIV = builder.create<arith::AddIOp>(newForOp.getInductionVar().getLoc(),
|
||||
newForOp.getRegionIterArgs()[nextIVIdx],
|
||||
newForOp.getStep());
|
||||
Value nextLoopCond = builder.create<arith::CmpIOp>(
|
||||
nextIV.getLoc(), arith::CmpIPredicate::slt,
|
||||
nextIV, newForOp.getUpperBound());
|
||||
Value nextIV = builder.create<arith::AddIOp>(
|
||||
newForOp.getInductionVar().getLoc(),
|
||||
newForOp.getRegionIterArgs()[nextIVIdx], newForOp.getStep());
|
||||
Value nextLoopCond =
|
||||
builder.create<arith::CmpIOp>(nextIV.getLoc(), arith::CmpIPredicate::slt,
|
||||
nextIV, newForOp.getUpperBound());
|
||||
for (Operation *op : orderedDeps) {
|
||||
Operation *nextOp = nullptr;
|
||||
// update loading mask
|
||||
if (loads.contains(op->getResult(0))) {
|
||||
auto loadOp = llvm::cast<triton::LoadOp>(op);
|
||||
Value mask = loadOp.mask();
|
||||
Value splatCond = builder.create<triton::BroadcastOp>(mask.getLoc(),
|
||||
mask.getType(),
|
||||
nextLoopCond);
|
||||
Value newMask = builder.create<arith::AndIOp>(mask.getLoc(),
|
||||
splatCond,
|
||||
nextMapping.lookupOrDefault(mask));
|
||||
// if mask is defined outside the loop, don't update the map more than once
|
||||
Value splatCond = builder.create<triton::BroadcastOp>(
|
||||
mask.getLoc(), mask.getType(), nextLoopCond);
|
||||
Value newMask = builder.create<arith::AndIOp>(
|
||||
mask.getLoc(), splatCond, nextMapping.lookupOrDefault(mask));
|
||||
// if mask is defined outside the loop, don't update the map more than
|
||||
// once
|
||||
if (!(forOp.isDefinedOutsideOfLoop(mask) && nextMapping.contains(mask)))
|
||||
nextMapping.map(mask, newMask);
|
||||
// TODO: more elegant way to do this?
|
||||
nextOp = builder.create<triton::gpu::CopyAsyncOp>(
|
||||
op->getLoc(), loadsMapping[op->getResult(0)].getType(),
|
||||
nextMapping.lookupOrDefault(loadOp.ptr()),
|
||||
nextMapping.lookupOrDefault(loadOp.mask()),
|
||||
nextMapping.lookupOrDefault(loadOp.other()),
|
||||
loadOp.cache(), loadOp.evict(), loadOp.isVolatile()
|
||||
);
|
||||
}
|
||||
else
|
||||
op->getLoc(), loadsMapping[op->getResult(0)].getType(),
|
||||
nextMapping.lookupOrDefault(loadOp.ptr()),
|
||||
nextMapping.lookupOrDefault(loadOp.mask()),
|
||||
nextMapping.lookupOrDefault(loadOp.other()), loadOp.cache(),
|
||||
loadOp.evict(), loadOp.isVolatile());
|
||||
} else
|
||||
nextOp = builder.clone(*op, nextMapping);
|
||||
// llvm::errs() << "epilogue cloning...: " << *op << "\n";
|
||||
// update mapping of results
|
||||
@@ -411,15 +408,16 @@ scf::ForOp LoopPipeliner::createNewForOp() {
|
||||
for (size_t idx = 0; idx < loads.size(); ++idx) {
|
||||
Value load = loads[idx];
|
||||
for (int stage = 1; stage < numStages - 1; ++stage) {
|
||||
yieldValues.push_back(newForOp.getRegionIterArgs()[
|
||||
loadIdx + idx*(numStages-1) + stage
|
||||
]);
|
||||
yieldValues.push_back(
|
||||
newForOp
|
||||
.getRegionIterArgs()[loadIdx + idx * (numStages - 1) + stage]);
|
||||
}
|
||||
yieldValues.push_back(nextMapping.lookup(load));
|
||||
}
|
||||
|
||||
for (size_t i = depArgsBeginIdx; i < nextIVIdx; ++i)
|
||||
yieldValues.push_back(depArgsMapping.lookup(newForOp.getRegionIterArgs()[i]));
|
||||
yieldValues.push_back(
|
||||
depArgsMapping.lookup(newForOp.getRegionIterArgs()[i]));
|
||||
yieldValues.push_back(nextIV);
|
||||
builder.setInsertionPointToEnd(newForOp.getBody());
|
||||
builder.create<scf::YieldOp>(forOp.getBody()->getTerminator()->getLoc(),
|
||||
@@ -430,9 +428,7 @@ scf::ForOp LoopPipeliner::createNewForOp() {
|
||||
// ref: mlir/lib/Dialect/SCF/Transforms/LoopPipelining.cpp
|
||||
struct PipelinePass : public TritonGPUPipelineBase<PipelinePass> {
|
||||
PipelinePass() = default;
|
||||
PipelinePass(int numStages) {
|
||||
this->numStages = numStages;
|
||||
}
|
||||
PipelinePass(int numStages) { this->numStages = numStages; }
|
||||
|
||||
void runOnOperation() override {
|
||||
int numStages = this->numStages;
|
||||
|
@@ -1,7 +1,7 @@
|
||||
#include "triton/Dialect/TritonGPU/Transforms/TritonGPUConversion.h"
|
||||
#include "mlir/IR/BlockAndValueMapping.h"
|
||||
#include "triton/Dialect/Triton/IR/Dialect.h"
|
||||
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
|
||||
#include "mlir/IR/BlockAndValueMapping.h"
|
||||
#include <algorithm>
|
||||
|
||||
using namespace mlir;
|
||||
@@ -10,7 +10,7 @@ using namespace mlir::triton::gpu;
|
||||
//
|
||||
// TypeConverter
|
||||
//
|
||||
TritonGPUTypeConverter::TritonGPUTypeConverter(MLIRContext *context,
|
||||
TritonGPUTypeConverter::TritonGPUTypeConverter(MLIRContext *context,
|
||||
int numThreads)
|
||||
: context(context), numThreads(numThreads) {
|
||||
// TODO: how does MLIR pick the right conversion?
|
||||
@@ -38,14 +38,14 @@ TritonGPUTypeConverter::TritonGPUTypeConverter(MLIRContext *context,
|
||||
// or assert no encoding?
|
||||
|
||||
// Now we assume:
|
||||
// contiguous = 1, order = 0, 1, 2, ...,
|
||||
// contiguous = 1, order = 0, 1, 2, ...,
|
||||
llvm::SmallVector<unsigned> threadTileSize(rank, 1); // naive layout
|
||||
llvm::SmallVector<unsigned> warpTileSize(rank, 1);
|
||||
llvm::SmallVector<unsigned> blockTileSize(rank);
|
||||
llvm::SmallVector<unsigned> order(rank);
|
||||
llvm::SmallVector<unsigned> broadcastAxis;
|
||||
int remainingThreads = numThreads;
|
||||
int remainingLanes = /*warp size*/32;
|
||||
int remainingLanes = /*warp size*/ 32;
|
||||
for (int64_t dim = 0; dim < rank; ++dim) {
|
||||
blockTileSize[dim] = std::clamp(remainingThreads, 1, int(shape[dim]));
|
||||
warpTileSize[dim] = std::clamp(remainingLanes, 1, int(shape[dim]));
|
||||
@@ -56,7 +56,8 @@ TritonGPUTypeConverter::TritonGPUTypeConverter(MLIRContext *context,
|
||||
// TODO: will we need repetition?
|
||||
}
|
||||
Attribute encoding = triton::gpu::TritonGPUBlockedEncodingAttr::get(
|
||||
context, threadTileSize, warpTileSize, blockTileSize, order, broadcastAxis);
|
||||
context, threadTileSize, warpTileSize, blockTileSize, order,
|
||||
broadcastAxis);
|
||||
return RankedTensorType::get(shape, elementType, encoding);
|
||||
});
|
||||
|
||||
@@ -65,8 +66,9 @@ TritonGPUTypeConverter::TritonGPUTypeConverter(MLIRContext *context,
|
||||
//
|
||||
// This will be called when (newArgType != origArgType)
|
||||
// This will create newArg, and map(origArg, newArg)
|
||||
addArgumentMaterialization([&](OpBuilder &builder, RankedTensorType tensorType,
|
||||
ValueRange inputs, Location loc) {
|
||||
addArgumentMaterialization([&](OpBuilder &builder,
|
||||
RankedTensorType tensorType, ValueRange inputs,
|
||||
Location loc) {
|
||||
llvm_unreachable("Not implemented");
|
||||
return llvm::None;
|
||||
});
|
||||
@@ -74,7 +76,7 @@ TritonGPUTypeConverter::TritonGPUTypeConverter(MLIRContext *context,
|
||||
// If the origValue still has live user(s), use this to
|
||||
// convert origValue to newValue
|
||||
addSourceMaterialization([&](OpBuilder &builder, RankedTensorType tensorType,
|
||||
ValueRange inputs, Location loc) {
|
||||
ValueRange inputs, Location loc) {
|
||||
llvm_unreachable("Not implemented");
|
||||
return llvm::None;
|
||||
});
|
||||
@@ -83,7 +85,7 @@ TritonGPUTypeConverter::TritonGPUTypeConverter(MLIRContext *context,
|
||||
// where, desiredType = typeConverter->convertType(origType)
|
||||
// NOTE: only for remapped values.
|
||||
addTargetMaterialization([&](OpBuilder &builder, RankedTensorType tensorType,
|
||||
ValueRange inputs, Location loc) {
|
||||
ValueRange inputs, Location loc) {
|
||||
llvm_unreachable("Not implemented");
|
||||
return llvm::None;
|
||||
});
|
||||
@@ -93,30 +95,31 @@ TritonGPUTypeConverter::TritonGPUTypeConverter(MLIRContext *context,
|
||||
// TritonGPUConversion
|
||||
//
|
||||
TritonGPUConversionTarget::TritonGPUConversionTarget(
|
||||
MLIRContext &context, TritonGPUTypeConverter &typeConverter)
|
||||
MLIRContext &context, TritonGPUTypeConverter &typeConverter)
|
||||
: ConversionTarget(context), typeConverter(typeConverter) {
|
||||
// TODO: we should also verify ops of TritonGPUDialect
|
||||
addLegalDialect<triton::gpu::TritonGPUDialect>();
|
||||
|
||||
// Some ops from SCF are illegal
|
||||
addIllegalOp<scf::ExecuteRegionOp, scf::ParallelOp,
|
||||
scf::ReduceOp, scf::ReduceReturnOp>();
|
||||
|
||||
addDynamicallyLegalDialect<arith::ArithmeticDialect,
|
||||
triton::TritonDialect,
|
||||
StandardOpsDialect,
|
||||
scf::SCFDialect>([&](Operation *op) {
|
||||
if (typeConverter.isLegal(op))
|
||||
return true;
|
||||
return false;
|
||||
});
|
||||
addIllegalOp<scf::ExecuteRegionOp, scf::ParallelOp, scf::ReduceOp,
|
||||
scf::ReduceReturnOp>();
|
||||
|
||||
addDynamicallyLegalDialect<arith::ArithmeticDialect, triton::TritonDialect,
|
||||
StandardOpsDialect, scf::SCFDialect>(
|
||||
[&](Operation *op) {
|
||||
if (typeConverter.isLegal(op))
|
||||
return true;
|
||||
return false;
|
||||
});
|
||||
|
||||
// We have requirements for the data layouts
|
||||
addDynamicallyLegalOp<triton::DotOp>([this](triton::DotOp dotOp) -> bool {
|
||||
Attribute aEncoding = dotOp.a().getType().cast<RankedTensorType>().getEncoding();
|
||||
Attribute bEncoding = dotOp.b().getType().cast<RankedTensorType>().getEncoding();
|
||||
if (aEncoding && aEncoding.isa<triton::gpu::TritonGPUSharedEncodingAttr>() &&
|
||||
Attribute aEncoding =
|
||||
dotOp.a().getType().cast<RankedTensorType>().getEncoding();
|
||||
Attribute bEncoding =
|
||||
dotOp.b().getType().cast<RankedTensorType>().getEncoding();
|
||||
if (aEncoding &&
|
||||
aEncoding.isa<triton::gpu::TritonGPUSharedEncodingAttr>() &&
|
||||
bEncoding && bEncoding.isa<triton::gpu::TritonGPUSharedEncodingAttr>())
|
||||
return true;
|
||||
// // TODO: we should delete this
|
||||
@@ -124,7 +127,6 @@ TritonGPUConversionTarget::TritonGPUConversionTarget(
|
||||
// return true;
|
||||
return false;
|
||||
});
|
||||
|
||||
}
|
||||
|
||||
// %dst = tt.broadcast %src
|
||||
@@ -133,12 +135,10 @@ TritonGPUConversionTarget::TritonGPUConversionTarget(
|
||||
// %bcst = tt.broadcast %newSrc
|
||||
// %dst = convert_layout %bcst
|
||||
LogicalResult TritonGPUConversionTarget::refineLayouts(ModuleOp mod,
|
||||
int numThreads) {
|
||||
int numThreads) {
|
||||
// collect broadcasts
|
||||
SmallVector<triton::BroadcastOp> broadcasts;
|
||||
mod.walk([&](triton::BroadcastOp op) {
|
||||
broadcasts.push_back(op);
|
||||
});
|
||||
mod.walk([&](triton::BroadcastOp op) { broadcasts.push_back(op); });
|
||||
|
||||
BlockAndValueMapping mapping;
|
||||
for (auto broadcast : broadcasts) {
|
||||
@@ -161,20 +161,14 @@ LogicalResult TritonGPUConversionTarget::refineLayouts(ModuleOp mod,
|
||||
broadcastAxis.push_back(ax);
|
||||
|
||||
Attribute originSrcEnc = tensorType.getEncoding();
|
||||
if (auto blockedEnc = originSrcEnc.dyn_cast<TritonGPUBlockedEncodingAttr>()) {
|
||||
if (auto blockedEnc =
|
||||
originSrcEnc.dyn_cast<TritonGPUBlockedEncodingAttr>()) {
|
||||
auto newSrcEnc = TritonGPUBlockedMulticastEncodingAttr::get(
|
||||
blockedEnc.getContext(),
|
||||
blockedEnc.getThreadTileSize(),
|
||||
blockedEnc.getWarpTileSize(),
|
||||
blockedEnc.getBlockTileSize(),
|
||||
blockedEnc.getOrder(),
|
||||
broadcastAxis
|
||||
);
|
||||
blockedEnc.getContext(), blockedEnc.getThreadTileSize(),
|
||||
blockedEnc.getWarpTileSize(), blockedEnc.getBlockTileSize(),
|
||||
blockedEnc.getOrder(), broadcastAxis);
|
||||
newSrcType = RankedTensorType::get(
|
||||
tensorType.getShape(),
|
||||
tensorType.getElementType(),
|
||||
newSrcEnc
|
||||
);
|
||||
tensorType.getShape(), tensorType.getElementType(), newSrcEnc);
|
||||
} else
|
||||
llvm_unreachable("src of broadcast should have blocked encoding");
|
||||
} else {
|
||||
@@ -186,34 +180,25 @@ LogicalResult TritonGPUConversionTarget::refineLayouts(ModuleOp mod,
|
||||
|
||||
// create new src
|
||||
if (!isSrcScalar) // we don't need to convert layout for scalar values
|
||||
src = builder.create<triton::gpu::ConvertLayoutOp>(
|
||||
src.getLoc(), newSrcType, src
|
||||
);
|
||||
src = builder.create<triton::gpu::ConvertLayoutOp>(src.getLoc(),
|
||||
newSrcType, src);
|
||||
|
||||
// create new broadcast
|
||||
// compute new type (encoding)
|
||||
auto originDstEnc = originDstTensorType.getEncoding()
|
||||
.dyn_cast<TritonGPUBlockedEncodingAttr>();
|
||||
.dyn_cast<TritonGPUBlockedEncodingAttr>();
|
||||
auto newEnc = TritonGPUBlockedEncodingAttr::get(
|
||||
originDstEnc.getContext(),
|
||||
originDstEnc.getThreadTileSize(),
|
||||
originDstEnc.getWarpTileSize(),
|
||||
originDstEnc.getBlockTileSize(),
|
||||
originDstEnc.getOrder(),
|
||||
broadcastAxis
|
||||
);
|
||||
auto newType = RankedTensorType::get(
|
||||
originDstTensorType.getShape(),
|
||||
originDstTensorType.getElementType(),
|
||||
newEnc
|
||||
);
|
||||
Value newBroadcast = builder.create<triton::BroadcastOp>(
|
||||
broadcast.getLoc(), newType, src
|
||||
);
|
||||
originDstEnc.getContext(), originDstEnc.getThreadTileSize(),
|
||||
originDstEnc.getWarpTileSize(), originDstEnc.getBlockTileSize(),
|
||||
originDstEnc.getOrder(), broadcastAxis);
|
||||
auto newType =
|
||||
RankedTensorType::get(originDstTensorType.getShape(),
|
||||
originDstTensorType.getElementType(), newEnc);
|
||||
Value newBroadcast =
|
||||
builder.create<triton::BroadcastOp>(broadcast.getLoc(), newType, src);
|
||||
// we don't want to change the encoding of the result
|
||||
Value newDst = builder.create<triton::gpu::ConvertLayoutOp>(
|
||||
broadcast.getLoc(), originDstType, newBroadcast
|
||||
);
|
||||
broadcast.getLoc(), originDstType, newBroadcast);
|
||||
|
||||
broadcast.replaceAllUsesWith(newDst);
|
||||
mapping.map(broadcast, newDst);
|
||||
|
@@ -5,7 +5,6 @@
|
||||
|
||||
using namespace mlir;
|
||||
|
||||
|
||||
#define GEN_PASS_CLASSES
|
||||
#include "triton/Dialect/TritonGPU/Transforms/Passes.h.inc"
|
||||
|
||||
@@ -37,28 +36,30 @@ private:
|
||||
if (!encoding.isa<triton::gpu::TritonGPUSharedEncodingAttr>())
|
||||
return dotOp.emitError() << name << " should be of shared layout";
|
||||
} else
|
||||
return dotOp.emitError() << name << "'s type should be of RankedTensorType";
|
||||
return dotOp.emitError()
|
||||
<< name << "'s type should be of RankedTensorType";
|
||||
}
|
||||
|
||||
Attribute cLayout;
|
||||
for (auto it : llvm::zip(llvm::SmallVector<Type>{cType, dType},
|
||||
llvm::SmallVector<char>{'c', 'd'})) {
|
||||
llvm::SmallVector<char>{'c', 'd'})) {
|
||||
Type type = std::get<0>(it);
|
||||
char name = std::get<1>(it);
|
||||
if (auto tensorType = type.dyn_cast<RankedTensorType>()) {
|
||||
Attribute encoding = tensorType.getEncoding();
|
||||
if (!encoding)
|
||||
return dotOp.emitError() << name << " should have encoding";
|
||||
if (!encoding.isa<triton::gpu::TritonGPUMmaEncodingAttr>() &&
|
||||
if (!encoding.isa<triton::gpu::TritonGPUMmaEncodingAttr>() &&
|
||||
!encoding.isa<triton::gpu::TritonGPUBlockedEncodingAttr>())
|
||||
return dotOp.emitError() << name << " should be of distributed layout";
|
||||
return dotOp.emitError()
|
||||
<< name << " should be of distributed layout";
|
||||
if (name == 'c')
|
||||
cLayout = encoding;
|
||||
else if (encoding != cLayout)
|
||||
return dotOp.emitError() << "d & c should have the same layout";
|
||||
} else
|
||||
return dotOp.emitError() << name
|
||||
<< "'s type should be of RankedTensorType";
|
||||
return dotOp.emitError()
|
||||
<< name << "'s type should be of RankedTensorType";
|
||||
}
|
||||
|
||||
// signalPassFailure();
|
||||
@@ -89,7 +90,7 @@ private:
|
||||
}
|
||||
|
||||
void verifyImpl(Operation *op) {
|
||||
if(verifySingleOp(op).failed())
|
||||
if (verifySingleOp(op).failed())
|
||||
signalPassFailure();
|
||||
|
||||
// verify that all child regions are ok
|
||||
|
408
lib/driver/dispatch.cc
Executable file → Normal file
408
lib/driver/dispatch.cc
Executable file → Normal file
@@ -1,107 +1,152 @@
|
||||
/* Copyright 2015-2017 Philippe Tillet
|
||||
*
|
||||
* Permission is hereby granted, free of charge, to any person obtaining
|
||||
* a copy of this software and associated documentation files
|
||||
* (the "Software"), to deal in the Software without restriction,
|
||||
* including without limitation the rights to use, copy, modify, merge,
|
||||
* publish, distribute, sublicense, and/or sell copies of the Software,
|
||||
* and to permit persons to whom the Software is furnished to do so,
|
||||
* subject to the following conditions:
|
||||
*
|
||||
* The above copyright notice and this permission notice shall be
|
||||
* included in all copies or substantial portions of the Software.
|
||||
*
|
||||
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
|
||||
* EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
|
||||
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
|
||||
* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
|
||||
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
|
||||
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
|
||||
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
||||
*/
|
||||
*
|
||||
* Permission is hereby granted, free of charge, to any person obtaining
|
||||
* a copy of this software and associated documentation files
|
||||
* (the "Software"), to deal in the Software without restriction,
|
||||
* including without limitation the rights to use, copy, modify, merge,
|
||||
* publish, distribute, sublicense, and/or sell copies of the Software,
|
||||
* and to permit persons to whom the Software is furnished to do so,
|
||||
* subject to the following conditions:
|
||||
*
|
||||
* The above copyright notice and this permission notice shall be
|
||||
* included in all copies or substantial portions of the Software.
|
||||
*
|
||||
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
|
||||
* EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
|
||||
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
|
||||
* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
|
||||
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
|
||||
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
|
||||
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
||||
*/
|
||||
|
||||
#include "triton/driver/dispatch.h"
|
||||
|
||||
namespace triton
|
||||
{
|
||||
namespace driver
|
||||
{
|
||||
namespace triton {
|
||||
namespace driver {
|
||||
|
||||
//Helpers for function definition
|
||||
#define DEFINE0(init, hlib, ret, fname) ret dispatch::fname()\
|
||||
{return f_impl<dispatch::init>(hlib, fname, fname ## _, #fname); }\
|
||||
void* dispatch::fname ## _;
|
||||
// Helpers for function definition
|
||||
#define DEFINE0(init, hlib, ret, fname) \
|
||||
ret dispatch::fname() { \
|
||||
return f_impl<dispatch::init>(hlib, fname, fname##_, #fname); \
|
||||
} \
|
||||
void *dispatch::fname##_;
|
||||
|
||||
#define DEFINE1(init, hlib, ret, fname, t1) ret dispatch::fname(t1 a)\
|
||||
{return f_impl<dispatch::init>(hlib, fname, fname ## _, #fname, a); }\
|
||||
void* dispatch::fname ## _;
|
||||
#define DEFINE1(init, hlib, ret, fname, t1) \
|
||||
ret dispatch::fname(t1 a) { \
|
||||
return f_impl<dispatch::init>(hlib, fname, fname##_, #fname, a); \
|
||||
} \
|
||||
void *dispatch::fname##_;
|
||||
|
||||
#define DEFINE2(init, hlib, ret, fname, t1, t2) ret dispatch::fname(t1 a, t2 b)\
|
||||
{return f_impl<dispatch::init>(hlib, fname, fname ## _, #fname, a, b); }\
|
||||
void* dispatch::fname ## _;
|
||||
#define DEFINE2(init, hlib, ret, fname, t1, t2) \
|
||||
ret dispatch::fname(t1 a, t2 b) { \
|
||||
return f_impl<dispatch::init>(hlib, fname, fname##_, #fname, a, b); \
|
||||
} \
|
||||
void *dispatch::fname##_;
|
||||
|
||||
#define DEFINE3(init, hlib, ret, fname, t1, t2, t3) ret dispatch::fname(t1 a, t2 b, t3 c)\
|
||||
{return f_impl<dispatch::init>(hlib, fname, fname ## _, #fname, a, b, c); }\
|
||||
void* dispatch::fname ## _;
|
||||
#define DEFINE3(init, hlib, ret, fname, t1, t2, t3) \
|
||||
ret dispatch::fname(t1 a, t2 b, t3 c) { \
|
||||
return f_impl<dispatch::init>(hlib, fname, fname##_, #fname, a, b, c); \
|
||||
} \
|
||||
void *dispatch::fname##_;
|
||||
|
||||
#define DEFINE4(init, hlib, ret, fname, t1, t2, t3, t4) ret dispatch::fname(t1 a, t2 b, t3 c, t4 d)\
|
||||
{return f_impl<dispatch::init>(hlib, fname, fname ## _, #fname, a, b, c, d); }\
|
||||
void* dispatch::fname ## _;
|
||||
#define DEFINE4(init, hlib, ret, fname, t1, t2, t3, t4) \
|
||||
ret dispatch::fname(t1 a, t2 b, t3 c, t4 d) { \
|
||||
return f_impl<dispatch::init>(hlib, fname, fname##_, #fname, a, b, c, d); \
|
||||
} \
|
||||
void *dispatch::fname##_;
|
||||
|
||||
#define DEFINE5(init, hlib, ret, fname, t1, t2, t3, t4, t5) ret dispatch::fname(t1 a, t2 b, t3 c, t4 d, t5 e)\
|
||||
{return f_impl<dispatch::init>(hlib, fname, fname ## _, #fname, a, b, c, d, e); }\
|
||||
void* dispatch::fname ## _;
|
||||
#define DEFINE5(init, hlib, ret, fname, t1, t2, t3, t4, t5) \
|
||||
ret dispatch::fname(t1 a, t2 b, t3 c, t4 d, t5 e) { \
|
||||
return f_impl<dispatch::init>(hlib, fname, fname##_, #fname, a, b, c, d, \
|
||||
e); \
|
||||
} \
|
||||
void *dispatch::fname##_;
|
||||
|
||||
#define DEFINE6(init, hlib, ret, fname, t1, t2, t3, t4, t5, t6) ret dispatch::fname(t1 a, t2 b, t3 c, t4 d, t5 e, t6 f)\
|
||||
{return f_impl<dispatch::init>(hlib, fname, fname ## _, #fname, a, b, c, d, e, f); }\
|
||||
void* dispatch::fname ## _;
|
||||
#define DEFINE6(init, hlib, ret, fname, t1, t2, t3, t4, t5, t6) \
|
||||
ret dispatch::fname(t1 a, t2 b, t3 c, t4 d, t5 e, t6 f) { \
|
||||
return f_impl<dispatch::init>(hlib, fname, fname##_, #fname, a, b, c, d, \
|
||||
e, f); \
|
||||
} \
|
||||
void *dispatch::fname##_;
|
||||
|
||||
#define DEFINE7(init, hlib, ret, fname, t1, t2, t3, t4, t5, t6, t7) ret dispatch::fname(t1 a, t2 b, t3 c, t4 d, t5 e, t6 f, t7 g)\
|
||||
{return f_impl<dispatch::init>(hlib, fname, fname ## _, #fname, a, b, c, d, e, f, g); }\
|
||||
void* dispatch::fname ## _;
|
||||
#define DEFINE7(init, hlib, ret, fname, t1, t2, t3, t4, t5, t6, t7) \
|
||||
ret dispatch::fname(t1 a, t2 b, t3 c, t4 d, t5 e, t6 f, t7 g) { \
|
||||
return f_impl<dispatch::init>(hlib, fname, fname##_, #fname, a, b, c, d, \
|
||||
e, f, g); \
|
||||
} \
|
||||
void *dispatch::fname##_;
|
||||
|
||||
#define DEFINE8(init, hlib, ret, fname, t1, t2, t3, t4, t5, t6, t7, t8) ret dispatch::fname(t1 a, t2 b, t3 c, t4 d, t5 e, t6 f, t7 g, t8 h)\
|
||||
{return f_impl<dispatch::init>(hlib, fname, fname ## _, #fname, a, b, c, d, e, f, g, h); }\
|
||||
void* dispatch::fname ## _;
|
||||
#define DEFINE8(init, hlib, ret, fname, t1, t2, t3, t4, t5, t6, t7, t8) \
|
||||
ret dispatch::fname(t1 a, t2 b, t3 c, t4 d, t5 e, t6 f, t7 g, t8 h) { \
|
||||
return f_impl<dispatch::init>(hlib, fname, fname##_, #fname, a, b, c, d, \
|
||||
e, f, g, h); \
|
||||
} \
|
||||
void *dispatch::fname##_;
|
||||
|
||||
#define DEFINE9(init, hlib, ret, fname, t1, t2, t3, t4, t5, t6, t7, t8, t9) ret dispatch::fname(t1 a, t2 b, t3 c, t4 d, t5 e, t6 f, t7 g, t8 h, t9 i)\
|
||||
{return f_impl<dispatch::init>(hlib, fname, fname ## _, #fname, a, b, c, d, e, f, g, h, i); }\
|
||||
void* dispatch::fname ## _;
|
||||
#define DEFINE9(init, hlib, ret, fname, t1, t2, t3, t4, t5, t6, t7, t8, t9) \
|
||||
ret dispatch::fname(t1 a, t2 b, t3 c, t4 d, t5 e, t6 f, t7 g, t8 h, t9 i) { \
|
||||
return f_impl<dispatch::init>(hlib, fname, fname##_, #fname, a, b, c, d, \
|
||||
e, f, g, h, i); \
|
||||
} \
|
||||
void *dispatch::fname##_;
|
||||
|
||||
#define DEFINE10(init, hlib, ret, fname, t1, t2, t3, t4, t5, t6, t7, t8, t9, t10) ret dispatch::fname(t1 a, t2 b, t3 c, t4 d, t5 e, t6 f, t7 g, t8 h, t9 i, t10 j)\
|
||||
{return f_impl<dispatch::init>(hlib, fname, fname ## _, #fname, a, b, c, d, e, f, g, h, i, j); }\
|
||||
void* dispatch::fname ## _;
|
||||
#define DEFINE10(init, hlib, ret, fname, t1, t2, t3, t4, t5, t6, t7, t8, t9, \
|
||||
t10) \
|
||||
ret dispatch::fname(t1 a, t2 b, t3 c, t4 d, t5 e, t6 f, t7 g, t8 h, t9 i, \
|
||||
t10 j) { \
|
||||
return f_impl<dispatch::init>(hlib, fname, fname##_, #fname, a, b, c, d, \
|
||||
e, f, g, h, i, j); \
|
||||
} \
|
||||
void *dispatch::fname##_;
|
||||
|
||||
#define DEFINE11(init, hlib, ret, fname, t1, t2, t3, t4, t5, t6, t7, t8, t9, t10, t11) ret dispatch::fname(t1 a, t2 b, t3 c, t4 d, t5 e, t6 f, t7 g, t8 h, t9 i, t10 j, t11 k)\
|
||||
{return f_impl<dispatch::init>(hlib, fname, fname ## _, #fname, a, b, c, d, e, f, g, h, i, j, k); }\
|
||||
void* dispatch::fname ## _;
|
||||
#define DEFINE11(init, hlib, ret, fname, t1, t2, t3, t4, t5, t6, t7, t8, t9, \
|
||||
t10, t11) \
|
||||
ret dispatch::fname(t1 a, t2 b, t3 c, t4 d, t5 e, t6 f, t7 g, t8 h, t9 i, \
|
||||
t10 j, t11 k) { \
|
||||
return f_impl<dispatch::init>(hlib, fname, fname##_, #fname, a, b, c, d, \
|
||||
e, f, g, h, i, j, k); \
|
||||
} \
|
||||
void *dispatch::fname##_;
|
||||
|
||||
#define DEFINE13(init, hlib, ret, fname, t1, t2, t3, t4, t5, t6, t7, t8, t9, t10, t11, t12, t13) ret dispatch::fname(t1 a, t2 b, t3 c, t4 d, t5 e, t6 f, t7 g, t8 h, t9 i, t10 j, t11 k, t12 l, t13 m)\
|
||||
{return f_impl<dispatch::init>(hlib, fname, fname ## _, #fname, a, b, c, d, e, f, g, h, i, j, k, l, m); }\
|
||||
void* dispatch::fname ## _;
|
||||
|
||||
#define DEFINE19(init, hlib, ret, fname, t1, t2, t3, t4, t5, t6, t7, t8, t9, t10, t11, t12, t13, t14, t15, t16, t17, t18, t19) ret dispatch::fname(t1 a, t2 b, t3 c, t4 d, t5 e, t6 f, t7 g, t8 h, t9 i, t10 j, t11 k, t12 l, t13 m, t14 n, t15 o, t16 p, t17 q, t18 r, t19 s)\
|
||||
{return f_impl<dispatch::init>(hlib, fname, fname ## _, #fname, a, b, c, d, e, f, g, h, i, j, k, l, m, n, o, p, q, r, s); }\
|
||||
void* dispatch::fname ## _;
|
||||
#define DEFINE13(init, hlib, ret, fname, t1, t2, t3, t4, t5, t6, t7, t8, t9, \
|
||||
t10, t11, t12, t13) \
|
||||
ret dispatch::fname(t1 a, t2 b, t3 c, t4 d, t5 e, t6 f, t7 g, t8 h, t9 i, \
|
||||
t10 j, t11 k, t12 l, t13 m) { \
|
||||
return f_impl<dispatch::init>(hlib, fname, fname##_, #fname, a, b, c, d, \
|
||||
e, f, g, h, i, j, k, l, m); \
|
||||
} \
|
||||
void *dispatch::fname##_;
|
||||
|
||||
#define DEFINE19(init, hlib, ret, fname, t1, t2, t3, t4, t5, t6, t7, t8, t9, \
|
||||
t10, t11, t12, t13, t14, t15, t16, t17, t18, t19) \
|
||||
ret dispatch::fname(t1 a, t2 b, t3 c, t4 d, t5 e, t6 f, t7 g, t8 h, t9 i, \
|
||||
t10 j, t11 k, t12 l, t13 m, t14 n, t15 o, t16 p, t17 q, \
|
||||
t18 r, t19 s) { \
|
||||
return f_impl<dispatch::init>(hlib, fname, fname##_, #fname, a, b, c, d, \
|
||||
e, f, g, h, i, j, k, l, m, n, o, p, q, r, \
|
||||
s); \
|
||||
} \
|
||||
void *dispatch::fname##_;
|
||||
|
||||
/* ------------------- *
|
||||
* CUDA
|
||||
* ------------------- */
|
||||
|
||||
bool dispatch::cuinit(){
|
||||
if(cuda_==nullptr){
|
||||
#ifdef _WIN32
|
||||
bool dispatch::cuinit() {
|
||||
if (cuda_ == nullptr) {
|
||||
#ifdef _WIN32
|
||||
cuda_ = dlopen("cudart64_110.dll", RTLD_LAZY);
|
||||
#else
|
||||
#else
|
||||
cuda_ = dlopen("libcuda.so", RTLD_LAZY);
|
||||
if(!cuda_)
|
||||
if (!cuda_)
|
||||
cuda_ = dlopen("libcuda.so.1", RTLD_LAZY);
|
||||
#endif
|
||||
if(!cuda_)
|
||||
throw std::runtime_error("Could not find `libcuda.so`. Make sure it is in your LD_LIBRARY_PATH.");
|
||||
#endif
|
||||
if (!cuda_)
|
||||
throw std::runtime_error("Could not find `libcuda.so`. Make sure it is "
|
||||
"in your LD_LIBRARY_PATH.");
|
||||
}
|
||||
if(cuda_ == nullptr)
|
||||
if (cuda_ == nullptr)
|
||||
return false;
|
||||
CUresult (*fptr)(unsigned int);
|
||||
cuInit_ = dlsym(cuda_, "cuInit");
|
||||
@@ -112,21 +157,33 @@ bool dispatch::cuinit(){
|
||||
}
|
||||
|
||||
#define CUDA_DEFINE1(ret, fname, t1) DEFINE1(cuinit, cuda_, ret, fname, t1)
|
||||
#define CUDA_DEFINE2(ret, fname, t1, t2) DEFINE2(cuinit, cuda_, ret, fname, t1, t2)
|
||||
#define CUDA_DEFINE3(ret, fname, t1, t2, t3) DEFINE3(cuinit, cuda_, ret, fname, t1, t2, t3)
|
||||
#define CUDA_DEFINE4(ret, fname, t1, t2, t3, t4) DEFINE4(cuinit, cuda_, ret, fname, t1, t2, t3, t4)
|
||||
#define CUDA_DEFINE5(ret, fname, t1, t2, t3, t4, t5) DEFINE5(cuinit, cuda_, ret, fname, t1, t2, t3, t4, t5)
|
||||
#define CUDA_DEFINE6(ret, fname, t1, t2, t3, t4, t5, t6) DEFINE6(cuinit, cuda_, ret, fname, t1, t2, t3, t4, t5, t6)
|
||||
#define CUDA_DEFINE7(ret, fname, t1, t2, t3, t4, t5, t6, t7) DEFINE7(cuinit, cuda_, ret, fname, t1, t2, t3, t4, t5, t6, t7)
|
||||
#define CUDA_DEFINE8(ret, fname, t1, t2, t3, t4, t5, t6, t7, t8) DEFINE8(cuinit, cuda_, ret, fname, t1, t2, t3, t4, t5, t6, t7, t8)
|
||||
#define CUDA_DEFINE9(ret, fname, t1, t2, t3, t4, t5, t6, t7, t8, t9) DEFINE9(cuinit, cuda_, ret, fname, t1, t2, t3, t4, t5, t6, t7, t8, t9)
|
||||
#define CUDA_DEFINE10(ret, fname, t1, t2, t3, t4, t5, t6, t7, t8, t9, t10) DEFINE10(cuinit, cuda_, ret, fname, t1, t2, t3, t4, t5, t6, t7, t8, t9, t10)
|
||||
#define CUDA_DEFINE11(ret, fname, t1, t2, t3, t4, t5, t6, t7, t8, t9, t10, t11) DEFINE11(cuinit, cuda_, ret, fname, t1, t2, t3, t4, t5, t6, t7, t8, t9, t10, t11)
|
||||
#define CUDA_DEFINE2(ret, fname, t1, t2) \
|
||||
DEFINE2(cuinit, cuda_, ret, fname, t1, t2)
|
||||
#define CUDA_DEFINE3(ret, fname, t1, t2, t3) \
|
||||
DEFINE3(cuinit, cuda_, ret, fname, t1, t2, t3)
|
||||
#define CUDA_DEFINE4(ret, fname, t1, t2, t3, t4) \
|
||||
DEFINE4(cuinit, cuda_, ret, fname, t1, t2, t3, t4)
|
||||
#define CUDA_DEFINE5(ret, fname, t1, t2, t3, t4, t5) \
|
||||
DEFINE5(cuinit, cuda_, ret, fname, t1, t2, t3, t4, t5)
|
||||
#define CUDA_DEFINE6(ret, fname, t1, t2, t3, t4, t5, t6) \
|
||||
DEFINE6(cuinit, cuda_, ret, fname, t1, t2, t3, t4, t5, t6)
|
||||
#define CUDA_DEFINE7(ret, fname, t1, t2, t3, t4, t5, t6, t7) \
|
||||
DEFINE7(cuinit, cuda_, ret, fname, t1, t2, t3, t4, t5, t6, t7)
|
||||
#define CUDA_DEFINE8(ret, fname, t1, t2, t3, t4, t5, t6, t7, t8) \
|
||||
DEFINE8(cuinit, cuda_, ret, fname, t1, t2, t3, t4, t5, t6, t7, t8)
|
||||
#define CUDA_DEFINE9(ret, fname, t1, t2, t3, t4, t5, t6, t7, t8, t9) \
|
||||
DEFINE9(cuinit, cuda_, ret, fname, t1, t2, t3, t4, t5, t6, t7, t8, t9)
|
||||
#define CUDA_DEFINE10(ret, fname, t1, t2, t3, t4, t5, t6, t7, t8, t9, t10) \
|
||||
DEFINE10(cuinit, cuda_, ret, fname, t1, t2, t3, t4, t5, t6, t7, t8, t9, t10)
|
||||
#define CUDA_DEFINE11(ret, fname, t1, t2, t3, t4, t5, t6, t7, t8, t9, t10, \
|
||||
t11) \
|
||||
DEFINE11(cuinit, cuda_, ret, fname, t1, t2, t3, t4, t5, t6, t7, t8, t9, t10, \
|
||||
t11)
|
||||
|
||||
// context management
|
||||
CUDA_DEFINE1(CUresult, cuCtxDestroy_v2, CUcontext)
|
||||
CUDA_DEFINE3(CUresult, cuCtxCreate_v2, CUcontext *, unsigned int, CUdevice)
|
||||
CUDA_DEFINE1(CUresult, cuCtxGetDevice, CUdevice*)
|
||||
CUDA_DEFINE1(CUresult, cuCtxGetDevice, CUdevice *)
|
||||
CUDA_DEFINE2(CUresult, cuCtxEnablePeerAccess, CUcontext, unsigned int)
|
||||
CUDA_DEFINE1(CUresult, cuInit, unsigned int)
|
||||
CUDA_DEFINE1(CUresult, cuDriverGetVersion, int *)
|
||||
@@ -134,59 +191,71 @@ CUDA_DEFINE1(CUresult, cuDriverGetVersion, int *)
|
||||
CUDA_DEFINE2(CUresult, cuDeviceGet, CUdevice *, int)
|
||||
CUDA_DEFINE3(CUresult, cuDeviceGetName, char *, int, CUdevice)
|
||||
CUDA_DEFINE3(CUresult, cuDeviceGetPCIBusId, char *, int, CUdevice)
|
||||
CUDA_DEFINE3(CUresult, cuDeviceGetAttribute, int *, CUdevice_attribute, CUdevice)
|
||||
CUDA_DEFINE1(CUresult, cuDeviceGetCount, int*)
|
||||
CUDA_DEFINE3(CUresult, cuDeviceGetAttribute, int *, CUdevice_attribute,
|
||||
CUdevice)
|
||||
CUDA_DEFINE1(CUresult, cuDeviceGetCount, int *)
|
||||
|
||||
// link management
|
||||
CUDA_DEFINE8(CUresult, cuLinkAddData_v2, CUlinkState, CUjitInputType, void*, size_t, const char*, unsigned int, CUjit_option*, void**);
|
||||
CUDA_DEFINE4(CUresult, cuLinkCreate_v2, unsigned int, CUjit_option*, void**, CUlinkState*);
|
||||
CUDA_DEFINE8(CUresult, cuLinkAddData_v2, CUlinkState, CUjitInputType, void *,
|
||||
size_t, const char *, unsigned int, CUjit_option *, void **);
|
||||
CUDA_DEFINE4(CUresult, cuLinkCreate_v2, unsigned int, CUjit_option *, void **,
|
||||
CUlinkState *);
|
||||
CUDA_DEFINE1(CUresult, cuLinkDestroy, CUlinkState);
|
||||
CUDA_DEFINE3(CUresult, cuLinkComplete, CUlinkState, void**, size_t*);
|
||||
CUDA_DEFINE3(CUresult, cuLinkComplete, CUlinkState, void **, size_t *);
|
||||
// module management
|
||||
CUDA_DEFINE4(CUresult, cuModuleGetGlobal_v2, CUdeviceptr*, size_t*, CUmodule, const char*)
|
||||
CUDA_DEFINE4(CUresult, cuModuleGetGlobal_v2, CUdeviceptr *, size_t *, CUmodule,
|
||||
const char *)
|
||||
CUDA_DEFINE2(CUresult, cuModuleLoad, CUmodule *, const char *)
|
||||
CUDA_DEFINE1(CUresult, cuModuleUnload, CUmodule)
|
||||
CUDA_DEFINE2(CUresult, cuModuleLoadData, CUmodule *, const void *)
|
||||
CUDA_DEFINE5(CUresult, cuModuleLoadDataEx, CUmodule *, const void *, unsigned int, CUjit_option *, void **)
|
||||
CUDA_DEFINE3(CUresult, cuModuleGetFunction, CUfunction *, CUmodule, const char *)
|
||||
CUDA_DEFINE5(CUresult, cuModuleLoadDataEx, CUmodule *, const void *,
|
||||
unsigned int, CUjit_option *, void **)
|
||||
CUDA_DEFINE3(CUresult, cuModuleGetFunction, CUfunction *, CUmodule,
|
||||
const char *)
|
||||
// stream management
|
||||
CUDA_DEFINE2(CUresult, cuStreamCreate, CUstream *, unsigned int)
|
||||
CUDA_DEFINE1(CUresult, cuStreamSynchronize, CUstream)
|
||||
CUDA_DEFINE1(CUresult, cuStreamDestroy_v2, CUstream)
|
||||
CUDA_DEFINE2(CUresult, cuStreamGetCtx, CUstream, CUcontext*)
|
||||
CUDA_DEFINE11(CUresult, cuLaunchKernel, CUfunction, unsigned int, unsigned int, unsigned int, unsigned int, unsigned int, unsigned int, unsigned int, CUstream, void **, void **)
|
||||
CUDA_DEFINE2(CUresult, cuStreamGetCtx, CUstream, CUcontext *)
|
||||
CUDA_DEFINE11(CUresult, cuLaunchKernel, CUfunction, unsigned int, unsigned int,
|
||||
unsigned int, unsigned int, unsigned int, unsigned int,
|
||||
unsigned int, CUstream, void **, void **)
|
||||
// function management
|
||||
CUDA_DEFINE3(CUresult, cuFuncGetAttribute, int*, CUfunction_attribute, CUfunction)
|
||||
CUDA_DEFINE3(CUresult, cuFuncSetAttribute, CUfunction, CUfunction_attribute, int)
|
||||
CUDA_DEFINE3(CUresult, cuFuncGetAttribute, int *, CUfunction_attribute,
|
||||
CUfunction)
|
||||
CUDA_DEFINE3(CUresult, cuFuncSetAttribute, CUfunction, CUfunction_attribute,
|
||||
int)
|
||||
CUDA_DEFINE2(CUresult, cuFuncSetCacheConfig, CUfunction, CUfunc_cache)
|
||||
// memory management
|
||||
CUDA_DEFINE3(CUresult, cuMemcpyDtoH_v2, void *, CUdeviceptr, size_t)
|
||||
CUDA_DEFINE1(CUresult, cuMemFree_v2, CUdeviceptr)
|
||||
CUDA_DEFINE4(CUresult, cuMemcpyDtoHAsync_v2, void *, CUdeviceptr, size_t, CUstream)
|
||||
CUDA_DEFINE4(CUresult, cuMemcpyHtoDAsync_v2, CUdeviceptr, const void *, size_t, CUstream)
|
||||
CUDA_DEFINE3(CUresult, cuMemcpyHtoD_v2, CUdeviceptr, const void *, size_t )
|
||||
CUDA_DEFINE2(CUresult, cuMemAlloc_v2, CUdeviceptr*, size_t)
|
||||
CUDA_DEFINE3(CUresult, cuPointerGetAttribute, void*, CUpointer_attribute, CUdeviceptr)
|
||||
CUDA_DEFINE4(CUresult, cuMemsetD8Async, CUdeviceptr, unsigned char, size_t, CUstream)
|
||||
CUDA_DEFINE4(CUresult, cuMemcpyDtoHAsync_v2, void *, CUdeviceptr, size_t,
|
||||
CUstream)
|
||||
CUDA_DEFINE4(CUresult, cuMemcpyHtoDAsync_v2, CUdeviceptr, const void *, size_t,
|
||||
CUstream)
|
||||
CUDA_DEFINE3(CUresult, cuMemcpyHtoD_v2, CUdeviceptr, const void *, size_t)
|
||||
CUDA_DEFINE2(CUresult, cuMemAlloc_v2, CUdeviceptr *, size_t)
|
||||
CUDA_DEFINE3(CUresult, cuPointerGetAttribute, void *, CUpointer_attribute,
|
||||
CUdeviceptr)
|
||||
CUDA_DEFINE4(CUresult, cuMemsetD8Async, CUdeviceptr, unsigned char, size_t,
|
||||
CUstream)
|
||||
// event management
|
||||
CUDA_DEFINE2(CUresult, cuEventCreate, CUevent *, unsigned int)
|
||||
CUDA_DEFINE3(CUresult, cuEventElapsedTime, float *, CUevent, CUevent)
|
||||
CUDA_DEFINE2(CUresult, cuEventRecord, CUevent, CUstream)
|
||||
CUDA_DEFINE1(CUresult, cuEventDestroy_v2, CUevent)
|
||||
|
||||
|
||||
|
||||
/* ------------------- *
|
||||
* NVML
|
||||
* ------------------- */
|
||||
bool dispatch::nvmlinit(){
|
||||
#ifdef _WIN32
|
||||
if(nvml_==nullptr)
|
||||
bool dispatch::nvmlinit() {
|
||||
#ifdef _WIN32
|
||||
if (nvml_ == nullptr)
|
||||
nvml_ = dlopen("nvml.dll", RTLD_LAZY);
|
||||
#else
|
||||
if(nvml_==nullptr)
|
||||
#else
|
||||
if (nvml_ == nullptr)
|
||||
nvml_ = dlopen("libnvidia-ml.so", RTLD_LAZY);
|
||||
#endif
|
||||
#endif
|
||||
nvmlReturn_t (*fptr)();
|
||||
nvmlInit_v2_ = dlsym(nvml_, "nvmlInit_v2");
|
||||
*reinterpret_cast<void **>(&fptr) = nvmlInit_v2_;
|
||||
@@ -197,21 +266,27 @@ bool dispatch::nvmlinit(){
|
||||
|
||||
#define NVML_DEFINE0(ret, fname) DEFINE0(nvmlinit, nvml_, ret, fname)
|
||||
#define NVML_DEFINE1(ret, fname, t1) DEFINE1(nvmlinit, nvml_, ret, fname, t1)
|
||||
#define NVML_DEFINE2(ret, fname, t1, t2) DEFINE2(nvmlinit, nvml_, ret, fname, t1, t2)
|
||||
#define NVML_DEFINE3(ret, fname, t1, t2, t3) DEFINE3(nvmlinit, nvml_, ret, fname, t1, t2, t3)
|
||||
#define NVML_DEFINE2(ret, fname, t1, t2) \
|
||||
DEFINE2(nvmlinit, nvml_, ret, fname, t1, t2)
|
||||
#define NVML_DEFINE3(ret, fname, t1, t2, t3) \
|
||||
DEFINE3(nvmlinit, nvml_, ret, fname, t1, t2, t3)
|
||||
|
||||
NVML_DEFINE2(nvmlReturn_t, nvmlDeviceGetHandleByPciBusId_v2, const char *, nvmlDevice_t*)
|
||||
NVML_DEFINE3(nvmlReturn_t, nvmlDeviceGetClockInfo, nvmlDevice_t, nvmlClockType_t, unsigned int*)
|
||||
NVML_DEFINE3(nvmlReturn_t, nvmlDeviceGetMaxClockInfo, nvmlDevice_t, nvmlClockType_t, unsigned int*)
|
||||
NVML_DEFINE3(nvmlReturn_t, nvmlDeviceSetApplicationsClocks, nvmlDevice_t, unsigned int, unsigned int)
|
||||
NVML_DEFINE2(nvmlReturn_t, nvmlDeviceGetHandleByPciBusId_v2, const char *,
|
||||
nvmlDevice_t *)
|
||||
NVML_DEFINE3(nvmlReturn_t, nvmlDeviceGetClockInfo, nvmlDevice_t,
|
||||
nvmlClockType_t, unsigned int *)
|
||||
NVML_DEFINE3(nvmlReturn_t, nvmlDeviceGetMaxClockInfo, nvmlDevice_t,
|
||||
nvmlClockType_t, unsigned int *)
|
||||
NVML_DEFINE3(nvmlReturn_t, nvmlDeviceSetApplicationsClocks, nvmlDevice_t,
|
||||
unsigned int, unsigned int)
|
||||
|
||||
/* ------------------- *
|
||||
* HIP
|
||||
* ------------------- */
|
||||
bool dispatch::hipinit(){
|
||||
if(hip_==nullptr)
|
||||
bool dispatch::hipinit() {
|
||||
if (hip_ == nullptr)
|
||||
hip_ = dlopen("libamdhip64.so", RTLD_LAZY);
|
||||
if(hip_ == nullptr)
|
||||
if (hip_ == nullptr)
|
||||
return false;
|
||||
hipError_t (*fptr)();
|
||||
hipInit_ = dlsym(hip_, "hipInit");
|
||||
@@ -222,23 +297,34 @@ bool dispatch::hipinit(){
|
||||
}
|
||||
|
||||
#define HIP_DEFINE1(ret, fname, t1) DEFINE1(hipinit, hip_, ret, fname, t1)
|
||||
#define HIP_DEFINE2(ret, fname, t1, t2) DEFINE2(hipinit, hip_, ret, fname, t1, t2)
|
||||
#define HIP_DEFINE3(ret, fname, t1, t2, t3) DEFINE3(hipinit, hip_, ret, fname, t1, t2, t3)
|
||||
#define HIP_DEFINE4(ret, fname, t1, t2, t3, t4) DEFINE4(hipinit, hip_, ret, fname, t1, t2, t3, t4)
|
||||
#define HIP_DEFINE5(ret, fname, t1, t2, t3, t4, t5) DEFINE5(hipinit, hip_, ret, fname, t1, t2, t3, t4, t5)
|
||||
#define HIP_DEFINE6(ret, fname, t1, t2, t3, t4, t5, t6) DEFINE6(hipinit, hip_, ret, fname, t1, t2, t3, t4, t5, t6)
|
||||
#define HIP_DEFINE7(ret, fname, t1, t2, t3, t4, t5, t6, t7) DEFINE7(hipinit, hip_, ret, fname, t1, t2, t3, t4, t5, t6, t7)
|
||||
#define HIP_DEFINE8(ret, fname, t1, t2, t3, t4, t5, t6, t7, t8) DEFINE8(hipinit, hip_, ret, fname, t1, t2, t3, t4, t5, t6, t7, t8)
|
||||
#define HIP_DEFINE9(ret, fname, t1, t2, t3, t4, t5, t6, t7, t8, t9) DEFINE9(hipinit, hip_, ret, fname, t1, t2, t3, t4, t5, t6, t7, t8, t9)
|
||||
#define HIP_DEFINE10(ret, fname, t1, t2, t3, t4, t5, t6, t7, t8, t9, t10) DEFINE10(hipinit, hip_, ret, fname, t1, t2, t3, t4, t5, t6, t7, t8, t9, t10)
|
||||
#define HIP_DEFINE11(ret, fname, t1, t2, t3, t4, t5, t6, t7, t8, t9, t10, t11) DEFINE11(hipinit, hip_, ret, fname, t1, t2, t3, t4, t5, t6, t7, t8, t9, t10, t11)
|
||||
#define HIP_DEFINE2(ret, fname, t1, t2) \
|
||||
DEFINE2(hipinit, hip_, ret, fname, t1, t2)
|
||||
#define HIP_DEFINE3(ret, fname, t1, t2, t3) \
|
||||
DEFINE3(hipinit, hip_, ret, fname, t1, t2, t3)
|
||||
#define HIP_DEFINE4(ret, fname, t1, t2, t3, t4) \
|
||||
DEFINE4(hipinit, hip_, ret, fname, t1, t2, t3, t4)
|
||||
#define HIP_DEFINE5(ret, fname, t1, t2, t3, t4, t5) \
|
||||
DEFINE5(hipinit, hip_, ret, fname, t1, t2, t3, t4, t5)
|
||||
#define HIP_DEFINE6(ret, fname, t1, t2, t3, t4, t5, t6) \
|
||||
DEFINE6(hipinit, hip_, ret, fname, t1, t2, t3, t4, t5, t6)
|
||||
#define HIP_DEFINE7(ret, fname, t1, t2, t3, t4, t5, t6, t7) \
|
||||
DEFINE7(hipinit, hip_, ret, fname, t1, t2, t3, t4, t5, t6, t7)
|
||||
#define HIP_DEFINE8(ret, fname, t1, t2, t3, t4, t5, t6, t7, t8) \
|
||||
DEFINE8(hipinit, hip_, ret, fname, t1, t2, t3, t4, t5, t6, t7, t8)
|
||||
#define HIP_DEFINE9(ret, fname, t1, t2, t3, t4, t5, t6, t7, t8, t9) \
|
||||
DEFINE9(hipinit, hip_, ret, fname, t1, t2, t3, t4, t5, t6, t7, t8, t9)
|
||||
#define HIP_DEFINE10(ret, fname, t1, t2, t3, t4, t5, t6, t7, t8, t9, t10) \
|
||||
DEFINE10(hipinit, hip_, ret, fname, t1, t2, t3, t4, t5, t6, t7, t8, t9, t10)
|
||||
#define HIP_DEFINE11(ret, fname, t1, t2, t3, t4, t5, t6, t7, t8, t9, t10, t11) \
|
||||
DEFINE11(hipinit, hip_, ret, fname, t1, t2, t3, t4, t5, t6, t7, t8, t9, t10, \
|
||||
t11)
|
||||
|
||||
// context management
|
||||
HIP_DEFINE1(hipError_t, hipCtxDestroy, hipCtx_t)
|
||||
HIP_DEFINE3(hipError_t, hipCtxCreate, hipCtx_t *, unsigned int, hipDevice_t)
|
||||
HIP_DEFINE1(hipError_t, hipCtxGetDevice, hipDevice_t*)
|
||||
HIP_DEFINE1(hipError_t, hipCtxGetDevice, hipDevice_t *)
|
||||
HIP_DEFINE1(hipError_t, hipCtxPushCurrent, hipCtx_t)
|
||||
HIP_DEFINE1(hipError_t, hipCtxPopCurrent, hipCtx_t*)
|
||||
HIP_DEFINE1(hipError_t, hipCtxPopCurrent, hipCtx_t *)
|
||||
HIP_DEFINE2(hipError_t, hipCtxEnablePeerAccess, hipCtx_t, unsigned int)
|
||||
HIP_DEFINE1(hipError_t, hipInit, unsigned int)
|
||||
HIP_DEFINE1(hipError_t, hipDriverGetVersion, int *)
|
||||
@@ -246,56 +332,64 @@ HIP_DEFINE1(hipError_t, hipDriverGetVersion, int *)
|
||||
HIP_DEFINE2(hipError_t, hipGetDevice, hipDevice_t *, int)
|
||||
HIP_DEFINE3(hipError_t, hipDeviceGetName, char *, int, hipDevice_t)
|
||||
HIP_DEFINE3(hipError_t, hipDeviceGetPCIBusId, char *, int, hipDevice_t)
|
||||
HIP_DEFINE3(hipError_t, hipDeviceGetAttribute, int *, hipDeviceAttribute_t, hipDevice_t)
|
||||
HIP_DEFINE3(hipError_t, hipDeviceGetAttribute, int *, hipDeviceAttribute_t,
|
||||
hipDevice_t)
|
||||
HIP_DEFINE1(hipError_t, hipGetDeviceCount, int *)
|
||||
// module management
|
||||
HIP_DEFINE4(hipError_t, hipModuleGetGlobal, hipDeviceptr_t*, size_t*, hipModule_t, const char*)
|
||||
HIP_DEFINE4(hipError_t, hipModuleGetGlobal, hipDeviceptr_t *, size_t *,
|
||||
hipModule_t, const char *)
|
||||
HIP_DEFINE2(hipError_t, hipModuleLoad, hipModule_t *, const char *)
|
||||
HIP_DEFINE1(hipError_t, hipModuleUnload, hipModule_t)
|
||||
HIP_DEFINE2(hipError_t, hipModuleLoadData, hipModule_t *, const void *)
|
||||
HIP_DEFINE5(hipError_t, hipModuleLoadDataEx, hipModule_t *, const void *, unsigned int, hipJitOption *, void **)
|
||||
HIP_DEFINE3(hipError_t, hipModuleGetFunction, hipFunction_t *, hipModule_t, const char *)
|
||||
HIP_DEFINE5(hipError_t, hipModuleLoadDataEx, hipModule_t *, const void *,
|
||||
unsigned int, hipJitOption *, void **)
|
||||
HIP_DEFINE3(hipError_t, hipModuleGetFunction, hipFunction_t *, hipModule_t,
|
||||
const char *)
|
||||
// stream management
|
||||
HIP_DEFINE2(hipError_t, hipStreamCreate, hipStream_t *, unsigned int)
|
||||
HIP_DEFINE1(hipError_t, hipStreamSynchronize, hipStream_t)
|
||||
HIP_DEFINE1(hipError_t, hipStreamDestroy, hipStream_t)
|
||||
HIP_DEFINE11(hipError_t, hipModuleLaunchKernel, hipFunction_t, unsigned int, unsigned int, unsigned int, unsigned int, unsigned int, unsigned int, unsigned int, hipStream_t, void **, void **)
|
||||
HIP_DEFINE11(hipError_t, hipModuleLaunchKernel, hipFunction_t, unsigned int,
|
||||
unsigned int, unsigned int, unsigned int, unsigned int,
|
||||
unsigned int, unsigned int, hipStream_t, void **, void **)
|
||||
// function management
|
||||
HIP_DEFINE2(hipError_t, hipFuncGetAttributes, hipFuncAttributes*, void*)
|
||||
HIP_DEFINE2(hipError_t, hipFuncGetAttributes, hipFuncAttributes *, void *)
|
||||
HIP_DEFINE2(hipError_t, hipFuncSetCacheConfig, hipFunction_t, hipFuncCache_t)
|
||||
// memory management
|
||||
HIP_DEFINE3(hipError_t, hipMemcpyDtoH, void *, hipDeviceptr_t, size_t)
|
||||
HIP_DEFINE1(hipError_t, hipFree, hipDeviceptr_t)
|
||||
HIP_DEFINE4(hipError_t, hipMemcpyDtoHAsync, void *, hipDeviceptr_t, size_t, hipStream_t)
|
||||
HIP_DEFINE4(hipError_t, hipMemcpyHtoDAsync, hipDeviceptr_t, const void *, size_t, hipStream_t)
|
||||
HIP_DEFINE3(hipError_t, hipMemcpyHtoD, hipDeviceptr_t, const void *, size_t )
|
||||
HIP_DEFINE2(hipError_t, hipMalloc, hipDeviceptr_t*, size_t)
|
||||
HIP_DEFINE3(hipError_t, hipPointerGetAttribute, void*, CUpointer_attribute, hipDeviceptr_t)
|
||||
HIP_DEFINE4(hipError_t, hipMemsetD8Async, hipDeviceptr_t, unsigned char, size_t, hipStream_t)
|
||||
HIP_DEFINE4(hipError_t, hipMemcpyDtoHAsync, void *, hipDeviceptr_t, size_t,
|
||||
hipStream_t)
|
||||
HIP_DEFINE4(hipError_t, hipMemcpyHtoDAsync, hipDeviceptr_t, const void *,
|
||||
size_t, hipStream_t)
|
||||
HIP_DEFINE3(hipError_t, hipMemcpyHtoD, hipDeviceptr_t, const void *, size_t)
|
||||
HIP_DEFINE2(hipError_t, hipMalloc, hipDeviceptr_t *, size_t)
|
||||
HIP_DEFINE3(hipError_t, hipPointerGetAttribute, void *, CUpointer_attribute,
|
||||
hipDeviceptr_t)
|
||||
HIP_DEFINE4(hipError_t, hipMemsetD8Async, hipDeviceptr_t, unsigned char, size_t,
|
||||
hipStream_t)
|
||||
// event management
|
||||
HIP_DEFINE2(hipError_t, hipEventCreate, hipEvent_t *, unsigned int)
|
||||
HIP_DEFINE3(hipError_t, hipEventElapsedTime, float *, hipEvent_t, hipEvent_t)
|
||||
HIP_DEFINE2(hipError_t, hipEventRecord, hipEvent_t, hipStream_t)
|
||||
HIP_DEFINE1(hipError_t, hipEventDestroy, hipEvent_t)
|
||||
|
||||
|
||||
/* ------------------- *
|
||||
* COMMON
|
||||
* ------------------- */
|
||||
|
||||
// Release
|
||||
void dispatch::release(){
|
||||
if(cuda_){
|
||||
void dispatch::release() {
|
||||
if (cuda_) {
|
||||
dlclose(cuda_);
|
||||
cuda_ = nullptr;
|
||||
}
|
||||
}
|
||||
|
||||
void* dispatch::cuda_;
|
||||
void* dispatch::nvml_;
|
||||
void* dispatch::nvmlInit_v2_;
|
||||
void* dispatch::hip_;
|
||||
void *dispatch::cuda_;
|
||||
void *dispatch::nvml_;
|
||||
void *dispatch::nvmlInit_v2_;
|
||||
void *dispatch::hip_;
|
||||
|
||||
|
||||
}
|
||||
}
|
||||
} // namespace driver
|
||||
} // namespace triton
|
||||
|
410
lib/driver/error.cc
Executable file → Normal file
410
lib/driver/error.cc
Executable file → Normal file
@@ -1,166 +1,270 @@
|
||||
/* Copyright 2015-2017 Philippe Tillet
|
||||
*
|
||||
* Permission is hereby granted, free of charge, to any person obtaining
|
||||
* a copy of this software and associated documentation files
|
||||
* (the "Software"), to deal in the Software without restriction,
|
||||
* including without limitation the rights to use, copy, modify, merge,
|
||||
* publish, distribute, sublicense, and/or sell copies of the Software,
|
||||
* and to permit persons to whom the Software is furnished to do so,
|
||||
* subject to the following conditions:
|
||||
*
|
||||
* The above copyright notice and this permission notice shall be
|
||||
* included in all copies or substantial portions of the Software.
|
||||
*
|
||||
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
|
||||
* EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
|
||||
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
|
||||
* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
|
||||
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
|
||||
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
|
||||
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
||||
*/
|
||||
*
|
||||
* Permission is hereby granted, free of charge, to any person obtaining
|
||||
* a copy of this software and associated documentation files
|
||||
* (the "Software"), to deal in the Software without restriction,
|
||||
* including without limitation the rights to use, copy, modify, merge,
|
||||
* publish, distribute, sublicense, and/or sell copies of the Software,
|
||||
* and to permit persons to whom the Software is furnished to do so,
|
||||
* subject to the following conditions:
|
||||
*
|
||||
* The above copyright notice and this permission notice shall be
|
||||
* included in all copies or substantial portions of the Software.
|
||||
*
|
||||
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
|
||||
* EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
|
||||
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
|
||||
* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
|
||||
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
|
||||
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
|
||||
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
||||
*/
|
||||
|
||||
#include "triton/driver/error.h"
|
||||
|
||||
namespace triton
|
||||
{
|
||||
namespace driver
|
||||
{
|
||||
namespace triton {
|
||||
namespace driver {
|
||||
|
||||
void check(CUresult err)
|
||||
{
|
||||
void check(CUresult err) {
|
||||
using namespace exception::cuda;
|
||||
switch(err)
|
||||
{
|
||||
case CUDA_SUCCESS : break;
|
||||
case CUDA_ERROR_INVALID_VALUE : throw invalid_value();
|
||||
case CUDA_ERROR_OUT_OF_MEMORY : throw out_of_memory();
|
||||
case CUDA_ERROR_NOT_INITIALIZED : throw not_initialized();
|
||||
case CUDA_ERROR_DEINITIALIZED : throw deinitialized();
|
||||
case CUDA_ERROR_PROFILER_DISABLED : throw profiler_disabled();
|
||||
case CUDA_ERROR_PROFILER_NOT_INITIALIZED : throw profiler_not_initialized();
|
||||
case CUDA_ERROR_PROFILER_ALREADY_STARTED : throw profiler_already_started();
|
||||
case CUDA_ERROR_PROFILER_ALREADY_STOPPED : throw profiler_already_stopped();
|
||||
case CUDA_ERROR_NO_DEVICE : throw no_device();
|
||||
case CUDA_ERROR_INVALID_DEVICE : throw invalid_device();
|
||||
case CUDA_ERROR_INVALID_IMAGE : throw invalid_image();
|
||||
case CUDA_ERROR_INVALID_CONTEXT : throw invalid_context();
|
||||
case CUDA_ERROR_CONTEXT_ALREADY_CURRENT : throw context_already_current();
|
||||
case CUDA_ERROR_MAP_FAILED : throw map_failed();
|
||||
case CUDA_ERROR_UNMAP_FAILED : throw unmap_failed();
|
||||
case CUDA_ERROR_ARRAY_IS_MAPPED : throw array_is_mapped();
|
||||
case CUDA_ERROR_ALREADY_MAPPED : throw already_mapped();
|
||||
case CUDA_ERROR_NO_BINARY_FOR_GPU : throw no_binary_for_gpu();
|
||||
case CUDA_ERROR_ALREADY_ACQUIRED : throw already_acquired();
|
||||
case CUDA_ERROR_NOT_MAPPED : throw not_mapped();
|
||||
case CUDA_ERROR_NOT_MAPPED_AS_ARRAY : throw not_mapped_as_array();
|
||||
case CUDA_ERROR_NOT_MAPPED_AS_POINTER : throw not_mapped_as_pointer();
|
||||
case CUDA_ERROR_ECC_UNCORRECTABLE : throw ecc_uncorrectable();
|
||||
case CUDA_ERROR_UNSUPPORTED_LIMIT : throw unsupported_limit();
|
||||
case CUDA_ERROR_CONTEXT_ALREADY_IN_USE : throw context_already_in_use();
|
||||
case CUDA_ERROR_PEER_ACCESS_UNSUPPORTED : throw peer_access_unsupported();
|
||||
case CUDA_ERROR_INVALID_PTX : throw invalid_ptx();
|
||||
case CUDA_ERROR_INVALID_GRAPHICS_CONTEXT : throw invalid_graphics_context();
|
||||
case CUDA_ERROR_INVALID_SOURCE : throw invalid_source();
|
||||
case CUDA_ERROR_FILE_NOT_FOUND : throw file_not_found();
|
||||
case CUDA_ERROR_SHARED_OBJECT_SYMBOL_NOT_FOUND : throw shared_object_symbol_not_found();
|
||||
case CUDA_ERROR_SHARED_OBJECT_INIT_FAILED : throw shared_object_init_failed();
|
||||
case CUDA_ERROR_OPERATING_SYSTEM : throw operating_system();
|
||||
case CUDA_ERROR_INVALID_HANDLE : throw invalid_handle();
|
||||
case CUDA_ERROR_NOT_FOUND : throw not_found();
|
||||
case CUDA_ERROR_NOT_READY : throw not_ready();
|
||||
case CUDA_ERROR_ILLEGAL_ADDRESS : throw illegal_address();
|
||||
case CUDA_ERROR_LAUNCH_OUT_OF_RESOURCES : throw launch_out_of_resources();
|
||||
case CUDA_ERROR_LAUNCH_TIMEOUT : throw launch_timeout();
|
||||
case CUDA_ERROR_LAUNCH_INCOMPATIBLE_TEXTURING : throw launch_incompatible_texturing();
|
||||
case CUDA_ERROR_PEER_ACCESS_ALREADY_ENABLED : throw peer_access_already_enabled();
|
||||
case CUDA_ERROR_PEER_ACCESS_NOT_ENABLED : throw peer_access_not_enabled();
|
||||
case CUDA_ERROR_PRIMARY_CONTEXT_ACTIVE : throw primary_context_active();
|
||||
case CUDA_ERROR_CONTEXT_IS_DESTROYED : throw context_is_destroyed();
|
||||
case CUDA_ERROR_ASSERT : throw assert_error();
|
||||
case CUDA_ERROR_TOO_MANY_PEERS : throw too_many_peers();
|
||||
case CUDA_ERROR_HOST_MEMORY_ALREADY_REGISTERED : throw host_memory_already_registered();
|
||||
case CUDA_ERROR_HOST_MEMORY_NOT_REGISTERED : throw host_memory_not_registered();
|
||||
case CUDA_ERROR_HARDWARE_STACK_ERROR : throw hardware_stack_error();
|
||||
case CUDA_ERROR_ILLEGAL_INSTRUCTION : throw illegal_instruction();
|
||||
case CUDA_ERROR_MISALIGNED_ADDRESS : throw misaligned_address();
|
||||
case CUDA_ERROR_INVALID_ADDRESS_SPACE : throw invalid_address_space();
|
||||
case CUDA_ERROR_INVALID_PC : throw invalid_pc();
|
||||
case CUDA_ERROR_LAUNCH_FAILED : throw launch_failed();
|
||||
case CUDA_ERROR_NOT_PERMITTED : throw not_permitted();
|
||||
case CUDA_ERROR_NOT_SUPPORTED : throw not_supported();
|
||||
case CUDA_ERROR_UNKNOWN : throw unknown();
|
||||
default : throw unknown();
|
||||
switch (err) {
|
||||
case CUDA_SUCCESS:
|
||||
break;
|
||||
case CUDA_ERROR_INVALID_VALUE:
|
||||
throw invalid_value();
|
||||
case CUDA_ERROR_OUT_OF_MEMORY:
|
||||
throw out_of_memory();
|
||||
case CUDA_ERROR_NOT_INITIALIZED:
|
||||
throw not_initialized();
|
||||
case CUDA_ERROR_DEINITIALIZED:
|
||||
throw deinitialized();
|
||||
case CUDA_ERROR_PROFILER_DISABLED:
|
||||
throw profiler_disabled();
|
||||
case CUDA_ERROR_PROFILER_NOT_INITIALIZED:
|
||||
throw profiler_not_initialized();
|
||||
case CUDA_ERROR_PROFILER_ALREADY_STARTED:
|
||||
throw profiler_already_started();
|
||||
case CUDA_ERROR_PROFILER_ALREADY_STOPPED:
|
||||
throw profiler_already_stopped();
|
||||
case CUDA_ERROR_NO_DEVICE:
|
||||
throw no_device();
|
||||
case CUDA_ERROR_INVALID_DEVICE:
|
||||
throw invalid_device();
|
||||
case CUDA_ERROR_INVALID_IMAGE:
|
||||
throw invalid_image();
|
||||
case CUDA_ERROR_INVALID_CONTEXT:
|
||||
throw invalid_context();
|
||||
case CUDA_ERROR_CONTEXT_ALREADY_CURRENT:
|
||||
throw context_already_current();
|
||||
case CUDA_ERROR_MAP_FAILED:
|
||||
throw map_failed();
|
||||
case CUDA_ERROR_UNMAP_FAILED:
|
||||
throw unmap_failed();
|
||||
case CUDA_ERROR_ARRAY_IS_MAPPED:
|
||||
throw array_is_mapped();
|
||||
case CUDA_ERROR_ALREADY_MAPPED:
|
||||
throw already_mapped();
|
||||
case CUDA_ERROR_NO_BINARY_FOR_GPU:
|
||||
throw no_binary_for_gpu();
|
||||
case CUDA_ERROR_ALREADY_ACQUIRED:
|
||||
throw already_acquired();
|
||||
case CUDA_ERROR_NOT_MAPPED:
|
||||
throw not_mapped();
|
||||
case CUDA_ERROR_NOT_MAPPED_AS_ARRAY:
|
||||
throw not_mapped_as_array();
|
||||
case CUDA_ERROR_NOT_MAPPED_AS_POINTER:
|
||||
throw not_mapped_as_pointer();
|
||||
case CUDA_ERROR_ECC_UNCORRECTABLE:
|
||||
throw ecc_uncorrectable();
|
||||
case CUDA_ERROR_UNSUPPORTED_LIMIT:
|
||||
throw unsupported_limit();
|
||||
case CUDA_ERROR_CONTEXT_ALREADY_IN_USE:
|
||||
throw context_already_in_use();
|
||||
case CUDA_ERROR_PEER_ACCESS_UNSUPPORTED:
|
||||
throw peer_access_unsupported();
|
||||
case CUDA_ERROR_INVALID_PTX:
|
||||
throw invalid_ptx();
|
||||
case CUDA_ERROR_INVALID_GRAPHICS_CONTEXT:
|
||||
throw invalid_graphics_context();
|
||||
case CUDA_ERROR_INVALID_SOURCE:
|
||||
throw invalid_source();
|
||||
case CUDA_ERROR_FILE_NOT_FOUND:
|
||||
throw file_not_found();
|
||||
case CUDA_ERROR_SHARED_OBJECT_SYMBOL_NOT_FOUND:
|
||||
throw shared_object_symbol_not_found();
|
||||
case CUDA_ERROR_SHARED_OBJECT_INIT_FAILED:
|
||||
throw shared_object_init_failed();
|
||||
case CUDA_ERROR_OPERATING_SYSTEM:
|
||||
throw operating_system();
|
||||
case CUDA_ERROR_INVALID_HANDLE:
|
||||
throw invalid_handle();
|
||||
case CUDA_ERROR_NOT_FOUND:
|
||||
throw not_found();
|
||||
case CUDA_ERROR_NOT_READY:
|
||||
throw not_ready();
|
||||
case CUDA_ERROR_ILLEGAL_ADDRESS:
|
||||
throw illegal_address();
|
||||
case CUDA_ERROR_LAUNCH_OUT_OF_RESOURCES:
|
||||
throw launch_out_of_resources();
|
||||
case CUDA_ERROR_LAUNCH_TIMEOUT:
|
||||
throw launch_timeout();
|
||||
case CUDA_ERROR_LAUNCH_INCOMPATIBLE_TEXTURING:
|
||||
throw launch_incompatible_texturing();
|
||||
case CUDA_ERROR_PEER_ACCESS_ALREADY_ENABLED:
|
||||
throw peer_access_already_enabled();
|
||||
case CUDA_ERROR_PEER_ACCESS_NOT_ENABLED:
|
||||
throw peer_access_not_enabled();
|
||||
case CUDA_ERROR_PRIMARY_CONTEXT_ACTIVE:
|
||||
throw primary_context_active();
|
||||
case CUDA_ERROR_CONTEXT_IS_DESTROYED:
|
||||
throw context_is_destroyed();
|
||||
case CUDA_ERROR_ASSERT:
|
||||
throw assert_error();
|
||||
case CUDA_ERROR_TOO_MANY_PEERS:
|
||||
throw too_many_peers();
|
||||
case CUDA_ERROR_HOST_MEMORY_ALREADY_REGISTERED:
|
||||
throw host_memory_already_registered();
|
||||
case CUDA_ERROR_HOST_MEMORY_NOT_REGISTERED:
|
||||
throw host_memory_not_registered();
|
||||
case CUDA_ERROR_HARDWARE_STACK_ERROR:
|
||||
throw hardware_stack_error();
|
||||
case CUDA_ERROR_ILLEGAL_INSTRUCTION:
|
||||
throw illegal_instruction();
|
||||
case CUDA_ERROR_MISALIGNED_ADDRESS:
|
||||
throw misaligned_address();
|
||||
case CUDA_ERROR_INVALID_ADDRESS_SPACE:
|
||||
throw invalid_address_space();
|
||||
case CUDA_ERROR_INVALID_PC:
|
||||
throw invalid_pc();
|
||||
case CUDA_ERROR_LAUNCH_FAILED:
|
||||
throw launch_failed();
|
||||
case CUDA_ERROR_NOT_PERMITTED:
|
||||
throw not_permitted();
|
||||
case CUDA_ERROR_NOT_SUPPORTED:
|
||||
throw not_supported();
|
||||
case CUDA_ERROR_UNKNOWN:
|
||||
throw unknown();
|
||||
default:
|
||||
throw unknown();
|
||||
}
|
||||
}
|
||||
|
||||
void check(hipError_t error) {
|
||||
using namespace exception::hip;
|
||||
switch(error)
|
||||
{
|
||||
case hipSuccess : break;
|
||||
case hipErrorInvalidValue : throw invalid_value();
|
||||
case hipErrorMemoryAllocation : throw out_of_memory();
|
||||
case hipErrorNotInitialized : throw not_initialized();
|
||||
case hipErrorDeinitialized : throw deinitialized();
|
||||
case hipErrorProfilerDisabled : throw profiler_disabled();
|
||||
case hipErrorProfilerNotInitialized : throw profiler_not_initialized();
|
||||
case hipErrorProfilerAlreadyStarted : throw profiler_already_started();
|
||||
case hipErrorProfilerAlreadyStopped : throw profiler_already_stopped();
|
||||
case hipErrorNoDevice : throw no_device();
|
||||
case hipErrorInvalidSymbol : throw invalid_symbol();
|
||||
case hipErrorInvalidDevice : throw invalid_device();
|
||||
case hipErrorInvalidImage : throw invalid_image();
|
||||
case hipErrorInvalidContext : throw invalid_context();
|
||||
case hipErrorContextAlreadyCurrent : throw context_already_current();
|
||||
case hipErrorMapFailed : throw map_failed();
|
||||
case hipErrorUnmapFailed : throw unmap_failed();
|
||||
case hipErrorArrayIsMapped : throw array_is_mapped();
|
||||
case hipErrorAlreadyMapped : throw already_mapped();
|
||||
case hipErrorNoBinaryForGpu : throw no_binary_for_gpu();
|
||||
case hipErrorAlreadyAcquired : throw already_acquired();
|
||||
case hipErrorNotMapped : throw not_mapped();
|
||||
case hipErrorNotMappedAsArray : throw not_mapped_as_array();
|
||||
case hipErrorNotMappedAsPointer : throw not_mapped_as_pointer();
|
||||
case hipErrorECCNotCorrectable : throw ecc_uncorrectable();
|
||||
case hipErrorUnsupportedLimit : throw unsupported_limit();
|
||||
case hipErrorContextAlreadyInUse : throw context_already_in_use();
|
||||
case hipErrorPeerAccessUnsupported : throw peer_access_unsupported();
|
||||
case hipErrorInvalidKernelFile : throw invalid_ptx();
|
||||
case hipErrorInvalidGraphicsContext : throw invalid_graphics_context();
|
||||
case hipErrorInvalidSource : throw invalid_source();
|
||||
case hipErrorFileNotFound : throw file_not_found();
|
||||
case hipErrorSharedObjectSymbolNotFound : throw shared_object_symbol_not_found();
|
||||
case hipErrorSharedObjectInitFailed : throw shared_object_init_failed();
|
||||
case hipErrorOperatingSystem : throw operating_system();
|
||||
case hipErrorInvalidResourceHandle : throw invalid_handle();
|
||||
case hipErrorNotFound : throw not_found();
|
||||
case hipErrorNotReady : throw not_ready();
|
||||
case hipErrorIllegalAddress : throw illegal_address();
|
||||
case hipErrorLaunchOutOfResources : throw launch_out_of_resources();
|
||||
case hipErrorLaunchTimeOut : throw launch_timeout();
|
||||
// case hipErrorLaunchIncompatibleTexturing : throw launch_incompatible_texturing();
|
||||
case hipErrorPeerAccessAlreadyEnabled : throw peer_access_already_enabled();
|
||||
case hipErrorPeerAccessNotEnabled : throw peer_access_not_enabled();
|
||||
// case hipErrorPrimaryContextActive : throw primary_context_active();
|
||||
// case hipErrorContextIsDestroyed : throw context_is_destroyed();
|
||||
case hipErrorAssert : throw assert_error();
|
||||
// case hipErrorTooManyPeers : throw too_many_peers();
|
||||
case hipErrorHostMemoryAlreadyRegistered : throw host_memory_already_registered();
|
||||
case hipErrorHostMemoryNotRegistered : throw host_memory_not_registered();
|
||||
// case hipErrorHardwareStackError : throw hardware_stack_error();
|
||||
// case hipErrorIllegalInstruction : throw illegal_instruction();
|
||||
// case hipErrorMisalignedAddress : throw misaligned_address();
|
||||
// case hipErrorInvalidAddressSpace : throw invalid_address_space();
|
||||
// case hipErrorInvalidPc : throw invalid_pc();
|
||||
case hipErrorLaunchFailure : throw launch_failed();
|
||||
// case hipErrorNotPermitted : throw not_permitted();
|
||||
case hipErrorNotSupported : throw not_supported();
|
||||
case hipErrorUnknown : throw unknown();
|
||||
default : throw unknown();
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
switch (error) {
|
||||
case hipSuccess:
|
||||
break;
|
||||
case hipErrorInvalidValue:
|
||||
throw invalid_value();
|
||||
case hipErrorMemoryAllocation:
|
||||
throw out_of_memory();
|
||||
case hipErrorNotInitialized:
|
||||
throw not_initialized();
|
||||
case hipErrorDeinitialized:
|
||||
throw deinitialized();
|
||||
case hipErrorProfilerDisabled:
|
||||
throw profiler_disabled();
|
||||
case hipErrorProfilerNotInitialized:
|
||||
throw profiler_not_initialized();
|
||||
case hipErrorProfilerAlreadyStarted:
|
||||
throw profiler_already_started();
|
||||
case hipErrorProfilerAlreadyStopped:
|
||||
throw profiler_already_stopped();
|
||||
case hipErrorNoDevice:
|
||||
throw no_device();
|
||||
case hipErrorInvalidSymbol:
|
||||
throw invalid_symbol();
|
||||
case hipErrorInvalidDevice:
|
||||
throw invalid_device();
|
||||
case hipErrorInvalidImage:
|
||||
throw invalid_image();
|
||||
case hipErrorInvalidContext:
|
||||
throw invalid_context();
|
||||
case hipErrorContextAlreadyCurrent:
|
||||
throw context_already_current();
|
||||
case hipErrorMapFailed:
|
||||
throw map_failed();
|
||||
case hipErrorUnmapFailed:
|
||||
throw unmap_failed();
|
||||
case hipErrorArrayIsMapped:
|
||||
throw array_is_mapped();
|
||||
case hipErrorAlreadyMapped:
|
||||
throw already_mapped();
|
||||
case hipErrorNoBinaryForGpu:
|
||||
throw no_binary_for_gpu();
|
||||
case hipErrorAlreadyAcquired:
|
||||
throw already_acquired();
|
||||
case hipErrorNotMapped:
|
||||
throw not_mapped();
|
||||
case hipErrorNotMappedAsArray:
|
||||
throw not_mapped_as_array();
|
||||
case hipErrorNotMappedAsPointer:
|
||||
throw not_mapped_as_pointer();
|
||||
case hipErrorECCNotCorrectable:
|
||||
throw ecc_uncorrectable();
|
||||
case hipErrorUnsupportedLimit:
|
||||
throw unsupported_limit();
|
||||
case hipErrorContextAlreadyInUse:
|
||||
throw context_already_in_use();
|
||||
case hipErrorPeerAccessUnsupported:
|
||||
throw peer_access_unsupported();
|
||||
case hipErrorInvalidKernelFile:
|
||||
throw invalid_ptx();
|
||||
case hipErrorInvalidGraphicsContext:
|
||||
throw invalid_graphics_context();
|
||||
case hipErrorInvalidSource:
|
||||
throw invalid_source();
|
||||
case hipErrorFileNotFound:
|
||||
throw file_not_found();
|
||||
case hipErrorSharedObjectSymbolNotFound:
|
||||
throw shared_object_symbol_not_found();
|
||||
case hipErrorSharedObjectInitFailed:
|
||||
throw shared_object_init_failed();
|
||||
case hipErrorOperatingSystem:
|
||||
throw operating_system();
|
||||
case hipErrorInvalidResourceHandle:
|
||||
throw invalid_handle();
|
||||
case hipErrorNotFound:
|
||||
throw not_found();
|
||||
case hipErrorNotReady:
|
||||
throw not_ready();
|
||||
case hipErrorIllegalAddress:
|
||||
throw illegal_address();
|
||||
case hipErrorLaunchOutOfResources:
|
||||
throw launch_out_of_resources();
|
||||
case hipErrorLaunchTimeOut:
|
||||
throw launch_timeout();
|
||||
// case hipErrorLaunchIncompatibleTexturing : throw
|
||||
// launch_incompatible_texturing();
|
||||
case hipErrorPeerAccessAlreadyEnabled:
|
||||
throw peer_access_already_enabled();
|
||||
case hipErrorPeerAccessNotEnabled:
|
||||
throw peer_access_not_enabled();
|
||||
// case hipErrorPrimaryContextActive : throw primary_context_active();
|
||||
// case hipErrorContextIsDestroyed : throw context_is_destroyed();
|
||||
case hipErrorAssert:
|
||||
throw assert_error();
|
||||
// case hipErrorTooManyPeers : throw too_many_peers();
|
||||
case hipErrorHostMemoryAlreadyRegistered:
|
||||
throw host_memory_already_registered();
|
||||
case hipErrorHostMemoryNotRegistered:
|
||||
throw host_memory_not_registered();
|
||||
// case hipErrorHardwareStackError : throw hardware_stack_error();
|
||||
// case hipErrorIllegalInstruction : throw illegal_instruction();
|
||||
// case hipErrorMisalignedAddress : throw misaligned_address();
|
||||
// case hipErrorInvalidAddressSpace : throw invalid_address_space();
|
||||
// case hipErrorInvalidPc : throw invalid_pc();
|
||||
case hipErrorLaunchFailure:
|
||||
throw launch_failed();
|
||||
// case hipErrorNotPermitted : throw not_permitted();
|
||||
case hipErrorNotSupported:
|
||||
throw not_supported();
|
||||
case hipErrorUnknown:
|
||||
throw unknown();
|
||||
default:
|
||||
throw unknown();
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace driver
|
||||
} // namespace triton
|
||||
|
@@ -1,73 +1,73 @@
|
||||
/* Copyright 2015-2017 Philippe Tillet
|
||||
*
|
||||
* Permission is hereby granted, free of charge, to any person obtaining
|
||||
* a copy of this software and associated documentation files
|
||||
* (the "Software"), to deal in the Software without restriction,
|
||||
* including without limitation the rights to use, copy, modify, merge,
|
||||
* publish, distribute, sublicense, and/or sell copies of the Software,
|
||||
* and to permit persons to whom the Software is furnished to do so,
|
||||
* subject to the following conditions:
|
||||
*
|
||||
* The above copyright notice and this permission notice shall be
|
||||
* included in all copies or substantial portions of the Software.
|
||||
*
|
||||
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
|
||||
* EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
|
||||
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
|
||||
* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
|
||||
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
|
||||
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
|
||||
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
||||
*/
|
||||
*
|
||||
* Permission is hereby granted, free of charge, to any person obtaining
|
||||
* a copy of this software and associated documentation files
|
||||
* (the "Software"), to deal in the Software without restriction,
|
||||
* including without limitation the rights to use, copy, modify, merge,
|
||||
* publish, distribute, sublicense, and/or sell copies of the Software,
|
||||
* and to permit persons to whom the Software is furnished to do so,
|
||||
* subject to the following conditions:
|
||||
*
|
||||
* The above copyright notice and this permission notice shall be
|
||||
* included in all copies or substantial portions of the Software.
|
||||
*
|
||||
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
|
||||
* EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
|
||||
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
|
||||
* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
|
||||
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
|
||||
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
|
||||
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
||||
*/
|
||||
#include <fstream>
|
||||
#if __has_include(<unistd.h>)
|
||||
#include <unistd.h>
|
||||
#include <unistd.h>
|
||||
#endif
|
||||
#include <memory>
|
||||
#include <regex>
|
||||
#include "triton/driver/llvm.h"
|
||||
#include "triton/driver/dispatch.h"
|
||||
#include "triton/driver/error.h"
|
||||
#include "triton/driver/llvm.h"
|
||||
#include "triton/tools/sha1.hpp"
|
||||
#include "triton/tools/sys/exec.hpp"
|
||||
#include "triton/tools/sys/getenv.hpp"
|
||||
#include "triton/tools/sys/mkdir.hpp"
|
||||
#include "triton/tools/sys/exec.hpp"
|
||||
#include "llvm/MC/TargetRegistry.h"
|
||||
#include "llvm/ExecutionEngine/ExecutionEngine.h"
|
||||
#include "llvm/ExecutionEngine/SectionMemoryManager.h"
|
||||
#include "llvm/IR/IRBuilder.h"
|
||||
#include "llvm/IR/Verifier.h"
|
||||
#include "llvm/IR/IRPrintingPasses.h"
|
||||
#include "llvm/IR/LegacyPassManager.h"
|
||||
#include "llvm/IR/Module.h"
|
||||
#include "llvm/IR/Verifier.h"
|
||||
#include "llvm/MC/TargetRegistry.h"
|
||||
#include "llvm/Support/CodeGen.h"
|
||||
#include "llvm/Support/CommandLine.h"
|
||||
#include "llvm/Support/SourceMgr.h"
|
||||
#include "llvm/Support/raw_ostream.h"
|
||||
#include "llvm/Support/TargetSelect.h"
|
||||
#include "llvm/Support/raw_ostream.h"
|
||||
#include "llvm/Target/TargetMachine.h"
|
||||
#include "llvm/Target/TargetOptions.h"
|
||||
#include "llvm/IR/LegacyPassManager.h"
|
||||
#include "llvm/ExecutionEngine/ExecutionEngine.h"
|
||||
#include "llvm/ExecutionEngine/SectionMemoryManager.h"
|
||||
#include "llvm/Transforms/Utils/Cloning.h"
|
||||
#include "llvm/Transforms/Scalar.h"
|
||||
#include "llvm/Transforms/Utils/Cloning.h"
|
||||
#include <memory>
|
||||
#include <regex>
|
||||
|
||||
// begin AMD stuff
|
||||
#include "llvm/ADT/StringRef.h"
|
||||
#include "llvm/Analysis/TargetLibraryInfo.h"
|
||||
#include "llvm/Support/FileSystem.h"
|
||||
#include "llvm/Support/FormattedStream.h"
|
||||
#include "llvm/Support/Program.h"
|
||||
#include "llvm/Support/ToolOutputFile.h"
|
||||
#include "llvm/ADT/StringRef.h"
|
||||
#include "llvm/Analysis/TargetLibraryInfo.h"
|
||||
// end AMD stuff
|
||||
|
||||
extern "C"{
|
||||
int set_curterm(char* nterm){ return 0; }
|
||||
int del_curterm(char* nterm){ return 0; }
|
||||
int tigetnum(char *capname) { return 0; }
|
||||
int setupterm(char *term, int fildes, int *errret) { return 0; }
|
||||
extern "C" {
|
||||
int set_curterm(char *nterm) { return 0; }
|
||||
int del_curterm(char *nterm) { return 0; }
|
||||
int tigetnum(char *capname) { return 0; }
|
||||
int setupterm(char *term, int fildes, int *errret) { return 0; }
|
||||
}
|
||||
|
||||
namespace triton{
|
||||
namespace driver{
|
||||
namespace triton {
|
||||
namespace driver {
|
||||
|
||||
void init_llvm() {
|
||||
LLVMInitializeNVPTXTargetInfo();
|
||||
@@ -80,82 +80,93 @@ void init_llvm() {
|
||||
LLVMInitializeAMDGPUAsmPrinter();
|
||||
}
|
||||
|
||||
|
||||
/* ------------------------ */
|
||||
// CUDA //
|
||||
/* ------------------------ */
|
||||
static bool find_and_replace(std::string& str, const std::string& begin, const std::string& end, const std::string& target){
|
||||
static bool find_and_replace(std::string &str, const std::string &begin,
|
||||
const std::string &end,
|
||||
const std::string &target) {
|
||||
size_t start_replace = str.find(begin);
|
||||
size_t end_replace = str.find(end, start_replace);
|
||||
if(start_replace == std::string::npos)
|
||||
if (start_replace == std::string::npos)
|
||||
return false;
|
||||
str.replace(start_replace, end_replace + 1 - start_replace, target);
|
||||
return true;
|
||||
}
|
||||
|
||||
std::string path_to_ptxas(int& version) {
|
||||
std::string path_to_ptxas(int &version) {
|
||||
std::vector<std::string> rets;
|
||||
std::string ret;
|
||||
// search pathes for ptxas
|
||||
std::vector<std::string> ptxas_prefixes = {"", "/usr/local/cuda/bin/"};
|
||||
std::string triton_ptxas = tools::getenv("TRITON_PTXAS_PATH");
|
||||
if(!triton_ptxas.empty())
|
||||
if (!triton_ptxas.empty())
|
||||
ptxas_prefixes.insert(ptxas_prefixes.begin(), triton_ptxas);
|
||||
// see what path for ptxas are valid
|
||||
std::vector<std::string> working_ptxas;
|
||||
for(std::string prefix: ptxas_prefixes){
|
||||
for (std::string prefix : ptxas_prefixes) {
|
||||
std::string ptxas = prefix + "ptxas";
|
||||
bool works = tools::exec(ptxas + " --version 2>&1", ret) == 0;
|
||||
if(works) {
|
||||
if (works) {
|
||||
working_ptxas.push_back(ptxas);
|
||||
rets.push_back(ret);
|
||||
}
|
||||
}
|
||||
// error if no working ptxas was found
|
||||
if(working_ptxas.empty())
|
||||
throw std::runtime_error("`ptxas` was searched in TRITON_PTXAS_PATH, /usr/local/cuda/bin/ or PATH"
|
||||
if (working_ptxas.empty())
|
||||
throw std::runtime_error("`ptxas` was searched in TRITON_PTXAS_PATH, "
|
||||
"/usr/local/cuda/bin/ or PATH"
|
||||
" but a working version could not be found.");
|
||||
std::string ptxas = working_ptxas.front();
|
||||
// parse version
|
||||
std::regex version_regex("release (\\d+)\\.(\\d+)");
|
||||
std::smatch match;
|
||||
bool found = false;
|
||||
// currently choosing the first ptxas. Other logics can be implemented in future
|
||||
for(std::string ret : rets) {
|
||||
if(std::regex_search(ret, match, version_regex)){
|
||||
// currently choosing the first ptxas. Other logics can be implemented in
|
||||
// future
|
||||
for (std::string ret : rets) {
|
||||
if (std::regex_search(ret, match, version_regex)) {
|
||||
int major = std::stoi(match[1]);
|
||||
int minor = std::stoi(match[2]);
|
||||
version = major*1000 + minor*10;
|
||||
version = major * 1000 + minor * 10;
|
||||
found = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
if ( not found) {
|
||||
if (not found) {
|
||||
throw std::runtime_error("Error in parsing version");
|
||||
}
|
||||
return ptxas;
|
||||
}
|
||||
|
||||
|
||||
int vptx(int version){
|
||||
if(version >= 11040) return 74;
|
||||
if(version >= 11030) return 73;
|
||||
if(version >= 11020) return 72;
|
||||
if(version >= 11010) return 71;
|
||||
if(version >= 11000) return 70;
|
||||
if(version >= 10020) return 65;
|
||||
if(version >= 10010) return 64;
|
||||
if(version >= 10000) return 63;
|
||||
int vptx(int version) {
|
||||
if (version >= 11040)
|
||||
return 74;
|
||||
if (version >= 11030)
|
||||
return 73;
|
||||
if (version >= 11020)
|
||||
return 72;
|
||||
if (version >= 11010)
|
||||
return 71;
|
||||
if (version >= 11000)
|
||||
return 70;
|
||||
if (version >= 10020)
|
||||
return 65;
|
||||
if (version >= 10010)
|
||||
return 64;
|
||||
if (version >= 10000)
|
||||
return 63;
|
||||
throw std::runtime_error("Triton requires CUDA 10+");
|
||||
}
|
||||
|
||||
std::string llir_to_ptx(llvm::Module* module, int cc, int version){
|
||||
std::string llir_to_ptx(llvm::Module *module, int cc, int version) {
|
||||
// LLVM version in use may not officially support target hardware
|
||||
int max_nvvm_cc = 75;
|
||||
int max_nvvm_ptx = 74;
|
||||
// options
|
||||
auto options = llvm::cl::getRegisteredOptions();
|
||||
auto* short_ptr = static_cast<llvm::cl::opt<bool>*>(options["nvptx-short-ptr"]);
|
||||
auto *short_ptr =
|
||||
static_cast<llvm::cl::opt<bool> *>(options["nvptx-short-ptr"]);
|
||||
assert(short_ptr);
|
||||
short_ptr->setValue(true);
|
||||
// compute capability
|
||||
@@ -170,7 +181,8 @@ std::string llir_to_ptx(llvm::Module* module, int cc, int version){
|
||||
std::string proc = "sm_" + std::to_string(std::min(cc, max_nvvm_cc));
|
||||
std::string layout = "";
|
||||
std::string features = "";
|
||||
// std::string features = "+ptx" + std::to_string(std::min(ptx, max_nvvm_ptx));
|
||||
// std::string features = "+ptx" + std::to_string(std::min(ptx,
|
||||
// max_nvvm_ptx));
|
||||
init_llvm();
|
||||
// verify and store llvm
|
||||
llvm::legacy::PassManager pm;
|
||||
@@ -181,16 +193,18 @@ std::string llir_to_ptx(llvm::Module* module, int cc, int version){
|
||||
// create machine
|
||||
module->setTargetTriple(triple);
|
||||
std::string error;
|
||||
auto target = llvm::TargetRegistry::lookupTarget(module->getTargetTriple(), error);
|
||||
auto target =
|
||||
llvm::TargetRegistry::lookupTarget(module->getTargetTriple(), error);
|
||||
llvm::TargetOptions opt;
|
||||
opt.AllowFPOpFusion = llvm::FPOpFusion::Fast;
|
||||
opt.UnsafeFPMath = false;
|
||||
opt.NoInfsFPMath = false;
|
||||
opt.NoNaNsFPMath = true;
|
||||
llvm::TargetMachine *machine = target->createTargetMachine(module->getTargetTriple(), proc, features, opt,
|
||||
llvm::Reloc::PIC_, llvm::None, llvm::CodeGenOpt::Aggressive);
|
||||
llvm::TargetMachine *machine = target->createTargetMachine(
|
||||
module->getTargetTriple(), proc, features, opt, llvm::Reloc::PIC_,
|
||||
llvm::None, llvm::CodeGenOpt::Aggressive);
|
||||
// set data layout
|
||||
if(layout.empty())
|
||||
if (layout.empty())
|
||||
module->setDataLayout(machine->createDataLayout());
|
||||
else
|
||||
module->setDataLayout(layout);
|
||||
@@ -200,19 +214,25 @@ std::string llir_to_ptx(llvm::Module* module, int cc, int version){
|
||||
llvm::legacy::PassManager pass;
|
||||
llvm::raw_svector_ostream stream(buffer);
|
||||
// emit
|
||||
machine->addPassesToEmitFile(pass, stream, nullptr, llvm::CodeGenFileType::CGFT_AssemblyFile);
|
||||
machine->addPassesToEmitFile(pass, stream, nullptr,
|
||||
llvm::CodeGenFileType::CGFT_AssemblyFile);
|
||||
pass.run(*module);
|
||||
|
||||
// post-process
|
||||
std::string result(buffer.begin(), buffer.end());
|
||||
find_and_replace(result, ".version", "\n", ".version " + std::to_string(ptx_major) + "." + std::to_string(ptx_minor) + "\n");
|
||||
find_and_replace(result, ".version", "\n",
|
||||
".version " + std::to_string(ptx_major) + "." +
|
||||
std::to_string(ptx_minor) + "\n");
|
||||
find_and_replace(result, ".target", "\n", ".target " + sm + "\n");
|
||||
while(find_and_replace(result, "\t// begin inline asm", "\n", ""));
|
||||
while(find_and_replace(result, "\t// end inline asm", "\n", ""));
|
||||
while (find_and_replace(result, "\t// begin inline asm", "\n", ""))
|
||||
;
|
||||
while (find_and_replace(result, "\t// end inline asm", "\n", ""))
|
||||
;
|
||||
return result;
|
||||
}
|
||||
|
||||
std::string ptx_to_cubin(const std::string& ptx, const std::string& ptxas, int cc) {
|
||||
std::string ptx_to_cubin(const std::string &ptx, const std::string &ptxas,
|
||||
int cc) {
|
||||
// compile ptx with ptxas
|
||||
char _fsrc[L_tmpnam];
|
||||
char _flog[L_tmpnam];
|
||||
@@ -221,15 +241,16 @@ std::string ptx_to_cubin(const std::string& ptx, const std::string& ptxas, int c
|
||||
std::string fsrc = _fsrc;
|
||||
std::string flog = _flog;
|
||||
std::string fbin = fsrc + ".o";
|
||||
const char* _fbin = fbin.c_str();
|
||||
const char *_fbin = fbin.c_str();
|
||||
std::ofstream ofs(fsrc);
|
||||
ofs << ptx << std::endl;
|
||||
ofs.close();
|
||||
std::string cmd;
|
||||
int err;
|
||||
cmd = ptxas + " -v --gpu-name=sm_" + std::to_string(cc) + " " + fsrc + " -o " + fsrc + ".o 2> " + flog;
|
||||
cmd = ptxas + " -v --gpu-name=sm_" + std::to_string(cc) + " " + fsrc +
|
||||
" -o " + fsrc + ".o 2> " + flog;
|
||||
err = system(cmd.c_str());
|
||||
if(err != 0){
|
||||
if (err != 0) {
|
||||
std::ifstream _log(_flog);
|
||||
std::string log(std::istreambuf_iterator<char>(_log), {});
|
||||
unlink(_fsrc);
|
||||
@@ -237,7 +258,7 @@ std::string ptx_to_cubin(const std::string& ptx, const std::string& ptxas, int c
|
||||
throw std::runtime_error("Internal Triton PTX codegen error: \n" + log);
|
||||
}
|
||||
CUmodule ret;
|
||||
std::ifstream _cubin(_fbin, std::ios::binary );
|
||||
std::ifstream _cubin(_fbin, std::ios::binary);
|
||||
std::string cubin(std::istreambuf_iterator<char>(_cubin), {});
|
||||
_cubin.close();
|
||||
unlink(_fsrc);
|
||||
@@ -251,11 +272,11 @@ std::string ptx_to_cubin(const std::string& ptx, const std::string& ptxas, int c
|
||||
// HIP //
|
||||
/* ------------------------ */
|
||||
|
||||
std::string llir_to_amdgpu(llvm::Module* module, const std::string& _proc) {
|
||||
std::string llir_to_amdgpu(llvm::Module *module, const std::string &_proc) {
|
||||
init_llvm();
|
||||
|
||||
// proc = std::get<0>(GetFeatureStrFromGCNArchName(rocminfo));
|
||||
// features = std::get<1>(GetFeatureStrFromGCNArchName(rocminfo));
|
||||
// proc = std::get<0>(GetFeatureStrFromGCNArchName(rocminfo));
|
||||
// features = std::get<1>(GetFeatureStrFromGCNArchName(rocminfo));
|
||||
|
||||
// create
|
||||
llvm::SmallVector<char, 0> buffer;
|
||||
@@ -270,17 +291,18 @@ std::string llir_to_amdgpu(llvm::Module* module, const std::string& _proc) {
|
||||
// create machine
|
||||
module->setTargetTriple(triple);
|
||||
std::string error;
|
||||
auto target = llvm::TargetRegistry::lookupTarget(module->getTargetTriple(), error);
|
||||
auto target =
|
||||
llvm::TargetRegistry::lookupTarget(module->getTargetTriple(), error);
|
||||
llvm::TargetOptions opt;
|
||||
opt.AllowFPOpFusion = llvm::FPOpFusion::Fast;
|
||||
opt.UnsafeFPMath = false;
|
||||
opt.NoInfsFPMath = false;
|
||||
opt.NoNaNsFPMath = true;
|
||||
llvm::TargetMachine *machine = target->createTargetMachine(module->getTargetTriple(), proc, features, opt,
|
||||
llvm::Reloc::PIC_, llvm::None,
|
||||
llvm::CodeGenOpt::Aggressive);
|
||||
llvm::TargetMachine *machine = target->createTargetMachine(
|
||||
module->getTargetTriple(), proc, features, opt, llvm::Reloc::PIC_,
|
||||
llvm::None, llvm::CodeGenOpt::Aggressive);
|
||||
// set data layout
|
||||
if(layout.empty())
|
||||
if (layout.empty())
|
||||
module->setDataLayout(machine->createDataLayout());
|
||||
else
|
||||
module->setDataLayout(layout);
|
||||
@@ -295,33 +317,37 @@ std::string llir_to_amdgpu(llvm::Module* module, const std::string& _proc) {
|
||||
std::error_code ec;
|
||||
|
||||
// Save GCN ISA binary.
|
||||
std::string isabin_path = std::string("/tmp/") + module_name + std::string(".o");
|
||||
std::string isabin_path =
|
||||
std::string("/tmp/") + module_name + std::string(".o");
|
||||
std::unique_ptr<llvm::raw_fd_ostream> isabin_fs(
|
||||
new llvm::raw_fd_ostream(isabin_path, ec, llvm::sys::fs::OF_Text));
|
||||
if (ec)
|
||||
{
|
||||
std::cout << isabin_path << " was not created. error code: " << ec << std::endl;
|
||||
if (ec) {
|
||||
std::cout << isabin_path << " was not created. error code: " << ec
|
||||
<< std::endl;
|
||||
}
|
||||
|
||||
// emit
|
||||
machine->addPassesToEmitFile(pass, *isabin_fs, nullptr, llvm::CGFT_ObjectFile);
|
||||
machine->addPassesToEmitFile(pass, *isabin_fs, nullptr,
|
||||
llvm::CGFT_ObjectFile);
|
||||
pass.run(*module);
|
||||
// Save GCN ISA.
|
||||
std::string amdgcn_path = std::string("/tmp/") + module_name + std::string(".gcn");
|
||||
std::string amdgcn_path =
|
||||
std::string("/tmp/") + module_name + std::string(".gcn");
|
||||
std::string result(buffer.begin(), buffer.end());
|
||||
std::ofstream amdgcn(amdgcn_path);
|
||||
amdgcn << result;
|
||||
amdgcn.close();
|
||||
|
||||
// generate HASCO file
|
||||
std::string hsaco_path = std::string("/tmp/") + module_name + std::string(".hsaco");
|
||||
std::string hsaco_path =
|
||||
std::string("/tmp/") + module_name + std::string(".hsaco");
|
||||
std::string error_message;
|
||||
int lld_result =
|
||||
llvm::sys::ExecuteAndWait("/opt/rocm/llvm/bin/ld.lld",
|
||||
{"/opt/rocm/llvm/bin/ld.lld", "-flavor", "gnu", "-shared", "-o", hsaco_path, isabin_path},
|
||||
{"/opt/rocm/llvm/bin/ld.lld", "-flavor", "gnu",
|
||||
"-shared", "-o", hsaco_path, isabin_path},
|
||||
llvm::None, {}, 0, 0, &error_message);
|
||||
if (lld_result)
|
||||
{
|
||||
if (lld_result) {
|
||||
std::cout << "ld.lld execute fail: " << std::endl;
|
||||
std::cout << error_message << std::endl;
|
||||
std::cout << lld_result << std::endl;
|
||||
@@ -330,33 +356,29 @@ std::string llir_to_amdgpu(llvm::Module* module, const std::string& _proc) {
|
||||
return hsaco_path;
|
||||
}
|
||||
|
||||
|
||||
hipModule_t amdgpu_to_hipmodule(const std::string& path) {
|
||||
hipModule_t amdgpu_to_hipmodule(const std::string &path) {
|
||||
// Read HSACO.
|
||||
std::ifstream hsaco_file(path, std::ios::binary | std::ios::ate);
|
||||
std::ifstream::pos_type hsaco_file_size = hsaco_file.tellg();
|
||||
|
||||
std::vector<unsigned char> hsaco(hsaco_file_size);
|
||||
hsaco_file.seekg(0, std::ios::beg);
|
||||
hsaco_file.read(reinterpret_cast<char*>(&hsaco[0]), hsaco_file_size);
|
||||
hsaco_file.read(reinterpret_cast<char *>(&hsaco[0]), hsaco_file_size);
|
||||
hsaco_file.close();
|
||||
hipJitOption opt[] = {hipJitOptionErrorLogBufferSizeBytes, hipJitOptionErrorLogBuffer,
|
||||
hipJitOptionInfoLogBufferSizeBytes, hipJitOptionInfoLogBuffer,
|
||||
hipJitOptionLogVerbose};
|
||||
hipJitOption opt[] = {hipJitOptionErrorLogBufferSizeBytes,
|
||||
hipJitOptionErrorLogBuffer,
|
||||
hipJitOptionInfoLogBufferSizeBytes,
|
||||
hipJitOptionInfoLogBuffer, hipJitOptionLogVerbose};
|
||||
const unsigned int errbufsize = 8192;
|
||||
const unsigned int logbufsize = 8192;
|
||||
char _err[errbufsize];
|
||||
char _log[logbufsize];
|
||||
void* optval[] = {(void*)(uintptr_t)errbufsize,
|
||||
(void*)_err, (void*)(uintptr_t)logbufsize,
|
||||
(void*)_log, (void*)1};
|
||||
void *optval[] = {(void *)(uintptr_t)errbufsize, (void *)_err,
|
||||
(void *)(uintptr_t)logbufsize, (void *)_log, (void *)1};
|
||||
hipModule_t ret;
|
||||
dispatch::hipModuleLoadDataEx(&ret, hsaco.data(), 5, opt, optval);
|
||||
return ret;
|
||||
}
|
||||
|
||||
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace driver
|
||||
} // namespace triton
|
||||
|
@@ -18,60 +18,83 @@ NAMESPACE_BEGIN(PYBIND11_NAMESPACE)
|
||||
/// @{
|
||||
|
||||
/// Annotation for methods
|
||||
struct is_method { handle class_; is_method(const handle &c) : class_(c) { } };
|
||||
struct is_method {
|
||||
handle class_;
|
||||
is_method(const handle &c) : class_(c) {}
|
||||
};
|
||||
|
||||
/// Annotation for operators
|
||||
struct is_operator { };
|
||||
struct is_operator {};
|
||||
|
||||
/// Annotation for parent scope
|
||||
struct scope { handle value; scope(const handle &s) : value(s) { } };
|
||||
struct scope {
|
||||
handle value;
|
||||
scope(const handle &s) : value(s) {}
|
||||
};
|
||||
|
||||
/// Annotation for documentation
|
||||
struct doc { const char *value; doc(const char *value) : value(value) { } };
|
||||
struct doc {
|
||||
const char *value;
|
||||
doc(const char *value) : value(value) {}
|
||||
};
|
||||
|
||||
/// Annotation for function names
|
||||
struct name { const char *value; name(const char *value) : value(value) { } };
|
||||
struct name {
|
||||
const char *value;
|
||||
name(const char *value) : value(value) {}
|
||||
};
|
||||
|
||||
/// Annotation indicating that a function is an overload associated with a given "sibling"
|
||||
struct sibling { handle value; sibling(const handle &value) : value(value.ptr()) { } };
|
||||
/// Annotation indicating that a function is an overload associated with a given
|
||||
/// "sibling"
|
||||
struct sibling {
|
||||
handle value;
|
||||
sibling(const handle &value) : value(value.ptr()) {}
|
||||
};
|
||||
|
||||
/// Annotation indicating that a class derives from another given type
|
||||
template <typename T> struct base {
|
||||
PYBIND11_DEPRECATED("base<T>() was deprecated in favor of specifying 'T' as a template argument to class_")
|
||||
base() { }
|
||||
PYBIND11_DEPRECATED("base<T>() was deprecated in favor of specifying 'T' as "
|
||||
"a template argument to class_")
|
||||
base() {}
|
||||
};
|
||||
|
||||
/// Keep patient alive while nurse lives
|
||||
template <size_t Nurse, size_t Patient> struct keep_alive { };
|
||||
template <size_t Nurse, size_t Patient> struct keep_alive {};
|
||||
|
||||
/// Annotation indicating that a class is involved in a multiple inheritance relationship
|
||||
struct multiple_inheritance { };
|
||||
/// Annotation indicating that a class is involved in a multiple inheritance
|
||||
/// relationship
|
||||
struct multiple_inheritance {};
|
||||
|
||||
/// Annotation which enables dynamic attributes, i.e. adds `__dict__` to a class
|
||||
struct dynamic_attr { };
|
||||
struct dynamic_attr {};
|
||||
|
||||
/// Annotation which enables the buffer protocol for a type
|
||||
struct buffer_protocol { };
|
||||
struct buffer_protocol {};
|
||||
|
||||
/// Annotation which requests that a special metaclass is created for a type
|
||||
struct metaclass {
|
||||
handle value;
|
||||
handle value;
|
||||
|
||||
PYBIND11_DEPRECATED("py::metaclass() is no longer required. It's turned on by default now.")
|
||||
metaclass() {}
|
||||
PYBIND11_DEPRECATED(
|
||||
"py::metaclass() is no longer required. It's turned on by default now.")
|
||||
metaclass() {}
|
||||
|
||||
/// Override pybind11's default metaclass
|
||||
explicit metaclass(handle value) : value(value) { }
|
||||
/// Override pybind11's default metaclass
|
||||
explicit metaclass(handle value) : value(value) {}
|
||||
};
|
||||
|
||||
/// Annotation that marks a class as local to the module:
|
||||
struct module_local { const bool value; constexpr module_local(bool v = true) : value(v) { } };
|
||||
struct module_local {
|
||||
const bool value;
|
||||
constexpr module_local(bool v = true) : value(v) {}
|
||||
};
|
||||
|
||||
/// Annotation to mark enums as an arithmetic type
|
||||
struct arithmetic { };
|
||||
struct arithmetic {};
|
||||
|
||||
/** \rst
|
||||
A call policy which places one or more guard variables (``Ts...``) around the function call.
|
||||
A call policy which places one or more guard variables (``Ts...``) around
|
||||
the function call.
|
||||
|
||||
For example, this definition:
|
||||
|
||||
@@ -92,20 +115,19 @@ template <typename... Ts> struct call_guard;
|
||||
|
||||
template <> struct call_guard<> { using type = detail::void_type; };
|
||||
|
||||
template <typename T>
|
||||
struct call_guard<T> {
|
||||
static_assert(std::is_default_constructible<T>::value,
|
||||
"The guard type must be default constructible");
|
||||
template <typename T> struct call_guard<T> {
|
||||
static_assert(std::is_default_constructible<T>::value,
|
||||
"The guard type must be default constructible");
|
||||
|
||||
using type = T;
|
||||
using type = T;
|
||||
};
|
||||
|
||||
template <typename T, typename... Ts>
|
||||
struct call_guard<T, Ts...> {
|
||||
struct type {
|
||||
T guard{}; // Compose multiple guard types with left-to-right default-constructor order
|
||||
typename call_guard<Ts...>::type next{};
|
||||
};
|
||||
template <typename T, typename... Ts> struct call_guard<T, Ts...> {
|
||||
struct type {
|
||||
T guard{}; // Compose multiple guard types with left-to-right
|
||||
// default-constructor order
|
||||
typename call_guard<Ts...>::type next{};
|
||||
};
|
||||
};
|
||||
|
||||
/// @} annotations
|
||||
@@ -115,181 +137,190 @@ NAMESPACE_BEGIN(detail)
|
||||
enum op_id : int;
|
||||
enum op_type : int;
|
||||
struct undefined_t;
|
||||
template <op_id id, op_type ot, typename L = undefined_t, typename R = undefined_t> struct op_;
|
||||
inline void keep_alive_impl(size_t Nurse, size_t Patient, function_call &call, handle ret);
|
||||
template <op_id id, op_type ot, typename L = undefined_t,
|
||||
typename R = undefined_t>
|
||||
struct op_;
|
||||
inline void keep_alive_impl(size_t Nurse, size_t Patient, function_call &call,
|
||||
handle ret);
|
||||
|
||||
/// Internal data structure which holds metadata about a keyword argument
|
||||
struct argument_record {
|
||||
const char *name; ///< Argument name
|
||||
const char *descr; ///< Human-readable version of the argument value
|
||||
handle value; ///< Associated Python object
|
||||
bool convert : 1; ///< True if the argument is allowed to convert when loading
|
||||
bool none : 1; ///< True if None is allowed when loading
|
||||
const char *name; ///< Argument name
|
||||
const char *descr; ///< Human-readable version of the argument value
|
||||
handle value; ///< Associated Python object
|
||||
bool convert : 1; ///< True if the argument is allowed to convert when loading
|
||||
bool none : 1; ///< True if None is allowed when loading
|
||||
|
||||
argument_record(const char *name, const char *descr, handle value, bool convert, bool none)
|
||||
: name(name), descr(descr), value(value), convert(convert), none(none) { }
|
||||
argument_record(const char *name, const char *descr, handle value,
|
||||
bool convert, bool none)
|
||||
: name(name), descr(descr), value(value), convert(convert), none(none) {}
|
||||
};
|
||||
|
||||
/// Internal data structure which holds metadata about a bound function (signature, overloads, etc.)
|
||||
/// Internal data structure which holds metadata about a bound function
|
||||
/// (signature, overloads, etc.)
|
||||
struct function_record {
|
||||
function_record()
|
||||
: is_constructor(false), is_new_style_constructor(false), is_stateless(false),
|
||||
is_operator(false), has_args(false), has_kwargs(false), is_method(false) { }
|
||||
function_record()
|
||||
: is_constructor(false), is_new_style_constructor(false),
|
||||
is_stateless(false), is_operator(false), has_args(false),
|
||||
has_kwargs(false), is_method(false) {}
|
||||
|
||||
/// Function name
|
||||
char *name = nullptr; /* why no C++ strings? They generate heavier code.. */
|
||||
/// Function name
|
||||
char *name = nullptr; /* why no C++ strings? They generate heavier code.. */
|
||||
|
||||
// User-specified documentation string
|
||||
char *doc = nullptr;
|
||||
// User-specified documentation string
|
||||
char *doc = nullptr;
|
||||
|
||||
/// Human-readable version of the function signature
|
||||
char *signature = nullptr;
|
||||
/// Human-readable version of the function signature
|
||||
char *signature = nullptr;
|
||||
|
||||
/// List of registered keyword arguments
|
||||
std::vector<argument_record> args;
|
||||
/// List of registered keyword arguments
|
||||
std::vector<argument_record> args;
|
||||
|
||||
/// Pointer to lambda function which converts arguments and performs the actual call
|
||||
handle (*impl) (function_call &) = nullptr;
|
||||
/// Pointer to lambda function which converts arguments and performs the
|
||||
/// actual call
|
||||
handle (*impl)(function_call &) = nullptr;
|
||||
|
||||
/// Storage for the wrapped function pointer and captured data, if any
|
||||
void *data[3] = { };
|
||||
/// Storage for the wrapped function pointer and captured data, if any
|
||||
void *data[3] = {};
|
||||
|
||||
/// Pointer to custom destructor for 'data' (if needed)
|
||||
void (*free_data) (function_record *ptr) = nullptr;
|
||||
/// Pointer to custom destructor for 'data' (if needed)
|
||||
void (*free_data)(function_record *ptr) = nullptr;
|
||||
|
||||
/// Return value policy associated with this function
|
||||
return_value_policy policy = return_value_policy::automatic;
|
||||
/// Return value policy associated with this function
|
||||
return_value_policy policy = return_value_policy::automatic;
|
||||
|
||||
/// True if name == '__init__'
|
||||
bool is_constructor : 1;
|
||||
/// True if name == '__init__'
|
||||
bool is_constructor : 1;
|
||||
|
||||
/// True if this is a new-style `__init__` defined in `detail/init.h`
|
||||
bool is_new_style_constructor : 1;
|
||||
/// True if this is a new-style `__init__` defined in `detail/init.h`
|
||||
bool is_new_style_constructor : 1;
|
||||
|
||||
/// True if this is a stateless function pointer
|
||||
bool is_stateless : 1;
|
||||
/// True if this is a stateless function pointer
|
||||
bool is_stateless : 1;
|
||||
|
||||
/// True if this is an operator (__add__), etc.
|
||||
bool is_operator : 1;
|
||||
/// True if this is an operator (__add__), etc.
|
||||
bool is_operator : 1;
|
||||
|
||||
/// True if the function has a '*args' argument
|
||||
bool has_args : 1;
|
||||
/// True if the function has a '*args' argument
|
||||
bool has_args : 1;
|
||||
|
||||
/// True if the function has a '**kwargs' argument
|
||||
bool has_kwargs : 1;
|
||||
/// True if the function has a '**kwargs' argument
|
||||
bool has_kwargs : 1;
|
||||
|
||||
/// True if this is a method
|
||||
bool is_method : 1;
|
||||
/// True if this is a method
|
||||
bool is_method : 1;
|
||||
|
||||
/// Number of arguments (including py::args and/or py::kwargs, if present)
|
||||
std::uint16_t nargs;
|
||||
/// Number of arguments (including py::args and/or py::kwargs, if present)
|
||||
std::uint16_t nargs;
|
||||
|
||||
/// Python method object
|
||||
PyMethodDef *def = nullptr;
|
||||
/// Python method object
|
||||
PyMethodDef *def = nullptr;
|
||||
|
||||
/// Python handle to the parent scope (a class or a module)
|
||||
handle scope;
|
||||
/// Python handle to the parent scope (a class or a module)
|
||||
handle scope;
|
||||
|
||||
/// Python handle to the sibling function representing an overload chain
|
||||
handle sibling;
|
||||
/// Python handle to the sibling function representing an overload chain
|
||||
handle sibling;
|
||||
|
||||
/// Pointer to next overload
|
||||
function_record *next = nullptr;
|
||||
/// Pointer to next overload
|
||||
function_record *next = nullptr;
|
||||
};
|
||||
|
||||
/// Special data structure which (temporarily) holds metadata about a bound class
|
||||
/// Special data structure which (temporarily) holds metadata about a bound
|
||||
/// class
|
||||
struct type_record {
|
||||
PYBIND11_NOINLINE type_record()
|
||||
: multiple_inheritance(false), dynamic_attr(false), buffer_protocol(false),
|
||||
default_holder(true), module_local(false) { }
|
||||
PYBIND11_NOINLINE type_record()
|
||||
: multiple_inheritance(false), dynamic_attr(false),
|
||||
buffer_protocol(false), default_holder(true), module_local(false) {}
|
||||
|
||||
/// Handle to the parent scope
|
||||
handle scope;
|
||||
/// Handle to the parent scope
|
||||
handle scope;
|
||||
|
||||
/// Name of the class
|
||||
const char *name = nullptr;
|
||||
/// Name of the class
|
||||
const char *name = nullptr;
|
||||
|
||||
// Pointer to RTTI type_info data structure
|
||||
const std::type_info *type = nullptr;
|
||||
// Pointer to RTTI type_info data structure
|
||||
const std::type_info *type = nullptr;
|
||||
|
||||
/// How large is the underlying C++ type?
|
||||
size_t type_size = 0;
|
||||
/// How large is the underlying C++ type?
|
||||
size_t type_size = 0;
|
||||
|
||||
/// What is the alignment of the underlying C++ type?
|
||||
size_t type_align = 0;
|
||||
/// What is the alignment of the underlying C++ type?
|
||||
size_t type_align = 0;
|
||||
|
||||
/// How large is the type's holder?
|
||||
size_t holder_size = 0;
|
||||
/// How large is the type's holder?
|
||||
size_t holder_size = 0;
|
||||
|
||||
/// The global operator new can be overridden with a class-specific variant
|
||||
void *(*operator_new)(size_t) = nullptr;
|
||||
/// The global operator new can be overridden with a class-specific variant
|
||||
void *(*operator_new)(size_t) = nullptr;
|
||||
|
||||
/// Function pointer to class_<..>::init_instance
|
||||
void (*init_instance)(instance *, const void *) = nullptr;
|
||||
/// Function pointer to class_<..>::init_instance
|
||||
void (*init_instance)(instance *, const void *) = nullptr;
|
||||
|
||||
/// Function pointer to class_<..>::dealloc
|
||||
void (*dealloc)(detail::value_and_holder &) = nullptr;
|
||||
/// Function pointer to class_<..>::dealloc
|
||||
void (*dealloc)(detail::value_and_holder &) = nullptr;
|
||||
|
||||
/// List of base classes of the newly created type
|
||||
list bases;
|
||||
/// List of base classes of the newly created type
|
||||
list bases;
|
||||
|
||||
/// Optional docstring
|
||||
const char *doc = nullptr;
|
||||
/// Optional docstring
|
||||
const char *doc = nullptr;
|
||||
|
||||
/// Custom metaclass (optional)
|
||||
handle metaclass;
|
||||
/// Custom metaclass (optional)
|
||||
handle metaclass;
|
||||
|
||||
/// Multiple inheritance marker
|
||||
bool multiple_inheritance : 1;
|
||||
/// Multiple inheritance marker
|
||||
bool multiple_inheritance : 1;
|
||||
|
||||
/// Does the class manage a __dict__?
|
||||
bool dynamic_attr : 1;
|
||||
/// Does the class manage a __dict__?
|
||||
bool dynamic_attr : 1;
|
||||
|
||||
/// Does the class implement the buffer protocol?
|
||||
bool buffer_protocol : 1;
|
||||
/// Does the class implement the buffer protocol?
|
||||
bool buffer_protocol : 1;
|
||||
|
||||
/// Is the default (unique_ptr) holder type used?
|
||||
bool default_holder : 1;
|
||||
/// Is the default (unique_ptr) holder type used?
|
||||
bool default_holder : 1;
|
||||
|
||||
/// Is the class definition local to the module shared object?
|
||||
bool module_local : 1;
|
||||
/// Is the class definition local to the module shared object?
|
||||
bool module_local : 1;
|
||||
|
||||
PYBIND11_NOINLINE void add_base(const std::type_info &base, void *(*caster)(void *)) {
|
||||
auto base_info = detail::get_type_info(base, false);
|
||||
if (!base_info) {
|
||||
std::string tname(base.name());
|
||||
detail::clean_type_id(tname);
|
||||
pybind11_fail("generic_type: type \"" + std::string(name) +
|
||||
"\" referenced unknown base type \"" + tname + "\"");
|
||||
}
|
||||
|
||||
if (default_holder != base_info->default_holder) {
|
||||
std::string tname(base.name());
|
||||
detail::clean_type_id(tname);
|
||||
pybind11_fail("generic_type: type \"" + std::string(name) + "\" " +
|
||||
(default_holder ? "does not have" : "has") +
|
||||
" a non-default holder type while its base \"" + tname + "\" " +
|
||||
(base_info->default_holder ? "does not" : "does"));
|
||||
}
|
||||
|
||||
bases.append((PyObject *) base_info->type);
|
||||
|
||||
if (base_info->type->tp_dictoffset != 0)
|
||||
dynamic_attr = true;
|
||||
|
||||
if (caster)
|
||||
base_info->implicit_casts.emplace_back(type, caster);
|
||||
PYBIND11_NOINLINE void add_base(const std::type_info &base,
|
||||
void *(*caster)(void *)) {
|
||||
auto base_info = detail::get_type_info(base, false);
|
||||
if (!base_info) {
|
||||
std::string tname(base.name());
|
||||
detail::clean_type_id(tname);
|
||||
pybind11_fail("generic_type: type \"" + std::string(name) +
|
||||
"\" referenced unknown base type \"" + tname + "\"");
|
||||
}
|
||||
|
||||
if (default_holder != base_info->default_holder) {
|
||||
std::string tname(base.name());
|
||||
detail::clean_type_id(tname);
|
||||
pybind11_fail("generic_type: type \"" + std::string(name) + "\" " +
|
||||
(default_holder ? "does not have" : "has") +
|
||||
" a non-default holder type while its base \"" + tname +
|
||||
"\" " + (base_info->default_holder ? "does not" : "does"));
|
||||
}
|
||||
|
||||
bases.append((PyObject *)base_info->type);
|
||||
|
||||
if (base_info->type->tp_dictoffset != 0)
|
||||
dynamic_attr = true;
|
||||
|
||||
if (caster)
|
||||
base_info->implicit_casts.emplace_back(type, caster);
|
||||
}
|
||||
};
|
||||
|
||||
inline function_call::function_call(const function_record &f, handle p) :
|
||||
func(f), parent(p) {
|
||||
args.reserve(f.nargs);
|
||||
args_convert.reserve(f.nargs);
|
||||
inline function_call::function_call(const function_record &f, handle p)
|
||||
: func(f), parent(p) {
|
||||
args.reserve(f.nargs);
|
||||
args_convert.reserve(f.nargs);
|
||||
}
|
||||
|
||||
/// Tag for a new-style `__init__` defined in `detail/init.h`
|
||||
struct is_new_style_constructor { };
|
||||
struct is_new_style_constructor {};
|
||||
|
||||
/**
|
||||
* Partial template specializations to process custom attributes provided to
|
||||
@@ -300,135 +331,191 @@ struct is_new_style_constructor { };
|
||||
template <typename T, typename SFINAE = void> struct process_attribute;
|
||||
|
||||
template <typename T> struct process_attribute_default {
|
||||
/// Default implementation: do nothing
|
||||
static void init(const T &, function_record *) { }
|
||||
static void init(const T &, type_record *) { }
|
||||
static void precall(function_call &) { }
|
||||
static void postcall(function_call &, handle) { }
|
||||
/// Default implementation: do nothing
|
||||
static void init(const T &, function_record *) {}
|
||||
static void init(const T &, type_record *) {}
|
||||
static void precall(function_call &) {}
|
||||
static void postcall(function_call &, handle) {}
|
||||
};
|
||||
|
||||
/// Process an attribute specifying the function's name
|
||||
template <> struct process_attribute<name> : process_attribute_default<name> {
|
||||
static void init(const name &n, function_record *r) { r->name = const_cast<char *>(n.value); }
|
||||
static void init(const name &n, function_record *r) {
|
||||
r->name = const_cast<char *>(n.value);
|
||||
}
|
||||
};
|
||||
|
||||
/// Process an attribute specifying the function's docstring
|
||||
template <> struct process_attribute<doc> : process_attribute_default<doc> {
|
||||
static void init(const doc &n, function_record *r) { r->doc = const_cast<char *>(n.value); }
|
||||
static void init(const doc &n, function_record *r) {
|
||||
r->doc = const_cast<char *>(n.value);
|
||||
}
|
||||
};
|
||||
|
||||
/// Process an attribute specifying the function's docstring (provided as a C-style string)
|
||||
template <> struct process_attribute<const char *> : process_attribute_default<const char *> {
|
||||
static void init(const char *d, function_record *r) { r->doc = const_cast<char *>(d); }
|
||||
static void init(const char *d, type_record *r) { r->doc = const_cast<char *>(d); }
|
||||
/// Process an attribute specifying the function's docstring (provided as a
|
||||
/// C-style string)
|
||||
template <>
|
||||
struct process_attribute<const char *>
|
||||
: process_attribute_default<const char *> {
|
||||
static void init(const char *d, function_record *r) {
|
||||
r->doc = const_cast<char *>(d);
|
||||
}
|
||||
static void init(const char *d, type_record *r) {
|
||||
r->doc = const_cast<char *>(d);
|
||||
}
|
||||
};
|
||||
template <> struct process_attribute<char *> : process_attribute<const char *> { };
|
||||
template <>
|
||||
struct process_attribute<char *> : process_attribute<const char *> {};
|
||||
|
||||
/// Process an attribute indicating the function's return value policy
|
||||
template <> struct process_attribute<return_value_policy> : process_attribute_default<return_value_policy> {
|
||||
static void init(const return_value_policy &p, function_record *r) { r->policy = p; }
|
||||
template <>
|
||||
struct process_attribute<return_value_policy>
|
||||
: process_attribute_default<return_value_policy> {
|
||||
static void init(const return_value_policy &p, function_record *r) {
|
||||
r->policy = p;
|
||||
}
|
||||
};
|
||||
|
||||
/// Process an attribute which indicates that this is an overloaded function associated with a given sibling
|
||||
template <> struct process_attribute<sibling> : process_attribute_default<sibling> {
|
||||
static void init(const sibling &s, function_record *r) { r->sibling = s.value; }
|
||||
/// Process an attribute which indicates that this is an overloaded function
|
||||
/// associated with a given sibling
|
||||
template <>
|
||||
struct process_attribute<sibling> : process_attribute_default<sibling> {
|
||||
static void init(const sibling &s, function_record *r) {
|
||||
r->sibling = s.value;
|
||||
}
|
||||
};
|
||||
|
||||
/// Process an attribute which indicates that this function is a method
|
||||
template <> struct process_attribute<is_method> : process_attribute_default<is_method> {
|
||||
static void init(const is_method &s, function_record *r) { r->is_method = true; r->scope = s.class_; }
|
||||
template <>
|
||||
struct process_attribute<is_method> : process_attribute_default<is_method> {
|
||||
static void init(const is_method &s, function_record *r) {
|
||||
r->is_method = true;
|
||||
r->scope = s.class_;
|
||||
}
|
||||
};
|
||||
|
||||
/// Process an attribute which indicates the parent scope of a method
|
||||
template <> struct process_attribute<scope> : process_attribute_default<scope> {
|
||||
static void init(const scope &s, function_record *r) { r->scope = s.value; }
|
||||
static void init(const scope &s, function_record *r) { r->scope = s.value; }
|
||||
};
|
||||
|
||||
/// Process an attribute which indicates that this function is an operator
|
||||
template <> struct process_attribute<is_operator> : process_attribute_default<is_operator> {
|
||||
static void init(const is_operator &, function_record *r) { r->is_operator = true; }
|
||||
template <>
|
||||
struct process_attribute<is_operator> : process_attribute_default<is_operator> {
|
||||
static void init(const is_operator &, function_record *r) {
|
||||
r->is_operator = true;
|
||||
}
|
||||
};
|
||||
|
||||
template <> struct process_attribute<is_new_style_constructor> : process_attribute_default<is_new_style_constructor> {
|
||||
static void init(const is_new_style_constructor &, function_record *r) { r->is_new_style_constructor = true; }
|
||||
template <>
|
||||
struct process_attribute<is_new_style_constructor>
|
||||
: process_attribute_default<is_new_style_constructor> {
|
||||
static void init(const is_new_style_constructor &, function_record *r) {
|
||||
r->is_new_style_constructor = true;
|
||||
}
|
||||
};
|
||||
|
||||
/// Process a keyword argument attribute (*without* a default value)
|
||||
template <> struct process_attribute<arg> : process_attribute_default<arg> {
|
||||
static void init(const arg &a, function_record *r) {
|
||||
if (r->is_method && r->args.empty())
|
||||
r->args.emplace_back("self", nullptr, handle(), true /*convert*/, false /*none not allowed*/);
|
||||
r->args.emplace_back(a.name, nullptr, handle(), !a.flag_noconvert, a.flag_none);
|
||||
}
|
||||
static void init(const arg &a, function_record *r) {
|
||||
if (r->is_method && r->args.empty())
|
||||
r->args.emplace_back("self", nullptr, handle(), true /*convert*/,
|
||||
false /*none not allowed*/);
|
||||
r->args.emplace_back(a.name, nullptr, handle(), !a.flag_noconvert,
|
||||
a.flag_none);
|
||||
}
|
||||
};
|
||||
|
||||
/// Process a keyword argument attribute (*with* a default value)
|
||||
template <> struct process_attribute<arg_v> : process_attribute_default<arg_v> {
|
||||
static void init(const arg_v &a, function_record *r) {
|
||||
if (r->is_method && r->args.empty())
|
||||
r->args.emplace_back("self", nullptr /*descr*/, handle() /*parent*/, true /*convert*/, false /*none not allowed*/);
|
||||
static void init(const arg_v &a, function_record *r) {
|
||||
if (r->is_method && r->args.empty())
|
||||
r->args.emplace_back("self", nullptr /*descr*/, handle() /*parent*/,
|
||||
true /*convert*/, false /*none not allowed*/);
|
||||
|
||||
if (!a.value) {
|
||||
if (!a.value) {
|
||||
#if !defined(NDEBUG)
|
||||
std::string descr("'");
|
||||
if (a.name) descr += std::string(a.name) + ": ";
|
||||
descr += a.type + "'";
|
||||
if (r->is_method) {
|
||||
if (r->name)
|
||||
descr += " in method '" + (std::string) str(r->scope) + "." + (std::string) r->name + "'";
|
||||
else
|
||||
descr += " in method of '" + (std::string) str(r->scope) + "'";
|
||||
} else if (r->name) {
|
||||
descr += " in function '" + (std::string) r->name + "'";
|
||||
}
|
||||
pybind11_fail("arg(): could not convert default argument "
|
||||
+ descr + " into a Python object (type not registered yet?)");
|
||||
std::string descr("'");
|
||||
if (a.name)
|
||||
descr += std::string(a.name) + ": ";
|
||||
descr += a.type + "'";
|
||||
if (r->is_method) {
|
||||
if (r->name)
|
||||
descr += " in method '" + (std::string)str(r->scope) + "." +
|
||||
(std::string)r->name + "'";
|
||||
else
|
||||
descr += " in method of '" + (std::string)str(r->scope) + "'";
|
||||
} else if (r->name) {
|
||||
descr += " in function '" + (std::string)r->name + "'";
|
||||
}
|
||||
pybind11_fail("arg(): could not convert default argument " + descr +
|
||||
" into a Python object (type not registered yet?)");
|
||||
#else
|
||||
pybind11_fail("arg(): could not convert default argument "
|
||||
"into a Python object (type not registered yet?). "
|
||||
"Compile in debug mode for more information.");
|
||||
pybind11_fail("arg(): could not convert default argument "
|
||||
"into a Python object (type not registered yet?). "
|
||||
"Compile in debug mode for more information.");
|
||||
#endif
|
||||
}
|
||||
r->args.emplace_back(a.name, a.descr, a.value.inc_ref(), !a.flag_noconvert, a.flag_none);
|
||||
}
|
||||
r->args.emplace_back(a.name, a.descr, a.value.inc_ref(), !a.flag_noconvert,
|
||||
a.flag_none);
|
||||
}
|
||||
};
|
||||
|
||||
/// Process a parent class attribute. Single inheritance only (class_ itself already guarantees that)
|
||||
/// Process a parent class attribute. Single inheritance only (class_ itself
|
||||
/// already guarantees that)
|
||||
template <typename T>
|
||||
struct process_attribute<T, enable_if_t<is_pyobject<T>::value>> : process_attribute_default<handle> {
|
||||
static void init(const handle &h, type_record *r) { r->bases.append(h); }
|
||||
struct process_attribute<T, enable_if_t<is_pyobject<T>::value>>
|
||||
: process_attribute_default<handle> {
|
||||
static void init(const handle &h, type_record *r) { r->bases.append(h); }
|
||||
};
|
||||
|
||||
/// Process a parent class attribute (deprecated, does not support multiple inheritance)
|
||||
/// Process a parent class attribute (deprecated, does not support multiple
|
||||
/// inheritance)
|
||||
template <typename T>
|
||||
struct process_attribute<base<T>> : process_attribute_default<base<T>> {
|
||||
static void init(const base<T> &, type_record *r) { r->add_base(typeid(T), nullptr); }
|
||||
static void init(const base<T> &, type_record *r) {
|
||||
r->add_base(typeid(T), nullptr);
|
||||
}
|
||||
};
|
||||
|
||||
/// Process a multiple inheritance attribute
|
||||
template <>
|
||||
struct process_attribute<multiple_inheritance> : process_attribute_default<multiple_inheritance> {
|
||||
static void init(const multiple_inheritance &, type_record *r) { r->multiple_inheritance = true; }
|
||||
struct process_attribute<multiple_inheritance>
|
||||
: process_attribute_default<multiple_inheritance> {
|
||||
static void init(const multiple_inheritance &, type_record *r) {
|
||||
r->multiple_inheritance = true;
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
struct process_attribute<dynamic_attr> : process_attribute_default<dynamic_attr> {
|
||||
static void init(const dynamic_attr &, type_record *r) { r->dynamic_attr = true; }
|
||||
struct process_attribute<dynamic_attr>
|
||||
: process_attribute_default<dynamic_attr> {
|
||||
static void init(const dynamic_attr &, type_record *r) {
|
||||
r->dynamic_attr = true;
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
struct process_attribute<buffer_protocol> : process_attribute_default<buffer_protocol> {
|
||||
static void init(const buffer_protocol &, type_record *r) { r->buffer_protocol = true; }
|
||||
struct process_attribute<buffer_protocol>
|
||||
: process_attribute_default<buffer_protocol> {
|
||||
static void init(const buffer_protocol &, type_record *r) {
|
||||
r->buffer_protocol = true;
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
struct process_attribute<metaclass> : process_attribute_default<metaclass> {
|
||||
static void init(const metaclass &m, type_record *r) { r->metaclass = m.value; }
|
||||
static void init(const metaclass &m, type_record *r) {
|
||||
r->metaclass = m.value;
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
struct process_attribute<module_local> : process_attribute_default<module_local> {
|
||||
static void init(const module_local &l, type_record *r) { r->module_local = l.value; }
|
||||
struct process_attribute<module_local>
|
||||
: process_attribute_default<module_local> {
|
||||
static void init(const module_local &l, type_record *r) {
|
||||
r->module_local = l.value;
|
||||
}
|
||||
};
|
||||
|
||||
/// Process an 'arithmetic' attribute for enums (does nothing here)
|
||||
@@ -436,57 +523,78 @@ template <>
|
||||
struct process_attribute<arithmetic> : process_attribute_default<arithmetic> {};
|
||||
|
||||
template <typename... Ts>
|
||||
struct process_attribute<call_guard<Ts...>> : process_attribute_default<call_guard<Ts...>> { };
|
||||
struct process_attribute<call_guard<Ts...>>
|
||||
: process_attribute_default<call_guard<Ts...>> {};
|
||||
|
||||
/**
|
||||
* Process a keep_alive call policy -- invokes keep_alive_impl during the
|
||||
* pre-call handler if both Nurse, Patient != 0 and use the post-call handler
|
||||
* otherwise
|
||||
*/
|
||||
template <size_t Nurse, size_t Patient> struct process_attribute<keep_alive<Nurse, Patient>> : public process_attribute_default<keep_alive<Nurse, Patient>> {
|
||||
template <size_t N = Nurse, size_t P = Patient, enable_if_t<N != 0 && P != 0, int> = 0>
|
||||
static void precall(function_call &call) { keep_alive_impl(Nurse, Patient, call, handle()); }
|
||||
template <size_t N = Nurse, size_t P = Patient, enable_if_t<N != 0 && P != 0, int> = 0>
|
||||
static void postcall(function_call &, handle) { }
|
||||
template <size_t N = Nurse, size_t P = Patient, enable_if_t<N == 0 || P == 0, int> = 0>
|
||||
static void precall(function_call &) { }
|
||||
template <size_t N = Nurse, size_t P = Patient, enable_if_t<N == 0 || P == 0, int> = 0>
|
||||
static void postcall(function_call &call, handle ret) { keep_alive_impl(Nurse, Patient, call, ret); }
|
||||
template <size_t Nurse, size_t Patient>
|
||||
struct process_attribute<keep_alive<Nurse, Patient>>
|
||||
: public process_attribute_default<keep_alive<Nurse, Patient>> {
|
||||
template <size_t N = Nurse, size_t P = Patient,
|
||||
enable_if_t<N != 0 && P != 0, int> = 0>
|
||||
static void precall(function_call &call) {
|
||||
keep_alive_impl(Nurse, Patient, call, handle());
|
||||
}
|
||||
template <size_t N = Nurse, size_t P = Patient,
|
||||
enable_if_t<N != 0 && P != 0, int> = 0>
|
||||
static void postcall(function_call &, handle) {}
|
||||
template <size_t N = Nurse, size_t P = Patient,
|
||||
enable_if_t<N == 0 || P == 0, int> = 0>
|
||||
static void precall(function_call &) {}
|
||||
template <size_t N = Nurse, size_t P = Patient,
|
||||
enable_if_t<N == 0 || P == 0, int> = 0>
|
||||
static void postcall(function_call &call, handle ret) {
|
||||
keep_alive_impl(Nurse, Patient, call, ret);
|
||||
}
|
||||
};
|
||||
|
||||
/// Recursively iterate over variadic template arguments
|
||||
template <typename... Args> struct process_attributes {
|
||||
static void init(const Args&... args, function_record *r) {
|
||||
int unused[] = { 0, (process_attribute<typename std::decay<Args>::type>::init(args, r), 0) ... };
|
||||
ignore_unused(unused);
|
||||
}
|
||||
static void init(const Args&... args, type_record *r) {
|
||||
int unused[] = { 0, (process_attribute<typename std::decay<Args>::type>::init(args, r), 0) ... };
|
||||
ignore_unused(unused);
|
||||
}
|
||||
static void precall(function_call &call) {
|
||||
int unused[] = { 0, (process_attribute<typename std::decay<Args>::type>::precall(call), 0) ... };
|
||||
ignore_unused(unused);
|
||||
}
|
||||
static void postcall(function_call &call, handle fn_ret) {
|
||||
int unused[] = { 0, (process_attribute<typename std::decay<Args>::type>::postcall(call, fn_ret), 0) ... };
|
||||
ignore_unused(unused);
|
||||
}
|
||||
static void init(const Args &... args, function_record *r) {
|
||||
int unused[] = {
|
||||
0, (process_attribute<typename std::decay<Args>::type>::init(args, r),
|
||||
0)...};
|
||||
ignore_unused(unused);
|
||||
}
|
||||
static void init(const Args &... args, type_record *r) {
|
||||
int unused[] = {
|
||||
0, (process_attribute<typename std::decay<Args>::type>::init(args, r),
|
||||
0)...};
|
||||
ignore_unused(unused);
|
||||
}
|
||||
static void precall(function_call &call) {
|
||||
int unused[] = {
|
||||
0, (process_attribute<typename std::decay<Args>::type>::precall(call),
|
||||
0)...};
|
||||
ignore_unused(unused);
|
||||
}
|
||||
static void postcall(function_call &call, handle fn_ret) {
|
||||
int unused[] = {
|
||||
0, (process_attribute<typename std::decay<Args>::type>::postcall(
|
||||
call, fn_ret),
|
||||
0)...};
|
||||
ignore_unused(unused);
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
using is_call_guard = is_instantiation<call_guard, T>;
|
||||
template <typename T> using is_call_guard = is_instantiation<call_guard, T>;
|
||||
|
||||
/// Extract the ``type`` from the first `call_guard` in `Extras...` (or `void_type` if none found)
|
||||
/// Extract the ``type`` from the first `call_guard` in `Extras...` (or
|
||||
/// `void_type` if none found)
|
||||
template <typename... Extra>
|
||||
using extract_guard_t = typename exactly_one_t<is_call_guard, call_guard<>, Extra...>::type;
|
||||
using extract_guard_t =
|
||||
typename exactly_one_t<is_call_guard, call_guard<>, Extra...>::type;
|
||||
|
||||
/// Check the number of named arguments at compile time
|
||||
template <typename... Extra,
|
||||
size_t named = constexpr_sum(std::is_base_of<arg, Extra>::value...),
|
||||
size_t self = constexpr_sum(std::is_same<is_method, Extra>::value...)>
|
||||
size_t self = constexpr_sum(std::is_same<is_method, Extra>::value...)>
|
||||
constexpr bool expected_num_args(size_t nargs, bool has_args, bool has_kwargs) {
|
||||
return named == 0 || (self + named + has_args + has_kwargs) == nargs;
|
||||
return named == 0 || (self + named + has_args + has_kwargs) == nargs;
|
||||
}
|
||||
|
||||
NAMESPACE_END(detail)
|
||||
|
@@ -15,93 +15,112 @@ NAMESPACE_BEGIN(PYBIND11_NAMESPACE)
|
||||
|
||||
/// Information record describing a Python buffer object
|
||||
struct buffer_info {
|
||||
void *ptr = nullptr; // Pointer to the underlying storage
|
||||
ssize_t itemsize = 0; // Size of individual items in bytes
|
||||
ssize_t size = 0; // Total number of entries
|
||||
std::string format; // For homogeneous buffers, this should be set to format_descriptor<T>::format()
|
||||
ssize_t ndim = 0; // Number of dimensions
|
||||
std::vector<ssize_t> shape; // Shape of the tensor (1 entry per dimension)
|
||||
std::vector<ssize_t> strides; // Number of entries between adjacent entries (for each per dimension)
|
||||
void *ptr = nullptr; // Pointer to the underlying storage
|
||||
ssize_t itemsize = 0; // Size of individual items in bytes
|
||||
ssize_t size = 0; // Total number of entries
|
||||
std::string format; // For homogeneous buffers, this should be set to
|
||||
// format_descriptor<T>::format()
|
||||
ssize_t ndim = 0; // Number of dimensions
|
||||
std::vector<ssize_t> shape; // Shape of the tensor (1 entry per dimension)
|
||||
std::vector<ssize_t> strides; // Number of entries between adjacent entries
|
||||
// (for each per dimension)
|
||||
|
||||
buffer_info() { }
|
||||
buffer_info() {}
|
||||
|
||||
buffer_info(void *ptr, ssize_t itemsize, const std::string &format, ssize_t ndim,
|
||||
detail::any_container<ssize_t> shape_in, detail::any_container<ssize_t> strides_in)
|
||||
: ptr(ptr), itemsize(itemsize), size(1), format(format), ndim(ndim),
|
||||
shape(std::move(shape_in)), strides(std::move(strides_in)) {
|
||||
if (ndim != (ssize_t) shape.size() || ndim != (ssize_t) strides.size())
|
||||
pybind11_fail("buffer_info: ndim doesn't match shape and/or strides length");
|
||||
for (size_t i = 0; i < (size_t) ndim; ++i)
|
||||
size *= shape[i];
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
buffer_info(T *ptr, detail::any_container<ssize_t> shape_in, detail::any_container<ssize_t> strides_in)
|
||||
: buffer_info(private_ctr_tag(), ptr, sizeof(T), format_descriptor<T>::format(), static_cast<ssize_t>(shape_in->size()), std::move(shape_in), std::move(strides_in)) { }
|
||||
|
||||
buffer_info(void *ptr, ssize_t itemsize, const std::string &format, ssize_t size)
|
||||
: buffer_info(ptr, itemsize, format, 1, {size}, {itemsize}) { }
|
||||
|
||||
template <typename T>
|
||||
buffer_info(T *ptr, ssize_t size)
|
||||
: buffer_info(ptr, sizeof(T), format_descriptor<T>::format(), size) { }
|
||||
|
||||
explicit buffer_info(Py_buffer *view, bool ownview = true)
|
||||
: buffer_info(view->buf, view->itemsize, view->format, view->ndim,
|
||||
{view->shape, view->shape + view->ndim}, {view->strides, view->strides + view->ndim}) {
|
||||
this->view = view;
|
||||
this->ownview = ownview;
|
||||
}
|
||||
|
||||
buffer_info(const buffer_info &) = delete;
|
||||
buffer_info& operator=(const buffer_info &) = delete;
|
||||
|
||||
buffer_info(buffer_info &&other) {
|
||||
(*this) = std::move(other);
|
||||
}
|
||||
|
||||
buffer_info& operator=(buffer_info &&rhs) {
|
||||
ptr = rhs.ptr;
|
||||
itemsize = rhs.itemsize;
|
||||
size = rhs.size;
|
||||
format = std::move(rhs.format);
|
||||
ndim = rhs.ndim;
|
||||
shape = std::move(rhs.shape);
|
||||
strides = std::move(rhs.strides);
|
||||
std::swap(view, rhs.view);
|
||||
std::swap(ownview, rhs.ownview);
|
||||
return *this;
|
||||
}
|
||||
|
||||
~buffer_info() {
|
||||
if (view && ownview) { PyBuffer_Release(view); delete view; }
|
||||
buffer_info(void *ptr, ssize_t itemsize, const std::string &format,
|
||||
ssize_t ndim, detail::any_container<ssize_t> shape_in,
|
||||
detail::any_container<ssize_t> strides_in)
|
||||
: ptr(ptr), itemsize(itemsize), size(1), format(format), ndim(ndim),
|
||||
shape(std::move(shape_in)), strides(std::move(strides_in)) {
|
||||
if (ndim != (ssize_t)shape.size() || ndim != (ssize_t)strides.size())
|
||||
pybind11_fail(
|
||||
"buffer_info: ndim doesn't match shape and/or strides length");
|
||||
for (size_t i = 0; i < (size_t)ndim; ++i)
|
||||
size *= shape[i];
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
buffer_info(T *ptr, detail::any_container<ssize_t> shape_in,
|
||||
detail::any_container<ssize_t> strides_in)
|
||||
: buffer_info(private_ctr_tag(), ptr, sizeof(T),
|
||||
format_descriptor<T>::format(),
|
||||
static_cast<ssize_t>(shape_in->size()), std::move(shape_in),
|
||||
std::move(strides_in)) {}
|
||||
|
||||
buffer_info(void *ptr, ssize_t itemsize, const std::string &format,
|
||||
ssize_t size)
|
||||
: buffer_info(ptr, itemsize, format, 1, {size}, {itemsize}) {}
|
||||
|
||||
template <typename T>
|
||||
buffer_info(T *ptr, ssize_t size)
|
||||
: buffer_info(ptr, sizeof(T), format_descriptor<T>::format(), size) {}
|
||||
|
||||
explicit buffer_info(Py_buffer *view, bool ownview = true)
|
||||
: buffer_info(view->buf, view->itemsize, view->format, view->ndim,
|
||||
{view->shape, view->shape + view->ndim},
|
||||
{view->strides, view->strides + view->ndim}) {
|
||||
this->view = view;
|
||||
this->ownview = ownview;
|
||||
}
|
||||
|
||||
buffer_info(const buffer_info &) = delete;
|
||||
buffer_info &operator=(const buffer_info &) = delete;
|
||||
|
||||
buffer_info(buffer_info &&other) { (*this) = std::move(other); }
|
||||
|
||||
buffer_info &operator=(buffer_info &&rhs) {
|
||||
ptr = rhs.ptr;
|
||||
itemsize = rhs.itemsize;
|
||||
size = rhs.size;
|
||||
format = std::move(rhs.format);
|
||||
ndim = rhs.ndim;
|
||||
shape = std::move(rhs.shape);
|
||||
strides = std::move(rhs.strides);
|
||||
std::swap(view, rhs.view);
|
||||
std::swap(ownview, rhs.ownview);
|
||||
return *this;
|
||||
}
|
||||
|
||||
~buffer_info() {
|
||||
if (view && ownview) {
|
||||
PyBuffer_Release(view);
|
||||
delete view;
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
struct private_ctr_tag { };
|
||||
struct private_ctr_tag {};
|
||||
|
||||
buffer_info(private_ctr_tag, void *ptr, ssize_t itemsize, const std::string &format, ssize_t ndim,
|
||||
detail::any_container<ssize_t> &&shape_in, detail::any_container<ssize_t> &&strides_in)
|
||||
: buffer_info(ptr, itemsize, format, ndim, std::move(shape_in), std::move(strides_in)) { }
|
||||
buffer_info(private_ctr_tag, void *ptr, ssize_t itemsize,
|
||||
const std::string &format, ssize_t ndim,
|
||||
detail::any_container<ssize_t> &&shape_in,
|
||||
detail::any_container<ssize_t> &&strides_in)
|
||||
: buffer_info(ptr, itemsize, format, ndim, std::move(shape_in),
|
||||
std::move(strides_in)) {}
|
||||
|
||||
Py_buffer *view = nullptr;
|
||||
bool ownview = false;
|
||||
Py_buffer *view = nullptr;
|
||||
bool ownview = false;
|
||||
};
|
||||
|
||||
NAMESPACE_BEGIN(detail)
|
||||
|
||||
template <typename T, typename SFINAE = void> struct compare_buffer_info {
|
||||
static bool compare(const buffer_info& b) {
|
||||
return b.format == format_descriptor<T>::format() && b.itemsize == (ssize_t) sizeof(T);
|
||||
}
|
||||
static bool compare(const buffer_info &b) {
|
||||
return b.format == format_descriptor<T>::format() &&
|
||||
b.itemsize == (ssize_t)sizeof(T);
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T> struct compare_buffer_info<T, detail::enable_if_t<std::is_integral<T>::value>> {
|
||||
static bool compare(const buffer_info& b) {
|
||||
return (size_t) b.itemsize == sizeof(T) && (b.format == format_descriptor<T>::value ||
|
||||
((sizeof(T) == sizeof(long)) && b.format == (std::is_unsigned<T>::value ? "L" : "l")) ||
|
||||
((sizeof(T) == sizeof(size_t)) && b.format == (std::is_unsigned<T>::value ? "N" : "n")));
|
||||
}
|
||||
template <typename T>
|
||||
struct compare_buffer_info<T, detail::enable_if_t<std::is_integral<T>::value>> {
|
||||
static bool compare(const buffer_info &b) {
|
||||
return (size_t)b.itemsize == sizeof(T) &&
|
||||
(b.format == format_descriptor<T>::value ||
|
||||
((sizeof(T) == sizeof(long)) &&
|
||||
b.format == (std::is_unsigned<T>::value ? "L" : "l")) ||
|
||||
((sizeof(T) == sizeof(size_t)) &&
|
||||
b.format == (std::is_unsigned<T>::value ? "N" : "n")));
|
||||
}
|
||||
};
|
||||
|
||||
NAMESPACE_END(detail)
|
||||
|
File diff suppressed because it is too large
Load Diff
@@ -1,5 +1,6 @@
|
||||
/*
|
||||
pybind11/chrono.h: Transparent conversion between std::chrono and python's datetime
|
||||
pybind11/chrono.h: Transparent conversion between std::chrono and python's
|
||||
datetime
|
||||
|
||||
Copyright (c) 2016 Trent Houliston <trent@houliston.me> and
|
||||
Wenzel Jakob <wenzel.jakob@epfl.ch>
|
||||
@@ -11,20 +12,21 @@
|
||||
#pragma once
|
||||
|
||||
#include "pybind11.h"
|
||||
#include <chrono>
|
||||
#include <cmath>
|
||||
#include <ctime>
|
||||
#include <chrono>
|
||||
#include <datetime.h>
|
||||
|
||||
// Backport the PyDateTime_DELTA functions from Python3.3 if required
|
||||
#ifndef PyDateTime_DELTA_GET_DAYS
|
||||
#define PyDateTime_DELTA_GET_DAYS(o) (((PyDateTime_Delta*)o)->days)
|
||||
#define PyDateTime_DELTA_GET_DAYS(o) (((PyDateTime_Delta *)o)->days)
|
||||
#endif
|
||||
#ifndef PyDateTime_DELTA_GET_SECONDS
|
||||
#define PyDateTime_DELTA_GET_SECONDS(o) (((PyDateTime_Delta*)o)->seconds)
|
||||
#define PyDateTime_DELTA_GET_SECONDS(o) (((PyDateTime_Delta *)o)->seconds)
|
||||
#endif
|
||||
#ifndef PyDateTime_DELTA_GET_MICROSECONDS
|
||||
#define PyDateTime_DELTA_GET_MICROSECONDS(o) (((PyDateTime_Delta*)o)->microseconds)
|
||||
#define PyDateTime_DELTA_GET_MICROSECONDS(o) \
|
||||
(((PyDateTime_Delta *)o)->microseconds)
|
||||
#endif
|
||||
|
||||
NAMESPACE_BEGIN(PYBIND11_NAMESPACE)
|
||||
@@ -32,131 +34,154 @@ NAMESPACE_BEGIN(detail)
|
||||
|
||||
template <typename type> class duration_caster {
|
||||
public:
|
||||
typedef typename type::rep rep;
|
||||
typedef typename type::period period;
|
||||
typedef typename type::rep rep;
|
||||
typedef typename type::period period;
|
||||
|
||||
typedef std::chrono::duration<uint_fast32_t, std::ratio<86400>> days;
|
||||
typedef std::chrono::duration<uint_fast32_t, std::ratio<86400>> days;
|
||||
|
||||
bool load(handle src, bool) {
|
||||
using namespace std::chrono;
|
||||
bool load(handle src, bool) {
|
||||
using namespace std::chrono;
|
||||
|
||||
// Lazy initialise the PyDateTime import
|
||||
if (!PyDateTimeAPI) { PyDateTime_IMPORT; }
|
||||
|
||||
if (!src) return false;
|
||||
// If invoked with datetime.delta object
|
||||
if (PyDelta_Check(src.ptr())) {
|
||||
value = type(duration_cast<duration<rep, period>>(
|
||||
days(PyDateTime_DELTA_GET_DAYS(src.ptr()))
|
||||
+ seconds(PyDateTime_DELTA_GET_SECONDS(src.ptr()))
|
||||
+ microseconds(PyDateTime_DELTA_GET_MICROSECONDS(src.ptr()))));
|
||||
return true;
|
||||
}
|
||||
// If invoked with a float we assume it is seconds and convert
|
||||
else if (PyFloat_Check(src.ptr())) {
|
||||
value = type(duration_cast<duration<rep, period>>(duration<double>(PyFloat_AsDouble(src.ptr()))));
|
||||
return true;
|
||||
}
|
||||
else return false;
|
||||
// Lazy initialise the PyDateTime import
|
||||
if (!PyDateTimeAPI) {
|
||||
PyDateTime_IMPORT;
|
||||
}
|
||||
|
||||
// If this is a duration just return it back
|
||||
static const std::chrono::duration<rep, period>& get_duration(const std::chrono::duration<rep, period> &src) {
|
||||
return src;
|
||||
if (!src)
|
||||
return false;
|
||||
// If invoked with datetime.delta object
|
||||
if (PyDelta_Check(src.ptr())) {
|
||||
value = type(duration_cast<duration<rep, period>>(
|
||||
days(PyDateTime_DELTA_GET_DAYS(src.ptr())) +
|
||||
seconds(PyDateTime_DELTA_GET_SECONDS(src.ptr())) +
|
||||
microseconds(PyDateTime_DELTA_GET_MICROSECONDS(src.ptr()))));
|
||||
return true;
|
||||
}
|
||||
// If invoked with a float we assume it is seconds and convert
|
||||
else if (PyFloat_Check(src.ptr())) {
|
||||
value = type(duration_cast<duration<rep, period>>(
|
||||
duration<double>(PyFloat_AsDouble(src.ptr()))));
|
||||
return true;
|
||||
} else
|
||||
return false;
|
||||
}
|
||||
|
||||
// If this is a duration just return it back
|
||||
static const std::chrono::duration<rep, period> &
|
||||
get_duration(const std::chrono::duration<rep, period> &src) {
|
||||
return src;
|
||||
}
|
||||
|
||||
// If this is a time_point get the time_since_epoch
|
||||
template <typename Clock>
|
||||
static std::chrono::duration<rep, period> get_duration(
|
||||
const std::chrono::time_point<Clock, std::chrono::duration<rep, period>>
|
||||
&src) {
|
||||
return src.time_since_epoch();
|
||||
}
|
||||
|
||||
static handle cast(const type &src, return_value_policy /* policy */,
|
||||
handle /* parent */) {
|
||||
using namespace std::chrono;
|
||||
|
||||
// Use overloaded function to get our duration from our source
|
||||
// Works out if it is a duration or time_point and get the duration
|
||||
auto d = get_duration(src);
|
||||
|
||||
// Lazy initialise the PyDateTime import
|
||||
if (!PyDateTimeAPI) {
|
||||
PyDateTime_IMPORT;
|
||||
}
|
||||
|
||||
// If this is a time_point get the time_since_epoch
|
||||
template <typename Clock> static std::chrono::duration<rep, period> get_duration(const std::chrono::time_point<Clock, std::chrono::duration<rep, period>> &src) {
|
||||
return src.time_since_epoch();
|
||||
}
|
||||
// Declare these special duration types so the conversions happen with the
|
||||
// correct primitive types (int)
|
||||
using dd_t = duration<int, std::ratio<86400>>;
|
||||
using ss_t = duration<int, std::ratio<1>>;
|
||||
using us_t = duration<int, std::micro>;
|
||||
|
||||
static handle cast(const type &src, return_value_policy /* policy */, handle /* parent */) {
|
||||
using namespace std::chrono;
|
||||
auto dd = duration_cast<dd_t>(d);
|
||||
auto subd = d - dd;
|
||||
auto ss = duration_cast<ss_t>(subd);
|
||||
auto us = duration_cast<us_t>(subd - ss);
|
||||
return PyDelta_FromDSU(dd.count(), ss.count(), us.count());
|
||||
}
|
||||
|
||||
// Use overloaded function to get our duration from our source
|
||||
// Works out if it is a duration or time_point and get the duration
|
||||
auto d = get_duration(src);
|
||||
|
||||
// Lazy initialise the PyDateTime import
|
||||
if (!PyDateTimeAPI) { PyDateTime_IMPORT; }
|
||||
|
||||
// Declare these special duration types so the conversions happen with the correct primitive types (int)
|
||||
using dd_t = duration<int, std::ratio<86400>>;
|
||||
using ss_t = duration<int, std::ratio<1>>;
|
||||
using us_t = duration<int, std::micro>;
|
||||
|
||||
auto dd = duration_cast<dd_t>(d);
|
||||
auto subd = d - dd;
|
||||
auto ss = duration_cast<ss_t>(subd);
|
||||
auto us = duration_cast<us_t>(subd - ss);
|
||||
return PyDelta_FromDSU(dd.count(), ss.count(), us.count());
|
||||
}
|
||||
|
||||
PYBIND11_TYPE_CASTER(type, _("datetime.timedelta"));
|
||||
PYBIND11_TYPE_CASTER(type, _("datetime.timedelta"));
|
||||
};
|
||||
|
||||
// This is for casting times on the system clock into datetime.datetime instances
|
||||
template <typename Duration> class type_caster<std::chrono::time_point<std::chrono::system_clock, Duration>> {
|
||||
// This is for casting times on the system clock into datetime.datetime
|
||||
// instances
|
||||
template <typename Duration>
|
||||
class type_caster<
|
||||
std::chrono::time_point<std::chrono::system_clock, Duration>> {
|
||||
public:
|
||||
typedef std::chrono::time_point<std::chrono::system_clock, Duration> type;
|
||||
bool load(handle src, bool) {
|
||||
using namespace std::chrono;
|
||||
typedef std::chrono::time_point<std::chrono::system_clock, Duration> type;
|
||||
bool load(handle src, bool) {
|
||||
using namespace std::chrono;
|
||||
|
||||
// Lazy initialise the PyDateTime import
|
||||
if (!PyDateTimeAPI) { PyDateTime_IMPORT; }
|
||||
|
||||
if (!src) return false;
|
||||
if (PyDateTime_Check(src.ptr())) {
|
||||
std::tm cal;
|
||||
cal.tm_sec = PyDateTime_DATE_GET_SECOND(src.ptr());
|
||||
cal.tm_min = PyDateTime_DATE_GET_MINUTE(src.ptr());
|
||||
cal.tm_hour = PyDateTime_DATE_GET_HOUR(src.ptr());
|
||||
cal.tm_mday = PyDateTime_GET_DAY(src.ptr());
|
||||
cal.tm_mon = PyDateTime_GET_MONTH(src.ptr()) - 1;
|
||||
cal.tm_year = PyDateTime_GET_YEAR(src.ptr()) - 1900;
|
||||
cal.tm_isdst = -1;
|
||||
|
||||
value = system_clock::from_time_t(std::mktime(&cal)) + microseconds(PyDateTime_DATE_GET_MICROSECOND(src.ptr()));
|
||||
return true;
|
||||
}
|
||||
else return false;
|
||||
// Lazy initialise the PyDateTime import
|
||||
if (!PyDateTimeAPI) {
|
||||
PyDateTime_IMPORT;
|
||||
}
|
||||
|
||||
static handle cast(const std::chrono::time_point<std::chrono::system_clock, Duration> &src, return_value_policy /* policy */, handle /* parent */) {
|
||||
using namespace std::chrono;
|
||||
if (!src)
|
||||
return false;
|
||||
if (PyDateTime_Check(src.ptr())) {
|
||||
std::tm cal;
|
||||
cal.tm_sec = PyDateTime_DATE_GET_SECOND(src.ptr());
|
||||
cal.tm_min = PyDateTime_DATE_GET_MINUTE(src.ptr());
|
||||
cal.tm_hour = PyDateTime_DATE_GET_HOUR(src.ptr());
|
||||
cal.tm_mday = PyDateTime_GET_DAY(src.ptr());
|
||||
cal.tm_mon = PyDateTime_GET_MONTH(src.ptr()) - 1;
|
||||
cal.tm_year = PyDateTime_GET_YEAR(src.ptr()) - 1900;
|
||||
cal.tm_isdst = -1;
|
||||
|
||||
// Lazy initialise the PyDateTime import
|
||||
if (!PyDateTimeAPI) { PyDateTime_IMPORT; }
|
||||
value = system_clock::from_time_t(std::mktime(&cal)) +
|
||||
microseconds(PyDateTime_DATE_GET_MICROSECOND(src.ptr()));
|
||||
return true;
|
||||
} else
|
||||
return false;
|
||||
}
|
||||
|
||||
std::time_t tt = system_clock::to_time_t(src);
|
||||
// this function uses static memory so it's best to copy it out asap just in case
|
||||
// otherwise other code that is using localtime may break this (not just python code)
|
||||
std::tm localtime = *std::localtime(&tt);
|
||||
static handle
|
||||
cast(const std::chrono::time_point<std::chrono::system_clock, Duration> &src,
|
||||
return_value_policy /* policy */, handle /* parent */) {
|
||||
using namespace std::chrono;
|
||||
|
||||
// Declare these special duration types so the conversions happen with the correct primitive types (int)
|
||||
using us_t = duration<int, std::micro>;
|
||||
|
||||
return PyDateTime_FromDateAndTime(localtime.tm_year + 1900,
|
||||
localtime.tm_mon + 1,
|
||||
localtime.tm_mday,
|
||||
localtime.tm_hour,
|
||||
localtime.tm_min,
|
||||
localtime.tm_sec,
|
||||
(duration_cast<us_t>(src.time_since_epoch() % seconds(1))).count());
|
||||
// Lazy initialise the PyDateTime import
|
||||
if (!PyDateTimeAPI) {
|
||||
PyDateTime_IMPORT;
|
||||
}
|
||||
PYBIND11_TYPE_CASTER(type, _("datetime.datetime"));
|
||||
|
||||
std::time_t tt = system_clock::to_time_t(src);
|
||||
// this function uses static memory so it's best to copy it out asap just in
|
||||
// case otherwise other code that is using localtime may break this (not
|
||||
// just python code)
|
||||
std::tm localtime = *std::localtime(&tt);
|
||||
|
||||
// Declare these special duration types so the conversions happen with the
|
||||
// correct primitive types (int)
|
||||
using us_t = duration<int, std::micro>;
|
||||
|
||||
return PyDateTime_FromDateAndTime(
|
||||
localtime.tm_year + 1900, localtime.tm_mon + 1, localtime.tm_mday,
|
||||
localtime.tm_hour, localtime.tm_min, localtime.tm_sec,
|
||||
(duration_cast<us_t>(src.time_since_epoch() % seconds(1))).count());
|
||||
}
|
||||
PYBIND11_TYPE_CASTER(type, _("datetime.datetime"));
|
||||
};
|
||||
|
||||
// Other clocks that are not the system clock are not measured as datetime.datetime objects
|
||||
// since they are not measured on calendar time. So instead we just make them timedeltas
|
||||
// Or if they have passed us a time as a float we convert that
|
||||
template <typename Clock, typename Duration> class type_caster<std::chrono::time_point<Clock, Duration>>
|
||||
: public duration_caster<std::chrono::time_point<Clock, Duration>> {
|
||||
};
|
||||
// Other clocks that are not the system clock are not measured as
|
||||
// datetime.datetime objects since they are not measured on calendar time. So
|
||||
// instead we just make them timedeltas Or if they have passed us a time as a
|
||||
// float we convert that
|
||||
template <typename Clock, typename Duration>
|
||||
class type_caster<std::chrono::time_point<Clock, Duration>>
|
||||
: public duration_caster<std::chrono::time_point<Clock, Duration>> {};
|
||||
|
||||
template <typename Rep, typename Period> class type_caster<std::chrono::duration<Rep, Period>>
|
||||
: public duration_caster<std::chrono::duration<Rep, Period>> {
|
||||
};
|
||||
template <typename Rep, typename Period>
|
||||
class type_caster<std::chrono::duration<Rep, Period>>
|
||||
: public duration_caster<std::chrono::duration<Rep, Period>> {};
|
||||
|
||||
NAMESPACE_END(detail)
|
||||
NAMESPACE_END(PYBIND11_NAMESPACE)
|
||||
|
@@ -1,2 +1,3 @@
|
||||
#include "detail/common.h"
|
||||
#warning "Including 'common.h' is deprecated. It will be removed in v3.0. Use 'pybind11.h'."
|
||||
#warning \
|
||||
"Including 'common.h' is deprecated. It will be removed in v3.0. Use 'pybind11.h'."
|
||||
|
@@ -14,52 +14,59 @@
|
||||
|
||||
/// glibc defines I as a macro which breaks things, e.g., boost template names
|
||||
#ifdef I
|
||||
# undef I
|
||||
#undef I
|
||||
#endif
|
||||
|
||||
NAMESPACE_BEGIN(PYBIND11_NAMESPACE)
|
||||
|
||||
template <typename T> struct format_descriptor<std::complex<T>, detail::enable_if_t<std::is_floating_point<T>::value>> {
|
||||
static constexpr const char c = format_descriptor<T>::c;
|
||||
static constexpr const char value[3] = { 'Z', c, '\0' };
|
||||
static std::string format() { return std::string(value); }
|
||||
template <typename T>
|
||||
struct format_descriptor<
|
||||
std::complex<T>, detail::enable_if_t<std::is_floating_point<T>::value>> {
|
||||
static constexpr const char c = format_descriptor<T>::c;
|
||||
static constexpr const char value[3] = {'Z', c, '\0'};
|
||||
static std::string format() { return std::string(value); }
|
||||
};
|
||||
|
||||
#ifndef PYBIND11_CPP17
|
||||
|
||||
template <typename T> constexpr const char format_descriptor<
|
||||
std::complex<T>, detail::enable_if_t<std::is_floating_point<T>::value>>::value[3];
|
||||
template <typename T>
|
||||
constexpr const char format_descriptor<
|
||||
std::complex<T>,
|
||||
detail::enable_if_t<std::is_floating_point<T>::value>>::value[3];
|
||||
|
||||
#endif
|
||||
|
||||
NAMESPACE_BEGIN(detail)
|
||||
|
||||
template <typename T> struct is_fmt_numeric<std::complex<T>, detail::enable_if_t<std::is_floating_point<T>::value>> {
|
||||
static constexpr bool value = true;
|
||||
static constexpr int index = is_fmt_numeric<T>::index + 3;
|
||||
template <typename T>
|
||||
struct is_fmt_numeric<std::complex<T>,
|
||||
detail::enable_if_t<std::is_floating_point<T>::value>> {
|
||||
static constexpr bool value = true;
|
||||
static constexpr int index = is_fmt_numeric<T>::index + 3;
|
||||
};
|
||||
|
||||
template <typename T> class type_caster<std::complex<T>> {
|
||||
public:
|
||||
bool load(handle src, bool convert) {
|
||||
if (!src)
|
||||
return false;
|
||||
if (!convert && !PyComplex_Check(src.ptr()))
|
||||
return false;
|
||||
Py_complex result = PyComplex_AsCComplex(src.ptr());
|
||||
if (result.real == -1.0 && PyErr_Occurred()) {
|
||||
PyErr_Clear();
|
||||
return false;
|
||||
}
|
||||
value = std::complex<T>((T) result.real, (T) result.imag);
|
||||
return true;
|
||||
bool load(handle src, bool convert) {
|
||||
if (!src)
|
||||
return false;
|
||||
if (!convert && !PyComplex_Check(src.ptr()))
|
||||
return false;
|
||||
Py_complex result = PyComplex_AsCComplex(src.ptr());
|
||||
if (result.real == -1.0 && PyErr_Occurred()) {
|
||||
PyErr_Clear();
|
||||
return false;
|
||||
}
|
||||
value = std::complex<T>((T)result.real, (T)result.imag);
|
||||
return true;
|
||||
}
|
||||
|
||||
static handle cast(const std::complex<T> &src, return_value_policy /* policy */, handle /* parent */) {
|
||||
return PyComplex_FromDoubles((double) src.real(), (double) src.imag());
|
||||
}
|
||||
static handle cast(const std::complex<T> &src,
|
||||
return_value_policy /* policy */, handle /* parent */) {
|
||||
return PyComplex_FromDoubles((double)src.real(), (double)src.imag());
|
||||
}
|
||||
|
||||
PYBIND11_TYPE_CASTER(std::complex<T>, _("complex"));
|
||||
PYBIND11_TYPE_CASTER(std::complex<T>, _("complex"));
|
||||
};
|
||||
NAMESPACE_END(detail)
|
||||
NAMESPACE_END(PYBIND11_NAMESPACE)
|
||||
|
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
@@ -1,5 +1,6 @@
|
||||
/*
|
||||
pybind11/detail/descr.h: Helper type for concatenating type signatures at compile time
|
||||
pybind11/detail/descr.h: Helper type for concatenating type signatures at
|
||||
compile time
|
||||
|
||||
Copyright (c) 2016 Wenzel Jakob <wenzel.jakob@epfl.ch>
|
||||
|
||||
@@ -15,67 +16,82 @@ NAMESPACE_BEGIN(PYBIND11_NAMESPACE)
|
||||
NAMESPACE_BEGIN(detail)
|
||||
|
||||
#if !defined(_MSC_VER)
|
||||
# define PYBIND11_DESCR_CONSTEXPR static constexpr
|
||||
#define PYBIND11_DESCR_CONSTEXPR static constexpr
|
||||
#else
|
||||
# define PYBIND11_DESCR_CONSTEXPR const
|
||||
#define PYBIND11_DESCR_CONSTEXPR const
|
||||
#endif
|
||||
|
||||
/* Concatenate type signatures at compile time */
|
||||
template <size_t N, typename... Ts>
|
||||
struct descr {
|
||||
char text[N + 1];
|
||||
template <size_t N, typename... Ts> struct descr {
|
||||
char text[N + 1];
|
||||
|
||||
constexpr descr() : text{'\0'} { }
|
||||
constexpr descr(char const (&s)[N+1]) : descr(s, make_index_sequence<N>()) { }
|
||||
constexpr descr() : text{'\0'} {}
|
||||
constexpr descr(char const (&s)[N + 1])
|
||||
: descr(s, make_index_sequence<N>()) {}
|
||||
|
||||
template <size_t... Is>
|
||||
constexpr descr(char const (&s)[N+1], index_sequence<Is...>) : text{s[Is]..., '\0'} { }
|
||||
template <size_t... Is>
|
||||
constexpr descr(char const (&s)[N + 1], index_sequence<Is...>)
|
||||
: text{s[Is]..., '\0'} {}
|
||||
|
||||
template <typename... Chars>
|
||||
constexpr descr(char c, Chars... cs) : text{c, static_cast<char>(cs)..., '\0'} { }
|
||||
template <typename... Chars>
|
||||
constexpr descr(char c, Chars... cs)
|
||||
: text{c, static_cast<char>(cs)..., '\0'} {}
|
||||
|
||||
static constexpr std::array<const std::type_info *, sizeof...(Ts) + 1> types() {
|
||||
return {{&typeid(Ts)..., nullptr}};
|
||||
}
|
||||
static constexpr std::array<const std::type_info *, sizeof...(Ts) + 1>
|
||||
types() {
|
||||
return {{&typeid(Ts)..., nullptr}};
|
||||
}
|
||||
};
|
||||
|
||||
template <size_t N1, size_t N2, typename... Ts1, typename... Ts2, size_t... Is1, size_t... Is2>
|
||||
constexpr descr<N1 + N2, Ts1..., Ts2...> plus_impl(const descr<N1, Ts1...> &a, const descr<N2, Ts2...> &b,
|
||||
index_sequence<Is1...>, index_sequence<Is2...>) {
|
||||
return {a.text[Is1]..., b.text[Is2]...};
|
||||
template <size_t N1, size_t N2, typename... Ts1, typename... Ts2, size_t... Is1,
|
||||
size_t... Is2>
|
||||
constexpr descr<N1 + N2, Ts1..., Ts2...>
|
||||
plus_impl(const descr<N1, Ts1...> &a, const descr<N2, Ts2...> &b,
|
||||
index_sequence<Is1...>, index_sequence<Is2...>) {
|
||||
return {a.text[Is1]..., b.text[Is2]...};
|
||||
}
|
||||
|
||||
template <size_t N1, size_t N2, typename... Ts1, typename... Ts2>
|
||||
constexpr descr<N1 + N2, Ts1..., Ts2...> operator+(const descr<N1, Ts1...> &a, const descr<N2, Ts2...> &b) {
|
||||
return plus_impl(a, b, make_index_sequence<N1>(), make_index_sequence<N2>());
|
||||
constexpr descr<N1 + N2, Ts1..., Ts2...> operator+(const descr<N1, Ts1...> &a,
|
||||
const descr<N2, Ts2...> &b) {
|
||||
return plus_impl(a, b, make_index_sequence<N1>(), make_index_sequence<N2>());
|
||||
}
|
||||
|
||||
template <size_t N>
|
||||
constexpr descr<N - 1> _(char const(&text)[N]) { return descr<N - 1>(text); }
|
||||
constexpr descr<0> _(char const(&)[1]) { return {}; }
|
||||
template <size_t N> constexpr descr<N - 1> _(char const (&text)[N]) {
|
||||
return descr<N - 1>(text);
|
||||
}
|
||||
constexpr descr<0> _(char const (&)[1]) { return {}; }
|
||||
|
||||
template <size_t Rem, size_t... Digits> struct int_to_str : int_to_str<Rem/10, Rem%10, Digits...> { };
|
||||
template <size_t...Digits> struct int_to_str<0, Digits...> {
|
||||
static constexpr auto digits = descr<sizeof...(Digits)>(('0' + Digits)...);
|
||||
template <size_t Rem, size_t... Digits>
|
||||
struct int_to_str : int_to_str<Rem / 10, Rem % 10, Digits...> {};
|
||||
template <size_t... Digits> struct int_to_str<0, Digits...> {
|
||||
static constexpr auto digits = descr<sizeof...(Digits)>(('0' + Digits)...);
|
||||
};
|
||||
|
||||
// Ternary description (like std::conditional)
|
||||
template <bool B, size_t N1, size_t N2>
|
||||
constexpr enable_if_t<B, descr<N1 - 1>> _(char const(&text1)[N1], char const(&)[N2]) {
|
||||
return _(text1);
|
||||
constexpr enable_if_t<B, descr<N1 - 1>> _(char const (&text1)[N1],
|
||||
char const (&)[N2]) {
|
||||
return _(text1);
|
||||
}
|
||||
template <bool B, size_t N1, size_t N2>
|
||||
constexpr enable_if_t<!B, descr<N2 - 1>> _(char const(&)[N1], char const(&text2)[N2]) {
|
||||
return _(text2);
|
||||
constexpr enable_if_t<!B, descr<N2 - 1>> _(char const (&)[N1],
|
||||
char const (&text2)[N2]) {
|
||||
return _(text2);
|
||||
}
|
||||
|
||||
template <bool B, typename T1, typename T2>
|
||||
constexpr enable_if_t<B, T1> _(const T1 &d, const T2 &) { return d; }
|
||||
constexpr enable_if_t<B, T1> _(const T1 &d, const T2 &) {
|
||||
return d;
|
||||
}
|
||||
template <bool B, typename T1, typename T2>
|
||||
constexpr enable_if_t<!B, T2> _(const T1 &, const T2 &d) { return d; }
|
||||
constexpr enable_if_t<!B, T2> _(const T1 &, const T2 &d) {
|
||||
return d;
|
||||
}
|
||||
|
||||
template <size_t Size> auto constexpr _() -> decltype(int_to_str<Size / 10, Size % 10>::digits) {
|
||||
return int_to_str<Size / 10, Size % 10>::digits;
|
||||
template <size_t Size>
|
||||
auto constexpr _() -> decltype(int_to_str<Size / 10, Size % 10>::digits) {
|
||||
return int_to_str<Size / 10, Size % 10>::digits;
|
||||
}
|
||||
|
||||
template <typename Type> constexpr descr<1, Type> _() { return {'%'}; }
|
||||
@@ -83,17 +99,19 @@ template <typename Type> constexpr descr<1, Type> _() { return {'%'}; }
|
||||
constexpr descr<0> concat() { return {}; }
|
||||
|
||||
template <size_t N, typename... Ts>
|
||||
constexpr descr<N, Ts...> concat(const descr<N, Ts...> &descr) { return descr; }
|
||||
constexpr descr<N, Ts...> concat(const descr<N, Ts...> &descr) {
|
||||
return descr;
|
||||
}
|
||||
|
||||
template <size_t N, typename... Ts, typename... Args>
|
||||
constexpr auto concat(const descr<N, Ts...> &d, const Args &...args)
|
||||
constexpr auto concat(const descr<N, Ts...> &d, const Args &... args)
|
||||
-> decltype(std::declval<descr<N + 2, Ts...>>() + concat(args...)) {
|
||||
return d + _(", ") + concat(args...);
|
||||
return d + _(", ") + concat(args...);
|
||||
}
|
||||
|
||||
template <size_t N, typename... Ts>
|
||||
constexpr descr<N + 2, Ts...> type_descr(const descr<N, Ts...> &descr) {
|
||||
return _("{") + descr + _("}");
|
||||
return _("{") + descr + _("}");
|
||||
}
|
||||
|
||||
NAMESPACE_END(detail)
|
||||
|
@@ -1,5 +1,6 @@
|
||||
/*
|
||||
pybind11/detail/init.h: init factory function implementation and support code.
|
||||
pybind11/detail/init.h: init factory function implementation and support
|
||||
code.
|
||||
|
||||
Copyright (c) 2017 Jason Rhinelander <jason@imaginary.ca>
|
||||
|
||||
@@ -14,26 +15,26 @@
|
||||
NAMESPACE_BEGIN(PYBIND11_NAMESPACE)
|
||||
NAMESPACE_BEGIN(detail)
|
||||
|
||||
template <>
|
||||
class type_caster<value_and_holder> {
|
||||
template <> class type_caster<value_and_holder> {
|
||||
public:
|
||||
bool load(handle h, bool) {
|
||||
value = reinterpret_cast<value_and_holder *>(h.ptr());
|
||||
return true;
|
||||
}
|
||||
bool load(handle h, bool) {
|
||||
value = reinterpret_cast<value_and_holder *>(h.ptr());
|
||||
return true;
|
||||
}
|
||||
|
||||
template <typename> using cast_op_type = value_and_holder &;
|
||||
operator value_and_holder &() { return *value; }
|
||||
static constexpr auto name = _<value_and_holder>();
|
||||
template <typename> using cast_op_type = value_and_holder &;
|
||||
operator value_and_holder &() { return *value; }
|
||||
static constexpr auto name = _<value_and_holder>();
|
||||
|
||||
private:
|
||||
value_and_holder *value = nullptr;
|
||||
value_and_holder *value = nullptr;
|
||||
};
|
||||
|
||||
NAMESPACE_BEGIN(initimpl)
|
||||
|
||||
inline void no_nullptr(void *ptr) {
|
||||
if (!ptr) throw type_error("pybind11::init(): factory function returned nullptr");
|
||||
if (!ptr)
|
||||
throw type_error("pybind11::init(): factory function returned nullptr");
|
||||
}
|
||||
|
||||
// Implementing functions for all forms of py::init<...> and py::init(...)
|
||||
@@ -41,293 +42,372 @@ template <typename Class> using Cpp = typename Class::type;
|
||||
template <typename Class> using Alias = typename Class::type_alias;
|
||||
template <typename Class> using Holder = typename Class::holder_type;
|
||||
|
||||
template <typename Class> using is_alias_constructible = std::is_constructible<Alias<Class>, Cpp<Class> &&>;
|
||||
template <typename Class>
|
||||
using is_alias_constructible =
|
||||
std::is_constructible<Alias<Class>, Cpp<Class> &&>;
|
||||
|
||||
// Takes a Cpp pointer and returns true if it actually is a polymorphic Alias instance.
|
||||
// Takes a Cpp pointer and returns true if it actually is a polymorphic Alias
|
||||
// instance.
|
||||
template <typename Class, enable_if_t<Class::has_alias, int> = 0>
|
||||
bool is_alias(Cpp<Class> *ptr) {
|
||||
return dynamic_cast<Alias<Class> *>(ptr) != nullptr;
|
||||
return dynamic_cast<Alias<Class> *>(ptr) != nullptr;
|
||||
}
|
||||
// Failing fallback version of the above for a no-alias class (always returns false)
|
||||
template <typename /*Class*/>
|
||||
constexpr bool is_alias(void *) { return false; }
|
||||
// Failing fallback version of the above for a no-alias class (always returns
|
||||
// false)
|
||||
template <typename /*Class*/> constexpr bool is_alias(void *) { return false; }
|
||||
|
||||
// Constructs and returns a new object; if the given arguments don't map to a constructor, we fall
|
||||
// back to brace aggregate initiailization so that for aggregate initialization can be used with
|
||||
// py::init, e.g. `py::init<int, int>` to initialize a `struct T { int a; int b; }`. For
|
||||
// non-aggregate types, we need to use an ordinary T(...) constructor (invoking as `T{...}` usually
|
||||
// works, but will not do the expected thing when `T` has an `initializer_list<T>` constructor).
|
||||
template <typename Class, typename... Args, detail::enable_if_t<std::is_constructible<Class, Args...>::value, int> = 0>
|
||||
inline Class *construct_or_initialize(Args &&...args) { return new Class(std::forward<Args>(args)...); }
|
||||
template <typename Class, typename... Args, detail::enable_if_t<!std::is_constructible<Class, Args...>::value, int> = 0>
|
||||
inline Class *construct_or_initialize(Args &&...args) { return new Class{std::forward<Args>(args)...}; }
|
||||
// Constructs and returns a new object; if the given arguments don't map to a
|
||||
// constructor, we fall back to brace aggregate initiailization so that for
|
||||
// aggregate initialization can be used with py::init, e.g. `py::init<int,
|
||||
// int>` to initialize a `struct T { int a; int b; }`. For non-aggregate types,
|
||||
// we need to use an ordinary T(...) constructor (invoking as `T{...}` usually
|
||||
// works, but will not do the expected thing when `T` has an
|
||||
// `initializer_list<T>` constructor).
|
||||
template <
|
||||
typename Class, typename... Args,
|
||||
detail::enable_if_t<std::is_constructible<Class, Args...>::value, int> = 0>
|
||||
inline Class *construct_or_initialize(Args &&... args) {
|
||||
return new Class(std::forward<Args>(args)...);
|
||||
}
|
||||
template <
|
||||
typename Class, typename... Args,
|
||||
detail::enable_if_t<!std::is_constructible<Class, Args...>::value, int> = 0>
|
||||
inline Class *construct_or_initialize(Args &&... args) {
|
||||
return new Class{std::forward<Args>(args)...};
|
||||
}
|
||||
|
||||
// Attempts to constructs an alias using a `Alias(Cpp &&)` constructor. This allows types with
|
||||
// an alias to provide only a single Cpp factory function as long as the Alias can be
|
||||
// constructed from an rvalue reference of the base Cpp type. This means that Alias classes
|
||||
// can, when appropriate, simply define a `Alias(Cpp &&)` constructor rather than needing to
|
||||
// inherit all the base class constructors.
|
||||
// Attempts to constructs an alias using a `Alias(Cpp &&)` constructor. This
|
||||
// allows types with an alias to provide only a single Cpp factory function as
|
||||
// long as the Alias can be constructed from an rvalue reference of the base Cpp
|
||||
// type. This means that Alias classes can, when appropriate, simply define a
|
||||
// `Alias(Cpp &&)` constructor rather than needing to inherit all the base class
|
||||
// constructors.
|
||||
template <typename Class>
|
||||
void construct_alias_from_cpp(std::true_type /*is_alias_constructible*/,
|
||||
value_and_holder &v_h, Cpp<Class> &&base) {
|
||||
v_h.value_ptr() = new Alias<Class>(std::move(base));
|
||||
v_h.value_ptr() = new Alias<Class>(std::move(base));
|
||||
}
|
||||
template <typename Class>
|
||||
[[noreturn]] void construct_alias_from_cpp(std::false_type /*!is_alias_constructible*/,
|
||||
value_and_holder &, Cpp<Class> &&) {
|
||||
throw type_error("pybind11::init(): unable to convert returned instance to required "
|
||||
"alias class: no `Alias<Class>(Class &&)` constructor available");
|
||||
[[noreturn]] void
|
||||
construct_alias_from_cpp(std::false_type /*!is_alias_constructible*/,
|
||||
value_and_holder &, Cpp<Class> &&) {
|
||||
throw type_error(
|
||||
"pybind11::init(): unable to convert returned instance to required "
|
||||
"alias class: no `Alias<Class>(Class &&)` constructor available");
|
||||
}
|
||||
|
||||
// Error-generating fallback for factories that don't match one of the below construction
|
||||
// mechanisms.
|
||||
template <typename Class>
|
||||
void construct(...) {
|
||||
static_assert(!std::is_same<Class, Class>::value /* always false */,
|
||||
"pybind11::init(): init function must return a compatible pointer, "
|
||||
"holder, or value");
|
||||
// Error-generating fallback for factories that don't match one of the below
|
||||
// construction mechanisms.
|
||||
template <typename Class> void construct(...) {
|
||||
static_assert(
|
||||
!std::is_same<Class, Class>::value /* always false */,
|
||||
"pybind11::init(): init function must return a compatible pointer, "
|
||||
"holder, or value");
|
||||
}
|
||||
|
||||
// Pointer return v1: the factory function returns a class pointer for a registered class.
|
||||
// If we don't need an alias (because this class doesn't have one, or because the final type is
|
||||
// inherited on the Python side) we can simply take over ownership. Otherwise we need to try to
|
||||
// construct an Alias from the returned base instance.
|
||||
// Pointer return v1: the factory function returns a class pointer for a
|
||||
// registered class. If we don't need an alias (because this class doesn't have
|
||||
// one, or because the final type is inherited on the Python side) we can simply
|
||||
// take over ownership. Otherwise we need to try to construct an Alias from the
|
||||
// returned base instance.
|
||||
template <typename Class>
|
||||
void construct(value_and_holder &v_h, Cpp<Class> *ptr, bool need_alias) {
|
||||
no_nullptr(ptr);
|
||||
if (Class::has_alias && need_alias && !is_alias<Class>(ptr)) {
|
||||
// We're going to try to construct an alias by moving the cpp type. Whether or not
|
||||
// that succeeds, we still need to destroy the original cpp pointer (either the
|
||||
// moved away leftover, if the alias construction works, or the value itself if we
|
||||
// throw an error), but we can't just call `delete ptr`: it might have a special
|
||||
// deleter, or might be shared_from_this. So we construct a holder around it as if
|
||||
// it was a normal instance, then steal the holder away into a local variable; thus
|
||||
// the holder and destruction happens when we leave the C++ scope, and the holder
|
||||
// class gets to handle the destruction however it likes.
|
||||
v_h.value_ptr() = ptr;
|
||||
v_h.set_instance_registered(true); // To prevent init_instance from registering it
|
||||
v_h.type->init_instance(v_h.inst, nullptr); // Set up the holder
|
||||
Holder<Class> temp_holder(std::move(v_h.holder<Holder<Class>>())); // Steal the holder
|
||||
v_h.type->dealloc(v_h); // Destroys the moved-out holder remains, resets value ptr to null
|
||||
v_h.set_instance_registered(false);
|
||||
no_nullptr(ptr);
|
||||
if (Class::has_alias && need_alias && !is_alias<Class>(ptr)) {
|
||||
// We're going to try to construct an alias by moving the cpp type. Whether
|
||||
// or not that succeeds, we still need to destroy the original cpp pointer
|
||||
// (either the moved away leftover, if the alias construction works, or the
|
||||
// value itself if we throw an error), but we can't just call `delete ptr`:
|
||||
// it might have a special deleter, or might be shared_from_this. So we
|
||||
// construct a holder around it as if it was a normal instance, then steal
|
||||
// the holder away into a local variable; thus the holder and destruction
|
||||
// happens when we leave the C++ scope, and the holder class gets to handle
|
||||
// the destruction however it likes.
|
||||
v_h.value_ptr() = ptr;
|
||||
v_h.set_instance_registered(
|
||||
true); // To prevent init_instance from registering it
|
||||
v_h.type->init_instance(v_h.inst, nullptr); // Set up the holder
|
||||
Holder<Class> temp_holder(
|
||||
std::move(v_h.holder<Holder<Class>>())); // Steal the holder
|
||||
v_h.type->dealloc(
|
||||
v_h); // Destroys the moved-out holder remains, resets value ptr to null
|
||||
v_h.set_instance_registered(false);
|
||||
|
||||
construct_alias_from_cpp<Class>(is_alias_constructible<Class>{}, v_h, std::move(*ptr));
|
||||
} else {
|
||||
// Otherwise the type isn't inherited, so we don't need an Alias
|
||||
v_h.value_ptr() = ptr;
|
||||
}
|
||||
construct_alias_from_cpp<Class>(is_alias_constructible<Class>{}, v_h,
|
||||
std::move(*ptr));
|
||||
} else {
|
||||
// Otherwise the type isn't inherited, so we don't need an Alias
|
||||
v_h.value_ptr() = ptr;
|
||||
}
|
||||
}
|
||||
|
||||
// Pointer return v2: a factory that always returns an alias instance ptr. We simply take over
|
||||
// ownership of the pointer.
|
||||
// Pointer return v2: a factory that always returns an alias instance ptr. We
|
||||
// simply take over ownership of the pointer.
|
||||
template <typename Class, enable_if_t<Class::has_alias, int> = 0>
|
||||
void construct(value_and_holder &v_h, Alias<Class> *alias_ptr, bool) {
|
||||
no_nullptr(alias_ptr);
|
||||
v_h.value_ptr() = static_cast<Cpp<Class> *>(alias_ptr);
|
||||
no_nullptr(alias_ptr);
|
||||
v_h.value_ptr() = static_cast<Cpp<Class> *>(alias_ptr);
|
||||
}
|
||||
|
||||
// Holder return: copy its pointer, and move or copy the returned holder into the new instance's
|
||||
// holder. This also handles types like std::shared_ptr<T> and std::unique_ptr<T> where T is a
|
||||
// derived type (through those holder's implicit conversion from derived class holder constructors).
|
||||
// Holder return: copy its pointer, and move or copy the returned holder into
|
||||
// the new instance's holder. This also handles types like std::shared_ptr<T>
|
||||
// and std::unique_ptr<T> where T is a derived type (through those holder's
|
||||
// implicit conversion from derived class holder constructors).
|
||||
template <typename Class>
|
||||
void construct(value_and_holder &v_h, Holder<Class> holder, bool need_alias) {
|
||||
auto *ptr = holder_helper<Holder<Class>>::get(holder);
|
||||
// If we need an alias, check that the held pointer is actually an alias instance
|
||||
if (Class::has_alias && need_alias && !is_alias<Class>(ptr))
|
||||
throw type_error("pybind11::init(): construction failed: returned holder-wrapped instance "
|
||||
"is not an alias instance");
|
||||
auto *ptr = holder_helper<Holder<Class>>::get(holder);
|
||||
// If we need an alias, check that the held pointer is actually an alias
|
||||
// instance
|
||||
if (Class::has_alias && need_alias && !is_alias<Class>(ptr))
|
||||
throw type_error("pybind11::init(): construction failed: returned "
|
||||
"holder-wrapped instance "
|
||||
"is not an alias instance");
|
||||
|
||||
v_h.value_ptr() = ptr;
|
||||
v_h.type->init_instance(v_h.inst, &holder);
|
||||
v_h.value_ptr() = ptr;
|
||||
v_h.type->init_instance(v_h.inst, &holder);
|
||||
}
|
||||
|
||||
// return-by-value version 1: returning a cpp class by value. If the class has an alias and an
|
||||
// alias is required the alias must have an `Alias(Cpp &&)` constructor so that we can construct
|
||||
// the alias from the base when needed (i.e. because of Python-side inheritance). When we don't
|
||||
// need it, we simply move-construct the cpp value into a new instance.
|
||||
// return-by-value version 1: returning a cpp class by value. If the class has
|
||||
// an alias and an alias is required the alias must have an `Alias(Cpp &&)`
|
||||
// constructor so that we can construct the alias from the base when needed
|
||||
// (i.e. because of Python-side inheritance). When we don't need it, we simply
|
||||
// move-construct the cpp value into a new instance.
|
||||
template <typename Class>
|
||||
void construct(value_and_holder &v_h, Cpp<Class> &&result, bool need_alias) {
|
||||
static_assert(std::is_move_constructible<Cpp<Class>>::value,
|
||||
"pybind11::init() return-by-value factory function requires a movable class");
|
||||
if (Class::has_alias && need_alias)
|
||||
construct_alias_from_cpp<Class>(is_alias_constructible<Class>{}, v_h, std::move(result));
|
||||
else
|
||||
v_h.value_ptr() = new Cpp<Class>(std::move(result));
|
||||
static_assert(std::is_move_constructible<Cpp<Class>>::value,
|
||||
"pybind11::init() return-by-value factory function requires a "
|
||||
"movable class");
|
||||
if (Class::has_alias && need_alias)
|
||||
construct_alias_from_cpp<Class>(is_alias_constructible<Class>{}, v_h,
|
||||
std::move(result));
|
||||
else
|
||||
v_h.value_ptr() = new Cpp<Class>(std::move(result));
|
||||
}
|
||||
|
||||
// return-by-value version 2: returning a value of the alias type itself. We move-construct an
|
||||
// Alias instance (even if no the python-side inheritance is involved). The is intended for
|
||||
// cases where Alias initialization is always desired.
|
||||
// return-by-value version 2: returning a value of the alias type itself. We
|
||||
// move-construct an Alias instance (even if no the python-side inheritance is
|
||||
// involved). The is intended for cases where Alias initialization is always
|
||||
// desired.
|
||||
template <typename Class>
|
||||
void construct(value_and_holder &v_h, Alias<Class> &&result, bool) {
|
||||
static_assert(std::is_move_constructible<Alias<Class>>::value,
|
||||
"pybind11::init() return-by-alias-value factory function requires a movable alias class");
|
||||
v_h.value_ptr() = new Alias<Class>(std::move(result));
|
||||
static_assert(std::is_move_constructible<Alias<Class>>::value,
|
||||
"pybind11::init() return-by-alias-value factory function "
|
||||
"requires a movable alias class");
|
||||
v_h.value_ptr() = new Alias<Class>(std::move(result));
|
||||
}
|
||||
|
||||
// Implementing class for py::init<...>()
|
||||
template <typename... Args>
|
||||
struct constructor {
|
||||
template <typename Class, typename... Extra, enable_if_t<!Class::has_alias, int> = 0>
|
||||
static void execute(Class &cl, const Extra&... extra) {
|
||||
cl.def("__init__", [](value_and_holder &v_h, Args... args) {
|
||||
v_h.value_ptr() = construct_or_initialize<Cpp<Class>>(std::forward<Args>(args)...);
|
||||
}, is_new_style_constructor(), extra...);
|
||||
}
|
||||
template <typename... Args> struct constructor {
|
||||
template <typename Class, typename... Extra,
|
||||
enable_if_t<!Class::has_alias, int> = 0>
|
||||
static void execute(Class &cl, const Extra &... extra) {
|
||||
cl.def(
|
||||
"__init__",
|
||||
[](value_and_holder &v_h, Args... args) {
|
||||
v_h.value_ptr() =
|
||||
construct_or_initialize<Cpp<Class>>(std::forward<Args>(args)...);
|
||||
},
|
||||
is_new_style_constructor(), extra...);
|
||||
}
|
||||
|
||||
template <typename Class, typename... Extra,
|
||||
enable_if_t<Class::has_alias &&
|
||||
std::is_constructible<Cpp<Class>, Args...>::value, int> = 0>
|
||||
static void execute(Class &cl, const Extra&... extra) {
|
||||
cl.def("__init__", [](value_and_holder &v_h, Args... args) {
|
||||
if (Py_TYPE(v_h.inst) == v_h.type->type)
|
||||
v_h.value_ptr() = construct_or_initialize<Cpp<Class>>(std::forward<Args>(args)...);
|
||||
else
|
||||
v_h.value_ptr() = construct_or_initialize<Alias<Class>>(std::forward<Args>(args)...);
|
||||
}, is_new_style_constructor(), extra...);
|
||||
}
|
||||
template <typename Class, typename... Extra,
|
||||
enable_if_t<Class::has_alias &&
|
||||
std::is_constructible<Cpp<Class>, Args...>::value,
|
||||
int> = 0>
|
||||
static void execute(Class &cl, const Extra &... extra) {
|
||||
cl.def(
|
||||
"__init__",
|
||||
[](value_and_holder &v_h, Args... args) {
|
||||
if (Py_TYPE(v_h.inst) == v_h.type->type)
|
||||
v_h.value_ptr() = construct_or_initialize<Cpp<Class>>(
|
||||
std::forward<Args>(args)...);
|
||||
else
|
||||
v_h.value_ptr() = construct_or_initialize<Alias<Class>>(
|
||||
std::forward<Args>(args)...);
|
||||
},
|
||||
is_new_style_constructor(), extra...);
|
||||
}
|
||||
|
||||
template <typename Class, typename... Extra,
|
||||
enable_if_t<Class::has_alias &&
|
||||
!std::is_constructible<Cpp<Class>, Args...>::value, int> = 0>
|
||||
static void execute(Class &cl, const Extra&... extra) {
|
||||
cl.def("__init__", [](value_and_holder &v_h, Args... args) {
|
||||
v_h.value_ptr() = construct_or_initialize<Alias<Class>>(std::forward<Args>(args)...);
|
||||
}, is_new_style_constructor(), extra...);
|
||||
}
|
||||
template <typename Class, typename... Extra,
|
||||
enable_if_t<Class::has_alias &&
|
||||
!std::is_constructible<Cpp<Class>, Args...>::value,
|
||||
int> = 0>
|
||||
static void execute(Class &cl, const Extra &... extra) {
|
||||
cl.def(
|
||||
"__init__",
|
||||
[](value_and_holder &v_h, Args... args) {
|
||||
v_h.value_ptr() = construct_or_initialize<Alias<Class>>(
|
||||
std::forward<Args>(args)...);
|
||||
},
|
||||
is_new_style_constructor(), extra...);
|
||||
}
|
||||
};
|
||||
|
||||
// Implementing class for py::init_alias<...>()
|
||||
template <typename... Args> struct alias_constructor {
|
||||
template <typename Class, typename... Extra,
|
||||
enable_if_t<Class::has_alias && std::is_constructible<Alias<Class>, Args...>::value, int> = 0>
|
||||
static void execute(Class &cl, const Extra&... extra) {
|
||||
cl.def("__init__", [](value_and_holder &v_h, Args... args) {
|
||||
v_h.value_ptr() = construct_or_initialize<Alias<Class>>(std::forward<Args>(args)...);
|
||||
}, is_new_style_constructor(), extra...);
|
||||
}
|
||||
template <typename Class, typename... Extra,
|
||||
enable_if_t<Class::has_alias &&
|
||||
std::is_constructible<Alias<Class>, Args...>::value,
|
||||
int> = 0>
|
||||
static void execute(Class &cl, const Extra &... extra) {
|
||||
cl.def(
|
||||
"__init__",
|
||||
[](value_and_holder &v_h, Args... args) {
|
||||
v_h.value_ptr() = construct_or_initialize<Alias<Class>>(
|
||||
std::forward<Args>(args)...);
|
||||
},
|
||||
is_new_style_constructor(), extra...);
|
||||
}
|
||||
};
|
||||
|
||||
// Implementation class for py::init(Func) and py::init(Func, AliasFunc)
|
||||
template <typename CFunc, typename AFunc = void_type (*)(),
|
||||
typename = function_signature_t<CFunc>, typename = function_signature_t<AFunc>>
|
||||
typename = function_signature_t<CFunc>,
|
||||
typename = function_signature_t<AFunc>>
|
||||
struct factory;
|
||||
|
||||
// Specialization for py::init(Func)
|
||||
template <typename Func, typename Return, typename... Args>
|
||||
struct factory<Func, void_type (*)(), Return(Args...)> {
|
||||
remove_reference_t<Func> class_factory;
|
||||
remove_reference_t<Func> class_factory;
|
||||
|
||||
factory(Func &&f) : class_factory(std::forward<Func>(f)) { }
|
||||
factory(Func &&f) : class_factory(std::forward<Func>(f)) {}
|
||||
|
||||
// The given class either has no alias or has no separate alias factory;
|
||||
// this always constructs the class itself. If the class is registered with an alias
|
||||
// type and an alias instance is needed (i.e. because the final type is a Python class
|
||||
// inheriting from the C++ type) the returned value needs to either already be an alias
|
||||
// instance, or the alias needs to be constructible from a `Class &&` argument.
|
||||
template <typename Class, typename... Extra>
|
||||
void execute(Class &cl, const Extra &...extra) && {
|
||||
#if defined(PYBIND11_CPP14)
|
||||
cl.def("__init__", [func = std::move(class_factory)]
|
||||
#else
|
||||
auto &func = class_factory;
|
||||
cl.def("__init__", [func]
|
||||
#endif
|
||||
// The given class either has no alias or has no separate alias factory;
|
||||
// this always constructs the class itself. If the class is registered with
|
||||
// an alias type and an alias instance is needed (i.e. because the final type
|
||||
// is a Python class inheriting from the C++ type) the returned value needs to
|
||||
// either already be an alias instance, or the alias needs to be constructible
|
||||
// from a `Class &&` argument.
|
||||
template <typename Class, typename... Extra>
|
||||
void execute(Class &cl, const Extra &... extra) && {
|
||||
#if defined(PYBIND11_CPP14)
|
||||
cl.def(
|
||||
"__init__",
|
||||
[func = std::move(class_factory)]
|
||||
#else
|
||||
auto &func = class_factory;
|
||||
cl.def(
|
||||
"__init__",
|
||||
[func]
|
||||
#endif
|
||||
(value_and_holder &v_h, Args... args) {
|
||||
construct<Class>(v_h, func(std::forward<Args>(args)...),
|
||||
Py_TYPE(v_h.inst) != v_h.type->type);
|
||||
}, is_new_style_constructor(), extra...);
|
||||
}
|
||||
construct<Class>(v_h, func(std::forward<Args>(args)...),
|
||||
Py_TYPE(v_h.inst) != v_h.type->type);
|
||||
},
|
||||
is_new_style_constructor(), extra...);
|
||||
}
|
||||
};
|
||||
|
||||
// Specialization for py::init(Func, AliasFunc)
|
||||
template <typename CFunc, typename AFunc,
|
||||
typename CReturn, typename... CArgs, typename AReturn, typename... AArgs>
|
||||
template <typename CFunc, typename AFunc, typename CReturn, typename... CArgs,
|
||||
typename AReturn, typename... AArgs>
|
||||
struct factory<CFunc, AFunc, CReturn(CArgs...), AReturn(AArgs...)> {
|
||||
static_assert(sizeof...(CArgs) == sizeof...(AArgs),
|
||||
"pybind11::init(class_factory, alias_factory): class and alias factories "
|
||||
"must have identical argument signatures");
|
||||
static_assert(all_of<std::is_same<CArgs, AArgs>...>::value,
|
||||
"pybind11::init(class_factory, alias_factory): class and alias factories "
|
||||
"must have identical argument signatures");
|
||||
static_assert(
|
||||
sizeof...(CArgs) == sizeof...(AArgs),
|
||||
"pybind11::init(class_factory, alias_factory): class and alias factories "
|
||||
"must have identical argument signatures");
|
||||
static_assert(
|
||||
all_of<std::is_same<CArgs, AArgs>...>::value,
|
||||
"pybind11::init(class_factory, alias_factory): class and alias factories "
|
||||
"must have identical argument signatures");
|
||||
|
||||
remove_reference_t<CFunc> class_factory;
|
||||
remove_reference_t<AFunc> alias_factory;
|
||||
remove_reference_t<CFunc> class_factory;
|
||||
remove_reference_t<AFunc> alias_factory;
|
||||
|
||||
factory(CFunc &&c, AFunc &&a)
|
||||
: class_factory(std::forward<CFunc>(c)), alias_factory(std::forward<AFunc>(a)) { }
|
||||
factory(CFunc &&c, AFunc &&a)
|
||||
: class_factory(std::forward<CFunc>(c)),
|
||||
alias_factory(std::forward<AFunc>(a)) {}
|
||||
|
||||
// The class factory is called when the `self` type passed to `__init__` is the direct
|
||||
// class (i.e. not inherited), the alias factory when `self` is a Python-side subtype.
|
||||
template <typename Class, typename... Extra>
|
||||
void execute(Class &cl, const Extra&... extra) && {
|
||||
static_assert(Class::has_alias, "The two-argument version of `py::init()` can "
|
||||
"only be used if the class has an alias");
|
||||
#if defined(PYBIND11_CPP14)
|
||||
cl.def("__init__", [class_func = std::move(class_factory), alias_func = std::move(alias_factory)]
|
||||
#else
|
||||
auto &class_func = class_factory;
|
||||
auto &alias_func = alias_factory;
|
||||
cl.def("__init__", [class_func, alias_func]
|
||||
#endif
|
||||
// The class factory is called when the `self` type passed to `__init__` is
|
||||
// the direct class (i.e. not inherited), the alias factory when `self` is a
|
||||
// Python-side subtype.
|
||||
template <typename Class, typename... Extra>
|
||||
void execute(Class &cl, const Extra &... extra) && {
|
||||
static_assert(Class::has_alias,
|
||||
"The two-argument version of `py::init()` can "
|
||||
"only be used if the class has an alias");
|
||||
#if defined(PYBIND11_CPP14)
|
||||
cl.def(
|
||||
"__init__",
|
||||
[class_func = std::move(class_factory),
|
||||
alias_func = std::move(alias_factory)]
|
||||
#else
|
||||
auto &class_func = class_factory;
|
||||
auto &alias_func = alias_factory;
|
||||
cl.def(
|
||||
"__init__",
|
||||
[class_func, alias_func]
|
||||
#endif
|
||||
(value_and_holder &v_h, CArgs... args) {
|
||||
if (Py_TYPE(v_h.inst) == v_h.type->type)
|
||||
// If the instance type equals the registered type we don't have inheritance, so
|
||||
// don't need the alias and can construct using the class function:
|
||||
construct<Class>(v_h, class_func(std::forward<CArgs>(args)...), false);
|
||||
else
|
||||
construct<Class>(v_h, alias_func(std::forward<CArgs>(args)...), true);
|
||||
}, is_new_style_constructor(), extra...);
|
||||
}
|
||||
if (Py_TYPE(v_h.inst) == v_h.type->type)
|
||||
// If the instance type equals the registered type we don't have
|
||||
// inheritance, so don't need the alias and can construct using the
|
||||
// class function:
|
||||
construct<Class>(v_h, class_func(std::forward<CArgs>(args)...),
|
||||
false);
|
||||
else
|
||||
construct<Class>(v_h, alias_func(std::forward<CArgs>(args)...),
|
||||
true);
|
||||
},
|
||||
is_new_style_constructor(), extra...);
|
||||
}
|
||||
};
|
||||
|
||||
/// Set just the C++ state. Same as `__init__`.
|
||||
template <typename Class, typename T>
|
||||
void setstate(value_and_holder &v_h, T &&result, bool need_alias) {
|
||||
construct<Class>(v_h, std::forward<T>(result), need_alias);
|
||||
construct<Class>(v_h, std::forward<T>(result), need_alias);
|
||||
}
|
||||
|
||||
/// Set both the C++ and Python states
|
||||
template <typename Class, typename T, typename O,
|
||||
enable_if_t<std::is_convertible<O, handle>::value, int> = 0>
|
||||
void setstate(value_and_holder &v_h, std::pair<T, O> &&result, bool need_alias) {
|
||||
construct<Class>(v_h, std::move(result.first), need_alias);
|
||||
setattr((PyObject *) v_h.inst, "__dict__", result.second);
|
||||
void setstate(value_and_holder &v_h, std::pair<T, O> &&result,
|
||||
bool need_alias) {
|
||||
construct<Class>(v_h, std::move(result.first), need_alias);
|
||||
setattr((PyObject *)v_h.inst, "__dict__", result.second);
|
||||
}
|
||||
|
||||
/// Implementation for py::pickle(GetState, SetState)
|
||||
template <typename Get, typename Set,
|
||||
typename = function_signature_t<Get>, typename = function_signature_t<Set>>
|
||||
template <typename Get, typename Set, typename = function_signature_t<Get>,
|
||||
typename = function_signature_t<Set>>
|
||||
struct pickle_factory;
|
||||
|
||||
template <typename Get, typename Set,
|
||||
typename RetState, typename Self, typename NewInstance, typename ArgState>
|
||||
template <typename Get, typename Set, typename RetState, typename Self,
|
||||
typename NewInstance, typename ArgState>
|
||||
struct pickle_factory<Get, Set, RetState(Self), NewInstance(ArgState)> {
|
||||
static_assert(std::is_same<intrinsic_t<RetState>, intrinsic_t<ArgState>>::value,
|
||||
"The type returned by `__getstate__` must be the same "
|
||||
"as the argument accepted by `__setstate__`");
|
||||
static_assert(
|
||||
std::is_same<intrinsic_t<RetState>, intrinsic_t<ArgState>>::value,
|
||||
"The type returned by `__getstate__` must be the same "
|
||||
"as the argument accepted by `__setstate__`");
|
||||
|
||||
remove_reference_t<Get> get;
|
||||
remove_reference_t<Set> set;
|
||||
remove_reference_t<Get> get;
|
||||
remove_reference_t<Set> set;
|
||||
|
||||
pickle_factory(Get get, Set set)
|
||||
: get(std::forward<Get>(get)), set(std::forward<Set>(set)) { }
|
||||
pickle_factory(Get get, Set set)
|
||||
: get(std::forward<Get>(get)), set(std::forward<Set>(set)) {}
|
||||
|
||||
template <typename Class, typename... Extra>
|
||||
void execute(Class &cl, const Extra &...extra) && {
|
||||
cl.def("__getstate__", std::move(get));
|
||||
template <typename Class, typename... Extra>
|
||||
void execute(Class &cl, const Extra &... extra) && {
|
||||
cl.def("__getstate__", std::move(get));
|
||||
|
||||
#if defined(PYBIND11_CPP14)
|
||||
cl.def("__setstate__", [func = std::move(set)]
|
||||
cl.def(
|
||||
"__setstate__",
|
||||
[func = std::move(set)]
|
||||
#else
|
||||
auto &func = set;
|
||||
cl.def("__setstate__", [func]
|
||||
auto &func = set;
|
||||
cl.def(
|
||||
"__setstate__",
|
||||
[func]
|
||||
#endif
|
||||
(value_and_holder &v_h, ArgState state) {
|
||||
setstate<Class>(v_h, func(std::forward<ArgState>(state)),
|
||||
Py_TYPE(v_h.inst) != v_h.type->type);
|
||||
}, is_new_style_constructor(), extra...);
|
||||
}
|
||||
setstate<Class>(v_h, func(std::forward<ArgState>(state)),
|
||||
Py_TYPE(v_h.inst) != v_h.type->type);
|
||||
},
|
||||
is_new_style_constructor(), extra...);
|
||||
}
|
||||
};
|
||||
|
||||
NAMESPACE_END(initimpl)
|
||||
|
@@ -18,276 +18,323 @@ inline PyTypeObject *make_static_property_type();
|
||||
inline PyTypeObject *make_default_metaclass();
|
||||
inline PyObject *make_object_base_type(PyTypeObject *metaclass);
|
||||
|
||||
// The old Python Thread Local Storage (TLS) API is deprecated in Python 3.7 in favor of the new
|
||||
// Thread Specific Storage (TSS) API.
|
||||
// The old Python Thread Local Storage (TLS) API is deprecated in Python 3.7 in
|
||||
// favor of the new Thread Specific Storage (TSS) API.
|
||||
#if PY_VERSION_HEX >= 0x03070000
|
||||
# define PYBIND11_TLS_KEY_INIT(var) Py_tss_t *var = nullptr
|
||||
# define PYBIND11_TLS_GET_VALUE(key) PyThread_tss_get((key))
|
||||
# define PYBIND11_TLS_REPLACE_VALUE(key, value) PyThread_tss_set((key), (value))
|
||||
# define PYBIND11_TLS_DELETE_VALUE(key) PyThread_tss_set((key), nullptr)
|
||||
#define PYBIND11_TLS_KEY_INIT(var) Py_tss_t *var = nullptr
|
||||
#define PYBIND11_TLS_GET_VALUE(key) PyThread_tss_get((key))
|
||||
#define PYBIND11_TLS_REPLACE_VALUE(key, value) PyThread_tss_set((key), (value))
|
||||
#define PYBIND11_TLS_DELETE_VALUE(key) PyThread_tss_set((key), nullptr)
|
||||
#else
|
||||
// Usually an int but a long on Cygwin64 with Python 3.x
|
||||
# define PYBIND11_TLS_KEY_INIT(var) decltype(PyThread_create_key()) var = 0
|
||||
# define PYBIND11_TLS_GET_VALUE(key) PyThread_get_key_value((key))
|
||||
# if PY_MAJOR_VERSION < 3
|
||||
# define PYBIND11_TLS_DELETE_VALUE(key) \
|
||||
PyThread_delete_key_value(key)
|
||||
# define PYBIND11_TLS_REPLACE_VALUE(key, value) \
|
||||
do { \
|
||||
PyThread_delete_key_value((key)); \
|
||||
PyThread_set_key_value((key), (value)); \
|
||||
} while (false)
|
||||
# else
|
||||
# define PYBIND11_TLS_DELETE_VALUE(key) \
|
||||
PyThread_set_key_value((key), nullptr)
|
||||
# define PYBIND11_TLS_REPLACE_VALUE(key, value) \
|
||||
PyThread_set_key_value((key), (value))
|
||||
# endif
|
||||
// Usually an int but a long on Cygwin64 with Python 3.x
|
||||
#define PYBIND11_TLS_KEY_INIT(var) decltype(PyThread_create_key()) var = 0
|
||||
#define PYBIND11_TLS_GET_VALUE(key) PyThread_get_key_value((key))
|
||||
#if PY_MAJOR_VERSION < 3
|
||||
#define PYBIND11_TLS_DELETE_VALUE(key) PyThread_delete_key_value(key)
|
||||
#define PYBIND11_TLS_REPLACE_VALUE(key, value) \
|
||||
do { \
|
||||
PyThread_delete_key_value((key)); \
|
||||
PyThread_set_key_value((key), (value)); \
|
||||
} while (false)
|
||||
#else
|
||||
#define PYBIND11_TLS_DELETE_VALUE(key) PyThread_set_key_value((key), nullptr)
|
||||
#define PYBIND11_TLS_REPLACE_VALUE(key, value) \
|
||||
PyThread_set_key_value((key), (value))
|
||||
#endif
|
||||
#endif
|
||||
|
||||
// Python loads modules by default with dlopen with the RTLD_LOCAL flag; under libc++ and possibly
|
||||
// other STLs, this means `typeid(A)` from one module won't equal `typeid(A)` from another module
|
||||
// even when `A` is the same, non-hidden-visibility type (e.g. from a common include). Under
|
||||
// libstdc++, this doesn't happen: equality and the type_index hash are based on the type name,
|
||||
// which works. If not under a known-good stl, provide our own name-based hash and equality
|
||||
// functions that use the type name.
|
||||
// Python loads modules by default with dlopen with the RTLD_LOCAL flag; under
|
||||
// libc++ and possibly other STLs, this means `typeid(A)` from one module won't
|
||||
// equal `typeid(A)` from another module even when `A` is the same,
|
||||
// non-hidden-visibility type (e.g. from a common include). Under libstdc++,
|
||||
// this doesn't happen: equality and the type_index hash are based on the type
|
||||
// name, which works. If not under a known-good stl, provide our own name-based
|
||||
// hash and equality functions that use the type name.
|
||||
#if defined(__GLIBCXX__)
|
||||
inline bool same_type(const std::type_info &lhs, const std::type_info &rhs) { return lhs == rhs; }
|
||||
inline bool same_type(const std::type_info &lhs, const std::type_info &rhs) {
|
||||
return lhs == rhs;
|
||||
}
|
||||
using type_hash = std::hash<std::type_index>;
|
||||
using type_equal_to = std::equal_to<std::type_index>;
|
||||
#else
|
||||
inline bool same_type(const std::type_info &lhs, const std::type_info &rhs) {
|
||||
return lhs.name() == rhs.name() || std::strcmp(lhs.name(), rhs.name()) == 0;
|
||||
return lhs.name() == rhs.name() || std::strcmp(lhs.name(), rhs.name()) == 0;
|
||||
}
|
||||
|
||||
struct type_hash {
|
||||
size_t operator()(const std::type_index &t) const {
|
||||
size_t hash = 5381;
|
||||
const char *ptr = t.name();
|
||||
while (auto c = static_cast<unsigned char>(*ptr++))
|
||||
hash = (hash * 33) ^ c;
|
||||
return hash;
|
||||
}
|
||||
size_t operator()(const std::type_index &t) const {
|
||||
size_t hash = 5381;
|
||||
const char *ptr = t.name();
|
||||
while (auto c = static_cast<unsigned char>(*ptr++))
|
||||
hash = (hash * 33) ^ c;
|
||||
return hash;
|
||||
}
|
||||
};
|
||||
|
||||
struct type_equal_to {
|
||||
bool operator()(const std::type_index &lhs, const std::type_index &rhs) const {
|
||||
return lhs.name() == rhs.name() || std::strcmp(lhs.name(), rhs.name()) == 0;
|
||||
}
|
||||
bool operator()(const std::type_index &lhs,
|
||||
const std::type_index &rhs) const {
|
||||
return lhs.name() == rhs.name() || std::strcmp(lhs.name(), rhs.name()) == 0;
|
||||
}
|
||||
};
|
||||
#endif
|
||||
|
||||
template <typename value_type>
|
||||
using type_map = std::unordered_map<std::type_index, value_type, type_hash, type_equal_to>;
|
||||
using type_map =
|
||||
std::unordered_map<std::type_index, value_type, type_hash, type_equal_to>;
|
||||
|
||||
struct overload_hash {
|
||||
inline size_t operator()(const std::pair<const PyObject *, const char *>& v) const {
|
||||
size_t value = std::hash<const void *>()(v.first);
|
||||
value ^= std::hash<const void *>()(v.second) + 0x9e3779b9 + (value<<6) + (value>>2);
|
||||
return value;
|
||||
}
|
||||
inline size_t
|
||||
operator()(const std::pair<const PyObject *, const char *> &v) const {
|
||||
size_t value = std::hash<const void *>()(v.first);
|
||||
value ^= std::hash<const void *>()(v.second) + 0x9e3779b9 + (value << 6) +
|
||||
(value >> 2);
|
||||
return value;
|
||||
}
|
||||
};
|
||||
|
||||
/// Internal data structure used to track registered instances and types.
|
||||
/// Whenever binary incompatible changes are made to this structure,
|
||||
/// `PYBIND11_INTERNALS_VERSION` must be incremented.
|
||||
struct internals {
|
||||
type_map<type_info *> registered_types_cpp; // std::type_index -> pybind11's type information
|
||||
std::unordered_map<PyTypeObject *, std::vector<type_info *>> registered_types_py; // PyTypeObject* -> base type_info(s)
|
||||
std::unordered_multimap<const void *, instance*> registered_instances; // void * -> instance*
|
||||
std::unordered_set<std::pair<const PyObject *, const char *>, overload_hash> inactive_overload_cache;
|
||||
type_map<std::vector<bool (*)(PyObject *, void *&)>> direct_conversions;
|
||||
std::unordered_map<const PyObject *, std::vector<PyObject *>> patients;
|
||||
std::forward_list<void (*) (std::exception_ptr)> registered_exception_translators;
|
||||
std::unordered_map<std::string, void *> shared_data; // Custom data to be shared across extensions
|
||||
std::vector<PyObject *> loader_patient_stack; // Used by `loader_life_support`
|
||||
std::forward_list<std::string> static_strings; // Stores the std::strings backing detail::c_str()
|
||||
PyTypeObject *static_property_type;
|
||||
PyTypeObject *default_metaclass;
|
||||
PyObject *instance_base;
|
||||
type_map<type_info *>
|
||||
registered_types_cpp; // std::type_index -> pybind11's type information
|
||||
std::unordered_map<PyTypeObject *, std::vector<type_info *>>
|
||||
registered_types_py; // PyTypeObject* -> base type_info(s)
|
||||
std::unordered_multimap<const void *, instance *>
|
||||
registered_instances; // void * -> instance*
|
||||
std::unordered_set<std::pair<const PyObject *, const char *>, overload_hash>
|
||||
inactive_overload_cache;
|
||||
type_map<std::vector<bool (*)(PyObject *, void *&)>> direct_conversions;
|
||||
std::unordered_map<const PyObject *, std::vector<PyObject *>> patients;
|
||||
std::forward_list<void (*)(std::exception_ptr)>
|
||||
registered_exception_translators;
|
||||
std::unordered_map<std::string, void *>
|
||||
shared_data; // Custom data to be shared across extensions
|
||||
std::vector<PyObject *> loader_patient_stack; // Used by `loader_life_support`
|
||||
std::forward_list<std::string>
|
||||
static_strings; // Stores the std::strings backing detail::c_str()
|
||||
PyTypeObject *static_property_type;
|
||||
PyTypeObject *default_metaclass;
|
||||
PyObject *instance_base;
|
||||
#if defined(WITH_THREAD)
|
||||
PYBIND11_TLS_KEY_INIT(tstate);
|
||||
PyInterpreterState *istate = nullptr;
|
||||
PYBIND11_TLS_KEY_INIT(tstate);
|
||||
PyInterpreterState *istate = nullptr;
|
||||
#endif
|
||||
};
|
||||
|
||||
/// Additional type information which does not fit into the PyTypeObject.
|
||||
/// Changes to this struct also require bumping `PYBIND11_INTERNALS_VERSION`.
|
||||
struct type_info {
|
||||
PyTypeObject *type;
|
||||
const std::type_info *cpptype;
|
||||
size_t type_size, type_align, holder_size_in_ptrs;
|
||||
void *(*operator_new)(size_t);
|
||||
void (*init_instance)(instance *, const void *);
|
||||
void (*dealloc)(value_and_holder &v_h);
|
||||
std::vector<PyObject *(*)(PyObject *, PyTypeObject *)> implicit_conversions;
|
||||
std::vector<std::pair<const std::type_info *, void *(*)(void *)>> implicit_casts;
|
||||
std::vector<bool (*)(PyObject *, void *&)> *direct_conversions;
|
||||
buffer_info *(*get_buffer)(PyObject *, void *) = nullptr;
|
||||
void *get_buffer_data = nullptr;
|
||||
void *(*module_local_load)(PyObject *, const type_info *) = nullptr;
|
||||
/* A simple type never occurs as a (direct or indirect) parent
|
||||
* of a class that makes use of multiple inheritance */
|
||||
bool simple_type : 1;
|
||||
/* True if there is no multiple inheritance in this type's inheritance tree */
|
||||
bool simple_ancestors : 1;
|
||||
/* for base vs derived holder_type checks */
|
||||
bool default_holder : 1;
|
||||
/* true if this is a type registered with py::module_local */
|
||||
bool module_local : 1;
|
||||
PyTypeObject *type;
|
||||
const std::type_info *cpptype;
|
||||
size_t type_size, type_align, holder_size_in_ptrs;
|
||||
void *(*operator_new)(size_t);
|
||||
void (*init_instance)(instance *, const void *);
|
||||
void (*dealloc)(value_and_holder &v_h);
|
||||
std::vector<PyObject *(*)(PyObject *, PyTypeObject *)> implicit_conversions;
|
||||
std::vector<std::pair<const std::type_info *, void *(*)(void *)>>
|
||||
implicit_casts;
|
||||
std::vector<bool (*)(PyObject *, void *&)> *direct_conversions;
|
||||
buffer_info *(*get_buffer)(PyObject *, void *) = nullptr;
|
||||
void *get_buffer_data = nullptr;
|
||||
void *(*module_local_load)(PyObject *, const type_info *) = nullptr;
|
||||
/* A simple type never occurs as a (direct or indirect) parent
|
||||
* of a class that makes use of multiple inheritance */
|
||||
bool simple_type : 1;
|
||||
/* True if there is no multiple inheritance in this type's inheritance tree */
|
||||
bool simple_ancestors : 1;
|
||||
/* for base vs derived holder_type checks */
|
||||
bool default_holder : 1;
|
||||
/* true if this is a type registered with py::module_local */
|
||||
bool module_local : 1;
|
||||
};
|
||||
|
||||
/// Tracks the `internals` and `type_info` ABI version independent of the main library version
|
||||
/// Tracks the `internals` and `type_info` ABI version independent of the main
|
||||
/// library version
|
||||
#define PYBIND11_INTERNALS_VERSION 3
|
||||
|
||||
#if defined(_DEBUG)
|
||||
# define PYBIND11_BUILD_TYPE "_debug"
|
||||
#define PYBIND11_BUILD_TYPE "_debug"
|
||||
#else
|
||||
# define PYBIND11_BUILD_TYPE ""
|
||||
#define PYBIND11_BUILD_TYPE ""
|
||||
#endif
|
||||
|
||||
#if defined(WITH_THREAD)
|
||||
# define PYBIND11_INTERNALS_KIND ""
|
||||
#define PYBIND11_INTERNALS_KIND ""
|
||||
#else
|
||||
# define PYBIND11_INTERNALS_KIND "_without_thread"
|
||||
#define PYBIND11_INTERNALS_KIND "_without_thread"
|
||||
#endif
|
||||
|
||||
#define PYBIND11_INTERNALS_ID "__pybind11_internals_v" \
|
||||
PYBIND11_TOSTRING(PYBIND11_INTERNALS_VERSION) PYBIND11_INTERNALS_KIND PYBIND11_BUILD_TYPE "__"
|
||||
#define PYBIND11_INTERNALS_ID \
|
||||
"__pybind11_internals_v" PYBIND11_TOSTRING(PYBIND11_INTERNALS_VERSION) \
|
||||
PYBIND11_INTERNALS_KIND PYBIND11_BUILD_TYPE "__"
|
||||
|
||||
#define PYBIND11_MODULE_LOCAL_ID "__pybind11_module_local_v" \
|
||||
PYBIND11_TOSTRING(PYBIND11_INTERNALS_VERSION) PYBIND11_INTERNALS_KIND PYBIND11_BUILD_TYPE "__"
|
||||
#define PYBIND11_MODULE_LOCAL_ID \
|
||||
"__pybind11_module_local_v" PYBIND11_TOSTRING(PYBIND11_INTERNALS_VERSION) \
|
||||
PYBIND11_INTERNALS_KIND PYBIND11_BUILD_TYPE "__"
|
||||
|
||||
/// Each module locally stores a pointer to the `internals` data. The data
|
||||
/// itself is shared among modules with the same `PYBIND11_INTERNALS_ID`.
|
||||
inline internals **&get_internals_pp() {
|
||||
static internals **internals_pp = nullptr;
|
||||
return internals_pp;
|
||||
static internals **internals_pp = nullptr;
|
||||
return internals_pp;
|
||||
}
|
||||
|
||||
/// Return a reference to the current `internals` data
|
||||
PYBIND11_NOINLINE inline internals &get_internals() {
|
||||
auto **&internals_pp = get_internals_pp();
|
||||
if (internals_pp && *internals_pp)
|
||||
return **internals_pp;
|
||||
|
||||
constexpr auto *id = PYBIND11_INTERNALS_ID;
|
||||
auto builtins = handle(PyEval_GetBuiltins());
|
||||
if (builtins.contains(id) && isinstance<capsule>(builtins[id])) {
|
||||
internals_pp = static_cast<internals **>(capsule(builtins[id]));
|
||||
|
||||
// We loaded builtins through python's builtins, which means that our `error_already_set`
|
||||
// and `builtin_exception` may be different local classes than the ones set up in the
|
||||
// initial exception translator, below, so add another for our local exception classes.
|
||||
//
|
||||
// libstdc++ doesn't require this (types there are identified only by name)
|
||||
#if !defined(__GLIBCXX__)
|
||||
(*internals_pp)->registered_exception_translators.push_front(
|
||||
[](std::exception_ptr p) -> void {
|
||||
try {
|
||||
if (p) std::rethrow_exception(p);
|
||||
} catch (error_already_set &e) { e.restore(); return;
|
||||
} catch (const builtin_exception &e) { e.set_error(); return;
|
||||
}
|
||||
}
|
||||
);
|
||||
#endif
|
||||
} else {
|
||||
if (!internals_pp) internals_pp = new internals*();
|
||||
auto *&internals_ptr = *internals_pp;
|
||||
internals_ptr = new internals();
|
||||
#if defined(WITH_THREAD)
|
||||
#if PY_VERSION_HEX < 0x03090000
|
||||
PyEval_InitThreads();
|
||||
#endif
|
||||
PyThreadState *tstate = PyThreadState_Get();
|
||||
#if PY_VERSION_HEX >= 0x03070000
|
||||
internals_ptr->tstate = PyThread_tss_alloc();
|
||||
if (!internals_ptr->tstate || PyThread_tss_create(internals_ptr->tstate))
|
||||
pybind11_fail("get_internals: could not successfully initialize the TSS key!");
|
||||
PyThread_tss_set(internals_ptr->tstate, tstate);
|
||||
#else
|
||||
internals_ptr->tstate = PyThread_create_key();
|
||||
if (internals_ptr->tstate == -1)
|
||||
pybind11_fail("get_internals: could not successfully initialize the TLS key!");
|
||||
PyThread_set_key_value(internals_ptr->tstate, tstate);
|
||||
#endif
|
||||
internals_ptr->istate = tstate->interp;
|
||||
#endif
|
||||
builtins[id] = capsule(internals_pp);
|
||||
internals_ptr->registered_exception_translators.push_front(
|
||||
[](std::exception_ptr p) -> void {
|
||||
try {
|
||||
if (p) std::rethrow_exception(p);
|
||||
} catch (error_already_set &e) { e.restore(); return;
|
||||
} catch (const builtin_exception &e) { e.set_error(); return;
|
||||
} catch (const std::bad_alloc &e) { PyErr_SetString(PyExc_MemoryError, e.what()); return;
|
||||
} catch (const std::domain_error &e) { PyErr_SetString(PyExc_ValueError, e.what()); return;
|
||||
} catch (const std::invalid_argument &e) { PyErr_SetString(PyExc_ValueError, e.what()); return;
|
||||
} catch (const std::length_error &e) { PyErr_SetString(PyExc_ValueError, e.what()); return;
|
||||
} catch (const std::out_of_range &e) { PyErr_SetString(PyExc_IndexError, e.what()); return;
|
||||
} catch (const std::range_error &e) { PyErr_SetString(PyExc_ValueError, e.what()); return;
|
||||
} catch (const std::exception &e) { PyErr_SetString(PyExc_RuntimeError, e.what()); return;
|
||||
} catch (...) {
|
||||
PyErr_SetString(PyExc_RuntimeError, "Caught an unknown exception!");
|
||||
return;
|
||||
}
|
||||
}
|
||||
);
|
||||
internals_ptr->static_property_type = make_static_property_type();
|
||||
internals_ptr->default_metaclass = make_default_metaclass();
|
||||
internals_ptr->instance_base = make_object_base_type(internals_ptr->default_metaclass);
|
||||
}
|
||||
auto **&internals_pp = get_internals_pp();
|
||||
if (internals_pp && *internals_pp)
|
||||
return **internals_pp;
|
||||
|
||||
constexpr auto *id = PYBIND11_INTERNALS_ID;
|
||||
auto builtins = handle(PyEval_GetBuiltins());
|
||||
if (builtins.contains(id) && isinstance<capsule>(builtins[id])) {
|
||||
internals_pp = static_cast<internals **>(capsule(builtins[id]));
|
||||
|
||||
// We loaded builtins through python's builtins, which means that our
|
||||
// `error_already_set` and `builtin_exception` may be different local
|
||||
// classes than the ones set up in the initial exception translator, below,
|
||||
// so add another for our local exception classes.
|
||||
//
|
||||
// libstdc++ doesn't require this (types there are identified only by name)
|
||||
#if !defined(__GLIBCXX__)
|
||||
(*internals_pp)
|
||||
->registered_exception_translators.push_front(
|
||||
[](std::exception_ptr p) -> void {
|
||||
try {
|
||||
if (p)
|
||||
std::rethrow_exception(p);
|
||||
} catch (error_already_set &e) {
|
||||
e.restore();
|
||||
return;
|
||||
} catch (const builtin_exception &e) {
|
||||
e.set_error();
|
||||
return;
|
||||
}
|
||||
});
|
||||
#endif
|
||||
} else {
|
||||
if (!internals_pp)
|
||||
internals_pp = new internals *();
|
||||
auto *&internals_ptr = *internals_pp;
|
||||
internals_ptr = new internals();
|
||||
#if defined(WITH_THREAD)
|
||||
#if PY_VERSION_HEX < 0x03090000
|
||||
PyEval_InitThreads();
|
||||
#endif
|
||||
PyThreadState *tstate = PyThreadState_Get();
|
||||
#if PY_VERSION_HEX >= 0x03070000
|
||||
internals_ptr->tstate = PyThread_tss_alloc();
|
||||
if (!internals_ptr->tstate || PyThread_tss_create(internals_ptr->tstate))
|
||||
pybind11_fail(
|
||||
"get_internals: could not successfully initialize the TSS key!");
|
||||
PyThread_tss_set(internals_ptr->tstate, tstate);
|
||||
#else
|
||||
internals_ptr->tstate = PyThread_create_key();
|
||||
if (internals_ptr->tstate == -1)
|
||||
pybind11_fail(
|
||||
"get_internals: could not successfully initialize the TLS key!");
|
||||
PyThread_set_key_value(internals_ptr->tstate, tstate);
|
||||
#endif
|
||||
internals_ptr->istate = tstate->interp;
|
||||
#endif
|
||||
builtins[id] = capsule(internals_pp);
|
||||
internals_ptr->registered_exception_translators.push_front(
|
||||
[](std::exception_ptr p) -> void {
|
||||
try {
|
||||
if (p)
|
||||
std::rethrow_exception(p);
|
||||
} catch (error_already_set &e) {
|
||||
e.restore();
|
||||
return;
|
||||
} catch (const builtin_exception &e) {
|
||||
e.set_error();
|
||||
return;
|
||||
} catch (const std::bad_alloc &e) {
|
||||
PyErr_SetString(PyExc_MemoryError, e.what());
|
||||
return;
|
||||
} catch (const std::domain_error &e) {
|
||||
PyErr_SetString(PyExc_ValueError, e.what());
|
||||
return;
|
||||
} catch (const std::invalid_argument &e) {
|
||||
PyErr_SetString(PyExc_ValueError, e.what());
|
||||
return;
|
||||
} catch (const std::length_error &e) {
|
||||
PyErr_SetString(PyExc_ValueError, e.what());
|
||||
return;
|
||||
} catch (const std::out_of_range &e) {
|
||||
PyErr_SetString(PyExc_IndexError, e.what());
|
||||
return;
|
||||
} catch (const std::range_error &e) {
|
||||
PyErr_SetString(PyExc_ValueError, e.what());
|
||||
return;
|
||||
} catch (const std::exception &e) {
|
||||
PyErr_SetString(PyExc_RuntimeError, e.what());
|
||||
return;
|
||||
} catch (...) {
|
||||
PyErr_SetString(PyExc_RuntimeError, "Caught an unknown exception!");
|
||||
return;
|
||||
}
|
||||
});
|
||||
internals_ptr->static_property_type = make_static_property_type();
|
||||
internals_ptr->default_metaclass = make_default_metaclass();
|
||||
internals_ptr->instance_base =
|
||||
make_object_base_type(internals_ptr->default_metaclass);
|
||||
}
|
||||
return **internals_pp;
|
||||
}
|
||||
|
||||
/// Works like `internals.registered_types_cpp`, but for module-local registered types:
|
||||
/// Works like `internals.registered_types_cpp`, but for module-local registered
|
||||
/// types:
|
||||
inline type_map<type_info *> ®istered_local_types_cpp() {
|
||||
static type_map<type_info *> locals{};
|
||||
return locals;
|
||||
static type_map<type_info *> locals{};
|
||||
return locals;
|
||||
}
|
||||
|
||||
/// Constructs a std::string with the given arguments, stores it in `internals`, and returns its
|
||||
/// `c_str()`. Such strings objects have a long storage duration -- the internal strings are only
|
||||
/// cleared when the program exits or after interpreter shutdown (when embedding), and so are
|
||||
/// suitable for c-style strings needed by Python internals (such as PyTypeObject's tp_name).
|
||||
template <typename... Args>
|
||||
const char *c_str(Args &&...args) {
|
||||
auto &strings = get_internals().static_strings;
|
||||
strings.emplace_front(std::forward<Args>(args)...);
|
||||
return strings.front().c_str();
|
||||
/// Constructs a std::string with the given arguments, stores it in `internals`,
|
||||
/// and returns its `c_str()`. Such strings objects have a long storage
|
||||
/// duration -- the internal strings are only cleared when the program exits or
|
||||
/// after interpreter shutdown (when embedding), and so are suitable for c-style
|
||||
/// strings needed by Python internals (such as PyTypeObject's tp_name).
|
||||
template <typename... Args> const char *c_str(Args &&... args) {
|
||||
auto &strings = get_internals().static_strings;
|
||||
strings.emplace_front(std::forward<Args>(args)...);
|
||||
return strings.front().c_str();
|
||||
}
|
||||
|
||||
NAMESPACE_END(detail)
|
||||
|
||||
/// Returns a named pointer that is shared among all extension modules (using the same
|
||||
/// pybind11 version) running in the current interpreter. Names starting with underscores
|
||||
/// are reserved for internal usage. Returns `nullptr` if no matching entry was found.
|
||||
/// Returns a named pointer that is shared among all extension modules (using
|
||||
/// the same pybind11 version) running in the current interpreter. Names
|
||||
/// starting with underscores are reserved for internal usage. Returns `nullptr`
|
||||
/// if no matching entry was found.
|
||||
inline PYBIND11_NOINLINE void *get_shared_data(const std::string &name) {
|
||||
auto &internals = detail::get_internals();
|
||||
auto it = internals.shared_data.find(name);
|
||||
return it != internals.shared_data.end() ? it->second : nullptr;
|
||||
auto &internals = detail::get_internals();
|
||||
auto it = internals.shared_data.find(name);
|
||||
return it != internals.shared_data.end() ? it->second : nullptr;
|
||||
}
|
||||
|
||||
/// Set the shared data that can be later recovered by `get_shared_data()`.
|
||||
inline PYBIND11_NOINLINE void *set_shared_data(const std::string &name, void *data) {
|
||||
detail::get_internals().shared_data[name] = data;
|
||||
return data;
|
||||
inline PYBIND11_NOINLINE void *set_shared_data(const std::string &name,
|
||||
void *data) {
|
||||
detail::get_internals().shared_data[name] = data;
|
||||
return data;
|
||||
}
|
||||
|
||||
/// Returns a typed reference to a shared data entry (by using `get_shared_data()`) if
|
||||
/// such entry exists. Otherwise, a new object of default-constructible type `T` is
|
||||
/// added to the shared data under the given name and a reference to it is returned.
|
||||
template<typename T>
|
||||
T &get_or_create_shared_data(const std::string &name) {
|
||||
auto &internals = detail::get_internals();
|
||||
auto it = internals.shared_data.find(name);
|
||||
T *ptr = (T *) (it != internals.shared_data.end() ? it->second : nullptr);
|
||||
if (!ptr) {
|
||||
ptr = new T();
|
||||
internals.shared_data[name] = ptr;
|
||||
}
|
||||
return *ptr;
|
||||
/// Returns a typed reference to a shared data entry (by using
|
||||
/// `get_shared_data()`) if such entry exists. Otherwise, a new object of
|
||||
/// default-constructible type `T` is added to the shared data under the given
|
||||
/// name and a reference to it is returned.
|
||||
template <typename T> T &get_or_create_shared_data(const std::string &name) {
|
||||
auto &internals = detail::get_internals();
|
||||
auto it = internals.shared_data.find(name);
|
||||
T *ptr = (T *)(it != internals.shared_data.end() ? it->second : nullptr);
|
||||
if (!ptr) {
|
||||
ptr = new T();
|
||||
internals.shared_data[name] = ptr;
|
||||
}
|
||||
return *ptr;
|
||||
}
|
||||
|
||||
NAMESPACE_END(PYBIND11_NAMESPACE)
|
||||
|
@@ -22,34 +22,35 @@ NAMESPACE_BEGIN(PYBIND11_NAMESPACE)
|
||||
NAMESPACE_BEGIN(detail)
|
||||
/// Erase all occurrences of a substring
|
||||
inline void erase_all(std::string &string, const std::string &search) {
|
||||
for (size_t pos = 0;;) {
|
||||
pos = string.find(search, pos);
|
||||
if (pos == std::string::npos) break;
|
||||
string.erase(pos, search.length());
|
||||
}
|
||||
for (size_t pos = 0;;) {
|
||||
pos = string.find(search, pos);
|
||||
if (pos == std::string::npos)
|
||||
break;
|
||||
string.erase(pos, search.length());
|
||||
}
|
||||
}
|
||||
|
||||
PYBIND11_NOINLINE inline void clean_type_id(std::string &name) {
|
||||
#if defined(__GNUG__)
|
||||
int status = 0;
|
||||
std::unique_ptr<char, void (*)(void *)> res {
|
||||
abi::__cxa_demangle(name.c_str(), nullptr, nullptr, &status), std::free };
|
||||
if (status == 0)
|
||||
name = res.get();
|
||||
int status = 0;
|
||||
std::unique_ptr<char, void (*)(void *)> res{
|
||||
abi::__cxa_demangle(name.c_str(), nullptr, nullptr, &status), std::free};
|
||||
if (status == 0)
|
||||
name = res.get();
|
||||
#else
|
||||
detail::erase_all(name, "class ");
|
||||
detail::erase_all(name, "struct ");
|
||||
detail::erase_all(name, "enum ");
|
||||
detail::erase_all(name, "class ");
|
||||
detail::erase_all(name, "struct ");
|
||||
detail::erase_all(name, "enum ");
|
||||
#endif
|
||||
detail::erase_all(name, "pybind11::");
|
||||
detail::erase_all(name, "pybind11::");
|
||||
}
|
||||
NAMESPACE_END(detail)
|
||||
|
||||
/// Return a string representation of a C++ type
|
||||
template <typename T> static std::string type_id() {
|
||||
std::string name(typeid(T).name());
|
||||
detail::clean_type_id(name);
|
||||
return name;
|
||||
std::string name(typeid(T).name());
|
||||
detail::clean_type_id(name);
|
||||
return name;
|
||||
}
|
||||
|
||||
NAMESPACE_END(PYBIND11_NAMESPACE)
|
||||
|
File diff suppressed because it is too large
Load Diff
@@ -9,23 +9,23 @@
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "pybind11.h"
|
||||
#include "eval.h"
|
||||
#include "pybind11.h"
|
||||
|
||||
#if defined(PYPY_VERSION)
|
||||
# error Embedding the interpreter is not supported with PyPy
|
||||
#error Embedding the interpreter is not supported with PyPy
|
||||
#endif
|
||||
|
||||
#if PY_MAJOR_VERSION >= 3
|
||||
# define PYBIND11_EMBEDDED_MODULE_IMPL(name) \
|
||||
extern "C" PyObject *pybind11_init_impl_##name() { \
|
||||
return pybind11_init_wrapper_##name(); \
|
||||
}
|
||||
#define PYBIND11_EMBEDDED_MODULE_IMPL(name) \
|
||||
extern "C" PyObject *pybind11_init_impl_##name() { \
|
||||
return pybind11_init_wrapper_##name(); \
|
||||
}
|
||||
#else
|
||||
# define PYBIND11_EMBEDDED_MODULE_IMPL(name) \
|
||||
extern "C" void pybind11_init_impl_##name() { \
|
||||
pybind11_init_wrapper_##name(); \
|
||||
}
|
||||
#define PYBIND11_EMBEDDED_MODULE_IMPL(name) \
|
||||
extern "C" void pybind11_init_impl_##name() { \
|
||||
pybind11_init_wrapper_##name(); \
|
||||
}
|
||||
#endif
|
||||
|
||||
/** \rst
|
||||
@@ -43,75 +43,78 @@
|
||||
});
|
||||
}
|
||||
\endrst */
|
||||
#define PYBIND11_EMBEDDED_MODULE(name, variable) \
|
||||
static void PYBIND11_CONCAT(pybind11_init_, name)(pybind11::module &); \
|
||||
static PyObject PYBIND11_CONCAT(*pybind11_init_wrapper_, name)() { \
|
||||
auto m = pybind11::module(PYBIND11_TOSTRING(name)); \
|
||||
try { \
|
||||
PYBIND11_CONCAT(pybind11_init_, name)(m); \
|
||||
return m.ptr(); \
|
||||
} catch (pybind11::error_already_set &e) { \
|
||||
PyErr_SetString(PyExc_ImportError, e.what()); \
|
||||
return nullptr; \
|
||||
} catch (const std::exception &e) { \
|
||||
PyErr_SetString(PyExc_ImportError, e.what()); \
|
||||
return nullptr; \
|
||||
} \
|
||||
} \
|
||||
PYBIND11_EMBEDDED_MODULE_IMPL(name) \
|
||||
pybind11::detail::embedded_module name(PYBIND11_TOSTRING(name), \
|
||||
PYBIND11_CONCAT(pybind11_init_impl_, name)); \
|
||||
void PYBIND11_CONCAT(pybind11_init_, name)(pybind11::module &variable)
|
||||
|
||||
#define PYBIND11_EMBEDDED_MODULE(name, variable) \
|
||||
static void PYBIND11_CONCAT(pybind11_init_, name)(pybind11::module &); \
|
||||
static PyObject PYBIND11_CONCAT(*pybind11_init_wrapper_, name)() { \
|
||||
auto m = pybind11::module(PYBIND11_TOSTRING(name)); \
|
||||
try { \
|
||||
PYBIND11_CONCAT(pybind11_init_, name)(m); \
|
||||
return m.ptr(); \
|
||||
} catch (pybind11::error_already_set & e) { \
|
||||
PyErr_SetString(PyExc_ImportError, e.what()); \
|
||||
return nullptr; \
|
||||
} catch (const std::exception &e) { \
|
||||
PyErr_SetString(PyExc_ImportError, e.what()); \
|
||||
return nullptr; \
|
||||
} \
|
||||
} \
|
||||
PYBIND11_EMBEDDED_MODULE_IMPL(name) \
|
||||
pybind11::detail::embedded_module name( \
|
||||
PYBIND11_TOSTRING(name), PYBIND11_CONCAT(pybind11_init_impl_, name)); \
|
||||
void PYBIND11_CONCAT(pybind11_init_, name)(pybind11::module & variable)
|
||||
|
||||
NAMESPACE_BEGIN(PYBIND11_NAMESPACE)
|
||||
NAMESPACE_BEGIN(detail)
|
||||
|
||||
/// Python 2.7/3.x compatible version of `PyImport_AppendInittab` and error checks.
|
||||
/// Python 2.7/3.x compatible version of `PyImport_AppendInittab` and error
|
||||
/// checks.
|
||||
struct embedded_module {
|
||||
#if PY_MAJOR_VERSION >= 3
|
||||
using init_t = PyObject *(*)();
|
||||
using init_t = PyObject *(*)();
|
||||
#else
|
||||
using init_t = void (*)();
|
||||
using init_t = void (*)();
|
||||
#endif
|
||||
embedded_module(const char *name, init_t init) {
|
||||
if (Py_IsInitialized())
|
||||
pybind11_fail("Can't add new modules after the interpreter has been initialized");
|
||||
embedded_module(const char *name, init_t init) {
|
||||
if (Py_IsInitialized())
|
||||
pybind11_fail(
|
||||
"Can't add new modules after the interpreter has been initialized");
|
||||
|
||||
auto result = PyImport_AppendInittab(name, init);
|
||||
if (result == -1)
|
||||
pybind11_fail("Insufficient memory to add a new module");
|
||||
}
|
||||
auto result = PyImport_AppendInittab(name, init);
|
||||
if (result == -1)
|
||||
pybind11_fail("Insufficient memory to add a new module");
|
||||
}
|
||||
};
|
||||
|
||||
NAMESPACE_END(detail)
|
||||
|
||||
/** \rst
|
||||
Initialize the Python interpreter. No other pybind11 or CPython API functions can be
|
||||
called before this is done; with the exception of `PYBIND11_EMBEDDED_MODULE`. The
|
||||
optional parameter can be used to skip the registration of signal handlers (see the
|
||||
`Python documentation`_ for details). Calling this function again after the interpreter
|
||||
has already been initialized is a fatal error.
|
||||
Initialize the Python interpreter. No other pybind11 or CPython API
|
||||
functions can be called before this is done; with the exception of
|
||||
`PYBIND11_EMBEDDED_MODULE`. The optional parameter can be used to skip the
|
||||
registration of signal handlers (see the `Python documentation`_ for details).
|
||||
Calling this function again after the interpreter has already been initialized
|
||||
is a fatal error.
|
||||
|
||||
If initializing the Python interpreter fails, then the program is terminated. (This
|
||||
is controlled by the CPython runtime and is an exception to pybind11's normal behavior
|
||||
of throwing exceptions on errors.)
|
||||
If initializing the Python interpreter fails, then the program is
|
||||
terminated. (This is controlled by the CPython runtime and is an exception to
|
||||
pybind11's normal behavior of throwing exceptions on errors.)
|
||||
|
||||
.. _Python documentation: https://docs.python.org/3/c-api/init.html#c.Py_InitializeEx
|
||||
\endrst */
|
||||
.. _Python documentation:
|
||||
https://docs.python.org/3/c-api/init.html#c.Py_InitializeEx \endrst */
|
||||
inline void initialize_interpreter(bool init_signal_handlers = true) {
|
||||
if (Py_IsInitialized())
|
||||
pybind11_fail("The interpreter is already running");
|
||||
if (Py_IsInitialized())
|
||||
pybind11_fail("The interpreter is already running");
|
||||
|
||||
Py_InitializeEx(init_signal_handlers ? 1 : 0);
|
||||
Py_InitializeEx(init_signal_handlers ? 1 : 0);
|
||||
|
||||
// Make .py files in the working directory available by default
|
||||
module::import("sys").attr("path").cast<list>().append(".");
|
||||
// Make .py files in the working directory available by default
|
||||
module::import("sys").attr("path").cast<list>().append(".");
|
||||
}
|
||||
|
||||
/** \rst
|
||||
Shut down the Python interpreter. No pybind11 or CPython API functions can be called
|
||||
after this. In addition, pybind11 objects must not outlive the interpreter:
|
||||
Shut down the Python interpreter. No pybind11 or CPython API functions can
|
||||
be called after this. In addition, pybind11 objects must not outlive the
|
||||
interpreter:
|
||||
|
||||
.. code-block:: cpp
|
||||
|
||||
@@ -136,32 +139,33 @@ inline void initialize_interpreter(bool init_signal_handlers = true) {
|
||||
|
||||
.. warning::
|
||||
|
||||
The interpreter can be restarted by calling `initialize_interpreter` again.
|
||||
Modules created using pybind11 can be safely re-initialized. However, Python
|
||||
itself cannot completely unload binary extension modules and there are several
|
||||
caveats with regard to interpreter restarting. All the details can be found
|
||||
in the CPython documentation. In short, not all interpreter memory may be
|
||||
The interpreter can be restarted by calling `initialize_interpreter`
|
||||
again. Modules created using pybind11 can be safely re-initialized. However,
|
||||
Python itself cannot completely unload binary extension modules and there are
|
||||
several caveats with regard to interpreter restarting. All the details can be
|
||||
found in the CPython documentation. In short, not all interpreter memory may be
|
||||
freed, either due to reference cycles or user-created global data.
|
||||
|
||||
\endrst */
|
||||
inline void finalize_interpreter() {
|
||||
handle builtins(PyEval_GetBuiltins());
|
||||
const char *id = PYBIND11_INTERNALS_ID;
|
||||
handle builtins(PyEval_GetBuiltins());
|
||||
const char *id = PYBIND11_INTERNALS_ID;
|
||||
|
||||
// Get the internals pointer (without creating it if it doesn't exist). It's possible for the
|
||||
// internals to be created during Py_Finalize() (e.g. if a py::capsule calls `get_internals()`
|
||||
// during destruction), so we get the pointer-pointer here and check it after Py_Finalize().
|
||||
detail::internals **internals_ptr_ptr = detail::get_internals_pp();
|
||||
// It could also be stashed in builtins, so look there too:
|
||||
if (builtins.contains(id) && isinstance<capsule>(builtins[id]))
|
||||
internals_ptr_ptr = capsule(builtins[id]);
|
||||
// Get the internals pointer (without creating it if it doesn't exist). It's
|
||||
// possible for the internals to be created during Py_Finalize() (e.g. if a
|
||||
// py::capsule calls `get_internals()` during destruction), so we get the
|
||||
// pointer-pointer here and check it after Py_Finalize().
|
||||
detail::internals **internals_ptr_ptr = detail::get_internals_pp();
|
||||
// It could also be stashed in builtins, so look there too:
|
||||
if (builtins.contains(id) && isinstance<capsule>(builtins[id]))
|
||||
internals_ptr_ptr = capsule(builtins[id]);
|
||||
|
||||
Py_Finalize();
|
||||
Py_Finalize();
|
||||
|
||||
if (internals_ptr_ptr) {
|
||||
delete *internals_ptr_ptr;
|
||||
*internals_ptr_ptr = nullptr;
|
||||
}
|
||||
if (internals_ptr_ptr) {
|
||||
delete *internals_ptr_ptr;
|
||||
*internals_ptr_ptr = nullptr;
|
||||
}
|
||||
}
|
||||
|
||||
/** \rst
|
||||
@@ -179,22 +183,24 @@ inline void finalize_interpreter() {
|
||||
\endrst */
|
||||
class scoped_interpreter {
|
||||
public:
|
||||
scoped_interpreter(bool init_signal_handlers = true) {
|
||||
initialize_interpreter(init_signal_handlers);
|
||||
}
|
||||
scoped_interpreter(bool init_signal_handlers = true) {
|
||||
initialize_interpreter(init_signal_handlers);
|
||||
}
|
||||
|
||||
scoped_interpreter(const scoped_interpreter &) = delete;
|
||||
scoped_interpreter(scoped_interpreter &&other) noexcept { other.is_valid = false; }
|
||||
scoped_interpreter &operator=(const scoped_interpreter &) = delete;
|
||||
scoped_interpreter &operator=(scoped_interpreter &&) = delete;
|
||||
scoped_interpreter(const scoped_interpreter &) = delete;
|
||||
scoped_interpreter(scoped_interpreter &&other) noexcept {
|
||||
other.is_valid = false;
|
||||
}
|
||||
scoped_interpreter &operator=(const scoped_interpreter &) = delete;
|
||||
scoped_interpreter &operator=(scoped_interpreter &&) = delete;
|
||||
|
||||
~scoped_interpreter() {
|
||||
if (is_valid)
|
||||
finalize_interpreter();
|
||||
}
|
||||
~scoped_interpreter() {
|
||||
if (is_valid)
|
||||
finalize_interpreter();
|
||||
}
|
||||
|
||||
private:
|
||||
bool is_valid = true;
|
||||
bool is_valid = true;
|
||||
};
|
||||
|
||||
NAMESPACE_END(PYBIND11_NAMESPACE)
|
||||
|
@@ -2,8 +2,8 @@
|
||||
pybind11/exec.h: Support for evaluating Python expressions and statements
|
||||
from strings and files
|
||||
|
||||
Copyright (c) 2016 Klemens Morgenstern <klemens.morgenstern@ed-chemnitz.de> and
|
||||
Wenzel Jakob <wenzel.jakob@epfl.ch>
|
||||
Copyright (c) 2016 Klemens Morgenstern <klemens.morgenstern@ed-chemnitz.de>
|
||||
and Wenzel Jakob <wenzel.jakob@epfl.ch>
|
||||
|
||||
All rights reserved. Use of this source code is governed by a
|
||||
BSD-style license that can be found in the LICENSE file.
|
||||
@@ -16,102 +16,119 @@
|
||||
NAMESPACE_BEGIN(PYBIND11_NAMESPACE)
|
||||
|
||||
enum eval_mode {
|
||||
/// Evaluate a string containing an isolated expression
|
||||
eval_expr,
|
||||
/// Evaluate a string containing an isolated expression
|
||||
eval_expr,
|
||||
|
||||
/// Evaluate a string containing a single statement. Returns \c none
|
||||
eval_single_statement,
|
||||
/// Evaluate a string containing a single statement. Returns \c none
|
||||
eval_single_statement,
|
||||
|
||||
/// Evaluate a string containing a sequence of statement. Returns \c none
|
||||
eval_statements
|
||||
/// Evaluate a string containing a sequence of statement. Returns \c none
|
||||
eval_statements
|
||||
};
|
||||
|
||||
template <eval_mode mode = eval_expr>
|
||||
object eval(str expr, object global = globals(), object local = object()) {
|
||||
if (!local)
|
||||
local = global;
|
||||
if (!local)
|
||||
local = global;
|
||||
|
||||
/* PyRun_String does not accept a PyObject / encoding specifier,
|
||||
this seems to be the only alternative */
|
||||
std::string buffer = "# -*- coding: utf-8 -*-\n" + (std::string) expr;
|
||||
/* PyRun_String does not accept a PyObject / encoding specifier,
|
||||
this seems to be the only alternative */
|
||||
std::string buffer = "# -*- coding: utf-8 -*-\n" + (std::string)expr;
|
||||
|
||||
int start;
|
||||
switch (mode) {
|
||||
case eval_expr: start = Py_eval_input; break;
|
||||
case eval_single_statement: start = Py_single_input; break;
|
||||
case eval_statements: start = Py_file_input; break;
|
||||
default: pybind11_fail("invalid evaluation mode");
|
||||
}
|
||||
int start;
|
||||
switch (mode) {
|
||||
case eval_expr:
|
||||
start = Py_eval_input;
|
||||
break;
|
||||
case eval_single_statement:
|
||||
start = Py_single_input;
|
||||
break;
|
||||
case eval_statements:
|
||||
start = Py_file_input;
|
||||
break;
|
||||
default:
|
||||
pybind11_fail("invalid evaluation mode");
|
||||
}
|
||||
|
||||
PyObject *result = PyRun_String(buffer.c_str(), start, global.ptr(), local.ptr());
|
||||
if (!result)
|
||||
throw error_already_set();
|
||||
return reinterpret_steal<object>(result);
|
||||
PyObject *result =
|
||||
PyRun_String(buffer.c_str(), start, global.ptr(), local.ptr());
|
||||
if (!result)
|
||||
throw error_already_set();
|
||||
return reinterpret_steal<object>(result);
|
||||
}
|
||||
|
||||
template <eval_mode mode = eval_expr, size_t N>
|
||||
object eval(const char (&s)[N], object global = globals(), object local = object()) {
|
||||
/* Support raw string literals by removing common leading whitespace */
|
||||
auto expr = (s[0] == '\n') ? str(module::import("textwrap").attr("dedent")(s))
|
||||
: str(s);
|
||||
return eval<mode>(expr, global, local);
|
||||
object eval(const char (&s)[N], object global = globals(),
|
||||
object local = object()) {
|
||||
/* Support raw string literals by removing common leading whitespace */
|
||||
auto expr = (s[0] == '\n') ? str(module::import("textwrap").attr("dedent")(s))
|
||||
: str(s);
|
||||
return eval<mode>(expr, global, local);
|
||||
}
|
||||
|
||||
inline void exec(str expr, object global = globals(), object local = object()) {
|
||||
eval<eval_statements>(expr, global, local);
|
||||
eval<eval_statements>(expr, global, local);
|
||||
}
|
||||
|
||||
template <size_t N>
|
||||
void exec(const char (&s)[N], object global = globals(), object local = object()) {
|
||||
eval<eval_statements>(s, global, local);
|
||||
void exec(const char (&s)[N], object global = globals(),
|
||||
object local = object()) {
|
||||
eval<eval_statements>(s, global, local);
|
||||
}
|
||||
|
||||
template <eval_mode mode = eval_statements>
|
||||
object eval_file(str fname, object global = globals(), object local = object()) {
|
||||
if (!local)
|
||||
local = global;
|
||||
object eval_file(str fname, object global = globals(),
|
||||
object local = object()) {
|
||||
if (!local)
|
||||
local = global;
|
||||
|
||||
int start;
|
||||
switch (mode) {
|
||||
case eval_expr: start = Py_eval_input; break;
|
||||
case eval_single_statement: start = Py_single_input; break;
|
||||
case eval_statements: start = Py_file_input; break;
|
||||
default: pybind11_fail("invalid evaluation mode");
|
||||
}
|
||||
int start;
|
||||
switch (mode) {
|
||||
case eval_expr:
|
||||
start = Py_eval_input;
|
||||
break;
|
||||
case eval_single_statement:
|
||||
start = Py_single_input;
|
||||
break;
|
||||
case eval_statements:
|
||||
start = Py_file_input;
|
||||
break;
|
||||
default:
|
||||
pybind11_fail("invalid evaluation mode");
|
||||
}
|
||||
|
||||
int closeFile = 1;
|
||||
std::string fname_str = (std::string) fname;
|
||||
int closeFile = 1;
|
||||
std::string fname_str = (std::string)fname;
|
||||
#if PY_VERSION_HEX >= 0x03040000
|
||||
FILE *f = _Py_fopen_obj(fname.ptr(), "r");
|
||||
FILE *f = _Py_fopen_obj(fname.ptr(), "r");
|
||||
#elif PY_VERSION_HEX >= 0x03000000
|
||||
FILE *f = _Py_fopen(fname.ptr(), "r");
|
||||
FILE *f = _Py_fopen(fname.ptr(), "r");
|
||||
#else
|
||||
/* No unicode support in open() :( */
|
||||
auto fobj = reinterpret_steal<object>(PyFile_FromString(
|
||||
const_cast<char *>(fname_str.c_str()),
|
||||
const_cast<char*>("r")));
|
||||
FILE *f = nullptr;
|
||||
if (fobj)
|
||||
f = PyFile_AsFile(fobj.ptr());
|
||||
closeFile = 0;
|
||||
/* No unicode support in open() :( */
|
||||
auto fobj = reinterpret_steal<object>(PyFile_FromString(
|
||||
const_cast<char *>(fname_str.c_str()), const_cast<char *>("r")));
|
||||
FILE *f = nullptr;
|
||||
if (fobj)
|
||||
f = PyFile_AsFile(fobj.ptr());
|
||||
closeFile = 0;
|
||||
#endif
|
||||
if (!f) {
|
||||
PyErr_Clear();
|
||||
pybind11_fail("File \"" + fname_str + "\" could not be opened!");
|
||||
}
|
||||
if (!f) {
|
||||
PyErr_Clear();
|
||||
pybind11_fail("File \"" + fname_str + "\" could not be opened!");
|
||||
}
|
||||
|
||||
#if PY_VERSION_HEX < 0x03000000 && defined(PYPY_VERSION)
|
||||
PyObject *result = PyRun_File(f, fname_str.c_str(), start, global.ptr(),
|
||||
local.ptr());
|
||||
(void) closeFile;
|
||||
PyObject *result =
|
||||
PyRun_File(f, fname_str.c_str(), start, global.ptr(), local.ptr());
|
||||
(void)closeFile;
|
||||
#else
|
||||
PyObject *result = PyRun_FileEx(f, fname_str.c_str(), start, global.ptr(),
|
||||
local.ptr(), closeFile);
|
||||
PyObject *result = PyRun_FileEx(f, fname_str.c_str(), start, global.ptr(),
|
||||
local.ptr(), closeFile);
|
||||
#endif
|
||||
|
||||
if (!result)
|
||||
throw error_already_set();
|
||||
return reinterpret_steal<object>(result);
|
||||
if (!result)
|
||||
throw error_already_set();
|
||||
return reinterpret_steal<object>(result);
|
||||
}
|
||||
|
||||
NAMESPACE_END(PYBIND11_NAMESPACE)
|
||||
|
@@ -17,91 +17,99 @@ NAMESPACE_BEGIN(detail)
|
||||
|
||||
template <typename Return, typename... Args>
|
||||
struct type_caster<std::function<Return(Args...)>> {
|
||||
using type = std::function<Return(Args...)>;
|
||||
using retval_type = conditional_t<std::is_same<Return, void>::value, void_type, Return>;
|
||||
using function_type = Return (*) (Args...);
|
||||
using type = std::function<Return(Args...)>;
|
||||
using retval_type =
|
||||
conditional_t<std::is_same<Return, void>::value, void_type, Return>;
|
||||
using function_type = Return (*)(Args...);
|
||||
|
||||
public:
|
||||
bool load(handle src, bool convert) {
|
||||
if (src.is_none()) {
|
||||
// Defer accepting None to other overloads (if we aren't in convert mode):
|
||||
if (!convert) return false;
|
||||
return true;
|
||||
}
|
||||
bool load(handle src, bool convert) {
|
||||
if (src.is_none()) {
|
||||
// Defer accepting None to other overloads (if we aren't in convert mode):
|
||||
if (!convert)
|
||||
return false;
|
||||
return true;
|
||||
}
|
||||
|
||||
if (!isinstance<function>(src))
|
||||
return false;
|
||||
if (!isinstance<function>(src))
|
||||
return false;
|
||||
|
||||
auto func = reinterpret_borrow<function>(src);
|
||||
auto func = reinterpret_borrow<function>(src);
|
||||
|
||||
/*
|
||||
When passing a C++ function as an argument to another C++
|
||||
function via Python, every function call would normally involve
|
||||
a full C++ -> Python -> C++ roundtrip, which can be prohibitive.
|
||||
Here, we try to at least detect the case where the function is
|
||||
stateless (i.e. function pointer or lambda function without
|
||||
captured variables), in which case the roundtrip can be avoided.
|
||||
*/
|
||||
if (auto cfunc = func.cpp_function()) {
|
||||
auto c = reinterpret_borrow<capsule>(PyCFunction_GET_SELF(cfunc.ptr()));
|
||||
auto rec = (function_record *) c;
|
||||
/*
|
||||
When passing a C++ function as an argument to another C++
|
||||
function via Python, every function call would normally involve
|
||||
a full C++ -> Python -> C++ roundtrip, which can be prohibitive.
|
||||
Here, we try to at least detect the case where the function is
|
||||
stateless (i.e. function pointer or lambda function without
|
||||
captured variables), in which case the roundtrip can be avoided.
|
||||
*/
|
||||
if (auto cfunc = func.cpp_function()) {
|
||||
auto c = reinterpret_borrow<capsule>(PyCFunction_GET_SELF(cfunc.ptr()));
|
||||
auto rec = (function_record *)c;
|
||||
|
||||
if (rec && rec->is_stateless &&
|
||||
same_type(typeid(function_type), *reinterpret_cast<const std::type_info *>(rec->data[1]))) {
|
||||
struct capture { function_type f; };
|
||||
value = ((capture *) &rec->data)->f;
|
||||
return true;
|
||||
}
|
||||
}
|
||||
|
||||
// ensure GIL is held during functor destruction
|
||||
struct func_handle {
|
||||
function f;
|
||||
func_handle(function&& f_) : f(std::move(f_)) {}
|
||||
func_handle(const func_handle&) = default;
|
||||
~func_handle() {
|
||||
gil_scoped_acquire acq;
|
||||
function kill_f(std::move(f));
|
||||
}
|
||||
if (rec && rec->is_stateless &&
|
||||
same_type(typeid(function_type),
|
||||
*reinterpret_cast<const std::type_info *>(rec->data[1]))) {
|
||||
struct capture {
|
||||
function_type f;
|
||||
};
|
||||
|
||||
// value = [hfunc = func_handle(std::move(func))](Args... args) -> Return {
|
||||
// gil_scoped_acquire acq;
|
||||
// object retval(hfunc.f(std::forward<Args>(args)...));
|
||||
// /* Visual studio 2015 parser issue: need parentheses around this expression */
|
||||
// return (retval.template cast<Return>());
|
||||
// };
|
||||
|
||||
struct func_wrapper {
|
||||
func_handle hfunc;
|
||||
func_wrapper(func_handle&& hf): hfunc(std::move(hf)) {}
|
||||
Return operator()(Args... args) const {
|
||||
gil_scoped_acquire acq;
|
||||
object retval(hfunc.f(std::forward<Args>(args)...));
|
||||
/* Visual studio 2015 parser issue: need parentheses around this expression */
|
||||
return (retval.template cast<Return>());
|
||||
}
|
||||
};
|
||||
|
||||
value = func_wrapper(func_handle(std::move(func)));
|
||||
|
||||
value = ((capture *)&rec->data)->f;
|
||||
return true;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename Func>
|
||||
static handle cast(Func &&f_, return_value_policy policy, handle /* parent */) {
|
||||
if (!f_)
|
||||
return none().inc_ref();
|
||||
// ensure GIL is held during functor destruction
|
||||
struct func_handle {
|
||||
function f;
|
||||
func_handle(function &&f_) : f(std::move(f_)) {}
|
||||
func_handle(const func_handle &) = default;
|
||||
~func_handle() {
|
||||
gil_scoped_acquire acq;
|
||||
function kill_f(std::move(f));
|
||||
}
|
||||
};
|
||||
|
||||
auto result = f_.template target<function_type>();
|
||||
if (result)
|
||||
return cpp_function(*result, policy).release();
|
||||
else
|
||||
return cpp_function(std::forward<Func>(f_), policy).release();
|
||||
}
|
||||
// value = [hfunc = func_handle(std::move(func))](Args... args) -> Return {
|
||||
// gil_scoped_acquire acq;
|
||||
// object retval(hfunc.f(std::forward<Args>(args)...));
|
||||
// /* Visual studio 2015 parser issue: need parentheses around this
|
||||
// expression */ return (retval.template cast<Return>());
|
||||
// };
|
||||
|
||||
PYBIND11_TYPE_CASTER(type, _("Callable[[") + concat(make_caster<Args>::name...) + _("], ")
|
||||
+ make_caster<retval_type>::name + _("]"));
|
||||
struct func_wrapper {
|
||||
func_handle hfunc;
|
||||
func_wrapper(func_handle &&hf) : hfunc(std::move(hf)) {}
|
||||
Return operator()(Args... args) const {
|
||||
gil_scoped_acquire acq;
|
||||
object retval(hfunc.f(std::forward<Args>(args)...));
|
||||
/* Visual studio 2015 parser issue: need parentheses around this
|
||||
* expression */
|
||||
return (retval.template cast<Return>());
|
||||
}
|
||||
};
|
||||
|
||||
value = func_wrapper(func_handle(std::move(func)));
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
template <typename Func>
|
||||
static handle cast(Func &&f_, return_value_policy policy,
|
||||
handle /* parent */) {
|
||||
if (!f_)
|
||||
return none().inc_ref();
|
||||
|
||||
auto result = f_.template target<function_type>();
|
||||
if (result)
|
||||
return cpp_function(*result, policy).release();
|
||||
else
|
||||
return cpp_function(std::forward<Func>(f_), policy).release();
|
||||
}
|
||||
|
||||
PYBIND11_TYPE_CASTER(type, _("Callable[[") +
|
||||
concat(make_caster<Args>::name...) + _("], ") +
|
||||
make_caster<retval_type>::name + _("]"));
|
||||
};
|
||||
|
||||
NAMESPACE_END(detail)
|
||||
|
@@ -1,5 +1,6 @@
|
||||
/*
|
||||
pybind11/iostream.h -- Tools to assist with redirecting cout and cerr to Python
|
||||
pybind11/iostream.h -- Tools to assist with redirecting cout and cerr to
|
||||
Python
|
||||
|
||||
Copyright (c) 2017 Henry F. Schreiner
|
||||
|
||||
@@ -11,11 +12,11 @@
|
||||
|
||||
#include "pybind11.h"
|
||||
|
||||
#include <streambuf>
|
||||
#include <ostream>
|
||||
#include <string>
|
||||
#include <memory>
|
||||
#include <iostream>
|
||||
#include <memory>
|
||||
#include <ostream>
|
||||
#include <streambuf>
|
||||
#include <string>
|
||||
|
||||
NAMESPACE_BEGIN(PYBIND11_NAMESPACE)
|
||||
NAMESPACE_BEGIN(detail)
|
||||
@@ -23,56 +24,50 @@ NAMESPACE_BEGIN(detail)
|
||||
// Buffer that writes to Python instead of C++
|
||||
class pythonbuf : public std::streambuf {
|
||||
private:
|
||||
using traits_type = std::streambuf::traits_type;
|
||||
using traits_type = std::streambuf::traits_type;
|
||||
|
||||
const size_t buf_size;
|
||||
std::unique_ptr<char[]> d_buffer;
|
||||
object pywrite;
|
||||
object pyflush;
|
||||
const size_t buf_size;
|
||||
std::unique_ptr<char[]> d_buffer;
|
||||
object pywrite;
|
||||
object pyflush;
|
||||
|
||||
int overflow(int c) {
|
||||
if (!traits_type::eq_int_type(c, traits_type::eof())) {
|
||||
*pptr() = traits_type::to_char_type(c);
|
||||
pbump(1);
|
||||
}
|
||||
return sync() == 0 ? traits_type::not_eof(c) : traits_type::eof();
|
||||
int overflow(int c) {
|
||||
if (!traits_type::eq_int_type(c, traits_type::eof())) {
|
||||
*pptr() = traits_type::to_char_type(c);
|
||||
pbump(1);
|
||||
}
|
||||
return sync() == 0 ? traits_type::not_eof(c) : traits_type::eof();
|
||||
}
|
||||
|
||||
int sync() {
|
||||
if (pbase() != pptr()) {
|
||||
// This subtraction cannot be negative, so dropping the sign
|
||||
str line(pbase(), static_cast<size_t>(pptr() - pbase()));
|
||||
int sync() {
|
||||
if (pbase() != pptr()) {
|
||||
// This subtraction cannot be negative, so dropping the sign
|
||||
str line(pbase(), static_cast<size_t>(pptr() - pbase()));
|
||||
|
||||
{
|
||||
gil_scoped_acquire tmp;
|
||||
pywrite(line);
|
||||
pyflush();
|
||||
}
|
||||
{
|
||||
gil_scoped_acquire tmp;
|
||||
pywrite(line);
|
||||
pyflush();
|
||||
}
|
||||
|
||||
setp(pbase(), epptr());
|
||||
}
|
||||
return 0;
|
||||
setp(pbase(), epptr());
|
||||
}
|
||||
return 0;
|
||||
}
|
||||
|
||||
public:
|
||||
pythonbuf(object pyostream, size_t buffer_size = 1024)
|
||||
: buf_size(buffer_size), d_buffer(new char[buf_size]),
|
||||
pywrite(pyostream.attr("write")), pyflush(pyostream.attr("flush")) {
|
||||
setp(d_buffer.get(), d_buffer.get() + buf_size - 1);
|
||||
}
|
||||
|
||||
pythonbuf(object pyostream, size_t buffer_size = 1024)
|
||||
: buf_size(buffer_size),
|
||||
d_buffer(new char[buf_size]),
|
||||
pywrite(pyostream.attr("write")),
|
||||
pyflush(pyostream.attr("flush")) {
|
||||
setp(d_buffer.get(), d_buffer.get() + buf_size - 1);
|
||||
}
|
||||
|
||||
/// Sync before destroy
|
||||
~pythonbuf() {
|
||||
sync();
|
||||
}
|
||||
/// Sync before destroy
|
||||
~pythonbuf() { sync(); }
|
||||
};
|
||||
|
||||
NAMESPACE_END(detail)
|
||||
|
||||
|
||||
/** \rst
|
||||
This a move-only guard that redirects output.
|
||||
|
||||
@@ -93,35 +88,32 @@ NAMESPACE_END(detail)
|
||||
.. code-block:: cpp
|
||||
|
||||
{
|
||||
py::scoped_ostream_redirect output{std::cerr, py::module::import("sys").attr("stderr")};
|
||||
std::cerr << "Hello, World!";
|
||||
py::scoped_ostream_redirect output{std::cerr,
|
||||
py::module::import("sys").attr("stderr")}; std::cerr << "Hello, World!";
|
||||
}
|
||||
\endrst */
|
||||
class scoped_ostream_redirect {
|
||||
protected:
|
||||
std::streambuf *old;
|
||||
std::ostream &costream;
|
||||
detail::pythonbuf buffer;
|
||||
std::streambuf *old;
|
||||
std::ostream &costream;
|
||||
detail::pythonbuf buffer;
|
||||
|
||||
public:
|
||||
scoped_ostream_redirect(
|
||||
std::ostream &costream = std::cout,
|
||||
object pyostream = module::import("sys").attr("stdout"))
|
||||
: costream(costream), buffer(pyostream) {
|
||||
old = costream.rdbuf(&buffer);
|
||||
}
|
||||
scoped_ostream_redirect(
|
||||
std::ostream &costream = std::cout,
|
||||
object pyostream = module::import("sys").attr("stdout"))
|
||||
: costream(costream), buffer(pyostream) {
|
||||
old = costream.rdbuf(&buffer);
|
||||
}
|
||||
|
||||
~scoped_ostream_redirect() {
|
||||
costream.rdbuf(old);
|
||||
}
|
||||
~scoped_ostream_redirect() { costream.rdbuf(old); }
|
||||
|
||||
scoped_ostream_redirect(const scoped_ostream_redirect &) = delete;
|
||||
scoped_ostream_redirect(scoped_ostream_redirect &&other) = default;
|
||||
scoped_ostream_redirect &operator=(const scoped_ostream_redirect &) = delete;
|
||||
scoped_ostream_redirect &operator=(scoped_ostream_redirect &&) = delete;
|
||||
scoped_ostream_redirect(const scoped_ostream_redirect &) = delete;
|
||||
scoped_ostream_redirect(scoped_ostream_redirect &&other) = default;
|
||||
scoped_ostream_redirect &operator=(const scoped_ostream_redirect &) = delete;
|
||||
scoped_ostream_redirect &operator=(scoped_ostream_redirect &&) = delete;
|
||||
};
|
||||
|
||||
|
||||
/** \rst
|
||||
Like `scoped_ostream_redirect`, but redirects cerr by default. This class
|
||||
is provided primary to make ``py::call_guard`` easier to make.
|
||||
@@ -135,44 +127,44 @@ public:
|
||||
\endrst */
|
||||
class scoped_estream_redirect : public scoped_ostream_redirect {
|
||||
public:
|
||||
scoped_estream_redirect(
|
||||
std::ostream &costream = std::cerr,
|
||||
object pyostream = module::import("sys").attr("stderr"))
|
||||
: scoped_ostream_redirect(costream,pyostream) {}
|
||||
scoped_estream_redirect(
|
||||
std::ostream &costream = std::cerr,
|
||||
object pyostream = module::import("sys").attr("stderr"))
|
||||
: scoped_ostream_redirect(costream, pyostream) {}
|
||||
};
|
||||
|
||||
|
||||
NAMESPACE_BEGIN(detail)
|
||||
|
||||
// Class to redirect output as a context manager. C++ backend.
|
||||
class OstreamRedirect {
|
||||
bool do_stdout_;
|
||||
bool do_stderr_;
|
||||
std::unique_ptr<scoped_ostream_redirect> redirect_stdout;
|
||||
std::unique_ptr<scoped_estream_redirect> redirect_stderr;
|
||||
bool do_stdout_;
|
||||
bool do_stderr_;
|
||||
std::unique_ptr<scoped_ostream_redirect> redirect_stdout;
|
||||
std::unique_ptr<scoped_estream_redirect> redirect_stderr;
|
||||
|
||||
public:
|
||||
OstreamRedirect(bool do_stdout = true, bool do_stderr = true)
|
||||
: do_stdout_(do_stdout), do_stderr_(do_stderr) {}
|
||||
OstreamRedirect(bool do_stdout = true, bool do_stderr = true)
|
||||
: do_stdout_(do_stdout), do_stderr_(do_stderr) {}
|
||||
|
||||
void enter() {
|
||||
if (do_stdout_)
|
||||
redirect_stdout.reset(new scoped_ostream_redirect());
|
||||
if (do_stderr_)
|
||||
redirect_stderr.reset(new scoped_estream_redirect());
|
||||
}
|
||||
void enter() {
|
||||
if (do_stdout_)
|
||||
redirect_stdout.reset(new scoped_ostream_redirect());
|
||||
if (do_stderr_)
|
||||
redirect_stderr.reset(new scoped_estream_redirect());
|
||||
}
|
||||
|
||||
void exit() {
|
||||
redirect_stdout.reset();
|
||||
redirect_stderr.reset();
|
||||
}
|
||||
void exit() {
|
||||
redirect_stdout.reset();
|
||||
redirect_stderr.reset();
|
||||
}
|
||||
};
|
||||
|
||||
NAMESPACE_END(detail)
|
||||
|
||||
/** \rst
|
||||
This is a helper function to add a C++ redirect context manager to Python
|
||||
instead of using a C++ guard. To use it, add the following to your binding code:
|
||||
instead of using a C++ guard. To use it, add the following to your binding
|
||||
code:
|
||||
|
||||
.. code-block:: cpp
|
||||
|
||||
@@ -197,11 +189,13 @@ NAMESPACE_END(detail)
|
||||
m.noisy_function_with_error_printing()
|
||||
|
||||
\endrst */
|
||||
inline class_<detail::OstreamRedirect> add_ostream_redirect(module m, std::string name = "ostream_redirect") {
|
||||
return class_<detail::OstreamRedirect>(m, name.c_str(), module_local())
|
||||
.def(init<bool,bool>(), arg("stdout")=true, arg("stderr")=true)
|
||||
.def("__enter__", &detail::OstreamRedirect::enter)
|
||||
.def("__exit__", [](detail::OstreamRedirect &self_, args) { self_.exit(); });
|
||||
inline class_<detail::OstreamRedirect>
|
||||
add_ostream_redirect(module m, std::string name = "ostream_redirect") {
|
||||
return class_<detail::OstreamRedirect>(m, name.c_str(), module_local())
|
||||
.def(init<bool, bool>(), arg("stdout") = true, arg("stderr") = true)
|
||||
.def("__enter__", &detail::OstreamRedirect::enter)
|
||||
.def("__exit__",
|
||||
[](detail::OstreamRedirect &self_, args) { self_.exit(); });
|
||||
}
|
||||
|
||||
NAMESPACE_END(PYBIND11_NAMESPACE)
|
||||
|
File diff suppressed because it is too large
Load Diff
@@ -12,10 +12,13 @@
|
||||
#include "pybind11.h"
|
||||
|
||||
#if defined(__clang__) && !defined(__INTEL_COMPILER)
|
||||
# pragma clang diagnostic ignored "-Wunsequenced" // multiple unsequenced modifications to 'self' (when using def(py::self OP Type()))
|
||||
#pragma clang diagnostic ignored \
|
||||
"-Wunsequenced" // multiple unsequenced modifications to 'self' (when using
|
||||
// def(py::self OP Type()))
|
||||
#elif defined(_MSC_VER)
|
||||
# pragma warning(push)
|
||||
# pragma warning(disable: 4127) // warning C4127: Conditional expression is constant
|
||||
#pragma warning(push)
|
||||
#pragma warning( \
|
||||
disable : 4127) // warning C4127: Conditional expression is constant
|
||||
#endif
|
||||
|
||||
NAMESPACE_BEGIN(PYBIND11_NAMESPACE)
|
||||
@@ -23,136 +26,191 @@ NAMESPACE_BEGIN(detail)
|
||||
|
||||
/// Enumeration with all supported operator types
|
||||
enum op_id : int {
|
||||
op_add, op_sub, op_mul, op_div, op_mod, op_divmod, op_pow, op_lshift,
|
||||
op_rshift, op_and, op_xor, op_or, op_neg, op_pos, op_abs, op_invert,
|
||||
op_int, op_long, op_float, op_str, op_cmp, op_gt, op_ge, op_lt, op_le,
|
||||
op_eq, op_ne, op_iadd, op_isub, op_imul, op_idiv, op_imod, op_ilshift,
|
||||
op_irshift, op_iand, op_ixor, op_ior, op_complex, op_bool, op_nonzero,
|
||||
op_repr, op_truediv, op_itruediv, op_hash
|
||||
op_add,
|
||||
op_sub,
|
||||
op_mul,
|
||||
op_div,
|
||||
op_mod,
|
||||
op_divmod,
|
||||
op_pow,
|
||||
op_lshift,
|
||||
op_rshift,
|
||||
op_and,
|
||||
op_xor,
|
||||
op_or,
|
||||
op_neg,
|
||||
op_pos,
|
||||
op_abs,
|
||||
op_invert,
|
||||
op_int,
|
||||
op_long,
|
||||
op_float,
|
||||
op_str,
|
||||
op_cmp,
|
||||
op_gt,
|
||||
op_ge,
|
||||
op_lt,
|
||||
op_le,
|
||||
op_eq,
|
||||
op_ne,
|
||||
op_iadd,
|
||||
op_isub,
|
||||
op_imul,
|
||||
op_idiv,
|
||||
op_imod,
|
||||
op_ilshift,
|
||||
op_irshift,
|
||||
op_iand,
|
||||
op_ixor,
|
||||
op_ior,
|
||||
op_complex,
|
||||
op_bool,
|
||||
op_nonzero,
|
||||
op_repr,
|
||||
op_truediv,
|
||||
op_itruediv,
|
||||
op_hash
|
||||
};
|
||||
|
||||
enum op_type : int {
|
||||
op_l, /* base type on left */
|
||||
op_r, /* base type on right */
|
||||
op_u /* unary operator */
|
||||
op_l, /* base type on left */
|
||||
op_r, /* base type on right */
|
||||
op_u /* unary operator */
|
||||
};
|
||||
|
||||
struct self_t { };
|
||||
struct self_t {};
|
||||
static const self_t self = self_t();
|
||||
|
||||
/// Type for an unused type slot
|
||||
struct undefined_t { };
|
||||
struct undefined_t {};
|
||||
|
||||
/// Don't warn about an unused variable
|
||||
inline self_t __self() { return self; }
|
||||
|
||||
/// base template of operator implementations
|
||||
template <op_id, op_type, typename B, typename L, typename R> struct op_impl { };
|
||||
template <op_id, op_type, typename B, typename L, typename R> struct op_impl {};
|
||||
|
||||
/// Operator implementation generator
|
||||
template <op_id id, op_type ot, typename L, typename R> struct op_ {
|
||||
template <typename Class, typename... Extra> void execute(Class &cl, const Extra&... extra) const {
|
||||
using Base = typename Class::type;
|
||||
using L_type = conditional_t<std::is_same<L, self_t>::value, Base, L>;
|
||||
using R_type = conditional_t<std::is_same<R, self_t>::value, Base, R>;
|
||||
using op = op_impl<id, ot, Base, L_type, R_type>;
|
||||
cl.def(op::name(), &op::execute, is_operator(), extra...);
|
||||
#if PY_MAJOR_VERSION < 3
|
||||
if (id == op_truediv || id == op_itruediv)
|
||||
cl.def(id == op_itruediv ? "__idiv__" : ot == op_l ? "__div__" : "__rdiv__",
|
||||
&op::execute, is_operator(), extra...);
|
||||
#endif
|
||||
}
|
||||
template <typename Class, typename... Extra> void execute_cast(Class &cl, const Extra&... extra) const {
|
||||
using Base = typename Class::type;
|
||||
using L_type = conditional_t<std::is_same<L, self_t>::value, Base, L>;
|
||||
using R_type = conditional_t<std::is_same<R, self_t>::value, Base, R>;
|
||||
using op = op_impl<id, ot, Base, L_type, R_type>;
|
||||
cl.def(op::name(), &op::execute_cast, is_operator(), extra...);
|
||||
#if PY_MAJOR_VERSION < 3
|
||||
if (id == op_truediv || id == op_itruediv)
|
||||
cl.def(id == op_itruediv ? "__idiv__" : ot == op_l ? "__div__" : "__rdiv__",
|
||||
&op::execute, is_operator(), extra...);
|
||||
#endif
|
||||
}
|
||||
template <typename Class, typename... Extra>
|
||||
void execute(Class &cl, const Extra &... extra) const {
|
||||
using Base = typename Class::type;
|
||||
using L_type = conditional_t<std::is_same<L, self_t>::value, Base, L>;
|
||||
using R_type = conditional_t<std::is_same<R, self_t>::value, Base, R>;
|
||||
using op = op_impl<id, ot, Base, L_type, R_type>;
|
||||
cl.def(op::name(), &op::execute, is_operator(), extra...);
|
||||
#if PY_MAJOR_VERSION < 3
|
||||
if (id == op_truediv || id == op_itruediv)
|
||||
cl.def(id == op_itruediv ? "__idiv__"
|
||||
: ot == op_l ? "__div__" : "__rdiv__",
|
||||
&op::execute, is_operator(), extra...);
|
||||
#endif
|
||||
}
|
||||
template <typename Class, typename... Extra>
|
||||
void execute_cast(Class &cl, const Extra &... extra) const {
|
||||
using Base = typename Class::type;
|
||||
using L_type = conditional_t<std::is_same<L, self_t>::value, Base, L>;
|
||||
using R_type = conditional_t<std::is_same<R, self_t>::value, Base, R>;
|
||||
using op = op_impl<id, ot, Base, L_type, R_type>;
|
||||
cl.def(op::name(), &op::execute_cast, is_operator(), extra...);
|
||||
#if PY_MAJOR_VERSION < 3
|
||||
if (id == op_truediv || id == op_itruediv)
|
||||
cl.def(id == op_itruediv ? "__idiv__"
|
||||
: ot == op_l ? "__div__" : "__rdiv__",
|
||||
&op::execute, is_operator(), extra...);
|
||||
#endif
|
||||
}
|
||||
};
|
||||
|
||||
#define PYBIND11_BINARY_OPERATOR(id, rid, op, expr) \
|
||||
template <typename B, typename L, typename R> struct op_impl<op_##id, op_l, B, L, R> { \
|
||||
static char const* name() { return "__" #id "__"; } \
|
||||
static auto execute(const L &l, const R &r) -> decltype(expr) { return (expr); } \
|
||||
static B execute_cast(const L &l, const R &r) { return B(expr); } \
|
||||
}; \
|
||||
template <typename B, typename L, typename R> struct op_impl<op_##id, op_r, B, L, R> { \
|
||||
static char const* name() { return "__" #rid "__"; } \
|
||||
static auto execute(const R &r, const L &l) -> decltype(expr) { return (expr); } \
|
||||
static B execute_cast(const R &r, const L &l) { return B(expr); } \
|
||||
}; \
|
||||
inline op_<op_##id, op_l, self_t, self_t> op(const self_t &, const self_t &) { \
|
||||
return op_<op_##id, op_l, self_t, self_t>(); \
|
||||
} \
|
||||
template <typename T> op_<op_##id, op_l, self_t, T> op(const self_t &, const T &) { \
|
||||
return op_<op_##id, op_l, self_t, T>(); \
|
||||
} \
|
||||
template <typename T> op_<op_##id, op_r, T, self_t> op(const T &, const self_t &) { \
|
||||
return op_<op_##id, op_r, T, self_t>(); \
|
||||
}
|
||||
#define PYBIND11_BINARY_OPERATOR(id, rid, op, expr) \
|
||||
template <typename B, typename L, typename R> \
|
||||
struct op_impl<op_##id, op_l, B, L, R> { \
|
||||
static char const *name() { return "__" #id "__"; } \
|
||||
static auto execute(const L &l, const R &r) -> decltype(expr) { \
|
||||
return (expr); \
|
||||
} \
|
||||
static B execute_cast(const L &l, const R &r) { return B(expr); } \
|
||||
}; \
|
||||
template <typename B, typename L, typename R> \
|
||||
struct op_impl<op_##id, op_r, B, L, R> { \
|
||||
static char const *name() { return "__" #rid "__"; } \
|
||||
static auto execute(const R &r, const L &l) -> decltype(expr) { \
|
||||
return (expr); \
|
||||
} \
|
||||
static B execute_cast(const R &r, const L &l) { return B(expr); } \
|
||||
}; \
|
||||
inline op_<op_##id, op_l, self_t, self_t> op(const self_t &, \
|
||||
const self_t &) { \
|
||||
return op_<op_##id, op_l, self_t, self_t>(); \
|
||||
} \
|
||||
template <typename T> \
|
||||
op_<op_##id, op_l, self_t, T> op(const self_t &, const T &) { \
|
||||
return op_<op_##id, op_l, self_t, T>(); \
|
||||
} \
|
||||
template <typename T> \
|
||||
op_<op_##id, op_r, T, self_t> op(const T &, const self_t &) { \
|
||||
return op_<op_##id, op_r, T, self_t>(); \
|
||||
}
|
||||
|
||||
#define PYBIND11_INPLACE_OPERATOR(id, op, expr) \
|
||||
template <typename B, typename L, typename R> struct op_impl<op_##id, op_l, B, L, R> { \
|
||||
static char const* name() { return "__" #id "__"; } \
|
||||
static auto execute(L &l, const R &r) -> decltype(expr) { return expr; } \
|
||||
static B execute_cast(L &l, const R &r) { return B(expr); } \
|
||||
}; \
|
||||
template <typename T> op_<op_##id, op_l, self_t, T> op(const self_t &, const T &) { \
|
||||
return op_<op_##id, op_l, self_t, T>(); \
|
||||
}
|
||||
#define PYBIND11_INPLACE_OPERATOR(id, op, expr) \
|
||||
template <typename B, typename L, typename R> \
|
||||
struct op_impl<op_##id, op_l, B, L, R> { \
|
||||
static char const *name() { return "__" #id "__"; } \
|
||||
static auto execute(L &l, const R &r) -> decltype(expr) { return expr; } \
|
||||
static B execute_cast(L &l, const R &r) { return B(expr); } \
|
||||
}; \
|
||||
template <typename T> \
|
||||
op_<op_##id, op_l, self_t, T> op(const self_t &, const T &) { \
|
||||
return op_<op_##id, op_l, self_t, T>(); \
|
||||
}
|
||||
|
||||
#define PYBIND11_UNARY_OPERATOR(id, op, expr) \
|
||||
template <typename B, typename L> struct op_impl<op_##id, op_u, B, L, undefined_t> { \
|
||||
static char const* name() { return "__" #id "__"; } \
|
||||
static auto execute(const L &l) -> decltype(expr) { return expr; } \
|
||||
static B execute_cast(const L &l) { return B(expr); } \
|
||||
}; \
|
||||
inline op_<op_##id, op_u, self_t, undefined_t> op(const self_t &) { \
|
||||
return op_<op_##id, op_u, self_t, undefined_t>(); \
|
||||
}
|
||||
#define PYBIND11_UNARY_OPERATOR(id, op, expr) \
|
||||
template <typename B, typename L> \
|
||||
struct op_impl<op_##id, op_u, B, L, undefined_t> { \
|
||||
static char const *name() { return "__" #id "__"; } \
|
||||
static auto execute(const L &l) -> decltype(expr) { return expr; } \
|
||||
static B execute_cast(const L &l) { return B(expr); } \
|
||||
}; \
|
||||
inline op_<op_##id, op_u, self_t, undefined_t> op(const self_t &) { \
|
||||
return op_<op_##id, op_u, self_t, undefined_t>(); \
|
||||
}
|
||||
|
||||
PYBIND11_BINARY_OPERATOR(sub, rsub, operator-, l - r)
|
||||
PYBIND11_BINARY_OPERATOR(add, radd, operator+, l + r)
|
||||
PYBIND11_BINARY_OPERATOR(mul, rmul, operator*, l * r)
|
||||
PYBIND11_BINARY_OPERATOR(truediv, rtruediv, operator/, l / r)
|
||||
PYBIND11_BINARY_OPERATOR(mod, rmod, operator%, l % r)
|
||||
PYBIND11_BINARY_OPERATOR(lshift, rlshift, operator<<, l << r)
|
||||
PYBIND11_BINARY_OPERATOR(rshift, rrshift, operator>>, l >> r)
|
||||
PYBIND11_BINARY_OPERATOR(and, rand, operator&, l & r)
|
||||
PYBIND11_BINARY_OPERATOR(xor, rxor, operator^, l ^ r)
|
||||
PYBIND11_BINARY_OPERATOR(eq, eq, operator==, l == r)
|
||||
PYBIND11_BINARY_OPERATOR(ne, ne, operator!=, l != r)
|
||||
PYBIND11_BINARY_OPERATOR(or, ror, operator|, l | r)
|
||||
PYBIND11_BINARY_OPERATOR(gt, lt, operator>, l > r)
|
||||
PYBIND11_BINARY_OPERATOR(ge, le, operator>=, l >= r)
|
||||
PYBIND11_BINARY_OPERATOR(lt, gt, operator<, l < r)
|
||||
PYBIND11_BINARY_OPERATOR(le, ge, operator<=, l <= r)
|
||||
//PYBIND11_BINARY_OPERATOR(pow, rpow, pow, std::pow(l, r))
|
||||
PYBIND11_INPLACE_OPERATOR(iadd, operator+=, l += r)
|
||||
PYBIND11_INPLACE_OPERATOR(isub, operator-=, l -= r)
|
||||
PYBIND11_INPLACE_OPERATOR(imul, operator*=, l *= r)
|
||||
PYBIND11_INPLACE_OPERATOR(itruediv, operator/=, l /= r)
|
||||
PYBIND11_INPLACE_OPERATOR(imod, operator%=, l %= r)
|
||||
PYBIND11_INPLACE_OPERATOR(ilshift, operator<<=, l <<= r)
|
||||
PYBIND11_INPLACE_OPERATOR(irshift, operator>>=, l >>= r)
|
||||
PYBIND11_INPLACE_OPERATOR(iand, operator&=, l &= r)
|
||||
PYBIND11_INPLACE_OPERATOR(ixor, operator^=, l ^= r)
|
||||
PYBIND11_INPLACE_OPERATOR(ior, operator|=, l |= r)
|
||||
PYBIND11_UNARY_OPERATOR(neg, operator-, -l)
|
||||
PYBIND11_UNARY_OPERATOR(pos, operator+, +l)
|
||||
PYBIND11_UNARY_OPERATOR(abs, abs, std::abs(l))
|
||||
PYBIND11_UNARY_OPERATOR(hash, hash, std::hash<L>()(l))
|
||||
PYBIND11_UNARY_OPERATOR(invert, operator~, (~l))
|
||||
PYBIND11_UNARY_OPERATOR(bool, operator!, !!l)
|
||||
PYBIND11_UNARY_OPERATOR(int, int_, (int) l)
|
||||
PYBIND11_UNARY_OPERATOR(float, float_, (double) l)
|
||||
PYBIND11_BINARY_OPERATOR(sub, rsub, operator-, l - r)
|
||||
PYBIND11_BINARY_OPERATOR(add, radd, operator+, l + r)
|
||||
PYBIND11_BINARY_OPERATOR(mul, rmul, operator*, l * r)
|
||||
PYBIND11_BINARY_OPERATOR(truediv, rtruediv, operator/, l / r)
|
||||
PYBIND11_BINARY_OPERATOR(mod, rmod, operator%, l % r)
|
||||
PYBIND11_BINARY_OPERATOR(lshift, rlshift, operator<<, l << r)
|
||||
PYBIND11_BINARY_OPERATOR(rshift, rrshift, operator>>, l>> r)
|
||||
PYBIND11_BINARY_OPERATOR(and, rand, operator&, l & r)
|
||||
PYBIND11_BINARY_OPERATOR(xor, rxor, operator^, l ^ r)
|
||||
PYBIND11_BINARY_OPERATOR(eq, eq, operator==, l == r)
|
||||
PYBIND11_BINARY_OPERATOR(ne, ne, operator!=, l != r)
|
||||
PYBIND11_BINARY_OPERATOR(or, ror, operator|, l | r)
|
||||
PYBIND11_BINARY_OPERATOR(gt, lt, operator>, l> r)
|
||||
PYBIND11_BINARY_OPERATOR(ge, le, operator>=, l >= r)
|
||||
PYBIND11_BINARY_OPERATOR(lt, gt, operator<, l<r)
|
||||
PYBIND11_BINARY_OPERATOR(le, ge, operator<=, l <= r)
|
||||
// PYBIND11_BINARY_OPERATOR(pow, rpow, pow, std::pow(l,
|
||||
// r))
|
||||
PYBIND11_INPLACE_OPERATOR(iadd, operator+=, l += r)
|
||||
PYBIND11_INPLACE_OPERATOR(isub, operator-=, l -= r)
|
||||
PYBIND11_INPLACE_OPERATOR(imul, operator*=, l *= r)
|
||||
PYBIND11_INPLACE_OPERATOR(itruediv, operator/=, l /= r)
|
||||
PYBIND11_INPLACE_OPERATOR(imod, operator%=, l %= r)
|
||||
PYBIND11_INPLACE_OPERATOR(ilshift, operator<<=, l <<= r)
|
||||
PYBIND11_INPLACE_OPERATOR(irshift, operator>>=, l >>= r)
|
||||
PYBIND11_INPLACE_OPERATOR(iand, operator&=, l &= r)
|
||||
PYBIND11_INPLACE_OPERATOR(ixor, operator^=, l ^= r)
|
||||
PYBIND11_INPLACE_OPERATOR(ior, operator|=, l |= r)
|
||||
PYBIND11_UNARY_OPERATOR(neg, operator-, - l)
|
||||
PYBIND11_UNARY_OPERATOR(pos, operator+, + l)
|
||||
PYBIND11_UNARY_OPERATOR(abs, abs, std::abs(l))
|
||||
PYBIND11_UNARY_OPERATOR(hash, hash, std::hash<L>()(l))
|
||||
PYBIND11_UNARY_OPERATOR(invert, operator~,(~l))
|
||||
PYBIND11_UNARY_OPERATOR(bool, operator!, !!l)
|
||||
PYBIND11_UNARY_OPERATOR(int, int_, (int)l)
|
||||
PYBIND11_UNARY_OPERATOR(float, float_, (double)l)
|
||||
|
||||
#undef PYBIND11_BINARY_OPERATOR
|
||||
#undef PYBIND11_INPLACE_OPERATOR
|
||||
@@ -164,5 +222,5 @@ using detail::self;
|
||||
NAMESPACE_END(PYBIND11_NAMESPACE)
|
||||
|
||||
#if defined(_MSC_VER)
|
||||
# pragma warning(pop)
|
||||
#pragma warning(pop)
|
||||
#endif
|
||||
|
@@ -15,51 +15,65 @@ NAMESPACE_BEGIN(PYBIND11_NAMESPACE)
|
||||
|
||||
class options {
|
||||
public:
|
||||
// Default RAII constructor, which leaves settings as they currently are.
|
||||
options() : previous_state(global_state()) {}
|
||||
|
||||
// Default RAII constructor, which leaves settings as they currently are.
|
||||
options() : previous_state(global_state()) {}
|
||||
// Class is non-copyable.
|
||||
options(const options &) = delete;
|
||||
options &operator=(const options &) = delete;
|
||||
|
||||
// Class is non-copyable.
|
||||
options(const options&) = delete;
|
||||
options& operator=(const options&) = delete;
|
||||
// Destructor, which restores settings that were in effect before.
|
||||
~options() { global_state() = previous_state; }
|
||||
|
||||
// Destructor, which restores settings that were in effect before.
|
||||
~options() {
|
||||
global_state() = previous_state;
|
||||
}
|
||||
// Setter methods (affect the global state):
|
||||
|
||||
// Setter methods (affect the global state):
|
||||
options &disable_user_defined_docstrings() & {
|
||||
global_state().show_user_defined_docstrings = false;
|
||||
return *this;
|
||||
}
|
||||
|
||||
options& disable_user_defined_docstrings() & { global_state().show_user_defined_docstrings = false; return *this; }
|
||||
options &enable_user_defined_docstrings() & {
|
||||
global_state().show_user_defined_docstrings = true;
|
||||
return *this;
|
||||
}
|
||||
|
||||
options& enable_user_defined_docstrings() & { global_state().show_user_defined_docstrings = true; return *this; }
|
||||
options &disable_function_signatures() & {
|
||||
global_state().show_function_signatures = false;
|
||||
return *this;
|
||||
}
|
||||
|
||||
options& disable_function_signatures() & { global_state().show_function_signatures = false; return *this; }
|
||||
options &enable_function_signatures() & {
|
||||
global_state().show_function_signatures = true;
|
||||
return *this;
|
||||
}
|
||||
|
||||
options& enable_function_signatures() & { global_state().show_function_signatures = true; return *this; }
|
||||
// Getter methods (return the global state):
|
||||
|
||||
// Getter methods (return the global state):
|
||||
static bool show_user_defined_docstrings() {
|
||||
return global_state().show_user_defined_docstrings;
|
||||
}
|
||||
|
||||
static bool show_user_defined_docstrings() { return global_state().show_user_defined_docstrings; }
|
||||
static bool show_function_signatures() {
|
||||
return global_state().show_function_signatures;
|
||||
}
|
||||
|
||||
static bool show_function_signatures() { return global_state().show_function_signatures; }
|
||||
|
||||
// This type is not meant to be allocated on the heap.
|
||||
void* operator new(size_t) = delete;
|
||||
// This type is not meant to be allocated on the heap.
|
||||
void *operator new(size_t) = delete;
|
||||
|
||||
private:
|
||||
struct state {
|
||||
bool show_user_defined_docstrings =
|
||||
true; //< Include user-supplied texts in docstrings.
|
||||
bool show_function_signatures =
|
||||
true; //< Include auto-generated function signatures in docstrings.
|
||||
};
|
||||
|
||||
struct state {
|
||||
bool show_user_defined_docstrings = true; //< Include user-supplied texts in docstrings.
|
||||
bool show_function_signatures = true; //< Include auto-generated function signatures in docstrings.
|
||||
};
|
||||
static state &global_state() {
|
||||
static state instance;
|
||||
return instance;
|
||||
}
|
||||
|
||||
static state &global_state() {
|
||||
static state instance;
|
||||
return instance;
|
||||
}
|
||||
|
||||
state previous_state;
|
||||
state previous_state;
|
||||
};
|
||||
|
||||
NAMESPACE_END(PYBIND11_NAMESPACE)
|
||||
|
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
@@ -10,373 +10,411 @@
|
||||
#pragma once
|
||||
|
||||
#include "pybind11.h"
|
||||
#include <set>
|
||||
#include <unordered_set>
|
||||
#include <map>
|
||||
#include <unordered_map>
|
||||
#include <deque>
|
||||
#include <iostream>
|
||||
#include <list>
|
||||
#include <deque>
|
||||
#include <map>
|
||||
#include <set>
|
||||
#include <unordered_map>
|
||||
#include <unordered_set>
|
||||
#include <valarray>
|
||||
|
||||
#if defined(_MSC_VER)
|
||||
#pragma warning(push)
|
||||
#pragma warning(disable: 4127) // warning C4127: Conditional expression is constant
|
||||
#pragma warning( \
|
||||
disable : 4127) // warning C4127: Conditional expression is constant
|
||||
#endif
|
||||
|
||||
#ifdef __has_include
|
||||
// std::optional (but including it in c++14 mode isn't allowed)
|
||||
# if defined(PYBIND11_CPP17) && __has_include(<optional>)
|
||||
# include <optional>
|
||||
# define PYBIND11_HAS_OPTIONAL 1
|
||||
# endif
|
||||
#if defined(PYBIND11_CPP17) && __has_include(<optional>)
|
||||
#include <optional>
|
||||
#define PYBIND11_HAS_OPTIONAL 1
|
||||
#endif
|
||||
// std::experimental::optional (but not allowed in c++11 mode)
|
||||
# if defined(PYBIND11_CPP14) && (__has_include(<experimental/optional>) && \
|
||||
#if defined(PYBIND11_CPP14) && (__has_include(<experimental/optional>) && \
|
||||
!__has_include(<optional>))
|
||||
# include <experimental/optional>
|
||||
# define PYBIND11_HAS_EXP_OPTIONAL 1
|
||||
# endif
|
||||
#include <experimental/optional>
|
||||
#define PYBIND11_HAS_EXP_OPTIONAL 1
|
||||
#endif
|
||||
// std::variant
|
||||
# if defined(PYBIND11_CPP17) && __has_include(<variant>)
|
||||
# include <variant>
|
||||
# define PYBIND11_HAS_VARIANT 1
|
||||
# endif
|
||||
#if defined(PYBIND11_CPP17) && __has_include(<variant>)
|
||||
#include <variant>
|
||||
#define PYBIND11_HAS_VARIANT 1
|
||||
#endif
|
||||
#elif defined(_MSC_VER) && defined(PYBIND11_CPP17)
|
||||
# include <optional>
|
||||
# include <variant>
|
||||
# define PYBIND11_HAS_OPTIONAL 1
|
||||
# define PYBIND11_HAS_VARIANT 1
|
||||
#include <optional>
|
||||
#include <variant>
|
||||
#define PYBIND11_HAS_OPTIONAL 1
|
||||
#define PYBIND11_HAS_VARIANT 1
|
||||
#endif
|
||||
|
||||
NAMESPACE_BEGIN(PYBIND11_NAMESPACE)
|
||||
NAMESPACE_BEGIN(detail)
|
||||
|
||||
/// Extracts an const lvalue reference or rvalue reference for U based on the type of T (e.g. for
|
||||
/// forwarding a container element). Typically used indirect via forwarded_type(), below.
|
||||
/// Extracts an const lvalue reference or rvalue reference for U based on the
|
||||
/// type of T (e.g. for forwarding a container element). Typically used
|
||||
/// indirect via forwarded_type(), below.
|
||||
template <typename T, typename U>
|
||||
using forwarded_type = conditional_t<
|
||||
std::is_lvalue_reference<T>::value, remove_reference_t<U> &, remove_reference_t<U> &&>;
|
||||
using forwarded_type =
|
||||
conditional_t<std::is_lvalue_reference<T>::value, remove_reference_t<U> &,
|
||||
remove_reference_t<U> &&>;
|
||||
|
||||
/// Forwards a value U as rvalue or lvalue according to whether T is rvalue or lvalue; typically
|
||||
/// used for forwarding a container's elements.
|
||||
template <typename T, typename U>
|
||||
forwarded_type<T, U> forward_like(U &&u) {
|
||||
return std::forward<detail::forwarded_type<T, U>>(std::forward<U>(u));
|
||||
/// Forwards a value U as rvalue or lvalue according to whether T is rvalue or
|
||||
/// lvalue; typically used for forwarding a container's elements.
|
||||
template <typename T, typename U> forwarded_type<T, U> forward_like(U &&u) {
|
||||
return std::forward<detail::forwarded_type<T, U>>(std::forward<U>(u));
|
||||
}
|
||||
|
||||
template <typename Type, typename Key> struct set_caster {
|
||||
using type = Type;
|
||||
using key_conv = make_caster<Key>;
|
||||
using type = Type;
|
||||
using key_conv = make_caster<Key>;
|
||||
|
||||
bool load(handle src, bool convert) {
|
||||
if (!isinstance<pybind11::set>(src))
|
||||
return false;
|
||||
auto s = reinterpret_borrow<pybind11::set>(src);
|
||||
value.clear();
|
||||
for (auto entry : s) {
|
||||
key_conv conv;
|
||||
if (!conv.load(entry, convert))
|
||||
return false;
|
||||
value.insert(cast_op<Key &&>(std::move(conv)));
|
||||
}
|
||||
return true;
|
||||
bool load(handle src, bool convert) {
|
||||
if (!isinstance<pybind11::set>(src))
|
||||
return false;
|
||||
auto s = reinterpret_borrow<pybind11::set>(src);
|
||||
value.clear();
|
||||
for (auto entry : s) {
|
||||
key_conv conv;
|
||||
if (!conv.load(entry, convert))
|
||||
return false;
|
||||
value.insert(cast_op<Key &&>(std::move(conv)));
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
static handle cast(T &&src, return_value_policy policy, handle parent) {
|
||||
if (!std::is_lvalue_reference<T>::value)
|
||||
policy = return_value_policy_override<Key>::policy(policy);
|
||||
pybind11::set s;
|
||||
for (auto &&value : src) {
|
||||
auto value_ = reinterpret_steal<object>(key_conv::cast(forward_like<T>(value), policy, parent));
|
||||
if (!value_ || !s.add(value_))
|
||||
return handle();
|
||||
}
|
||||
return s.release();
|
||||
template <typename T>
|
||||
static handle cast(T &&src, return_value_policy policy, handle parent) {
|
||||
if (!std::is_lvalue_reference<T>::value)
|
||||
policy = return_value_policy_override<Key>::policy(policy);
|
||||
pybind11::set s;
|
||||
for (auto &&value : src) {
|
||||
auto value_ = reinterpret_steal<object>(
|
||||
key_conv::cast(forward_like<T>(value), policy, parent));
|
||||
if (!value_ || !s.add(value_))
|
||||
return handle();
|
||||
}
|
||||
return s.release();
|
||||
}
|
||||
|
||||
PYBIND11_TYPE_CASTER(type, _("Set[") + key_conv::name + _("]"));
|
||||
PYBIND11_TYPE_CASTER(type, _("Set[") + key_conv::name + _("]"));
|
||||
};
|
||||
|
||||
template <typename Type, typename Key, typename Value> struct map_caster {
|
||||
using key_conv = make_caster<Key>;
|
||||
using value_conv = make_caster<Value>;
|
||||
using key_conv = make_caster<Key>;
|
||||
using value_conv = make_caster<Value>;
|
||||
|
||||
bool load(handle src, bool convert) {
|
||||
if (!isinstance<dict>(src))
|
||||
return false;
|
||||
auto d = reinterpret_borrow<dict>(src);
|
||||
value.clear();
|
||||
for (auto it : d) {
|
||||
key_conv kconv;
|
||||
value_conv vconv;
|
||||
if (!kconv.load(it.first.ptr(), convert) ||
|
||||
!vconv.load(it.second.ptr(), convert))
|
||||
return false;
|
||||
value.emplace(cast_op<Key &&>(std::move(kconv)), cast_op<Value &&>(std::move(vconv)));
|
||||
}
|
||||
return true;
|
||||
bool load(handle src, bool convert) {
|
||||
if (!isinstance<dict>(src))
|
||||
return false;
|
||||
auto d = reinterpret_borrow<dict>(src);
|
||||
value.clear();
|
||||
for (auto it : d) {
|
||||
key_conv kconv;
|
||||
value_conv vconv;
|
||||
if (!kconv.load(it.first.ptr(), convert) ||
|
||||
!vconv.load(it.second.ptr(), convert))
|
||||
return false;
|
||||
value.emplace(cast_op<Key &&>(std::move(kconv)),
|
||||
cast_op<Value &&>(std::move(vconv)));
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
static handle cast(T &&src, return_value_policy policy, handle parent) {
|
||||
dict d;
|
||||
return_value_policy policy_key = policy;
|
||||
return_value_policy policy_value = policy;
|
||||
if (!std::is_lvalue_reference<T>::value) {
|
||||
policy_key = return_value_policy_override<Key>::policy(policy_key);
|
||||
policy_value = return_value_policy_override<Value>::policy(policy_value);
|
||||
}
|
||||
for (auto &&kv : src) {
|
||||
auto key = reinterpret_steal<object>(key_conv::cast(forward_like<T>(kv.first), policy_key, parent));
|
||||
auto value = reinterpret_steal<object>(value_conv::cast(forward_like<T>(kv.second), policy_value, parent));
|
||||
if (!key || !value)
|
||||
return handle();
|
||||
d[key] = value;
|
||||
}
|
||||
return d.release();
|
||||
template <typename T>
|
||||
static handle cast(T &&src, return_value_policy policy, handle parent) {
|
||||
dict d;
|
||||
return_value_policy policy_key = policy;
|
||||
return_value_policy policy_value = policy;
|
||||
if (!std::is_lvalue_reference<T>::value) {
|
||||
policy_key = return_value_policy_override<Key>::policy(policy_key);
|
||||
policy_value = return_value_policy_override<Value>::policy(policy_value);
|
||||
}
|
||||
for (auto &&kv : src) {
|
||||
auto key = reinterpret_steal<object>(
|
||||
key_conv::cast(forward_like<T>(kv.first), policy_key, parent));
|
||||
auto value = reinterpret_steal<object>(
|
||||
value_conv::cast(forward_like<T>(kv.second), policy_value, parent));
|
||||
if (!key || !value)
|
||||
return handle();
|
||||
d[key] = value;
|
||||
}
|
||||
return d.release();
|
||||
}
|
||||
|
||||
PYBIND11_TYPE_CASTER(Type, _("Dict[") + key_conv::name + _(", ") + value_conv::name + _("]"));
|
||||
PYBIND11_TYPE_CASTER(Type, _("Dict[") + key_conv::name + _(", ") +
|
||||
value_conv::name + _("]"));
|
||||
};
|
||||
|
||||
template <typename Type, typename Value> struct list_caster {
|
||||
using value_conv = make_caster<Value>;
|
||||
using value_conv = make_caster<Value>;
|
||||
|
||||
bool load(handle src, bool convert) {
|
||||
if (!isinstance<sequence>(src) || isinstance<str>(src))
|
||||
return false;
|
||||
auto s = reinterpret_borrow<sequence>(src);
|
||||
value.clear();
|
||||
reserve_maybe(s, &value);
|
||||
for (auto it : s) {
|
||||
value_conv conv;
|
||||
if (!conv.load(it, convert))
|
||||
return false;
|
||||
value.push_back(cast_op<Value &&>(std::move(conv)));
|
||||
}
|
||||
return true;
|
||||
bool load(handle src, bool convert) {
|
||||
if (!isinstance<sequence>(src) || isinstance<str>(src))
|
||||
return false;
|
||||
auto s = reinterpret_borrow<sequence>(src);
|
||||
value.clear();
|
||||
reserve_maybe(s, &value);
|
||||
for (auto it : s) {
|
||||
value_conv conv;
|
||||
if (!conv.load(it, convert))
|
||||
return false;
|
||||
value.push_back(cast_op<Value &&>(std::move(conv)));
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
private:
|
||||
template <typename T = Type,
|
||||
enable_if_t<std::is_same<decltype(std::declval<T>().reserve(0)), void>::value, int> = 0>
|
||||
void reserve_maybe(sequence s, Type *) { value.reserve(s.size()); }
|
||||
void reserve_maybe(sequence, void *) { }
|
||||
template <typename T = Type,
|
||||
enable_if_t<std::is_same<decltype(std::declval<T>().reserve(0)),
|
||||
void>::value,
|
||||
int> = 0>
|
||||
void reserve_maybe(sequence s, Type *) {
|
||||
value.reserve(s.size());
|
||||
}
|
||||
void reserve_maybe(sequence, void *) {}
|
||||
|
||||
public:
|
||||
template <typename T>
|
||||
static handle cast(T &&src, return_value_policy policy, handle parent) {
|
||||
if (!std::is_lvalue_reference<T>::value)
|
||||
policy = return_value_policy_override<Value>::policy(policy);
|
||||
list l(src.size());
|
||||
size_t index = 0;
|
||||
for (auto &&value : src) {
|
||||
auto value_ = reinterpret_steal<object>(value_conv::cast(forward_like<T>(value), policy, parent));
|
||||
if (!value_)
|
||||
return handle();
|
||||
PyList_SET_ITEM(l.ptr(), (ssize_t) index++, value_.release().ptr()); // steals a reference
|
||||
}
|
||||
return l.release();
|
||||
template <typename T>
|
||||
static handle cast(T &&src, return_value_policy policy, handle parent) {
|
||||
if (!std::is_lvalue_reference<T>::value)
|
||||
policy = return_value_policy_override<Value>::policy(policy);
|
||||
list l(src.size());
|
||||
size_t index = 0;
|
||||
for (auto &&value : src) {
|
||||
auto value_ = reinterpret_steal<object>(
|
||||
value_conv::cast(forward_like<T>(value), policy, parent));
|
||||
if (!value_)
|
||||
return handle();
|
||||
PyList_SET_ITEM(l.ptr(), (ssize_t)index++,
|
||||
value_.release().ptr()); // steals a reference
|
||||
}
|
||||
return l.release();
|
||||
}
|
||||
|
||||
PYBIND11_TYPE_CASTER(Type, _("List[") + value_conv::name + _("]"));
|
||||
PYBIND11_TYPE_CASTER(Type, _("List[") + value_conv::name + _("]"));
|
||||
};
|
||||
|
||||
template <typename Type, typename Alloc> struct type_caster<std::vector<Type, Alloc>>
|
||||
: list_caster<std::vector<Type, Alloc>, Type> { };
|
||||
template <typename Type, typename Alloc>
|
||||
struct type_caster<std::vector<Type, Alloc>>
|
||||
: list_caster<std::vector<Type, Alloc>, Type> {};
|
||||
|
||||
template <typename Type, typename Alloc> struct type_caster<std::deque<Type, Alloc>>
|
||||
: list_caster<std::deque<Type, Alloc>, Type> { };
|
||||
template <typename Type, typename Alloc>
|
||||
struct type_caster<std::deque<Type, Alloc>>
|
||||
: list_caster<std::deque<Type, Alloc>, Type> {};
|
||||
|
||||
template <typename Type, typename Alloc> struct type_caster<std::list<Type, Alloc>>
|
||||
: list_caster<std::list<Type, Alloc>, Type> { };
|
||||
template <typename Type, typename Alloc>
|
||||
struct type_caster<std::list<Type, Alloc>>
|
||||
: list_caster<std::list<Type, Alloc>, Type> {};
|
||||
|
||||
template <typename ArrayType, typename Value, bool Resizable, size_t Size = 0> struct array_caster {
|
||||
using value_conv = make_caster<Value>;
|
||||
template <typename ArrayType, typename Value, bool Resizable, size_t Size = 0>
|
||||
struct array_caster {
|
||||
using value_conv = make_caster<Value>;
|
||||
|
||||
private:
|
||||
template <bool R = Resizable>
|
||||
bool require_size(enable_if_t<R, size_t> size) {
|
||||
if (value.size() != size)
|
||||
value.resize(size);
|
||||
return true;
|
||||
}
|
||||
template <bool R = Resizable>
|
||||
bool require_size(enable_if_t<!R, size_t> size) {
|
||||
return size == Size;
|
||||
}
|
||||
template <bool R = Resizable> bool require_size(enable_if_t<R, size_t> size) {
|
||||
if (value.size() != size)
|
||||
value.resize(size);
|
||||
return true;
|
||||
}
|
||||
template <bool R = Resizable>
|
||||
bool require_size(enable_if_t<!R, size_t> size) {
|
||||
return size == Size;
|
||||
}
|
||||
|
||||
public:
|
||||
bool load(handle src, bool convert) {
|
||||
if (!isinstance<sequence>(src))
|
||||
return false;
|
||||
auto l = reinterpret_borrow<sequence>(src);
|
||||
if (!require_size(l.size()))
|
||||
return false;
|
||||
size_t ctr = 0;
|
||||
for (auto it : l) {
|
||||
value_conv conv;
|
||||
if (!conv.load(it, convert))
|
||||
return false;
|
||||
value[ctr++] = cast_op<Value &&>(std::move(conv));
|
||||
}
|
||||
return true;
|
||||
bool load(handle src, bool convert) {
|
||||
if (!isinstance<sequence>(src))
|
||||
return false;
|
||||
auto l = reinterpret_borrow<sequence>(src);
|
||||
if (!require_size(l.size()))
|
||||
return false;
|
||||
size_t ctr = 0;
|
||||
for (auto it : l) {
|
||||
value_conv conv;
|
||||
if (!conv.load(it, convert))
|
||||
return false;
|
||||
value[ctr++] = cast_op<Value &&>(std::move(conv));
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
static handle cast(T &&src, return_value_policy policy, handle parent) {
|
||||
list l(src.size());
|
||||
size_t index = 0;
|
||||
for (auto &&value : src) {
|
||||
auto value_ = reinterpret_steal<object>(value_conv::cast(forward_like<T>(value), policy, parent));
|
||||
if (!value_)
|
||||
return handle();
|
||||
PyList_SET_ITEM(l.ptr(), (ssize_t) index++, value_.release().ptr()); // steals a reference
|
||||
}
|
||||
return l.release();
|
||||
template <typename T>
|
||||
static handle cast(T &&src, return_value_policy policy, handle parent) {
|
||||
list l(src.size());
|
||||
size_t index = 0;
|
||||
for (auto &&value : src) {
|
||||
auto value_ = reinterpret_steal<object>(
|
||||
value_conv::cast(forward_like<T>(value), policy, parent));
|
||||
if (!value_)
|
||||
return handle();
|
||||
PyList_SET_ITEM(l.ptr(), (ssize_t)index++,
|
||||
value_.release().ptr()); // steals a reference
|
||||
}
|
||||
return l.release();
|
||||
}
|
||||
|
||||
PYBIND11_TYPE_CASTER(ArrayType, _("List[") + value_conv::name + _<Resizable>(_(""), _("[") + _<Size>() + _("]")) + _("]"));
|
||||
PYBIND11_TYPE_CASTER(ArrayType,
|
||||
_("List[") + value_conv::name +
|
||||
_<Resizable>(_(""), _("[") + _<Size>() + _("]")) +
|
||||
_("]"));
|
||||
};
|
||||
|
||||
template <typename Type, size_t Size> struct type_caster<std::array<Type, Size>>
|
||||
: array_caster<std::array<Type, Size>, Type, false, Size> { };
|
||||
template <typename Type, size_t Size>
|
||||
struct type_caster<std::array<Type, Size>>
|
||||
: array_caster<std::array<Type, Size>, Type, false, Size> {};
|
||||
|
||||
template <typename Type> struct type_caster<std::valarray<Type>>
|
||||
: array_caster<std::valarray<Type>, Type, true> { };
|
||||
template <typename Type>
|
||||
struct type_caster<std::valarray<Type>>
|
||||
: array_caster<std::valarray<Type>, Type, true> {};
|
||||
|
||||
template <typename Key, typename Compare, typename Alloc> struct type_caster<std::set<Key, Compare, Alloc>>
|
||||
: set_caster<std::set<Key, Compare, Alloc>, Key> { };
|
||||
template <typename Key, typename Compare, typename Alloc>
|
||||
struct type_caster<std::set<Key, Compare, Alloc>>
|
||||
: set_caster<std::set<Key, Compare, Alloc>, Key> {};
|
||||
|
||||
template <typename Key, typename Hash, typename Equal, typename Alloc> struct type_caster<std::unordered_set<Key, Hash, Equal, Alloc>>
|
||||
: set_caster<std::unordered_set<Key, Hash, Equal, Alloc>, Key> { };
|
||||
template <typename Key, typename Hash, typename Equal, typename Alloc>
|
||||
struct type_caster<std::unordered_set<Key, Hash, Equal, Alloc>>
|
||||
: set_caster<std::unordered_set<Key, Hash, Equal, Alloc>, Key> {};
|
||||
|
||||
template <typename Key, typename Value, typename Compare, typename Alloc> struct type_caster<std::map<Key, Value, Compare, Alloc>>
|
||||
: map_caster<std::map<Key, Value, Compare, Alloc>, Key, Value> { };
|
||||
template <typename Key, typename Value, typename Compare, typename Alloc>
|
||||
struct type_caster<std::map<Key, Value, Compare, Alloc>>
|
||||
: map_caster<std::map<Key, Value, Compare, Alloc>, Key, Value> {};
|
||||
|
||||
template <typename Key, typename Value, typename Hash, typename Equal, typename Alloc> struct type_caster<std::unordered_map<Key, Value, Hash, Equal, Alloc>>
|
||||
: map_caster<std::unordered_map<Key, Value, Hash, Equal, Alloc>, Key, Value> { };
|
||||
template <typename Key, typename Value, typename Hash, typename Equal,
|
||||
typename Alloc>
|
||||
struct type_caster<std::unordered_map<Key, Value, Hash, Equal, Alloc>>
|
||||
: map_caster<std::unordered_map<Key, Value, Hash, Equal, Alloc>, Key,
|
||||
Value> {};
|
||||
|
||||
// This type caster is intended to be used for std::optional and std::experimental::optional
|
||||
template<typename T> struct optional_caster {
|
||||
using value_conv = make_caster<typename T::value_type>;
|
||||
// This type caster is intended to be used for std::optional and
|
||||
// std::experimental::optional
|
||||
template <typename T> struct optional_caster {
|
||||
using value_conv = make_caster<typename T::value_type>;
|
||||
|
||||
template <typename T_>
|
||||
static handle cast(T_ &&src, return_value_policy policy, handle parent) {
|
||||
if (!src)
|
||||
return none().inc_ref();
|
||||
policy = return_value_policy_override<typename T::value_type>::policy(policy);
|
||||
return value_conv::cast(*std::forward<T_>(src), policy, parent);
|
||||
template <typename T_>
|
||||
static handle cast(T_ &&src, return_value_policy policy, handle parent) {
|
||||
if (!src)
|
||||
return none().inc_ref();
|
||||
policy =
|
||||
return_value_policy_override<typename T::value_type>::policy(policy);
|
||||
return value_conv::cast(*std::forward<T_>(src), policy, parent);
|
||||
}
|
||||
|
||||
bool load(handle src, bool convert) {
|
||||
if (!src) {
|
||||
return false;
|
||||
} else if (src.is_none()) {
|
||||
return true; // default-constructed value is already empty
|
||||
}
|
||||
value_conv inner_caster;
|
||||
if (!inner_caster.load(src, convert))
|
||||
return false;
|
||||
|
||||
bool load(handle src, bool convert) {
|
||||
if (!src) {
|
||||
return false;
|
||||
} else if (src.is_none()) {
|
||||
return true; // default-constructed value is already empty
|
||||
}
|
||||
value_conv inner_caster;
|
||||
if (!inner_caster.load(src, convert))
|
||||
return false;
|
||||
value.emplace(cast_op<typename T::value_type &&>(std::move(inner_caster)));
|
||||
return true;
|
||||
}
|
||||
|
||||
value.emplace(cast_op<typename T::value_type &&>(std::move(inner_caster)));
|
||||
return true;
|
||||
}
|
||||
|
||||
PYBIND11_TYPE_CASTER(T, _("Optional[") + value_conv::name + _("]"));
|
||||
PYBIND11_TYPE_CASTER(T, _("Optional[") + value_conv::name + _("]"));
|
||||
};
|
||||
|
||||
#if PYBIND11_HAS_OPTIONAL
|
||||
template<typename T> struct type_caster<std::optional<T>>
|
||||
template <typename T>
|
||||
struct type_caster<std::optional<T>>
|
||||
: public optional_caster<std::optional<T>> {};
|
||||
|
||||
template<> struct type_caster<std::nullopt_t>
|
||||
: public void_caster<std::nullopt_t> {};
|
||||
template <>
|
||||
struct type_caster<std::nullopt_t> : public void_caster<std::nullopt_t> {};
|
||||
#endif
|
||||
|
||||
#if PYBIND11_HAS_EXP_OPTIONAL
|
||||
template<typename T> struct type_caster<std::experimental::optional<T>>
|
||||
template <typename T>
|
||||
struct type_caster<std::experimental::optional<T>>
|
||||
: public optional_caster<std::experimental::optional<T>> {};
|
||||
|
||||
template<> struct type_caster<std::experimental::nullopt_t>
|
||||
template <>
|
||||
struct type_caster<std::experimental::nullopt_t>
|
||||
: public void_caster<std::experimental::nullopt_t> {};
|
||||
#endif
|
||||
|
||||
/// Visit a variant and cast any found type to Python
|
||||
struct variant_caster_visitor {
|
||||
return_value_policy policy;
|
||||
handle parent;
|
||||
return_value_policy policy;
|
||||
handle parent;
|
||||
|
||||
using result_type = handle; // required by boost::variant in C++11
|
||||
using result_type = handle; // required by boost::variant in C++11
|
||||
|
||||
template <typename T>
|
||||
result_type operator()(T &&src) const {
|
||||
return make_caster<T>::cast(std::forward<T>(src), policy, parent);
|
||||
}
|
||||
template <typename T> result_type operator()(T &&src) const {
|
||||
return make_caster<T>::cast(std::forward<T>(src), policy, parent);
|
||||
}
|
||||
};
|
||||
|
||||
/// Helper class which abstracts away variant's `visit` function. `std::variant` and similar
|
||||
/// `namespace::variant` types which provide a `namespace::visit()` function are handled here
|
||||
/// automatically using argument-dependent lookup. Users can provide specializations for other
|
||||
/// variant-like classes, e.g. `boost::variant` and `boost::apply_visitor`.
|
||||
template <template<typename...> class Variant>
|
||||
struct visit_helper {
|
||||
template <typename... Args>
|
||||
static auto call(Args &&...args) -> decltype(visit(std::forward<Args>(args)...)) {
|
||||
return visit(std::forward<Args>(args)...);
|
||||
}
|
||||
/// Helper class which abstracts away variant's `visit` function. `std::variant`
|
||||
/// and similar `namespace::variant` types which provide a `namespace::visit()`
|
||||
/// function are handled here automatically using argument-dependent lookup.
|
||||
/// Users can provide specializations for other variant-like classes, e.g.
|
||||
/// `boost::variant` and `boost::apply_visitor`.
|
||||
template <template <typename...> class Variant> struct visit_helper {
|
||||
template <typename... Args>
|
||||
static auto call(Args &&... args)
|
||||
-> decltype(visit(std::forward<Args>(args)...)) {
|
||||
return visit(std::forward<Args>(args)...);
|
||||
}
|
||||
};
|
||||
|
||||
/// Generic variant caster
|
||||
template <typename Variant> struct variant_caster;
|
||||
|
||||
template <template<typename...> class V, typename... Ts>
|
||||
template <template <typename...> class V, typename... Ts>
|
||||
struct variant_caster<V<Ts...>> {
|
||||
static_assert(sizeof...(Ts) > 0, "Variant must consist of at least one alternative.");
|
||||
static_assert(sizeof...(Ts) > 0,
|
||||
"Variant must consist of at least one alternative.");
|
||||
|
||||
template <typename U, typename... Us>
|
||||
bool load_alternative(handle src, bool convert, type_list<U, Us...>) {
|
||||
auto caster = make_caster<U>();
|
||||
if (caster.load(src, convert)) {
|
||||
value = cast_op<U>(caster);
|
||||
return true;
|
||||
}
|
||||
return load_alternative(src, convert, type_list<Us...>{});
|
||||
template <typename U, typename... Us>
|
||||
bool load_alternative(handle src, bool convert, type_list<U, Us...>) {
|
||||
auto caster = make_caster<U>();
|
||||
if (caster.load(src, convert)) {
|
||||
value = cast_op<U>(caster);
|
||||
return true;
|
||||
}
|
||||
return load_alternative(src, convert, type_list<Us...>{});
|
||||
}
|
||||
|
||||
bool load_alternative(handle, bool, type_list<>) { return false; }
|
||||
bool load_alternative(handle, bool, type_list<>) { return false; }
|
||||
|
||||
bool load(handle src, bool convert) {
|
||||
// Do a first pass without conversions to improve constructor resolution.
|
||||
// E.g. `py::int_(1).cast<variant<double, int>>()` needs to fill the `int`
|
||||
// slot of the variant. Without two-pass loading `double` would be filled
|
||||
// because it appears first and a conversion is possible.
|
||||
if (convert && load_alternative(src, false, type_list<Ts...>{}))
|
||||
return true;
|
||||
return load_alternative(src, convert, type_list<Ts...>{});
|
||||
}
|
||||
bool load(handle src, bool convert) {
|
||||
// Do a first pass without conversions to improve constructor resolution.
|
||||
// E.g. `py::int_(1).cast<variant<double, int>>()` needs to fill the `int`
|
||||
// slot of the variant. Without two-pass loading `double` would be filled
|
||||
// because it appears first and a conversion is possible.
|
||||
if (convert && load_alternative(src, false, type_list<Ts...>{}))
|
||||
return true;
|
||||
return load_alternative(src, convert, type_list<Ts...>{});
|
||||
}
|
||||
|
||||
template <typename Variant>
|
||||
static handle cast(Variant &&src, return_value_policy policy, handle parent) {
|
||||
return visit_helper<V>::call(variant_caster_visitor{policy, parent},
|
||||
std::forward<Variant>(src));
|
||||
}
|
||||
template <typename Variant>
|
||||
static handle cast(Variant &&src, return_value_policy policy, handle parent) {
|
||||
return visit_helper<V>::call(variant_caster_visitor{policy, parent},
|
||||
std::forward<Variant>(src));
|
||||
}
|
||||
|
||||
using Type = V<Ts...>;
|
||||
PYBIND11_TYPE_CASTER(Type, _("Union[") + detail::concat(make_caster<Ts>::name...) + _("]"));
|
||||
using Type = V<Ts...>;
|
||||
PYBIND11_TYPE_CASTER(Type, _("Union[") +
|
||||
detail::concat(make_caster<Ts>::name...) +
|
||||
_("]"));
|
||||
};
|
||||
|
||||
#if PYBIND11_HAS_VARIANT
|
||||
template <typename... Ts>
|
||||
struct type_caster<std::variant<Ts...>> : variant_caster<std::variant<Ts...>> { };
|
||||
struct type_caster<std::variant<Ts...>> : variant_caster<std::variant<Ts...>> {
|
||||
};
|
||||
#endif
|
||||
|
||||
NAMESPACE_END(detail)
|
||||
|
||||
inline std::ostream &operator<<(std::ostream &os, const handle &obj) {
|
||||
os << (std::string) str(obj);
|
||||
return os;
|
||||
os << (std::string)str(obj);
|
||||
return os;
|
||||
}
|
||||
|
||||
NAMESPACE_END(PYBIND11_NAMESPACE)
|
||||
|
File diff suppressed because it is too large
Load Diff
2018
python/src/triton.cc
2018
python/src/triton.cc
File diff suppressed because it is too large
Load Diff
@@ -1,48 +1,51 @@
|
||||
#include "triton/Analysis/AxisInfo.h"
|
||||
#include "mlir/Pass/Pass.h"
|
||||
#include "triton/Analysis/AxisInfo.h"
|
||||
|
||||
using namespace mlir;
|
||||
|
||||
namespace{
|
||||
namespace {
|
||||
|
||||
struct TestAxisInfoPass
|
||||
: public PassWrapper<TestAxisInfoPass, OperationPass<FuncOp>>{
|
||||
|
||||
: public PassWrapper<TestAxisInfoPass, OperationPass<FuncOp>> {
|
||||
|
||||
// LLVM15+
|
||||
// MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestAlignmentPass);
|
||||
|
||||
void print(const std::string& name, raw_ostream& os, ArrayRef<int> vals){
|
||||
|
||||
void print(const std::string &name, raw_ostream &os, ArrayRef<int> vals) {
|
||||
os << name << ": [";
|
||||
for(size_t d = 0; d < vals.size(); d++){
|
||||
if(d != 0) os << ", ";
|
||||
for (size_t d = 0; d < vals.size(); d++) {
|
||||
if (d != 0)
|
||||
os << ", ";
|
||||
os << vals[d];
|
||||
}
|
||||
os << "]";
|
||||
}
|
||||
|
||||
StringRef getArgument() const final { return "test-print-alignment"; }
|
||||
StringRef getDescription() const final
|
||||
{ return "print the result of the alignment analysis pass"; }
|
||||
StringRef getDescription() const final {
|
||||
return "print the result of the alignment analysis pass";
|
||||
}
|
||||
|
||||
void runOnOperation() override {
|
||||
Operation* operation = getOperation();
|
||||
auto& os = llvm::errs();
|
||||
Operation *operation = getOperation();
|
||||
auto &os = llvm::errs();
|
||||
os << "Testing: " << operation->getName() << "\n";
|
||||
AxisInfoAnalysis analysis(&getContext());
|
||||
analysis.run(operation);
|
||||
operation->walk([&](Operation* op){
|
||||
if(op->getNumResults() < 1)
|
||||
operation->walk([&](Operation *op) {
|
||||
if (op->getNumResults() < 1)
|
||||
return;
|
||||
for(Value result: op->getResults()){
|
||||
for (Value result : op->getResults()) {
|
||||
// std::ostringstream oss;
|
||||
// result.print(oss);
|
||||
// os << " => ";
|
||||
LatticeElement<AxisInfo> *latticeElement = analysis.lookupLatticeElement(result);
|
||||
if(!latticeElement){
|
||||
LatticeElement<AxisInfo> *latticeElement =
|
||||
analysis.lookupLatticeElement(result);
|
||||
if (!latticeElement) {
|
||||
os << "None\n";
|
||||
return;
|
||||
}
|
||||
AxisInfo& info = latticeElement->getValue();
|
||||
AxisInfo &info = latticeElement->getValue();
|
||||
print("Contiguity", os, info.getContiguity());
|
||||
os << " ; ";
|
||||
print("Divisibility", os, info.getDivisibility());
|
||||
@@ -50,18 +53,17 @@ struct TestAxisInfoPass
|
||||
print("Constancy", os, info.getConstancy());
|
||||
os << " ( ";
|
||||
result.print(os);
|
||||
os << " ) ";
|
||||
os << " ) ";
|
||||
os << "\n";
|
||||
}
|
||||
});
|
||||
}
|
||||
};
|
||||
|
||||
}
|
||||
} // namespace
|
||||
|
||||
namespace mlir{
|
||||
namespace test{
|
||||
namespace mlir {
|
||||
namespace test {
|
||||
void registerTestAlignmentPass() { PassRegistration<TestAxisInfoPass>(); }
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace test
|
||||
} // namespace mlir
|
||||
|
Reference in New Issue
Block a user