[CI] run clang-format (#24)

This commit is contained in:
Philippe Tillet
2022-07-26 17:25:03 -07:00
committed by GitHub
parent 25357083e6
commit 6d62d88d4f
62 changed files with 13673 additions and 11367 deletions

View File

@@ -27,10 +27,16 @@ jobs:
pip install isort
isort -c ./python || ( echo '::error title=Imports not sorted::Please run \"isort ./python\"' ; exit 1 )
- name: Check style
- name: Check python style
run: |
pip install autopep8
autopep8 -a -r -d --exit-code ./python || ( echo '::error title=Style issues::Please run \"autopep8 -a -r -i ./python\"' ; exit 1 )
- name: Check cpp style
run: |
sudo apt-get install clang-format
find . -regex '.*\.\(cpp\|hpp\|h\|cc\)' -not -path "./python/build/*" -not -path "./include/triton/external/*" -print0 | xargs -0 -n1 clang-format -style=file --dry-run -Werror -i ||
(echo '::error title=Style issues:: Please run `find . -regex ".*\.\(cpp\|hpp\|h\|cc\)" -not -path "./python/build/*" -not -path "./include/triton/external/*" -print0 | xargs -0 -n1 clang-format -style=file -i`' ; exit 1)
- name: Flake8
run: |

View File

@@ -57,9 +57,8 @@ static cl::opt<bool> NoCanonicalizeWhiteSpace(
"strict-whitespace",
cl::desc("Do not treat all horizontal whitespace as equivalent"));
static cl::opt<bool> IgnoreCase(
"ignore-case",
cl::desc("Use case-insensitive matching"));
static cl::opt<bool> IgnoreCase("ignore-case",
cl::desc("Use case-insensitive matching"));
static cl::list<std::string> ImplicitCheckNot(
"implicit-check-not",
@@ -169,12 +168,6 @@ static cl::list<unsigned> DumpInputContexts(
typedef cl::list<std::string>::const_iterator prefix_iterator;
static void DumpCommandLine(int argc, char **argv) {
errs() << "FileCheck command line: ";
for (int I = 0; I < argc; I++)
@@ -613,8 +606,7 @@ static void DumpAnnotatedInput(raw_ostream &OS, const FileCheckRequest &Req,
ElidedLinesOS.enable_colors(true);
auto AnnotationItr = Annotations.begin(), AnnotationEnd = Annotations.end();
for (unsigned Line = 1;
InputFilePtr != InputFileEnd || AnnotationItr != AnnotationEnd;
++Line) {
InputFilePtr != InputFileEnd || AnnotationItr != AnnotationEnd; ++Line) {
const unsigned char *InputFileLine = InputFilePtr;
// Compute the previous and next line included by the filter.
@@ -691,8 +683,7 @@ static void DumpAnnotatedInput(raw_ostream &OS, const FileCheckRequest &Req,
unsigned InputLineWidth = InputFilePtr - InputFileLine;
// Print any annotations.
while (AnnotationItr != AnnotationEnd &&
AnnotationItr->InputLine == Line) {
while (AnnotationItr != AnnotationEnd && AnnotationItr->InputLine == Line) {
WithColor COS(*LineOS, AnnotationItr->Marker.Color, /*Bold=*/true,
/*BG=*/false, TheColorMode);
// The two spaces below are where the ": " appears on input lines.

View File

@@ -10,11 +10,11 @@
#include "mlir/InitAllPasses.h"
#include "mlir/Support/MlirOptMain.h"
namespace mlir{
namespace test{
namespace mlir {
namespace test {
void registerTestAlignmentPass();
}
}
} // namespace mlir
int main(int argc, char **argv) {
mlir::registerAllPasses();
@@ -25,13 +25,11 @@ int main(int argc, char **argv) {
// TODO: register Triton & TritonGPU passes
mlir::DialectRegistry registry;
registry.insert<mlir::triton::TritonDialect,
mlir::triton::gpu::TritonGPUDialect,
mlir::arith::ArithmeticDialect,
mlir::StandardOpsDialect,
mlir::scf::SCFDialect>();
registry
.insert<mlir::triton::TritonDialect, mlir::triton::gpu::TritonGPUDialect,
mlir::arith::ArithmeticDialect, mlir::StandardOpsDialect,
mlir::scf::SCFDialect>();
return mlir::asMainReturnCode(
mlir::MlirOptMain(argc, argv, "Triton (GPU) optimizer driver\n", registry)
);
return mlir::asMainReturnCode(mlir::MlirOptMain(
argc, argv, "Triton (GPU) optimizer driver\n", registry));
}

View File

@@ -10,7 +10,6 @@
namespace mlir {
//===----------------------------------------------------------------------===//
// AxisInfo
//===----------------------------------------------------------------------===//
@@ -25,26 +24,25 @@ public:
public:
// Default constructor
AxisInfo(): AxisInfo({}, {}, {}) { }
AxisInfo() : AxisInfo({}, {}, {}) {}
// Construct contiguity info with known contiguity
AxisInfo(ContiguityT knownContiguity, DivisibilityT knownDivisibility,
ConstancyT knownConstancy)
: contiguity(knownContiguity), divisibility(knownDivisibility),
constancy(knownConstancy), rank(contiguity.size()) {
assert(knownDivisibility.size() == rank);
assert(knownConstancy.size() == rank);
}
: contiguity(knownContiguity), divisibility(knownDivisibility),
constancy(knownConstancy), rank(contiguity.size()) {
assert(knownDivisibility.size() == rank);
assert(knownConstancy.size() == rank);
}
// Accessors
int getContiguity(size_t d) const { return contiguity[d]; }
const ContiguityT& getContiguity() const { return contiguity; }
int getContiguity(size_t d) const { return contiguity[d]; }
const ContiguityT &getContiguity() const { return contiguity; }
int getDivisibility(size_t d) const { return divisibility[d]; }
const DivisibilityT& getDivisibility() const { return divisibility; }
const DivisibilityT &getDivisibility() const { return divisibility; }
int getConstancy(size_t d) const { return constancy[d]; }
const ConstancyT& getConstancy() const { return constancy; }
int getConstancy(size_t d) const { return constancy[d]; }
const ConstancyT &getConstancy() const { return constancy; }
int getRank() const { return rank; }
@@ -56,13 +54,13 @@ public:
}
/// The pessimistic value state of the contiguity is unknown.
static AxisInfo getPessimisticValueState(MLIRContext *context)
{ return AxisInfo(); }
static AxisInfo getPessimisticValueState(MLIRContext *context) {
return AxisInfo();
}
static AxisInfo getPessimisticValueState(Value value);
// The gcd of both arguments for each dimension
static AxisInfo join(const AxisInfo &lhs,
const AxisInfo &rhs);
static AxisInfo join(const AxisInfo &lhs, const AxisInfo &rhs);
private:
/// The _contiguity_ information maps the `d`-th
@@ -81,7 +79,7 @@ private:
/// [19, 23, 27, 31]
/// Would have contiguity [2, 1].
ContiguityT contiguity;
/// The _divisibility_ information maps the `d`-th
/// dimension to the largest power-of-two that
/// divides the first element of all the values along it
@@ -107,39 +105,36 @@ private:
/// [16, 16, 16, 16, 20, 20, 20, 20]
/// would have constancy [1, 4]
ConstancyT constancy;
// number of dimensions of the lattice
int rank;
};
class AxisInfoAnalysis
: public ForwardDataFlowAnalysis<AxisInfo> {
class AxisInfoAnalysis : public ForwardDataFlowAnalysis<AxisInfo> {
private:
static const int maxPow2Divisor = 65536;
int highestPowOf2Divisor(int n){
if(n==0)
int highestPowOf2Divisor(int n) {
if (n == 0)
return maxPow2Divisor;
return (n & (~(n - 1)));
}
AxisInfo visitBinaryOp(Operation* op, AxisInfo lhsInfo, AxisInfo rhsInfo,
const std::function<int(AxisInfo,AxisInfo,int)>& getContiguity,
const std::function<int(AxisInfo,AxisInfo,int)>& getDivisibility,
const std::function<int(AxisInfo,AxisInfo,int)>& getConstancy);
AxisInfo visitBinaryOp(
Operation *op, AxisInfo lhsInfo, AxisInfo rhsInfo,
const std::function<int(AxisInfo, AxisInfo, int)> &getContiguity,
const std::function<int(AxisInfo, AxisInfo, int)> &getDivisibility,
const std::function<int(AxisInfo, AxisInfo, int)> &getConstancy);
public:
using ForwardDataFlowAnalysis<AxisInfo>::ForwardDataFlowAnalysis;
ChangeResult visitOperation(Operation *op,
ArrayRef<LatticeElement<AxisInfo> *> operands) override;
ChangeResult
visitOperation(Operation *op,
ArrayRef<LatticeElement<AxisInfo> *> operands) override;
};
}
} // namespace mlir
#endif

View File

@@ -3,17 +3,13 @@
#include "triton/Conversion/TritonToTritonGPU/TritonToTritonGPU.h"
namespace mlir
{
namespace triton
{
namespace mlir {
namespace triton {
#define GEN_PASS_REGISTRATION
#include "triton/Conversion/Passes.h.inc"
} // namespace triton
} // namespace mlir
#endif

View File

@@ -3,18 +3,17 @@
#include <memory>
namespace mlir{
namespace mlir {
class ModuleOp;
template <typename T> class OperationPass;
namespace triton{
namespace triton {
std::unique_ptr<OperationPass<ModuleOp>>
std::unique_ptr<OperationPass<ModuleOp>>
createConvertTritonToTritonGPUPass(int numWarps = 4);
}
} // namespace mlir
#endif

View File

@@ -1,17 +1,16 @@
#ifndef TRITON_DIALECT_TRITON_IR_DIALECT_H_
#define TRITON_DIALECT_TRITON_IR_DIALECT_H_
#include "mlir/Dialect/SCF/SCF.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/Dialect.h"
#include "mlir/Interfaces/ControlFlowInterfaces.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/Dialect/SCF/SCF.h"
#include "triton/Dialect/Triton/IR/Traits.h"
#include "triton/Dialect/Triton/IR/Types.h"
#include "triton/Dialect/Triton/IR/Dialect.h.inc"
#include "triton/Dialect/Triton/IR/OpsEnums.h.inc"
#include "triton/Dialect/Triton/IR/Traits.h"
#include "triton/Dialect/Triton/IR/Types.h"
#define GET_OP_CLASSES
#include "triton/Dialect/Triton/IR/Ops.h.inc"

View File

@@ -19,7 +19,7 @@ public:
static LogicalResult verifyTrait(Operation *op) {
// The rationale for this number is to prevent users from creating programs
// that would have catastrophic register pressure and cause the compiler to
// hang.
// hang.
// Since H100 has 256KB registers, we should allow users to create tensors
// of size up to 256K elements. It will spill for datatypes wider than 1B,
// but we probably should limit number of elements (rather than bytes) to
@@ -31,8 +31,8 @@ public:
for (int64_t s : tensorType.getShape())
numElements *= s;
if (numElements > maxElement)
return op->emitError("Maximum allowed number of elements is ") << maxElement << ", but "
<< *op << " has more than that";
return op->emitError("Maximum allowed number of elements is ")
<< maxElement << ", but " << *op << " has more than that";
if ((numElements & (numElements - 1)) != 0)
return op->emitError("Number of elements must be power-of-two, but ")
<< *op << " doesn't follow the rule";
@@ -45,8 +45,8 @@ public:
for (int64_t s : tensorType.getShape())
numElements *= s;
if (numElements > maxElement)
return op->emitError("Maximum allowed number of elements is ") << maxElement << ", but "
<< *op << " has more than that";
return op->emitError("Maximum allowed number of elements is ")
<< maxElement << ", but " << *op << " has more than that";
if ((numElements & (numElements - 1)) != 0)
return op->emitError("Number of elements must be power-of-two, but ")
<< *op << " doesn't follow the rule";
@@ -57,7 +57,7 @@ public:
}
};
}
}
} // namespace OpTrait
} // namespace mlir
#endif

View File

@@ -13,6 +13,6 @@ std::unique_ptr<Pass> createCombineOpsPass();
#define GEN_PASS_REGISTRATION
#include "triton/Dialect/Triton/Transforms/Passes.h.inc"
}
} // namespace mlir
#endif

View File

@@ -15,5 +15,4 @@
#define GET_OP_CLASSES
#include "triton/Dialect/TritonGPU/IR/Ops.h.inc"
#endif // TRITON_DIALECT_TRITONGPU_IR_DIALECT_H_

View File

@@ -14,6 +14,7 @@ namespace mlir {
class TritonGPUTypeConverter : public TypeConverter {
public:
TritonGPUTypeConverter(MLIRContext *context, int numThreads);
private:
MLIRContext *context;
int numThreads;
@@ -21,8 +22,10 @@ private:
class TritonGPUConversionTarget : public ConversionTarget {
TritonGPUTypeConverter &typeConverter;
public:
explicit TritonGPUConversionTarget(MLIRContext &ctx, TritonGPUTypeConverter &typeConverter);
explicit TritonGPUConversionTarget(MLIRContext &ctx,
TritonGPUTypeConverter &typeConverter);
/// update layouts & insert ConvertLayoutOp if necessary
LogicalResult refineLayouts(ModuleOp mod, int numThreads);

387
include/triton/driver/dispatch.h Executable file → Normal file
View File

@@ -3,10 +3,10 @@
#ifndef _TRITON_DRIVER_DISPATCH_H_
#define _TRITON_DRIVER_DISPATCH_H_
#include <type_traits>
#include <dlfcn.h>
#include <type_traits>
//CUDA Backend
// CUDA Backend
#include "triton/external/CUDA/cuda.h"
#include "triton/external/CUDA/nvml.h"
@@ -14,47 +14,43 @@
//#define __HIP_PLATFORM_AMD__
#include "triton/external/hip.h"
//Exceptions
// Exceptions
#include <iostream>
#include <stdexcept>
namespace llvm {
class PassRegistry;
class Module;
}
} // namespace llvm
namespace triton
{
namespace driver
{
namespace triton {
namespace driver {
class cu_context;
template<class T> void check(T){}
template <class T> void check(T) {}
void check(CUresult err);
void check(hipError_t err);
class dispatch
{
class dispatch {
protected:
template <class F>
struct return_type;
template <class F> struct return_type;
template <class R, class... A>
struct return_type<R (*)(A...)>
{ typedef R type; };
template <class R, class... A> struct return_type<R (*)(A...)> {
typedef R type;
};
typedef bool (*f_init_t)();
template<f_init_t initializer, typename FunPtrT, typename... Args>
static typename return_type<FunPtrT>::type f_impl(void*& lib_h, FunPtrT, void*& cache, const char * name, Args... args)
{
template <f_init_t initializer, typename FunPtrT, typename... Args>
static typename return_type<FunPtrT>::type
f_impl(void *&lib_h, FunPtrT, void *&cache, const char *name, Args... args) {
initializer();
if(cache == nullptr){
if (cache == nullptr) {
cache = dlsym(lib_h, name);
if(cache == 0)
throw std::runtime_error("dlsym unable to load function");
}
if (cache == 0)
throw std::runtime_error("dlsym unable to load function");
}
FunPtrT fptr;
*reinterpret_cast<void **>(&fptr) = cache;
typename return_type<FunPtrT>::type res = (*fptr)(args...);
@@ -76,63 +72,99 @@ public:
// context management
static CUresult cuInit(unsigned int Flags);
static CUresult cuCtxDestroy_v2(CUcontext ctx);
static CUresult cuCtxCreate_v2(CUcontext *pctx, unsigned int flags, CUdevice dev);
static CUresult cuCtxCreate_v2(CUcontext *pctx, unsigned int flags,
CUdevice dev);
static CUresult cuCtxPushCurrent_v2(CUcontext ctx);
static CUresult cuCtxPopCurrent_v2(CUcontext *pctx);
static CUresult cuCtxGetDevice(CUdevice* result);
static CUresult cuCtxEnablePeerAccess(CUcontext peerContext, unsigned int flags);
static CUresult cuCtxGetDevice(CUdevice *result);
static CUresult cuCtxEnablePeerAccess(CUcontext peerContext,
unsigned int flags);
static CUresult cuDriverGetVersion(int *driverVersion);
// device management
static CUresult cuDeviceGet(CUdevice *device, int ordinal);
static CUresult cuDeviceGetName(char *name, int len, CUdevice dev);
static CUresult cuDeviceGetPCIBusId(char *id, int len, CUdevice dev);
static CUresult cuDeviceGetAttribute(int *pi, CUdevice_attribute attrib, CUdevice dev);
static CUresult cuDeviceGetAttribute(int *pi, CUdevice_attribute attrib,
CUdevice dev);
static CUresult cuDeviceGetCount(int *count);
// link management
static CUresult cuLinkAddData_v2(CUlinkState state, CUjitInputType type, void* data, size_t size, const char* name, unsigned int numOptions, CUjit_option* options, void** optionValues);
static CUresult cuLinkCreate_v2(unsigned int numOptions, CUjit_option* options, void** optionValues, CUlinkState* stateOut);
static CUresult cuLinkComplete(CUlinkState state, void** cubinOut, size_t* sizeOut);
static CUresult cuLinkAddData_v2(CUlinkState state, CUjitInputType type,
void *data, size_t size, const char *name,
unsigned int numOptions,
CUjit_option *options, void **optionValues);
static CUresult cuLinkCreate_v2(unsigned int numOptions,
CUjit_option *options, void **optionValues,
CUlinkState *stateOut);
static CUresult cuLinkComplete(CUlinkState state, void **cubinOut,
size_t *sizeOut);
static CUresult cuLinkDestroy(CUlinkState state);
// module management
static CUresult cuModuleGetGlobal_v2(CUdeviceptr *dptr, size_t* bytes, CUmodule hmod, const char *name);
static CUresult cuModuleGetGlobal_v2(CUdeviceptr *dptr, size_t *bytes,
CUmodule hmod, const char *name);
static CUresult cuModuleLoad(CUmodule *module, const char *fname);
static CUresult cuModuleLoadData(CUmodule* module, const void* image);
static CUresult cuModuleLoadData(CUmodule *module, const void *image);
static CUresult cuModuleUnload(CUmodule hmod);
static CUresult cuModuleLoadDataEx(CUmodule *module, const void *image, unsigned int numOptions, CUjit_option *options, void **optionValues);
static CUresult cuModuleGetFunction(CUfunction *hfunc, CUmodule hmod, const char *name);
static CUresult cuModuleLoadDataEx(CUmodule *module, const void *image,
unsigned int numOptions,
CUjit_option *options,
void **optionValues);
static CUresult cuModuleGetFunction(CUfunction *hfunc, CUmodule hmod,
const char *name);
// stream management
static CUresult cuStreamCreate(CUstream *phStream, unsigned int Flags);
static CUresult cuStreamSynchronize(CUstream hStream);
static CUresult cuStreamGetCtx(CUstream hStream, CUcontext* pctx);
static CUresult cuStreamGetCtx(CUstream hStream, CUcontext *pctx);
static CUresult cuStreamDestroy_v2(CUstream hStream);
static CUresult cuLaunchKernel(CUfunction f, unsigned int gridDimX, unsigned int gridDimY, unsigned int gridDimZ, unsigned int blockDimX, unsigned int blockDimY, unsigned int blockDimZ, unsigned int sharedMemBytes, CUstream hStream, void **kernelParams, void **extra);
static CUresult cuLaunchKernel(CUfunction f, unsigned int gridDimX,
unsigned int gridDimY, unsigned int gridDimZ,
unsigned int blockDimX, unsigned int blockDimY,
unsigned int blockDimZ,
unsigned int sharedMemBytes, CUstream hStream,
void **kernelParams, void **extra);
// function management
static CUresult cuFuncGetAttribute(int* pi, CUfunction_attribute attrib, CUfunction hfunc);
static CUresult cuFuncSetAttribute(CUfunction hfunc, CUfunction_attribute attrib, int value);
static CUresult cuFuncGetAttribute(int *pi, CUfunction_attribute attrib,
CUfunction hfunc);
static CUresult cuFuncSetAttribute(CUfunction hfunc,
CUfunction_attribute attrib, int value);
static CUresult cuFuncSetCacheConfig(CUfunction hfunc, CUfunc_cache config);
// memory management
static CUresult cuMemAlloc_v2(CUdeviceptr *dptr, size_t bytesize);
static CUresult cuPointerGetAttribute(void * data, CUpointer_attribute attribute, CUdeviceptr ptr);
static CUresult cuMemsetD8Async(CUdeviceptr dst, unsigned char x, size_t N, CUstream stream);
static CUresult cuMemcpyDtoH_v2(void *dstHost, CUdeviceptr srcDevice, size_t ByteCount);
static CUresult cuPointerGetAttribute(void *data,
CUpointer_attribute attribute,
CUdeviceptr ptr);
static CUresult cuMemsetD8Async(CUdeviceptr dst, unsigned char x, size_t N,
CUstream stream);
static CUresult cuMemcpyDtoH_v2(void *dstHost, CUdeviceptr srcDevice,
size_t ByteCount);
static CUresult cuMemFree_v2(CUdeviceptr dptr);
static CUresult cuMemcpyDtoHAsync_v2(void *dstHost, CUdeviceptr srcDevice, size_t ByteCount, CUstream hStream);
static CUresult cuMemcpyHtoDAsync_v2(CUdeviceptr dstDevice, const void *srcHost, size_t ByteCount, CUstream hStream);
static CUresult cuMemcpyHtoD_v2(CUdeviceptr dstDevice, const void *srcHost, size_t ByteCount);
static CUresult cuMemcpyDtoHAsync_v2(void *dstHost, CUdeviceptr srcDevice,
size_t ByteCount, CUstream hStream);
static CUresult cuMemcpyHtoDAsync_v2(CUdeviceptr dstDevice,
const void *srcHost, size_t ByteCount,
CUstream hStream);
static CUresult cuMemcpyHtoD_v2(CUdeviceptr dstDevice, const void *srcHost,
size_t ByteCount);
// event management
static CUresult cuEventCreate(CUevent *phEvent, unsigned int Flags);
static CUresult cuEventElapsedTime(float *pMilliseconds, CUevent hStart, CUevent hEnd);
static CUresult cuEventElapsedTime(float *pMilliseconds, CUevent hStart,
CUevent hEnd);
static CUresult cuEventRecord(CUevent hEvent, CUstream hStream);
static CUresult cuEventDestroy_v2(CUevent hEvent);
/* ------------------- *
* NVML
* ------------------- */
static nvmlReturn_t nvmlDeviceGetHandleByPciBusId_v2( const char* pciBusId, nvmlDevice_t* device);
static nvmlReturn_t nvmlDeviceGetClockInfo(nvmlDevice_t device, nvmlClockType_t type, unsigned int *clock);
static nvmlReturn_t nvmlDeviceGetMaxClockInfo(nvmlDevice_t device, nvmlClockType_t type, unsigned int *clock);
static nvmlReturn_t nvmlDeviceSetApplicationsClocks(nvmlDevice_t device, unsigned int mem_clock, unsigned int sm_clock);
static nvmlReturn_t nvmlDeviceGetHandleByPciBusId_v2(const char *pciBusId,
nvmlDevice_t *device);
static nvmlReturn_t nvmlDeviceGetClockInfo(nvmlDevice_t device,
nvmlClockType_t type,
unsigned int *clock);
static nvmlReturn_t nvmlDeviceGetMaxClockInfo(nvmlDevice_t device,
nvmlClockType_t type,
unsigned int *clock);
static nvmlReturn_t nvmlDeviceSetApplicationsClocks(nvmlDevice_t device,
unsigned int mem_clock,
unsigned int sm_clock);
/* ------------------- *
* HIP
@@ -140,177 +172,198 @@ public:
// context management
static hipError_t hipInit(unsigned int Flags);
static hipError_t hipCtxDestroy(hipCtx_t ctx);
static hipError_t hipCtxCreate(hipCtx_t *pctx, unsigned int flags, hipDevice_t dev);
static hipError_t hipCtxCreate(hipCtx_t *pctx, unsigned int flags,
hipDevice_t dev);
static hipError_t hipCtxPushCurrent(hipCtx_t ctx);
static hipError_t hipCtxPopCurrent(hipCtx_t *pctx);
static hipError_t hipCtxGetDevice(hipDevice_t* result);
static hipError_t hipCtxEnablePeerAccess(hipCtx_t peerContext, unsigned int flags);
static hipError_t hipCtxGetDevice(hipDevice_t *result);
static hipError_t hipCtxEnablePeerAccess(hipCtx_t peerContext,
unsigned int flags);
static hipError_t hipDriverGetVersion(int *driverVersion);
// device management
static hipError_t hipGetDevice(hipDevice_t *device, int ordinal);
static hipError_t hipDeviceGetName(char *name, int len, hipDevice_t dev);
static hipError_t hipDeviceGetPCIBusId(char *id, int len, hipDevice_t dev);
static hipError_t hipDeviceGetAttribute(int *pi, hipDeviceAttribute_t attrib, hipDevice_t dev);
static hipError_t hipDeviceGetAttribute(int *pi, hipDeviceAttribute_t attrib,
hipDevice_t dev);
static hipError_t hipGetDeviceCount(int *count);
// module management
static hipError_t hipModuleGetGlobal(hipDeviceptr_t *dptr, size_t* bytes, hipModule_t hmod, const char *name);
static hipError_t hipModuleGetGlobal(hipDeviceptr_t *dptr, size_t *bytes,
hipModule_t hmod, const char *name);
static hipError_t hipModuleLoad(hipModule_t *module, const char *fname);
static hipError_t hipModuleLoadData(hipModule_t* module, const void* image);
static hipError_t hipModuleLoadData(hipModule_t *module, const void *image);
static hipError_t hipModuleUnload(hipModule_t hmod);
static hipError_t hipModuleLoadDataEx(hipModule_t *module, const void *image, unsigned int numOptions, hipJitOption *options, void **optionValues);
static hipError_t hipModuleGetFunction(hipFunction_t *hfunc, hipModule_t hmod, const char *name);
static hipError_t hipModuleLoadDataEx(hipModule_t *module, const void *image,
unsigned int numOptions,
hipJitOption *options,
void **optionValues);
static hipError_t hipModuleGetFunction(hipFunction_t *hfunc, hipModule_t hmod,
const char *name);
// stream management
static hipError_t hipStreamCreate(hipStream_t *phStream, unsigned int Flags);
static hipError_t hipStreamSynchronize(hipStream_t hStream);
static hipError_t hipStreamDestroy(hipStream_t hStream);
static hipError_t hipModuleLaunchKernel(hipFunction_t f, unsigned int gridDimX, unsigned int gridDimY, unsigned int gridDimZ, unsigned int blockDimX, unsigned int blockDimY, unsigned int blockDimZ, unsigned int sharedMemBytes, hipStream_t hStream, void **kernelParams, void **extra);
static hipError_t
hipModuleLaunchKernel(hipFunction_t f, unsigned int gridDimX,
unsigned int gridDimY, unsigned int gridDimZ,
unsigned int blockDimX, unsigned int blockDimY,
unsigned int blockDimZ, unsigned int sharedMemBytes,
hipStream_t hStream, void **kernelParams, void **extra);
// function management
static hipError_t hipFuncGetAttributes(hipFuncAttributes* attrib, void* hfunc);
static hipError_t hipFuncSetAttribute(hipFunction_t hfunc, hipFuncAttribute attrib, int value);
static hipError_t hipFuncSetCacheConfig(hipFunction_t hfunc, hipFuncCache_t config);
static hipError_t hipFuncGetAttributes(hipFuncAttributes *attrib,
void *hfunc);
static hipError_t hipFuncSetAttribute(hipFunction_t hfunc,
hipFuncAttribute attrib, int value);
static hipError_t hipFuncSetCacheConfig(hipFunction_t hfunc,
hipFuncCache_t config);
// memory management
static hipError_t hipMalloc(hipDeviceptr_t *dptr, size_t bytesize);
static hipError_t hipPointerGetAttribute(void * data, CUpointer_attribute attribute, hipDeviceptr_t ptr);
static hipError_t hipMemsetD8Async(hipDeviceptr_t dst, unsigned char x, size_t N, hipStream_t stream);
static hipError_t hipMemcpyDtoH(void *dstHost, hipDeviceptr_t srcDevice, size_t ByteCount);
static hipError_t hipPointerGetAttribute(void *data,
CUpointer_attribute attribute,
hipDeviceptr_t ptr);
static hipError_t hipMemsetD8Async(hipDeviceptr_t dst, unsigned char x,
size_t N, hipStream_t stream);
static hipError_t hipMemcpyDtoH(void *dstHost, hipDeviceptr_t srcDevice,
size_t ByteCount);
static hipError_t hipFree(hipDeviceptr_t dptr);
static hipError_t hipMemcpyDtoHAsync(void *dstHost, hipDeviceptr_t srcDevice, size_t ByteCount, hipStream_t hStream);
static hipError_t hipMemcpyHtoDAsync(hipDeviceptr_t dstDevice, const void *srcHost, size_t ByteCount, hipStream_t hStream);
static hipError_t hipMemcpyHtoD(hipDeviceptr_t dstDevice, const void *srcHost, size_t ByteCount);
static hipError_t hipMemcpyDtoHAsync(void *dstHost, hipDeviceptr_t srcDevice,
size_t ByteCount, hipStream_t hStream);
static hipError_t hipMemcpyHtoDAsync(hipDeviceptr_t dstDevice,
const void *srcHost, size_t ByteCount,
hipStream_t hStream);
static hipError_t hipMemcpyHtoD(hipDeviceptr_t dstDevice, const void *srcHost,
size_t ByteCount);
// event management
static hipError_t hipEventCreate(hipEvent_t *phEvent, unsigned int Flags);
static hipError_t hipEventElapsedTime(float *pMilliseconds, hipEvent_t hStart, hipEvent_t hEnd);
static hipError_t hipEventElapsedTime(float *pMilliseconds, hipEvent_t hStart,
hipEvent_t hEnd);
static hipError_t hipEventRecord(hipEvent_t hEvent, hipStream_t hStream);
static hipError_t hipEventDestroy(hipEvent_t hEvent);
private:
// Libraries
static void* cuda_;
static void* nvml_;
static void* hip_;
static void *cuda_;
static void *nvml_;
static void *hip_;
/* ------------------- *
* CUDA
* ------------------- */
// context management
static void* cuCtxGetCurrent_;
static void* cuCtxSetCurrent_;
static void* cuCtxDestroy_v2_;
static void* cuCtxCreate_v2_;
static void* cuCtxGetDevice_;
static void* cuCtxPushCurrent_v2_;
static void* cuCtxPopCurrent_v2_;
static void* cuCtxEnablePeerAccess_;
static void* cuDriverGetVersion_;
static void* cuInit_;
static void *cuCtxGetCurrent_;
static void *cuCtxSetCurrent_;
static void *cuCtxDestroy_v2_;
static void *cuCtxCreate_v2_;
static void *cuCtxGetDevice_;
static void *cuCtxPushCurrent_v2_;
static void *cuCtxPopCurrent_v2_;
static void *cuCtxEnablePeerAccess_;
static void *cuDriverGetVersion_;
static void *cuInit_;
// device management
static void* cuDeviceGet_;
static void* cuDeviceGetName_;
static void* cuDeviceGetPCIBusId_;
static void* cuDeviceGetAttribute_;
static void* cuDeviceGetCount_;
static void *cuDeviceGet_;
static void *cuDeviceGetName_;
static void *cuDeviceGetPCIBusId_;
static void *cuDeviceGetAttribute_;
static void *cuDeviceGetCount_;
// link management
static void* cuLinkAddData_v2_;
static void* cuLinkCreate_v2_;
static void* cuLinkDestroy_;
static void* cuLinkComplete_;
static void *cuLinkAddData_v2_;
static void *cuLinkCreate_v2_;
static void *cuLinkDestroy_;
static void *cuLinkComplete_;
// module management
static void* cuModuleGetGlobal_v2_;
static void* cuModuleLoad_;
static void* cuModuleUnload_;
static void* cuModuleLoadDataEx_;
static void* cuModuleLoadData_;
static void* cuModuleGetFunction_;
static void *cuModuleGetGlobal_v2_;
static void *cuModuleLoad_;
static void *cuModuleUnload_;
static void *cuModuleLoadDataEx_;
static void *cuModuleLoadData_;
static void *cuModuleGetFunction_;
// stream management
static void* cuStreamCreate_;
static void* cuStreamSynchronize_;
static void* cuStreamDestroy_v2_;
static void* cuStreamGetCtx_;
static void* cuLaunchKernel_;
static void *cuStreamCreate_;
static void *cuStreamSynchronize_;
static void *cuStreamDestroy_v2_;
static void *cuStreamGetCtx_;
static void *cuLaunchKernel_;
// function management
static void* cuFuncGetAttribute_;
static void* cuFuncSetAttribute_;
static void* cuFuncSetCacheConfig_;
static void *cuFuncGetAttribute_;
static void *cuFuncSetAttribute_;
static void *cuFuncSetCacheConfig_;
// memory management
static void* cuMemcpyDtoH_v2_;
static void* cuMemFree_v2_;
static void* cuMemcpyDtoHAsync_v2_;
static void* cuMemcpyHtoDAsync_v2_;
static void* cuMemcpyHtoD_v2_;
static void* cuMemAlloc_v2_;
static void* cuMemsetD8Async_;
static void* cuPointerGetAttribute_;
static void *cuMemcpyDtoH_v2_;
static void *cuMemFree_v2_;
static void *cuMemcpyDtoHAsync_v2_;
static void *cuMemcpyHtoDAsync_v2_;
static void *cuMemcpyHtoD_v2_;
static void *cuMemAlloc_v2_;
static void *cuMemsetD8Async_;
static void *cuPointerGetAttribute_;
// event management
static void* cuEventCreate_;
static void* cuEventElapsedTime_;
static void* cuEventRecord_;
static void* cuEventDestroy_v2_;
static void *cuEventCreate_;
static void *cuEventElapsedTime_;
static void *cuEventRecord_;
static void *cuEventDestroy_v2_;
/* ------------------- *
* NVML
* ------------------- */
static void* nvmlInit_v2_;
static void* nvmlDeviceGetHandleByPciBusId_v2_;
static void* nvmlDeviceGetClockInfo_;
static void* nvmlDeviceGetMaxClockInfo_;
static void* nvmlDeviceSetApplicationsClocks_;
static void *nvmlInit_v2_;
static void *nvmlDeviceGetHandleByPciBusId_v2_;
static void *nvmlDeviceGetClockInfo_;
static void *nvmlDeviceGetMaxClockInfo_;
static void *nvmlDeviceSetApplicationsClocks_;
/* ------------------- *
* HIP
* ------------------- */
// context management
static void* hipInit_;
static void* hipCtxDestroy_;
static void* hipCtxCreate_;
static void* hipCtxPushCurrent_;
static void* hipCtxPopCurrent_;
static void* hipCtxGetDevice_;
static void* hipCtxEnablePeerAccess_;
static void* hipDriverGetVersion_;
static void *hipInit_;
static void *hipCtxDestroy_;
static void *hipCtxCreate_;
static void *hipCtxPushCurrent_;
static void *hipCtxPopCurrent_;
static void *hipCtxGetDevice_;
static void *hipCtxEnablePeerAccess_;
static void *hipDriverGetVersion_;
// device management
static void* hipGetDevice_;
static void* hipDeviceGetName_;
static void* hipDeviceGetPCIBusId_;
static void* hipDeviceGetAttribute_;
static void* hipGetDeviceCount_;
static void *hipGetDevice_;
static void *hipDeviceGetName_;
static void *hipDeviceGetPCIBusId_;
static void *hipDeviceGetAttribute_;
static void *hipGetDeviceCount_;
// module management
static void* hipModuleGetGlobal_;
static void* hipModuleLoad_;
static void* hipModuleLoadData_;
static void* hipModuleUnload_;
static void* hipModuleLoadDataEx_;
static void* hipModuleGetFunction_;
static void *hipModuleGetGlobal_;
static void *hipModuleLoad_;
static void *hipModuleLoadData_;
static void *hipModuleUnload_;
static void *hipModuleLoadDataEx_;
static void *hipModuleGetFunction_;
// stream management
static void* hipStreamCreate_;
static void* hipStreamSynchronize_;
static void* hipStreamDestroy_;
static void* hipModuleLaunchKernel_;;
static void *hipStreamCreate_;
static void *hipStreamSynchronize_;
static void *hipStreamDestroy_;
static void *hipModuleLaunchKernel_;
;
// function management
static void* hipFuncGetAttributes_;
static void* hipFuncSetAttribute_;
static void* hipFuncSetCacheConfig_;
static void *hipFuncGetAttributes_;
static void *hipFuncSetAttribute_;
static void *hipFuncSetCacheConfig_;
// memory management
static void* hipMalloc_;
static void* hipPointerGetAttribute_;
static void* hipMemsetD8Async_;
static void* hipMemcpyDtoH_;
static void* hipFree_;
static void* hipMemcpyDtoHAsync_;
static void* hipMemcpyHtoDAsync_;
static void* hipMemcpyHtoD_;
static void *hipMalloc_;
static void *hipPointerGetAttribute_;
static void *hipMemsetD8Async_;
static void *hipMemcpyDtoH_;
static void *hipFree_;
static void *hipMemcpyDtoHAsync_;
static void *hipMemcpyHtoDAsync_;
static void *hipMemcpyHtoD_;
// event management
static void* hipEventCreate_;
static void* hipEventElapsedTime_;
static void* hipEventRecord_;
static void* hipEventDestroy_;
static void *hipEventCreate_;
static void *hipEventElapsedTime_;
static void *hipEventRecord_;
static void *hipEventDestroy_;
};
}
}
} // namespace driver
} // namespace triton
#endif

415
include/triton/driver/error.h Executable file → Normal file
View File

@@ -3,223 +3,252 @@
#ifndef _TRITON_DRIVER_ERROR_H_
#define _TRITON_DRIVER_ERROR_H_
#include <exception>
#include "triton/driver/dispatch.h"
#include <exception>
namespace triton {
namespace triton
{
namespace driver {
namespace driver
{
namespace exception {
namespace exception
{
namespace nvrtc {
namespace nvrtc
{
#define TRITON_CREATE_NVRTC_EXCEPTION(name, msg) \
class name : public std::exception { \
public: \
const char *what() const throw() override { return "NVRTC: Error- " msg; } \
}
#define TRITON_CREATE_NVRTC_EXCEPTION(name, msg) \
class name: public std::exception { public: const char * what() const throw() override { return "NVRTC: Error- " msg; } }
TRITON_CREATE_NVRTC_EXCEPTION(out_of_memory ,"out of memory");
TRITON_CREATE_NVRTC_EXCEPTION(program_creation_failure ,"program creation failure");
TRITON_CREATE_NVRTC_EXCEPTION(invalid_input ,"invalid input");
TRITON_CREATE_NVRTC_EXCEPTION(invalid_program ,"invalid program");
TRITON_CREATE_NVRTC_EXCEPTION(invalid_option ,"invalid option");
TRITON_CREATE_NVRTC_EXCEPTION(compilation ,"compilation");
TRITON_CREATE_NVRTC_EXCEPTION(builtin_operation_failure ,"builtin operation failure");
TRITON_CREATE_NVRTC_EXCEPTION(unknown_error ,"unknown error");
TRITON_CREATE_NVRTC_EXCEPTION(out_of_memory, "out of memory");
TRITON_CREATE_NVRTC_EXCEPTION(program_creation_failure,
"program creation failure");
TRITON_CREATE_NVRTC_EXCEPTION(invalid_input, "invalid input");
TRITON_CREATE_NVRTC_EXCEPTION(invalid_program, "invalid program");
TRITON_CREATE_NVRTC_EXCEPTION(invalid_option, "invalid option");
TRITON_CREATE_NVRTC_EXCEPTION(compilation, "compilation");
TRITON_CREATE_NVRTC_EXCEPTION(builtin_operation_failure,
"builtin operation failure");
TRITON_CREATE_NVRTC_EXCEPTION(unknown_error, "unknown error");
#undef TRITON_CREATE_NVRTC_EXCEPTION
} // namespace nvrtc
namespace cuda {
class base : public std::exception {};
#define TRITON_CREATE_CUDA_EXCEPTION(name, msg) \
class name : public base { \
public: \
const char *what() const throw() override { return "CUDA: Error- " msg; } \
}
namespace cuda
{
class base: public std::exception{};
#define TRITON_CREATE_CUDA_EXCEPTION(name, msg) \
class name: public base { public:const char * what() const throw() override { return "CUDA: Error- " msg; } }
TRITON_CREATE_CUDA_EXCEPTION(invalid_value ,"invalid value");
TRITON_CREATE_CUDA_EXCEPTION(out_of_memory ,"out of memory");
TRITON_CREATE_CUDA_EXCEPTION(not_initialized ,"not initialized");
TRITON_CREATE_CUDA_EXCEPTION(deinitialized ,"deinitialized");
TRITON_CREATE_CUDA_EXCEPTION(profiler_disabled ,"profiler disabled");
TRITON_CREATE_CUDA_EXCEPTION(profiler_not_initialized ,"profiler not initialized");
TRITON_CREATE_CUDA_EXCEPTION(profiler_already_started ,"profiler already started");
TRITON_CREATE_CUDA_EXCEPTION(profiler_already_stopped ,"profiler already stopped");
TRITON_CREATE_CUDA_EXCEPTION(no_device ,"no device");
TRITON_CREATE_CUDA_EXCEPTION(invalid_device ,"invalid device");
TRITON_CREATE_CUDA_EXCEPTION(invalid_image ,"invalid image");
TRITON_CREATE_CUDA_EXCEPTION(invalid_context ,"invalid context");
TRITON_CREATE_CUDA_EXCEPTION(context_already_current ,"context already current");
TRITON_CREATE_CUDA_EXCEPTION(map_failed ,"map failed");
TRITON_CREATE_CUDA_EXCEPTION(unmap_failed ,"unmap failed");
TRITON_CREATE_CUDA_EXCEPTION(array_is_mapped ,"array is mapped");
TRITON_CREATE_CUDA_EXCEPTION(already_mapped ,"already mapped");
TRITON_CREATE_CUDA_EXCEPTION(no_binary_for_gpu ,"no binary for gpu");
TRITON_CREATE_CUDA_EXCEPTION(already_acquired ,"already acquired");
TRITON_CREATE_CUDA_EXCEPTION(not_mapped ,"not mapped");
TRITON_CREATE_CUDA_EXCEPTION(not_mapped_as_array ,"not mapped as array");
TRITON_CREATE_CUDA_EXCEPTION(not_mapped_as_pointer ,"not mapped as pointer");
TRITON_CREATE_CUDA_EXCEPTION(ecc_uncorrectable ,"ecc uncorrectable");
TRITON_CREATE_CUDA_EXCEPTION(unsupported_limit ,"unsupported limit");
TRITON_CREATE_CUDA_EXCEPTION(context_already_in_use ,"context already in use");
TRITON_CREATE_CUDA_EXCEPTION(peer_access_unsupported ,"peer access unsupported");
TRITON_CREATE_CUDA_EXCEPTION(invalid_ptx ,"invalid ptx");
TRITON_CREATE_CUDA_EXCEPTION(invalid_graphics_context ,"invalid graphics context");
TRITON_CREATE_CUDA_EXCEPTION(invalid_source ,"invalid source");
TRITON_CREATE_CUDA_EXCEPTION(file_not_found ,"file not found");
TRITON_CREATE_CUDA_EXCEPTION(shared_object_symbol_not_found ,"shared object symbol not found");
TRITON_CREATE_CUDA_EXCEPTION(shared_object_init_failed ,"shared object init failed");
TRITON_CREATE_CUDA_EXCEPTION(operating_system ,"operating system");
TRITON_CREATE_CUDA_EXCEPTION(invalid_handle ,"invalid handle");
TRITON_CREATE_CUDA_EXCEPTION(not_found ,"not found");
TRITON_CREATE_CUDA_EXCEPTION(not_ready ,"not ready");
TRITON_CREATE_CUDA_EXCEPTION(illegal_address ,"illegal address");
TRITON_CREATE_CUDA_EXCEPTION(launch_out_of_resources ,"launch out of resources");
TRITON_CREATE_CUDA_EXCEPTION(launch_timeout ,"launch timeout");
TRITON_CREATE_CUDA_EXCEPTION(launch_incompatible_texturing ,"launch incompatible texturing");
TRITON_CREATE_CUDA_EXCEPTION(peer_access_already_enabled ,"peer access already enabled");
TRITON_CREATE_CUDA_EXCEPTION(peer_access_not_enabled ,"peer access not enabled");
TRITON_CREATE_CUDA_EXCEPTION(primary_context_active ,"primary context active");
TRITON_CREATE_CUDA_EXCEPTION(context_is_destroyed ,"context is destroyed");
TRITON_CREATE_CUDA_EXCEPTION(assert_error ,"assert");
TRITON_CREATE_CUDA_EXCEPTION(too_many_peers ,"too many peers");
TRITON_CREATE_CUDA_EXCEPTION(host_memory_already_registered ,"host memory already registered");
TRITON_CREATE_CUDA_EXCEPTION(host_memory_not_registered ,"hot memory not registered");
TRITON_CREATE_CUDA_EXCEPTION(hardware_stack_error ,"hardware stack error");
TRITON_CREATE_CUDA_EXCEPTION(illegal_instruction ,"illegal instruction");
TRITON_CREATE_CUDA_EXCEPTION(misaligned_address ,"misaligned address");
TRITON_CREATE_CUDA_EXCEPTION(invalid_address_space ,"invalid address space");
TRITON_CREATE_CUDA_EXCEPTION(invalid_pc ,"invalid pc");
TRITON_CREATE_CUDA_EXCEPTION(launch_failed ,"launch failed");
TRITON_CREATE_CUDA_EXCEPTION(not_permitted ,"not permitted");
TRITON_CREATE_CUDA_EXCEPTION(not_supported ,"not supported");
TRITON_CREATE_CUDA_EXCEPTION(unknown ,"unknown");
TRITON_CREATE_CUDA_EXCEPTION(invalid_value, "invalid value");
TRITON_CREATE_CUDA_EXCEPTION(out_of_memory, "out of memory");
TRITON_CREATE_CUDA_EXCEPTION(not_initialized, "not initialized");
TRITON_CREATE_CUDA_EXCEPTION(deinitialized, "deinitialized");
TRITON_CREATE_CUDA_EXCEPTION(profiler_disabled, "profiler disabled");
TRITON_CREATE_CUDA_EXCEPTION(profiler_not_initialized,
"profiler not initialized");
TRITON_CREATE_CUDA_EXCEPTION(profiler_already_started,
"profiler already started");
TRITON_CREATE_CUDA_EXCEPTION(profiler_already_stopped,
"profiler already stopped");
TRITON_CREATE_CUDA_EXCEPTION(no_device, "no device");
TRITON_CREATE_CUDA_EXCEPTION(invalid_device, "invalid device");
TRITON_CREATE_CUDA_EXCEPTION(invalid_image, "invalid image");
TRITON_CREATE_CUDA_EXCEPTION(invalid_context, "invalid context");
TRITON_CREATE_CUDA_EXCEPTION(context_already_current,
"context already current");
TRITON_CREATE_CUDA_EXCEPTION(map_failed, "map failed");
TRITON_CREATE_CUDA_EXCEPTION(unmap_failed, "unmap failed");
TRITON_CREATE_CUDA_EXCEPTION(array_is_mapped, "array is mapped");
TRITON_CREATE_CUDA_EXCEPTION(already_mapped, "already mapped");
TRITON_CREATE_CUDA_EXCEPTION(no_binary_for_gpu, "no binary for gpu");
TRITON_CREATE_CUDA_EXCEPTION(already_acquired, "already acquired");
TRITON_CREATE_CUDA_EXCEPTION(not_mapped, "not mapped");
TRITON_CREATE_CUDA_EXCEPTION(not_mapped_as_array, "not mapped as array");
TRITON_CREATE_CUDA_EXCEPTION(not_mapped_as_pointer, "not mapped as pointer");
TRITON_CREATE_CUDA_EXCEPTION(ecc_uncorrectable, "ecc uncorrectable");
TRITON_CREATE_CUDA_EXCEPTION(unsupported_limit, "unsupported limit");
TRITON_CREATE_CUDA_EXCEPTION(context_already_in_use, "context already in use");
TRITON_CREATE_CUDA_EXCEPTION(peer_access_unsupported,
"peer access unsupported");
TRITON_CREATE_CUDA_EXCEPTION(invalid_ptx, "invalid ptx");
TRITON_CREATE_CUDA_EXCEPTION(invalid_graphics_context,
"invalid graphics context");
TRITON_CREATE_CUDA_EXCEPTION(invalid_source, "invalid source");
TRITON_CREATE_CUDA_EXCEPTION(file_not_found, "file not found");
TRITON_CREATE_CUDA_EXCEPTION(shared_object_symbol_not_found,
"shared object symbol not found");
TRITON_CREATE_CUDA_EXCEPTION(shared_object_init_failed,
"shared object init failed");
TRITON_CREATE_CUDA_EXCEPTION(operating_system, "operating system");
TRITON_CREATE_CUDA_EXCEPTION(invalid_handle, "invalid handle");
TRITON_CREATE_CUDA_EXCEPTION(not_found, "not found");
TRITON_CREATE_CUDA_EXCEPTION(not_ready, "not ready");
TRITON_CREATE_CUDA_EXCEPTION(illegal_address, "illegal address");
TRITON_CREATE_CUDA_EXCEPTION(launch_out_of_resources,
"launch out of resources");
TRITON_CREATE_CUDA_EXCEPTION(launch_timeout, "launch timeout");
TRITON_CREATE_CUDA_EXCEPTION(launch_incompatible_texturing,
"launch incompatible texturing");
TRITON_CREATE_CUDA_EXCEPTION(peer_access_already_enabled,
"peer access already enabled");
TRITON_CREATE_CUDA_EXCEPTION(peer_access_not_enabled,
"peer access not enabled");
TRITON_CREATE_CUDA_EXCEPTION(primary_context_active, "primary context active");
TRITON_CREATE_CUDA_EXCEPTION(context_is_destroyed, "context is destroyed");
TRITON_CREATE_CUDA_EXCEPTION(assert_error, "assert");
TRITON_CREATE_CUDA_EXCEPTION(too_many_peers, "too many peers");
TRITON_CREATE_CUDA_EXCEPTION(host_memory_already_registered,
"host memory already registered");
TRITON_CREATE_CUDA_EXCEPTION(host_memory_not_registered,
"hot memory not registered");
TRITON_CREATE_CUDA_EXCEPTION(hardware_stack_error, "hardware stack error");
TRITON_CREATE_CUDA_EXCEPTION(illegal_instruction, "illegal instruction");
TRITON_CREATE_CUDA_EXCEPTION(misaligned_address, "misaligned address");
TRITON_CREATE_CUDA_EXCEPTION(invalid_address_space, "invalid address space");
TRITON_CREATE_CUDA_EXCEPTION(invalid_pc, "invalid pc");
TRITON_CREATE_CUDA_EXCEPTION(launch_failed, "launch failed");
TRITON_CREATE_CUDA_EXCEPTION(not_permitted, "not permitted");
TRITON_CREATE_CUDA_EXCEPTION(not_supported, "not supported");
TRITON_CREATE_CUDA_EXCEPTION(unknown, "unknown");
#undef TRITON_CREATE_CUDA_EXCEPTION
} // namespace cuda
namespace cublas {
class base : public std::exception {};
#define TRITON_CREATE_CUBLAS_EXCEPTION(name, msg) \
class name : public base { \
public: \
const char *what() const throw() override { \
return "CUBLAS: Error- " msg; \
} \
}
namespace cublas
{
class base: public std::exception{};
#define TRITON_CREATE_CUBLAS_EXCEPTION(name, msg) \
class name: public base { public: const char * what() const throw() override { return "CUBLAS: Error- " msg; } }
TRITON_CREATE_CUBLAS_EXCEPTION(not_initialized ,"not initialized");
TRITON_CREATE_CUBLAS_EXCEPTION(alloc_failed ,"alloc failed");
TRITON_CREATE_CUBLAS_EXCEPTION(invalid_value ,"invalid value");
TRITON_CREATE_CUBLAS_EXCEPTION(arch_mismatch ,"arch mismatch");
TRITON_CREATE_CUBLAS_EXCEPTION(mapping_error ,"mapping error");
TRITON_CREATE_CUBLAS_EXCEPTION(execution_failed ,"execution failed");
TRITON_CREATE_CUBLAS_EXCEPTION(internal_error ,"internal error");
TRITON_CREATE_CUBLAS_EXCEPTION(not_supported ,"not supported");
TRITON_CREATE_CUBLAS_EXCEPTION(license_error ,"license error");
TRITON_CREATE_CUBLAS_EXCEPTION(unknown ,"unknown");
TRITON_CREATE_CUBLAS_EXCEPTION(not_initialized, "not initialized");
TRITON_CREATE_CUBLAS_EXCEPTION(alloc_failed, "alloc failed");
TRITON_CREATE_CUBLAS_EXCEPTION(invalid_value, "invalid value");
TRITON_CREATE_CUBLAS_EXCEPTION(arch_mismatch, "arch mismatch");
TRITON_CREATE_CUBLAS_EXCEPTION(mapping_error, "mapping error");
TRITON_CREATE_CUBLAS_EXCEPTION(execution_failed, "execution failed");
TRITON_CREATE_CUBLAS_EXCEPTION(internal_error, "internal error");
TRITON_CREATE_CUBLAS_EXCEPTION(not_supported, "not supported");
TRITON_CREATE_CUBLAS_EXCEPTION(license_error, "license error");
TRITON_CREATE_CUBLAS_EXCEPTION(unknown, "unknown");
#undef TRITON_CREATE_CUBLAS_EXCEPTION
} // namespace cublas
namespace cudnn {
#define TRITON_CREATE_CUDNN_EXCEPTION(name, msg) \
class name : public std::exception { \
public: \
const char *what() const throw() override { return "CUDNN: Error- " msg; } \
}
namespace cudnn
{
#define TRITON_CREATE_CUDNN_EXCEPTION(name, msg) \
class name: public std::exception { public: const char * what() const throw() override { return "CUDNN: Error- " msg; } }
TRITON_CREATE_CUDNN_EXCEPTION(not_initialized, "not initialized");
TRITON_CREATE_CUDNN_EXCEPTION(alloc_failed, "allocation failed");
TRITON_CREATE_CUDNN_EXCEPTION(bad_param, "bad param");
TRITON_CREATE_CUDNN_EXCEPTION(internal_error, "internal error");
TRITON_CREATE_CUDNN_EXCEPTION(invalid_value, "invalid value");
TRITON_CREATE_CUDNN_EXCEPTION(arch_mismatch, "arch mismatch");
TRITON_CREATE_CUDNN_EXCEPTION(mapping_error, "mapping error");
TRITON_CREATE_CUDNN_EXCEPTION(execution_failed, "execution failed");
TRITON_CREATE_CUDNN_EXCEPTION(not_supported, "not supported");
TRITON_CREATE_CUDNN_EXCEPTION(license_error, "license error");
TRITON_CREATE_CUDNN_EXCEPTION(runtime_prerequisite_missing,
"prerequisite missing");
TRITON_CREATE_CUDNN_EXCEPTION(runtime_in_progress, "runtime in progress");
TRITON_CREATE_CUDNN_EXCEPTION(runtime_fp_overflow, "runtime fp overflow");
} // namespace cudnn
TRITON_CREATE_CUDNN_EXCEPTION(not_initialized ,"not initialized");
TRITON_CREATE_CUDNN_EXCEPTION(alloc_failed ,"allocation failed");
TRITON_CREATE_CUDNN_EXCEPTION(bad_param ,"bad param");
TRITON_CREATE_CUDNN_EXCEPTION(internal_error ,"internal error");
TRITON_CREATE_CUDNN_EXCEPTION(invalid_value ,"invalid value");
TRITON_CREATE_CUDNN_EXCEPTION(arch_mismatch ,"arch mismatch");
TRITON_CREATE_CUDNN_EXCEPTION(mapping_error ,"mapping error");
TRITON_CREATE_CUDNN_EXCEPTION(execution_failed ,"execution failed");
TRITON_CREATE_CUDNN_EXCEPTION(not_supported ,"not supported");
TRITON_CREATE_CUDNN_EXCEPTION(license_error ,"license error");
TRITON_CREATE_CUDNN_EXCEPTION(runtime_prerequisite_missing ,"prerequisite missing");
TRITON_CREATE_CUDNN_EXCEPTION(runtime_in_progress ,"runtime in progress");
TRITON_CREATE_CUDNN_EXCEPTION(runtime_fp_overflow ,"runtime fp overflow");
namespace hip {
class base : public std::exception {};
#define TRITON_CREATE_HIP_EXCEPTION(name, msg) \
class name : public base { \
public: \
const char *what() const throw() override { return "HIP: Error- " msg; } \
}
namespace hip
{
class base: public std::exception{};
#define TRITON_CREATE_HIP_EXCEPTION(name, msg) \
class name: public base { public:const char * what() const throw() override { return "HIP: Error- " msg; } }
TRITON_CREATE_HIP_EXCEPTION(invalid_value ,"invalid value");
TRITON_CREATE_HIP_EXCEPTION(out_of_memory ,"out of memory");
TRITON_CREATE_HIP_EXCEPTION(not_initialized ,"not initialized");
TRITON_CREATE_HIP_EXCEPTION(deinitialized ,"deinitialized");
TRITON_CREATE_HIP_EXCEPTION(profiler_disabled ,"profiler disabled");
TRITON_CREATE_HIP_EXCEPTION(profiler_not_initialized ,"profiler not initialized");
TRITON_CREATE_HIP_EXCEPTION(profiler_already_started ,"profiler already started");
TRITON_CREATE_HIP_EXCEPTION(profiler_already_stopped ,"profiler already stopped");
TRITON_CREATE_HIP_EXCEPTION(no_device ,"no device");
TRITON_CREATE_HIP_EXCEPTION(invalid_device ,"invalid device");
TRITON_CREATE_HIP_EXCEPTION(invalid_image ,"invalid image");
TRITON_CREATE_HIP_EXCEPTION(invalid_context ,"invalid context");
TRITON_CREATE_HIP_EXCEPTION(context_already_current ,"context already current");
TRITON_CREATE_HIP_EXCEPTION(map_failed ,"map failed");
TRITON_CREATE_HIP_EXCEPTION(unmap_failed ,"unmap failed");
TRITON_CREATE_HIP_EXCEPTION(array_is_mapped ,"array is mapped");
TRITON_CREATE_HIP_EXCEPTION(already_mapped ,"already mapped");
TRITON_CREATE_HIP_EXCEPTION(no_binary_for_gpu ,"no binary for gpu");
TRITON_CREATE_HIP_EXCEPTION(already_acquired ,"already acquired");
TRITON_CREATE_HIP_EXCEPTION(not_mapped ,"not mapped");
TRITON_CREATE_HIP_EXCEPTION(not_mapped_as_array ,"not mapped as array");
TRITON_CREATE_HIP_EXCEPTION(not_mapped_as_pointer ,"not mapped as pointer");
TRITON_CREATE_HIP_EXCEPTION(ecc_uncorrectable ,"ecc uncorrectable");
TRITON_CREATE_HIP_EXCEPTION(unsupported_limit ,"unsupported limit");
TRITON_CREATE_HIP_EXCEPTION(context_already_in_use ,"context already in use");
TRITON_CREATE_HIP_EXCEPTION(peer_access_unsupported ,"peer access unsupported");
TRITON_CREATE_HIP_EXCEPTION(invalid_ptx ,"invalid ptx");
TRITON_CREATE_HIP_EXCEPTION(invalid_graphics_context ,"invalid graphics context");
TRITON_CREATE_HIP_EXCEPTION(invalid_source ,"invalid source");
TRITON_CREATE_HIP_EXCEPTION(file_not_found ,"file not found");
TRITON_CREATE_HIP_EXCEPTION(shared_object_symbol_not_found ,"shared object symbol not found");
TRITON_CREATE_HIP_EXCEPTION(shared_object_init_failed ,"shared object init failed");
TRITON_CREATE_HIP_EXCEPTION(operating_system ,"operating system");
TRITON_CREATE_HIP_EXCEPTION(invalid_handle ,"invalid handle");
TRITON_CREATE_HIP_EXCEPTION(not_found ,"not found");
TRITON_CREATE_HIP_EXCEPTION(not_ready ,"not ready");
TRITON_CREATE_HIP_EXCEPTION(illegal_address ,"illegal address");
TRITON_CREATE_HIP_EXCEPTION(launch_out_of_resources ,"launch out of resources");
TRITON_CREATE_HIP_EXCEPTION(launch_timeout ,"launch timeout");
TRITON_CREATE_HIP_EXCEPTION(launch_incompatible_texturing ,"launch incompatible texturing");
TRITON_CREATE_HIP_EXCEPTION(peer_access_already_enabled ,"peer access already enabled");
TRITON_CREATE_HIP_EXCEPTION(peer_access_not_enabled ,"peer access not enabled");
TRITON_CREATE_HIP_EXCEPTION(primary_context_active ,"primary context active");
TRITON_CREATE_HIP_EXCEPTION(context_is_destroyed ,"context is destroyed");
TRITON_CREATE_HIP_EXCEPTION(assert_error ,"assert");
TRITON_CREATE_HIP_EXCEPTION(too_many_peers ,"too many peers");
TRITON_CREATE_HIP_EXCEPTION(host_memory_already_registered ,"host memory already registered");
TRITON_CREATE_HIP_EXCEPTION(host_memory_not_registered ,"hot memory not registered");
TRITON_CREATE_HIP_EXCEPTION(hardware_stack_error ,"hardware stack error");
TRITON_CREATE_HIP_EXCEPTION(illegal_instruction ,"illegal instruction");
TRITON_CREATE_HIP_EXCEPTION(misaligned_address ,"misaligned address");
TRITON_CREATE_HIP_EXCEPTION(invalid_address_space ,"invalid address space");
TRITON_CREATE_HIP_EXCEPTION(invalid_pc ,"invalid pc");
TRITON_CREATE_HIP_EXCEPTION(launch_failed ,"launch failed");
TRITON_CREATE_HIP_EXCEPTION(not_permitted ,"not permitted");
TRITON_CREATE_HIP_EXCEPTION(not_supported ,"not supported");
TRITON_CREATE_HIP_EXCEPTION(invalid_symbol ,"invalid symbol");
TRITON_CREATE_HIP_EXCEPTION(unknown ,"unknown");
TRITON_CREATE_HIP_EXCEPTION(invalid_value, "invalid value");
TRITON_CREATE_HIP_EXCEPTION(out_of_memory, "out of memory");
TRITON_CREATE_HIP_EXCEPTION(not_initialized, "not initialized");
TRITON_CREATE_HIP_EXCEPTION(deinitialized, "deinitialized");
TRITON_CREATE_HIP_EXCEPTION(profiler_disabled, "profiler disabled");
TRITON_CREATE_HIP_EXCEPTION(profiler_not_initialized,
"profiler not initialized");
TRITON_CREATE_HIP_EXCEPTION(profiler_already_started,
"profiler already started");
TRITON_CREATE_HIP_EXCEPTION(profiler_already_stopped,
"profiler already stopped");
TRITON_CREATE_HIP_EXCEPTION(no_device, "no device");
TRITON_CREATE_HIP_EXCEPTION(invalid_device, "invalid device");
TRITON_CREATE_HIP_EXCEPTION(invalid_image, "invalid image");
TRITON_CREATE_HIP_EXCEPTION(invalid_context, "invalid context");
TRITON_CREATE_HIP_EXCEPTION(context_already_current, "context already current");
TRITON_CREATE_HIP_EXCEPTION(map_failed, "map failed");
TRITON_CREATE_HIP_EXCEPTION(unmap_failed, "unmap failed");
TRITON_CREATE_HIP_EXCEPTION(array_is_mapped, "array is mapped");
TRITON_CREATE_HIP_EXCEPTION(already_mapped, "already mapped");
TRITON_CREATE_HIP_EXCEPTION(no_binary_for_gpu, "no binary for gpu");
TRITON_CREATE_HIP_EXCEPTION(already_acquired, "already acquired");
TRITON_CREATE_HIP_EXCEPTION(not_mapped, "not mapped");
TRITON_CREATE_HIP_EXCEPTION(not_mapped_as_array, "not mapped as array");
TRITON_CREATE_HIP_EXCEPTION(not_mapped_as_pointer, "not mapped as pointer");
TRITON_CREATE_HIP_EXCEPTION(ecc_uncorrectable, "ecc uncorrectable");
TRITON_CREATE_HIP_EXCEPTION(unsupported_limit, "unsupported limit");
TRITON_CREATE_HIP_EXCEPTION(context_already_in_use, "context already in use");
TRITON_CREATE_HIP_EXCEPTION(peer_access_unsupported, "peer access unsupported");
TRITON_CREATE_HIP_EXCEPTION(invalid_ptx, "invalid ptx");
TRITON_CREATE_HIP_EXCEPTION(invalid_graphics_context,
"invalid graphics context");
TRITON_CREATE_HIP_EXCEPTION(invalid_source, "invalid source");
TRITON_CREATE_HIP_EXCEPTION(file_not_found, "file not found");
TRITON_CREATE_HIP_EXCEPTION(shared_object_symbol_not_found,
"shared object symbol not found");
TRITON_CREATE_HIP_EXCEPTION(shared_object_init_failed,
"shared object init failed");
TRITON_CREATE_HIP_EXCEPTION(operating_system, "operating system");
TRITON_CREATE_HIP_EXCEPTION(invalid_handle, "invalid handle");
TRITON_CREATE_HIP_EXCEPTION(not_found, "not found");
TRITON_CREATE_HIP_EXCEPTION(not_ready, "not ready");
TRITON_CREATE_HIP_EXCEPTION(illegal_address, "illegal address");
TRITON_CREATE_HIP_EXCEPTION(launch_out_of_resources, "launch out of resources");
TRITON_CREATE_HIP_EXCEPTION(launch_timeout, "launch timeout");
TRITON_CREATE_HIP_EXCEPTION(launch_incompatible_texturing,
"launch incompatible texturing");
TRITON_CREATE_HIP_EXCEPTION(peer_access_already_enabled,
"peer access already enabled");
TRITON_CREATE_HIP_EXCEPTION(peer_access_not_enabled, "peer access not enabled");
TRITON_CREATE_HIP_EXCEPTION(primary_context_active, "primary context active");
TRITON_CREATE_HIP_EXCEPTION(context_is_destroyed, "context is destroyed");
TRITON_CREATE_HIP_EXCEPTION(assert_error, "assert");
TRITON_CREATE_HIP_EXCEPTION(too_many_peers, "too many peers");
TRITON_CREATE_HIP_EXCEPTION(host_memory_already_registered,
"host memory already registered");
TRITON_CREATE_HIP_EXCEPTION(host_memory_not_registered,
"hot memory not registered");
TRITON_CREATE_HIP_EXCEPTION(hardware_stack_error, "hardware stack error");
TRITON_CREATE_HIP_EXCEPTION(illegal_instruction, "illegal instruction");
TRITON_CREATE_HIP_EXCEPTION(misaligned_address, "misaligned address");
TRITON_CREATE_HIP_EXCEPTION(invalid_address_space, "invalid address space");
TRITON_CREATE_HIP_EXCEPTION(invalid_pc, "invalid pc");
TRITON_CREATE_HIP_EXCEPTION(launch_failed, "launch failed");
TRITON_CREATE_HIP_EXCEPTION(not_permitted, "not permitted");
TRITON_CREATE_HIP_EXCEPTION(not_supported, "not supported");
TRITON_CREATE_HIP_EXCEPTION(invalid_symbol, "invalid symbol");
TRITON_CREATE_HIP_EXCEPTION(unknown, "unknown");
#undef TRITON_CREATE_CUDA_EXCEPTION
}
} // namespace hip
}
}
}
} // namespace exception
} // namespace driver
} // namespace triton
#endif

View File

@@ -1,20 +1,21 @@
#include <string>
#include "triton/driver/dispatch.h"
#include <string>
namespace llvm{
namespace llvm {
class Module;
}
namespace triton{
namespace driver{
namespace triton {
namespace driver {
void init_llvm();
std::string path_to_ptxas(int& version);
std::string llir_to_ptx(llvm::Module* module, int cc, int version);
std::string ptx_to_cubin(const std::string& ptx, const std::string& ptxas_path, int cc);
CUmodule ptx_to_cumodule(const std::string& ptx, int cc);
std::string llir_to_amdgpu(llvm::Module* module, const std::string& proc);
hipModule_t amdgpu_to_hipmodule(const std::string& path);
std::string path_to_ptxas(int &version);
std::string llir_to_ptx(llvm::Module *module, int cc, int version);
std::string ptx_to_cubin(const std::string &ptx, const std::string &ptxas_path,
int cc);
CUmodule ptx_to_cumodule(const std::string &ptx, int cc);
std::string llir_to_amdgpu(llvm::Module *module, const std::string &proc);
hipModule_t amdgpu_to_hipmodule(const std::string &path);
}
}
} // namespace driver
} // namespace triton

View File

@@ -3,52 +3,55 @@
#ifndef _TRITON_TOOLS_BENCH_H_
#define _TRITON_TOOLS_BENCH_H_
#include <chrono>
#include <functional>
#include <algorithm>
#include "triton/driver/device.h"
#include "triton/driver/stream.h"
#include <algorithm>
#include <chrono>
#include <functional>
namespace triton{
namespace tools{
namespace triton {
namespace tools {
class timer{
typedef std::chrono::high_resolution_clock high_resolution_clock;
typedef std::chrono::nanoseconds nanoseconds;
class timer {
typedef std::chrono::high_resolution_clock high_resolution_clock;
typedef std::chrono::nanoseconds nanoseconds;
public:
explicit timer(bool run = false)
{ if (run) start(); }
explicit timer(bool run = false) {
if (run)
start();
}
void start()
{ _start = high_resolution_clock::now(); }
void start() { _start = high_resolution_clock::now(); }
nanoseconds get() const
{ return std::chrono::duration_cast<nanoseconds>(high_resolution_clock::now() - _start); }
nanoseconds get() const {
return std::chrono::duration_cast<nanoseconds>(
high_resolution_clock::now() - _start);
}
private:
high_resolution_clock::time_point _start;
high_resolution_clock::time_point _start;
};
inline double bench(std::function<void()> const & op, driver::stream * stream, size_t warmup = 10, size_t repeat = 200)
{
inline double bench(std::function<void()> const &op, driver::stream *stream,
size_t warmup = 10, size_t repeat = 200) {
timer tmr;
std::vector<size_t> times;
double total_time = 0;
for(size_t i = 0; i < warmup; i++)
for (size_t i = 0; i < warmup; i++)
op();
stream->synchronize();
tmr.start();
for(size_t i = 0; i < repeat; i++){
for (size_t i = 0; i < repeat; i++) {
op();
}
stream->synchronize();
return (float)tmr.get().count() / repeat;
// return *std::min_element(times.begin(), times.end());
// return *std::min_element(times.begin(), times.end());
}
}
}
} // namespace tools
} // namespace triton
#endif

View File

@@ -3,16 +3,15 @@
#ifndef _TRITON_TOOLS_THREAD_GRAPH_H_
#define _TRITON_TOOLS_THREAD_GRAPH_H_
#include <iostream>
#include <map>
#include <set>
#include <vector>
#include <iostream>
namespace triton {
namespace tools{
namespace tools {
template<class node_t>
class graph {
template <class node_t> class graph {
typedef std::map<node_t, std::set<node_t>> edges_t;
public:
@@ -21,27 +20,27 @@ public:
private:
void connected_components_impl(node_t x, std::set<node_t> &nodes,
nmap_t* nmap, cmap_t* cmap, int id) const {
if(nmap)
nmap_t *nmap, cmap_t *cmap, int id) const {
if (nmap)
(*nmap)[x] = id;
if(cmap)
if (cmap)
(*cmap)[id].push_back(x);
if(nodes.find(x) != nodes.end()) {
if (nodes.find(x) != nodes.end()) {
nodes.erase(x);
for(const node_t &y: edges_.at(x))
for (const node_t &y : edges_.at(x))
connected_components_impl(y, nodes, nmap, cmap, id);
}
}
public:
void connected_components(cmap_t *cmap, nmap_t *nmap) const {
if(cmap)
if (cmap)
cmap->clear();
if(nmap)
if (nmap)
nmap->clear();
std::set<node_t> nodes = nodes_;
unsigned id = 0;
while(!nodes.empty()){
while (!nodes.empty()) {
connected_components_impl(*nodes.begin(), nodes, nmap, cmap, id++);
}
}
@@ -63,7 +62,7 @@ private:
edges_t edges_;
};
}
}
} // namespace tools
} // namespace triton
#endif

View File

@@ -33,154 +33,140 @@
#ifndef _TRITON_TOOLS_SHA1_HPP_
#define _TRITON_TOOLS_SHA1_HPP_
namespace sha1
namespace sha1 {
namespace // local
{
namespace // local
{
// Rotate an integer value to left.
inline unsigned int rol(const unsigned int value,
const unsigned int steps)
{
return ((value << steps) | (value >> (32 - steps)));
}
// Rotate an integer value to left.
inline unsigned int rol(const unsigned int value, const unsigned int steps) {
return ((value << steps) | (value >> (32 - steps)));
}
// Sets the first 16 integers in the buffert to zero.
// Used for clearing the W buffert.
inline void clearWBuffert(unsigned int* buffert)
{
for (int pos = 16; --pos >= 0;)
{
buffert[pos] = 0;
}
}
// Sets the first 16 integers in the buffert to zero.
// Used for clearing the W buffert.
inline void clearWBuffert(unsigned int *buffert) {
for (int pos = 16; --pos >= 0;) {
buffert[pos] = 0;
}
}
inline void innerHash(unsigned int* result, unsigned int* w)
{
unsigned int a = result[0];
unsigned int b = result[1];
unsigned int c = result[2];
unsigned int d = result[3];
unsigned int e = result[4];
inline void innerHash(unsigned int *result, unsigned int *w) {
unsigned int a = result[0];
unsigned int b = result[1];
unsigned int c = result[2];
unsigned int d = result[3];
unsigned int e = result[4];
int round = 0;
int round = 0;
#define sha1macro(func,val) \
{ \
const unsigned int t = rol(a, 5) + (func) + e + val + w[round]; \
e = d; \
d = c; \
c = rol(b, 30); \
b = a; \
a = t; \
}
#define sha1macro(func, val) \
{ \
const unsigned int t = rol(a, 5) + (func) + e + val + w[round]; \
e = d; \
d = c; \
c = rol(b, 30); \
b = a; \
a = t; \
}
while (round < 16)
{
sha1macro((b & c) | (~b & d), 0x5a827999)
++round;
}
while (round < 20)
{
w[round] = rol((w[round - 3] ^ w[round - 8] ^ w[round - 14] ^ w[round - 16]), 1);
sha1macro((b & c) | (~b & d), 0x5a827999)
++round;
}
while (round < 40)
{
w[round] = rol((w[round - 3] ^ w[round - 8] ^ w[round - 14] ^ w[round - 16]), 1);
sha1macro(b ^ c ^ d, 0x6ed9eba1)
++round;
}
while (round < 60)
{
w[round] = rol((w[round - 3] ^ w[round - 8] ^ w[round - 14] ^ w[round - 16]), 1);
sha1macro((b & c) | (b & d) | (c & d), 0x8f1bbcdc)
++round;
}
while (round < 80)
{
w[round] = rol((w[round - 3] ^ w[round - 8] ^ w[round - 14] ^ w[round - 16]), 1);
sha1macro(b ^ c ^ d, 0xca62c1d6)
++round;
}
while (round < 16) {
sha1macro((b & c) | (~b & d), 0x5a827999)++ round;
}
while (round < 20) {
w[round] =
rol((w[round - 3] ^ w[round - 8] ^ w[round - 14] ^ w[round - 16]), 1);
sha1macro((b & c) | (~b & d), 0x5a827999)++ round;
}
while (round < 40) {
w[round] =
rol((w[round - 3] ^ w[round - 8] ^ w[round - 14] ^ w[round - 16]), 1);
sha1macro(b ^ c ^ d, 0x6ed9eba1)++ round;
}
while (round < 60) {
w[round] =
rol((w[round - 3] ^ w[round - 8] ^ w[round - 14] ^ w[round - 16]), 1);
sha1macro((b & c) | (b & d) | (c & d), 0x8f1bbcdc)++ round;
}
while (round < 80) {
w[round] =
rol((w[round - 3] ^ w[round - 8] ^ w[round - 14] ^ w[round - 16]), 1);
sha1macro(b ^ c ^ d, 0xca62c1d6)++ round;
}
#undef sha1macro
#undef sha1macro
result[0] += a;
result[1] += b;
result[2] += c;
result[3] += d;
result[4] += e;
}
} // namespace
result[0] += a;
result[1] += b;
result[2] += c;
result[3] += d;
result[4] += e;
}
} // namespace
inline void calc(const void* src, const int bytelength, unsigned char* hash)
{
// Init the result array.
unsigned int result[5] = { 0x67452301, 0xefcdab89, 0x98badcfe, 0x10325476, 0xc3d2e1f0 };
inline void calc(const void *src, const int bytelength, unsigned char *hash) {
// Init the result array.
unsigned int result[5] = {0x67452301, 0xefcdab89, 0x98badcfe, 0x10325476,
0xc3d2e1f0};
// Cast the void src pointer to be the byte array we can work with.
const unsigned char* sarray = (const unsigned char*) src;
// Cast the void src pointer to be the byte array we can work with.
const unsigned char *sarray = (const unsigned char *)src;
// The reusable round buffer
unsigned int w[80];
// The reusable round buffer
unsigned int w[80];
// Loop through all complete 64byte blocks.
const int endOfFullBlocks = bytelength - 64;
int endCurrentBlock;
int currentBlock = 0;
// Loop through all complete 64byte blocks.
const int endOfFullBlocks = bytelength - 64;
int endCurrentBlock;
int currentBlock = 0;
while (currentBlock <= endOfFullBlocks)
{
endCurrentBlock = currentBlock + 64;
while (currentBlock <= endOfFullBlocks) {
endCurrentBlock = currentBlock + 64;
// Init the round buffer with the 64 byte block data.
for (int roundPos = 0; currentBlock < endCurrentBlock; currentBlock += 4)
{
// This line will swap endian on big endian and keep endian on little endian.
w[roundPos++] = (unsigned int) sarray[currentBlock + 3]
| (((unsigned int) sarray[currentBlock + 2]) << 8)
| (((unsigned int) sarray[currentBlock + 1]) << 16)
| (((unsigned int) sarray[currentBlock]) << 24);
}
innerHash(result, w);
}
// Handle the last and not full 64 byte block if existing.
endCurrentBlock = bytelength - currentBlock;
clearWBuffert(w);
int lastBlockBytes = 0;
for (;lastBlockBytes < endCurrentBlock; ++lastBlockBytes)
{
w[lastBlockBytes >> 2] |= (unsigned int) sarray[lastBlockBytes + currentBlock] << ((3 - (lastBlockBytes & 3)) << 3);
}
w[lastBlockBytes >> 2] |= 0x80 << ((3 - (lastBlockBytes & 3)) << 3);
if (endCurrentBlock >= 56)
{
innerHash(result, w);
clearWBuffert(w);
}
w[15] = bytelength << 3;
innerHash(result, w);
// Store hash in result pointer, and make sure we get in in the correct order on both endian models.
for (int hashByte = 20; --hashByte >= 0;)
{
hash[hashByte] = (result[hashByte >> 2] >> (((3 - hashByte) & 0x3) << 3)) & 0xff;
}
// Init the round buffer with the 64 byte block data.
for (int roundPos = 0; currentBlock < endCurrentBlock; currentBlock += 4) {
// This line will swap endian on big endian and keep endian on little
// endian.
w[roundPos++] = (unsigned int)sarray[currentBlock + 3] |
(((unsigned int)sarray[currentBlock + 2]) << 8) |
(((unsigned int)sarray[currentBlock + 1]) << 16) |
(((unsigned int)sarray[currentBlock]) << 24);
}
innerHash(result, w);
}
inline void toHexString(const unsigned char* hash, char* hexstring)
{
const char hexDigits[] = { "0123456789abcdef" };
// Handle the last and not full 64 byte block if existing.
endCurrentBlock = bytelength - currentBlock;
clearWBuffert(w);
int lastBlockBytes = 0;
for (; lastBlockBytes < endCurrentBlock; ++lastBlockBytes) {
w[lastBlockBytes >> 2] |=
(unsigned int)sarray[lastBlockBytes + currentBlock]
<< ((3 - (lastBlockBytes & 3)) << 3);
}
w[lastBlockBytes >> 2] |= 0x80 << ((3 - (lastBlockBytes & 3)) << 3);
if (endCurrentBlock >= 56) {
innerHash(result, w);
clearWBuffert(w);
}
w[15] = bytelength << 3;
innerHash(result, w);
for (int hashByte = 20; --hashByte >= 0;)
{
hexstring[hashByte << 1] = hexDigits[(hash[hashByte] >> 4) & 0xf];
hexstring[(hashByte << 1) + 1] = hexDigits[hash[hashByte] & 0xf];
}
hexstring[40] = 0;
}
// Store hash in result pointer, and make sure we get in in the correct order
// on both endian models.
for (int hashByte = 20; --hashByte >= 0;) {
hash[hashByte] =
(result[hashByte >> 2] >> (((3 - hashByte) & 0x3) << 3)) & 0xff;
}
}
inline void toHexString(const unsigned char *hash, char *hexstring) {
const char hexDigits[] = {"0123456789abcdef"};
for (int hashByte = 20; --hashByte >= 0;) {
hexstring[hashByte << 1] = hexDigits[(hash[hashByte] >> 4) & 0xf];
hexstring[(hashByte << 1) + 1] = hexDigits[hash[hashByte] & 0xf];
}
hexstring[40] = 0;
}
} // namespace sha1
#endif

View File

@@ -7,11 +7,8 @@
#include <stdexcept>
#include <string>
namespace triton
{
namespace tools
{
namespace triton {
namespace tools {
#ifdef _WIN32
#define popen _popen
@@ -19,12 +16,12 @@ namespace tools
#endif
#ifndef WEXITSTATUS
#define WEXITSTATUS(stat_val) ((unsigned)(stat_val) & 255)
#define WEXITSTATUS(stat_val) ((unsigned)(stat_val)&255)
#endif
int exec(const std::string& cmd, std::string& result) {
int exec(const std::string &cmd, std::string &result) {
char buffer[128];
FILE* pipe = popen(cmd.c_str(), "r");
FILE *pipe = popen(cmd.c_str(), "r");
if (!pipe)
return 0;
result.clear();
@@ -37,10 +34,9 @@ int exec(const std::string& cmd, std::string& result) {
}
int status = pclose(pipe);
return WEXITSTATUS(status);
}
}
}
} // namespace tools
} // namespace triton
#endif

27
include/triton/tools/sys/getenv.hpp Executable file → Normal file
View File

@@ -22,26 +22,23 @@
#ifndef TDL_TOOLS_SYS_GETENV_HPP
#define TDL_TOOLS_SYS_GETENV_HPP
#include <string>
#include <cstdlib>
#include <string>
namespace triton
{
namespace triton {
namespace tools
{
inline std::string getenv(const char * name)
{
const char * cstr = std::getenv(name);
if(!cstr)
return "";
std::string result(cstr);
return result;
}
namespace tools {
inline std::string getenv(const char *name) {
const char *cstr = std::getenv(name);
if (!cstr)
return "";
std::string result(cstr);
return result;
}
}
} // namespace tools
} // namespace triton
#endif

74
include/triton/tools/sys/mkdir.hpp Executable file → Normal file
View File

@@ -22,55 +22,49 @@
#ifndef TDL_TOOLS_SYS_MKDIR_HPP
#define TDL_TOOLS_SYS_MKDIR_HPP
#include <cstring>
#include <string>
#include <cstdlib>
#include <sys/stat.h>
#include <cstring>
#include <errno.h>
#include <string>
#include <sys/stat.h>
#if defined(_WIN32)
#include <direct.h>
#include <direct.h>
#endif
namespace triton
{
namespace triton {
namespace tools
{
inline int mkdir(std::string const & path)
{
#if defined(_WIN32)
return _mkdir(path.c_str());
#else
return ::mkdir(path.c_str(), 0777);
#endif
}
inline int mkpath(std::string const & path)
{
int status = 0;
size_t pp = 0;
size_t sp;
while ((sp = path.find('/', pp)) != std::string::npos)
{
if (sp != pp){
status = mkdir(path.substr(0, sp));
}
pp = sp + 1;
}
return (status==0 || errno==EEXIST)?0:-1;
}
inline int mtime(std::string const & path)
{
struct stat st;
if(stat(path.c_str(), &st) != 0)
return 0;
return st.st_mtime;
}
namespace tools {
inline int mkdir(std::string const &path) {
#if defined(_WIN32)
return _mkdir(path.c_str());
#else
return ::mkdir(path.c_str(), 0777);
#endif
}
inline int mkpath(std::string const &path) {
int status = 0;
size_t pp = 0;
size_t sp;
while ((sp = path.find('/', pp)) != std::string::npos) {
if (sp != pp) {
status = mkdir(path.substr(0, sp));
}
pp = sp + 1;
}
return (status == 0 || errno == EEXIST) ? 0 : -1;
}
inline int mtime(std::string const &path) {
struct stat st;
if (stat(path.c_str(), &st) != 0)
return 0;
return st.st_mtime;
}
} // namespace tools
} // namespace triton
#endif

View File

@@ -3,88 +3,79 @@
#ifndef _TRITON_TOOLS_THREAD_POOL_H_
#define _TRITON_TOOLS_THREAD_POOL_H_
#include <vector>
#include <queue>
#include <memory>
#include <thread>
#include <mutex>
#include <condition_variable>
#include <future>
#include <functional>
#include <future>
#include <memory>
#include <mutex>
#include <queue>
#include <stdexcept>
#include <thread>
#include <vector>
class ThreadPool {
public:
ThreadPool(size_t threads)
: stop(false) {
for(size_t i = 0;i < threads;++i)
workers.emplace_back(
[this] {
for(;;){
std::function<void()> task;
{
std::unique_lock<std::mutex> lock(this->queue_mutex);
this->condition.wait(lock,
[this]{ return this->stop || !this->tasks.empty(); });
if(this->stop && this->tasks.empty())
return;
task = std::move(this->tasks.front());
this->tasks.pop();
}
task();
}
}
);
}
ThreadPool(size_t threads) : stop(false) {
for (size_t i = 0; i < threads; ++i)
workers.emplace_back([this] {
for (;;) {
std::function<void()> task;
{
std::unique_lock<std::mutex> lock(this->queue_mutex);
this->condition.wait(
lock, [this] { return this->stop || !this->tasks.empty(); });
if (this->stop && this->tasks.empty())
return;
task = std::move(this->tasks.front());
this->tasks.pop();
}
task();
}
});
}
template <class F, class... Args>
auto enqueue(F &&f, Args &&... args)
-> std::future<typename std::result_of<F(Args...)>::type> {
using return_type = typename std::result_of<F(Args...)>::type;
template<class F, class... Args>
auto enqueue(F&& f, Args&&... args)
-> std::future<typename std::result_of<F(Args...)>::type>
auto task = std::make_shared<std::packaged_task<return_type()>>(
std::bind(std::forward<F>(f), std::forward<Args>(args)...));
std::future<return_type> res = task->get_future();
{
using return_type = typename std::result_of<F(Args...)>::type;
std::unique_lock<std::mutex> lock(queue_mutex);
auto task = std::make_shared< std::packaged_task<return_type()> >(
std::bind(std::forward<F>(f), std::forward<Args>(args)...)
);
// don't allow enqueueing after stopping the pool
if (stop)
throw std::runtime_error("enqueue on stopped ThreadPool");
std::future<return_type> res = task->get_future();
{
std::unique_lock<std::mutex> lock(queue_mutex);
// don't allow enqueueing after stopping the pool
if(stop)
throw std::runtime_error("enqueue on stopped ThreadPool");
tasks.emplace([task](){ (*task)(); });
}
condition.notify_one();
return res;
tasks.emplace([task]() { (*task)(); });
}
condition.notify_one();
return res;
}
~ThreadPool() {
{
std::unique_lock<std::mutex> lock(queue_mutex);
stop = true;
}
condition.notify_all();
for(std::thread &worker: workers)
worker.join();
~ThreadPool() {
{
std::unique_lock<std::mutex> lock(queue_mutex);
stop = true;
}
condition.notify_all();
for (std::thread &worker : workers)
worker.join();
}
private:
// need to keep track of threads so we can join them
std::vector< std::thread > workers;
// the task queue
std::queue< std::function<void()> > tasks;
// need to keep track of threads so we can join them
std::vector<std::thread> workers;
// the task queue
std::queue<std::function<void()>> tasks;
// synchronization
std::mutex queue_mutex;
std::condition_variable condition;
bool stop;
// synchronization
std::mutex queue_mutex;
std::condition_variable condition;
bool stop;
};
#endif

View File

@@ -8,24 +8,23 @@
namespace mlir {
//===----------------------------------------------------------------------===//
// AxisInfo
//===----------------------------------------------------------------------===//
// Function for extended Euclidean Algorithm
static int gcd_impl(int a, int b, int *x, int *y){
// Function for extended Euclidean Algorithm
static int gcd_impl(int a, int b, int *x, int *y) {
// Base Case
if (a == 0) {
*x = 0;
*y = 1;
return b;
*x = 0;
*y = 1;
return b;
}
int x1, y1; // To store results of recursive call
int gcd = gcd_impl(b%a, a, &x1, &y1);
int gcd = gcd_impl(b % a, a, &x1, &y1);
// Update x and y using results of
// recursive call
*x = y1 - (b/a) * x1;
*x = y1 - (b / a) * x1;
*y = x1;
return gcd;
}
@@ -35,17 +34,17 @@ static int gcd(int a, int b) {
return gcd_impl(a, b, &x, &y);
}
AxisInfo AxisInfo::getPessimisticValueState(Value value) {
size_t rank = 1;
if(TensorType ty = value.getType().dyn_cast<TensorType>())
if (TensorType ty = value.getType().dyn_cast<TensorType>())
rank = ty.getRank();
int divHint = 1;
if(BlockArgument blockArg = value.dyn_cast<BlockArgument>()){
Operation* op = blockArg.getOwner()->getParentOp();
if(FuncOp fun = dyn_cast<FuncOp>(op)){
Attribute attr = fun.getArgAttr(blockArg.getArgNumber(), "tt.divisibility");
if(attr)
if (BlockArgument blockArg = value.dyn_cast<BlockArgument>()) {
Operation *op = blockArg.getOwner()->getParentOp();
if (FuncOp fun = dyn_cast<FuncOp>(op)) {
Attribute attr =
fun.getArgAttr(blockArg.getArgNumber(), "tt.divisibility");
if (attr)
divHint = attr.cast<IntegerAttr>().getValue().getZExtValue();
}
}
@@ -55,51 +54,51 @@ AxisInfo AxisInfo::getPessimisticValueState(Value value) {
return AxisInfo(contiguity, divisibility, constancy);
}
// The gcd of both arguments for each dimension
AxisInfo AxisInfo::join(const AxisInfo &lhs,
const AxisInfo &rhs) {
AxisInfo AxisInfo::join(const AxisInfo &lhs, const AxisInfo &rhs) {
ContiguityT retContiguity;
DivisibilityT retDivisibility;
ConstancyT retConstancy;
for(size_t d = 0; d < lhs.getRank(); d++){
for (size_t d = 0; d < lhs.getRank(); d++) {
retContiguity.push_back(gcd(lhs.getContiguity(d), rhs.getContiguity(d)));
retDivisibility.push_back(gcd(lhs.getDivisibility(d), rhs.getDivisibility(d)));
retDivisibility.push_back(
gcd(lhs.getDivisibility(d), rhs.getDivisibility(d)));
retConstancy.push_back(gcd(lhs.getConstancy(d), rhs.getConstancy(d)));
}
return AxisInfo(retContiguity, retDivisibility, retConstancy);
}
//===----------------------------------------------------------------------===//
// AxisInfoAnalysis
//===----------------------------------------------------------------------===//
AxisInfo AxisInfoAnalysis::visitBinaryOp(Operation* op, AxisInfo lhsInfo, AxisInfo rhsInfo,
const std::function<int(AxisInfo,AxisInfo,int)>& getContiguity,
const std::function<int(AxisInfo,AxisInfo,int)>& getDivisibility,
const std::function<int(AxisInfo,AxisInfo,int)>& getConstancy) {
int rank = lhsInfo.getRank();
AxisInfo::ContiguityT newContiguity;
AxisInfo::DivisibilityT newDivisibility;
AxisInfo::ConstancyT newConstancy;
for(size_t d = 0; d < rank; d++){
newContiguity.push_back(getContiguity(lhsInfo, rhsInfo, d));
newDivisibility.push_back(getDivisibility(lhsInfo, rhsInfo, d));
newConstancy.push_back(getConstancy(lhsInfo, rhsInfo, d));
}
return AxisInfo(newContiguity, newDivisibility, newConstancy);
AxisInfo AxisInfoAnalysis::visitBinaryOp(
Operation *op, AxisInfo lhsInfo, AxisInfo rhsInfo,
const std::function<int(AxisInfo, AxisInfo, int)> &getContiguity,
const std::function<int(AxisInfo, AxisInfo, int)> &getDivisibility,
const std::function<int(AxisInfo, AxisInfo, int)> &getConstancy) {
int rank = lhsInfo.getRank();
AxisInfo::ContiguityT newContiguity;
AxisInfo::DivisibilityT newDivisibility;
AxisInfo::ConstancyT newConstancy;
for (size_t d = 0; d < rank; d++) {
newContiguity.push_back(getContiguity(lhsInfo, rhsInfo, d));
newDivisibility.push_back(getDivisibility(lhsInfo, rhsInfo, d));
newConstancy.push_back(getConstancy(lhsInfo, rhsInfo, d));
}
return AxisInfo(newContiguity, newDivisibility, newConstancy);
}
ChangeResult AxisInfoAnalysis::visitOperation(Operation *op,
ArrayRef<LatticeElement<AxisInfo> *> operands) {
ChangeResult AxisInfoAnalysis::visitOperation(
Operation *op, ArrayRef<LatticeElement<AxisInfo> *> operands) {
AxisInfo curr;
// This preserves the input axes (e.g., cast):
if (llvm::isa<arith::ExtSIOp, arith::ExtUIOp, arith::TruncIOp,
triton::PtrToIntOp, triton::IntToPtrOp>(op))
curr = operands[0]->getValue();
// Constant ranges
if (triton::MakeRangeOp make_range = llvm::dyn_cast<triton::MakeRangeOp>(op)){
if (triton::MakeRangeOp make_range =
llvm::dyn_cast<triton::MakeRangeOp>(op)) {
int start = make_range.start();
int end = make_range.end();
AxisInfo::ContiguityT contiguity = {end - start};
@@ -108,61 +107,59 @@ ChangeResult AxisInfoAnalysis::visitOperation(Operation *op,
curr = AxisInfo(contiguity, divisibility, constancy);
}
// Constant
if (arith::ConstantOp constant = llvm::dyn_cast<arith::ConstantOp>(op)){
if (arith::ConstantOp constant = llvm::dyn_cast<arith::ConstantOp>(op)) {
auto intAttr = constant.getValue().dyn_cast<IntegerAttr>();
if(intAttr){
if (intAttr) {
size_t val = intAttr.getValue().getZExtValue();
curr = AxisInfo({1}, {highestPowOf2Divisor(val)}, {1});
}
// TODO: generalize to dense attr
auto splatAttr = constant.getValue().dyn_cast<SplatElementsAttr>();
if(splatAttr && splatAttr.getElementType().isInteger(32)){
if (splatAttr && splatAttr.getElementType().isInteger(32)) {
auto value = splatAttr.getSplatValue<int>();
TensorType ty = splatAttr.getType().cast<TensorType>();
curr = AxisInfo(AxisInfo::ContiguityT(ty.getRank(), 1),
AxisInfo::DivisibilityT(ty.getRank(), highestPowOf2Divisor(value)),
AxisInfo::ConstancyT(ty.getShape().begin(), ty.getShape().end()));
curr = AxisInfo(
AxisInfo::ContiguityT(ty.getRank(), 1),
AxisInfo::DivisibilityT(ty.getRank(), highestPowOf2Divisor(value)),
AxisInfo::ConstancyT(ty.getShape().begin(), ty.getShape().end()));
}
}
// Addition
if (llvm::isa<arith::AddIOp, triton::GEPOp>(op)){
auto newContiguity = [&](AxisInfo lhs, AxisInfo rhs, int d){
if (llvm::isa<arith::AddIOp, triton::GEPOp>(op)) {
auto newContiguity = [&](AxisInfo lhs, AxisInfo rhs, int d) {
return std::max(gcd(lhs.getContiguity(d), rhs.getConstancy(d)),
gcd(lhs.getConstancy(d), rhs.getContiguity(d)));
};
auto newConstancy = [&](AxisInfo lhs, AxisInfo rhs, int d){
auto newConstancy = [&](AxisInfo lhs, AxisInfo rhs, int d) {
return gcd(lhs.getConstancy(d), rhs.getConstancy(d));
};
auto newDivisibility = [&](AxisInfo lhs, AxisInfo rhs, int d){
auto newDivisibility = [&](AxisInfo lhs, AxisInfo rhs, int d) {
return gcd(lhs.getDivisibility(d), rhs.getDivisibility(d));
};
curr = visitBinaryOp(op, operands[0]->getValue(), operands[1]->getValue(),
newContiguity, newDivisibility, newConstancy);
newContiguity, newDivisibility, newConstancy);
}
// Multiplication
if (llvm::isa<arith::MulIOp>(op)){
auto newContiguity = [](AxisInfo lhs, AxisInfo rhs, int d){
return 1;
};
auto newConstancy = [](AxisInfo lhs, AxisInfo rhs, int d){
if (llvm::isa<arith::MulIOp>(op)) {
auto newContiguity = [](AxisInfo lhs, AxisInfo rhs, int d) { return 1; };
auto newConstancy = [](AxisInfo lhs, AxisInfo rhs, int d) {
return gcd(lhs.getConstancy(d), rhs.getConstancy(d));
};
auto newDivisibility = [](AxisInfo lhs, AxisInfo rhs, int d){
return lhs.getDivisibility(d)*rhs.getDivisibility(d);
auto newDivisibility = [](AxisInfo lhs, AxisInfo rhs, int d) {
return lhs.getDivisibility(d) * rhs.getDivisibility(d);
};
curr = visitBinaryOp(op, operands[0]->getValue(), operands[1]->getValue(),
newContiguity, newDivisibility, newConstancy);
newContiguity, newDivisibility, newConstancy);
}
// Splat
if (llvm::isa<triton::SplatOp>(op)){
if (llvm::isa<triton::SplatOp>(op)) {
Type _retTy = *op->result_type_begin();
TensorType retTy = _retTy.cast<TensorType>();
AxisInfo opInfo = operands[0]->getValue();
AxisInfo::ContiguityT contiguity;
AxisInfo::DivisibilityT divisibility;
AxisInfo::ConstancyT constancy;
for(size_t d = 0; d < retTy.getRank(); d++){
for (size_t d = 0; d < retTy.getRank(); d++) {
contiguity.push_back(1);
divisibility.push_back(opInfo.getDivisibility(0));
constancy.push_back(retTy.getShape()[d]);
@@ -171,7 +168,7 @@ ChangeResult AxisInfoAnalysis::visitOperation(Operation *op,
}
// Reshape
// TODO: Replace by `unsqueeze`
if (llvm::isa<triton::ReshapeOp>(op)){
if (llvm::isa<triton::ReshapeOp>(op)) {
Type _retTy = *op->result_type_begin();
Type _opTy = *op->operand_type_begin();
TensorType retTy = _retTy.cast<TensorType>();
@@ -184,20 +181,17 @@ ChangeResult AxisInfoAnalysis::visitOperation(Operation *op,
AxisInfo::ConstancyT constancy;
bool is_skewed = false;
size_t current = 0;
for(size_t d = 0; d < retTy.getRank(); d++){
if(retShape[d] == 1){
for (size_t d = 0; d < retTy.getRank(); d++) {
if (retShape[d] == 1) {
contiguity.push_back(1);
divisibility.push_back(1);
constancy.push_back(1);
}
else if(!is_skewed
&& retShape[d] == opShape[current]){
} else if (!is_skewed && retShape[d] == opShape[current]) {
contiguity.push_back(opInfo.getContiguity()[current]);
divisibility.push_back(opInfo.getDivisibility()[current]);
constancy.push_back(opInfo.getConstancy()[current]);
current++;
}
else {
} else {
is_skewed = true;
contiguity.push_back(1);
divisibility.push_back(1);
@@ -207,7 +201,7 @@ ChangeResult AxisInfoAnalysis::visitOperation(Operation *op,
curr = AxisInfo(contiguity, divisibility, constancy);
}
// Broadcast
if (llvm::isa<triton::BroadcastOp>(op)){
if (llvm::isa<triton::BroadcastOp>(op)) {
Type _retTy = *op->result_type_begin();
Type _opTy = *op->operand_type_begin();
TensorType retTy = _retTy.cast<TensorType>();
@@ -218,14 +212,14 @@ ChangeResult AxisInfoAnalysis::visitOperation(Operation *op,
AxisInfo::ContiguityT contiguity;
AxisInfo::DivisibilityT divisibility;
AxisInfo::ConstancyT constancy;
for(size_t d = 0; d < retTy.getRank(); d++){
for (size_t d = 0; d < retTy.getRank(); d++) {
contiguity.push_back(opShape[d] == 1 ? 1 : opInfo.getContiguity(d));
divisibility.push_back(opInfo.getDivisibility(d));
constancy.push_back(opShape[d] == 1 ? retShape[d] : 1);
}
curr = AxisInfo(contiguity, divisibility, constancy);
}
if(curr.getRank() == 0){
if (curr.getRank() == 0) {
return markAllPessimisticFixpoint(op->getResults());
}
// join all latice elements
@@ -236,4 +230,4 @@ ChangeResult AxisInfoAnalysis::visitOperation(Operation *op,
return result;
}
}
} // namespace mlir

View File

@@ -2,14 +2,16 @@
#define TRITON_CONVERSION_PASSDETAIL_H
#include "mlir/Pass/Pass.h"
#include "triton/Dialect/Triton/IR/Dialect.h"
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
namespace mlir{
namespace triton{
namespace mlir {
namespace triton {
#define GEN_PASS_CLASSES
#include "triton/Conversion/Passes.h.inc"
}
}
} // namespace triton
} // namespace mlir
#endif

View File

@@ -1,42 +1,42 @@
#include "triton/Conversion/TritonToTritonGPU/TritonToTritonGPU.h"
#include "../PassDetail.h"
#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
#include "mlir/Transforms/DialectConversion.h"
#include "triton/Dialect/Triton/IR/Dialect.h"
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
#include "triton/Conversion/TritonToTritonGPU/TritonToTritonGPU.h"
#include "triton/Dialect/TritonGPU/Transforms/TritonGPUConversion.h"
#include "mlir/Transforms/DialectConversion.h"
#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
#include "../PassDetail.h"
using namespace mlir;
using namespace mlir::triton;
namespace {
template<class Op>
class ArithGenericPattern : public OpConversionPattern<Op> {
template <class Op> class ArithGenericPattern : public OpConversionPattern<Op> {
public:
using OpConversionPattern<Op>::OpConversionPattern;
LogicalResult matchAndRewrite(Op op, typename Op::Adaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
LogicalResult
matchAndRewrite(Op op, typename Op::Adaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
Type retType = this->getTypeConverter()->convertType(op.getType());
Op res = rewriter.replaceOpWithNewOp<Op>(
op, retType, adaptor.getOperands()
);
Op res =
rewriter.replaceOpWithNewOp<Op>(op, retType, adaptor.getOperands());
return success();
}
};
template<class SrcOp, class DstOp>
template <class SrcOp, class DstOp>
class ArithCmpPattern : public OpConversionPattern<SrcOp> {
public:
using OpConversionPattern<SrcOp>::OpConversionPattern;
LogicalResult matchAndRewrite(SrcOp op, typename SrcOp::Adaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
LogicalResult
matchAndRewrite(SrcOp op, typename SrcOp::Adaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
Type retType = this->getTypeConverter()->convertType(op.getType());
DstOp res = rewriter.replaceOpWithNewOp<DstOp>(
op, retType, adaptor.getPredicate(), adaptor.getLhs(), adaptor.getRhs()
);
DstOp res =
rewriter.replaceOpWithNewOp<DstOp>(op, retType, adaptor.getPredicate(),
adaptor.getLhs(), adaptor.getRhs());
return success();
}
};
@@ -45,36 +45,40 @@ class ArithConstantPattern : public OpConversionPattern<arith::ConstantOp> {
public:
using OpConversionPattern<arith::ConstantOp>::OpConversionPattern;
LogicalResult matchAndRewrite(arith::ConstantOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
LogicalResult
matchAndRewrite(arith::ConstantOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
Type retType = getTypeConverter()->convertType(op.getType());
auto value = adaptor.getValue().dyn_cast<DenseElementsAttr>();
assert(value);
rewriter.replaceOpWithNewOp<arith::ConstantOp>(
op, retType, value.reshape(retType) // This is a hack. We just want to add encoding
op, retType,
value.reshape(retType) // This is a hack. We just want to add encoding
);
return success();
}
};
class ConvertArithmeticOp: public ConversionPattern {
class ConvertArithmeticOp : public ConversionPattern {
public:
ConvertArithmeticOp(TritonGPUTypeConverter &typeConverter, MLIRContext *context)
: ConversionPattern(typeConverter, MatchAnyOpTypeTag(), /*benefit=*/1,
context) {}
ConvertArithmeticOp(TritonGPUTypeConverter &typeConverter,
MLIRContext *context)
: ConversionPattern(typeConverter, MatchAnyOpTypeTag(), /*benefit=*/1,
context) {}
LogicalResult matchAndRewrite(Operation* op, ArrayRef<Value> operands,
ConversionPatternRewriter& rewriter) const override {
Dialect* dialect = op->getDialect();
if(dialect->getTypeID() != mlir::TypeID::get<arith::ArithmeticDialect>())
return failure();
return success();
}
LogicalResult
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
Dialect *dialect = op->getDialect();
if (dialect->getTypeID() != mlir::TypeID::get<arith::ArithmeticDialect>())
return failure();
return success();
}
};
void populateArithmeticPatternsAndLegality(
TritonGPUTypeConverter& typeConverter, RewritePatternSet &patterns,
TritonGPUConversionTarget &target){
TritonGPUTypeConverter &typeConverter, RewritePatternSet &patterns,
TritonGPUConversionTarget &target) {
// --------------
// Add legality and rewrite pattern rules for operations
// from the Arithmetic dialect. The basic premise is that
@@ -91,59 +95,49 @@ void populateArithmeticPatternsAndLegality(
// );
// Rewrite rule
// patterns.add<ConvertArithmeticOp>(typeConverter, context);
patterns.add<ArithConstantPattern,
ArithGenericPattern<arith::AddIOp>,
ArithGenericPattern<arith::SubIOp>,
ArithGenericPattern<arith::MulIOp>,
ArithGenericPattern<arith::DivUIOp>,
ArithGenericPattern<arith::DivSIOp>,
ArithGenericPattern<arith::CeilDivUIOp>,
ArithGenericPattern<arith::CeilDivSIOp>,
ArithGenericPattern<arith::FloorDivSIOp>,
ArithGenericPattern<arith::RemUIOp>,
ArithGenericPattern<arith::RemSIOp>,
ArithGenericPattern<arith::AndIOp>,
ArithGenericPattern<arith::OrIOp>,
ArithGenericPattern<arith::XOrIOp>,
ArithGenericPattern<arith::ShLIOp>,
ArithGenericPattern<arith::ShRUIOp>,
ArithGenericPattern<arith::ShRSIOp>, // NegFOp
// Floating point
ArithGenericPattern<arith::AddFOp>,
ArithGenericPattern<arith::SubFOp>,
// MaxMin
ArithGenericPattern<arith::MaxFOp>,
ArithGenericPattern<arith::MaxSIOp>,
ArithGenericPattern<arith::MaxUIOp>,
ArithGenericPattern<arith::MinFOp>,
ArithGenericPattern<arith::MinSIOp>,
ArithGenericPattern<arith::MinUIOp>,
// Floating point
ArithGenericPattern<arith::MulFOp>,
ArithGenericPattern<arith::DivFOp>,
ArithGenericPattern<arith::RemFOp>,
// Cmp
ArithCmpPattern<arith::CmpIOp, triton::gpu::CmpIOp>,
ArithCmpPattern<arith::CmpFOp, triton::gpu::CmpFOp>,
// Cast Ops
ArithGenericPattern<arith::TruncIOp>,
ArithGenericPattern<arith::TruncFOp>
>(typeConverter, context);
patterns.add<
ArithConstantPattern, ArithGenericPattern<arith::AddIOp>,
ArithGenericPattern<arith::SubIOp>, ArithGenericPattern<arith::MulIOp>,
ArithGenericPattern<arith::DivUIOp>, ArithGenericPattern<arith::DivSIOp>,
ArithGenericPattern<arith::CeilDivUIOp>,
ArithGenericPattern<arith::CeilDivSIOp>,
ArithGenericPattern<arith::FloorDivSIOp>,
ArithGenericPattern<arith::RemUIOp>, ArithGenericPattern<arith::RemSIOp>,
ArithGenericPattern<arith::AndIOp>, ArithGenericPattern<arith::OrIOp>,
ArithGenericPattern<arith::XOrIOp>, ArithGenericPattern<arith::ShLIOp>,
ArithGenericPattern<arith::ShRUIOp>,
ArithGenericPattern<arith::ShRSIOp>, // NegFOp
// Floating point
ArithGenericPattern<arith::AddFOp>, ArithGenericPattern<arith::SubFOp>,
// MaxMin
ArithGenericPattern<arith::MaxFOp>, ArithGenericPattern<arith::MaxSIOp>,
ArithGenericPattern<arith::MaxUIOp>, ArithGenericPattern<arith::MinFOp>,
ArithGenericPattern<arith::MinSIOp>, ArithGenericPattern<arith::MinUIOp>,
// Floating point
ArithGenericPattern<arith::MulFOp>, ArithGenericPattern<arith::DivFOp>,
ArithGenericPattern<arith::RemFOp>,
// Cmp
ArithCmpPattern<arith::CmpIOp, triton::gpu::CmpIOp>,
ArithCmpPattern<arith::CmpFOp, triton::gpu::CmpFOp>,
// Cast Ops
ArithGenericPattern<arith::TruncIOp>,
ArithGenericPattern<arith::TruncFOp>>(typeConverter, context);
}
//
// Triton patterns
//
// TODO: Do we need to put them in anonymous namespace?
struct TritonMakeRangePattern : public OpConversionPattern<triton::MakeRangeOp> {
struct TritonMakeRangePattern
: public OpConversionPattern<triton::MakeRangeOp> {
using OpConversionPattern<triton::MakeRangeOp>::OpConversionPattern;
LogicalResult matchAndRewrite(triton::MakeRangeOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
LogicalResult
matchAndRewrite(triton::MakeRangeOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
Type retType = getTypeConverter()->convertType(op.getType());
rewriter.replaceOpWithNewOp<triton::MakeRangeOp>(
op, retType, adaptor.start(), adaptor.end()
);
op, retType, adaptor.start(), adaptor.end());
return success();
}
};
@@ -151,8 +145,9 @@ struct TritonMakeRangePattern : public OpConversionPattern<triton::MakeRangeOp>
struct TritonDotPattern : public OpConversionPattern<triton::DotOp> {
using OpConversionPattern<triton::DotOp>::OpConversionPattern;
LogicalResult matchAndRewrite(triton::DotOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
LogicalResult
matchAndRewrite(triton::DotOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
Type retType = getTypeConverter()->convertType(op.getType());
// a & b must be of smem layout
auto aType = adaptor.a().getType().cast<RankedTensorType>();
@@ -165,18 +160,21 @@ struct TritonDotPattern : public OpConversionPattern<triton::DotOp> {
Value b = adaptor.b();
SmallVector<unsigned, 2> order{1, 0};
if (!aEncoding.isa<triton::gpu::TritonGPUSharedEncodingAttr>()) {
Attribute encoding = triton::gpu::TritonGPUSharedEncodingAttr::get(getContext(), 1, 1, 1, order);
auto dstType = RankedTensorType::get(aType.getShape(), aType.getElementType(), encoding);
Attribute encoding = triton::gpu::TritonGPUSharedEncodingAttr::get(
getContext(), 1, 1, 1, order);
auto dstType = RankedTensorType::get(aType.getShape(),
aType.getElementType(), encoding);
a = rewriter.create<triton::gpu::ConvertLayoutOp>(a.getLoc(), dstType, a);
}
if (!bEncoding.isa<triton::gpu::TritonGPUSharedEncodingAttr>()) {
Attribute encoding = triton::gpu::TritonGPUSharedEncodingAttr::get(getContext(), 1, 1, 1, order);
auto dstType = RankedTensorType::get(bType.getShape(), bType.getElementType(), encoding);
Attribute encoding = triton::gpu::TritonGPUSharedEncodingAttr::get(
getContext(), 1, 1, 1, order);
auto dstType = RankedTensorType::get(bType.getShape(),
bType.getElementType(), encoding);
b = rewriter.create<triton::gpu::ConvertLayoutOp>(b.getLoc(), dstType, b);
}
auto newDot = rewriter.replaceOpWithNewOp<triton::DotOp>(
op, retType, a, b, adaptor.c(), adaptor.allowTF32()
);
op, retType, a, b, adaptor.c(), adaptor.allowTF32());
return success();
}
};
@@ -184,14 +182,13 @@ struct TritonDotPattern : public OpConversionPattern<triton::DotOp> {
struct TritonLoadPattern : public OpConversionPattern<triton::LoadOp> {
using OpConversionPattern<triton::LoadOp>::OpConversionPattern;
LogicalResult matchAndRewrite(triton::LoadOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
LogicalResult
matchAndRewrite(triton::LoadOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
Type retType = getTypeConverter()->convertType(op.getType());
rewriter.replaceOpWithNewOp<triton::LoadOp>(
op, retType,
adaptor.ptr(), adaptor.mask(), adaptor.other(),
adaptor.cache(), adaptor.evict(), adaptor.isVolatile()
);
op, retType, adaptor.ptr(), adaptor.mask(), adaptor.other(),
adaptor.cache(), adaptor.evict(), adaptor.isVolatile());
return success();
}
};
@@ -199,11 +196,11 @@ struct TritonLoadPattern : public OpConversionPattern<triton::LoadOp> {
struct TritonStorePattern : public OpConversionPattern<triton::StoreOp> {
using OpConversionPattern<triton::StoreOp>::OpConversionPattern;
LogicalResult matchAndRewrite(triton::StoreOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
LogicalResult
matchAndRewrite(triton::StoreOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto newOp = rewriter.replaceOpWithNewOp<triton::StoreOp>(
op, adaptor.ptr(), adaptor.value(), adaptor.mask()
);
op, adaptor.ptr(), adaptor.value(), adaptor.mask());
return success();
}
};
@@ -212,12 +209,11 @@ template <class Op>
struct TritonGenericPattern : public OpConversionPattern<Op> {
using OpConversionPattern<Op>::OpConversionPattern;
LogicalResult matchAndRewrite(Op op, typename Op::Adaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
LogicalResult
matchAndRewrite(Op op, typename Op::Adaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
Type retType = this->getTypeConverter()->convertType(op.getType());
rewriter.replaceOpWithNewOp<Op>(
op, retType, adaptor.getOperands()
);
rewriter.replaceOpWithNewOp<Op>(op, retType, adaptor.getOperands());
return success();
}
};
@@ -225,30 +221,25 @@ struct TritonGenericPattern : public OpConversionPattern<Op> {
struct TritonReducePattern : public OpConversionPattern<triton::ReduceOp> {
using OpConversionPattern<triton::ReduceOp>::OpConversionPattern;
LogicalResult matchAndRewrite(triton::ReduceOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
LogicalResult
matchAndRewrite(triton::ReduceOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
Type retType = this->getTypeConverter()->convertType(op.getType());
auto newOp = rewriter.replaceOpWithNewOp<triton::ReduceOp>(
op, retType, adaptor.redOp(), adaptor.operand(), adaptor.axis()
);
op, retType, adaptor.redOp(), adaptor.operand(), adaptor.axis());
return success();
}
};
void populateTritonPatterns(
TritonGPUTypeConverter& typeConverter, RewritePatternSet &patterns
) {
void populateTritonPatterns(TritonGPUTypeConverter &typeConverter,
RewritePatternSet &patterns) {
MLIRContext *context = patterns.getContext();
patterns.add<TritonGenericPattern<triton::ReshapeOp>,
TritonGenericPattern<triton::SplatOp>,
TritonGenericPattern<triton::BroadcastOp>,
TritonGenericPattern<triton::GEPOp>,
TritonReducePattern,
TritonMakeRangePattern,
TritonDotPattern,
TritonLoadPattern,
TritonStorePattern
>(typeConverter, context);
TritonGenericPattern<triton::GEPOp>, TritonReducePattern,
TritonMakeRangePattern, TritonDotPattern, TritonLoadPattern,
TritonStorePattern>(typeConverter, context);
}
//
@@ -259,17 +250,19 @@ void populateTritonPatterns(
struct SCFForPattern : public OpConversionPattern<scf::ForOp> {
using OpConversionPattern<scf::ForOp>::OpConversionPattern;
// Ref: ConvertForOpTypes
LogicalResult matchAndRewrite(scf::ForOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto newOp = cast<scf::ForOp>(rewriter.cloneWithoutRegions(*op.getOperation()));
LogicalResult
matchAndRewrite(scf::ForOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto newOp =
cast<scf::ForOp>(rewriter.cloneWithoutRegions(*op.getOperation()));
rewriter.inlineRegionBefore(op.getLoopBody(), newOp.getLoopBody(),
newOp.getLoopBody().end());
// Now, update all the types.
// Convert the types of block arguments within the given region. This
// replaces each block with a new block containing the updated signature. The
// entry block may have a special conversion if `entryConversion` is
// replaces each block with a new block containing the updated signature.
// The entry block may have a special conversion if `entryConversion` is
// provided. On success, the new entry block to the region is returned for
// convenience. Otherwise, failure is returned.
if (failed(rewriter.convertRegionTypes(&newOp.getLoopBody(),
@@ -299,33 +292,27 @@ struct SCFForPattern : public OpConversionPattern<scf::ForOp> {
struct SCFYieldPattern : public OpConversionPattern<scf::YieldOp> {
using OpConversionPattern<scf::YieldOp>::OpConversionPattern;
LogicalResult matchAndRewrite(scf::YieldOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
LogicalResult
matchAndRewrite(scf::YieldOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
// rewriter.setInsertionPointToEnd(rewriter.getInsertionBlock());
// rewriter.create<scf::YieldOp>(op.getLoc(), adaptor.getOperands());
// op.erase();
rewriter.replaceOpWithNewOp<scf::YieldOp>(
op, adaptor.getOperands()
);
rewriter.replaceOpWithNewOp<scf::YieldOp>(op, adaptor.getOperands());
return success();
}
};
void populateSCFPatterns(
TritonGPUTypeConverter &typeConverter, RewritePatternSet &patterns
) {
void populateSCFPatterns(TritonGPUTypeConverter &typeConverter,
RewritePatternSet &patterns) {
MLIRContext *context = patterns.getContext();
patterns.add<SCFYieldPattern, SCFForPattern
>(typeConverter, context);
patterns.add<SCFYieldPattern, SCFForPattern>(typeConverter, context);
}
class ConvertTritonToTritonGPU :
public ConvertTritonToTritonGPUBase<ConvertTritonToTritonGPU> {
class ConvertTritonToTritonGPU
: public ConvertTritonToTritonGPUBase<ConvertTritonToTritonGPU> {
public:
ConvertTritonToTritonGPU(int numWarps) {
this->numWarps = numWarps;
}
ConvertTritonToTritonGPU(int numWarps) { this->numWarps = numWarps; }
void runOnOperation() override {
MLIRContext *context = &getContext();
@@ -339,21 +326,21 @@ public:
// add rules
populateArithmeticPatternsAndLegality(typeConverter, patterns, target);
populateTritonPatterns(typeConverter, patterns);
// TODO: can we use
// TODO: can we use
// mlir::scf::populateSCFStructurealTypeConversionsAndLegality(...) here?
populateSCFPatterns(typeConverter, patterns);
if(failed(applyPartialConversion(mod, target, std::move(patterns))))
return signalPassFailure();
if (failed(applyPartialConversion(mod, target, std::move(patterns))))
return signalPassFailure();
// update layouts
// broadcast src => multicast, dst => broadcasted
if(failed(target.refineLayouts(mod, numWarps)))
if (failed(target.refineLayouts(mod, numWarps)))
return signalPassFailure();
}
};
}
} // namespace
std::unique_ptr<OperationPass<ModuleOp>>
mlir::triton::createConvertTritonToTritonGPUPass(int numWarps) {

View File

@@ -7,7 +7,6 @@
#include "mlir/IR/DialectImplementation.h"
#include "triton/Dialect/Triton/IR/Dialect.cpp.inc"
using namespace mlir;
@@ -19,12 +18,13 @@ void TritonDialect::initialize() {
addOperations<
#define GET_OP_LIST
#include "triton/Dialect/Triton/IR/Ops.cpp.inc"
>();
>();
// We can also add interface here.
}
Operation *TritonDialect::materializeConstant(OpBuilder &builder, Attribute value,
Type type, Location loc) {
Operation *TritonDialect::materializeConstant(OpBuilder &builder,
Attribute value, Type type,
Location loc) {
return builder.create<arith::ConstantOp>(loc, type, value);
}

View File

@@ -13,14 +13,16 @@ namespace triton {
static Type getI1SameShape(Type type) {
auto i1Type = IntegerType::get(type.getContext(), 1);
if (auto tensorType = type.dyn_cast<RankedTensorType>())
return RankedTensorType::get(tensorType.getShape(), i1Type, tensorType.getEncoding());
return RankedTensorType::get(tensorType.getShape(), i1Type,
tensorType.getEncoding());
return Type();
}
static Type getI32SameShape(Type type) {
auto i32Type = IntegerType::get(type.getContext(), 32);
if (auto tensorType = type.dyn_cast<RankedTensorType>())
return RankedTensorType::get(tensorType.getShape(), i32Type, tensorType.getEncoding());
return RankedTensorType::get(tensorType.getShape(), i32Type,
tensorType.getEncoding());
return Type();
}
@@ -34,8 +36,8 @@ static Type getPointerTypeFromTensor(Type type) {
return Type();
}
}
}
} // namespace triton
} // namespace mlir
#define GET_OP_CLASSES
#include "triton/Dialect/Triton/IR/Ops.cpp.inc"
@@ -48,50 +50,48 @@ namespace triton {
//-- StoreOp --
// Default mask
void StoreOp::build(::mlir::OpBuilder &builder, ::mlir::OperationState &state, ::mlir::Value ptr, ::mlir::Value value) {
void StoreOp::build(::mlir::OpBuilder &builder, ::mlir::OperationState &state,
::mlir::Value ptr, ::mlir::Value value) {
TensorType ptrType = ptr.getType().dyn_cast<TensorType>();
auto shape = ptrType.getShape();
::mlir::Value mask = builder.create<arith::ConstantOp>(
ptr.getLoc(),
RankedTensorType::get(shape, builder.getI1Type()),
DenseIntElementsAttr::get(
RankedTensorType::get(shape, builder.getI1Type()), true
)
);
ptr.getLoc(), RankedTensorType::get(shape, builder.getI1Type()),
DenseIntElementsAttr::get(
RankedTensorType::get(shape, builder.getI1Type()), true));
state.addOperands(ptr);
state.addOperands(value);
state.addOperands(mask);
}
//-- LoadOp --
void LoadOp::build(::mlir::OpBuilder &builder, ::mlir::OperationState &state, ::mlir::Value ptr,
::mlir::triton::CacheModifier cache, ::mlir::triton::EvictionPolicy evict, bool isVolatile) {
void LoadOp::build(::mlir::OpBuilder &builder, ::mlir::OperationState &state,
::mlir::Value ptr, ::mlir::triton::CacheModifier cache,
::mlir::triton::EvictionPolicy evict, bool isVolatile) {
TensorType ptrType = ptr.getType().dyn_cast<TensorType>();
Type elementType = ptrType.getElementType().dyn_cast<PointerType>().getPointeeType();
Type elementType =
ptrType.getElementType().dyn_cast<PointerType>().getPointeeType();
auto shape = ptrType.getShape();
// mask
::mlir::Value mask = builder.create<arith::ConstantOp>(
ptr.getLoc(),
RankedTensorType::get(shape, builder.getI1Type()),
DenseIntElementsAttr::get(
RankedTensorType::get(shape, builder.getI1Type()), true
)
);
ptr.getLoc(), RankedTensorType::get(shape, builder.getI1Type()),
DenseIntElementsAttr::get(
RankedTensorType::get(shape, builder.getI1Type()), true));
// other
Type resultType = RankedTensorType::get(shape, elementType);
::mlir::Value other = builder.create<arith::ConstantOp>(
ptr.getLoc(),
resultType,
DenseElementsAttr::get(
resultType, builder.getZeroAttr(elementType)
)
);
ptr.getLoc(), resultType,
DenseElementsAttr::get(resultType, builder.getZeroAttr(elementType)));
state.addOperands(ptr);
state.addOperands(mask);
state.addOperands(other);
state.addAttribute(cacheAttrName(state.name), ::mlir::triton::CacheModifierAttr::get(builder.getContext(), cache));
state.addAttribute(evictAttrName(state.name), ::mlir::triton::EvictionPolicyAttr::get(builder.getContext(), evict));
state.addAttribute(isVolatileAttrName(state.name), builder.getBoolAttr(isVolatile));
state.addAttribute(
cacheAttrName(state.name),
::mlir::triton::CacheModifierAttr::get(builder.getContext(), cache));
state.addAttribute(
evictAttrName(state.name),
::mlir::triton::EvictionPolicyAttr::get(builder.getContext(), evict));
state.addAttribute(isVolatileAttrName(state.name),
builder.getBoolAttr(isVolatile));
state.addTypes({resultType});
}

View File

@@ -1,6 +1,6 @@
#include "triton/Dialect/Triton/IR/Dialect.h"
#include "triton/Dialect/Triton/IR/Types.h"
#include "mlir/IR/DialectImplementation.h" // required by `Types.cpp.inc`
#include "triton/Dialect/Triton/IR/Dialect.h"
#include "llvm/ADT/TypeSwitch.h" // required by `Types.cpp.inc`
using namespace mlir;
@@ -16,7 +16,7 @@ void TritonDialect::registerTypes() {
addTypes<
#define GET_TYPEDEF_LIST
#include "triton/Dialect/Triton/IR/Types.cpp.inc"
>();
>();
}
Type PointerType::parse(AsmParser &parser) {

View File

@@ -17,21 +17,23 @@ namespace {
class CombineDotOp : public mlir::RewritePattern {
public:
CombineDotOp(mlir::MLIRContext *context)
: mlir::RewritePattern(mlir::RewritePattern::MatchAnyOpTypeTag(), 1, context) {}
mlir::LogicalResult matchAndRewrite(mlir::Operation *op,
mlir::PatternRewriter &rewriter) const override {
: mlir::RewritePattern(mlir::RewritePattern::MatchAnyOpTypeTag(), 1,
context) {}
mlir::LogicalResult
matchAndRewrite(mlir::Operation *op,
mlir::PatternRewriter &rewriter) const override {
if (llvm::isa<mlir::arith::AddIOp, mlir::arith::AddFOp>(op)) {
if (isCandidate(op->getOperand(0)).succeeded()) {
auto dotOp = op->getOperand(0).getDefiningOp<mlir::triton::DotOp>();
rewriter.replaceOpWithNewOp<mlir::triton::DotOp>(
op, dotOp->getResultTypes().front(), dotOp.a(),
dotOp.b(), op->getOperand(1), dotOp.allowTF32());
op, dotOp->getResultTypes().front(), dotOp.a(), dotOp.b(),
op->getOperand(1), dotOp.allowTF32());
return mlir::success();
} else if (isCandidate(op->getOperand(1)).succeeded()) {
auto dotOp = op->getOperand(1).getDefiningOp<mlir::triton::DotOp>();
rewriter.replaceOpWithNewOp<mlir::triton::DotOp>(
op, dotOp->getResultTypes().front(), dotOp.a(),
dotOp.b(), op->getOperand(0), dotOp.allowTF32());
op, dotOp->getResultTypes().front(), dotOp.a(), dotOp.b(),
op->getOperand(0), dotOp.allowTF32());
return mlir::success();
}
}
@@ -54,7 +56,7 @@ private:
return true;
// broadcast(constant_0)
if (auto bc = val.getDefiningOp<mlir::triton::BroadcastOp>()) {
if (mlir::matchPattern(bc.src(), mlir::m_Zero()) ||
if (mlir::matchPattern(bc.src(), mlir::m_Zero()) ||
mlir::matchPattern(bc.src(), mlir::m_AnyZeroFloat()))
return true;
}
@@ -68,18 +70,19 @@ private:
class CombineGEPOp : public mlir::RewritePattern {
public:
CombineGEPOp(mlir::MLIRContext *context)
: mlir::RewritePattern(mlir::RewritePattern::MatchAnyOpTypeTag(), 1, context) {}
: mlir::RewritePattern(mlir::RewritePattern::MatchAnyOpTypeTag(), 1,
context) {}
mlir::LogicalResult matchAndRewrite(mlir::Operation *op,
mlir::PatternRewriter &rewriter) const override {
mlir::LogicalResult
matchAndRewrite(mlir::Operation *op,
mlir::PatternRewriter &rewriter) const override {
if (llvm::isa<mlir::triton::GEPOp>(op)) {
if (auto gep2 = op->getOperand(0).getDefiningOp<mlir::triton::GEPOp>()) {
auto loc = op->getLoc();
mlir::Value newIdx = rewriter.create<mlir::arith::AddIOp>(
loc, op->getOperand(1), gep2->getOperand(1));
loc, op->getOperand(1), gep2->getOperand(1));
rewriter.replaceOpWithNewOp<mlir::triton::GEPOp>(
op, op->getResultTypes().front(), gep2->getOperand(0), newIdx
);
op, op->getResultTypes().front(), gep2->getOperand(0), newIdx);
return mlir::success();
}
}
@@ -92,20 +95,21 @@ public:
class CombineSelectMaskedLoadOp : public mlir::RewritePattern {
public:
CombineSelectMaskedLoadOp(mlir::MLIRContext *context)
: mlir::RewritePattern(mlir::RewritePattern::MatchAnyOpTypeTag(), 1, context) {}
: mlir::RewritePattern(mlir::RewritePattern::MatchAnyOpTypeTag(), 1,
context) {}
mlir::LogicalResult matchAndRewrite(mlir::Operation *op,
mlir::PatternRewriter &rewriter) const override {
mlir::LogicalResult
matchAndRewrite(mlir::Operation *op,
mlir::PatternRewriter &rewriter) const override {
if (llvm::isa<mlir::SelectOp>(op)) {
if (auto load = op->getOperand(1).getDefiningOp<mlir::triton::LoadOp>()) {
mlir::Value cond = op->getOperand(0);
if (auto bc = load.mask().getDefiningOp<mlir::triton::BroadcastOp>()) {
if (bc.src().getDefiningOp() == cond.getDefiningOp()) {
rewriter.replaceOpWithNewOp<mlir::triton::LoadOp>(
op, op->getResultTypes().front(),
load.ptr(), load.mask(), op->getOperand(2),
load.cache(), load.evict(), load.isVolatile()
);
op, op->getResultTypes().front(), load.ptr(), load.mask(),
op->getOperand(2), load.cache(), load.evict(),
load.isVolatile());
return mlir::success();
}
}
@@ -120,11 +124,11 @@ public:
class CombineBroadcastConstantOp : public mlir::RewritePattern {
public:
CombineBroadcastConstantOp(mlir::MLIRContext *context)
: mlir::RewritePattern(mlir::RewritePattern::MatchAnyOpTypeTag(), 1,
context) {}
: mlir::RewritePattern(mlir::RewritePattern::MatchAnyOpTypeTag(), 1,
context) {}
LogicalResult matchAndRewrite(Operation *op,
PatternRewriter &rewriter) const override {
PatternRewriter &rewriter) const override {
if (auto broadcast = llvm::dyn_cast<triton::BroadcastOp>(op)) {
if (auto cst = broadcast.src().getDefiningOp<arith::ConstantOp>()) {
Attribute value = cst.getValue();
@@ -132,15 +136,14 @@ public:
if (auto denseValue = value.dyn_cast<DenseElementsAttr>()) {
if (!denseValue.isSplat())
return failure();
value = DenseElementsAttr::get(resType, denseValue.getSplatValue<Attribute>());
value = DenseElementsAttr::get(resType,
denseValue.getSplatValue<Attribute>());
} else {
if (!value.isa<FloatAttr, IntegerAttr>())
return failure();
value = DenseElementsAttr::get(resType, value);
}
rewriter.replaceOpWithNewOp<arith::ConstantOp>(
op, value, resType
);
rewriter.replaceOpWithNewOp<arith::ConstantOp>(op, value, resType);
return success();
}
}

View File

@@ -11,19 +11,18 @@ using namespace mlir::triton::gpu;
// parse an array of integers
static LogicalResult parseIntArrayAttr(AsmParser &parser,
const NamedAttribute &attr,
/*SmallVector<unsigned, 2>*/auto &res,
StringRef desc) {
/*SmallVector<unsigned, 2>*/ auto &res,
StringRef desc) {
auto arrayAttr = attr.getValue().dyn_cast<ArrayAttr>();
if (!arrayAttr) {
parser.emitError(parser.getNameLoc(), "expected an array for ")
<< desc;
parser.emitError(parser.getNameLoc(), "expected an array for ") << desc;
return failure();
}
for (Attribute i : arrayAttr) {
auto intAttr = i.dyn_cast<IntegerAttr>();
if (!intAttr) {
parser.emitError(parser.getNameLoc(), "expected an integer value in ")
<< desc;
<< desc;
return failure();
}
res.push_back(intAttr.getUInt());
@@ -46,7 +45,7 @@ static Attribute parseBlocked(AsmParser &parser, Type type) {
return {};
if (parser.parseGreater().failed())
return {};
SmallVector<unsigned, 2> threadTileSize;
SmallVector<unsigned, 2> warpTileSize;
SmallVector<unsigned, 2> blockTileSize;
@@ -55,19 +54,23 @@ static Attribute parseBlocked(AsmParser &parser, Type type) {
for (const NamedAttribute &attr : dict) {
if (attr.getName() == "threadTileSize") {
if (parseIntArrayAttr(parser, attr, threadTileSize, "thread tile size").failed())
if (parseIntArrayAttr(parser, attr, threadTileSize, "thread tile size")
.failed())
return {};
} else if (attr.getName() == "warpTileSize") {
if (parseIntArrayAttr(parser, attr, warpTileSize, "warp tile size").failed())
if (parseIntArrayAttr(parser, attr, warpTileSize, "warp tile size")
.failed())
return {};
} else if (attr.getName() == "blockTileSize") {
if (parseIntArrayAttr(parser, attr, blockTileSize, "block tile size").failed())
if (parseIntArrayAttr(parser, attr, blockTileSize, "block tile size")
.failed())
return {};
} else if (attr.getName() == "order") {
if (parseIntArrayAttr(parser, attr, order, "order").failed())
return {};
} else if (attr.getName() == "broadcastAxis") {
if (parseIntArrayAttr(parser, attr, broadcastAxis, "broadcastAxis").failed())
if (parseIntArrayAttr(parser, attr, broadcastAxis, "broadcastAxis")
.failed())
return {};
} else {
parser.emitError(parser.getNameLoc(), "unexpected key: ")
@@ -76,12 +79,9 @@ static Attribute parseBlocked(AsmParser &parser, Type type) {
}
}
return parser.getChecked<TritonGPUBlockedEncodingAttr>(parser.getContext(),
threadTileSize,
warpTileSize,
blockTileSize,
order,
broadcastAxis);
return parser.getChecked<TritonGPUBlockedEncodingAttr>(
parser.getContext(), threadTileSize, warpTileSize, blockTileSize, order,
broadcastAxis);
}
static void printBlocked(AsmPrinter &printer, auto *attr) {
@@ -94,8 +94,7 @@ static void printBlocked(AsmPrinter &printer, auto *attr) {
<< "}>";
}
Attribute
TritonGPUBlockedEncodingAttr::parse(AsmParser &parser, Type type) {
Attribute TritonGPUBlockedEncodingAttr::parse(AsmParser &parser, Type type) {
parseBlocked(parser, type);
}
@@ -103,8 +102,8 @@ void TritonGPUBlockedEncodingAttr::print(mlir::AsmPrinter &printer) const {
printBlocked(printer, this);
}
Attribute
TritonGPUBlockedMulticastEncodingAttr::parse(AsmParser &parser, Type type) {
Attribute TritonGPUBlockedMulticastEncodingAttr::parse(AsmParser &parser,
Type type) {
parseBlocked(parser, type);
}
@@ -131,38 +130,37 @@ static Attribute parseMma(AsmParser &parser, Type type) {
for (const NamedAttribute &attr : dict) {
if (attr.getName() == "fragmentPerWarp") {
if (parseIntArrayAttr(parser, attr, fragmentPerWarp, "fragmentPerWarp").failed())
if (parseIntArrayAttr(parser, attr, fragmentPerWarp, "fragmentPerWarp")
.failed())
return {};
} else if (attr.getName() == "shapePerWarp") {
if (parseIntArrayAttr(parser, attr, shapePerWarp, "shapePerWarp").failed())
if (parseIntArrayAttr(parser, attr, shapePerWarp, "shapePerWarp")
.failed())
return {};
} else if (attr.getName() == "warpPerTile") {
if (parseIntArrayAttr(parser, attr, warpPerTile, "warpPerTile").failed())
return {};
} else if (attr.getName() == "shapePerTile") {
if (parseIntArrayAttr(parser, attr, shapePerTile, "shapePerTile").failed())
if (parseIntArrayAttr(parser, attr, shapePerTile, "shapePerTile")
.failed())
return {};
} else if (attr.getName() == "repetitions") {
if (parseIntArrayAttr(parser, attr, repetitions, "repetitions").failed())
return {};
} else if (attr.getName() == "contigPerThread") {
if (parseIntArrayAttr(parser, attr, contigPerThread, "contigPerThread").failed())
if (parseIntArrayAttr(parser, attr, contigPerThread, "contigPerThread")
.failed())
return {};
} else {
parser.emitError(parser.getNameLoc(), "unexpected key: ")
<< attr.getName().strref();
<< attr.getName().strref();
return {};
}
}
return parser.getChecked<TritonGPUMmaEncodingAttr>(parser.getContext(),
fragmentPerWarp,
shapePerWarp,
warpPerTile,
shapePerTile,
repetitions,
contigPerThread,
broadcastAxis);
return parser.getChecked<TritonGPUMmaEncodingAttr>(
parser.getContext(), fragmentPerWarp, shapePerWarp, warpPerTile,
shapePerTile, repetitions, contigPerThread, broadcastAxis);
}
static void printMma(AsmPrinter &printer, auto *attr) {
@@ -176,8 +174,7 @@ static void printMma(AsmPrinter &printer, auto *attr) {
<< "}>";
}
Attribute
TritonGPUMmaEncodingAttr::parse(AsmParser &parser, Type type) {
Attribute TritonGPUMmaEncodingAttr::parse(AsmParser &parser, Type type) {
return parseMma(parser, type);
}
@@ -185,8 +182,8 @@ void TritonGPUMmaEncodingAttr::print(AsmPrinter &printer) const {
printMma(printer, this);
}
Attribute
TritonGPUMmaMulticastEncodingAttr::parse(AsmParser &parser, Type type) {
Attribute TritonGPUMmaMulticastEncodingAttr::parse(AsmParser &parser,
Type type) {
return parseMma(parser, type);
}
@@ -194,8 +191,7 @@ void TritonGPUMmaMulticastEncodingAttr::print(AsmPrinter &printer) const {
printMma(printer, this);
}
Attribute
TritonGPUSharedEncodingAttr::parse(AsmParser &parser, Type type) {
Attribute TritonGPUSharedEncodingAttr::parse(AsmParser &parser, Type type) {
if (parser.parseLess().failed())
return {};
// Parse the data as a dictionary
@@ -210,8 +206,7 @@ TritonGPUSharedEncodingAttr::parse(AsmParser &parser, Type type) {
unsigned maxPhase = 0;
SmallVector<unsigned, 2> order;
auto parseUInt = [&parser](const NamedAttribute &attr,
unsigned &value,
auto parseUInt = [&parser](const NamedAttribute &attr, unsigned &value,
StringRef desc) -> LogicalResult {
auto intAttr = attr.getValue().dyn_cast<IntegerAttr>();
if (!intAttr) {
@@ -237,29 +232,25 @@ TritonGPUSharedEncodingAttr::parse(AsmParser &parser, Type type) {
return {};
} else {
parser.emitError(parser.getNameLoc(), "unexpected key: ")
<< attr.getName().strref();
<< attr.getName().strref();
return {};
}
}
return parser.getChecked<TritonGPUSharedEncodingAttr>(parser.getContext(),
vec,
perPhase,
maxPhase,
order);
return parser.getChecked<TritonGPUSharedEncodingAttr>(
parser.getContext(), vec, perPhase, maxPhase, order);
}
void TritonGPUSharedEncodingAttr::print(AsmPrinter &printer) const {
printer << "<{"
<< "vec = " << getVec()
<< ", perPhase = " << getPerPhase()
<< ", maxPhase = " << getMaxPhase()
<< ", order = [" << getOrder() << "]"
<< "vec = " << getVec() << ", perPhase = " << getPerPhase()
<< ", maxPhase = " << getMaxPhase() << ", order = [" << getOrder()
<< "]"
<< "}>";
}
class TritonGPUOpAsmInterface : public OpAsmDialectInterface {
public:
public:
using OpAsmDialectInterface::OpAsmDialectInterface;
AliasResult getAlias(Attribute attr, raw_ostream &os) const override {
@@ -289,7 +280,7 @@ class TritonGPUOpAsmInterface : public OpAsmDialectInterface {
OpAsmDialectInterface::getAlias(attr, os);
}
private:
private:
static void printMma(const auto &attr, raw_ostream &os) {
TritonGPUOpAsmInterface::printArray(attr.getFragmentPerWarp(), os);
TritonGPUOpAsmInterface::printArray(attr.getShapePerWarp(), os);
@@ -338,7 +329,7 @@ void TritonGPUDialect::initialize() {
addOperations<
#define GET_OP_LIST
#include "triton/Dialect/TritonGPU/IR/Ops.cpp.inc"
>();
>();
addInterfaces<TritonGPUOpAsmInterface>();
}
@@ -349,7 +340,8 @@ namespace triton {
static Type getI1SameShape(Type type) {
auto i1Type = IntegerType::get(type.getContext(), 1);
if (auto tensorType = type.dyn_cast<RankedTensorType>())
return RankedTensorType::get(tensorType.getShape(), i1Type, tensorType.getEncoding());
return RankedTensorType::get(tensorType.getShape(), i1Type,
tensorType.getEncoding());
return Type();
}
@@ -368,8 +360,8 @@ static Type getPointeeType(Type type) {
return Type();
}
}
}
} // namespace triton
} // namespace mlir
static LogicalResult verify(CopyAsyncOp op) {
Type resType = op.getResult().getType();
@@ -385,11 +377,9 @@ static LogicalResult verify(CopyAsyncOp op) {
#define GET_OP_CLASSES
#include "triton/Dialect/TritonGPU/IR/Ops.cpp.inc"
// verify TritonGPU ops
LogicalResult
TritonGPUDialect::verifyOperationAttribute(Operation *op,
NamedAttribute attr) {
LogicalResult TritonGPUDialect::verifyOperationAttribute(Operation *op,
NamedAttribute attr) {
// TODO: fill this.
return success();
}

View File

@@ -27,8 +27,8 @@ namespace {
#define GEN_PASS_CLASSES
#include "triton/Dialect/TritonGPU/Transforms/Passes.h.inc"
class TritonGPUCombineOpsPass
: public TritonGPUCombineOpsBase<TritonGPUCombineOpsPass> {
class TritonGPUCombineOpsPass
: public TritonGPUCombineOpsBase<TritonGPUCombineOpsPass> {
public:
void runOnOperation() override {
MLIRContext *context = &getContext();

View File

@@ -6,12 +6,11 @@
//===----------------------------------------------------------------------===//
//
// This file implements loop software pipelining
// The implementation here is inspired by the pipeline pass in Triton (-v2.0)
// The implementation here is inspired by the pipeline pass in Triton (-v2.0)
// and SCF's LoopPipelining.
//
//===----------------------------------------------------------------------===//
using namespace mlir;
#define GEN_PASS_CLASSES
@@ -41,14 +40,15 @@ class LoopPipeliner {
/// Block arguments that loads depend on
DenseSet<BlockArgument> depArgs;
/// Operations (inside the loop body) that loads depend on
DenseSet<Operation*> depOps;
DenseSet<Operation *> depOps;
/// collect values that v depends on and are defined inside the loop
void collectDeps(Value v, int stages, DenseSet<Value> &deps);
void setValueMapping(Value origin, Value newValue, int stage);
public:
LoopPipeliner(scf::ForOp forOp, int numStages)
LoopPipeliner(scf::ForOp forOp, int numStages)
: forOp(forOp), numStages(numStages) {
// cache yieldOp
yieldOp = cast<scf::YieldOp>(forOp.getBody()->getTerminator());
@@ -86,7 +86,7 @@ void LoopPipeliner::collectDeps(Value v, int stages, DenseSet<Value> &deps) {
if (auto arg = v.dyn_cast<BlockArgument>()) {
deps.insert(v);
// Note: we have iv as the first arg, so the op idx is arg.getArgNumber()-1
collectDeps(yieldOp->getOperand(arg.getArgNumber() - 1), stages-1, deps);
collectDeps(yieldOp->getOperand(arg.getArgNumber() - 1), stages - 1, deps);
} else { // value
// v might be in deps, but we still need to visit v.
// This is because v might depends on value in previous iterations
@@ -123,8 +123,8 @@ LogicalResult LoopPipeliner::initialize() {
}
// for (triton::LoadOp loadOp : allLoads) {
// llvm::errs() << loadOp << " depends on: #" << loadDeps[loadOp].size() << " values\n";
// for (Value dep : loadDeps[loadOp])
// llvm::errs() << loadOp << " depends on: #" << loadDeps[loadOp].size() <<
// " values\n"; for (Value dep : loadDeps[loadOp])
// llvm::errs() << dep << "\n";
// llvm::errs() << "\n";
// }
@@ -147,9 +147,13 @@ LogicalResult LoopPipeliner::initialize() {
if (isCandiate && loadOp.getResult().hasOneUse()) {
isCandiate = false;
Operation *use = *loadOp.getResult().getUsers().begin();
if (auto convertLayout = llvm::dyn_cast<triton::gpu::ConvertLayoutOp>(use)) {
if (auto tensorType = convertLayout.getResult().getType().dyn_cast<RankedTensorType>()) {
if (tensorType.getEncoding().isa<triton::gpu::TritonGPUSharedEncodingAttr>()) {
if (auto convertLayout =
llvm::dyn_cast<triton::gpu::ConvertLayoutOp>(use)) {
if (auto tensorType = convertLayout.getResult()
.getType()
.dyn_cast<RankedTensorType>()) {
if (tensorType.getEncoding()
.isa<triton::gpu::TritonGPUSharedEncodingAttr>()) {
isCandiate = true;
loadsMapping[loadOp] = convertLayout;
}
@@ -162,7 +166,6 @@ LogicalResult LoopPipeliner::initialize() {
loads.insert(loadOp);
}
// we have some loads to pipeline
if (!loads.empty()) {
// update depArgs & depOps
@@ -202,10 +205,10 @@ void LoopPipeliner::emitPrologue() {
// special handling for loop condition as there is no condition in ForOp
Value loopCond = builder.create<arith::CmpIOp>(
iv.getLoc(), arith::CmpIPredicate::slt, iv, forOp.getUpperBound());
iv.getLoc(), arith::CmpIPredicate::slt, iv, forOp.getUpperBound());
// rematerialize peeled values
SmallVector<Operation*> orderedDeps;
SmallVector<Operation *> orderedDeps;
for (Operation &op : forOp.getLoopBody().front()) {
if (depOps.contains(&op))
orderedDeps.push_back(&op);
@@ -221,10 +224,9 @@ void LoopPipeliner::emitPrologue() {
// TODO: check if the hardware supports copyasync
if (auto loadOp = llvm::dyn_cast<triton::LoadOp>(op)) {
newOp = builder.create<triton::gpu::CopyAsyncOp>(
op->getLoc(), loadsMapping[loadOp].getType(),
loadOp.ptr(), loadOp.mask(), loadOp.other(),
loadOp.cache(), loadOp.evict(), loadOp.isVolatile()
);
op->getLoc(), loadsMapping[loadOp].getType(), loadOp.ptr(),
loadOp.mask(), loadOp.other(), loadOp.cache(), loadOp.evict(),
loadOp.isVolatile());
} else
llvm_unreachable("This should be LoadOp");
} else
@@ -245,12 +247,10 @@ void LoopPipeliner::emitPrologue() {
// assert(I1 or TensorOf<[I1]>);
OpBuilder::InsertionGuard g(builder);
builder.setInsertionPoint(newOp);
Value splatCond = builder.create<triton::BroadcastOp>(mask.getLoc(),
mask.getType(),
loopCond);
Value newMask = builder.create<arith::AndIOp>(mask.getLoc(),
mask,
splatCond);
Value splatCond = builder.create<triton::BroadcastOp>(
mask.getLoc(), mask.getType(), loopCond);
Value newMask =
builder.create<arith::AndIOp>(mask.getLoc(), mask, splatCond);
newOp->setOperand(1, newMask);
}
@@ -264,8 +264,9 @@ void LoopPipeliner::emitPrologue() {
// update mapping for loop-carried values (args)
for (OpOperand &operand : yieldOp->getOpOperands()) {
if (operand.get() == op->getResult(dstIdx))
setValueMapping(forOp.getRegionIterArgs()[operand.getOperandNumber()],
newOp->getResult(dstIdx), stage + 1);
setValueMapping(
forOp.getRegionIterArgs()[operand.getOperandNumber()],
newOp->getResult(dstIdx), stage + 1);
}
}
}
@@ -296,21 +297,19 @@ scf::ForOp LoopPipeliner::createNewForOp() {
size_t depArgsBeginIdx = newLoopArgs.size();
for (BlockArgument depArg : depArgs) {
depArgsIdx[depArg] = newLoopArgs.size();
newLoopArgs.push_back(valueMapping[depArg][numStages-1]);
newLoopArgs.push_back(valueMapping[depArg][numStages - 1]);
}
size_t nextIVIdx = newLoopArgs.size();
newLoopArgs.push_back(valueMapping[forOp.getInductionVar()][numStages-2]);
newLoopArgs.push_back(valueMapping[forOp.getInductionVar()][numStages - 2]);
for (size_t i = 0; i < newLoopArgs.size(); ++i)
assert(newLoopArgs[i]);
// 1. signature of the new ForOp
auto newForOp = builder.create<scf::ForOp>(forOp.getLoc(),
forOp.getLowerBound(),
forOp.getUpperBound(),
forOp.getStep(),
newLoopArgs);
auto newForOp = builder.create<scf::ForOp>(
forOp.getLoc(), forOp.getLowerBound(), forOp.getUpperBound(),
forOp.getStep(), newLoopArgs);
// 2. body of the new ForOp
builder.setInsertionPointToStart(newForOp.getBody());
@@ -329,15 +328,15 @@ scf::ForOp LoopPipeliner::createNewForOp() {
// 3. replace loads with block args (from prologue)
for (size_t idx = 0; idx < loads.size(); ++idx) {
Value load = loads[idx];
assert(load.hasOneUse() && "we assume that this load has one use (ConvertLayout)");
assert(load.hasOneUse() &&
"we assume that this load has one use (ConvertLayout)");
Value loadUse = load.getUsers().begin()->getResult(0);
mapping.lookup(loadUse).replaceAllUsesWith(
newForOp.getRegionIterArgs()[loadIdx + idx*(numStages-1)]);
newForOp.getRegionIterArgs()[loadIdx + idx * (numStages - 1)]);
}
// 4. prefetch the next iteration
SmallVector<Operation*> orderedDeps;
SmallVector<Operation *> orderedDeps;
for (Operation &op : forOp.getLoopBody().front()) {
if (depOps.contains(&op))
orderedDeps.push_back(&op);
@@ -350,41 +349,39 @@ scf::ForOp LoopPipeliner::createNewForOp() {
DenseMap<BlockArgument, Value> depArgsMapping;
size_t argIdx = 0;
for (BlockArgument arg : depArgs) {
nextMapping.map(arg, newForOp.getRegionIterArgs()[argIdx + depArgsBeginIdx]);
nextMapping.map(arg,
newForOp.getRegionIterArgs()[argIdx + depArgsBeginIdx]);
++argIdx;
}
// special handling for iv & loop condition
Value nextIV = builder.create<arith::AddIOp>(newForOp.getInductionVar().getLoc(),
newForOp.getRegionIterArgs()[nextIVIdx],
newForOp.getStep());
Value nextLoopCond = builder.create<arith::CmpIOp>(
nextIV.getLoc(), arith::CmpIPredicate::slt,
nextIV, newForOp.getUpperBound());
Value nextIV = builder.create<arith::AddIOp>(
newForOp.getInductionVar().getLoc(),
newForOp.getRegionIterArgs()[nextIVIdx], newForOp.getStep());
Value nextLoopCond =
builder.create<arith::CmpIOp>(nextIV.getLoc(), arith::CmpIPredicate::slt,
nextIV, newForOp.getUpperBound());
for (Operation *op : orderedDeps) {
Operation *nextOp = nullptr;
// update loading mask
if (loads.contains(op->getResult(0))) {
auto loadOp = llvm::cast<triton::LoadOp>(op);
Value mask = loadOp.mask();
Value splatCond = builder.create<triton::BroadcastOp>(mask.getLoc(),
mask.getType(),
nextLoopCond);
Value newMask = builder.create<arith::AndIOp>(mask.getLoc(),
splatCond,
nextMapping.lookupOrDefault(mask));
// if mask is defined outside the loop, don't update the map more than once
Value splatCond = builder.create<triton::BroadcastOp>(
mask.getLoc(), mask.getType(), nextLoopCond);
Value newMask = builder.create<arith::AndIOp>(
mask.getLoc(), splatCond, nextMapping.lookupOrDefault(mask));
// if mask is defined outside the loop, don't update the map more than
// once
if (!(forOp.isDefinedOutsideOfLoop(mask) && nextMapping.contains(mask)))
nextMapping.map(mask, newMask);
// TODO: more elegant way to do this?
nextOp = builder.create<triton::gpu::CopyAsyncOp>(
op->getLoc(), loadsMapping[op->getResult(0)].getType(),
nextMapping.lookupOrDefault(loadOp.ptr()),
nextMapping.lookupOrDefault(loadOp.mask()),
nextMapping.lookupOrDefault(loadOp.other()),
loadOp.cache(), loadOp.evict(), loadOp.isVolatile()
);
}
else
op->getLoc(), loadsMapping[op->getResult(0)].getType(),
nextMapping.lookupOrDefault(loadOp.ptr()),
nextMapping.lookupOrDefault(loadOp.mask()),
nextMapping.lookupOrDefault(loadOp.other()), loadOp.cache(),
loadOp.evict(), loadOp.isVolatile());
} else
nextOp = builder.clone(*op, nextMapping);
// llvm::errs() << "epilogue cloning...: " << *op << "\n";
// update mapping of results
@@ -411,15 +408,16 @@ scf::ForOp LoopPipeliner::createNewForOp() {
for (size_t idx = 0; idx < loads.size(); ++idx) {
Value load = loads[idx];
for (int stage = 1; stage < numStages - 1; ++stage) {
yieldValues.push_back(newForOp.getRegionIterArgs()[
loadIdx + idx*(numStages-1) + stage
]);
yieldValues.push_back(
newForOp
.getRegionIterArgs()[loadIdx + idx * (numStages - 1) + stage]);
}
yieldValues.push_back(nextMapping.lookup(load));
}
for (size_t i = depArgsBeginIdx; i < nextIVIdx; ++i)
yieldValues.push_back(depArgsMapping.lookup(newForOp.getRegionIterArgs()[i]));
yieldValues.push_back(
depArgsMapping.lookup(newForOp.getRegionIterArgs()[i]));
yieldValues.push_back(nextIV);
builder.setInsertionPointToEnd(newForOp.getBody());
builder.create<scf::YieldOp>(forOp.getBody()->getTerminator()->getLoc(),
@@ -430,9 +428,7 @@ scf::ForOp LoopPipeliner::createNewForOp() {
// ref: mlir/lib/Dialect/SCF/Transforms/LoopPipelining.cpp
struct PipelinePass : public TritonGPUPipelineBase<PipelinePass> {
PipelinePass() = default;
PipelinePass(int numStages) {
this->numStages = numStages;
}
PipelinePass(int numStages) { this->numStages = numStages; }
void runOnOperation() override {
int numStages = this->numStages;

View File

@@ -1,7 +1,7 @@
#include "triton/Dialect/TritonGPU/Transforms/TritonGPUConversion.h"
#include "mlir/IR/BlockAndValueMapping.h"
#include "triton/Dialect/Triton/IR/Dialect.h"
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
#include "mlir/IR/BlockAndValueMapping.h"
#include <algorithm>
using namespace mlir;
@@ -10,7 +10,7 @@ using namespace mlir::triton::gpu;
//
// TypeConverter
//
TritonGPUTypeConverter::TritonGPUTypeConverter(MLIRContext *context,
TritonGPUTypeConverter::TritonGPUTypeConverter(MLIRContext *context,
int numThreads)
: context(context), numThreads(numThreads) {
// TODO: how does MLIR pick the right conversion?
@@ -38,14 +38,14 @@ TritonGPUTypeConverter::TritonGPUTypeConverter(MLIRContext *context,
// or assert no encoding?
// Now we assume:
// contiguous = 1, order = 0, 1, 2, ...,
// contiguous = 1, order = 0, 1, 2, ...,
llvm::SmallVector<unsigned> threadTileSize(rank, 1); // naive layout
llvm::SmallVector<unsigned> warpTileSize(rank, 1);
llvm::SmallVector<unsigned> blockTileSize(rank);
llvm::SmallVector<unsigned> order(rank);
llvm::SmallVector<unsigned> broadcastAxis;
int remainingThreads = numThreads;
int remainingLanes = /*warp size*/32;
int remainingLanes = /*warp size*/ 32;
for (int64_t dim = 0; dim < rank; ++dim) {
blockTileSize[dim] = std::clamp(remainingThreads, 1, int(shape[dim]));
warpTileSize[dim] = std::clamp(remainingLanes, 1, int(shape[dim]));
@@ -56,7 +56,8 @@ TritonGPUTypeConverter::TritonGPUTypeConverter(MLIRContext *context,
// TODO: will we need repetition?
}
Attribute encoding = triton::gpu::TritonGPUBlockedEncodingAttr::get(
context, threadTileSize, warpTileSize, blockTileSize, order, broadcastAxis);
context, threadTileSize, warpTileSize, blockTileSize, order,
broadcastAxis);
return RankedTensorType::get(shape, elementType, encoding);
});
@@ -65,8 +66,9 @@ TritonGPUTypeConverter::TritonGPUTypeConverter(MLIRContext *context,
//
// This will be called when (newArgType != origArgType)
// This will create newArg, and map(origArg, newArg)
addArgumentMaterialization([&](OpBuilder &builder, RankedTensorType tensorType,
ValueRange inputs, Location loc) {
addArgumentMaterialization([&](OpBuilder &builder,
RankedTensorType tensorType, ValueRange inputs,
Location loc) {
llvm_unreachable("Not implemented");
return llvm::None;
});
@@ -74,7 +76,7 @@ TritonGPUTypeConverter::TritonGPUTypeConverter(MLIRContext *context,
// If the origValue still has live user(s), use this to
// convert origValue to newValue
addSourceMaterialization([&](OpBuilder &builder, RankedTensorType tensorType,
ValueRange inputs, Location loc) {
ValueRange inputs, Location loc) {
llvm_unreachable("Not implemented");
return llvm::None;
});
@@ -83,7 +85,7 @@ TritonGPUTypeConverter::TritonGPUTypeConverter(MLIRContext *context,
// where, desiredType = typeConverter->convertType(origType)
// NOTE: only for remapped values.
addTargetMaterialization([&](OpBuilder &builder, RankedTensorType tensorType,
ValueRange inputs, Location loc) {
ValueRange inputs, Location loc) {
llvm_unreachable("Not implemented");
return llvm::None;
});
@@ -93,30 +95,31 @@ TritonGPUTypeConverter::TritonGPUTypeConverter(MLIRContext *context,
// TritonGPUConversion
//
TritonGPUConversionTarget::TritonGPUConversionTarget(
MLIRContext &context, TritonGPUTypeConverter &typeConverter)
MLIRContext &context, TritonGPUTypeConverter &typeConverter)
: ConversionTarget(context), typeConverter(typeConverter) {
// TODO: we should also verify ops of TritonGPUDialect
addLegalDialect<triton::gpu::TritonGPUDialect>();
// Some ops from SCF are illegal
addIllegalOp<scf::ExecuteRegionOp, scf::ParallelOp,
scf::ReduceOp, scf::ReduceReturnOp>();
addDynamicallyLegalDialect<arith::ArithmeticDialect,
triton::TritonDialect,
StandardOpsDialect,
scf::SCFDialect>([&](Operation *op) {
if (typeConverter.isLegal(op))
return true;
return false;
});
addIllegalOp<scf::ExecuteRegionOp, scf::ParallelOp, scf::ReduceOp,
scf::ReduceReturnOp>();
addDynamicallyLegalDialect<arith::ArithmeticDialect, triton::TritonDialect,
StandardOpsDialect, scf::SCFDialect>(
[&](Operation *op) {
if (typeConverter.isLegal(op))
return true;
return false;
});
// We have requirements for the data layouts
addDynamicallyLegalOp<triton::DotOp>([this](triton::DotOp dotOp) -> bool {
Attribute aEncoding = dotOp.a().getType().cast<RankedTensorType>().getEncoding();
Attribute bEncoding = dotOp.b().getType().cast<RankedTensorType>().getEncoding();
if (aEncoding && aEncoding.isa<triton::gpu::TritonGPUSharedEncodingAttr>() &&
Attribute aEncoding =
dotOp.a().getType().cast<RankedTensorType>().getEncoding();
Attribute bEncoding =
dotOp.b().getType().cast<RankedTensorType>().getEncoding();
if (aEncoding &&
aEncoding.isa<triton::gpu::TritonGPUSharedEncodingAttr>() &&
bEncoding && bEncoding.isa<triton::gpu::TritonGPUSharedEncodingAttr>())
return true;
// // TODO: we should delete this
@@ -124,7 +127,6 @@ TritonGPUConversionTarget::TritonGPUConversionTarget(
// return true;
return false;
});
}
// %dst = tt.broadcast %src
@@ -133,12 +135,10 @@ TritonGPUConversionTarget::TritonGPUConversionTarget(
// %bcst = tt.broadcast %newSrc
// %dst = convert_layout %bcst
LogicalResult TritonGPUConversionTarget::refineLayouts(ModuleOp mod,
int numThreads) {
int numThreads) {
// collect broadcasts
SmallVector<triton::BroadcastOp> broadcasts;
mod.walk([&](triton::BroadcastOp op) {
broadcasts.push_back(op);
});
mod.walk([&](triton::BroadcastOp op) { broadcasts.push_back(op); });
BlockAndValueMapping mapping;
for (auto broadcast : broadcasts) {
@@ -161,20 +161,14 @@ LogicalResult TritonGPUConversionTarget::refineLayouts(ModuleOp mod,
broadcastAxis.push_back(ax);
Attribute originSrcEnc = tensorType.getEncoding();
if (auto blockedEnc = originSrcEnc.dyn_cast<TritonGPUBlockedEncodingAttr>()) {
if (auto blockedEnc =
originSrcEnc.dyn_cast<TritonGPUBlockedEncodingAttr>()) {
auto newSrcEnc = TritonGPUBlockedMulticastEncodingAttr::get(
blockedEnc.getContext(),
blockedEnc.getThreadTileSize(),
blockedEnc.getWarpTileSize(),
blockedEnc.getBlockTileSize(),
blockedEnc.getOrder(),
broadcastAxis
);
blockedEnc.getContext(), blockedEnc.getThreadTileSize(),
blockedEnc.getWarpTileSize(), blockedEnc.getBlockTileSize(),
blockedEnc.getOrder(), broadcastAxis);
newSrcType = RankedTensorType::get(
tensorType.getShape(),
tensorType.getElementType(),
newSrcEnc
);
tensorType.getShape(), tensorType.getElementType(), newSrcEnc);
} else
llvm_unreachable("src of broadcast should have blocked encoding");
} else {
@@ -186,34 +180,25 @@ LogicalResult TritonGPUConversionTarget::refineLayouts(ModuleOp mod,
// create new src
if (!isSrcScalar) // we don't need to convert layout for scalar values
src = builder.create<triton::gpu::ConvertLayoutOp>(
src.getLoc(), newSrcType, src
);
src = builder.create<triton::gpu::ConvertLayoutOp>(src.getLoc(),
newSrcType, src);
// create new broadcast
// compute new type (encoding)
auto originDstEnc = originDstTensorType.getEncoding()
.dyn_cast<TritonGPUBlockedEncodingAttr>();
.dyn_cast<TritonGPUBlockedEncodingAttr>();
auto newEnc = TritonGPUBlockedEncodingAttr::get(
originDstEnc.getContext(),
originDstEnc.getThreadTileSize(),
originDstEnc.getWarpTileSize(),
originDstEnc.getBlockTileSize(),
originDstEnc.getOrder(),
broadcastAxis
);
auto newType = RankedTensorType::get(
originDstTensorType.getShape(),
originDstTensorType.getElementType(),
newEnc
);
Value newBroadcast = builder.create<triton::BroadcastOp>(
broadcast.getLoc(), newType, src
);
originDstEnc.getContext(), originDstEnc.getThreadTileSize(),
originDstEnc.getWarpTileSize(), originDstEnc.getBlockTileSize(),
originDstEnc.getOrder(), broadcastAxis);
auto newType =
RankedTensorType::get(originDstTensorType.getShape(),
originDstTensorType.getElementType(), newEnc);
Value newBroadcast =
builder.create<triton::BroadcastOp>(broadcast.getLoc(), newType, src);
// we don't want to change the encoding of the result
Value newDst = builder.create<triton::gpu::ConvertLayoutOp>(
broadcast.getLoc(), originDstType, newBroadcast
);
broadcast.getLoc(), originDstType, newBroadcast);
broadcast.replaceAllUsesWith(newDst);
mapping.map(broadcast, newDst);

View File

@@ -5,7 +5,6 @@
using namespace mlir;
#define GEN_PASS_CLASSES
#include "triton/Dialect/TritonGPU/Transforms/Passes.h.inc"
@@ -37,28 +36,30 @@ private:
if (!encoding.isa<triton::gpu::TritonGPUSharedEncodingAttr>())
return dotOp.emitError() << name << " should be of shared layout";
} else
return dotOp.emitError() << name << "'s type should be of RankedTensorType";
return dotOp.emitError()
<< name << "'s type should be of RankedTensorType";
}
Attribute cLayout;
for (auto it : llvm::zip(llvm::SmallVector<Type>{cType, dType},
llvm::SmallVector<char>{'c', 'd'})) {
llvm::SmallVector<char>{'c', 'd'})) {
Type type = std::get<0>(it);
char name = std::get<1>(it);
if (auto tensorType = type.dyn_cast<RankedTensorType>()) {
Attribute encoding = tensorType.getEncoding();
if (!encoding)
return dotOp.emitError() << name << " should have encoding";
if (!encoding.isa<triton::gpu::TritonGPUMmaEncodingAttr>() &&
if (!encoding.isa<triton::gpu::TritonGPUMmaEncodingAttr>() &&
!encoding.isa<triton::gpu::TritonGPUBlockedEncodingAttr>())
return dotOp.emitError() << name << " should be of distributed layout";
return dotOp.emitError()
<< name << " should be of distributed layout";
if (name == 'c')
cLayout = encoding;
else if (encoding != cLayout)
return dotOp.emitError() << "d & c should have the same layout";
} else
return dotOp.emitError() << name
<< "'s type should be of RankedTensorType";
return dotOp.emitError()
<< name << "'s type should be of RankedTensorType";
}
// signalPassFailure();
@@ -89,7 +90,7 @@ private:
}
void verifyImpl(Operation *op) {
if(verifySingleOp(op).failed())
if (verifySingleOp(op).failed())
signalPassFailure();
// verify that all child regions are ok

408
lib/driver/dispatch.cc Executable file → Normal file
View File

@@ -1,107 +1,152 @@
/* Copyright 2015-2017 Philippe Tillet
*
* Permission is hereby granted, free of charge, to any person obtaining
* a copy of this software and associated documentation files
* (the "Software"), to deal in the Software without restriction,
* including without limitation the rights to use, copy, modify, merge,
* publish, distribute, sublicense, and/or sell copies of the Software,
* and to permit persons to whom the Software is furnished to do so,
* subject to the following conditions:
*
* The above copyright notice and this permission notice shall be
* included in all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
* EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
*/
*
* Permission is hereby granted, free of charge, to any person obtaining
* a copy of this software and associated documentation files
* (the "Software"), to deal in the Software without restriction,
* including without limitation the rights to use, copy, modify, merge,
* publish, distribute, sublicense, and/or sell copies of the Software,
* and to permit persons to whom the Software is furnished to do so,
* subject to the following conditions:
*
* The above copyright notice and this permission notice shall be
* included in all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
* EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
*/
#include "triton/driver/dispatch.h"
namespace triton
{
namespace driver
{
namespace triton {
namespace driver {
//Helpers for function definition
#define DEFINE0(init, hlib, ret, fname) ret dispatch::fname()\
{return f_impl<dispatch::init>(hlib, fname, fname ## _, #fname); }\
void* dispatch::fname ## _;
// Helpers for function definition
#define DEFINE0(init, hlib, ret, fname) \
ret dispatch::fname() { \
return f_impl<dispatch::init>(hlib, fname, fname##_, #fname); \
} \
void *dispatch::fname##_;
#define DEFINE1(init, hlib, ret, fname, t1) ret dispatch::fname(t1 a)\
{return f_impl<dispatch::init>(hlib, fname, fname ## _, #fname, a); }\
void* dispatch::fname ## _;
#define DEFINE1(init, hlib, ret, fname, t1) \
ret dispatch::fname(t1 a) { \
return f_impl<dispatch::init>(hlib, fname, fname##_, #fname, a); \
} \
void *dispatch::fname##_;
#define DEFINE2(init, hlib, ret, fname, t1, t2) ret dispatch::fname(t1 a, t2 b)\
{return f_impl<dispatch::init>(hlib, fname, fname ## _, #fname, a, b); }\
void* dispatch::fname ## _;
#define DEFINE2(init, hlib, ret, fname, t1, t2) \
ret dispatch::fname(t1 a, t2 b) { \
return f_impl<dispatch::init>(hlib, fname, fname##_, #fname, a, b); \
} \
void *dispatch::fname##_;
#define DEFINE3(init, hlib, ret, fname, t1, t2, t3) ret dispatch::fname(t1 a, t2 b, t3 c)\
{return f_impl<dispatch::init>(hlib, fname, fname ## _, #fname, a, b, c); }\
void* dispatch::fname ## _;
#define DEFINE3(init, hlib, ret, fname, t1, t2, t3) \
ret dispatch::fname(t1 a, t2 b, t3 c) { \
return f_impl<dispatch::init>(hlib, fname, fname##_, #fname, a, b, c); \
} \
void *dispatch::fname##_;
#define DEFINE4(init, hlib, ret, fname, t1, t2, t3, t4) ret dispatch::fname(t1 a, t2 b, t3 c, t4 d)\
{return f_impl<dispatch::init>(hlib, fname, fname ## _, #fname, a, b, c, d); }\
void* dispatch::fname ## _;
#define DEFINE4(init, hlib, ret, fname, t1, t2, t3, t4) \
ret dispatch::fname(t1 a, t2 b, t3 c, t4 d) { \
return f_impl<dispatch::init>(hlib, fname, fname##_, #fname, a, b, c, d); \
} \
void *dispatch::fname##_;
#define DEFINE5(init, hlib, ret, fname, t1, t2, t3, t4, t5) ret dispatch::fname(t1 a, t2 b, t3 c, t4 d, t5 e)\
{return f_impl<dispatch::init>(hlib, fname, fname ## _, #fname, a, b, c, d, e); }\
void* dispatch::fname ## _;
#define DEFINE5(init, hlib, ret, fname, t1, t2, t3, t4, t5) \
ret dispatch::fname(t1 a, t2 b, t3 c, t4 d, t5 e) { \
return f_impl<dispatch::init>(hlib, fname, fname##_, #fname, a, b, c, d, \
e); \
} \
void *dispatch::fname##_;
#define DEFINE6(init, hlib, ret, fname, t1, t2, t3, t4, t5, t6) ret dispatch::fname(t1 a, t2 b, t3 c, t4 d, t5 e, t6 f)\
{return f_impl<dispatch::init>(hlib, fname, fname ## _, #fname, a, b, c, d, e, f); }\
void* dispatch::fname ## _;
#define DEFINE6(init, hlib, ret, fname, t1, t2, t3, t4, t5, t6) \
ret dispatch::fname(t1 a, t2 b, t3 c, t4 d, t5 e, t6 f) { \
return f_impl<dispatch::init>(hlib, fname, fname##_, #fname, a, b, c, d, \
e, f); \
} \
void *dispatch::fname##_;
#define DEFINE7(init, hlib, ret, fname, t1, t2, t3, t4, t5, t6, t7) ret dispatch::fname(t1 a, t2 b, t3 c, t4 d, t5 e, t6 f, t7 g)\
{return f_impl<dispatch::init>(hlib, fname, fname ## _, #fname, a, b, c, d, e, f, g); }\
void* dispatch::fname ## _;
#define DEFINE7(init, hlib, ret, fname, t1, t2, t3, t4, t5, t6, t7) \
ret dispatch::fname(t1 a, t2 b, t3 c, t4 d, t5 e, t6 f, t7 g) { \
return f_impl<dispatch::init>(hlib, fname, fname##_, #fname, a, b, c, d, \
e, f, g); \
} \
void *dispatch::fname##_;
#define DEFINE8(init, hlib, ret, fname, t1, t2, t3, t4, t5, t6, t7, t8) ret dispatch::fname(t1 a, t2 b, t3 c, t4 d, t5 e, t6 f, t7 g, t8 h)\
{return f_impl<dispatch::init>(hlib, fname, fname ## _, #fname, a, b, c, d, e, f, g, h); }\
void* dispatch::fname ## _;
#define DEFINE8(init, hlib, ret, fname, t1, t2, t3, t4, t5, t6, t7, t8) \
ret dispatch::fname(t1 a, t2 b, t3 c, t4 d, t5 e, t6 f, t7 g, t8 h) { \
return f_impl<dispatch::init>(hlib, fname, fname##_, #fname, a, b, c, d, \
e, f, g, h); \
} \
void *dispatch::fname##_;
#define DEFINE9(init, hlib, ret, fname, t1, t2, t3, t4, t5, t6, t7, t8, t9) ret dispatch::fname(t1 a, t2 b, t3 c, t4 d, t5 e, t6 f, t7 g, t8 h, t9 i)\
{return f_impl<dispatch::init>(hlib, fname, fname ## _, #fname, a, b, c, d, e, f, g, h, i); }\
void* dispatch::fname ## _;
#define DEFINE9(init, hlib, ret, fname, t1, t2, t3, t4, t5, t6, t7, t8, t9) \
ret dispatch::fname(t1 a, t2 b, t3 c, t4 d, t5 e, t6 f, t7 g, t8 h, t9 i) { \
return f_impl<dispatch::init>(hlib, fname, fname##_, #fname, a, b, c, d, \
e, f, g, h, i); \
} \
void *dispatch::fname##_;
#define DEFINE10(init, hlib, ret, fname, t1, t2, t3, t4, t5, t6, t7, t8, t9, t10) ret dispatch::fname(t1 a, t2 b, t3 c, t4 d, t5 e, t6 f, t7 g, t8 h, t9 i, t10 j)\
{return f_impl<dispatch::init>(hlib, fname, fname ## _, #fname, a, b, c, d, e, f, g, h, i, j); }\
void* dispatch::fname ## _;
#define DEFINE10(init, hlib, ret, fname, t1, t2, t3, t4, t5, t6, t7, t8, t9, \
t10) \
ret dispatch::fname(t1 a, t2 b, t3 c, t4 d, t5 e, t6 f, t7 g, t8 h, t9 i, \
t10 j) { \
return f_impl<dispatch::init>(hlib, fname, fname##_, #fname, a, b, c, d, \
e, f, g, h, i, j); \
} \
void *dispatch::fname##_;
#define DEFINE11(init, hlib, ret, fname, t1, t2, t3, t4, t5, t6, t7, t8, t9, t10, t11) ret dispatch::fname(t1 a, t2 b, t3 c, t4 d, t5 e, t6 f, t7 g, t8 h, t9 i, t10 j, t11 k)\
{return f_impl<dispatch::init>(hlib, fname, fname ## _, #fname, a, b, c, d, e, f, g, h, i, j, k); }\
void* dispatch::fname ## _;
#define DEFINE11(init, hlib, ret, fname, t1, t2, t3, t4, t5, t6, t7, t8, t9, \
t10, t11) \
ret dispatch::fname(t1 a, t2 b, t3 c, t4 d, t5 e, t6 f, t7 g, t8 h, t9 i, \
t10 j, t11 k) { \
return f_impl<dispatch::init>(hlib, fname, fname##_, #fname, a, b, c, d, \
e, f, g, h, i, j, k); \
} \
void *dispatch::fname##_;
#define DEFINE13(init, hlib, ret, fname, t1, t2, t3, t4, t5, t6, t7, t8, t9, t10, t11, t12, t13) ret dispatch::fname(t1 a, t2 b, t3 c, t4 d, t5 e, t6 f, t7 g, t8 h, t9 i, t10 j, t11 k, t12 l, t13 m)\
{return f_impl<dispatch::init>(hlib, fname, fname ## _, #fname, a, b, c, d, e, f, g, h, i, j, k, l, m); }\
void* dispatch::fname ## _;
#define DEFINE19(init, hlib, ret, fname, t1, t2, t3, t4, t5, t6, t7, t8, t9, t10, t11, t12, t13, t14, t15, t16, t17, t18, t19) ret dispatch::fname(t1 a, t2 b, t3 c, t4 d, t5 e, t6 f, t7 g, t8 h, t9 i, t10 j, t11 k, t12 l, t13 m, t14 n, t15 o, t16 p, t17 q, t18 r, t19 s)\
{return f_impl<dispatch::init>(hlib, fname, fname ## _, #fname, a, b, c, d, e, f, g, h, i, j, k, l, m, n, o, p, q, r, s); }\
void* dispatch::fname ## _;
#define DEFINE13(init, hlib, ret, fname, t1, t2, t3, t4, t5, t6, t7, t8, t9, \
t10, t11, t12, t13) \
ret dispatch::fname(t1 a, t2 b, t3 c, t4 d, t5 e, t6 f, t7 g, t8 h, t9 i, \
t10 j, t11 k, t12 l, t13 m) { \
return f_impl<dispatch::init>(hlib, fname, fname##_, #fname, a, b, c, d, \
e, f, g, h, i, j, k, l, m); \
} \
void *dispatch::fname##_;
#define DEFINE19(init, hlib, ret, fname, t1, t2, t3, t4, t5, t6, t7, t8, t9, \
t10, t11, t12, t13, t14, t15, t16, t17, t18, t19) \
ret dispatch::fname(t1 a, t2 b, t3 c, t4 d, t5 e, t6 f, t7 g, t8 h, t9 i, \
t10 j, t11 k, t12 l, t13 m, t14 n, t15 o, t16 p, t17 q, \
t18 r, t19 s) { \
return f_impl<dispatch::init>(hlib, fname, fname##_, #fname, a, b, c, d, \
e, f, g, h, i, j, k, l, m, n, o, p, q, r, \
s); \
} \
void *dispatch::fname##_;
/* ------------------- *
* CUDA
* ------------------- */
bool dispatch::cuinit(){
if(cuda_==nullptr){
#ifdef _WIN32
bool dispatch::cuinit() {
if (cuda_ == nullptr) {
#ifdef _WIN32
cuda_ = dlopen("cudart64_110.dll", RTLD_LAZY);
#else
#else
cuda_ = dlopen("libcuda.so", RTLD_LAZY);
if(!cuda_)
if (!cuda_)
cuda_ = dlopen("libcuda.so.1", RTLD_LAZY);
#endif
if(!cuda_)
throw std::runtime_error("Could not find `libcuda.so`. Make sure it is in your LD_LIBRARY_PATH.");
#endif
if (!cuda_)
throw std::runtime_error("Could not find `libcuda.so`. Make sure it is "
"in your LD_LIBRARY_PATH.");
}
if(cuda_ == nullptr)
if (cuda_ == nullptr)
return false;
CUresult (*fptr)(unsigned int);
cuInit_ = dlsym(cuda_, "cuInit");
@@ -112,21 +157,33 @@ bool dispatch::cuinit(){
}
#define CUDA_DEFINE1(ret, fname, t1) DEFINE1(cuinit, cuda_, ret, fname, t1)
#define CUDA_DEFINE2(ret, fname, t1, t2) DEFINE2(cuinit, cuda_, ret, fname, t1, t2)
#define CUDA_DEFINE3(ret, fname, t1, t2, t3) DEFINE3(cuinit, cuda_, ret, fname, t1, t2, t3)
#define CUDA_DEFINE4(ret, fname, t1, t2, t3, t4) DEFINE4(cuinit, cuda_, ret, fname, t1, t2, t3, t4)
#define CUDA_DEFINE5(ret, fname, t1, t2, t3, t4, t5) DEFINE5(cuinit, cuda_, ret, fname, t1, t2, t3, t4, t5)
#define CUDA_DEFINE6(ret, fname, t1, t2, t3, t4, t5, t6) DEFINE6(cuinit, cuda_, ret, fname, t1, t2, t3, t4, t5, t6)
#define CUDA_DEFINE7(ret, fname, t1, t2, t3, t4, t5, t6, t7) DEFINE7(cuinit, cuda_, ret, fname, t1, t2, t3, t4, t5, t6, t7)
#define CUDA_DEFINE8(ret, fname, t1, t2, t3, t4, t5, t6, t7, t8) DEFINE8(cuinit, cuda_, ret, fname, t1, t2, t3, t4, t5, t6, t7, t8)
#define CUDA_DEFINE9(ret, fname, t1, t2, t3, t4, t5, t6, t7, t8, t9) DEFINE9(cuinit, cuda_, ret, fname, t1, t2, t3, t4, t5, t6, t7, t8, t9)
#define CUDA_DEFINE10(ret, fname, t1, t2, t3, t4, t5, t6, t7, t8, t9, t10) DEFINE10(cuinit, cuda_, ret, fname, t1, t2, t3, t4, t5, t6, t7, t8, t9, t10)
#define CUDA_DEFINE11(ret, fname, t1, t2, t3, t4, t5, t6, t7, t8, t9, t10, t11) DEFINE11(cuinit, cuda_, ret, fname, t1, t2, t3, t4, t5, t6, t7, t8, t9, t10, t11)
#define CUDA_DEFINE2(ret, fname, t1, t2) \
DEFINE2(cuinit, cuda_, ret, fname, t1, t2)
#define CUDA_DEFINE3(ret, fname, t1, t2, t3) \
DEFINE3(cuinit, cuda_, ret, fname, t1, t2, t3)
#define CUDA_DEFINE4(ret, fname, t1, t2, t3, t4) \
DEFINE4(cuinit, cuda_, ret, fname, t1, t2, t3, t4)
#define CUDA_DEFINE5(ret, fname, t1, t2, t3, t4, t5) \
DEFINE5(cuinit, cuda_, ret, fname, t1, t2, t3, t4, t5)
#define CUDA_DEFINE6(ret, fname, t1, t2, t3, t4, t5, t6) \
DEFINE6(cuinit, cuda_, ret, fname, t1, t2, t3, t4, t5, t6)
#define CUDA_DEFINE7(ret, fname, t1, t2, t3, t4, t5, t6, t7) \
DEFINE7(cuinit, cuda_, ret, fname, t1, t2, t3, t4, t5, t6, t7)
#define CUDA_DEFINE8(ret, fname, t1, t2, t3, t4, t5, t6, t7, t8) \
DEFINE8(cuinit, cuda_, ret, fname, t1, t2, t3, t4, t5, t6, t7, t8)
#define CUDA_DEFINE9(ret, fname, t1, t2, t3, t4, t5, t6, t7, t8, t9) \
DEFINE9(cuinit, cuda_, ret, fname, t1, t2, t3, t4, t5, t6, t7, t8, t9)
#define CUDA_DEFINE10(ret, fname, t1, t2, t3, t4, t5, t6, t7, t8, t9, t10) \
DEFINE10(cuinit, cuda_, ret, fname, t1, t2, t3, t4, t5, t6, t7, t8, t9, t10)
#define CUDA_DEFINE11(ret, fname, t1, t2, t3, t4, t5, t6, t7, t8, t9, t10, \
t11) \
DEFINE11(cuinit, cuda_, ret, fname, t1, t2, t3, t4, t5, t6, t7, t8, t9, t10, \
t11)
// context management
CUDA_DEFINE1(CUresult, cuCtxDestroy_v2, CUcontext)
CUDA_DEFINE3(CUresult, cuCtxCreate_v2, CUcontext *, unsigned int, CUdevice)
CUDA_DEFINE1(CUresult, cuCtxGetDevice, CUdevice*)
CUDA_DEFINE1(CUresult, cuCtxGetDevice, CUdevice *)
CUDA_DEFINE2(CUresult, cuCtxEnablePeerAccess, CUcontext, unsigned int)
CUDA_DEFINE1(CUresult, cuInit, unsigned int)
CUDA_DEFINE1(CUresult, cuDriverGetVersion, int *)
@@ -134,59 +191,71 @@ CUDA_DEFINE1(CUresult, cuDriverGetVersion, int *)
CUDA_DEFINE2(CUresult, cuDeviceGet, CUdevice *, int)
CUDA_DEFINE3(CUresult, cuDeviceGetName, char *, int, CUdevice)
CUDA_DEFINE3(CUresult, cuDeviceGetPCIBusId, char *, int, CUdevice)
CUDA_DEFINE3(CUresult, cuDeviceGetAttribute, int *, CUdevice_attribute, CUdevice)
CUDA_DEFINE1(CUresult, cuDeviceGetCount, int*)
CUDA_DEFINE3(CUresult, cuDeviceGetAttribute, int *, CUdevice_attribute,
CUdevice)
CUDA_DEFINE1(CUresult, cuDeviceGetCount, int *)
// link management
CUDA_DEFINE8(CUresult, cuLinkAddData_v2, CUlinkState, CUjitInputType, void*, size_t, const char*, unsigned int, CUjit_option*, void**);
CUDA_DEFINE4(CUresult, cuLinkCreate_v2, unsigned int, CUjit_option*, void**, CUlinkState*);
CUDA_DEFINE8(CUresult, cuLinkAddData_v2, CUlinkState, CUjitInputType, void *,
size_t, const char *, unsigned int, CUjit_option *, void **);
CUDA_DEFINE4(CUresult, cuLinkCreate_v2, unsigned int, CUjit_option *, void **,
CUlinkState *);
CUDA_DEFINE1(CUresult, cuLinkDestroy, CUlinkState);
CUDA_DEFINE3(CUresult, cuLinkComplete, CUlinkState, void**, size_t*);
CUDA_DEFINE3(CUresult, cuLinkComplete, CUlinkState, void **, size_t *);
// module management
CUDA_DEFINE4(CUresult, cuModuleGetGlobal_v2, CUdeviceptr*, size_t*, CUmodule, const char*)
CUDA_DEFINE4(CUresult, cuModuleGetGlobal_v2, CUdeviceptr *, size_t *, CUmodule,
const char *)
CUDA_DEFINE2(CUresult, cuModuleLoad, CUmodule *, const char *)
CUDA_DEFINE1(CUresult, cuModuleUnload, CUmodule)
CUDA_DEFINE2(CUresult, cuModuleLoadData, CUmodule *, const void *)
CUDA_DEFINE5(CUresult, cuModuleLoadDataEx, CUmodule *, const void *, unsigned int, CUjit_option *, void **)
CUDA_DEFINE3(CUresult, cuModuleGetFunction, CUfunction *, CUmodule, const char *)
CUDA_DEFINE5(CUresult, cuModuleLoadDataEx, CUmodule *, const void *,
unsigned int, CUjit_option *, void **)
CUDA_DEFINE3(CUresult, cuModuleGetFunction, CUfunction *, CUmodule,
const char *)
// stream management
CUDA_DEFINE2(CUresult, cuStreamCreate, CUstream *, unsigned int)
CUDA_DEFINE1(CUresult, cuStreamSynchronize, CUstream)
CUDA_DEFINE1(CUresult, cuStreamDestroy_v2, CUstream)
CUDA_DEFINE2(CUresult, cuStreamGetCtx, CUstream, CUcontext*)
CUDA_DEFINE11(CUresult, cuLaunchKernel, CUfunction, unsigned int, unsigned int, unsigned int, unsigned int, unsigned int, unsigned int, unsigned int, CUstream, void **, void **)
CUDA_DEFINE2(CUresult, cuStreamGetCtx, CUstream, CUcontext *)
CUDA_DEFINE11(CUresult, cuLaunchKernel, CUfunction, unsigned int, unsigned int,
unsigned int, unsigned int, unsigned int, unsigned int,
unsigned int, CUstream, void **, void **)
// function management
CUDA_DEFINE3(CUresult, cuFuncGetAttribute, int*, CUfunction_attribute, CUfunction)
CUDA_DEFINE3(CUresult, cuFuncSetAttribute, CUfunction, CUfunction_attribute, int)
CUDA_DEFINE3(CUresult, cuFuncGetAttribute, int *, CUfunction_attribute,
CUfunction)
CUDA_DEFINE3(CUresult, cuFuncSetAttribute, CUfunction, CUfunction_attribute,
int)
CUDA_DEFINE2(CUresult, cuFuncSetCacheConfig, CUfunction, CUfunc_cache)
// memory management
CUDA_DEFINE3(CUresult, cuMemcpyDtoH_v2, void *, CUdeviceptr, size_t)
CUDA_DEFINE1(CUresult, cuMemFree_v2, CUdeviceptr)
CUDA_DEFINE4(CUresult, cuMemcpyDtoHAsync_v2, void *, CUdeviceptr, size_t, CUstream)
CUDA_DEFINE4(CUresult, cuMemcpyHtoDAsync_v2, CUdeviceptr, const void *, size_t, CUstream)
CUDA_DEFINE3(CUresult, cuMemcpyHtoD_v2, CUdeviceptr, const void *, size_t )
CUDA_DEFINE2(CUresult, cuMemAlloc_v2, CUdeviceptr*, size_t)
CUDA_DEFINE3(CUresult, cuPointerGetAttribute, void*, CUpointer_attribute, CUdeviceptr)
CUDA_DEFINE4(CUresult, cuMemsetD8Async, CUdeviceptr, unsigned char, size_t, CUstream)
CUDA_DEFINE4(CUresult, cuMemcpyDtoHAsync_v2, void *, CUdeviceptr, size_t,
CUstream)
CUDA_DEFINE4(CUresult, cuMemcpyHtoDAsync_v2, CUdeviceptr, const void *, size_t,
CUstream)
CUDA_DEFINE3(CUresult, cuMemcpyHtoD_v2, CUdeviceptr, const void *, size_t)
CUDA_DEFINE2(CUresult, cuMemAlloc_v2, CUdeviceptr *, size_t)
CUDA_DEFINE3(CUresult, cuPointerGetAttribute, void *, CUpointer_attribute,
CUdeviceptr)
CUDA_DEFINE4(CUresult, cuMemsetD8Async, CUdeviceptr, unsigned char, size_t,
CUstream)
// event management
CUDA_DEFINE2(CUresult, cuEventCreate, CUevent *, unsigned int)
CUDA_DEFINE3(CUresult, cuEventElapsedTime, float *, CUevent, CUevent)
CUDA_DEFINE2(CUresult, cuEventRecord, CUevent, CUstream)
CUDA_DEFINE1(CUresult, cuEventDestroy_v2, CUevent)
/* ------------------- *
* NVML
* ------------------- */
bool dispatch::nvmlinit(){
#ifdef _WIN32
if(nvml_==nullptr)
bool dispatch::nvmlinit() {
#ifdef _WIN32
if (nvml_ == nullptr)
nvml_ = dlopen("nvml.dll", RTLD_LAZY);
#else
if(nvml_==nullptr)
#else
if (nvml_ == nullptr)
nvml_ = dlopen("libnvidia-ml.so", RTLD_LAZY);
#endif
#endif
nvmlReturn_t (*fptr)();
nvmlInit_v2_ = dlsym(nvml_, "nvmlInit_v2");
*reinterpret_cast<void **>(&fptr) = nvmlInit_v2_;
@@ -197,21 +266,27 @@ bool dispatch::nvmlinit(){
#define NVML_DEFINE0(ret, fname) DEFINE0(nvmlinit, nvml_, ret, fname)
#define NVML_DEFINE1(ret, fname, t1) DEFINE1(nvmlinit, nvml_, ret, fname, t1)
#define NVML_DEFINE2(ret, fname, t1, t2) DEFINE2(nvmlinit, nvml_, ret, fname, t1, t2)
#define NVML_DEFINE3(ret, fname, t1, t2, t3) DEFINE3(nvmlinit, nvml_, ret, fname, t1, t2, t3)
#define NVML_DEFINE2(ret, fname, t1, t2) \
DEFINE2(nvmlinit, nvml_, ret, fname, t1, t2)
#define NVML_DEFINE3(ret, fname, t1, t2, t3) \
DEFINE3(nvmlinit, nvml_, ret, fname, t1, t2, t3)
NVML_DEFINE2(nvmlReturn_t, nvmlDeviceGetHandleByPciBusId_v2, const char *, nvmlDevice_t*)
NVML_DEFINE3(nvmlReturn_t, nvmlDeviceGetClockInfo, nvmlDevice_t, nvmlClockType_t, unsigned int*)
NVML_DEFINE3(nvmlReturn_t, nvmlDeviceGetMaxClockInfo, nvmlDevice_t, nvmlClockType_t, unsigned int*)
NVML_DEFINE3(nvmlReturn_t, nvmlDeviceSetApplicationsClocks, nvmlDevice_t, unsigned int, unsigned int)
NVML_DEFINE2(nvmlReturn_t, nvmlDeviceGetHandleByPciBusId_v2, const char *,
nvmlDevice_t *)
NVML_DEFINE3(nvmlReturn_t, nvmlDeviceGetClockInfo, nvmlDevice_t,
nvmlClockType_t, unsigned int *)
NVML_DEFINE3(nvmlReturn_t, nvmlDeviceGetMaxClockInfo, nvmlDevice_t,
nvmlClockType_t, unsigned int *)
NVML_DEFINE3(nvmlReturn_t, nvmlDeviceSetApplicationsClocks, nvmlDevice_t,
unsigned int, unsigned int)
/* ------------------- *
* HIP
* ------------------- */
bool dispatch::hipinit(){
if(hip_==nullptr)
bool dispatch::hipinit() {
if (hip_ == nullptr)
hip_ = dlopen("libamdhip64.so", RTLD_LAZY);
if(hip_ == nullptr)
if (hip_ == nullptr)
return false;
hipError_t (*fptr)();
hipInit_ = dlsym(hip_, "hipInit");
@@ -222,23 +297,34 @@ bool dispatch::hipinit(){
}
#define HIP_DEFINE1(ret, fname, t1) DEFINE1(hipinit, hip_, ret, fname, t1)
#define HIP_DEFINE2(ret, fname, t1, t2) DEFINE2(hipinit, hip_, ret, fname, t1, t2)
#define HIP_DEFINE3(ret, fname, t1, t2, t3) DEFINE3(hipinit, hip_, ret, fname, t1, t2, t3)
#define HIP_DEFINE4(ret, fname, t1, t2, t3, t4) DEFINE4(hipinit, hip_, ret, fname, t1, t2, t3, t4)
#define HIP_DEFINE5(ret, fname, t1, t2, t3, t4, t5) DEFINE5(hipinit, hip_, ret, fname, t1, t2, t3, t4, t5)
#define HIP_DEFINE6(ret, fname, t1, t2, t3, t4, t5, t6) DEFINE6(hipinit, hip_, ret, fname, t1, t2, t3, t4, t5, t6)
#define HIP_DEFINE7(ret, fname, t1, t2, t3, t4, t5, t6, t7) DEFINE7(hipinit, hip_, ret, fname, t1, t2, t3, t4, t5, t6, t7)
#define HIP_DEFINE8(ret, fname, t1, t2, t3, t4, t5, t6, t7, t8) DEFINE8(hipinit, hip_, ret, fname, t1, t2, t3, t4, t5, t6, t7, t8)
#define HIP_DEFINE9(ret, fname, t1, t2, t3, t4, t5, t6, t7, t8, t9) DEFINE9(hipinit, hip_, ret, fname, t1, t2, t3, t4, t5, t6, t7, t8, t9)
#define HIP_DEFINE10(ret, fname, t1, t2, t3, t4, t5, t6, t7, t8, t9, t10) DEFINE10(hipinit, hip_, ret, fname, t1, t2, t3, t4, t5, t6, t7, t8, t9, t10)
#define HIP_DEFINE11(ret, fname, t1, t2, t3, t4, t5, t6, t7, t8, t9, t10, t11) DEFINE11(hipinit, hip_, ret, fname, t1, t2, t3, t4, t5, t6, t7, t8, t9, t10, t11)
#define HIP_DEFINE2(ret, fname, t1, t2) \
DEFINE2(hipinit, hip_, ret, fname, t1, t2)
#define HIP_DEFINE3(ret, fname, t1, t2, t3) \
DEFINE3(hipinit, hip_, ret, fname, t1, t2, t3)
#define HIP_DEFINE4(ret, fname, t1, t2, t3, t4) \
DEFINE4(hipinit, hip_, ret, fname, t1, t2, t3, t4)
#define HIP_DEFINE5(ret, fname, t1, t2, t3, t4, t5) \
DEFINE5(hipinit, hip_, ret, fname, t1, t2, t3, t4, t5)
#define HIP_DEFINE6(ret, fname, t1, t2, t3, t4, t5, t6) \
DEFINE6(hipinit, hip_, ret, fname, t1, t2, t3, t4, t5, t6)
#define HIP_DEFINE7(ret, fname, t1, t2, t3, t4, t5, t6, t7) \
DEFINE7(hipinit, hip_, ret, fname, t1, t2, t3, t4, t5, t6, t7)
#define HIP_DEFINE8(ret, fname, t1, t2, t3, t4, t5, t6, t7, t8) \
DEFINE8(hipinit, hip_, ret, fname, t1, t2, t3, t4, t5, t6, t7, t8)
#define HIP_DEFINE9(ret, fname, t1, t2, t3, t4, t5, t6, t7, t8, t9) \
DEFINE9(hipinit, hip_, ret, fname, t1, t2, t3, t4, t5, t6, t7, t8, t9)
#define HIP_DEFINE10(ret, fname, t1, t2, t3, t4, t5, t6, t7, t8, t9, t10) \
DEFINE10(hipinit, hip_, ret, fname, t1, t2, t3, t4, t5, t6, t7, t8, t9, t10)
#define HIP_DEFINE11(ret, fname, t1, t2, t3, t4, t5, t6, t7, t8, t9, t10, t11) \
DEFINE11(hipinit, hip_, ret, fname, t1, t2, t3, t4, t5, t6, t7, t8, t9, t10, \
t11)
// context management
HIP_DEFINE1(hipError_t, hipCtxDestroy, hipCtx_t)
HIP_DEFINE3(hipError_t, hipCtxCreate, hipCtx_t *, unsigned int, hipDevice_t)
HIP_DEFINE1(hipError_t, hipCtxGetDevice, hipDevice_t*)
HIP_DEFINE1(hipError_t, hipCtxGetDevice, hipDevice_t *)
HIP_DEFINE1(hipError_t, hipCtxPushCurrent, hipCtx_t)
HIP_DEFINE1(hipError_t, hipCtxPopCurrent, hipCtx_t*)
HIP_DEFINE1(hipError_t, hipCtxPopCurrent, hipCtx_t *)
HIP_DEFINE2(hipError_t, hipCtxEnablePeerAccess, hipCtx_t, unsigned int)
HIP_DEFINE1(hipError_t, hipInit, unsigned int)
HIP_DEFINE1(hipError_t, hipDriverGetVersion, int *)
@@ -246,56 +332,64 @@ HIP_DEFINE1(hipError_t, hipDriverGetVersion, int *)
HIP_DEFINE2(hipError_t, hipGetDevice, hipDevice_t *, int)
HIP_DEFINE3(hipError_t, hipDeviceGetName, char *, int, hipDevice_t)
HIP_DEFINE3(hipError_t, hipDeviceGetPCIBusId, char *, int, hipDevice_t)
HIP_DEFINE3(hipError_t, hipDeviceGetAttribute, int *, hipDeviceAttribute_t, hipDevice_t)
HIP_DEFINE3(hipError_t, hipDeviceGetAttribute, int *, hipDeviceAttribute_t,
hipDevice_t)
HIP_DEFINE1(hipError_t, hipGetDeviceCount, int *)
// module management
HIP_DEFINE4(hipError_t, hipModuleGetGlobal, hipDeviceptr_t*, size_t*, hipModule_t, const char*)
HIP_DEFINE4(hipError_t, hipModuleGetGlobal, hipDeviceptr_t *, size_t *,
hipModule_t, const char *)
HIP_DEFINE2(hipError_t, hipModuleLoad, hipModule_t *, const char *)
HIP_DEFINE1(hipError_t, hipModuleUnload, hipModule_t)
HIP_DEFINE2(hipError_t, hipModuleLoadData, hipModule_t *, const void *)
HIP_DEFINE5(hipError_t, hipModuleLoadDataEx, hipModule_t *, const void *, unsigned int, hipJitOption *, void **)
HIP_DEFINE3(hipError_t, hipModuleGetFunction, hipFunction_t *, hipModule_t, const char *)
HIP_DEFINE5(hipError_t, hipModuleLoadDataEx, hipModule_t *, const void *,
unsigned int, hipJitOption *, void **)
HIP_DEFINE3(hipError_t, hipModuleGetFunction, hipFunction_t *, hipModule_t,
const char *)
// stream management
HIP_DEFINE2(hipError_t, hipStreamCreate, hipStream_t *, unsigned int)
HIP_DEFINE1(hipError_t, hipStreamSynchronize, hipStream_t)
HIP_DEFINE1(hipError_t, hipStreamDestroy, hipStream_t)
HIP_DEFINE11(hipError_t, hipModuleLaunchKernel, hipFunction_t, unsigned int, unsigned int, unsigned int, unsigned int, unsigned int, unsigned int, unsigned int, hipStream_t, void **, void **)
HIP_DEFINE11(hipError_t, hipModuleLaunchKernel, hipFunction_t, unsigned int,
unsigned int, unsigned int, unsigned int, unsigned int,
unsigned int, unsigned int, hipStream_t, void **, void **)
// function management
HIP_DEFINE2(hipError_t, hipFuncGetAttributes, hipFuncAttributes*, void*)
HIP_DEFINE2(hipError_t, hipFuncGetAttributes, hipFuncAttributes *, void *)
HIP_DEFINE2(hipError_t, hipFuncSetCacheConfig, hipFunction_t, hipFuncCache_t)
// memory management
HIP_DEFINE3(hipError_t, hipMemcpyDtoH, void *, hipDeviceptr_t, size_t)
HIP_DEFINE1(hipError_t, hipFree, hipDeviceptr_t)
HIP_DEFINE4(hipError_t, hipMemcpyDtoHAsync, void *, hipDeviceptr_t, size_t, hipStream_t)
HIP_DEFINE4(hipError_t, hipMemcpyHtoDAsync, hipDeviceptr_t, const void *, size_t, hipStream_t)
HIP_DEFINE3(hipError_t, hipMemcpyHtoD, hipDeviceptr_t, const void *, size_t )
HIP_DEFINE2(hipError_t, hipMalloc, hipDeviceptr_t*, size_t)
HIP_DEFINE3(hipError_t, hipPointerGetAttribute, void*, CUpointer_attribute, hipDeviceptr_t)
HIP_DEFINE4(hipError_t, hipMemsetD8Async, hipDeviceptr_t, unsigned char, size_t, hipStream_t)
HIP_DEFINE4(hipError_t, hipMemcpyDtoHAsync, void *, hipDeviceptr_t, size_t,
hipStream_t)
HIP_DEFINE4(hipError_t, hipMemcpyHtoDAsync, hipDeviceptr_t, const void *,
size_t, hipStream_t)
HIP_DEFINE3(hipError_t, hipMemcpyHtoD, hipDeviceptr_t, const void *, size_t)
HIP_DEFINE2(hipError_t, hipMalloc, hipDeviceptr_t *, size_t)
HIP_DEFINE3(hipError_t, hipPointerGetAttribute, void *, CUpointer_attribute,
hipDeviceptr_t)
HIP_DEFINE4(hipError_t, hipMemsetD8Async, hipDeviceptr_t, unsigned char, size_t,
hipStream_t)
// event management
HIP_DEFINE2(hipError_t, hipEventCreate, hipEvent_t *, unsigned int)
HIP_DEFINE3(hipError_t, hipEventElapsedTime, float *, hipEvent_t, hipEvent_t)
HIP_DEFINE2(hipError_t, hipEventRecord, hipEvent_t, hipStream_t)
HIP_DEFINE1(hipError_t, hipEventDestroy, hipEvent_t)
/* ------------------- *
* COMMON
* ------------------- */
// Release
void dispatch::release(){
if(cuda_){
void dispatch::release() {
if (cuda_) {
dlclose(cuda_);
cuda_ = nullptr;
}
}
void* dispatch::cuda_;
void* dispatch::nvml_;
void* dispatch::nvmlInit_v2_;
void* dispatch::hip_;
void *dispatch::cuda_;
void *dispatch::nvml_;
void *dispatch::nvmlInit_v2_;
void *dispatch::hip_;
}
}
} // namespace driver
} // namespace triton

410
lib/driver/error.cc Executable file → Normal file
View File

@@ -1,166 +1,270 @@
/* Copyright 2015-2017 Philippe Tillet
*
* Permission is hereby granted, free of charge, to any person obtaining
* a copy of this software and associated documentation files
* (the "Software"), to deal in the Software without restriction,
* including without limitation the rights to use, copy, modify, merge,
* publish, distribute, sublicense, and/or sell copies of the Software,
* and to permit persons to whom the Software is furnished to do so,
* subject to the following conditions:
*
* The above copyright notice and this permission notice shall be
* included in all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
* EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
*/
*
* Permission is hereby granted, free of charge, to any person obtaining
* a copy of this software and associated documentation files
* (the "Software"), to deal in the Software without restriction,
* including without limitation the rights to use, copy, modify, merge,
* publish, distribute, sublicense, and/or sell copies of the Software,
* and to permit persons to whom the Software is furnished to do so,
* subject to the following conditions:
*
* The above copyright notice and this permission notice shall be
* included in all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
* EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
*/
#include "triton/driver/error.h"
namespace triton
{
namespace driver
{
namespace triton {
namespace driver {
void check(CUresult err)
{
void check(CUresult err) {
using namespace exception::cuda;
switch(err)
{
case CUDA_SUCCESS : break;
case CUDA_ERROR_INVALID_VALUE : throw invalid_value();
case CUDA_ERROR_OUT_OF_MEMORY : throw out_of_memory();
case CUDA_ERROR_NOT_INITIALIZED : throw not_initialized();
case CUDA_ERROR_DEINITIALIZED : throw deinitialized();
case CUDA_ERROR_PROFILER_DISABLED : throw profiler_disabled();
case CUDA_ERROR_PROFILER_NOT_INITIALIZED : throw profiler_not_initialized();
case CUDA_ERROR_PROFILER_ALREADY_STARTED : throw profiler_already_started();
case CUDA_ERROR_PROFILER_ALREADY_STOPPED : throw profiler_already_stopped();
case CUDA_ERROR_NO_DEVICE : throw no_device();
case CUDA_ERROR_INVALID_DEVICE : throw invalid_device();
case CUDA_ERROR_INVALID_IMAGE : throw invalid_image();
case CUDA_ERROR_INVALID_CONTEXT : throw invalid_context();
case CUDA_ERROR_CONTEXT_ALREADY_CURRENT : throw context_already_current();
case CUDA_ERROR_MAP_FAILED : throw map_failed();
case CUDA_ERROR_UNMAP_FAILED : throw unmap_failed();
case CUDA_ERROR_ARRAY_IS_MAPPED : throw array_is_mapped();
case CUDA_ERROR_ALREADY_MAPPED : throw already_mapped();
case CUDA_ERROR_NO_BINARY_FOR_GPU : throw no_binary_for_gpu();
case CUDA_ERROR_ALREADY_ACQUIRED : throw already_acquired();
case CUDA_ERROR_NOT_MAPPED : throw not_mapped();
case CUDA_ERROR_NOT_MAPPED_AS_ARRAY : throw not_mapped_as_array();
case CUDA_ERROR_NOT_MAPPED_AS_POINTER : throw not_mapped_as_pointer();
case CUDA_ERROR_ECC_UNCORRECTABLE : throw ecc_uncorrectable();
case CUDA_ERROR_UNSUPPORTED_LIMIT : throw unsupported_limit();
case CUDA_ERROR_CONTEXT_ALREADY_IN_USE : throw context_already_in_use();
case CUDA_ERROR_PEER_ACCESS_UNSUPPORTED : throw peer_access_unsupported();
case CUDA_ERROR_INVALID_PTX : throw invalid_ptx();
case CUDA_ERROR_INVALID_GRAPHICS_CONTEXT : throw invalid_graphics_context();
case CUDA_ERROR_INVALID_SOURCE : throw invalid_source();
case CUDA_ERROR_FILE_NOT_FOUND : throw file_not_found();
case CUDA_ERROR_SHARED_OBJECT_SYMBOL_NOT_FOUND : throw shared_object_symbol_not_found();
case CUDA_ERROR_SHARED_OBJECT_INIT_FAILED : throw shared_object_init_failed();
case CUDA_ERROR_OPERATING_SYSTEM : throw operating_system();
case CUDA_ERROR_INVALID_HANDLE : throw invalid_handle();
case CUDA_ERROR_NOT_FOUND : throw not_found();
case CUDA_ERROR_NOT_READY : throw not_ready();
case CUDA_ERROR_ILLEGAL_ADDRESS : throw illegal_address();
case CUDA_ERROR_LAUNCH_OUT_OF_RESOURCES : throw launch_out_of_resources();
case CUDA_ERROR_LAUNCH_TIMEOUT : throw launch_timeout();
case CUDA_ERROR_LAUNCH_INCOMPATIBLE_TEXTURING : throw launch_incompatible_texturing();
case CUDA_ERROR_PEER_ACCESS_ALREADY_ENABLED : throw peer_access_already_enabled();
case CUDA_ERROR_PEER_ACCESS_NOT_ENABLED : throw peer_access_not_enabled();
case CUDA_ERROR_PRIMARY_CONTEXT_ACTIVE : throw primary_context_active();
case CUDA_ERROR_CONTEXT_IS_DESTROYED : throw context_is_destroyed();
case CUDA_ERROR_ASSERT : throw assert_error();
case CUDA_ERROR_TOO_MANY_PEERS : throw too_many_peers();
case CUDA_ERROR_HOST_MEMORY_ALREADY_REGISTERED : throw host_memory_already_registered();
case CUDA_ERROR_HOST_MEMORY_NOT_REGISTERED : throw host_memory_not_registered();
case CUDA_ERROR_HARDWARE_STACK_ERROR : throw hardware_stack_error();
case CUDA_ERROR_ILLEGAL_INSTRUCTION : throw illegal_instruction();
case CUDA_ERROR_MISALIGNED_ADDRESS : throw misaligned_address();
case CUDA_ERROR_INVALID_ADDRESS_SPACE : throw invalid_address_space();
case CUDA_ERROR_INVALID_PC : throw invalid_pc();
case CUDA_ERROR_LAUNCH_FAILED : throw launch_failed();
case CUDA_ERROR_NOT_PERMITTED : throw not_permitted();
case CUDA_ERROR_NOT_SUPPORTED : throw not_supported();
case CUDA_ERROR_UNKNOWN : throw unknown();
default : throw unknown();
switch (err) {
case CUDA_SUCCESS:
break;
case CUDA_ERROR_INVALID_VALUE:
throw invalid_value();
case CUDA_ERROR_OUT_OF_MEMORY:
throw out_of_memory();
case CUDA_ERROR_NOT_INITIALIZED:
throw not_initialized();
case CUDA_ERROR_DEINITIALIZED:
throw deinitialized();
case CUDA_ERROR_PROFILER_DISABLED:
throw profiler_disabled();
case CUDA_ERROR_PROFILER_NOT_INITIALIZED:
throw profiler_not_initialized();
case CUDA_ERROR_PROFILER_ALREADY_STARTED:
throw profiler_already_started();
case CUDA_ERROR_PROFILER_ALREADY_STOPPED:
throw profiler_already_stopped();
case CUDA_ERROR_NO_DEVICE:
throw no_device();
case CUDA_ERROR_INVALID_DEVICE:
throw invalid_device();
case CUDA_ERROR_INVALID_IMAGE:
throw invalid_image();
case CUDA_ERROR_INVALID_CONTEXT:
throw invalid_context();
case CUDA_ERROR_CONTEXT_ALREADY_CURRENT:
throw context_already_current();
case CUDA_ERROR_MAP_FAILED:
throw map_failed();
case CUDA_ERROR_UNMAP_FAILED:
throw unmap_failed();
case CUDA_ERROR_ARRAY_IS_MAPPED:
throw array_is_mapped();
case CUDA_ERROR_ALREADY_MAPPED:
throw already_mapped();
case CUDA_ERROR_NO_BINARY_FOR_GPU:
throw no_binary_for_gpu();
case CUDA_ERROR_ALREADY_ACQUIRED:
throw already_acquired();
case CUDA_ERROR_NOT_MAPPED:
throw not_mapped();
case CUDA_ERROR_NOT_MAPPED_AS_ARRAY:
throw not_mapped_as_array();
case CUDA_ERROR_NOT_MAPPED_AS_POINTER:
throw not_mapped_as_pointer();
case CUDA_ERROR_ECC_UNCORRECTABLE:
throw ecc_uncorrectable();
case CUDA_ERROR_UNSUPPORTED_LIMIT:
throw unsupported_limit();
case CUDA_ERROR_CONTEXT_ALREADY_IN_USE:
throw context_already_in_use();
case CUDA_ERROR_PEER_ACCESS_UNSUPPORTED:
throw peer_access_unsupported();
case CUDA_ERROR_INVALID_PTX:
throw invalid_ptx();
case CUDA_ERROR_INVALID_GRAPHICS_CONTEXT:
throw invalid_graphics_context();
case CUDA_ERROR_INVALID_SOURCE:
throw invalid_source();
case CUDA_ERROR_FILE_NOT_FOUND:
throw file_not_found();
case CUDA_ERROR_SHARED_OBJECT_SYMBOL_NOT_FOUND:
throw shared_object_symbol_not_found();
case CUDA_ERROR_SHARED_OBJECT_INIT_FAILED:
throw shared_object_init_failed();
case CUDA_ERROR_OPERATING_SYSTEM:
throw operating_system();
case CUDA_ERROR_INVALID_HANDLE:
throw invalid_handle();
case CUDA_ERROR_NOT_FOUND:
throw not_found();
case CUDA_ERROR_NOT_READY:
throw not_ready();
case CUDA_ERROR_ILLEGAL_ADDRESS:
throw illegal_address();
case CUDA_ERROR_LAUNCH_OUT_OF_RESOURCES:
throw launch_out_of_resources();
case CUDA_ERROR_LAUNCH_TIMEOUT:
throw launch_timeout();
case CUDA_ERROR_LAUNCH_INCOMPATIBLE_TEXTURING:
throw launch_incompatible_texturing();
case CUDA_ERROR_PEER_ACCESS_ALREADY_ENABLED:
throw peer_access_already_enabled();
case CUDA_ERROR_PEER_ACCESS_NOT_ENABLED:
throw peer_access_not_enabled();
case CUDA_ERROR_PRIMARY_CONTEXT_ACTIVE:
throw primary_context_active();
case CUDA_ERROR_CONTEXT_IS_DESTROYED:
throw context_is_destroyed();
case CUDA_ERROR_ASSERT:
throw assert_error();
case CUDA_ERROR_TOO_MANY_PEERS:
throw too_many_peers();
case CUDA_ERROR_HOST_MEMORY_ALREADY_REGISTERED:
throw host_memory_already_registered();
case CUDA_ERROR_HOST_MEMORY_NOT_REGISTERED:
throw host_memory_not_registered();
case CUDA_ERROR_HARDWARE_STACK_ERROR:
throw hardware_stack_error();
case CUDA_ERROR_ILLEGAL_INSTRUCTION:
throw illegal_instruction();
case CUDA_ERROR_MISALIGNED_ADDRESS:
throw misaligned_address();
case CUDA_ERROR_INVALID_ADDRESS_SPACE:
throw invalid_address_space();
case CUDA_ERROR_INVALID_PC:
throw invalid_pc();
case CUDA_ERROR_LAUNCH_FAILED:
throw launch_failed();
case CUDA_ERROR_NOT_PERMITTED:
throw not_permitted();
case CUDA_ERROR_NOT_SUPPORTED:
throw not_supported();
case CUDA_ERROR_UNKNOWN:
throw unknown();
default:
throw unknown();
}
}
void check(hipError_t error) {
using namespace exception::hip;
switch(error)
{
case hipSuccess : break;
case hipErrorInvalidValue : throw invalid_value();
case hipErrorMemoryAllocation : throw out_of_memory();
case hipErrorNotInitialized : throw not_initialized();
case hipErrorDeinitialized : throw deinitialized();
case hipErrorProfilerDisabled : throw profiler_disabled();
case hipErrorProfilerNotInitialized : throw profiler_not_initialized();
case hipErrorProfilerAlreadyStarted : throw profiler_already_started();
case hipErrorProfilerAlreadyStopped : throw profiler_already_stopped();
case hipErrorNoDevice : throw no_device();
case hipErrorInvalidSymbol : throw invalid_symbol();
case hipErrorInvalidDevice : throw invalid_device();
case hipErrorInvalidImage : throw invalid_image();
case hipErrorInvalidContext : throw invalid_context();
case hipErrorContextAlreadyCurrent : throw context_already_current();
case hipErrorMapFailed : throw map_failed();
case hipErrorUnmapFailed : throw unmap_failed();
case hipErrorArrayIsMapped : throw array_is_mapped();
case hipErrorAlreadyMapped : throw already_mapped();
case hipErrorNoBinaryForGpu : throw no_binary_for_gpu();
case hipErrorAlreadyAcquired : throw already_acquired();
case hipErrorNotMapped : throw not_mapped();
case hipErrorNotMappedAsArray : throw not_mapped_as_array();
case hipErrorNotMappedAsPointer : throw not_mapped_as_pointer();
case hipErrorECCNotCorrectable : throw ecc_uncorrectable();
case hipErrorUnsupportedLimit : throw unsupported_limit();
case hipErrorContextAlreadyInUse : throw context_already_in_use();
case hipErrorPeerAccessUnsupported : throw peer_access_unsupported();
case hipErrorInvalidKernelFile : throw invalid_ptx();
case hipErrorInvalidGraphicsContext : throw invalid_graphics_context();
case hipErrorInvalidSource : throw invalid_source();
case hipErrorFileNotFound : throw file_not_found();
case hipErrorSharedObjectSymbolNotFound : throw shared_object_symbol_not_found();
case hipErrorSharedObjectInitFailed : throw shared_object_init_failed();
case hipErrorOperatingSystem : throw operating_system();
case hipErrorInvalidResourceHandle : throw invalid_handle();
case hipErrorNotFound : throw not_found();
case hipErrorNotReady : throw not_ready();
case hipErrorIllegalAddress : throw illegal_address();
case hipErrorLaunchOutOfResources : throw launch_out_of_resources();
case hipErrorLaunchTimeOut : throw launch_timeout();
// case hipErrorLaunchIncompatibleTexturing : throw launch_incompatible_texturing();
case hipErrorPeerAccessAlreadyEnabled : throw peer_access_already_enabled();
case hipErrorPeerAccessNotEnabled : throw peer_access_not_enabled();
// case hipErrorPrimaryContextActive : throw primary_context_active();
// case hipErrorContextIsDestroyed : throw context_is_destroyed();
case hipErrorAssert : throw assert_error();
// case hipErrorTooManyPeers : throw too_many_peers();
case hipErrorHostMemoryAlreadyRegistered : throw host_memory_already_registered();
case hipErrorHostMemoryNotRegistered : throw host_memory_not_registered();
// case hipErrorHardwareStackError : throw hardware_stack_error();
// case hipErrorIllegalInstruction : throw illegal_instruction();
// case hipErrorMisalignedAddress : throw misaligned_address();
// case hipErrorInvalidAddressSpace : throw invalid_address_space();
// case hipErrorInvalidPc : throw invalid_pc();
case hipErrorLaunchFailure : throw launch_failed();
// case hipErrorNotPermitted : throw not_permitted();
case hipErrorNotSupported : throw not_supported();
case hipErrorUnknown : throw unknown();
default : throw unknown();
}
}
}
switch (error) {
case hipSuccess:
break;
case hipErrorInvalidValue:
throw invalid_value();
case hipErrorMemoryAllocation:
throw out_of_memory();
case hipErrorNotInitialized:
throw not_initialized();
case hipErrorDeinitialized:
throw deinitialized();
case hipErrorProfilerDisabled:
throw profiler_disabled();
case hipErrorProfilerNotInitialized:
throw profiler_not_initialized();
case hipErrorProfilerAlreadyStarted:
throw profiler_already_started();
case hipErrorProfilerAlreadyStopped:
throw profiler_already_stopped();
case hipErrorNoDevice:
throw no_device();
case hipErrorInvalidSymbol:
throw invalid_symbol();
case hipErrorInvalidDevice:
throw invalid_device();
case hipErrorInvalidImage:
throw invalid_image();
case hipErrorInvalidContext:
throw invalid_context();
case hipErrorContextAlreadyCurrent:
throw context_already_current();
case hipErrorMapFailed:
throw map_failed();
case hipErrorUnmapFailed:
throw unmap_failed();
case hipErrorArrayIsMapped:
throw array_is_mapped();
case hipErrorAlreadyMapped:
throw already_mapped();
case hipErrorNoBinaryForGpu:
throw no_binary_for_gpu();
case hipErrorAlreadyAcquired:
throw already_acquired();
case hipErrorNotMapped:
throw not_mapped();
case hipErrorNotMappedAsArray:
throw not_mapped_as_array();
case hipErrorNotMappedAsPointer:
throw not_mapped_as_pointer();
case hipErrorECCNotCorrectable:
throw ecc_uncorrectable();
case hipErrorUnsupportedLimit:
throw unsupported_limit();
case hipErrorContextAlreadyInUse:
throw context_already_in_use();
case hipErrorPeerAccessUnsupported:
throw peer_access_unsupported();
case hipErrorInvalidKernelFile:
throw invalid_ptx();
case hipErrorInvalidGraphicsContext:
throw invalid_graphics_context();
case hipErrorInvalidSource:
throw invalid_source();
case hipErrorFileNotFound:
throw file_not_found();
case hipErrorSharedObjectSymbolNotFound:
throw shared_object_symbol_not_found();
case hipErrorSharedObjectInitFailed:
throw shared_object_init_failed();
case hipErrorOperatingSystem:
throw operating_system();
case hipErrorInvalidResourceHandle:
throw invalid_handle();
case hipErrorNotFound:
throw not_found();
case hipErrorNotReady:
throw not_ready();
case hipErrorIllegalAddress:
throw illegal_address();
case hipErrorLaunchOutOfResources:
throw launch_out_of_resources();
case hipErrorLaunchTimeOut:
throw launch_timeout();
// case hipErrorLaunchIncompatibleTexturing : throw
// launch_incompatible_texturing();
case hipErrorPeerAccessAlreadyEnabled:
throw peer_access_already_enabled();
case hipErrorPeerAccessNotEnabled:
throw peer_access_not_enabled();
// case hipErrorPrimaryContextActive : throw primary_context_active();
// case hipErrorContextIsDestroyed : throw context_is_destroyed();
case hipErrorAssert:
throw assert_error();
// case hipErrorTooManyPeers : throw too_many_peers();
case hipErrorHostMemoryAlreadyRegistered:
throw host_memory_already_registered();
case hipErrorHostMemoryNotRegistered:
throw host_memory_not_registered();
// case hipErrorHardwareStackError : throw hardware_stack_error();
// case hipErrorIllegalInstruction : throw illegal_instruction();
// case hipErrorMisalignedAddress : throw misaligned_address();
// case hipErrorInvalidAddressSpace : throw invalid_address_space();
// case hipErrorInvalidPc : throw invalid_pc();
case hipErrorLaunchFailure:
throw launch_failed();
// case hipErrorNotPermitted : throw not_permitted();
case hipErrorNotSupported:
throw not_supported();
case hipErrorUnknown:
throw unknown();
default:
throw unknown();
}
}
} // namespace driver
} // namespace triton

View File

@@ -1,73 +1,73 @@
/* Copyright 2015-2017 Philippe Tillet
*
* Permission is hereby granted, free of charge, to any person obtaining
* a copy of this software and associated documentation files
* (the "Software"), to deal in the Software without restriction,
* including without limitation the rights to use, copy, modify, merge,
* publish, distribute, sublicense, and/or sell copies of the Software,
* and to permit persons to whom the Software is furnished to do so,
* subject to the following conditions:
*
* The above copyright notice and this permission notice shall be
* included in all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
* EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
*/
*
* Permission is hereby granted, free of charge, to any person obtaining
* a copy of this software and associated documentation files
* (the "Software"), to deal in the Software without restriction,
* including without limitation the rights to use, copy, modify, merge,
* publish, distribute, sublicense, and/or sell copies of the Software,
* and to permit persons to whom the Software is furnished to do so,
* subject to the following conditions:
*
* The above copyright notice and this permission notice shall be
* included in all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
* EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
*/
#include <fstream>
#if __has_include(<unistd.h>)
#include <unistd.h>
#include <unistd.h>
#endif
#include <memory>
#include <regex>
#include "triton/driver/llvm.h"
#include "triton/driver/dispatch.h"
#include "triton/driver/error.h"
#include "triton/driver/llvm.h"
#include "triton/tools/sha1.hpp"
#include "triton/tools/sys/exec.hpp"
#include "triton/tools/sys/getenv.hpp"
#include "triton/tools/sys/mkdir.hpp"
#include "triton/tools/sys/exec.hpp"
#include "llvm/MC/TargetRegistry.h"
#include "llvm/ExecutionEngine/ExecutionEngine.h"
#include "llvm/ExecutionEngine/SectionMemoryManager.h"
#include "llvm/IR/IRBuilder.h"
#include "llvm/IR/Verifier.h"
#include "llvm/IR/IRPrintingPasses.h"
#include "llvm/IR/LegacyPassManager.h"
#include "llvm/IR/Module.h"
#include "llvm/IR/Verifier.h"
#include "llvm/MC/TargetRegistry.h"
#include "llvm/Support/CodeGen.h"
#include "llvm/Support/CommandLine.h"
#include "llvm/Support/SourceMgr.h"
#include "llvm/Support/raw_ostream.h"
#include "llvm/Support/TargetSelect.h"
#include "llvm/Support/raw_ostream.h"
#include "llvm/Target/TargetMachine.h"
#include "llvm/Target/TargetOptions.h"
#include "llvm/IR/LegacyPassManager.h"
#include "llvm/ExecutionEngine/ExecutionEngine.h"
#include "llvm/ExecutionEngine/SectionMemoryManager.h"
#include "llvm/Transforms/Utils/Cloning.h"
#include "llvm/Transforms/Scalar.h"
#include "llvm/Transforms/Utils/Cloning.h"
#include <memory>
#include <regex>
// begin AMD stuff
#include "llvm/ADT/StringRef.h"
#include "llvm/Analysis/TargetLibraryInfo.h"
#include "llvm/Support/FileSystem.h"
#include "llvm/Support/FormattedStream.h"
#include "llvm/Support/Program.h"
#include "llvm/Support/ToolOutputFile.h"
#include "llvm/ADT/StringRef.h"
#include "llvm/Analysis/TargetLibraryInfo.h"
// end AMD stuff
extern "C"{
int set_curterm(char* nterm){ return 0; }
int del_curterm(char* nterm){ return 0; }
int tigetnum(char *capname) { return 0; }
int setupterm(char *term, int fildes, int *errret) { return 0; }
extern "C" {
int set_curterm(char *nterm) { return 0; }
int del_curterm(char *nterm) { return 0; }
int tigetnum(char *capname) { return 0; }
int setupterm(char *term, int fildes, int *errret) { return 0; }
}
namespace triton{
namespace driver{
namespace triton {
namespace driver {
void init_llvm() {
LLVMInitializeNVPTXTargetInfo();
@@ -80,82 +80,93 @@ void init_llvm() {
LLVMInitializeAMDGPUAsmPrinter();
}
/* ------------------------ */
// CUDA //
/* ------------------------ */
static bool find_and_replace(std::string& str, const std::string& begin, const std::string& end, const std::string& target){
static bool find_and_replace(std::string &str, const std::string &begin,
const std::string &end,
const std::string &target) {
size_t start_replace = str.find(begin);
size_t end_replace = str.find(end, start_replace);
if(start_replace == std::string::npos)
if (start_replace == std::string::npos)
return false;
str.replace(start_replace, end_replace + 1 - start_replace, target);
return true;
}
std::string path_to_ptxas(int& version) {
std::string path_to_ptxas(int &version) {
std::vector<std::string> rets;
std::string ret;
// search pathes for ptxas
std::vector<std::string> ptxas_prefixes = {"", "/usr/local/cuda/bin/"};
std::string triton_ptxas = tools::getenv("TRITON_PTXAS_PATH");
if(!triton_ptxas.empty())
if (!triton_ptxas.empty())
ptxas_prefixes.insert(ptxas_prefixes.begin(), triton_ptxas);
// see what path for ptxas are valid
std::vector<std::string> working_ptxas;
for(std::string prefix: ptxas_prefixes){
for (std::string prefix : ptxas_prefixes) {
std::string ptxas = prefix + "ptxas";
bool works = tools::exec(ptxas + " --version 2>&1", ret) == 0;
if(works) {
if (works) {
working_ptxas.push_back(ptxas);
rets.push_back(ret);
}
}
// error if no working ptxas was found
if(working_ptxas.empty())
throw std::runtime_error("`ptxas` was searched in TRITON_PTXAS_PATH, /usr/local/cuda/bin/ or PATH"
if (working_ptxas.empty())
throw std::runtime_error("`ptxas` was searched in TRITON_PTXAS_PATH, "
"/usr/local/cuda/bin/ or PATH"
" but a working version could not be found.");
std::string ptxas = working_ptxas.front();
// parse version
std::regex version_regex("release (\\d+)\\.(\\d+)");
std::smatch match;
bool found = false;
// currently choosing the first ptxas. Other logics can be implemented in future
for(std::string ret : rets) {
if(std::regex_search(ret, match, version_regex)){
// currently choosing the first ptxas. Other logics can be implemented in
// future
for (std::string ret : rets) {
if (std::regex_search(ret, match, version_regex)) {
int major = std::stoi(match[1]);
int minor = std::stoi(match[2]);
version = major*1000 + minor*10;
version = major * 1000 + minor * 10;
found = true;
break;
}
}
if ( not found) {
if (not found) {
throw std::runtime_error("Error in parsing version");
}
return ptxas;
}
int vptx(int version){
if(version >= 11040) return 74;
if(version >= 11030) return 73;
if(version >= 11020) return 72;
if(version >= 11010) return 71;
if(version >= 11000) return 70;
if(version >= 10020) return 65;
if(version >= 10010) return 64;
if(version >= 10000) return 63;
int vptx(int version) {
if (version >= 11040)
return 74;
if (version >= 11030)
return 73;
if (version >= 11020)
return 72;
if (version >= 11010)
return 71;
if (version >= 11000)
return 70;
if (version >= 10020)
return 65;
if (version >= 10010)
return 64;
if (version >= 10000)
return 63;
throw std::runtime_error("Triton requires CUDA 10+");
}
std::string llir_to_ptx(llvm::Module* module, int cc, int version){
std::string llir_to_ptx(llvm::Module *module, int cc, int version) {
// LLVM version in use may not officially support target hardware
int max_nvvm_cc = 75;
int max_nvvm_ptx = 74;
// options
auto options = llvm::cl::getRegisteredOptions();
auto* short_ptr = static_cast<llvm::cl::opt<bool>*>(options["nvptx-short-ptr"]);
auto *short_ptr =
static_cast<llvm::cl::opt<bool> *>(options["nvptx-short-ptr"]);
assert(short_ptr);
short_ptr->setValue(true);
// compute capability
@@ -170,7 +181,8 @@ std::string llir_to_ptx(llvm::Module* module, int cc, int version){
std::string proc = "sm_" + std::to_string(std::min(cc, max_nvvm_cc));
std::string layout = "";
std::string features = "";
// std::string features = "+ptx" + std::to_string(std::min(ptx, max_nvvm_ptx));
// std::string features = "+ptx" + std::to_string(std::min(ptx,
// max_nvvm_ptx));
init_llvm();
// verify and store llvm
llvm::legacy::PassManager pm;
@@ -181,16 +193,18 @@ std::string llir_to_ptx(llvm::Module* module, int cc, int version){
// create machine
module->setTargetTriple(triple);
std::string error;
auto target = llvm::TargetRegistry::lookupTarget(module->getTargetTriple(), error);
auto target =
llvm::TargetRegistry::lookupTarget(module->getTargetTriple(), error);
llvm::TargetOptions opt;
opt.AllowFPOpFusion = llvm::FPOpFusion::Fast;
opt.UnsafeFPMath = false;
opt.NoInfsFPMath = false;
opt.NoNaNsFPMath = true;
llvm::TargetMachine *machine = target->createTargetMachine(module->getTargetTriple(), proc, features, opt,
llvm::Reloc::PIC_, llvm::None, llvm::CodeGenOpt::Aggressive);
llvm::TargetMachine *machine = target->createTargetMachine(
module->getTargetTriple(), proc, features, opt, llvm::Reloc::PIC_,
llvm::None, llvm::CodeGenOpt::Aggressive);
// set data layout
if(layout.empty())
if (layout.empty())
module->setDataLayout(machine->createDataLayout());
else
module->setDataLayout(layout);
@@ -200,19 +214,25 @@ std::string llir_to_ptx(llvm::Module* module, int cc, int version){
llvm::legacy::PassManager pass;
llvm::raw_svector_ostream stream(buffer);
// emit
machine->addPassesToEmitFile(pass, stream, nullptr, llvm::CodeGenFileType::CGFT_AssemblyFile);
machine->addPassesToEmitFile(pass, stream, nullptr,
llvm::CodeGenFileType::CGFT_AssemblyFile);
pass.run(*module);
// post-process
std::string result(buffer.begin(), buffer.end());
find_and_replace(result, ".version", "\n", ".version " + std::to_string(ptx_major) + "." + std::to_string(ptx_minor) + "\n");
find_and_replace(result, ".version", "\n",
".version " + std::to_string(ptx_major) + "." +
std::to_string(ptx_minor) + "\n");
find_and_replace(result, ".target", "\n", ".target " + sm + "\n");
while(find_and_replace(result, "\t// begin inline asm", "\n", ""));
while(find_and_replace(result, "\t// end inline asm", "\n", ""));
while (find_and_replace(result, "\t// begin inline asm", "\n", ""))
;
while (find_and_replace(result, "\t// end inline asm", "\n", ""))
;
return result;
}
std::string ptx_to_cubin(const std::string& ptx, const std::string& ptxas, int cc) {
std::string ptx_to_cubin(const std::string &ptx, const std::string &ptxas,
int cc) {
// compile ptx with ptxas
char _fsrc[L_tmpnam];
char _flog[L_tmpnam];
@@ -221,15 +241,16 @@ std::string ptx_to_cubin(const std::string& ptx, const std::string& ptxas, int c
std::string fsrc = _fsrc;
std::string flog = _flog;
std::string fbin = fsrc + ".o";
const char* _fbin = fbin.c_str();
const char *_fbin = fbin.c_str();
std::ofstream ofs(fsrc);
ofs << ptx << std::endl;
ofs.close();
std::string cmd;
int err;
cmd = ptxas + " -v --gpu-name=sm_" + std::to_string(cc) + " " + fsrc + " -o " + fsrc + ".o 2> " + flog;
cmd = ptxas + " -v --gpu-name=sm_" + std::to_string(cc) + " " + fsrc +
" -o " + fsrc + ".o 2> " + flog;
err = system(cmd.c_str());
if(err != 0){
if (err != 0) {
std::ifstream _log(_flog);
std::string log(std::istreambuf_iterator<char>(_log), {});
unlink(_fsrc);
@@ -237,7 +258,7 @@ std::string ptx_to_cubin(const std::string& ptx, const std::string& ptxas, int c
throw std::runtime_error("Internal Triton PTX codegen error: \n" + log);
}
CUmodule ret;
std::ifstream _cubin(_fbin, std::ios::binary );
std::ifstream _cubin(_fbin, std::ios::binary);
std::string cubin(std::istreambuf_iterator<char>(_cubin), {});
_cubin.close();
unlink(_fsrc);
@@ -251,11 +272,11 @@ std::string ptx_to_cubin(const std::string& ptx, const std::string& ptxas, int c
// HIP //
/* ------------------------ */
std::string llir_to_amdgpu(llvm::Module* module, const std::string& _proc) {
std::string llir_to_amdgpu(llvm::Module *module, const std::string &_proc) {
init_llvm();
// proc = std::get<0>(GetFeatureStrFromGCNArchName(rocminfo));
// features = std::get<1>(GetFeatureStrFromGCNArchName(rocminfo));
// proc = std::get<0>(GetFeatureStrFromGCNArchName(rocminfo));
// features = std::get<1>(GetFeatureStrFromGCNArchName(rocminfo));
// create
llvm::SmallVector<char, 0> buffer;
@@ -270,17 +291,18 @@ std::string llir_to_amdgpu(llvm::Module* module, const std::string& _proc) {
// create machine
module->setTargetTriple(triple);
std::string error;
auto target = llvm::TargetRegistry::lookupTarget(module->getTargetTriple(), error);
auto target =
llvm::TargetRegistry::lookupTarget(module->getTargetTriple(), error);
llvm::TargetOptions opt;
opt.AllowFPOpFusion = llvm::FPOpFusion::Fast;
opt.UnsafeFPMath = false;
opt.NoInfsFPMath = false;
opt.NoNaNsFPMath = true;
llvm::TargetMachine *machine = target->createTargetMachine(module->getTargetTriple(), proc, features, opt,
llvm::Reloc::PIC_, llvm::None,
llvm::CodeGenOpt::Aggressive);
llvm::TargetMachine *machine = target->createTargetMachine(
module->getTargetTriple(), proc, features, opt, llvm::Reloc::PIC_,
llvm::None, llvm::CodeGenOpt::Aggressive);
// set data layout
if(layout.empty())
if (layout.empty())
module->setDataLayout(machine->createDataLayout());
else
module->setDataLayout(layout);
@@ -295,33 +317,37 @@ std::string llir_to_amdgpu(llvm::Module* module, const std::string& _proc) {
std::error_code ec;
// Save GCN ISA binary.
std::string isabin_path = std::string("/tmp/") + module_name + std::string(".o");
std::string isabin_path =
std::string("/tmp/") + module_name + std::string(".o");
std::unique_ptr<llvm::raw_fd_ostream> isabin_fs(
new llvm::raw_fd_ostream(isabin_path, ec, llvm::sys::fs::OF_Text));
if (ec)
{
std::cout << isabin_path << " was not created. error code: " << ec << std::endl;
if (ec) {
std::cout << isabin_path << " was not created. error code: " << ec
<< std::endl;
}
// emit
machine->addPassesToEmitFile(pass, *isabin_fs, nullptr, llvm::CGFT_ObjectFile);
machine->addPassesToEmitFile(pass, *isabin_fs, nullptr,
llvm::CGFT_ObjectFile);
pass.run(*module);
// Save GCN ISA.
std::string amdgcn_path = std::string("/tmp/") + module_name + std::string(".gcn");
std::string amdgcn_path =
std::string("/tmp/") + module_name + std::string(".gcn");
std::string result(buffer.begin(), buffer.end());
std::ofstream amdgcn(amdgcn_path);
amdgcn << result;
amdgcn.close();
// generate HASCO file
std::string hsaco_path = std::string("/tmp/") + module_name + std::string(".hsaco");
std::string hsaco_path =
std::string("/tmp/") + module_name + std::string(".hsaco");
std::string error_message;
int lld_result =
llvm::sys::ExecuteAndWait("/opt/rocm/llvm/bin/ld.lld",
{"/opt/rocm/llvm/bin/ld.lld", "-flavor", "gnu", "-shared", "-o", hsaco_path, isabin_path},
{"/opt/rocm/llvm/bin/ld.lld", "-flavor", "gnu",
"-shared", "-o", hsaco_path, isabin_path},
llvm::None, {}, 0, 0, &error_message);
if (lld_result)
{
if (lld_result) {
std::cout << "ld.lld execute fail: " << std::endl;
std::cout << error_message << std::endl;
std::cout << lld_result << std::endl;
@@ -330,33 +356,29 @@ std::string llir_to_amdgpu(llvm::Module* module, const std::string& _proc) {
return hsaco_path;
}
hipModule_t amdgpu_to_hipmodule(const std::string& path) {
hipModule_t amdgpu_to_hipmodule(const std::string &path) {
// Read HSACO.
std::ifstream hsaco_file(path, std::ios::binary | std::ios::ate);
std::ifstream::pos_type hsaco_file_size = hsaco_file.tellg();
std::vector<unsigned char> hsaco(hsaco_file_size);
hsaco_file.seekg(0, std::ios::beg);
hsaco_file.read(reinterpret_cast<char*>(&hsaco[0]), hsaco_file_size);
hsaco_file.read(reinterpret_cast<char *>(&hsaco[0]), hsaco_file_size);
hsaco_file.close();
hipJitOption opt[] = {hipJitOptionErrorLogBufferSizeBytes, hipJitOptionErrorLogBuffer,
hipJitOptionInfoLogBufferSizeBytes, hipJitOptionInfoLogBuffer,
hipJitOptionLogVerbose};
hipJitOption opt[] = {hipJitOptionErrorLogBufferSizeBytes,
hipJitOptionErrorLogBuffer,
hipJitOptionInfoLogBufferSizeBytes,
hipJitOptionInfoLogBuffer, hipJitOptionLogVerbose};
const unsigned int errbufsize = 8192;
const unsigned int logbufsize = 8192;
char _err[errbufsize];
char _log[logbufsize];
void* optval[] = {(void*)(uintptr_t)errbufsize,
(void*)_err, (void*)(uintptr_t)logbufsize,
(void*)_log, (void*)1};
void *optval[] = {(void *)(uintptr_t)errbufsize, (void *)_err,
(void *)(uintptr_t)logbufsize, (void *)_log, (void *)1};
hipModule_t ret;
dispatch::hipModuleLoadDataEx(&ret, hsaco.data(), 5, opt, optval);
return ret;
}
}
}
} // namespace driver
} // namespace triton

View File

@@ -18,60 +18,83 @@ NAMESPACE_BEGIN(PYBIND11_NAMESPACE)
/// @{
/// Annotation for methods
struct is_method { handle class_; is_method(const handle &c) : class_(c) { } };
struct is_method {
handle class_;
is_method(const handle &c) : class_(c) {}
};
/// Annotation for operators
struct is_operator { };
struct is_operator {};
/// Annotation for parent scope
struct scope { handle value; scope(const handle &s) : value(s) { } };
struct scope {
handle value;
scope(const handle &s) : value(s) {}
};
/// Annotation for documentation
struct doc { const char *value; doc(const char *value) : value(value) { } };
struct doc {
const char *value;
doc(const char *value) : value(value) {}
};
/// Annotation for function names
struct name { const char *value; name(const char *value) : value(value) { } };
struct name {
const char *value;
name(const char *value) : value(value) {}
};
/// Annotation indicating that a function is an overload associated with a given "sibling"
struct sibling { handle value; sibling(const handle &value) : value(value.ptr()) { } };
/// Annotation indicating that a function is an overload associated with a given
/// "sibling"
struct sibling {
handle value;
sibling(const handle &value) : value(value.ptr()) {}
};
/// Annotation indicating that a class derives from another given type
template <typename T> struct base {
PYBIND11_DEPRECATED("base<T>() was deprecated in favor of specifying 'T' as a template argument to class_")
base() { }
PYBIND11_DEPRECATED("base<T>() was deprecated in favor of specifying 'T' as "
"a template argument to class_")
base() {}
};
/// Keep patient alive while nurse lives
template <size_t Nurse, size_t Patient> struct keep_alive { };
template <size_t Nurse, size_t Patient> struct keep_alive {};
/// Annotation indicating that a class is involved in a multiple inheritance relationship
struct multiple_inheritance { };
/// Annotation indicating that a class is involved in a multiple inheritance
/// relationship
struct multiple_inheritance {};
/// Annotation which enables dynamic attributes, i.e. adds `__dict__` to a class
struct dynamic_attr { };
struct dynamic_attr {};
/// Annotation which enables the buffer protocol for a type
struct buffer_protocol { };
struct buffer_protocol {};
/// Annotation which requests that a special metaclass is created for a type
struct metaclass {
handle value;
handle value;
PYBIND11_DEPRECATED("py::metaclass() is no longer required. It's turned on by default now.")
metaclass() {}
PYBIND11_DEPRECATED(
"py::metaclass() is no longer required. It's turned on by default now.")
metaclass() {}
/// Override pybind11's default metaclass
explicit metaclass(handle value) : value(value) { }
/// Override pybind11's default metaclass
explicit metaclass(handle value) : value(value) {}
};
/// Annotation that marks a class as local to the module:
struct module_local { const bool value; constexpr module_local(bool v = true) : value(v) { } };
struct module_local {
const bool value;
constexpr module_local(bool v = true) : value(v) {}
};
/// Annotation to mark enums as an arithmetic type
struct arithmetic { };
struct arithmetic {};
/** \rst
A call policy which places one or more guard variables (``Ts...``) around the function call.
A call policy which places one or more guard variables (``Ts...``) around
the function call.
For example, this definition:
@@ -92,20 +115,19 @@ template <typename... Ts> struct call_guard;
template <> struct call_guard<> { using type = detail::void_type; };
template <typename T>
struct call_guard<T> {
static_assert(std::is_default_constructible<T>::value,
"The guard type must be default constructible");
template <typename T> struct call_guard<T> {
static_assert(std::is_default_constructible<T>::value,
"The guard type must be default constructible");
using type = T;
using type = T;
};
template <typename T, typename... Ts>
struct call_guard<T, Ts...> {
struct type {
T guard{}; // Compose multiple guard types with left-to-right default-constructor order
typename call_guard<Ts...>::type next{};
};
template <typename T, typename... Ts> struct call_guard<T, Ts...> {
struct type {
T guard{}; // Compose multiple guard types with left-to-right
// default-constructor order
typename call_guard<Ts...>::type next{};
};
};
/// @} annotations
@@ -115,181 +137,190 @@ NAMESPACE_BEGIN(detail)
enum op_id : int;
enum op_type : int;
struct undefined_t;
template <op_id id, op_type ot, typename L = undefined_t, typename R = undefined_t> struct op_;
inline void keep_alive_impl(size_t Nurse, size_t Patient, function_call &call, handle ret);
template <op_id id, op_type ot, typename L = undefined_t,
typename R = undefined_t>
struct op_;
inline void keep_alive_impl(size_t Nurse, size_t Patient, function_call &call,
handle ret);
/// Internal data structure which holds metadata about a keyword argument
struct argument_record {
const char *name; ///< Argument name
const char *descr; ///< Human-readable version of the argument value
handle value; ///< Associated Python object
bool convert : 1; ///< True if the argument is allowed to convert when loading
bool none : 1; ///< True if None is allowed when loading
const char *name; ///< Argument name
const char *descr; ///< Human-readable version of the argument value
handle value; ///< Associated Python object
bool convert : 1; ///< True if the argument is allowed to convert when loading
bool none : 1; ///< True if None is allowed when loading
argument_record(const char *name, const char *descr, handle value, bool convert, bool none)
: name(name), descr(descr), value(value), convert(convert), none(none) { }
argument_record(const char *name, const char *descr, handle value,
bool convert, bool none)
: name(name), descr(descr), value(value), convert(convert), none(none) {}
};
/// Internal data structure which holds metadata about a bound function (signature, overloads, etc.)
/// Internal data structure which holds metadata about a bound function
/// (signature, overloads, etc.)
struct function_record {
function_record()
: is_constructor(false), is_new_style_constructor(false), is_stateless(false),
is_operator(false), has_args(false), has_kwargs(false), is_method(false) { }
function_record()
: is_constructor(false), is_new_style_constructor(false),
is_stateless(false), is_operator(false), has_args(false),
has_kwargs(false), is_method(false) {}
/// Function name
char *name = nullptr; /* why no C++ strings? They generate heavier code.. */
/// Function name
char *name = nullptr; /* why no C++ strings? They generate heavier code.. */
// User-specified documentation string
char *doc = nullptr;
// User-specified documentation string
char *doc = nullptr;
/// Human-readable version of the function signature
char *signature = nullptr;
/// Human-readable version of the function signature
char *signature = nullptr;
/// List of registered keyword arguments
std::vector<argument_record> args;
/// List of registered keyword arguments
std::vector<argument_record> args;
/// Pointer to lambda function which converts arguments and performs the actual call
handle (*impl) (function_call &) = nullptr;
/// Pointer to lambda function which converts arguments and performs the
/// actual call
handle (*impl)(function_call &) = nullptr;
/// Storage for the wrapped function pointer and captured data, if any
void *data[3] = { };
/// Storage for the wrapped function pointer and captured data, if any
void *data[3] = {};
/// Pointer to custom destructor for 'data' (if needed)
void (*free_data) (function_record *ptr) = nullptr;
/// Pointer to custom destructor for 'data' (if needed)
void (*free_data)(function_record *ptr) = nullptr;
/// Return value policy associated with this function
return_value_policy policy = return_value_policy::automatic;
/// Return value policy associated with this function
return_value_policy policy = return_value_policy::automatic;
/// True if name == '__init__'
bool is_constructor : 1;
/// True if name == '__init__'
bool is_constructor : 1;
/// True if this is a new-style `__init__` defined in `detail/init.h`
bool is_new_style_constructor : 1;
/// True if this is a new-style `__init__` defined in `detail/init.h`
bool is_new_style_constructor : 1;
/// True if this is a stateless function pointer
bool is_stateless : 1;
/// True if this is a stateless function pointer
bool is_stateless : 1;
/// True if this is an operator (__add__), etc.
bool is_operator : 1;
/// True if this is an operator (__add__), etc.
bool is_operator : 1;
/// True if the function has a '*args' argument
bool has_args : 1;
/// True if the function has a '*args' argument
bool has_args : 1;
/// True if the function has a '**kwargs' argument
bool has_kwargs : 1;
/// True if the function has a '**kwargs' argument
bool has_kwargs : 1;
/// True if this is a method
bool is_method : 1;
/// True if this is a method
bool is_method : 1;
/// Number of arguments (including py::args and/or py::kwargs, if present)
std::uint16_t nargs;
/// Number of arguments (including py::args and/or py::kwargs, if present)
std::uint16_t nargs;
/// Python method object
PyMethodDef *def = nullptr;
/// Python method object
PyMethodDef *def = nullptr;
/// Python handle to the parent scope (a class or a module)
handle scope;
/// Python handle to the parent scope (a class or a module)
handle scope;
/// Python handle to the sibling function representing an overload chain
handle sibling;
/// Python handle to the sibling function representing an overload chain
handle sibling;
/// Pointer to next overload
function_record *next = nullptr;
/// Pointer to next overload
function_record *next = nullptr;
};
/// Special data structure which (temporarily) holds metadata about a bound class
/// Special data structure which (temporarily) holds metadata about a bound
/// class
struct type_record {
PYBIND11_NOINLINE type_record()
: multiple_inheritance(false), dynamic_attr(false), buffer_protocol(false),
default_holder(true), module_local(false) { }
PYBIND11_NOINLINE type_record()
: multiple_inheritance(false), dynamic_attr(false),
buffer_protocol(false), default_holder(true), module_local(false) {}
/// Handle to the parent scope
handle scope;
/// Handle to the parent scope
handle scope;
/// Name of the class
const char *name = nullptr;
/// Name of the class
const char *name = nullptr;
// Pointer to RTTI type_info data structure
const std::type_info *type = nullptr;
// Pointer to RTTI type_info data structure
const std::type_info *type = nullptr;
/// How large is the underlying C++ type?
size_t type_size = 0;
/// How large is the underlying C++ type?
size_t type_size = 0;
/// What is the alignment of the underlying C++ type?
size_t type_align = 0;
/// What is the alignment of the underlying C++ type?
size_t type_align = 0;
/// How large is the type's holder?
size_t holder_size = 0;
/// How large is the type's holder?
size_t holder_size = 0;
/// The global operator new can be overridden with a class-specific variant
void *(*operator_new)(size_t) = nullptr;
/// The global operator new can be overridden with a class-specific variant
void *(*operator_new)(size_t) = nullptr;
/// Function pointer to class_<..>::init_instance
void (*init_instance)(instance *, const void *) = nullptr;
/// Function pointer to class_<..>::init_instance
void (*init_instance)(instance *, const void *) = nullptr;
/// Function pointer to class_<..>::dealloc
void (*dealloc)(detail::value_and_holder &) = nullptr;
/// Function pointer to class_<..>::dealloc
void (*dealloc)(detail::value_and_holder &) = nullptr;
/// List of base classes of the newly created type
list bases;
/// List of base classes of the newly created type
list bases;
/// Optional docstring
const char *doc = nullptr;
/// Optional docstring
const char *doc = nullptr;
/// Custom metaclass (optional)
handle metaclass;
/// Custom metaclass (optional)
handle metaclass;
/// Multiple inheritance marker
bool multiple_inheritance : 1;
/// Multiple inheritance marker
bool multiple_inheritance : 1;
/// Does the class manage a __dict__?
bool dynamic_attr : 1;
/// Does the class manage a __dict__?
bool dynamic_attr : 1;
/// Does the class implement the buffer protocol?
bool buffer_protocol : 1;
/// Does the class implement the buffer protocol?
bool buffer_protocol : 1;
/// Is the default (unique_ptr) holder type used?
bool default_holder : 1;
/// Is the default (unique_ptr) holder type used?
bool default_holder : 1;
/// Is the class definition local to the module shared object?
bool module_local : 1;
/// Is the class definition local to the module shared object?
bool module_local : 1;
PYBIND11_NOINLINE void add_base(const std::type_info &base, void *(*caster)(void *)) {
auto base_info = detail::get_type_info(base, false);
if (!base_info) {
std::string tname(base.name());
detail::clean_type_id(tname);
pybind11_fail("generic_type: type \"" + std::string(name) +
"\" referenced unknown base type \"" + tname + "\"");
}
if (default_holder != base_info->default_holder) {
std::string tname(base.name());
detail::clean_type_id(tname);
pybind11_fail("generic_type: type \"" + std::string(name) + "\" " +
(default_holder ? "does not have" : "has") +
" a non-default holder type while its base \"" + tname + "\" " +
(base_info->default_holder ? "does not" : "does"));
}
bases.append((PyObject *) base_info->type);
if (base_info->type->tp_dictoffset != 0)
dynamic_attr = true;
if (caster)
base_info->implicit_casts.emplace_back(type, caster);
PYBIND11_NOINLINE void add_base(const std::type_info &base,
void *(*caster)(void *)) {
auto base_info = detail::get_type_info(base, false);
if (!base_info) {
std::string tname(base.name());
detail::clean_type_id(tname);
pybind11_fail("generic_type: type \"" + std::string(name) +
"\" referenced unknown base type \"" + tname + "\"");
}
if (default_holder != base_info->default_holder) {
std::string tname(base.name());
detail::clean_type_id(tname);
pybind11_fail("generic_type: type \"" + std::string(name) + "\" " +
(default_holder ? "does not have" : "has") +
" a non-default holder type while its base \"" + tname +
"\" " + (base_info->default_holder ? "does not" : "does"));
}
bases.append((PyObject *)base_info->type);
if (base_info->type->tp_dictoffset != 0)
dynamic_attr = true;
if (caster)
base_info->implicit_casts.emplace_back(type, caster);
}
};
inline function_call::function_call(const function_record &f, handle p) :
func(f), parent(p) {
args.reserve(f.nargs);
args_convert.reserve(f.nargs);
inline function_call::function_call(const function_record &f, handle p)
: func(f), parent(p) {
args.reserve(f.nargs);
args_convert.reserve(f.nargs);
}
/// Tag for a new-style `__init__` defined in `detail/init.h`
struct is_new_style_constructor { };
struct is_new_style_constructor {};
/**
* Partial template specializations to process custom attributes provided to
@@ -300,135 +331,191 @@ struct is_new_style_constructor { };
template <typename T, typename SFINAE = void> struct process_attribute;
template <typename T> struct process_attribute_default {
/// Default implementation: do nothing
static void init(const T &, function_record *) { }
static void init(const T &, type_record *) { }
static void precall(function_call &) { }
static void postcall(function_call &, handle) { }
/// Default implementation: do nothing
static void init(const T &, function_record *) {}
static void init(const T &, type_record *) {}
static void precall(function_call &) {}
static void postcall(function_call &, handle) {}
};
/// Process an attribute specifying the function's name
template <> struct process_attribute<name> : process_attribute_default<name> {
static void init(const name &n, function_record *r) { r->name = const_cast<char *>(n.value); }
static void init(const name &n, function_record *r) {
r->name = const_cast<char *>(n.value);
}
};
/// Process an attribute specifying the function's docstring
template <> struct process_attribute<doc> : process_attribute_default<doc> {
static void init(const doc &n, function_record *r) { r->doc = const_cast<char *>(n.value); }
static void init(const doc &n, function_record *r) {
r->doc = const_cast<char *>(n.value);
}
};
/// Process an attribute specifying the function's docstring (provided as a C-style string)
template <> struct process_attribute<const char *> : process_attribute_default<const char *> {
static void init(const char *d, function_record *r) { r->doc = const_cast<char *>(d); }
static void init(const char *d, type_record *r) { r->doc = const_cast<char *>(d); }
/// Process an attribute specifying the function's docstring (provided as a
/// C-style string)
template <>
struct process_attribute<const char *>
: process_attribute_default<const char *> {
static void init(const char *d, function_record *r) {
r->doc = const_cast<char *>(d);
}
static void init(const char *d, type_record *r) {
r->doc = const_cast<char *>(d);
}
};
template <> struct process_attribute<char *> : process_attribute<const char *> { };
template <>
struct process_attribute<char *> : process_attribute<const char *> {};
/// Process an attribute indicating the function's return value policy
template <> struct process_attribute<return_value_policy> : process_attribute_default<return_value_policy> {
static void init(const return_value_policy &p, function_record *r) { r->policy = p; }
template <>
struct process_attribute<return_value_policy>
: process_attribute_default<return_value_policy> {
static void init(const return_value_policy &p, function_record *r) {
r->policy = p;
}
};
/// Process an attribute which indicates that this is an overloaded function associated with a given sibling
template <> struct process_attribute<sibling> : process_attribute_default<sibling> {
static void init(const sibling &s, function_record *r) { r->sibling = s.value; }
/// Process an attribute which indicates that this is an overloaded function
/// associated with a given sibling
template <>
struct process_attribute<sibling> : process_attribute_default<sibling> {
static void init(const sibling &s, function_record *r) {
r->sibling = s.value;
}
};
/// Process an attribute which indicates that this function is a method
template <> struct process_attribute<is_method> : process_attribute_default<is_method> {
static void init(const is_method &s, function_record *r) { r->is_method = true; r->scope = s.class_; }
template <>
struct process_attribute<is_method> : process_attribute_default<is_method> {
static void init(const is_method &s, function_record *r) {
r->is_method = true;
r->scope = s.class_;
}
};
/// Process an attribute which indicates the parent scope of a method
template <> struct process_attribute<scope> : process_attribute_default<scope> {
static void init(const scope &s, function_record *r) { r->scope = s.value; }
static void init(const scope &s, function_record *r) { r->scope = s.value; }
};
/// Process an attribute which indicates that this function is an operator
template <> struct process_attribute<is_operator> : process_attribute_default<is_operator> {
static void init(const is_operator &, function_record *r) { r->is_operator = true; }
template <>
struct process_attribute<is_operator> : process_attribute_default<is_operator> {
static void init(const is_operator &, function_record *r) {
r->is_operator = true;
}
};
template <> struct process_attribute<is_new_style_constructor> : process_attribute_default<is_new_style_constructor> {
static void init(const is_new_style_constructor &, function_record *r) { r->is_new_style_constructor = true; }
template <>
struct process_attribute<is_new_style_constructor>
: process_attribute_default<is_new_style_constructor> {
static void init(const is_new_style_constructor &, function_record *r) {
r->is_new_style_constructor = true;
}
};
/// Process a keyword argument attribute (*without* a default value)
template <> struct process_attribute<arg> : process_attribute_default<arg> {
static void init(const arg &a, function_record *r) {
if (r->is_method && r->args.empty())
r->args.emplace_back("self", nullptr, handle(), true /*convert*/, false /*none not allowed*/);
r->args.emplace_back(a.name, nullptr, handle(), !a.flag_noconvert, a.flag_none);
}
static void init(const arg &a, function_record *r) {
if (r->is_method && r->args.empty())
r->args.emplace_back("self", nullptr, handle(), true /*convert*/,
false /*none not allowed*/);
r->args.emplace_back(a.name, nullptr, handle(), !a.flag_noconvert,
a.flag_none);
}
};
/// Process a keyword argument attribute (*with* a default value)
template <> struct process_attribute<arg_v> : process_attribute_default<arg_v> {
static void init(const arg_v &a, function_record *r) {
if (r->is_method && r->args.empty())
r->args.emplace_back("self", nullptr /*descr*/, handle() /*parent*/, true /*convert*/, false /*none not allowed*/);
static void init(const arg_v &a, function_record *r) {
if (r->is_method && r->args.empty())
r->args.emplace_back("self", nullptr /*descr*/, handle() /*parent*/,
true /*convert*/, false /*none not allowed*/);
if (!a.value) {
if (!a.value) {
#if !defined(NDEBUG)
std::string descr("'");
if (a.name) descr += std::string(a.name) + ": ";
descr += a.type + "'";
if (r->is_method) {
if (r->name)
descr += " in method '" + (std::string) str(r->scope) + "." + (std::string) r->name + "'";
else
descr += " in method of '" + (std::string) str(r->scope) + "'";
} else if (r->name) {
descr += " in function '" + (std::string) r->name + "'";
}
pybind11_fail("arg(): could not convert default argument "
+ descr + " into a Python object (type not registered yet?)");
std::string descr("'");
if (a.name)
descr += std::string(a.name) + ": ";
descr += a.type + "'";
if (r->is_method) {
if (r->name)
descr += " in method '" + (std::string)str(r->scope) + "." +
(std::string)r->name + "'";
else
descr += " in method of '" + (std::string)str(r->scope) + "'";
} else if (r->name) {
descr += " in function '" + (std::string)r->name + "'";
}
pybind11_fail("arg(): could not convert default argument " + descr +
" into a Python object (type not registered yet?)");
#else
pybind11_fail("arg(): could not convert default argument "
"into a Python object (type not registered yet?). "
"Compile in debug mode for more information.");
pybind11_fail("arg(): could not convert default argument "
"into a Python object (type not registered yet?). "
"Compile in debug mode for more information.");
#endif
}
r->args.emplace_back(a.name, a.descr, a.value.inc_ref(), !a.flag_noconvert, a.flag_none);
}
r->args.emplace_back(a.name, a.descr, a.value.inc_ref(), !a.flag_noconvert,
a.flag_none);
}
};
/// Process a parent class attribute. Single inheritance only (class_ itself already guarantees that)
/// Process a parent class attribute. Single inheritance only (class_ itself
/// already guarantees that)
template <typename T>
struct process_attribute<T, enable_if_t<is_pyobject<T>::value>> : process_attribute_default<handle> {
static void init(const handle &h, type_record *r) { r->bases.append(h); }
struct process_attribute<T, enable_if_t<is_pyobject<T>::value>>
: process_attribute_default<handle> {
static void init(const handle &h, type_record *r) { r->bases.append(h); }
};
/// Process a parent class attribute (deprecated, does not support multiple inheritance)
/// Process a parent class attribute (deprecated, does not support multiple
/// inheritance)
template <typename T>
struct process_attribute<base<T>> : process_attribute_default<base<T>> {
static void init(const base<T> &, type_record *r) { r->add_base(typeid(T), nullptr); }
static void init(const base<T> &, type_record *r) {
r->add_base(typeid(T), nullptr);
}
};
/// Process a multiple inheritance attribute
template <>
struct process_attribute<multiple_inheritance> : process_attribute_default<multiple_inheritance> {
static void init(const multiple_inheritance &, type_record *r) { r->multiple_inheritance = true; }
struct process_attribute<multiple_inheritance>
: process_attribute_default<multiple_inheritance> {
static void init(const multiple_inheritance &, type_record *r) {
r->multiple_inheritance = true;
}
};
template <>
struct process_attribute<dynamic_attr> : process_attribute_default<dynamic_attr> {
static void init(const dynamic_attr &, type_record *r) { r->dynamic_attr = true; }
struct process_attribute<dynamic_attr>
: process_attribute_default<dynamic_attr> {
static void init(const dynamic_attr &, type_record *r) {
r->dynamic_attr = true;
}
};
template <>
struct process_attribute<buffer_protocol> : process_attribute_default<buffer_protocol> {
static void init(const buffer_protocol &, type_record *r) { r->buffer_protocol = true; }
struct process_attribute<buffer_protocol>
: process_attribute_default<buffer_protocol> {
static void init(const buffer_protocol &, type_record *r) {
r->buffer_protocol = true;
}
};
template <>
struct process_attribute<metaclass> : process_attribute_default<metaclass> {
static void init(const metaclass &m, type_record *r) { r->metaclass = m.value; }
static void init(const metaclass &m, type_record *r) {
r->metaclass = m.value;
}
};
template <>
struct process_attribute<module_local> : process_attribute_default<module_local> {
static void init(const module_local &l, type_record *r) { r->module_local = l.value; }
struct process_attribute<module_local>
: process_attribute_default<module_local> {
static void init(const module_local &l, type_record *r) {
r->module_local = l.value;
}
};
/// Process an 'arithmetic' attribute for enums (does nothing here)
@@ -436,57 +523,78 @@ template <>
struct process_attribute<arithmetic> : process_attribute_default<arithmetic> {};
template <typename... Ts>
struct process_attribute<call_guard<Ts...>> : process_attribute_default<call_guard<Ts...>> { };
struct process_attribute<call_guard<Ts...>>
: process_attribute_default<call_guard<Ts...>> {};
/**
* Process a keep_alive call policy -- invokes keep_alive_impl during the
* pre-call handler if both Nurse, Patient != 0 and use the post-call handler
* otherwise
*/
template <size_t Nurse, size_t Patient> struct process_attribute<keep_alive<Nurse, Patient>> : public process_attribute_default<keep_alive<Nurse, Patient>> {
template <size_t N = Nurse, size_t P = Patient, enable_if_t<N != 0 && P != 0, int> = 0>
static void precall(function_call &call) { keep_alive_impl(Nurse, Patient, call, handle()); }
template <size_t N = Nurse, size_t P = Patient, enable_if_t<N != 0 && P != 0, int> = 0>
static void postcall(function_call &, handle) { }
template <size_t N = Nurse, size_t P = Patient, enable_if_t<N == 0 || P == 0, int> = 0>
static void precall(function_call &) { }
template <size_t N = Nurse, size_t P = Patient, enable_if_t<N == 0 || P == 0, int> = 0>
static void postcall(function_call &call, handle ret) { keep_alive_impl(Nurse, Patient, call, ret); }
template <size_t Nurse, size_t Patient>
struct process_attribute<keep_alive<Nurse, Patient>>
: public process_attribute_default<keep_alive<Nurse, Patient>> {
template <size_t N = Nurse, size_t P = Patient,
enable_if_t<N != 0 && P != 0, int> = 0>
static void precall(function_call &call) {
keep_alive_impl(Nurse, Patient, call, handle());
}
template <size_t N = Nurse, size_t P = Patient,
enable_if_t<N != 0 && P != 0, int> = 0>
static void postcall(function_call &, handle) {}
template <size_t N = Nurse, size_t P = Patient,
enable_if_t<N == 0 || P == 0, int> = 0>
static void precall(function_call &) {}
template <size_t N = Nurse, size_t P = Patient,
enable_if_t<N == 0 || P == 0, int> = 0>
static void postcall(function_call &call, handle ret) {
keep_alive_impl(Nurse, Patient, call, ret);
}
};
/// Recursively iterate over variadic template arguments
template <typename... Args> struct process_attributes {
static void init(const Args&... args, function_record *r) {
int unused[] = { 0, (process_attribute<typename std::decay<Args>::type>::init(args, r), 0) ... };
ignore_unused(unused);
}
static void init(const Args&... args, type_record *r) {
int unused[] = { 0, (process_attribute<typename std::decay<Args>::type>::init(args, r), 0) ... };
ignore_unused(unused);
}
static void precall(function_call &call) {
int unused[] = { 0, (process_attribute<typename std::decay<Args>::type>::precall(call), 0) ... };
ignore_unused(unused);
}
static void postcall(function_call &call, handle fn_ret) {
int unused[] = { 0, (process_attribute<typename std::decay<Args>::type>::postcall(call, fn_ret), 0) ... };
ignore_unused(unused);
}
static void init(const Args &... args, function_record *r) {
int unused[] = {
0, (process_attribute<typename std::decay<Args>::type>::init(args, r),
0)...};
ignore_unused(unused);
}
static void init(const Args &... args, type_record *r) {
int unused[] = {
0, (process_attribute<typename std::decay<Args>::type>::init(args, r),
0)...};
ignore_unused(unused);
}
static void precall(function_call &call) {
int unused[] = {
0, (process_attribute<typename std::decay<Args>::type>::precall(call),
0)...};
ignore_unused(unused);
}
static void postcall(function_call &call, handle fn_ret) {
int unused[] = {
0, (process_attribute<typename std::decay<Args>::type>::postcall(
call, fn_ret),
0)...};
ignore_unused(unused);
}
};
template <typename T>
using is_call_guard = is_instantiation<call_guard, T>;
template <typename T> using is_call_guard = is_instantiation<call_guard, T>;
/// Extract the ``type`` from the first `call_guard` in `Extras...` (or `void_type` if none found)
/// Extract the ``type`` from the first `call_guard` in `Extras...` (or
/// `void_type` if none found)
template <typename... Extra>
using extract_guard_t = typename exactly_one_t<is_call_guard, call_guard<>, Extra...>::type;
using extract_guard_t =
typename exactly_one_t<is_call_guard, call_guard<>, Extra...>::type;
/// Check the number of named arguments at compile time
template <typename... Extra,
size_t named = constexpr_sum(std::is_base_of<arg, Extra>::value...),
size_t self = constexpr_sum(std::is_same<is_method, Extra>::value...)>
size_t self = constexpr_sum(std::is_same<is_method, Extra>::value...)>
constexpr bool expected_num_args(size_t nargs, bool has_args, bool has_kwargs) {
return named == 0 || (self + named + has_args + has_kwargs) == nargs;
return named == 0 || (self + named + has_args + has_kwargs) == nargs;
}
NAMESPACE_END(detail)

View File

@@ -15,93 +15,112 @@ NAMESPACE_BEGIN(PYBIND11_NAMESPACE)
/// Information record describing a Python buffer object
struct buffer_info {
void *ptr = nullptr; // Pointer to the underlying storage
ssize_t itemsize = 0; // Size of individual items in bytes
ssize_t size = 0; // Total number of entries
std::string format; // For homogeneous buffers, this should be set to format_descriptor<T>::format()
ssize_t ndim = 0; // Number of dimensions
std::vector<ssize_t> shape; // Shape of the tensor (1 entry per dimension)
std::vector<ssize_t> strides; // Number of entries between adjacent entries (for each per dimension)
void *ptr = nullptr; // Pointer to the underlying storage
ssize_t itemsize = 0; // Size of individual items in bytes
ssize_t size = 0; // Total number of entries
std::string format; // For homogeneous buffers, this should be set to
// format_descriptor<T>::format()
ssize_t ndim = 0; // Number of dimensions
std::vector<ssize_t> shape; // Shape of the tensor (1 entry per dimension)
std::vector<ssize_t> strides; // Number of entries between adjacent entries
// (for each per dimension)
buffer_info() { }
buffer_info() {}
buffer_info(void *ptr, ssize_t itemsize, const std::string &format, ssize_t ndim,
detail::any_container<ssize_t> shape_in, detail::any_container<ssize_t> strides_in)
: ptr(ptr), itemsize(itemsize), size(1), format(format), ndim(ndim),
shape(std::move(shape_in)), strides(std::move(strides_in)) {
if (ndim != (ssize_t) shape.size() || ndim != (ssize_t) strides.size())
pybind11_fail("buffer_info: ndim doesn't match shape and/or strides length");
for (size_t i = 0; i < (size_t) ndim; ++i)
size *= shape[i];
}
template <typename T>
buffer_info(T *ptr, detail::any_container<ssize_t> shape_in, detail::any_container<ssize_t> strides_in)
: buffer_info(private_ctr_tag(), ptr, sizeof(T), format_descriptor<T>::format(), static_cast<ssize_t>(shape_in->size()), std::move(shape_in), std::move(strides_in)) { }
buffer_info(void *ptr, ssize_t itemsize, const std::string &format, ssize_t size)
: buffer_info(ptr, itemsize, format, 1, {size}, {itemsize}) { }
template <typename T>
buffer_info(T *ptr, ssize_t size)
: buffer_info(ptr, sizeof(T), format_descriptor<T>::format(), size) { }
explicit buffer_info(Py_buffer *view, bool ownview = true)
: buffer_info(view->buf, view->itemsize, view->format, view->ndim,
{view->shape, view->shape + view->ndim}, {view->strides, view->strides + view->ndim}) {
this->view = view;
this->ownview = ownview;
}
buffer_info(const buffer_info &) = delete;
buffer_info& operator=(const buffer_info &) = delete;
buffer_info(buffer_info &&other) {
(*this) = std::move(other);
}
buffer_info& operator=(buffer_info &&rhs) {
ptr = rhs.ptr;
itemsize = rhs.itemsize;
size = rhs.size;
format = std::move(rhs.format);
ndim = rhs.ndim;
shape = std::move(rhs.shape);
strides = std::move(rhs.strides);
std::swap(view, rhs.view);
std::swap(ownview, rhs.ownview);
return *this;
}
~buffer_info() {
if (view && ownview) { PyBuffer_Release(view); delete view; }
buffer_info(void *ptr, ssize_t itemsize, const std::string &format,
ssize_t ndim, detail::any_container<ssize_t> shape_in,
detail::any_container<ssize_t> strides_in)
: ptr(ptr), itemsize(itemsize), size(1), format(format), ndim(ndim),
shape(std::move(shape_in)), strides(std::move(strides_in)) {
if (ndim != (ssize_t)shape.size() || ndim != (ssize_t)strides.size())
pybind11_fail(
"buffer_info: ndim doesn't match shape and/or strides length");
for (size_t i = 0; i < (size_t)ndim; ++i)
size *= shape[i];
}
template <typename T>
buffer_info(T *ptr, detail::any_container<ssize_t> shape_in,
detail::any_container<ssize_t> strides_in)
: buffer_info(private_ctr_tag(), ptr, sizeof(T),
format_descriptor<T>::format(),
static_cast<ssize_t>(shape_in->size()), std::move(shape_in),
std::move(strides_in)) {}
buffer_info(void *ptr, ssize_t itemsize, const std::string &format,
ssize_t size)
: buffer_info(ptr, itemsize, format, 1, {size}, {itemsize}) {}
template <typename T>
buffer_info(T *ptr, ssize_t size)
: buffer_info(ptr, sizeof(T), format_descriptor<T>::format(), size) {}
explicit buffer_info(Py_buffer *view, bool ownview = true)
: buffer_info(view->buf, view->itemsize, view->format, view->ndim,
{view->shape, view->shape + view->ndim},
{view->strides, view->strides + view->ndim}) {
this->view = view;
this->ownview = ownview;
}
buffer_info(const buffer_info &) = delete;
buffer_info &operator=(const buffer_info &) = delete;
buffer_info(buffer_info &&other) { (*this) = std::move(other); }
buffer_info &operator=(buffer_info &&rhs) {
ptr = rhs.ptr;
itemsize = rhs.itemsize;
size = rhs.size;
format = std::move(rhs.format);
ndim = rhs.ndim;
shape = std::move(rhs.shape);
strides = std::move(rhs.strides);
std::swap(view, rhs.view);
std::swap(ownview, rhs.ownview);
return *this;
}
~buffer_info() {
if (view && ownview) {
PyBuffer_Release(view);
delete view;
}
}
private:
struct private_ctr_tag { };
struct private_ctr_tag {};
buffer_info(private_ctr_tag, void *ptr, ssize_t itemsize, const std::string &format, ssize_t ndim,
detail::any_container<ssize_t> &&shape_in, detail::any_container<ssize_t> &&strides_in)
: buffer_info(ptr, itemsize, format, ndim, std::move(shape_in), std::move(strides_in)) { }
buffer_info(private_ctr_tag, void *ptr, ssize_t itemsize,
const std::string &format, ssize_t ndim,
detail::any_container<ssize_t> &&shape_in,
detail::any_container<ssize_t> &&strides_in)
: buffer_info(ptr, itemsize, format, ndim, std::move(shape_in),
std::move(strides_in)) {}
Py_buffer *view = nullptr;
bool ownview = false;
Py_buffer *view = nullptr;
bool ownview = false;
};
NAMESPACE_BEGIN(detail)
template <typename T, typename SFINAE = void> struct compare_buffer_info {
static bool compare(const buffer_info& b) {
return b.format == format_descriptor<T>::format() && b.itemsize == (ssize_t) sizeof(T);
}
static bool compare(const buffer_info &b) {
return b.format == format_descriptor<T>::format() &&
b.itemsize == (ssize_t)sizeof(T);
}
};
template <typename T> struct compare_buffer_info<T, detail::enable_if_t<std::is_integral<T>::value>> {
static bool compare(const buffer_info& b) {
return (size_t) b.itemsize == sizeof(T) && (b.format == format_descriptor<T>::value ||
((sizeof(T) == sizeof(long)) && b.format == (std::is_unsigned<T>::value ? "L" : "l")) ||
((sizeof(T) == sizeof(size_t)) && b.format == (std::is_unsigned<T>::value ? "N" : "n")));
}
template <typename T>
struct compare_buffer_info<T, detail::enable_if_t<std::is_integral<T>::value>> {
static bool compare(const buffer_info &b) {
return (size_t)b.itemsize == sizeof(T) &&
(b.format == format_descriptor<T>::value ||
((sizeof(T) == sizeof(long)) &&
b.format == (std::is_unsigned<T>::value ? "L" : "l")) ||
((sizeof(T) == sizeof(size_t)) &&
b.format == (std::is_unsigned<T>::value ? "N" : "n")));
}
};
NAMESPACE_END(detail)

File diff suppressed because it is too large Load Diff

View File

@@ -1,5 +1,6 @@
/*
pybind11/chrono.h: Transparent conversion between std::chrono and python's datetime
pybind11/chrono.h: Transparent conversion between std::chrono and python's
datetime
Copyright (c) 2016 Trent Houliston <trent@houliston.me> and
Wenzel Jakob <wenzel.jakob@epfl.ch>
@@ -11,20 +12,21 @@
#pragma once
#include "pybind11.h"
#include <chrono>
#include <cmath>
#include <ctime>
#include <chrono>
#include <datetime.h>
// Backport the PyDateTime_DELTA functions from Python3.3 if required
#ifndef PyDateTime_DELTA_GET_DAYS
#define PyDateTime_DELTA_GET_DAYS(o) (((PyDateTime_Delta*)o)->days)
#define PyDateTime_DELTA_GET_DAYS(o) (((PyDateTime_Delta *)o)->days)
#endif
#ifndef PyDateTime_DELTA_GET_SECONDS
#define PyDateTime_DELTA_GET_SECONDS(o) (((PyDateTime_Delta*)o)->seconds)
#define PyDateTime_DELTA_GET_SECONDS(o) (((PyDateTime_Delta *)o)->seconds)
#endif
#ifndef PyDateTime_DELTA_GET_MICROSECONDS
#define PyDateTime_DELTA_GET_MICROSECONDS(o) (((PyDateTime_Delta*)o)->microseconds)
#define PyDateTime_DELTA_GET_MICROSECONDS(o) \
(((PyDateTime_Delta *)o)->microseconds)
#endif
NAMESPACE_BEGIN(PYBIND11_NAMESPACE)
@@ -32,131 +34,154 @@ NAMESPACE_BEGIN(detail)
template <typename type> class duration_caster {
public:
typedef typename type::rep rep;
typedef typename type::period period;
typedef typename type::rep rep;
typedef typename type::period period;
typedef std::chrono::duration<uint_fast32_t, std::ratio<86400>> days;
typedef std::chrono::duration<uint_fast32_t, std::ratio<86400>> days;
bool load(handle src, bool) {
using namespace std::chrono;
bool load(handle src, bool) {
using namespace std::chrono;
// Lazy initialise the PyDateTime import
if (!PyDateTimeAPI) { PyDateTime_IMPORT; }
if (!src) return false;
// If invoked with datetime.delta object
if (PyDelta_Check(src.ptr())) {
value = type(duration_cast<duration<rep, period>>(
days(PyDateTime_DELTA_GET_DAYS(src.ptr()))
+ seconds(PyDateTime_DELTA_GET_SECONDS(src.ptr()))
+ microseconds(PyDateTime_DELTA_GET_MICROSECONDS(src.ptr()))));
return true;
}
// If invoked with a float we assume it is seconds and convert
else if (PyFloat_Check(src.ptr())) {
value = type(duration_cast<duration<rep, period>>(duration<double>(PyFloat_AsDouble(src.ptr()))));
return true;
}
else return false;
// Lazy initialise the PyDateTime import
if (!PyDateTimeAPI) {
PyDateTime_IMPORT;
}
// If this is a duration just return it back
static const std::chrono::duration<rep, period>& get_duration(const std::chrono::duration<rep, period> &src) {
return src;
if (!src)
return false;
// If invoked with datetime.delta object
if (PyDelta_Check(src.ptr())) {
value = type(duration_cast<duration<rep, period>>(
days(PyDateTime_DELTA_GET_DAYS(src.ptr())) +
seconds(PyDateTime_DELTA_GET_SECONDS(src.ptr())) +
microseconds(PyDateTime_DELTA_GET_MICROSECONDS(src.ptr()))));
return true;
}
// If invoked with a float we assume it is seconds and convert
else if (PyFloat_Check(src.ptr())) {
value = type(duration_cast<duration<rep, period>>(
duration<double>(PyFloat_AsDouble(src.ptr()))));
return true;
} else
return false;
}
// If this is a duration just return it back
static const std::chrono::duration<rep, period> &
get_duration(const std::chrono::duration<rep, period> &src) {
return src;
}
// If this is a time_point get the time_since_epoch
template <typename Clock>
static std::chrono::duration<rep, period> get_duration(
const std::chrono::time_point<Clock, std::chrono::duration<rep, period>>
&src) {
return src.time_since_epoch();
}
static handle cast(const type &src, return_value_policy /* policy */,
handle /* parent */) {
using namespace std::chrono;
// Use overloaded function to get our duration from our source
// Works out if it is a duration or time_point and get the duration
auto d = get_duration(src);
// Lazy initialise the PyDateTime import
if (!PyDateTimeAPI) {
PyDateTime_IMPORT;
}
// If this is a time_point get the time_since_epoch
template <typename Clock> static std::chrono::duration<rep, period> get_duration(const std::chrono::time_point<Clock, std::chrono::duration<rep, period>> &src) {
return src.time_since_epoch();
}
// Declare these special duration types so the conversions happen with the
// correct primitive types (int)
using dd_t = duration<int, std::ratio<86400>>;
using ss_t = duration<int, std::ratio<1>>;
using us_t = duration<int, std::micro>;
static handle cast(const type &src, return_value_policy /* policy */, handle /* parent */) {
using namespace std::chrono;
auto dd = duration_cast<dd_t>(d);
auto subd = d - dd;
auto ss = duration_cast<ss_t>(subd);
auto us = duration_cast<us_t>(subd - ss);
return PyDelta_FromDSU(dd.count(), ss.count(), us.count());
}
// Use overloaded function to get our duration from our source
// Works out if it is a duration or time_point and get the duration
auto d = get_duration(src);
// Lazy initialise the PyDateTime import
if (!PyDateTimeAPI) { PyDateTime_IMPORT; }
// Declare these special duration types so the conversions happen with the correct primitive types (int)
using dd_t = duration<int, std::ratio<86400>>;
using ss_t = duration<int, std::ratio<1>>;
using us_t = duration<int, std::micro>;
auto dd = duration_cast<dd_t>(d);
auto subd = d - dd;
auto ss = duration_cast<ss_t>(subd);
auto us = duration_cast<us_t>(subd - ss);
return PyDelta_FromDSU(dd.count(), ss.count(), us.count());
}
PYBIND11_TYPE_CASTER(type, _("datetime.timedelta"));
PYBIND11_TYPE_CASTER(type, _("datetime.timedelta"));
};
// This is for casting times on the system clock into datetime.datetime instances
template <typename Duration> class type_caster<std::chrono::time_point<std::chrono::system_clock, Duration>> {
// This is for casting times on the system clock into datetime.datetime
// instances
template <typename Duration>
class type_caster<
std::chrono::time_point<std::chrono::system_clock, Duration>> {
public:
typedef std::chrono::time_point<std::chrono::system_clock, Duration> type;
bool load(handle src, bool) {
using namespace std::chrono;
typedef std::chrono::time_point<std::chrono::system_clock, Duration> type;
bool load(handle src, bool) {
using namespace std::chrono;
// Lazy initialise the PyDateTime import
if (!PyDateTimeAPI) { PyDateTime_IMPORT; }
if (!src) return false;
if (PyDateTime_Check(src.ptr())) {
std::tm cal;
cal.tm_sec = PyDateTime_DATE_GET_SECOND(src.ptr());
cal.tm_min = PyDateTime_DATE_GET_MINUTE(src.ptr());
cal.tm_hour = PyDateTime_DATE_GET_HOUR(src.ptr());
cal.tm_mday = PyDateTime_GET_DAY(src.ptr());
cal.tm_mon = PyDateTime_GET_MONTH(src.ptr()) - 1;
cal.tm_year = PyDateTime_GET_YEAR(src.ptr()) - 1900;
cal.tm_isdst = -1;
value = system_clock::from_time_t(std::mktime(&cal)) + microseconds(PyDateTime_DATE_GET_MICROSECOND(src.ptr()));
return true;
}
else return false;
// Lazy initialise the PyDateTime import
if (!PyDateTimeAPI) {
PyDateTime_IMPORT;
}
static handle cast(const std::chrono::time_point<std::chrono::system_clock, Duration> &src, return_value_policy /* policy */, handle /* parent */) {
using namespace std::chrono;
if (!src)
return false;
if (PyDateTime_Check(src.ptr())) {
std::tm cal;
cal.tm_sec = PyDateTime_DATE_GET_SECOND(src.ptr());
cal.tm_min = PyDateTime_DATE_GET_MINUTE(src.ptr());
cal.tm_hour = PyDateTime_DATE_GET_HOUR(src.ptr());
cal.tm_mday = PyDateTime_GET_DAY(src.ptr());
cal.tm_mon = PyDateTime_GET_MONTH(src.ptr()) - 1;
cal.tm_year = PyDateTime_GET_YEAR(src.ptr()) - 1900;
cal.tm_isdst = -1;
// Lazy initialise the PyDateTime import
if (!PyDateTimeAPI) { PyDateTime_IMPORT; }
value = system_clock::from_time_t(std::mktime(&cal)) +
microseconds(PyDateTime_DATE_GET_MICROSECOND(src.ptr()));
return true;
} else
return false;
}
std::time_t tt = system_clock::to_time_t(src);
// this function uses static memory so it's best to copy it out asap just in case
// otherwise other code that is using localtime may break this (not just python code)
std::tm localtime = *std::localtime(&tt);
static handle
cast(const std::chrono::time_point<std::chrono::system_clock, Duration> &src,
return_value_policy /* policy */, handle /* parent */) {
using namespace std::chrono;
// Declare these special duration types so the conversions happen with the correct primitive types (int)
using us_t = duration<int, std::micro>;
return PyDateTime_FromDateAndTime(localtime.tm_year + 1900,
localtime.tm_mon + 1,
localtime.tm_mday,
localtime.tm_hour,
localtime.tm_min,
localtime.tm_sec,
(duration_cast<us_t>(src.time_since_epoch() % seconds(1))).count());
// Lazy initialise the PyDateTime import
if (!PyDateTimeAPI) {
PyDateTime_IMPORT;
}
PYBIND11_TYPE_CASTER(type, _("datetime.datetime"));
std::time_t tt = system_clock::to_time_t(src);
// this function uses static memory so it's best to copy it out asap just in
// case otherwise other code that is using localtime may break this (not
// just python code)
std::tm localtime = *std::localtime(&tt);
// Declare these special duration types so the conversions happen with the
// correct primitive types (int)
using us_t = duration<int, std::micro>;
return PyDateTime_FromDateAndTime(
localtime.tm_year + 1900, localtime.tm_mon + 1, localtime.tm_mday,
localtime.tm_hour, localtime.tm_min, localtime.tm_sec,
(duration_cast<us_t>(src.time_since_epoch() % seconds(1))).count());
}
PYBIND11_TYPE_CASTER(type, _("datetime.datetime"));
};
// Other clocks that are not the system clock are not measured as datetime.datetime objects
// since they are not measured on calendar time. So instead we just make them timedeltas
// Or if they have passed us a time as a float we convert that
template <typename Clock, typename Duration> class type_caster<std::chrono::time_point<Clock, Duration>>
: public duration_caster<std::chrono::time_point<Clock, Duration>> {
};
// Other clocks that are not the system clock are not measured as
// datetime.datetime objects since they are not measured on calendar time. So
// instead we just make them timedeltas Or if they have passed us a time as a
// float we convert that
template <typename Clock, typename Duration>
class type_caster<std::chrono::time_point<Clock, Duration>>
: public duration_caster<std::chrono::time_point<Clock, Duration>> {};
template <typename Rep, typename Period> class type_caster<std::chrono::duration<Rep, Period>>
: public duration_caster<std::chrono::duration<Rep, Period>> {
};
template <typename Rep, typename Period>
class type_caster<std::chrono::duration<Rep, Period>>
: public duration_caster<std::chrono::duration<Rep, Period>> {};
NAMESPACE_END(detail)
NAMESPACE_END(PYBIND11_NAMESPACE)

View File

@@ -1,2 +1,3 @@
#include "detail/common.h"
#warning "Including 'common.h' is deprecated. It will be removed in v3.0. Use 'pybind11.h'."
#warning \
"Including 'common.h' is deprecated. It will be removed in v3.0. Use 'pybind11.h'."

View File

@@ -14,52 +14,59 @@
/// glibc defines I as a macro which breaks things, e.g., boost template names
#ifdef I
# undef I
#undef I
#endif
NAMESPACE_BEGIN(PYBIND11_NAMESPACE)
template <typename T> struct format_descriptor<std::complex<T>, detail::enable_if_t<std::is_floating_point<T>::value>> {
static constexpr const char c = format_descriptor<T>::c;
static constexpr const char value[3] = { 'Z', c, '\0' };
static std::string format() { return std::string(value); }
template <typename T>
struct format_descriptor<
std::complex<T>, detail::enable_if_t<std::is_floating_point<T>::value>> {
static constexpr const char c = format_descriptor<T>::c;
static constexpr const char value[3] = {'Z', c, '\0'};
static std::string format() { return std::string(value); }
};
#ifndef PYBIND11_CPP17
template <typename T> constexpr const char format_descriptor<
std::complex<T>, detail::enable_if_t<std::is_floating_point<T>::value>>::value[3];
template <typename T>
constexpr const char format_descriptor<
std::complex<T>,
detail::enable_if_t<std::is_floating_point<T>::value>>::value[3];
#endif
NAMESPACE_BEGIN(detail)
template <typename T> struct is_fmt_numeric<std::complex<T>, detail::enable_if_t<std::is_floating_point<T>::value>> {
static constexpr bool value = true;
static constexpr int index = is_fmt_numeric<T>::index + 3;
template <typename T>
struct is_fmt_numeric<std::complex<T>,
detail::enable_if_t<std::is_floating_point<T>::value>> {
static constexpr bool value = true;
static constexpr int index = is_fmt_numeric<T>::index + 3;
};
template <typename T> class type_caster<std::complex<T>> {
public:
bool load(handle src, bool convert) {
if (!src)
return false;
if (!convert && !PyComplex_Check(src.ptr()))
return false;
Py_complex result = PyComplex_AsCComplex(src.ptr());
if (result.real == -1.0 && PyErr_Occurred()) {
PyErr_Clear();
return false;
}
value = std::complex<T>((T) result.real, (T) result.imag);
return true;
bool load(handle src, bool convert) {
if (!src)
return false;
if (!convert && !PyComplex_Check(src.ptr()))
return false;
Py_complex result = PyComplex_AsCComplex(src.ptr());
if (result.real == -1.0 && PyErr_Occurred()) {
PyErr_Clear();
return false;
}
value = std::complex<T>((T)result.real, (T)result.imag);
return true;
}
static handle cast(const std::complex<T> &src, return_value_policy /* policy */, handle /* parent */) {
return PyComplex_FromDoubles((double) src.real(), (double) src.imag());
}
static handle cast(const std::complex<T> &src,
return_value_policy /* policy */, handle /* parent */) {
return PyComplex_FromDoubles((double)src.real(), (double)src.imag());
}
PYBIND11_TYPE_CASTER(std::complex<T>, _("complex"));
PYBIND11_TYPE_CASTER(std::complex<T>, _("complex"));
};
NAMESPACE_END(detail)
NAMESPACE_END(PYBIND11_NAMESPACE)

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@@ -1,5 +1,6 @@
/*
pybind11/detail/descr.h: Helper type for concatenating type signatures at compile time
pybind11/detail/descr.h: Helper type for concatenating type signatures at
compile time
Copyright (c) 2016 Wenzel Jakob <wenzel.jakob@epfl.ch>
@@ -15,67 +16,82 @@ NAMESPACE_BEGIN(PYBIND11_NAMESPACE)
NAMESPACE_BEGIN(detail)
#if !defined(_MSC_VER)
# define PYBIND11_DESCR_CONSTEXPR static constexpr
#define PYBIND11_DESCR_CONSTEXPR static constexpr
#else
# define PYBIND11_DESCR_CONSTEXPR const
#define PYBIND11_DESCR_CONSTEXPR const
#endif
/* Concatenate type signatures at compile time */
template <size_t N, typename... Ts>
struct descr {
char text[N + 1];
template <size_t N, typename... Ts> struct descr {
char text[N + 1];
constexpr descr() : text{'\0'} { }
constexpr descr(char const (&s)[N+1]) : descr(s, make_index_sequence<N>()) { }
constexpr descr() : text{'\0'} {}
constexpr descr(char const (&s)[N + 1])
: descr(s, make_index_sequence<N>()) {}
template <size_t... Is>
constexpr descr(char const (&s)[N+1], index_sequence<Is...>) : text{s[Is]..., '\0'} { }
template <size_t... Is>
constexpr descr(char const (&s)[N + 1], index_sequence<Is...>)
: text{s[Is]..., '\0'} {}
template <typename... Chars>
constexpr descr(char c, Chars... cs) : text{c, static_cast<char>(cs)..., '\0'} { }
template <typename... Chars>
constexpr descr(char c, Chars... cs)
: text{c, static_cast<char>(cs)..., '\0'} {}
static constexpr std::array<const std::type_info *, sizeof...(Ts) + 1> types() {
return {{&typeid(Ts)..., nullptr}};
}
static constexpr std::array<const std::type_info *, sizeof...(Ts) + 1>
types() {
return {{&typeid(Ts)..., nullptr}};
}
};
template <size_t N1, size_t N2, typename... Ts1, typename... Ts2, size_t... Is1, size_t... Is2>
constexpr descr<N1 + N2, Ts1..., Ts2...> plus_impl(const descr<N1, Ts1...> &a, const descr<N2, Ts2...> &b,
index_sequence<Is1...>, index_sequence<Is2...>) {
return {a.text[Is1]..., b.text[Is2]...};
template <size_t N1, size_t N2, typename... Ts1, typename... Ts2, size_t... Is1,
size_t... Is2>
constexpr descr<N1 + N2, Ts1..., Ts2...>
plus_impl(const descr<N1, Ts1...> &a, const descr<N2, Ts2...> &b,
index_sequence<Is1...>, index_sequence<Is2...>) {
return {a.text[Is1]..., b.text[Is2]...};
}
template <size_t N1, size_t N2, typename... Ts1, typename... Ts2>
constexpr descr<N1 + N2, Ts1..., Ts2...> operator+(const descr<N1, Ts1...> &a, const descr<N2, Ts2...> &b) {
return plus_impl(a, b, make_index_sequence<N1>(), make_index_sequence<N2>());
constexpr descr<N1 + N2, Ts1..., Ts2...> operator+(const descr<N1, Ts1...> &a,
const descr<N2, Ts2...> &b) {
return plus_impl(a, b, make_index_sequence<N1>(), make_index_sequence<N2>());
}
template <size_t N>
constexpr descr<N - 1> _(char const(&text)[N]) { return descr<N - 1>(text); }
constexpr descr<0> _(char const(&)[1]) { return {}; }
template <size_t N> constexpr descr<N - 1> _(char const (&text)[N]) {
return descr<N - 1>(text);
}
constexpr descr<0> _(char const (&)[1]) { return {}; }
template <size_t Rem, size_t... Digits> struct int_to_str : int_to_str<Rem/10, Rem%10, Digits...> { };
template <size_t...Digits> struct int_to_str<0, Digits...> {
static constexpr auto digits = descr<sizeof...(Digits)>(('0' + Digits)...);
template <size_t Rem, size_t... Digits>
struct int_to_str : int_to_str<Rem / 10, Rem % 10, Digits...> {};
template <size_t... Digits> struct int_to_str<0, Digits...> {
static constexpr auto digits = descr<sizeof...(Digits)>(('0' + Digits)...);
};
// Ternary description (like std::conditional)
template <bool B, size_t N1, size_t N2>
constexpr enable_if_t<B, descr<N1 - 1>> _(char const(&text1)[N1], char const(&)[N2]) {
return _(text1);
constexpr enable_if_t<B, descr<N1 - 1>> _(char const (&text1)[N1],
char const (&)[N2]) {
return _(text1);
}
template <bool B, size_t N1, size_t N2>
constexpr enable_if_t<!B, descr<N2 - 1>> _(char const(&)[N1], char const(&text2)[N2]) {
return _(text2);
constexpr enable_if_t<!B, descr<N2 - 1>> _(char const (&)[N1],
char const (&text2)[N2]) {
return _(text2);
}
template <bool B, typename T1, typename T2>
constexpr enable_if_t<B, T1> _(const T1 &d, const T2 &) { return d; }
constexpr enable_if_t<B, T1> _(const T1 &d, const T2 &) {
return d;
}
template <bool B, typename T1, typename T2>
constexpr enable_if_t<!B, T2> _(const T1 &, const T2 &d) { return d; }
constexpr enable_if_t<!B, T2> _(const T1 &, const T2 &d) {
return d;
}
template <size_t Size> auto constexpr _() -> decltype(int_to_str<Size / 10, Size % 10>::digits) {
return int_to_str<Size / 10, Size % 10>::digits;
template <size_t Size>
auto constexpr _() -> decltype(int_to_str<Size / 10, Size % 10>::digits) {
return int_to_str<Size / 10, Size % 10>::digits;
}
template <typename Type> constexpr descr<1, Type> _() { return {'%'}; }
@@ -83,17 +99,19 @@ template <typename Type> constexpr descr<1, Type> _() { return {'%'}; }
constexpr descr<0> concat() { return {}; }
template <size_t N, typename... Ts>
constexpr descr<N, Ts...> concat(const descr<N, Ts...> &descr) { return descr; }
constexpr descr<N, Ts...> concat(const descr<N, Ts...> &descr) {
return descr;
}
template <size_t N, typename... Ts, typename... Args>
constexpr auto concat(const descr<N, Ts...> &d, const Args &...args)
constexpr auto concat(const descr<N, Ts...> &d, const Args &... args)
-> decltype(std::declval<descr<N + 2, Ts...>>() + concat(args...)) {
return d + _(", ") + concat(args...);
return d + _(", ") + concat(args...);
}
template <size_t N, typename... Ts>
constexpr descr<N + 2, Ts...> type_descr(const descr<N, Ts...> &descr) {
return _("{") + descr + _("}");
return _("{") + descr + _("}");
}
NAMESPACE_END(detail)

View File

@@ -1,5 +1,6 @@
/*
pybind11/detail/init.h: init factory function implementation and support code.
pybind11/detail/init.h: init factory function implementation and support
code.
Copyright (c) 2017 Jason Rhinelander <jason@imaginary.ca>
@@ -14,26 +15,26 @@
NAMESPACE_BEGIN(PYBIND11_NAMESPACE)
NAMESPACE_BEGIN(detail)
template <>
class type_caster<value_and_holder> {
template <> class type_caster<value_and_holder> {
public:
bool load(handle h, bool) {
value = reinterpret_cast<value_and_holder *>(h.ptr());
return true;
}
bool load(handle h, bool) {
value = reinterpret_cast<value_and_holder *>(h.ptr());
return true;
}
template <typename> using cast_op_type = value_and_holder &;
operator value_and_holder &() { return *value; }
static constexpr auto name = _<value_and_holder>();
template <typename> using cast_op_type = value_and_holder &;
operator value_and_holder &() { return *value; }
static constexpr auto name = _<value_and_holder>();
private:
value_and_holder *value = nullptr;
value_and_holder *value = nullptr;
};
NAMESPACE_BEGIN(initimpl)
inline void no_nullptr(void *ptr) {
if (!ptr) throw type_error("pybind11::init(): factory function returned nullptr");
if (!ptr)
throw type_error("pybind11::init(): factory function returned nullptr");
}
// Implementing functions for all forms of py::init<...> and py::init(...)
@@ -41,293 +42,372 @@ template <typename Class> using Cpp = typename Class::type;
template <typename Class> using Alias = typename Class::type_alias;
template <typename Class> using Holder = typename Class::holder_type;
template <typename Class> using is_alias_constructible = std::is_constructible<Alias<Class>, Cpp<Class> &&>;
template <typename Class>
using is_alias_constructible =
std::is_constructible<Alias<Class>, Cpp<Class> &&>;
// Takes a Cpp pointer and returns true if it actually is a polymorphic Alias instance.
// Takes a Cpp pointer and returns true if it actually is a polymorphic Alias
// instance.
template <typename Class, enable_if_t<Class::has_alias, int> = 0>
bool is_alias(Cpp<Class> *ptr) {
return dynamic_cast<Alias<Class> *>(ptr) != nullptr;
return dynamic_cast<Alias<Class> *>(ptr) != nullptr;
}
// Failing fallback version of the above for a no-alias class (always returns false)
template <typename /*Class*/>
constexpr bool is_alias(void *) { return false; }
// Failing fallback version of the above for a no-alias class (always returns
// false)
template <typename /*Class*/> constexpr bool is_alias(void *) { return false; }
// Constructs and returns a new object; if the given arguments don't map to a constructor, we fall
// back to brace aggregate initiailization so that for aggregate initialization can be used with
// py::init, e.g. `py::init<int, int>` to initialize a `struct T { int a; int b; }`. For
// non-aggregate types, we need to use an ordinary T(...) constructor (invoking as `T{...}` usually
// works, but will not do the expected thing when `T` has an `initializer_list<T>` constructor).
template <typename Class, typename... Args, detail::enable_if_t<std::is_constructible<Class, Args...>::value, int> = 0>
inline Class *construct_or_initialize(Args &&...args) { return new Class(std::forward<Args>(args)...); }
template <typename Class, typename... Args, detail::enable_if_t<!std::is_constructible<Class, Args...>::value, int> = 0>
inline Class *construct_or_initialize(Args &&...args) { return new Class{std::forward<Args>(args)...}; }
// Constructs and returns a new object; if the given arguments don't map to a
// constructor, we fall back to brace aggregate initiailization so that for
// aggregate initialization can be used with py::init, e.g. `py::init<int,
// int>` to initialize a `struct T { int a; int b; }`. For non-aggregate types,
// we need to use an ordinary T(...) constructor (invoking as `T{...}` usually
// works, but will not do the expected thing when `T` has an
// `initializer_list<T>` constructor).
template <
typename Class, typename... Args,
detail::enable_if_t<std::is_constructible<Class, Args...>::value, int> = 0>
inline Class *construct_or_initialize(Args &&... args) {
return new Class(std::forward<Args>(args)...);
}
template <
typename Class, typename... Args,
detail::enable_if_t<!std::is_constructible<Class, Args...>::value, int> = 0>
inline Class *construct_or_initialize(Args &&... args) {
return new Class{std::forward<Args>(args)...};
}
// Attempts to constructs an alias using a `Alias(Cpp &&)` constructor. This allows types with
// an alias to provide only a single Cpp factory function as long as the Alias can be
// constructed from an rvalue reference of the base Cpp type. This means that Alias classes
// can, when appropriate, simply define a `Alias(Cpp &&)` constructor rather than needing to
// inherit all the base class constructors.
// Attempts to constructs an alias using a `Alias(Cpp &&)` constructor. This
// allows types with an alias to provide only a single Cpp factory function as
// long as the Alias can be constructed from an rvalue reference of the base Cpp
// type. This means that Alias classes can, when appropriate, simply define a
// `Alias(Cpp &&)` constructor rather than needing to inherit all the base class
// constructors.
template <typename Class>
void construct_alias_from_cpp(std::true_type /*is_alias_constructible*/,
value_and_holder &v_h, Cpp<Class> &&base) {
v_h.value_ptr() = new Alias<Class>(std::move(base));
v_h.value_ptr() = new Alias<Class>(std::move(base));
}
template <typename Class>
[[noreturn]] void construct_alias_from_cpp(std::false_type /*!is_alias_constructible*/,
value_and_holder &, Cpp<Class> &&) {
throw type_error("pybind11::init(): unable to convert returned instance to required "
"alias class: no `Alias<Class>(Class &&)` constructor available");
[[noreturn]] void
construct_alias_from_cpp(std::false_type /*!is_alias_constructible*/,
value_and_holder &, Cpp<Class> &&) {
throw type_error(
"pybind11::init(): unable to convert returned instance to required "
"alias class: no `Alias<Class>(Class &&)` constructor available");
}
// Error-generating fallback for factories that don't match one of the below construction
// mechanisms.
template <typename Class>
void construct(...) {
static_assert(!std::is_same<Class, Class>::value /* always false */,
"pybind11::init(): init function must return a compatible pointer, "
"holder, or value");
// Error-generating fallback for factories that don't match one of the below
// construction mechanisms.
template <typename Class> void construct(...) {
static_assert(
!std::is_same<Class, Class>::value /* always false */,
"pybind11::init(): init function must return a compatible pointer, "
"holder, or value");
}
// Pointer return v1: the factory function returns a class pointer for a registered class.
// If we don't need an alias (because this class doesn't have one, or because the final type is
// inherited on the Python side) we can simply take over ownership. Otherwise we need to try to
// construct an Alias from the returned base instance.
// Pointer return v1: the factory function returns a class pointer for a
// registered class. If we don't need an alias (because this class doesn't have
// one, or because the final type is inherited on the Python side) we can simply
// take over ownership. Otherwise we need to try to construct an Alias from the
// returned base instance.
template <typename Class>
void construct(value_and_holder &v_h, Cpp<Class> *ptr, bool need_alias) {
no_nullptr(ptr);
if (Class::has_alias && need_alias && !is_alias<Class>(ptr)) {
// We're going to try to construct an alias by moving the cpp type. Whether or not
// that succeeds, we still need to destroy the original cpp pointer (either the
// moved away leftover, if the alias construction works, or the value itself if we
// throw an error), but we can't just call `delete ptr`: it might have a special
// deleter, or might be shared_from_this. So we construct a holder around it as if
// it was a normal instance, then steal the holder away into a local variable; thus
// the holder and destruction happens when we leave the C++ scope, and the holder
// class gets to handle the destruction however it likes.
v_h.value_ptr() = ptr;
v_h.set_instance_registered(true); // To prevent init_instance from registering it
v_h.type->init_instance(v_h.inst, nullptr); // Set up the holder
Holder<Class> temp_holder(std::move(v_h.holder<Holder<Class>>())); // Steal the holder
v_h.type->dealloc(v_h); // Destroys the moved-out holder remains, resets value ptr to null
v_h.set_instance_registered(false);
no_nullptr(ptr);
if (Class::has_alias && need_alias && !is_alias<Class>(ptr)) {
// We're going to try to construct an alias by moving the cpp type. Whether
// or not that succeeds, we still need to destroy the original cpp pointer
// (either the moved away leftover, if the alias construction works, or the
// value itself if we throw an error), but we can't just call `delete ptr`:
// it might have a special deleter, or might be shared_from_this. So we
// construct a holder around it as if it was a normal instance, then steal
// the holder away into a local variable; thus the holder and destruction
// happens when we leave the C++ scope, and the holder class gets to handle
// the destruction however it likes.
v_h.value_ptr() = ptr;
v_h.set_instance_registered(
true); // To prevent init_instance from registering it
v_h.type->init_instance(v_h.inst, nullptr); // Set up the holder
Holder<Class> temp_holder(
std::move(v_h.holder<Holder<Class>>())); // Steal the holder
v_h.type->dealloc(
v_h); // Destroys the moved-out holder remains, resets value ptr to null
v_h.set_instance_registered(false);
construct_alias_from_cpp<Class>(is_alias_constructible<Class>{}, v_h, std::move(*ptr));
} else {
// Otherwise the type isn't inherited, so we don't need an Alias
v_h.value_ptr() = ptr;
}
construct_alias_from_cpp<Class>(is_alias_constructible<Class>{}, v_h,
std::move(*ptr));
} else {
// Otherwise the type isn't inherited, so we don't need an Alias
v_h.value_ptr() = ptr;
}
}
// Pointer return v2: a factory that always returns an alias instance ptr. We simply take over
// ownership of the pointer.
// Pointer return v2: a factory that always returns an alias instance ptr. We
// simply take over ownership of the pointer.
template <typename Class, enable_if_t<Class::has_alias, int> = 0>
void construct(value_and_holder &v_h, Alias<Class> *alias_ptr, bool) {
no_nullptr(alias_ptr);
v_h.value_ptr() = static_cast<Cpp<Class> *>(alias_ptr);
no_nullptr(alias_ptr);
v_h.value_ptr() = static_cast<Cpp<Class> *>(alias_ptr);
}
// Holder return: copy its pointer, and move or copy the returned holder into the new instance's
// holder. This also handles types like std::shared_ptr<T> and std::unique_ptr<T> where T is a
// derived type (through those holder's implicit conversion from derived class holder constructors).
// Holder return: copy its pointer, and move or copy the returned holder into
// the new instance's holder. This also handles types like std::shared_ptr<T>
// and std::unique_ptr<T> where T is a derived type (through those holder's
// implicit conversion from derived class holder constructors).
template <typename Class>
void construct(value_and_holder &v_h, Holder<Class> holder, bool need_alias) {
auto *ptr = holder_helper<Holder<Class>>::get(holder);
// If we need an alias, check that the held pointer is actually an alias instance
if (Class::has_alias && need_alias && !is_alias<Class>(ptr))
throw type_error("pybind11::init(): construction failed: returned holder-wrapped instance "
"is not an alias instance");
auto *ptr = holder_helper<Holder<Class>>::get(holder);
// If we need an alias, check that the held pointer is actually an alias
// instance
if (Class::has_alias && need_alias && !is_alias<Class>(ptr))
throw type_error("pybind11::init(): construction failed: returned "
"holder-wrapped instance "
"is not an alias instance");
v_h.value_ptr() = ptr;
v_h.type->init_instance(v_h.inst, &holder);
v_h.value_ptr() = ptr;
v_h.type->init_instance(v_h.inst, &holder);
}
// return-by-value version 1: returning a cpp class by value. If the class has an alias and an
// alias is required the alias must have an `Alias(Cpp &&)` constructor so that we can construct
// the alias from the base when needed (i.e. because of Python-side inheritance). When we don't
// need it, we simply move-construct the cpp value into a new instance.
// return-by-value version 1: returning a cpp class by value. If the class has
// an alias and an alias is required the alias must have an `Alias(Cpp &&)`
// constructor so that we can construct the alias from the base when needed
// (i.e. because of Python-side inheritance). When we don't need it, we simply
// move-construct the cpp value into a new instance.
template <typename Class>
void construct(value_and_holder &v_h, Cpp<Class> &&result, bool need_alias) {
static_assert(std::is_move_constructible<Cpp<Class>>::value,
"pybind11::init() return-by-value factory function requires a movable class");
if (Class::has_alias && need_alias)
construct_alias_from_cpp<Class>(is_alias_constructible<Class>{}, v_h, std::move(result));
else
v_h.value_ptr() = new Cpp<Class>(std::move(result));
static_assert(std::is_move_constructible<Cpp<Class>>::value,
"pybind11::init() return-by-value factory function requires a "
"movable class");
if (Class::has_alias && need_alias)
construct_alias_from_cpp<Class>(is_alias_constructible<Class>{}, v_h,
std::move(result));
else
v_h.value_ptr() = new Cpp<Class>(std::move(result));
}
// return-by-value version 2: returning a value of the alias type itself. We move-construct an
// Alias instance (even if no the python-side inheritance is involved). The is intended for
// cases where Alias initialization is always desired.
// return-by-value version 2: returning a value of the alias type itself. We
// move-construct an Alias instance (even if no the python-side inheritance is
// involved). The is intended for cases where Alias initialization is always
// desired.
template <typename Class>
void construct(value_and_holder &v_h, Alias<Class> &&result, bool) {
static_assert(std::is_move_constructible<Alias<Class>>::value,
"pybind11::init() return-by-alias-value factory function requires a movable alias class");
v_h.value_ptr() = new Alias<Class>(std::move(result));
static_assert(std::is_move_constructible<Alias<Class>>::value,
"pybind11::init() return-by-alias-value factory function "
"requires a movable alias class");
v_h.value_ptr() = new Alias<Class>(std::move(result));
}
// Implementing class for py::init<...>()
template <typename... Args>
struct constructor {
template <typename Class, typename... Extra, enable_if_t<!Class::has_alias, int> = 0>
static void execute(Class &cl, const Extra&... extra) {
cl.def("__init__", [](value_and_holder &v_h, Args... args) {
v_h.value_ptr() = construct_or_initialize<Cpp<Class>>(std::forward<Args>(args)...);
}, is_new_style_constructor(), extra...);
}
template <typename... Args> struct constructor {
template <typename Class, typename... Extra,
enable_if_t<!Class::has_alias, int> = 0>
static void execute(Class &cl, const Extra &... extra) {
cl.def(
"__init__",
[](value_and_holder &v_h, Args... args) {
v_h.value_ptr() =
construct_or_initialize<Cpp<Class>>(std::forward<Args>(args)...);
},
is_new_style_constructor(), extra...);
}
template <typename Class, typename... Extra,
enable_if_t<Class::has_alias &&
std::is_constructible<Cpp<Class>, Args...>::value, int> = 0>
static void execute(Class &cl, const Extra&... extra) {
cl.def("__init__", [](value_and_holder &v_h, Args... args) {
if (Py_TYPE(v_h.inst) == v_h.type->type)
v_h.value_ptr() = construct_or_initialize<Cpp<Class>>(std::forward<Args>(args)...);
else
v_h.value_ptr() = construct_or_initialize<Alias<Class>>(std::forward<Args>(args)...);
}, is_new_style_constructor(), extra...);
}
template <typename Class, typename... Extra,
enable_if_t<Class::has_alias &&
std::is_constructible<Cpp<Class>, Args...>::value,
int> = 0>
static void execute(Class &cl, const Extra &... extra) {
cl.def(
"__init__",
[](value_and_holder &v_h, Args... args) {
if (Py_TYPE(v_h.inst) == v_h.type->type)
v_h.value_ptr() = construct_or_initialize<Cpp<Class>>(
std::forward<Args>(args)...);
else
v_h.value_ptr() = construct_or_initialize<Alias<Class>>(
std::forward<Args>(args)...);
},
is_new_style_constructor(), extra...);
}
template <typename Class, typename... Extra,
enable_if_t<Class::has_alias &&
!std::is_constructible<Cpp<Class>, Args...>::value, int> = 0>
static void execute(Class &cl, const Extra&... extra) {
cl.def("__init__", [](value_and_holder &v_h, Args... args) {
v_h.value_ptr() = construct_or_initialize<Alias<Class>>(std::forward<Args>(args)...);
}, is_new_style_constructor(), extra...);
}
template <typename Class, typename... Extra,
enable_if_t<Class::has_alias &&
!std::is_constructible<Cpp<Class>, Args...>::value,
int> = 0>
static void execute(Class &cl, const Extra &... extra) {
cl.def(
"__init__",
[](value_and_holder &v_h, Args... args) {
v_h.value_ptr() = construct_or_initialize<Alias<Class>>(
std::forward<Args>(args)...);
},
is_new_style_constructor(), extra...);
}
};
// Implementing class for py::init_alias<...>()
template <typename... Args> struct alias_constructor {
template <typename Class, typename... Extra,
enable_if_t<Class::has_alias && std::is_constructible<Alias<Class>, Args...>::value, int> = 0>
static void execute(Class &cl, const Extra&... extra) {
cl.def("__init__", [](value_and_holder &v_h, Args... args) {
v_h.value_ptr() = construct_or_initialize<Alias<Class>>(std::forward<Args>(args)...);
}, is_new_style_constructor(), extra...);
}
template <typename Class, typename... Extra,
enable_if_t<Class::has_alias &&
std::is_constructible<Alias<Class>, Args...>::value,
int> = 0>
static void execute(Class &cl, const Extra &... extra) {
cl.def(
"__init__",
[](value_and_holder &v_h, Args... args) {
v_h.value_ptr() = construct_or_initialize<Alias<Class>>(
std::forward<Args>(args)...);
},
is_new_style_constructor(), extra...);
}
};
// Implementation class for py::init(Func) and py::init(Func, AliasFunc)
template <typename CFunc, typename AFunc = void_type (*)(),
typename = function_signature_t<CFunc>, typename = function_signature_t<AFunc>>
typename = function_signature_t<CFunc>,
typename = function_signature_t<AFunc>>
struct factory;
// Specialization for py::init(Func)
template <typename Func, typename Return, typename... Args>
struct factory<Func, void_type (*)(), Return(Args...)> {
remove_reference_t<Func> class_factory;
remove_reference_t<Func> class_factory;
factory(Func &&f) : class_factory(std::forward<Func>(f)) { }
factory(Func &&f) : class_factory(std::forward<Func>(f)) {}
// The given class either has no alias or has no separate alias factory;
// this always constructs the class itself. If the class is registered with an alias
// type and an alias instance is needed (i.e. because the final type is a Python class
// inheriting from the C++ type) the returned value needs to either already be an alias
// instance, or the alias needs to be constructible from a `Class &&` argument.
template <typename Class, typename... Extra>
void execute(Class &cl, const Extra &...extra) && {
#if defined(PYBIND11_CPP14)
cl.def("__init__", [func = std::move(class_factory)]
#else
auto &func = class_factory;
cl.def("__init__", [func]
#endif
// The given class either has no alias or has no separate alias factory;
// this always constructs the class itself. If the class is registered with
// an alias type and an alias instance is needed (i.e. because the final type
// is a Python class inheriting from the C++ type) the returned value needs to
// either already be an alias instance, or the alias needs to be constructible
// from a `Class &&` argument.
template <typename Class, typename... Extra>
void execute(Class &cl, const Extra &... extra) && {
#if defined(PYBIND11_CPP14)
cl.def(
"__init__",
[func = std::move(class_factory)]
#else
auto &func = class_factory;
cl.def(
"__init__",
[func]
#endif
(value_and_holder &v_h, Args... args) {
construct<Class>(v_h, func(std::forward<Args>(args)...),
Py_TYPE(v_h.inst) != v_h.type->type);
}, is_new_style_constructor(), extra...);
}
construct<Class>(v_h, func(std::forward<Args>(args)...),
Py_TYPE(v_h.inst) != v_h.type->type);
},
is_new_style_constructor(), extra...);
}
};
// Specialization for py::init(Func, AliasFunc)
template <typename CFunc, typename AFunc,
typename CReturn, typename... CArgs, typename AReturn, typename... AArgs>
template <typename CFunc, typename AFunc, typename CReturn, typename... CArgs,
typename AReturn, typename... AArgs>
struct factory<CFunc, AFunc, CReturn(CArgs...), AReturn(AArgs...)> {
static_assert(sizeof...(CArgs) == sizeof...(AArgs),
"pybind11::init(class_factory, alias_factory): class and alias factories "
"must have identical argument signatures");
static_assert(all_of<std::is_same<CArgs, AArgs>...>::value,
"pybind11::init(class_factory, alias_factory): class and alias factories "
"must have identical argument signatures");
static_assert(
sizeof...(CArgs) == sizeof...(AArgs),
"pybind11::init(class_factory, alias_factory): class and alias factories "
"must have identical argument signatures");
static_assert(
all_of<std::is_same<CArgs, AArgs>...>::value,
"pybind11::init(class_factory, alias_factory): class and alias factories "
"must have identical argument signatures");
remove_reference_t<CFunc> class_factory;
remove_reference_t<AFunc> alias_factory;
remove_reference_t<CFunc> class_factory;
remove_reference_t<AFunc> alias_factory;
factory(CFunc &&c, AFunc &&a)
: class_factory(std::forward<CFunc>(c)), alias_factory(std::forward<AFunc>(a)) { }
factory(CFunc &&c, AFunc &&a)
: class_factory(std::forward<CFunc>(c)),
alias_factory(std::forward<AFunc>(a)) {}
// The class factory is called when the `self` type passed to `__init__` is the direct
// class (i.e. not inherited), the alias factory when `self` is a Python-side subtype.
template <typename Class, typename... Extra>
void execute(Class &cl, const Extra&... extra) && {
static_assert(Class::has_alias, "The two-argument version of `py::init()` can "
"only be used if the class has an alias");
#if defined(PYBIND11_CPP14)
cl.def("__init__", [class_func = std::move(class_factory), alias_func = std::move(alias_factory)]
#else
auto &class_func = class_factory;
auto &alias_func = alias_factory;
cl.def("__init__", [class_func, alias_func]
#endif
// The class factory is called when the `self` type passed to `__init__` is
// the direct class (i.e. not inherited), the alias factory when `self` is a
// Python-side subtype.
template <typename Class, typename... Extra>
void execute(Class &cl, const Extra &... extra) && {
static_assert(Class::has_alias,
"The two-argument version of `py::init()` can "
"only be used if the class has an alias");
#if defined(PYBIND11_CPP14)
cl.def(
"__init__",
[class_func = std::move(class_factory),
alias_func = std::move(alias_factory)]
#else
auto &class_func = class_factory;
auto &alias_func = alias_factory;
cl.def(
"__init__",
[class_func, alias_func]
#endif
(value_and_holder &v_h, CArgs... args) {
if (Py_TYPE(v_h.inst) == v_h.type->type)
// If the instance type equals the registered type we don't have inheritance, so
// don't need the alias and can construct using the class function:
construct<Class>(v_h, class_func(std::forward<CArgs>(args)...), false);
else
construct<Class>(v_h, alias_func(std::forward<CArgs>(args)...), true);
}, is_new_style_constructor(), extra...);
}
if (Py_TYPE(v_h.inst) == v_h.type->type)
// If the instance type equals the registered type we don't have
// inheritance, so don't need the alias and can construct using the
// class function:
construct<Class>(v_h, class_func(std::forward<CArgs>(args)...),
false);
else
construct<Class>(v_h, alias_func(std::forward<CArgs>(args)...),
true);
},
is_new_style_constructor(), extra...);
}
};
/// Set just the C++ state. Same as `__init__`.
template <typename Class, typename T>
void setstate(value_and_holder &v_h, T &&result, bool need_alias) {
construct<Class>(v_h, std::forward<T>(result), need_alias);
construct<Class>(v_h, std::forward<T>(result), need_alias);
}
/// Set both the C++ and Python states
template <typename Class, typename T, typename O,
enable_if_t<std::is_convertible<O, handle>::value, int> = 0>
void setstate(value_and_holder &v_h, std::pair<T, O> &&result, bool need_alias) {
construct<Class>(v_h, std::move(result.first), need_alias);
setattr((PyObject *) v_h.inst, "__dict__", result.second);
void setstate(value_and_holder &v_h, std::pair<T, O> &&result,
bool need_alias) {
construct<Class>(v_h, std::move(result.first), need_alias);
setattr((PyObject *)v_h.inst, "__dict__", result.second);
}
/// Implementation for py::pickle(GetState, SetState)
template <typename Get, typename Set,
typename = function_signature_t<Get>, typename = function_signature_t<Set>>
template <typename Get, typename Set, typename = function_signature_t<Get>,
typename = function_signature_t<Set>>
struct pickle_factory;
template <typename Get, typename Set,
typename RetState, typename Self, typename NewInstance, typename ArgState>
template <typename Get, typename Set, typename RetState, typename Self,
typename NewInstance, typename ArgState>
struct pickle_factory<Get, Set, RetState(Self), NewInstance(ArgState)> {
static_assert(std::is_same<intrinsic_t<RetState>, intrinsic_t<ArgState>>::value,
"The type returned by `__getstate__` must be the same "
"as the argument accepted by `__setstate__`");
static_assert(
std::is_same<intrinsic_t<RetState>, intrinsic_t<ArgState>>::value,
"The type returned by `__getstate__` must be the same "
"as the argument accepted by `__setstate__`");
remove_reference_t<Get> get;
remove_reference_t<Set> set;
remove_reference_t<Get> get;
remove_reference_t<Set> set;
pickle_factory(Get get, Set set)
: get(std::forward<Get>(get)), set(std::forward<Set>(set)) { }
pickle_factory(Get get, Set set)
: get(std::forward<Get>(get)), set(std::forward<Set>(set)) {}
template <typename Class, typename... Extra>
void execute(Class &cl, const Extra &...extra) && {
cl.def("__getstate__", std::move(get));
template <typename Class, typename... Extra>
void execute(Class &cl, const Extra &... extra) && {
cl.def("__getstate__", std::move(get));
#if defined(PYBIND11_CPP14)
cl.def("__setstate__", [func = std::move(set)]
cl.def(
"__setstate__",
[func = std::move(set)]
#else
auto &func = set;
cl.def("__setstate__", [func]
auto &func = set;
cl.def(
"__setstate__",
[func]
#endif
(value_and_holder &v_h, ArgState state) {
setstate<Class>(v_h, func(std::forward<ArgState>(state)),
Py_TYPE(v_h.inst) != v_h.type->type);
}, is_new_style_constructor(), extra...);
}
setstate<Class>(v_h, func(std::forward<ArgState>(state)),
Py_TYPE(v_h.inst) != v_h.type->type);
},
is_new_style_constructor(), extra...);
}
};
NAMESPACE_END(initimpl)

View File

@@ -18,276 +18,323 @@ inline PyTypeObject *make_static_property_type();
inline PyTypeObject *make_default_metaclass();
inline PyObject *make_object_base_type(PyTypeObject *metaclass);
// The old Python Thread Local Storage (TLS) API is deprecated in Python 3.7 in favor of the new
// Thread Specific Storage (TSS) API.
// The old Python Thread Local Storage (TLS) API is deprecated in Python 3.7 in
// favor of the new Thread Specific Storage (TSS) API.
#if PY_VERSION_HEX >= 0x03070000
# define PYBIND11_TLS_KEY_INIT(var) Py_tss_t *var = nullptr
# define PYBIND11_TLS_GET_VALUE(key) PyThread_tss_get((key))
# define PYBIND11_TLS_REPLACE_VALUE(key, value) PyThread_tss_set((key), (value))
# define PYBIND11_TLS_DELETE_VALUE(key) PyThread_tss_set((key), nullptr)
#define PYBIND11_TLS_KEY_INIT(var) Py_tss_t *var = nullptr
#define PYBIND11_TLS_GET_VALUE(key) PyThread_tss_get((key))
#define PYBIND11_TLS_REPLACE_VALUE(key, value) PyThread_tss_set((key), (value))
#define PYBIND11_TLS_DELETE_VALUE(key) PyThread_tss_set((key), nullptr)
#else
// Usually an int but a long on Cygwin64 with Python 3.x
# define PYBIND11_TLS_KEY_INIT(var) decltype(PyThread_create_key()) var = 0
# define PYBIND11_TLS_GET_VALUE(key) PyThread_get_key_value((key))
# if PY_MAJOR_VERSION < 3
# define PYBIND11_TLS_DELETE_VALUE(key) \
PyThread_delete_key_value(key)
# define PYBIND11_TLS_REPLACE_VALUE(key, value) \
do { \
PyThread_delete_key_value((key)); \
PyThread_set_key_value((key), (value)); \
} while (false)
# else
# define PYBIND11_TLS_DELETE_VALUE(key) \
PyThread_set_key_value((key), nullptr)
# define PYBIND11_TLS_REPLACE_VALUE(key, value) \
PyThread_set_key_value((key), (value))
# endif
// Usually an int but a long on Cygwin64 with Python 3.x
#define PYBIND11_TLS_KEY_INIT(var) decltype(PyThread_create_key()) var = 0
#define PYBIND11_TLS_GET_VALUE(key) PyThread_get_key_value((key))
#if PY_MAJOR_VERSION < 3
#define PYBIND11_TLS_DELETE_VALUE(key) PyThread_delete_key_value(key)
#define PYBIND11_TLS_REPLACE_VALUE(key, value) \
do { \
PyThread_delete_key_value((key)); \
PyThread_set_key_value((key), (value)); \
} while (false)
#else
#define PYBIND11_TLS_DELETE_VALUE(key) PyThread_set_key_value((key), nullptr)
#define PYBIND11_TLS_REPLACE_VALUE(key, value) \
PyThread_set_key_value((key), (value))
#endif
#endif
// Python loads modules by default with dlopen with the RTLD_LOCAL flag; under libc++ and possibly
// other STLs, this means `typeid(A)` from one module won't equal `typeid(A)` from another module
// even when `A` is the same, non-hidden-visibility type (e.g. from a common include). Under
// libstdc++, this doesn't happen: equality and the type_index hash are based on the type name,
// which works. If not under a known-good stl, provide our own name-based hash and equality
// functions that use the type name.
// Python loads modules by default with dlopen with the RTLD_LOCAL flag; under
// libc++ and possibly other STLs, this means `typeid(A)` from one module won't
// equal `typeid(A)` from another module even when `A` is the same,
// non-hidden-visibility type (e.g. from a common include). Under libstdc++,
// this doesn't happen: equality and the type_index hash are based on the type
// name, which works. If not under a known-good stl, provide our own name-based
// hash and equality functions that use the type name.
#if defined(__GLIBCXX__)
inline bool same_type(const std::type_info &lhs, const std::type_info &rhs) { return lhs == rhs; }
inline bool same_type(const std::type_info &lhs, const std::type_info &rhs) {
return lhs == rhs;
}
using type_hash = std::hash<std::type_index>;
using type_equal_to = std::equal_to<std::type_index>;
#else
inline bool same_type(const std::type_info &lhs, const std::type_info &rhs) {
return lhs.name() == rhs.name() || std::strcmp(lhs.name(), rhs.name()) == 0;
return lhs.name() == rhs.name() || std::strcmp(lhs.name(), rhs.name()) == 0;
}
struct type_hash {
size_t operator()(const std::type_index &t) const {
size_t hash = 5381;
const char *ptr = t.name();
while (auto c = static_cast<unsigned char>(*ptr++))
hash = (hash * 33) ^ c;
return hash;
}
size_t operator()(const std::type_index &t) const {
size_t hash = 5381;
const char *ptr = t.name();
while (auto c = static_cast<unsigned char>(*ptr++))
hash = (hash * 33) ^ c;
return hash;
}
};
struct type_equal_to {
bool operator()(const std::type_index &lhs, const std::type_index &rhs) const {
return lhs.name() == rhs.name() || std::strcmp(lhs.name(), rhs.name()) == 0;
}
bool operator()(const std::type_index &lhs,
const std::type_index &rhs) const {
return lhs.name() == rhs.name() || std::strcmp(lhs.name(), rhs.name()) == 0;
}
};
#endif
template <typename value_type>
using type_map = std::unordered_map<std::type_index, value_type, type_hash, type_equal_to>;
using type_map =
std::unordered_map<std::type_index, value_type, type_hash, type_equal_to>;
struct overload_hash {
inline size_t operator()(const std::pair<const PyObject *, const char *>& v) const {
size_t value = std::hash<const void *>()(v.first);
value ^= std::hash<const void *>()(v.second) + 0x9e3779b9 + (value<<6) + (value>>2);
return value;
}
inline size_t
operator()(const std::pair<const PyObject *, const char *> &v) const {
size_t value = std::hash<const void *>()(v.first);
value ^= std::hash<const void *>()(v.second) + 0x9e3779b9 + (value << 6) +
(value >> 2);
return value;
}
};
/// Internal data structure used to track registered instances and types.
/// Whenever binary incompatible changes are made to this structure,
/// `PYBIND11_INTERNALS_VERSION` must be incremented.
struct internals {
type_map<type_info *> registered_types_cpp; // std::type_index -> pybind11's type information
std::unordered_map<PyTypeObject *, std::vector<type_info *>> registered_types_py; // PyTypeObject* -> base type_info(s)
std::unordered_multimap<const void *, instance*> registered_instances; // void * -> instance*
std::unordered_set<std::pair<const PyObject *, const char *>, overload_hash> inactive_overload_cache;
type_map<std::vector<bool (*)(PyObject *, void *&)>> direct_conversions;
std::unordered_map<const PyObject *, std::vector<PyObject *>> patients;
std::forward_list<void (*) (std::exception_ptr)> registered_exception_translators;
std::unordered_map<std::string, void *> shared_data; // Custom data to be shared across extensions
std::vector<PyObject *> loader_patient_stack; // Used by `loader_life_support`
std::forward_list<std::string> static_strings; // Stores the std::strings backing detail::c_str()
PyTypeObject *static_property_type;
PyTypeObject *default_metaclass;
PyObject *instance_base;
type_map<type_info *>
registered_types_cpp; // std::type_index -> pybind11's type information
std::unordered_map<PyTypeObject *, std::vector<type_info *>>
registered_types_py; // PyTypeObject* -> base type_info(s)
std::unordered_multimap<const void *, instance *>
registered_instances; // void * -> instance*
std::unordered_set<std::pair<const PyObject *, const char *>, overload_hash>
inactive_overload_cache;
type_map<std::vector<bool (*)(PyObject *, void *&)>> direct_conversions;
std::unordered_map<const PyObject *, std::vector<PyObject *>> patients;
std::forward_list<void (*)(std::exception_ptr)>
registered_exception_translators;
std::unordered_map<std::string, void *>
shared_data; // Custom data to be shared across extensions
std::vector<PyObject *> loader_patient_stack; // Used by `loader_life_support`
std::forward_list<std::string>
static_strings; // Stores the std::strings backing detail::c_str()
PyTypeObject *static_property_type;
PyTypeObject *default_metaclass;
PyObject *instance_base;
#if defined(WITH_THREAD)
PYBIND11_TLS_KEY_INIT(tstate);
PyInterpreterState *istate = nullptr;
PYBIND11_TLS_KEY_INIT(tstate);
PyInterpreterState *istate = nullptr;
#endif
};
/// Additional type information which does not fit into the PyTypeObject.
/// Changes to this struct also require bumping `PYBIND11_INTERNALS_VERSION`.
struct type_info {
PyTypeObject *type;
const std::type_info *cpptype;
size_t type_size, type_align, holder_size_in_ptrs;
void *(*operator_new)(size_t);
void (*init_instance)(instance *, const void *);
void (*dealloc)(value_and_holder &v_h);
std::vector<PyObject *(*)(PyObject *, PyTypeObject *)> implicit_conversions;
std::vector<std::pair<const std::type_info *, void *(*)(void *)>> implicit_casts;
std::vector<bool (*)(PyObject *, void *&)> *direct_conversions;
buffer_info *(*get_buffer)(PyObject *, void *) = nullptr;
void *get_buffer_data = nullptr;
void *(*module_local_load)(PyObject *, const type_info *) = nullptr;
/* A simple type never occurs as a (direct or indirect) parent
* of a class that makes use of multiple inheritance */
bool simple_type : 1;
/* True if there is no multiple inheritance in this type's inheritance tree */
bool simple_ancestors : 1;
/* for base vs derived holder_type checks */
bool default_holder : 1;
/* true if this is a type registered with py::module_local */
bool module_local : 1;
PyTypeObject *type;
const std::type_info *cpptype;
size_t type_size, type_align, holder_size_in_ptrs;
void *(*operator_new)(size_t);
void (*init_instance)(instance *, const void *);
void (*dealloc)(value_and_holder &v_h);
std::vector<PyObject *(*)(PyObject *, PyTypeObject *)> implicit_conversions;
std::vector<std::pair<const std::type_info *, void *(*)(void *)>>
implicit_casts;
std::vector<bool (*)(PyObject *, void *&)> *direct_conversions;
buffer_info *(*get_buffer)(PyObject *, void *) = nullptr;
void *get_buffer_data = nullptr;
void *(*module_local_load)(PyObject *, const type_info *) = nullptr;
/* A simple type never occurs as a (direct or indirect) parent
* of a class that makes use of multiple inheritance */
bool simple_type : 1;
/* True if there is no multiple inheritance in this type's inheritance tree */
bool simple_ancestors : 1;
/* for base vs derived holder_type checks */
bool default_holder : 1;
/* true if this is a type registered with py::module_local */
bool module_local : 1;
};
/// Tracks the `internals` and `type_info` ABI version independent of the main library version
/// Tracks the `internals` and `type_info` ABI version independent of the main
/// library version
#define PYBIND11_INTERNALS_VERSION 3
#if defined(_DEBUG)
# define PYBIND11_BUILD_TYPE "_debug"
#define PYBIND11_BUILD_TYPE "_debug"
#else
# define PYBIND11_BUILD_TYPE ""
#define PYBIND11_BUILD_TYPE ""
#endif
#if defined(WITH_THREAD)
# define PYBIND11_INTERNALS_KIND ""
#define PYBIND11_INTERNALS_KIND ""
#else
# define PYBIND11_INTERNALS_KIND "_without_thread"
#define PYBIND11_INTERNALS_KIND "_without_thread"
#endif
#define PYBIND11_INTERNALS_ID "__pybind11_internals_v" \
PYBIND11_TOSTRING(PYBIND11_INTERNALS_VERSION) PYBIND11_INTERNALS_KIND PYBIND11_BUILD_TYPE "__"
#define PYBIND11_INTERNALS_ID \
"__pybind11_internals_v" PYBIND11_TOSTRING(PYBIND11_INTERNALS_VERSION) \
PYBIND11_INTERNALS_KIND PYBIND11_BUILD_TYPE "__"
#define PYBIND11_MODULE_LOCAL_ID "__pybind11_module_local_v" \
PYBIND11_TOSTRING(PYBIND11_INTERNALS_VERSION) PYBIND11_INTERNALS_KIND PYBIND11_BUILD_TYPE "__"
#define PYBIND11_MODULE_LOCAL_ID \
"__pybind11_module_local_v" PYBIND11_TOSTRING(PYBIND11_INTERNALS_VERSION) \
PYBIND11_INTERNALS_KIND PYBIND11_BUILD_TYPE "__"
/// Each module locally stores a pointer to the `internals` data. The data
/// itself is shared among modules with the same `PYBIND11_INTERNALS_ID`.
inline internals **&get_internals_pp() {
static internals **internals_pp = nullptr;
return internals_pp;
static internals **internals_pp = nullptr;
return internals_pp;
}
/// Return a reference to the current `internals` data
PYBIND11_NOINLINE inline internals &get_internals() {
auto **&internals_pp = get_internals_pp();
if (internals_pp && *internals_pp)
return **internals_pp;
constexpr auto *id = PYBIND11_INTERNALS_ID;
auto builtins = handle(PyEval_GetBuiltins());
if (builtins.contains(id) && isinstance<capsule>(builtins[id])) {
internals_pp = static_cast<internals **>(capsule(builtins[id]));
// We loaded builtins through python's builtins, which means that our `error_already_set`
// and `builtin_exception` may be different local classes than the ones set up in the
// initial exception translator, below, so add another for our local exception classes.
//
// libstdc++ doesn't require this (types there are identified only by name)
#if !defined(__GLIBCXX__)
(*internals_pp)->registered_exception_translators.push_front(
[](std::exception_ptr p) -> void {
try {
if (p) std::rethrow_exception(p);
} catch (error_already_set &e) { e.restore(); return;
} catch (const builtin_exception &e) { e.set_error(); return;
}
}
);
#endif
} else {
if (!internals_pp) internals_pp = new internals*();
auto *&internals_ptr = *internals_pp;
internals_ptr = new internals();
#if defined(WITH_THREAD)
#if PY_VERSION_HEX < 0x03090000
PyEval_InitThreads();
#endif
PyThreadState *tstate = PyThreadState_Get();
#if PY_VERSION_HEX >= 0x03070000
internals_ptr->tstate = PyThread_tss_alloc();
if (!internals_ptr->tstate || PyThread_tss_create(internals_ptr->tstate))
pybind11_fail("get_internals: could not successfully initialize the TSS key!");
PyThread_tss_set(internals_ptr->tstate, tstate);
#else
internals_ptr->tstate = PyThread_create_key();
if (internals_ptr->tstate == -1)
pybind11_fail("get_internals: could not successfully initialize the TLS key!");
PyThread_set_key_value(internals_ptr->tstate, tstate);
#endif
internals_ptr->istate = tstate->interp;
#endif
builtins[id] = capsule(internals_pp);
internals_ptr->registered_exception_translators.push_front(
[](std::exception_ptr p) -> void {
try {
if (p) std::rethrow_exception(p);
} catch (error_already_set &e) { e.restore(); return;
} catch (const builtin_exception &e) { e.set_error(); return;
} catch (const std::bad_alloc &e) { PyErr_SetString(PyExc_MemoryError, e.what()); return;
} catch (const std::domain_error &e) { PyErr_SetString(PyExc_ValueError, e.what()); return;
} catch (const std::invalid_argument &e) { PyErr_SetString(PyExc_ValueError, e.what()); return;
} catch (const std::length_error &e) { PyErr_SetString(PyExc_ValueError, e.what()); return;
} catch (const std::out_of_range &e) { PyErr_SetString(PyExc_IndexError, e.what()); return;
} catch (const std::range_error &e) { PyErr_SetString(PyExc_ValueError, e.what()); return;
} catch (const std::exception &e) { PyErr_SetString(PyExc_RuntimeError, e.what()); return;
} catch (...) {
PyErr_SetString(PyExc_RuntimeError, "Caught an unknown exception!");
return;
}
}
);
internals_ptr->static_property_type = make_static_property_type();
internals_ptr->default_metaclass = make_default_metaclass();
internals_ptr->instance_base = make_object_base_type(internals_ptr->default_metaclass);
}
auto **&internals_pp = get_internals_pp();
if (internals_pp && *internals_pp)
return **internals_pp;
constexpr auto *id = PYBIND11_INTERNALS_ID;
auto builtins = handle(PyEval_GetBuiltins());
if (builtins.contains(id) && isinstance<capsule>(builtins[id])) {
internals_pp = static_cast<internals **>(capsule(builtins[id]));
// We loaded builtins through python's builtins, which means that our
// `error_already_set` and `builtin_exception` may be different local
// classes than the ones set up in the initial exception translator, below,
// so add another for our local exception classes.
//
// libstdc++ doesn't require this (types there are identified only by name)
#if !defined(__GLIBCXX__)
(*internals_pp)
->registered_exception_translators.push_front(
[](std::exception_ptr p) -> void {
try {
if (p)
std::rethrow_exception(p);
} catch (error_already_set &e) {
e.restore();
return;
} catch (const builtin_exception &e) {
e.set_error();
return;
}
});
#endif
} else {
if (!internals_pp)
internals_pp = new internals *();
auto *&internals_ptr = *internals_pp;
internals_ptr = new internals();
#if defined(WITH_THREAD)
#if PY_VERSION_HEX < 0x03090000
PyEval_InitThreads();
#endif
PyThreadState *tstate = PyThreadState_Get();
#if PY_VERSION_HEX >= 0x03070000
internals_ptr->tstate = PyThread_tss_alloc();
if (!internals_ptr->tstate || PyThread_tss_create(internals_ptr->tstate))
pybind11_fail(
"get_internals: could not successfully initialize the TSS key!");
PyThread_tss_set(internals_ptr->tstate, tstate);
#else
internals_ptr->tstate = PyThread_create_key();
if (internals_ptr->tstate == -1)
pybind11_fail(
"get_internals: could not successfully initialize the TLS key!");
PyThread_set_key_value(internals_ptr->tstate, tstate);
#endif
internals_ptr->istate = tstate->interp;
#endif
builtins[id] = capsule(internals_pp);
internals_ptr->registered_exception_translators.push_front(
[](std::exception_ptr p) -> void {
try {
if (p)
std::rethrow_exception(p);
} catch (error_already_set &e) {
e.restore();
return;
} catch (const builtin_exception &e) {
e.set_error();
return;
} catch (const std::bad_alloc &e) {
PyErr_SetString(PyExc_MemoryError, e.what());
return;
} catch (const std::domain_error &e) {
PyErr_SetString(PyExc_ValueError, e.what());
return;
} catch (const std::invalid_argument &e) {
PyErr_SetString(PyExc_ValueError, e.what());
return;
} catch (const std::length_error &e) {
PyErr_SetString(PyExc_ValueError, e.what());
return;
} catch (const std::out_of_range &e) {
PyErr_SetString(PyExc_IndexError, e.what());
return;
} catch (const std::range_error &e) {
PyErr_SetString(PyExc_ValueError, e.what());
return;
} catch (const std::exception &e) {
PyErr_SetString(PyExc_RuntimeError, e.what());
return;
} catch (...) {
PyErr_SetString(PyExc_RuntimeError, "Caught an unknown exception!");
return;
}
});
internals_ptr->static_property_type = make_static_property_type();
internals_ptr->default_metaclass = make_default_metaclass();
internals_ptr->instance_base =
make_object_base_type(internals_ptr->default_metaclass);
}
return **internals_pp;
}
/// Works like `internals.registered_types_cpp`, but for module-local registered types:
/// Works like `internals.registered_types_cpp`, but for module-local registered
/// types:
inline type_map<type_info *> &registered_local_types_cpp() {
static type_map<type_info *> locals{};
return locals;
static type_map<type_info *> locals{};
return locals;
}
/// Constructs a std::string with the given arguments, stores it in `internals`, and returns its
/// `c_str()`. Such strings objects have a long storage duration -- the internal strings are only
/// cleared when the program exits or after interpreter shutdown (when embedding), and so are
/// suitable for c-style strings needed by Python internals (such as PyTypeObject's tp_name).
template <typename... Args>
const char *c_str(Args &&...args) {
auto &strings = get_internals().static_strings;
strings.emplace_front(std::forward<Args>(args)...);
return strings.front().c_str();
/// Constructs a std::string with the given arguments, stores it in `internals`,
/// and returns its `c_str()`. Such strings objects have a long storage
/// duration -- the internal strings are only cleared when the program exits or
/// after interpreter shutdown (when embedding), and so are suitable for c-style
/// strings needed by Python internals (such as PyTypeObject's tp_name).
template <typename... Args> const char *c_str(Args &&... args) {
auto &strings = get_internals().static_strings;
strings.emplace_front(std::forward<Args>(args)...);
return strings.front().c_str();
}
NAMESPACE_END(detail)
/// Returns a named pointer that is shared among all extension modules (using the same
/// pybind11 version) running in the current interpreter. Names starting with underscores
/// are reserved for internal usage. Returns `nullptr` if no matching entry was found.
/// Returns a named pointer that is shared among all extension modules (using
/// the same pybind11 version) running in the current interpreter. Names
/// starting with underscores are reserved for internal usage. Returns `nullptr`
/// if no matching entry was found.
inline PYBIND11_NOINLINE void *get_shared_data(const std::string &name) {
auto &internals = detail::get_internals();
auto it = internals.shared_data.find(name);
return it != internals.shared_data.end() ? it->second : nullptr;
auto &internals = detail::get_internals();
auto it = internals.shared_data.find(name);
return it != internals.shared_data.end() ? it->second : nullptr;
}
/// Set the shared data that can be later recovered by `get_shared_data()`.
inline PYBIND11_NOINLINE void *set_shared_data(const std::string &name, void *data) {
detail::get_internals().shared_data[name] = data;
return data;
inline PYBIND11_NOINLINE void *set_shared_data(const std::string &name,
void *data) {
detail::get_internals().shared_data[name] = data;
return data;
}
/// Returns a typed reference to a shared data entry (by using `get_shared_data()`) if
/// such entry exists. Otherwise, a new object of default-constructible type `T` is
/// added to the shared data under the given name and a reference to it is returned.
template<typename T>
T &get_or_create_shared_data(const std::string &name) {
auto &internals = detail::get_internals();
auto it = internals.shared_data.find(name);
T *ptr = (T *) (it != internals.shared_data.end() ? it->second : nullptr);
if (!ptr) {
ptr = new T();
internals.shared_data[name] = ptr;
}
return *ptr;
/// Returns a typed reference to a shared data entry (by using
/// `get_shared_data()`) if such entry exists. Otherwise, a new object of
/// default-constructible type `T` is added to the shared data under the given
/// name and a reference to it is returned.
template <typename T> T &get_or_create_shared_data(const std::string &name) {
auto &internals = detail::get_internals();
auto it = internals.shared_data.find(name);
T *ptr = (T *)(it != internals.shared_data.end() ? it->second : nullptr);
if (!ptr) {
ptr = new T();
internals.shared_data[name] = ptr;
}
return *ptr;
}
NAMESPACE_END(PYBIND11_NAMESPACE)

View File

@@ -22,34 +22,35 @@ NAMESPACE_BEGIN(PYBIND11_NAMESPACE)
NAMESPACE_BEGIN(detail)
/// Erase all occurrences of a substring
inline void erase_all(std::string &string, const std::string &search) {
for (size_t pos = 0;;) {
pos = string.find(search, pos);
if (pos == std::string::npos) break;
string.erase(pos, search.length());
}
for (size_t pos = 0;;) {
pos = string.find(search, pos);
if (pos == std::string::npos)
break;
string.erase(pos, search.length());
}
}
PYBIND11_NOINLINE inline void clean_type_id(std::string &name) {
#if defined(__GNUG__)
int status = 0;
std::unique_ptr<char, void (*)(void *)> res {
abi::__cxa_demangle(name.c_str(), nullptr, nullptr, &status), std::free };
if (status == 0)
name = res.get();
int status = 0;
std::unique_ptr<char, void (*)(void *)> res{
abi::__cxa_demangle(name.c_str(), nullptr, nullptr, &status), std::free};
if (status == 0)
name = res.get();
#else
detail::erase_all(name, "class ");
detail::erase_all(name, "struct ");
detail::erase_all(name, "enum ");
detail::erase_all(name, "class ");
detail::erase_all(name, "struct ");
detail::erase_all(name, "enum ");
#endif
detail::erase_all(name, "pybind11::");
detail::erase_all(name, "pybind11::");
}
NAMESPACE_END(detail)
/// Return a string representation of a C++ type
template <typename T> static std::string type_id() {
std::string name(typeid(T).name());
detail::clean_type_id(name);
return name;
std::string name(typeid(T).name());
detail::clean_type_id(name);
return name;
}
NAMESPACE_END(PYBIND11_NAMESPACE)

File diff suppressed because it is too large Load Diff

View File

@@ -9,23 +9,23 @@
#pragma once
#include "pybind11.h"
#include "eval.h"
#include "pybind11.h"
#if defined(PYPY_VERSION)
# error Embedding the interpreter is not supported with PyPy
#error Embedding the interpreter is not supported with PyPy
#endif
#if PY_MAJOR_VERSION >= 3
# define PYBIND11_EMBEDDED_MODULE_IMPL(name) \
extern "C" PyObject *pybind11_init_impl_##name() { \
return pybind11_init_wrapper_##name(); \
}
#define PYBIND11_EMBEDDED_MODULE_IMPL(name) \
extern "C" PyObject *pybind11_init_impl_##name() { \
return pybind11_init_wrapper_##name(); \
}
#else
# define PYBIND11_EMBEDDED_MODULE_IMPL(name) \
extern "C" void pybind11_init_impl_##name() { \
pybind11_init_wrapper_##name(); \
}
#define PYBIND11_EMBEDDED_MODULE_IMPL(name) \
extern "C" void pybind11_init_impl_##name() { \
pybind11_init_wrapper_##name(); \
}
#endif
/** \rst
@@ -43,75 +43,78 @@
});
}
\endrst */
#define PYBIND11_EMBEDDED_MODULE(name, variable) \
static void PYBIND11_CONCAT(pybind11_init_, name)(pybind11::module &); \
static PyObject PYBIND11_CONCAT(*pybind11_init_wrapper_, name)() { \
auto m = pybind11::module(PYBIND11_TOSTRING(name)); \
try { \
PYBIND11_CONCAT(pybind11_init_, name)(m); \
return m.ptr(); \
} catch (pybind11::error_already_set &e) { \
PyErr_SetString(PyExc_ImportError, e.what()); \
return nullptr; \
} catch (const std::exception &e) { \
PyErr_SetString(PyExc_ImportError, e.what()); \
return nullptr; \
} \
} \
PYBIND11_EMBEDDED_MODULE_IMPL(name) \
pybind11::detail::embedded_module name(PYBIND11_TOSTRING(name), \
PYBIND11_CONCAT(pybind11_init_impl_, name)); \
void PYBIND11_CONCAT(pybind11_init_, name)(pybind11::module &variable)
#define PYBIND11_EMBEDDED_MODULE(name, variable) \
static void PYBIND11_CONCAT(pybind11_init_, name)(pybind11::module &); \
static PyObject PYBIND11_CONCAT(*pybind11_init_wrapper_, name)() { \
auto m = pybind11::module(PYBIND11_TOSTRING(name)); \
try { \
PYBIND11_CONCAT(pybind11_init_, name)(m); \
return m.ptr(); \
} catch (pybind11::error_already_set & e) { \
PyErr_SetString(PyExc_ImportError, e.what()); \
return nullptr; \
} catch (const std::exception &e) { \
PyErr_SetString(PyExc_ImportError, e.what()); \
return nullptr; \
} \
} \
PYBIND11_EMBEDDED_MODULE_IMPL(name) \
pybind11::detail::embedded_module name( \
PYBIND11_TOSTRING(name), PYBIND11_CONCAT(pybind11_init_impl_, name)); \
void PYBIND11_CONCAT(pybind11_init_, name)(pybind11::module & variable)
NAMESPACE_BEGIN(PYBIND11_NAMESPACE)
NAMESPACE_BEGIN(detail)
/// Python 2.7/3.x compatible version of `PyImport_AppendInittab` and error checks.
/// Python 2.7/3.x compatible version of `PyImport_AppendInittab` and error
/// checks.
struct embedded_module {
#if PY_MAJOR_VERSION >= 3
using init_t = PyObject *(*)();
using init_t = PyObject *(*)();
#else
using init_t = void (*)();
using init_t = void (*)();
#endif
embedded_module(const char *name, init_t init) {
if (Py_IsInitialized())
pybind11_fail("Can't add new modules after the interpreter has been initialized");
embedded_module(const char *name, init_t init) {
if (Py_IsInitialized())
pybind11_fail(
"Can't add new modules after the interpreter has been initialized");
auto result = PyImport_AppendInittab(name, init);
if (result == -1)
pybind11_fail("Insufficient memory to add a new module");
}
auto result = PyImport_AppendInittab(name, init);
if (result == -1)
pybind11_fail("Insufficient memory to add a new module");
}
};
NAMESPACE_END(detail)
/** \rst
Initialize the Python interpreter. No other pybind11 or CPython API functions can be
called before this is done; with the exception of `PYBIND11_EMBEDDED_MODULE`. The
optional parameter can be used to skip the registration of signal handlers (see the
`Python documentation`_ for details). Calling this function again after the interpreter
has already been initialized is a fatal error.
Initialize the Python interpreter. No other pybind11 or CPython API
functions can be called before this is done; with the exception of
`PYBIND11_EMBEDDED_MODULE`. The optional parameter can be used to skip the
registration of signal handlers (see the `Python documentation`_ for details).
Calling this function again after the interpreter has already been initialized
is a fatal error.
If initializing the Python interpreter fails, then the program is terminated. (This
is controlled by the CPython runtime and is an exception to pybind11's normal behavior
of throwing exceptions on errors.)
If initializing the Python interpreter fails, then the program is
terminated. (This is controlled by the CPython runtime and is an exception to
pybind11's normal behavior of throwing exceptions on errors.)
.. _Python documentation: https://docs.python.org/3/c-api/init.html#c.Py_InitializeEx
\endrst */
.. _Python documentation:
https://docs.python.org/3/c-api/init.html#c.Py_InitializeEx \endrst */
inline void initialize_interpreter(bool init_signal_handlers = true) {
if (Py_IsInitialized())
pybind11_fail("The interpreter is already running");
if (Py_IsInitialized())
pybind11_fail("The interpreter is already running");
Py_InitializeEx(init_signal_handlers ? 1 : 0);
Py_InitializeEx(init_signal_handlers ? 1 : 0);
// Make .py files in the working directory available by default
module::import("sys").attr("path").cast<list>().append(".");
// Make .py files in the working directory available by default
module::import("sys").attr("path").cast<list>().append(".");
}
/** \rst
Shut down the Python interpreter. No pybind11 or CPython API functions can be called
after this. In addition, pybind11 objects must not outlive the interpreter:
Shut down the Python interpreter. No pybind11 or CPython API functions can
be called after this. In addition, pybind11 objects must not outlive the
interpreter:
.. code-block:: cpp
@@ -136,32 +139,33 @@ inline void initialize_interpreter(bool init_signal_handlers = true) {
.. warning::
The interpreter can be restarted by calling `initialize_interpreter` again.
Modules created using pybind11 can be safely re-initialized. However, Python
itself cannot completely unload binary extension modules and there are several
caveats with regard to interpreter restarting. All the details can be found
in the CPython documentation. In short, not all interpreter memory may be
The interpreter can be restarted by calling `initialize_interpreter`
again. Modules created using pybind11 can be safely re-initialized. However,
Python itself cannot completely unload binary extension modules and there are
several caveats with regard to interpreter restarting. All the details can be
found in the CPython documentation. In short, not all interpreter memory may be
freed, either due to reference cycles or user-created global data.
\endrst */
inline void finalize_interpreter() {
handle builtins(PyEval_GetBuiltins());
const char *id = PYBIND11_INTERNALS_ID;
handle builtins(PyEval_GetBuiltins());
const char *id = PYBIND11_INTERNALS_ID;
// Get the internals pointer (without creating it if it doesn't exist). It's possible for the
// internals to be created during Py_Finalize() (e.g. if a py::capsule calls `get_internals()`
// during destruction), so we get the pointer-pointer here and check it after Py_Finalize().
detail::internals **internals_ptr_ptr = detail::get_internals_pp();
// It could also be stashed in builtins, so look there too:
if (builtins.contains(id) && isinstance<capsule>(builtins[id]))
internals_ptr_ptr = capsule(builtins[id]);
// Get the internals pointer (without creating it if it doesn't exist). It's
// possible for the internals to be created during Py_Finalize() (e.g. if a
// py::capsule calls `get_internals()` during destruction), so we get the
// pointer-pointer here and check it after Py_Finalize().
detail::internals **internals_ptr_ptr = detail::get_internals_pp();
// It could also be stashed in builtins, so look there too:
if (builtins.contains(id) && isinstance<capsule>(builtins[id]))
internals_ptr_ptr = capsule(builtins[id]);
Py_Finalize();
Py_Finalize();
if (internals_ptr_ptr) {
delete *internals_ptr_ptr;
*internals_ptr_ptr = nullptr;
}
if (internals_ptr_ptr) {
delete *internals_ptr_ptr;
*internals_ptr_ptr = nullptr;
}
}
/** \rst
@@ -179,22 +183,24 @@ inline void finalize_interpreter() {
\endrst */
class scoped_interpreter {
public:
scoped_interpreter(bool init_signal_handlers = true) {
initialize_interpreter(init_signal_handlers);
}
scoped_interpreter(bool init_signal_handlers = true) {
initialize_interpreter(init_signal_handlers);
}
scoped_interpreter(const scoped_interpreter &) = delete;
scoped_interpreter(scoped_interpreter &&other) noexcept { other.is_valid = false; }
scoped_interpreter &operator=(const scoped_interpreter &) = delete;
scoped_interpreter &operator=(scoped_interpreter &&) = delete;
scoped_interpreter(const scoped_interpreter &) = delete;
scoped_interpreter(scoped_interpreter &&other) noexcept {
other.is_valid = false;
}
scoped_interpreter &operator=(const scoped_interpreter &) = delete;
scoped_interpreter &operator=(scoped_interpreter &&) = delete;
~scoped_interpreter() {
if (is_valid)
finalize_interpreter();
}
~scoped_interpreter() {
if (is_valid)
finalize_interpreter();
}
private:
bool is_valid = true;
bool is_valid = true;
};
NAMESPACE_END(PYBIND11_NAMESPACE)

View File

@@ -2,8 +2,8 @@
pybind11/exec.h: Support for evaluating Python expressions and statements
from strings and files
Copyright (c) 2016 Klemens Morgenstern <klemens.morgenstern@ed-chemnitz.de> and
Wenzel Jakob <wenzel.jakob@epfl.ch>
Copyright (c) 2016 Klemens Morgenstern <klemens.morgenstern@ed-chemnitz.de>
and Wenzel Jakob <wenzel.jakob@epfl.ch>
All rights reserved. Use of this source code is governed by a
BSD-style license that can be found in the LICENSE file.
@@ -16,102 +16,119 @@
NAMESPACE_BEGIN(PYBIND11_NAMESPACE)
enum eval_mode {
/// Evaluate a string containing an isolated expression
eval_expr,
/// Evaluate a string containing an isolated expression
eval_expr,
/// Evaluate a string containing a single statement. Returns \c none
eval_single_statement,
/// Evaluate a string containing a single statement. Returns \c none
eval_single_statement,
/// Evaluate a string containing a sequence of statement. Returns \c none
eval_statements
/// Evaluate a string containing a sequence of statement. Returns \c none
eval_statements
};
template <eval_mode mode = eval_expr>
object eval(str expr, object global = globals(), object local = object()) {
if (!local)
local = global;
if (!local)
local = global;
/* PyRun_String does not accept a PyObject / encoding specifier,
this seems to be the only alternative */
std::string buffer = "# -*- coding: utf-8 -*-\n" + (std::string) expr;
/* PyRun_String does not accept a PyObject / encoding specifier,
this seems to be the only alternative */
std::string buffer = "# -*- coding: utf-8 -*-\n" + (std::string)expr;
int start;
switch (mode) {
case eval_expr: start = Py_eval_input; break;
case eval_single_statement: start = Py_single_input; break;
case eval_statements: start = Py_file_input; break;
default: pybind11_fail("invalid evaluation mode");
}
int start;
switch (mode) {
case eval_expr:
start = Py_eval_input;
break;
case eval_single_statement:
start = Py_single_input;
break;
case eval_statements:
start = Py_file_input;
break;
default:
pybind11_fail("invalid evaluation mode");
}
PyObject *result = PyRun_String(buffer.c_str(), start, global.ptr(), local.ptr());
if (!result)
throw error_already_set();
return reinterpret_steal<object>(result);
PyObject *result =
PyRun_String(buffer.c_str(), start, global.ptr(), local.ptr());
if (!result)
throw error_already_set();
return reinterpret_steal<object>(result);
}
template <eval_mode mode = eval_expr, size_t N>
object eval(const char (&s)[N], object global = globals(), object local = object()) {
/* Support raw string literals by removing common leading whitespace */
auto expr = (s[0] == '\n') ? str(module::import("textwrap").attr("dedent")(s))
: str(s);
return eval<mode>(expr, global, local);
object eval(const char (&s)[N], object global = globals(),
object local = object()) {
/* Support raw string literals by removing common leading whitespace */
auto expr = (s[0] == '\n') ? str(module::import("textwrap").attr("dedent")(s))
: str(s);
return eval<mode>(expr, global, local);
}
inline void exec(str expr, object global = globals(), object local = object()) {
eval<eval_statements>(expr, global, local);
eval<eval_statements>(expr, global, local);
}
template <size_t N>
void exec(const char (&s)[N], object global = globals(), object local = object()) {
eval<eval_statements>(s, global, local);
void exec(const char (&s)[N], object global = globals(),
object local = object()) {
eval<eval_statements>(s, global, local);
}
template <eval_mode mode = eval_statements>
object eval_file(str fname, object global = globals(), object local = object()) {
if (!local)
local = global;
object eval_file(str fname, object global = globals(),
object local = object()) {
if (!local)
local = global;
int start;
switch (mode) {
case eval_expr: start = Py_eval_input; break;
case eval_single_statement: start = Py_single_input; break;
case eval_statements: start = Py_file_input; break;
default: pybind11_fail("invalid evaluation mode");
}
int start;
switch (mode) {
case eval_expr:
start = Py_eval_input;
break;
case eval_single_statement:
start = Py_single_input;
break;
case eval_statements:
start = Py_file_input;
break;
default:
pybind11_fail("invalid evaluation mode");
}
int closeFile = 1;
std::string fname_str = (std::string) fname;
int closeFile = 1;
std::string fname_str = (std::string)fname;
#if PY_VERSION_HEX >= 0x03040000
FILE *f = _Py_fopen_obj(fname.ptr(), "r");
FILE *f = _Py_fopen_obj(fname.ptr(), "r");
#elif PY_VERSION_HEX >= 0x03000000
FILE *f = _Py_fopen(fname.ptr(), "r");
FILE *f = _Py_fopen(fname.ptr(), "r");
#else
/* No unicode support in open() :( */
auto fobj = reinterpret_steal<object>(PyFile_FromString(
const_cast<char *>(fname_str.c_str()),
const_cast<char*>("r")));
FILE *f = nullptr;
if (fobj)
f = PyFile_AsFile(fobj.ptr());
closeFile = 0;
/* No unicode support in open() :( */
auto fobj = reinterpret_steal<object>(PyFile_FromString(
const_cast<char *>(fname_str.c_str()), const_cast<char *>("r")));
FILE *f = nullptr;
if (fobj)
f = PyFile_AsFile(fobj.ptr());
closeFile = 0;
#endif
if (!f) {
PyErr_Clear();
pybind11_fail("File \"" + fname_str + "\" could not be opened!");
}
if (!f) {
PyErr_Clear();
pybind11_fail("File \"" + fname_str + "\" could not be opened!");
}
#if PY_VERSION_HEX < 0x03000000 && defined(PYPY_VERSION)
PyObject *result = PyRun_File(f, fname_str.c_str(), start, global.ptr(),
local.ptr());
(void) closeFile;
PyObject *result =
PyRun_File(f, fname_str.c_str(), start, global.ptr(), local.ptr());
(void)closeFile;
#else
PyObject *result = PyRun_FileEx(f, fname_str.c_str(), start, global.ptr(),
local.ptr(), closeFile);
PyObject *result = PyRun_FileEx(f, fname_str.c_str(), start, global.ptr(),
local.ptr(), closeFile);
#endif
if (!result)
throw error_already_set();
return reinterpret_steal<object>(result);
if (!result)
throw error_already_set();
return reinterpret_steal<object>(result);
}
NAMESPACE_END(PYBIND11_NAMESPACE)

View File

@@ -17,91 +17,99 @@ NAMESPACE_BEGIN(detail)
template <typename Return, typename... Args>
struct type_caster<std::function<Return(Args...)>> {
using type = std::function<Return(Args...)>;
using retval_type = conditional_t<std::is_same<Return, void>::value, void_type, Return>;
using function_type = Return (*) (Args...);
using type = std::function<Return(Args...)>;
using retval_type =
conditional_t<std::is_same<Return, void>::value, void_type, Return>;
using function_type = Return (*)(Args...);
public:
bool load(handle src, bool convert) {
if (src.is_none()) {
// Defer accepting None to other overloads (if we aren't in convert mode):
if (!convert) return false;
return true;
}
bool load(handle src, bool convert) {
if (src.is_none()) {
// Defer accepting None to other overloads (if we aren't in convert mode):
if (!convert)
return false;
return true;
}
if (!isinstance<function>(src))
return false;
if (!isinstance<function>(src))
return false;
auto func = reinterpret_borrow<function>(src);
auto func = reinterpret_borrow<function>(src);
/*
When passing a C++ function as an argument to another C++
function via Python, every function call would normally involve
a full C++ -> Python -> C++ roundtrip, which can be prohibitive.
Here, we try to at least detect the case where the function is
stateless (i.e. function pointer or lambda function without
captured variables), in which case the roundtrip can be avoided.
*/
if (auto cfunc = func.cpp_function()) {
auto c = reinterpret_borrow<capsule>(PyCFunction_GET_SELF(cfunc.ptr()));
auto rec = (function_record *) c;
/*
When passing a C++ function as an argument to another C++
function via Python, every function call would normally involve
a full C++ -> Python -> C++ roundtrip, which can be prohibitive.
Here, we try to at least detect the case where the function is
stateless (i.e. function pointer or lambda function without
captured variables), in which case the roundtrip can be avoided.
*/
if (auto cfunc = func.cpp_function()) {
auto c = reinterpret_borrow<capsule>(PyCFunction_GET_SELF(cfunc.ptr()));
auto rec = (function_record *)c;
if (rec && rec->is_stateless &&
same_type(typeid(function_type), *reinterpret_cast<const std::type_info *>(rec->data[1]))) {
struct capture { function_type f; };
value = ((capture *) &rec->data)->f;
return true;
}
}
// ensure GIL is held during functor destruction
struct func_handle {
function f;
func_handle(function&& f_) : f(std::move(f_)) {}
func_handle(const func_handle&) = default;
~func_handle() {
gil_scoped_acquire acq;
function kill_f(std::move(f));
}
if (rec && rec->is_stateless &&
same_type(typeid(function_type),
*reinterpret_cast<const std::type_info *>(rec->data[1]))) {
struct capture {
function_type f;
};
// value = [hfunc = func_handle(std::move(func))](Args... args) -> Return {
// gil_scoped_acquire acq;
// object retval(hfunc.f(std::forward<Args>(args)...));
// /* Visual studio 2015 parser issue: need parentheses around this expression */
// return (retval.template cast<Return>());
// };
struct func_wrapper {
func_handle hfunc;
func_wrapper(func_handle&& hf): hfunc(std::move(hf)) {}
Return operator()(Args... args) const {
gil_scoped_acquire acq;
object retval(hfunc.f(std::forward<Args>(args)...));
/* Visual studio 2015 parser issue: need parentheses around this expression */
return (retval.template cast<Return>());
}
};
value = func_wrapper(func_handle(std::move(func)));
value = ((capture *)&rec->data)->f;
return true;
}
}
template <typename Func>
static handle cast(Func &&f_, return_value_policy policy, handle /* parent */) {
if (!f_)
return none().inc_ref();
// ensure GIL is held during functor destruction
struct func_handle {
function f;
func_handle(function &&f_) : f(std::move(f_)) {}
func_handle(const func_handle &) = default;
~func_handle() {
gil_scoped_acquire acq;
function kill_f(std::move(f));
}
};
auto result = f_.template target<function_type>();
if (result)
return cpp_function(*result, policy).release();
else
return cpp_function(std::forward<Func>(f_), policy).release();
}
// value = [hfunc = func_handle(std::move(func))](Args... args) -> Return {
// gil_scoped_acquire acq;
// object retval(hfunc.f(std::forward<Args>(args)...));
// /* Visual studio 2015 parser issue: need parentheses around this
// expression */ return (retval.template cast<Return>());
// };
PYBIND11_TYPE_CASTER(type, _("Callable[[") + concat(make_caster<Args>::name...) + _("], ")
+ make_caster<retval_type>::name + _("]"));
struct func_wrapper {
func_handle hfunc;
func_wrapper(func_handle &&hf) : hfunc(std::move(hf)) {}
Return operator()(Args... args) const {
gil_scoped_acquire acq;
object retval(hfunc.f(std::forward<Args>(args)...));
/* Visual studio 2015 parser issue: need parentheses around this
* expression */
return (retval.template cast<Return>());
}
};
value = func_wrapper(func_handle(std::move(func)));
return true;
}
template <typename Func>
static handle cast(Func &&f_, return_value_policy policy,
handle /* parent */) {
if (!f_)
return none().inc_ref();
auto result = f_.template target<function_type>();
if (result)
return cpp_function(*result, policy).release();
else
return cpp_function(std::forward<Func>(f_), policy).release();
}
PYBIND11_TYPE_CASTER(type, _("Callable[[") +
concat(make_caster<Args>::name...) + _("], ") +
make_caster<retval_type>::name + _("]"));
};
NAMESPACE_END(detail)

View File

@@ -1,5 +1,6 @@
/*
pybind11/iostream.h -- Tools to assist with redirecting cout and cerr to Python
pybind11/iostream.h -- Tools to assist with redirecting cout and cerr to
Python
Copyright (c) 2017 Henry F. Schreiner
@@ -11,11 +12,11 @@
#include "pybind11.h"
#include <streambuf>
#include <ostream>
#include <string>
#include <memory>
#include <iostream>
#include <memory>
#include <ostream>
#include <streambuf>
#include <string>
NAMESPACE_BEGIN(PYBIND11_NAMESPACE)
NAMESPACE_BEGIN(detail)
@@ -23,56 +24,50 @@ NAMESPACE_BEGIN(detail)
// Buffer that writes to Python instead of C++
class pythonbuf : public std::streambuf {
private:
using traits_type = std::streambuf::traits_type;
using traits_type = std::streambuf::traits_type;
const size_t buf_size;
std::unique_ptr<char[]> d_buffer;
object pywrite;
object pyflush;
const size_t buf_size;
std::unique_ptr<char[]> d_buffer;
object pywrite;
object pyflush;
int overflow(int c) {
if (!traits_type::eq_int_type(c, traits_type::eof())) {
*pptr() = traits_type::to_char_type(c);
pbump(1);
}
return sync() == 0 ? traits_type::not_eof(c) : traits_type::eof();
int overflow(int c) {
if (!traits_type::eq_int_type(c, traits_type::eof())) {
*pptr() = traits_type::to_char_type(c);
pbump(1);
}
return sync() == 0 ? traits_type::not_eof(c) : traits_type::eof();
}
int sync() {
if (pbase() != pptr()) {
// This subtraction cannot be negative, so dropping the sign
str line(pbase(), static_cast<size_t>(pptr() - pbase()));
int sync() {
if (pbase() != pptr()) {
// This subtraction cannot be negative, so dropping the sign
str line(pbase(), static_cast<size_t>(pptr() - pbase()));
{
gil_scoped_acquire tmp;
pywrite(line);
pyflush();
}
{
gil_scoped_acquire tmp;
pywrite(line);
pyflush();
}
setp(pbase(), epptr());
}
return 0;
setp(pbase(), epptr());
}
return 0;
}
public:
pythonbuf(object pyostream, size_t buffer_size = 1024)
: buf_size(buffer_size), d_buffer(new char[buf_size]),
pywrite(pyostream.attr("write")), pyflush(pyostream.attr("flush")) {
setp(d_buffer.get(), d_buffer.get() + buf_size - 1);
}
pythonbuf(object pyostream, size_t buffer_size = 1024)
: buf_size(buffer_size),
d_buffer(new char[buf_size]),
pywrite(pyostream.attr("write")),
pyflush(pyostream.attr("flush")) {
setp(d_buffer.get(), d_buffer.get() + buf_size - 1);
}
/// Sync before destroy
~pythonbuf() {
sync();
}
/// Sync before destroy
~pythonbuf() { sync(); }
};
NAMESPACE_END(detail)
/** \rst
This a move-only guard that redirects output.
@@ -93,35 +88,32 @@ NAMESPACE_END(detail)
.. code-block:: cpp
{
py::scoped_ostream_redirect output{std::cerr, py::module::import("sys").attr("stderr")};
std::cerr << "Hello, World!";
py::scoped_ostream_redirect output{std::cerr,
py::module::import("sys").attr("stderr")}; std::cerr << "Hello, World!";
}
\endrst */
class scoped_ostream_redirect {
protected:
std::streambuf *old;
std::ostream &costream;
detail::pythonbuf buffer;
std::streambuf *old;
std::ostream &costream;
detail::pythonbuf buffer;
public:
scoped_ostream_redirect(
std::ostream &costream = std::cout,
object pyostream = module::import("sys").attr("stdout"))
: costream(costream), buffer(pyostream) {
old = costream.rdbuf(&buffer);
}
scoped_ostream_redirect(
std::ostream &costream = std::cout,
object pyostream = module::import("sys").attr("stdout"))
: costream(costream), buffer(pyostream) {
old = costream.rdbuf(&buffer);
}
~scoped_ostream_redirect() {
costream.rdbuf(old);
}
~scoped_ostream_redirect() { costream.rdbuf(old); }
scoped_ostream_redirect(const scoped_ostream_redirect &) = delete;
scoped_ostream_redirect(scoped_ostream_redirect &&other) = default;
scoped_ostream_redirect &operator=(const scoped_ostream_redirect &) = delete;
scoped_ostream_redirect &operator=(scoped_ostream_redirect &&) = delete;
scoped_ostream_redirect(const scoped_ostream_redirect &) = delete;
scoped_ostream_redirect(scoped_ostream_redirect &&other) = default;
scoped_ostream_redirect &operator=(const scoped_ostream_redirect &) = delete;
scoped_ostream_redirect &operator=(scoped_ostream_redirect &&) = delete;
};
/** \rst
Like `scoped_ostream_redirect`, but redirects cerr by default. This class
is provided primary to make ``py::call_guard`` easier to make.
@@ -135,44 +127,44 @@ public:
\endrst */
class scoped_estream_redirect : public scoped_ostream_redirect {
public:
scoped_estream_redirect(
std::ostream &costream = std::cerr,
object pyostream = module::import("sys").attr("stderr"))
: scoped_ostream_redirect(costream,pyostream) {}
scoped_estream_redirect(
std::ostream &costream = std::cerr,
object pyostream = module::import("sys").attr("stderr"))
: scoped_ostream_redirect(costream, pyostream) {}
};
NAMESPACE_BEGIN(detail)
// Class to redirect output as a context manager. C++ backend.
class OstreamRedirect {
bool do_stdout_;
bool do_stderr_;
std::unique_ptr<scoped_ostream_redirect> redirect_stdout;
std::unique_ptr<scoped_estream_redirect> redirect_stderr;
bool do_stdout_;
bool do_stderr_;
std::unique_ptr<scoped_ostream_redirect> redirect_stdout;
std::unique_ptr<scoped_estream_redirect> redirect_stderr;
public:
OstreamRedirect(bool do_stdout = true, bool do_stderr = true)
: do_stdout_(do_stdout), do_stderr_(do_stderr) {}
OstreamRedirect(bool do_stdout = true, bool do_stderr = true)
: do_stdout_(do_stdout), do_stderr_(do_stderr) {}
void enter() {
if (do_stdout_)
redirect_stdout.reset(new scoped_ostream_redirect());
if (do_stderr_)
redirect_stderr.reset(new scoped_estream_redirect());
}
void enter() {
if (do_stdout_)
redirect_stdout.reset(new scoped_ostream_redirect());
if (do_stderr_)
redirect_stderr.reset(new scoped_estream_redirect());
}
void exit() {
redirect_stdout.reset();
redirect_stderr.reset();
}
void exit() {
redirect_stdout.reset();
redirect_stderr.reset();
}
};
NAMESPACE_END(detail)
/** \rst
This is a helper function to add a C++ redirect context manager to Python
instead of using a C++ guard. To use it, add the following to your binding code:
instead of using a C++ guard. To use it, add the following to your binding
code:
.. code-block:: cpp
@@ -197,11 +189,13 @@ NAMESPACE_END(detail)
m.noisy_function_with_error_printing()
\endrst */
inline class_<detail::OstreamRedirect> add_ostream_redirect(module m, std::string name = "ostream_redirect") {
return class_<detail::OstreamRedirect>(m, name.c_str(), module_local())
.def(init<bool,bool>(), arg("stdout")=true, arg("stderr")=true)
.def("__enter__", &detail::OstreamRedirect::enter)
.def("__exit__", [](detail::OstreamRedirect &self_, args) { self_.exit(); });
inline class_<detail::OstreamRedirect>
add_ostream_redirect(module m, std::string name = "ostream_redirect") {
return class_<detail::OstreamRedirect>(m, name.c_str(), module_local())
.def(init<bool, bool>(), arg("stdout") = true, arg("stderr") = true)
.def("__enter__", &detail::OstreamRedirect::enter)
.def("__exit__",
[](detail::OstreamRedirect &self_, args) { self_.exit(); });
}
NAMESPACE_END(PYBIND11_NAMESPACE)

File diff suppressed because it is too large Load Diff

View File

@@ -12,10 +12,13 @@
#include "pybind11.h"
#if defined(__clang__) && !defined(__INTEL_COMPILER)
# pragma clang diagnostic ignored "-Wunsequenced" // multiple unsequenced modifications to 'self' (when using def(py::self OP Type()))
#pragma clang diagnostic ignored \
"-Wunsequenced" // multiple unsequenced modifications to 'self' (when using
// def(py::self OP Type()))
#elif defined(_MSC_VER)
# pragma warning(push)
# pragma warning(disable: 4127) // warning C4127: Conditional expression is constant
#pragma warning(push)
#pragma warning( \
disable : 4127) // warning C4127: Conditional expression is constant
#endif
NAMESPACE_BEGIN(PYBIND11_NAMESPACE)
@@ -23,136 +26,191 @@ NAMESPACE_BEGIN(detail)
/// Enumeration with all supported operator types
enum op_id : int {
op_add, op_sub, op_mul, op_div, op_mod, op_divmod, op_pow, op_lshift,
op_rshift, op_and, op_xor, op_or, op_neg, op_pos, op_abs, op_invert,
op_int, op_long, op_float, op_str, op_cmp, op_gt, op_ge, op_lt, op_le,
op_eq, op_ne, op_iadd, op_isub, op_imul, op_idiv, op_imod, op_ilshift,
op_irshift, op_iand, op_ixor, op_ior, op_complex, op_bool, op_nonzero,
op_repr, op_truediv, op_itruediv, op_hash
op_add,
op_sub,
op_mul,
op_div,
op_mod,
op_divmod,
op_pow,
op_lshift,
op_rshift,
op_and,
op_xor,
op_or,
op_neg,
op_pos,
op_abs,
op_invert,
op_int,
op_long,
op_float,
op_str,
op_cmp,
op_gt,
op_ge,
op_lt,
op_le,
op_eq,
op_ne,
op_iadd,
op_isub,
op_imul,
op_idiv,
op_imod,
op_ilshift,
op_irshift,
op_iand,
op_ixor,
op_ior,
op_complex,
op_bool,
op_nonzero,
op_repr,
op_truediv,
op_itruediv,
op_hash
};
enum op_type : int {
op_l, /* base type on left */
op_r, /* base type on right */
op_u /* unary operator */
op_l, /* base type on left */
op_r, /* base type on right */
op_u /* unary operator */
};
struct self_t { };
struct self_t {};
static const self_t self = self_t();
/// Type for an unused type slot
struct undefined_t { };
struct undefined_t {};
/// Don't warn about an unused variable
inline self_t __self() { return self; }
/// base template of operator implementations
template <op_id, op_type, typename B, typename L, typename R> struct op_impl { };
template <op_id, op_type, typename B, typename L, typename R> struct op_impl {};
/// Operator implementation generator
template <op_id id, op_type ot, typename L, typename R> struct op_ {
template <typename Class, typename... Extra> void execute(Class &cl, const Extra&... extra) const {
using Base = typename Class::type;
using L_type = conditional_t<std::is_same<L, self_t>::value, Base, L>;
using R_type = conditional_t<std::is_same<R, self_t>::value, Base, R>;
using op = op_impl<id, ot, Base, L_type, R_type>;
cl.def(op::name(), &op::execute, is_operator(), extra...);
#if PY_MAJOR_VERSION < 3
if (id == op_truediv || id == op_itruediv)
cl.def(id == op_itruediv ? "__idiv__" : ot == op_l ? "__div__" : "__rdiv__",
&op::execute, is_operator(), extra...);
#endif
}
template <typename Class, typename... Extra> void execute_cast(Class &cl, const Extra&... extra) const {
using Base = typename Class::type;
using L_type = conditional_t<std::is_same<L, self_t>::value, Base, L>;
using R_type = conditional_t<std::is_same<R, self_t>::value, Base, R>;
using op = op_impl<id, ot, Base, L_type, R_type>;
cl.def(op::name(), &op::execute_cast, is_operator(), extra...);
#if PY_MAJOR_VERSION < 3
if (id == op_truediv || id == op_itruediv)
cl.def(id == op_itruediv ? "__idiv__" : ot == op_l ? "__div__" : "__rdiv__",
&op::execute, is_operator(), extra...);
#endif
}
template <typename Class, typename... Extra>
void execute(Class &cl, const Extra &... extra) const {
using Base = typename Class::type;
using L_type = conditional_t<std::is_same<L, self_t>::value, Base, L>;
using R_type = conditional_t<std::is_same<R, self_t>::value, Base, R>;
using op = op_impl<id, ot, Base, L_type, R_type>;
cl.def(op::name(), &op::execute, is_operator(), extra...);
#if PY_MAJOR_VERSION < 3
if (id == op_truediv || id == op_itruediv)
cl.def(id == op_itruediv ? "__idiv__"
: ot == op_l ? "__div__" : "__rdiv__",
&op::execute, is_operator(), extra...);
#endif
}
template <typename Class, typename... Extra>
void execute_cast(Class &cl, const Extra &... extra) const {
using Base = typename Class::type;
using L_type = conditional_t<std::is_same<L, self_t>::value, Base, L>;
using R_type = conditional_t<std::is_same<R, self_t>::value, Base, R>;
using op = op_impl<id, ot, Base, L_type, R_type>;
cl.def(op::name(), &op::execute_cast, is_operator(), extra...);
#if PY_MAJOR_VERSION < 3
if (id == op_truediv || id == op_itruediv)
cl.def(id == op_itruediv ? "__idiv__"
: ot == op_l ? "__div__" : "__rdiv__",
&op::execute, is_operator(), extra...);
#endif
}
};
#define PYBIND11_BINARY_OPERATOR(id, rid, op, expr) \
template <typename B, typename L, typename R> struct op_impl<op_##id, op_l, B, L, R> { \
static char const* name() { return "__" #id "__"; } \
static auto execute(const L &l, const R &r) -> decltype(expr) { return (expr); } \
static B execute_cast(const L &l, const R &r) { return B(expr); } \
}; \
template <typename B, typename L, typename R> struct op_impl<op_##id, op_r, B, L, R> { \
static char const* name() { return "__" #rid "__"; } \
static auto execute(const R &r, const L &l) -> decltype(expr) { return (expr); } \
static B execute_cast(const R &r, const L &l) { return B(expr); } \
}; \
inline op_<op_##id, op_l, self_t, self_t> op(const self_t &, const self_t &) { \
return op_<op_##id, op_l, self_t, self_t>(); \
} \
template <typename T> op_<op_##id, op_l, self_t, T> op(const self_t &, const T &) { \
return op_<op_##id, op_l, self_t, T>(); \
} \
template <typename T> op_<op_##id, op_r, T, self_t> op(const T &, const self_t &) { \
return op_<op_##id, op_r, T, self_t>(); \
}
#define PYBIND11_BINARY_OPERATOR(id, rid, op, expr) \
template <typename B, typename L, typename R> \
struct op_impl<op_##id, op_l, B, L, R> { \
static char const *name() { return "__" #id "__"; } \
static auto execute(const L &l, const R &r) -> decltype(expr) { \
return (expr); \
} \
static B execute_cast(const L &l, const R &r) { return B(expr); } \
}; \
template <typename B, typename L, typename R> \
struct op_impl<op_##id, op_r, B, L, R> { \
static char const *name() { return "__" #rid "__"; } \
static auto execute(const R &r, const L &l) -> decltype(expr) { \
return (expr); \
} \
static B execute_cast(const R &r, const L &l) { return B(expr); } \
}; \
inline op_<op_##id, op_l, self_t, self_t> op(const self_t &, \
const self_t &) { \
return op_<op_##id, op_l, self_t, self_t>(); \
} \
template <typename T> \
op_<op_##id, op_l, self_t, T> op(const self_t &, const T &) { \
return op_<op_##id, op_l, self_t, T>(); \
} \
template <typename T> \
op_<op_##id, op_r, T, self_t> op(const T &, const self_t &) { \
return op_<op_##id, op_r, T, self_t>(); \
}
#define PYBIND11_INPLACE_OPERATOR(id, op, expr) \
template <typename B, typename L, typename R> struct op_impl<op_##id, op_l, B, L, R> { \
static char const* name() { return "__" #id "__"; } \
static auto execute(L &l, const R &r) -> decltype(expr) { return expr; } \
static B execute_cast(L &l, const R &r) { return B(expr); } \
}; \
template <typename T> op_<op_##id, op_l, self_t, T> op(const self_t &, const T &) { \
return op_<op_##id, op_l, self_t, T>(); \
}
#define PYBIND11_INPLACE_OPERATOR(id, op, expr) \
template <typename B, typename L, typename R> \
struct op_impl<op_##id, op_l, B, L, R> { \
static char const *name() { return "__" #id "__"; } \
static auto execute(L &l, const R &r) -> decltype(expr) { return expr; } \
static B execute_cast(L &l, const R &r) { return B(expr); } \
}; \
template <typename T> \
op_<op_##id, op_l, self_t, T> op(const self_t &, const T &) { \
return op_<op_##id, op_l, self_t, T>(); \
}
#define PYBIND11_UNARY_OPERATOR(id, op, expr) \
template <typename B, typename L> struct op_impl<op_##id, op_u, B, L, undefined_t> { \
static char const* name() { return "__" #id "__"; } \
static auto execute(const L &l) -> decltype(expr) { return expr; } \
static B execute_cast(const L &l) { return B(expr); } \
}; \
inline op_<op_##id, op_u, self_t, undefined_t> op(const self_t &) { \
return op_<op_##id, op_u, self_t, undefined_t>(); \
}
#define PYBIND11_UNARY_OPERATOR(id, op, expr) \
template <typename B, typename L> \
struct op_impl<op_##id, op_u, B, L, undefined_t> { \
static char const *name() { return "__" #id "__"; } \
static auto execute(const L &l) -> decltype(expr) { return expr; } \
static B execute_cast(const L &l) { return B(expr); } \
}; \
inline op_<op_##id, op_u, self_t, undefined_t> op(const self_t &) { \
return op_<op_##id, op_u, self_t, undefined_t>(); \
}
PYBIND11_BINARY_OPERATOR(sub, rsub, operator-, l - r)
PYBIND11_BINARY_OPERATOR(add, radd, operator+, l + r)
PYBIND11_BINARY_OPERATOR(mul, rmul, operator*, l * r)
PYBIND11_BINARY_OPERATOR(truediv, rtruediv, operator/, l / r)
PYBIND11_BINARY_OPERATOR(mod, rmod, operator%, l % r)
PYBIND11_BINARY_OPERATOR(lshift, rlshift, operator<<, l << r)
PYBIND11_BINARY_OPERATOR(rshift, rrshift, operator>>, l >> r)
PYBIND11_BINARY_OPERATOR(and, rand, operator&, l & r)
PYBIND11_BINARY_OPERATOR(xor, rxor, operator^, l ^ r)
PYBIND11_BINARY_OPERATOR(eq, eq, operator==, l == r)
PYBIND11_BINARY_OPERATOR(ne, ne, operator!=, l != r)
PYBIND11_BINARY_OPERATOR(or, ror, operator|, l | r)
PYBIND11_BINARY_OPERATOR(gt, lt, operator>, l > r)
PYBIND11_BINARY_OPERATOR(ge, le, operator>=, l >= r)
PYBIND11_BINARY_OPERATOR(lt, gt, operator<, l < r)
PYBIND11_BINARY_OPERATOR(le, ge, operator<=, l <= r)
//PYBIND11_BINARY_OPERATOR(pow, rpow, pow, std::pow(l, r))
PYBIND11_INPLACE_OPERATOR(iadd, operator+=, l += r)
PYBIND11_INPLACE_OPERATOR(isub, operator-=, l -= r)
PYBIND11_INPLACE_OPERATOR(imul, operator*=, l *= r)
PYBIND11_INPLACE_OPERATOR(itruediv, operator/=, l /= r)
PYBIND11_INPLACE_OPERATOR(imod, operator%=, l %= r)
PYBIND11_INPLACE_OPERATOR(ilshift, operator<<=, l <<= r)
PYBIND11_INPLACE_OPERATOR(irshift, operator>>=, l >>= r)
PYBIND11_INPLACE_OPERATOR(iand, operator&=, l &= r)
PYBIND11_INPLACE_OPERATOR(ixor, operator^=, l ^= r)
PYBIND11_INPLACE_OPERATOR(ior, operator|=, l |= r)
PYBIND11_UNARY_OPERATOR(neg, operator-, -l)
PYBIND11_UNARY_OPERATOR(pos, operator+, +l)
PYBIND11_UNARY_OPERATOR(abs, abs, std::abs(l))
PYBIND11_UNARY_OPERATOR(hash, hash, std::hash<L>()(l))
PYBIND11_UNARY_OPERATOR(invert, operator~, (~l))
PYBIND11_UNARY_OPERATOR(bool, operator!, !!l)
PYBIND11_UNARY_OPERATOR(int, int_, (int) l)
PYBIND11_UNARY_OPERATOR(float, float_, (double) l)
PYBIND11_BINARY_OPERATOR(sub, rsub, operator-, l - r)
PYBIND11_BINARY_OPERATOR(add, radd, operator+, l + r)
PYBIND11_BINARY_OPERATOR(mul, rmul, operator*, l * r)
PYBIND11_BINARY_OPERATOR(truediv, rtruediv, operator/, l / r)
PYBIND11_BINARY_OPERATOR(mod, rmod, operator%, l % r)
PYBIND11_BINARY_OPERATOR(lshift, rlshift, operator<<, l << r)
PYBIND11_BINARY_OPERATOR(rshift, rrshift, operator>>, l>> r)
PYBIND11_BINARY_OPERATOR(and, rand, operator&, l & r)
PYBIND11_BINARY_OPERATOR(xor, rxor, operator^, l ^ r)
PYBIND11_BINARY_OPERATOR(eq, eq, operator==, l == r)
PYBIND11_BINARY_OPERATOR(ne, ne, operator!=, l != r)
PYBIND11_BINARY_OPERATOR(or, ror, operator|, l | r)
PYBIND11_BINARY_OPERATOR(gt, lt, operator>, l> r)
PYBIND11_BINARY_OPERATOR(ge, le, operator>=, l >= r)
PYBIND11_BINARY_OPERATOR(lt, gt, operator<, l<r)
PYBIND11_BINARY_OPERATOR(le, ge, operator<=, l <= r)
// PYBIND11_BINARY_OPERATOR(pow, rpow, pow, std::pow(l,
// r))
PYBIND11_INPLACE_OPERATOR(iadd, operator+=, l += r)
PYBIND11_INPLACE_OPERATOR(isub, operator-=, l -= r)
PYBIND11_INPLACE_OPERATOR(imul, operator*=, l *= r)
PYBIND11_INPLACE_OPERATOR(itruediv, operator/=, l /= r)
PYBIND11_INPLACE_OPERATOR(imod, operator%=, l %= r)
PYBIND11_INPLACE_OPERATOR(ilshift, operator<<=, l <<= r)
PYBIND11_INPLACE_OPERATOR(irshift, operator>>=, l >>= r)
PYBIND11_INPLACE_OPERATOR(iand, operator&=, l &= r)
PYBIND11_INPLACE_OPERATOR(ixor, operator^=, l ^= r)
PYBIND11_INPLACE_OPERATOR(ior, operator|=, l |= r)
PYBIND11_UNARY_OPERATOR(neg, operator-, - l)
PYBIND11_UNARY_OPERATOR(pos, operator+, + l)
PYBIND11_UNARY_OPERATOR(abs, abs, std::abs(l))
PYBIND11_UNARY_OPERATOR(hash, hash, std::hash<L>()(l))
PYBIND11_UNARY_OPERATOR(invert, operator~,(~l))
PYBIND11_UNARY_OPERATOR(bool, operator!, !!l)
PYBIND11_UNARY_OPERATOR(int, int_, (int)l)
PYBIND11_UNARY_OPERATOR(float, float_, (double)l)
#undef PYBIND11_BINARY_OPERATOR
#undef PYBIND11_INPLACE_OPERATOR
@@ -164,5 +222,5 @@ using detail::self;
NAMESPACE_END(PYBIND11_NAMESPACE)
#if defined(_MSC_VER)
# pragma warning(pop)
#pragma warning(pop)
#endif

View File

@@ -15,51 +15,65 @@ NAMESPACE_BEGIN(PYBIND11_NAMESPACE)
class options {
public:
// Default RAII constructor, which leaves settings as they currently are.
options() : previous_state(global_state()) {}
// Default RAII constructor, which leaves settings as they currently are.
options() : previous_state(global_state()) {}
// Class is non-copyable.
options(const options &) = delete;
options &operator=(const options &) = delete;
// Class is non-copyable.
options(const options&) = delete;
options& operator=(const options&) = delete;
// Destructor, which restores settings that were in effect before.
~options() { global_state() = previous_state; }
// Destructor, which restores settings that were in effect before.
~options() {
global_state() = previous_state;
}
// Setter methods (affect the global state):
// Setter methods (affect the global state):
options &disable_user_defined_docstrings() & {
global_state().show_user_defined_docstrings = false;
return *this;
}
options& disable_user_defined_docstrings() & { global_state().show_user_defined_docstrings = false; return *this; }
options &enable_user_defined_docstrings() & {
global_state().show_user_defined_docstrings = true;
return *this;
}
options& enable_user_defined_docstrings() & { global_state().show_user_defined_docstrings = true; return *this; }
options &disable_function_signatures() & {
global_state().show_function_signatures = false;
return *this;
}
options& disable_function_signatures() & { global_state().show_function_signatures = false; return *this; }
options &enable_function_signatures() & {
global_state().show_function_signatures = true;
return *this;
}
options& enable_function_signatures() & { global_state().show_function_signatures = true; return *this; }
// Getter methods (return the global state):
// Getter methods (return the global state):
static bool show_user_defined_docstrings() {
return global_state().show_user_defined_docstrings;
}
static bool show_user_defined_docstrings() { return global_state().show_user_defined_docstrings; }
static bool show_function_signatures() {
return global_state().show_function_signatures;
}
static bool show_function_signatures() { return global_state().show_function_signatures; }
// This type is not meant to be allocated on the heap.
void* operator new(size_t) = delete;
// This type is not meant to be allocated on the heap.
void *operator new(size_t) = delete;
private:
struct state {
bool show_user_defined_docstrings =
true; //< Include user-supplied texts in docstrings.
bool show_function_signatures =
true; //< Include auto-generated function signatures in docstrings.
};
struct state {
bool show_user_defined_docstrings = true; //< Include user-supplied texts in docstrings.
bool show_function_signatures = true; //< Include auto-generated function signatures in docstrings.
};
static state &global_state() {
static state instance;
return instance;
}
static state &global_state() {
static state instance;
return instance;
}
state previous_state;
state previous_state;
};
NAMESPACE_END(PYBIND11_NAMESPACE)

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@@ -10,373 +10,411 @@
#pragma once
#include "pybind11.h"
#include <set>
#include <unordered_set>
#include <map>
#include <unordered_map>
#include <deque>
#include <iostream>
#include <list>
#include <deque>
#include <map>
#include <set>
#include <unordered_map>
#include <unordered_set>
#include <valarray>
#if defined(_MSC_VER)
#pragma warning(push)
#pragma warning(disable: 4127) // warning C4127: Conditional expression is constant
#pragma warning( \
disable : 4127) // warning C4127: Conditional expression is constant
#endif
#ifdef __has_include
// std::optional (but including it in c++14 mode isn't allowed)
# if defined(PYBIND11_CPP17) && __has_include(<optional>)
# include <optional>
# define PYBIND11_HAS_OPTIONAL 1
# endif
#if defined(PYBIND11_CPP17) && __has_include(<optional>)
#include <optional>
#define PYBIND11_HAS_OPTIONAL 1
#endif
// std::experimental::optional (but not allowed in c++11 mode)
# if defined(PYBIND11_CPP14) && (__has_include(<experimental/optional>) && \
#if defined(PYBIND11_CPP14) && (__has_include(<experimental/optional>) && \
!__has_include(<optional>))
# include <experimental/optional>
# define PYBIND11_HAS_EXP_OPTIONAL 1
# endif
#include <experimental/optional>
#define PYBIND11_HAS_EXP_OPTIONAL 1
#endif
// std::variant
# if defined(PYBIND11_CPP17) && __has_include(<variant>)
# include <variant>
# define PYBIND11_HAS_VARIANT 1
# endif
#if defined(PYBIND11_CPP17) && __has_include(<variant>)
#include <variant>
#define PYBIND11_HAS_VARIANT 1
#endif
#elif defined(_MSC_VER) && defined(PYBIND11_CPP17)
# include <optional>
# include <variant>
# define PYBIND11_HAS_OPTIONAL 1
# define PYBIND11_HAS_VARIANT 1
#include <optional>
#include <variant>
#define PYBIND11_HAS_OPTIONAL 1
#define PYBIND11_HAS_VARIANT 1
#endif
NAMESPACE_BEGIN(PYBIND11_NAMESPACE)
NAMESPACE_BEGIN(detail)
/// Extracts an const lvalue reference or rvalue reference for U based on the type of T (e.g. for
/// forwarding a container element). Typically used indirect via forwarded_type(), below.
/// Extracts an const lvalue reference or rvalue reference for U based on the
/// type of T (e.g. for forwarding a container element). Typically used
/// indirect via forwarded_type(), below.
template <typename T, typename U>
using forwarded_type = conditional_t<
std::is_lvalue_reference<T>::value, remove_reference_t<U> &, remove_reference_t<U> &&>;
using forwarded_type =
conditional_t<std::is_lvalue_reference<T>::value, remove_reference_t<U> &,
remove_reference_t<U> &&>;
/// Forwards a value U as rvalue or lvalue according to whether T is rvalue or lvalue; typically
/// used for forwarding a container's elements.
template <typename T, typename U>
forwarded_type<T, U> forward_like(U &&u) {
return std::forward<detail::forwarded_type<T, U>>(std::forward<U>(u));
/// Forwards a value U as rvalue or lvalue according to whether T is rvalue or
/// lvalue; typically used for forwarding a container's elements.
template <typename T, typename U> forwarded_type<T, U> forward_like(U &&u) {
return std::forward<detail::forwarded_type<T, U>>(std::forward<U>(u));
}
template <typename Type, typename Key> struct set_caster {
using type = Type;
using key_conv = make_caster<Key>;
using type = Type;
using key_conv = make_caster<Key>;
bool load(handle src, bool convert) {
if (!isinstance<pybind11::set>(src))
return false;
auto s = reinterpret_borrow<pybind11::set>(src);
value.clear();
for (auto entry : s) {
key_conv conv;
if (!conv.load(entry, convert))
return false;
value.insert(cast_op<Key &&>(std::move(conv)));
}
return true;
bool load(handle src, bool convert) {
if (!isinstance<pybind11::set>(src))
return false;
auto s = reinterpret_borrow<pybind11::set>(src);
value.clear();
for (auto entry : s) {
key_conv conv;
if (!conv.load(entry, convert))
return false;
value.insert(cast_op<Key &&>(std::move(conv)));
}
return true;
}
template <typename T>
static handle cast(T &&src, return_value_policy policy, handle parent) {
if (!std::is_lvalue_reference<T>::value)
policy = return_value_policy_override<Key>::policy(policy);
pybind11::set s;
for (auto &&value : src) {
auto value_ = reinterpret_steal<object>(key_conv::cast(forward_like<T>(value), policy, parent));
if (!value_ || !s.add(value_))
return handle();
}
return s.release();
template <typename T>
static handle cast(T &&src, return_value_policy policy, handle parent) {
if (!std::is_lvalue_reference<T>::value)
policy = return_value_policy_override<Key>::policy(policy);
pybind11::set s;
for (auto &&value : src) {
auto value_ = reinterpret_steal<object>(
key_conv::cast(forward_like<T>(value), policy, parent));
if (!value_ || !s.add(value_))
return handle();
}
return s.release();
}
PYBIND11_TYPE_CASTER(type, _("Set[") + key_conv::name + _("]"));
PYBIND11_TYPE_CASTER(type, _("Set[") + key_conv::name + _("]"));
};
template <typename Type, typename Key, typename Value> struct map_caster {
using key_conv = make_caster<Key>;
using value_conv = make_caster<Value>;
using key_conv = make_caster<Key>;
using value_conv = make_caster<Value>;
bool load(handle src, bool convert) {
if (!isinstance<dict>(src))
return false;
auto d = reinterpret_borrow<dict>(src);
value.clear();
for (auto it : d) {
key_conv kconv;
value_conv vconv;
if (!kconv.load(it.first.ptr(), convert) ||
!vconv.load(it.second.ptr(), convert))
return false;
value.emplace(cast_op<Key &&>(std::move(kconv)), cast_op<Value &&>(std::move(vconv)));
}
return true;
bool load(handle src, bool convert) {
if (!isinstance<dict>(src))
return false;
auto d = reinterpret_borrow<dict>(src);
value.clear();
for (auto it : d) {
key_conv kconv;
value_conv vconv;
if (!kconv.load(it.first.ptr(), convert) ||
!vconv.load(it.second.ptr(), convert))
return false;
value.emplace(cast_op<Key &&>(std::move(kconv)),
cast_op<Value &&>(std::move(vconv)));
}
return true;
}
template <typename T>
static handle cast(T &&src, return_value_policy policy, handle parent) {
dict d;
return_value_policy policy_key = policy;
return_value_policy policy_value = policy;
if (!std::is_lvalue_reference<T>::value) {
policy_key = return_value_policy_override<Key>::policy(policy_key);
policy_value = return_value_policy_override<Value>::policy(policy_value);
}
for (auto &&kv : src) {
auto key = reinterpret_steal<object>(key_conv::cast(forward_like<T>(kv.first), policy_key, parent));
auto value = reinterpret_steal<object>(value_conv::cast(forward_like<T>(kv.second), policy_value, parent));
if (!key || !value)
return handle();
d[key] = value;
}
return d.release();
template <typename T>
static handle cast(T &&src, return_value_policy policy, handle parent) {
dict d;
return_value_policy policy_key = policy;
return_value_policy policy_value = policy;
if (!std::is_lvalue_reference<T>::value) {
policy_key = return_value_policy_override<Key>::policy(policy_key);
policy_value = return_value_policy_override<Value>::policy(policy_value);
}
for (auto &&kv : src) {
auto key = reinterpret_steal<object>(
key_conv::cast(forward_like<T>(kv.first), policy_key, parent));
auto value = reinterpret_steal<object>(
value_conv::cast(forward_like<T>(kv.second), policy_value, parent));
if (!key || !value)
return handle();
d[key] = value;
}
return d.release();
}
PYBIND11_TYPE_CASTER(Type, _("Dict[") + key_conv::name + _(", ") + value_conv::name + _("]"));
PYBIND11_TYPE_CASTER(Type, _("Dict[") + key_conv::name + _(", ") +
value_conv::name + _("]"));
};
template <typename Type, typename Value> struct list_caster {
using value_conv = make_caster<Value>;
using value_conv = make_caster<Value>;
bool load(handle src, bool convert) {
if (!isinstance<sequence>(src) || isinstance<str>(src))
return false;
auto s = reinterpret_borrow<sequence>(src);
value.clear();
reserve_maybe(s, &value);
for (auto it : s) {
value_conv conv;
if (!conv.load(it, convert))
return false;
value.push_back(cast_op<Value &&>(std::move(conv)));
}
return true;
bool load(handle src, bool convert) {
if (!isinstance<sequence>(src) || isinstance<str>(src))
return false;
auto s = reinterpret_borrow<sequence>(src);
value.clear();
reserve_maybe(s, &value);
for (auto it : s) {
value_conv conv;
if (!conv.load(it, convert))
return false;
value.push_back(cast_op<Value &&>(std::move(conv)));
}
return true;
}
private:
template <typename T = Type,
enable_if_t<std::is_same<decltype(std::declval<T>().reserve(0)), void>::value, int> = 0>
void reserve_maybe(sequence s, Type *) { value.reserve(s.size()); }
void reserve_maybe(sequence, void *) { }
template <typename T = Type,
enable_if_t<std::is_same<decltype(std::declval<T>().reserve(0)),
void>::value,
int> = 0>
void reserve_maybe(sequence s, Type *) {
value.reserve(s.size());
}
void reserve_maybe(sequence, void *) {}
public:
template <typename T>
static handle cast(T &&src, return_value_policy policy, handle parent) {
if (!std::is_lvalue_reference<T>::value)
policy = return_value_policy_override<Value>::policy(policy);
list l(src.size());
size_t index = 0;
for (auto &&value : src) {
auto value_ = reinterpret_steal<object>(value_conv::cast(forward_like<T>(value), policy, parent));
if (!value_)
return handle();
PyList_SET_ITEM(l.ptr(), (ssize_t) index++, value_.release().ptr()); // steals a reference
}
return l.release();
template <typename T>
static handle cast(T &&src, return_value_policy policy, handle parent) {
if (!std::is_lvalue_reference<T>::value)
policy = return_value_policy_override<Value>::policy(policy);
list l(src.size());
size_t index = 0;
for (auto &&value : src) {
auto value_ = reinterpret_steal<object>(
value_conv::cast(forward_like<T>(value), policy, parent));
if (!value_)
return handle();
PyList_SET_ITEM(l.ptr(), (ssize_t)index++,
value_.release().ptr()); // steals a reference
}
return l.release();
}
PYBIND11_TYPE_CASTER(Type, _("List[") + value_conv::name + _("]"));
PYBIND11_TYPE_CASTER(Type, _("List[") + value_conv::name + _("]"));
};
template <typename Type, typename Alloc> struct type_caster<std::vector<Type, Alloc>>
: list_caster<std::vector<Type, Alloc>, Type> { };
template <typename Type, typename Alloc>
struct type_caster<std::vector<Type, Alloc>>
: list_caster<std::vector<Type, Alloc>, Type> {};
template <typename Type, typename Alloc> struct type_caster<std::deque<Type, Alloc>>
: list_caster<std::deque<Type, Alloc>, Type> { };
template <typename Type, typename Alloc>
struct type_caster<std::deque<Type, Alloc>>
: list_caster<std::deque<Type, Alloc>, Type> {};
template <typename Type, typename Alloc> struct type_caster<std::list<Type, Alloc>>
: list_caster<std::list<Type, Alloc>, Type> { };
template <typename Type, typename Alloc>
struct type_caster<std::list<Type, Alloc>>
: list_caster<std::list<Type, Alloc>, Type> {};
template <typename ArrayType, typename Value, bool Resizable, size_t Size = 0> struct array_caster {
using value_conv = make_caster<Value>;
template <typename ArrayType, typename Value, bool Resizable, size_t Size = 0>
struct array_caster {
using value_conv = make_caster<Value>;
private:
template <bool R = Resizable>
bool require_size(enable_if_t<R, size_t> size) {
if (value.size() != size)
value.resize(size);
return true;
}
template <bool R = Resizable>
bool require_size(enable_if_t<!R, size_t> size) {
return size == Size;
}
template <bool R = Resizable> bool require_size(enable_if_t<R, size_t> size) {
if (value.size() != size)
value.resize(size);
return true;
}
template <bool R = Resizable>
bool require_size(enable_if_t<!R, size_t> size) {
return size == Size;
}
public:
bool load(handle src, bool convert) {
if (!isinstance<sequence>(src))
return false;
auto l = reinterpret_borrow<sequence>(src);
if (!require_size(l.size()))
return false;
size_t ctr = 0;
for (auto it : l) {
value_conv conv;
if (!conv.load(it, convert))
return false;
value[ctr++] = cast_op<Value &&>(std::move(conv));
}
return true;
bool load(handle src, bool convert) {
if (!isinstance<sequence>(src))
return false;
auto l = reinterpret_borrow<sequence>(src);
if (!require_size(l.size()))
return false;
size_t ctr = 0;
for (auto it : l) {
value_conv conv;
if (!conv.load(it, convert))
return false;
value[ctr++] = cast_op<Value &&>(std::move(conv));
}
return true;
}
template <typename T>
static handle cast(T &&src, return_value_policy policy, handle parent) {
list l(src.size());
size_t index = 0;
for (auto &&value : src) {
auto value_ = reinterpret_steal<object>(value_conv::cast(forward_like<T>(value), policy, parent));
if (!value_)
return handle();
PyList_SET_ITEM(l.ptr(), (ssize_t) index++, value_.release().ptr()); // steals a reference
}
return l.release();
template <typename T>
static handle cast(T &&src, return_value_policy policy, handle parent) {
list l(src.size());
size_t index = 0;
for (auto &&value : src) {
auto value_ = reinterpret_steal<object>(
value_conv::cast(forward_like<T>(value), policy, parent));
if (!value_)
return handle();
PyList_SET_ITEM(l.ptr(), (ssize_t)index++,
value_.release().ptr()); // steals a reference
}
return l.release();
}
PYBIND11_TYPE_CASTER(ArrayType, _("List[") + value_conv::name + _<Resizable>(_(""), _("[") + _<Size>() + _("]")) + _("]"));
PYBIND11_TYPE_CASTER(ArrayType,
_("List[") + value_conv::name +
_<Resizable>(_(""), _("[") + _<Size>() + _("]")) +
_("]"));
};
template <typename Type, size_t Size> struct type_caster<std::array<Type, Size>>
: array_caster<std::array<Type, Size>, Type, false, Size> { };
template <typename Type, size_t Size>
struct type_caster<std::array<Type, Size>>
: array_caster<std::array<Type, Size>, Type, false, Size> {};
template <typename Type> struct type_caster<std::valarray<Type>>
: array_caster<std::valarray<Type>, Type, true> { };
template <typename Type>
struct type_caster<std::valarray<Type>>
: array_caster<std::valarray<Type>, Type, true> {};
template <typename Key, typename Compare, typename Alloc> struct type_caster<std::set<Key, Compare, Alloc>>
: set_caster<std::set<Key, Compare, Alloc>, Key> { };
template <typename Key, typename Compare, typename Alloc>
struct type_caster<std::set<Key, Compare, Alloc>>
: set_caster<std::set<Key, Compare, Alloc>, Key> {};
template <typename Key, typename Hash, typename Equal, typename Alloc> struct type_caster<std::unordered_set<Key, Hash, Equal, Alloc>>
: set_caster<std::unordered_set<Key, Hash, Equal, Alloc>, Key> { };
template <typename Key, typename Hash, typename Equal, typename Alloc>
struct type_caster<std::unordered_set<Key, Hash, Equal, Alloc>>
: set_caster<std::unordered_set<Key, Hash, Equal, Alloc>, Key> {};
template <typename Key, typename Value, typename Compare, typename Alloc> struct type_caster<std::map<Key, Value, Compare, Alloc>>
: map_caster<std::map<Key, Value, Compare, Alloc>, Key, Value> { };
template <typename Key, typename Value, typename Compare, typename Alloc>
struct type_caster<std::map<Key, Value, Compare, Alloc>>
: map_caster<std::map<Key, Value, Compare, Alloc>, Key, Value> {};
template <typename Key, typename Value, typename Hash, typename Equal, typename Alloc> struct type_caster<std::unordered_map<Key, Value, Hash, Equal, Alloc>>
: map_caster<std::unordered_map<Key, Value, Hash, Equal, Alloc>, Key, Value> { };
template <typename Key, typename Value, typename Hash, typename Equal,
typename Alloc>
struct type_caster<std::unordered_map<Key, Value, Hash, Equal, Alloc>>
: map_caster<std::unordered_map<Key, Value, Hash, Equal, Alloc>, Key,
Value> {};
// This type caster is intended to be used for std::optional and std::experimental::optional
template<typename T> struct optional_caster {
using value_conv = make_caster<typename T::value_type>;
// This type caster is intended to be used for std::optional and
// std::experimental::optional
template <typename T> struct optional_caster {
using value_conv = make_caster<typename T::value_type>;
template <typename T_>
static handle cast(T_ &&src, return_value_policy policy, handle parent) {
if (!src)
return none().inc_ref();
policy = return_value_policy_override<typename T::value_type>::policy(policy);
return value_conv::cast(*std::forward<T_>(src), policy, parent);
template <typename T_>
static handle cast(T_ &&src, return_value_policy policy, handle parent) {
if (!src)
return none().inc_ref();
policy =
return_value_policy_override<typename T::value_type>::policy(policy);
return value_conv::cast(*std::forward<T_>(src), policy, parent);
}
bool load(handle src, bool convert) {
if (!src) {
return false;
} else if (src.is_none()) {
return true; // default-constructed value is already empty
}
value_conv inner_caster;
if (!inner_caster.load(src, convert))
return false;
bool load(handle src, bool convert) {
if (!src) {
return false;
} else if (src.is_none()) {
return true; // default-constructed value is already empty
}
value_conv inner_caster;
if (!inner_caster.load(src, convert))
return false;
value.emplace(cast_op<typename T::value_type &&>(std::move(inner_caster)));
return true;
}
value.emplace(cast_op<typename T::value_type &&>(std::move(inner_caster)));
return true;
}
PYBIND11_TYPE_CASTER(T, _("Optional[") + value_conv::name + _("]"));
PYBIND11_TYPE_CASTER(T, _("Optional[") + value_conv::name + _("]"));
};
#if PYBIND11_HAS_OPTIONAL
template<typename T> struct type_caster<std::optional<T>>
template <typename T>
struct type_caster<std::optional<T>>
: public optional_caster<std::optional<T>> {};
template<> struct type_caster<std::nullopt_t>
: public void_caster<std::nullopt_t> {};
template <>
struct type_caster<std::nullopt_t> : public void_caster<std::nullopt_t> {};
#endif
#if PYBIND11_HAS_EXP_OPTIONAL
template<typename T> struct type_caster<std::experimental::optional<T>>
template <typename T>
struct type_caster<std::experimental::optional<T>>
: public optional_caster<std::experimental::optional<T>> {};
template<> struct type_caster<std::experimental::nullopt_t>
template <>
struct type_caster<std::experimental::nullopt_t>
: public void_caster<std::experimental::nullopt_t> {};
#endif
/// Visit a variant and cast any found type to Python
struct variant_caster_visitor {
return_value_policy policy;
handle parent;
return_value_policy policy;
handle parent;
using result_type = handle; // required by boost::variant in C++11
using result_type = handle; // required by boost::variant in C++11
template <typename T>
result_type operator()(T &&src) const {
return make_caster<T>::cast(std::forward<T>(src), policy, parent);
}
template <typename T> result_type operator()(T &&src) const {
return make_caster<T>::cast(std::forward<T>(src), policy, parent);
}
};
/// Helper class which abstracts away variant's `visit` function. `std::variant` and similar
/// `namespace::variant` types which provide a `namespace::visit()` function are handled here
/// automatically using argument-dependent lookup. Users can provide specializations for other
/// variant-like classes, e.g. `boost::variant` and `boost::apply_visitor`.
template <template<typename...> class Variant>
struct visit_helper {
template <typename... Args>
static auto call(Args &&...args) -> decltype(visit(std::forward<Args>(args)...)) {
return visit(std::forward<Args>(args)...);
}
/// Helper class which abstracts away variant's `visit` function. `std::variant`
/// and similar `namespace::variant` types which provide a `namespace::visit()`
/// function are handled here automatically using argument-dependent lookup.
/// Users can provide specializations for other variant-like classes, e.g.
/// `boost::variant` and `boost::apply_visitor`.
template <template <typename...> class Variant> struct visit_helper {
template <typename... Args>
static auto call(Args &&... args)
-> decltype(visit(std::forward<Args>(args)...)) {
return visit(std::forward<Args>(args)...);
}
};
/// Generic variant caster
template <typename Variant> struct variant_caster;
template <template<typename...> class V, typename... Ts>
template <template <typename...> class V, typename... Ts>
struct variant_caster<V<Ts...>> {
static_assert(sizeof...(Ts) > 0, "Variant must consist of at least one alternative.");
static_assert(sizeof...(Ts) > 0,
"Variant must consist of at least one alternative.");
template <typename U, typename... Us>
bool load_alternative(handle src, bool convert, type_list<U, Us...>) {
auto caster = make_caster<U>();
if (caster.load(src, convert)) {
value = cast_op<U>(caster);
return true;
}
return load_alternative(src, convert, type_list<Us...>{});
template <typename U, typename... Us>
bool load_alternative(handle src, bool convert, type_list<U, Us...>) {
auto caster = make_caster<U>();
if (caster.load(src, convert)) {
value = cast_op<U>(caster);
return true;
}
return load_alternative(src, convert, type_list<Us...>{});
}
bool load_alternative(handle, bool, type_list<>) { return false; }
bool load_alternative(handle, bool, type_list<>) { return false; }
bool load(handle src, bool convert) {
// Do a first pass without conversions to improve constructor resolution.
// E.g. `py::int_(1).cast<variant<double, int>>()` needs to fill the `int`
// slot of the variant. Without two-pass loading `double` would be filled
// because it appears first and a conversion is possible.
if (convert && load_alternative(src, false, type_list<Ts...>{}))
return true;
return load_alternative(src, convert, type_list<Ts...>{});
}
bool load(handle src, bool convert) {
// Do a first pass without conversions to improve constructor resolution.
// E.g. `py::int_(1).cast<variant<double, int>>()` needs to fill the `int`
// slot of the variant. Without two-pass loading `double` would be filled
// because it appears first and a conversion is possible.
if (convert && load_alternative(src, false, type_list<Ts...>{}))
return true;
return load_alternative(src, convert, type_list<Ts...>{});
}
template <typename Variant>
static handle cast(Variant &&src, return_value_policy policy, handle parent) {
return visit_helper<V>::call(variant_caster_visitor{policy, parent},
std::forward<Variant>(src));
}
template <typename Variant>
static handle cast(Variant &&src, return_value_policy policy, handle parent) {
return visit_helper<V>::call(variant_caster_visitor{policy, parent},
std::forward<Variant>(src));
}
using Type = V<Ts...>;
PYBIND11_TYPE_CASTER(Type, _("Union[") + detail::concat(make_caster<Ts>::name...) + _("]"));
using Type = V<Ts...>;
PYBIND11_TYPE_CASTER(Type, _("Union[") +
detail::concat(make_caster<Ts>::name...) +
_("]"));
};
#if PYBIND11_HAS_VARIANT
template <typename... Ts>
struct type_caster<std::variant<Ts...>> : variant_caster<std::variant<Ts...>> { };
struct type_caster<std::variant<Ts...>> : variant_caster<std::variant<Ts...>> {
};
#endif
NAMESPACE_END(detail)
inline std::ostream &operator<<(std::ostream &os, const handle &obj) {
os << (std::string) str(obj);
return os;
os << (std::string)str(obj);
return os;
}
NAMESPACE_END(PYBIND11_NAMESPACE)

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@@ -1,48 +1,51 @@
#include "triton/Analysis/AxisInfo.h"
#include "mlir/Pass/Pass.h"
#include "triton/Analysis/AxisInfo.h"
using namespace mlir;
namespace{
namespace {
struct TestAxisInfoPass
: public PassWrapper<TestAxisInfoPass, OperationPass<FuncOp>>{
: public PassWrapper<TestAxisInfoPass, OperationPass<FuncOp>> {
// LLVM15+
// MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestAlignmentPass);
void print(const std::string& name, raw_ostream& os, ArrayRef<int> vals){
void print(const std::string &name, raw_ostream &os, ArrayRef<int> vals) {
os << name << ": [";
for(size_t d = 0; d < vals.size(); d++){
if(d != 0) os << ", ";
for (size_t d = 0; d < vals.size(); d++) {
if (d != 0)
os << ", ";
os << vals[d];
}
os << "]";
}
StringRef getArgument() const final { return "test-print-alignment"; }
StringRef getDescription() const final
{ return "print the result of the alignment analysis pass"; }
StringRef getDescription() const final {
return "print the result of the alignment analysis pass";
}
void runOnOperation() override {
Operation* operation = getOperation();
auto& os = llvm::errs();
Operation *operation = getOperation();
auto &os = llvm::errs();
os << "Testing: " << operation->getName() << "\n";
AxisInfoAnalysis analysis(&getContext());
analysis.run(operation);
operation->walk([&](Operation* op){
if(op->getNumResults() < 1)
operation->walk([&](Operation *op) {
if (op->getNumResults() < 1)
return;
for(Value result: op->getResults()){
for (Value result : op->getResults()) {
// std::ostringstream oss;
// result.print(oss);
// os << " => ";
LatticeElement<AxisInfo> *latticeElement = analysis.lookupLatticeElement(result);
if(!latticeElement){
LatticeElement<AxisInfo> *latticeElement =
analysis.lookupLatticeElement(result);
if (!latticeElement) {
os << "None\n";
return;
}
AxisInfo& info = latticeElement->getValue();
AxisInfo &info = latticeElement->getValue();
print("Contiguity", os, info.getContiguity());
os << " ; ";
print("Divisibility", os, info.getDivisibility());
@@ -50,18 +53,17 @@ struct TestAxisInfoPass
print("Constancy", os, info.getConstancy());
os << " ( ";
result.print(os);
os << " ) ";
os << " ) ";
os << "\n";
}
});
}
};
}
} // namespace
namespace mlir{
namespace test{
namespace mlir {
namespace test {
void registerTestAlignmentPass() { PassRegistration<TestAxisInfoPass>(); }
}
}
} // namespace test
} // namespace mlir