History prior to this date belonged to the now deprecated ISAAC project, and was deleted to save space

This commit is contained in:
Philippe Tillet
2021-07-27 12:38:38 -07:00
commit 6d7cf35123
202 changed files with 94034 additions and 0 deletions

View File

@@ -0,0 +1,48 @@
import torch
import numpy as np
import reference
import optimized
from time import time
use_half = True
def cast(x):
if use_half:
return x.half()
else:
return x
# GPU device
device = torch.device("cuda:0")
# shapes
batch, nhead = 8, 28
dm, dk, dv = 1024, 1024, 1024
lq, lk, lv = 1024, 1024, 1024
# initialize tensors
torch.manual_seed(0)
np.random.seed(0)
query = cast(torch.randn(batch, lq, dm)).cuda()
key = cast(torch.randn(batch, lk, dm)).cuda()
value = cast(torch.randn(batch, lv, dm)).cuda()
# initialize layers
torch.manual_seed(0)
np.random.seed(0)
rattn = cast(reference.MultiHeadAttention(nhead, dm, dk, dv).to(device))
torch.manual_seed(0)
np.random.seed(0)
tattn = cast(optimized.MultiHeadAttention(nhead, dm, dk, dv).to(device))
# test
routput, _ = rattn(query, key, value)
toutput, _ = tattn(query, key, value)
diff = torch.max(torch.abs(routput - toutput))
assert diff < 1e-2
# benchmark
start = time()
routput, _ = rattn(query, key, value)
end = time()
rtime = end - start
start = time()
toutput, _ = tattn(query, key, value)
end = time()
ttime = end - start
print(f'Torch: {rtime} s')
print(f'Triton: {ttime} s')

View File

@@ -0,0 +1,50 @@
import numpy as np
import torch
import torch.nn as nn
import triton
class MultiHeadAttention(nn.Module):
''' Multi-Head Attention module '''
def __init__(self, n_head, d_model, d_k, d_v):
super().__init__()
self.n_head = n_head
self.d_k = d_k
self.d_v = d_v
# linear layers
self.w_qs = nn.Linear(d_model, n_head * d_k)
self.w_ks = nn.Linear(d_model, n_head * d_k)
self.w_vs = nn.Linear(d_model, n_head * d_v)
self.fc = nn.Linear(n_head * d_v, d_model)
# initialize weights
nn.init.normal_(self.w_qs.weight, mean=0, std=np.sqrt(2.0 / (d_model + d_k)))
nn.init.normal_(self.w_ks.weight, mean=0, std=np.sqrt(2.0 / (d_model + d_k)))
nn.init.normal_(self.w_vs.weight, mean=0, std=np.sqrt(2.0 / (d_model + d_v)))
nn.init.xavier_normal_(self.fc.weight)
# layer normalization
self.layer_norm = nn.LayerNorm(d_model)
def forward(self, q, k, v, mask=None):
# dimensions
d_k, d_v, n_head = self.d_k, self.d_v, self.n_head
sz_b, len_q, _ = q.size()
sz_b, len_k, _ = k.size()
sz_b, len_v, _ = v.size()
# linear transformations
residual = q
q = self.w_qs(q).view(sz_b, len_q, n_head, d_k)
k = self.w_ks(k).view(sz_b, len_k, n_head, d_k)
v = self.w_vs(v).view(sz_b, len_v, n_head, d_v)
# scaled dot-product attention
attn = triton.ops.einsum('blhk,bthk->hblt', q, k, [n_head, sz_b, len_q, len_k])
attn = attn / np.sqrt(d_k)
if mask is not None:
attn = attn.masked_fill(mask[None], -np.inf)
attn = torch.softmax(attn, dim=3)
output = triton.ops.einsum('hblt,bthv->blhv', attn, v, [sz_b, len_q, n_head, d_v])
output = output.view(sz_b, len_q, -1)
output = self.fc(output)
# epilogue
output = self.layer_norm(output + residual)
return output, attn

View File

@@ -0,0 +1,72 @@
import numpy as np
import torch
import torch.nn as nn
class ScaledDotProductAttention(nn.Module):
''' Scaled Dot-Product Attention '''
def __init__(self, temperature, attn_dropout=0.1):
super().__init__()
self.temperature = temperature
self.softmax = nn.Softmax(dim=2)
def forward(self, q, k, v, mask=None):
attn = torch.bmm(q, k.transpose(1, 2))
attn = attn / self.temperature
if mask is not None:
attn = attn.masked_fill(mask, -np.inf)
attn = self.softmax(attn)
output = torch.bmm(attn, v)
return output, attn
class MultiHeadAttention(nn.Module):
''' Multi-Head Attention module '''
def __init__(self, n_head, d_model, d_k, d_v):
super().__init__()
self.n_head = n_head
self.d_k = d_k
self.d_v = d_v
# linear layers
self.w_qs = nn.Linear(d_model, n_head * d_k)
self.w_ks = nn.Linear(d_model, n_head * d_k)
self.w_vs = nn.Linear(d_model, n_head * d_v)
self.fc = nn.Linear(n_head * d_v, d_model)
# initialize weights
nn.init.normal_(self.w_qs.weight, mean=0, std=np.sqrt(2.0 / (d_model + d_k)))
nn.init.normal_(self.w_ks.weight, mean=0, std=np.sqrt(2.0 / (d_model + d_k)))
nn.init.normal_(self.w_vs.weight, mean=0, std=np.sqrt(2.0 / (d_model + d_v)))
nn.init.xavier_normal_(self.fc.weight)
# normalization
self.layer_norm = nn.LayerNorm(d_model)
# scaled dot-product
self.attention = ScaledDotProductAttention(temperature=np.power(d_k, 0.5))
def forward(self, q, k, v, mask=None):
# dimensions
d_k, d_v, n_head = self.d_k, self.d_v, self.n_head
sz_b, len_q, _ = q.size()
sz_b, len_k, _ = k.size()
sz_b, len_v, _ = v.size()
# linear transformations
residual = q
q = self.w_qs(q).view(sz_b, len_q, n_head, d_k)
k = self.w_ks(k).view(sz_b, len_k, n_head, d_k)
v = self.w_vs(v).view(sz_b, len_v, n_head, d_v)
# scaled dot-product attention
q = q.permute(2, 0, 1, 3).contiguous().view(-1, len_q, d_k) # (n*b) x lq x dk
k = k.permute(2, 0, 1, 3).contiguous().view(-1, len_k, d_k) # (n*b) x lk x dk
v = v.permute(2, 0, 1, 3).contiguous().view(-1, len_v, d_v) # (n*b) x lv x dv
if mask:
mask = mask.repeat(n_head, 1, 1) # (n*b) x .. x ..
output, attn = self.attention(q, k, v, mask=mask)
# linear transformation
output = output.view(n_head, sz_b, len_q, d_v)
output = output.permute(1, 2, 0, 3).contiguous().view(sz_b, len_q, -1) # b x lq x (n*dv)
output = self.fc(output)
# normalization
output = self.layer_norm(output + residual)
return output, attn

View File

@@ -0,0 +1,56 @@
import triton
import numpy as np
from enum import Enum
class MODE(Enum):
TF = 1
TORCH = 2
try:
import tensorflow as tf
mode = MODE.TF
except ModuleNotFoundError:
pass
try:
import torch
mode = MODE.TORCH
except ModuleNotFoundError:
pass
C, H, W, B = 32, 1, 1, 128
x = np.random.uniform(-1, 1, (C, H, W, B)).astype(np.float32)
gamma = np.random.uniform(-1, 1, C).astype(np.float32)
beta = np.random.uniform(-1, 1, C).astype(np.float32)
dy = np.random.uniform(-1, 1, (C, H, W, B)).astype(np.float32)
if mode == MODE.TORCH:
fw_x = torch.from_numpy(x).cuda()
fw_gamma = torch.from_numpy(gamma).cuda()
fw_beta = torch.from_numpy(beta).cuda()
fw_dy = torch.from_numpy(dy).cuda()
# register gradients
fw_x.requires_grad_(True)
fw_gamma.requires_grad_(True)
fw_beta.requires_grad_(True)
# execute
fw_y = triton.ops.batchnorm(fw_x, fw_gamma, fw_beta, 1e-4)
fw_y.backward(fw_dy)
if mode == MODE.TF:
fw_x = tf.placeholder(shape=x.shape, dtype=x.dtype)
fw_gamma = tf.placeholder(shape=gamma.shape, dtype=gamma.dtype)
fw_beta = tf.placeholder(shape=beta.shape, dtype=beta.dtype)
fw_dy = tf.placeholder(shape=dy.shape, dtype=dy.dtype)
# execute
fw_y = triton.ops.batchnorm(fw_x, fw_gamma, fw_beta, 1e-4)
fw_mean, fw_var = tf.nn.moments(fw_x, [1, 2, 3])
fw_dx, fw_dgamma, fw_dbeta = tf.gradients(fw_y, [fw_x, fw_gamma, fw_beta], fw_dy)
# run
sess = tf.InteractiveSession()
feed_dict = {fw_x: x, fw_gamma: gamma, fw_beta: beta, fw_dy: dy}
sess.run(tf.global_variables_initializer())
result = sess.run([fw_dx, fw_dgamma, fw_dbeta], feed_dict=feed_dict)
print(result)

197
python/examples/einsum.py Normal file
View File

@@ -0,0 +1,197 @@
import triton
import torch
from torch.utils.cpp_extension import load
import numpy as np
#import utils
from time import time
#torch.backends.cudnn.benchmark = True
configs = []
# Matrix multiplication
MNK = [
(512, 512 ,512),
(2048, 2048, 2048),
(8192, 8192, 8192),
# (64, 64, 64000),
# (64, 64, 128000),
# (256, 256, 64000),
# (256, 256, 128000),
# (1536, 16, 1536),
# (1536, 32, 1536),
# (1536, 64, 1536),
# (1536, 128, 1536),
# (4096, 16, 4096),
# (4096, 32, 4096),
# (4096, 64, 4096),
# (4096, 128, 4096),
# (127008, 768, 576)
]
for M, N, K in MNK:
matmul = lambda a, b: torch.matmul(a, b)
configs += [([M, K], [K, N], [M, N], matmul, 'mk,kn->mn', dict())]
for M, N, K in MNK:
matmul = lambda a, b: torch.matmul(a.t(), b)
configs += [([M, K], [M, N], [K, N], None, 'mk,mn->kn', dict())]
for M, N, K in MNK:
matmul = lambda a, b: torch.matmul(a, b.t())
configs += [([M, N], [K, N], [M, K], None, 'mn,kn->mk', dict())]
# Relative attention
NTHSE = [
(16, 512, 1, 64, 64),
# (16, 512, 1, 128, 128),
# (16, 512, 1, 256, 256),
# (16, 512, 1, 256, 512),
(16, 512, 8, 64, 64),
# (16, 512, 8, 128, 128),
# (16, 512, 8, 256, 256),
# (16, 512, 8, 256, 512),
# (64, 1024, 1, 64, 64),
(64, 1024, 1, 128, 128),
# (64, 1024, 1, 256, 256),
# (64, 1024, 1, 256, 512),
# (64, 1024, 8, 64, 64),
(64, 1024, 8, 128, 128),
# (64, 1024, 8, 256, 256),
# (64, 1024, 8, 256, 512),
# (128, 1024, 1, 64, 64),
# (128, 1024, 1, 128, 128),
# (128, 1024, 1, 256, 256),
(128, 1024, 1, 256, 512),
# (128, 1024, 8, 64, 64),
# (128, 1024, 8, 128, 128),
# (128, 1024, 8, 256, 256),
(128, 1024, 8, 256, 512)
]
for N, T, H, S, E in NTHSE:
configs += [([N, T, H, S], [H, E, S], [N, H, T, E], None, 'nths,hes->nhte', dict())]
for N, T, H, S, E in NTHSE:
configs += [([N, H, T, E], [N, T, H, S], [H, E, S], None, 'nhte,nths->hes', dict())]
for N, T, H, S, E in NTHSE:
configs += [([N, H, T, E], [H, E, S], [N, T, H, S], None, 'nhte,hes->nths', dict())]
# 1D Dense convolution
NCHKR = [
(1, 1152, 12602, 512, 3)
]
for N, C, H, K, R in NCHKR:
torch_fn = lambda a, b: torch.nn.functional.conv1d(a, b.permute(2, 0, 1))
configs += [([N, C, H],
[C, R, K],
[N, K, H - R + 1],
torch_fn,
'nc(h+r),crk->nkh',
dict())]
# 2D Dense convolution
NCHWKRS = [
(8, 64, 128, 128, 768, 3, 3),
(8, 128, 64, 64, 256, 3, 3),
(8, 256, 32, 32, 512, 3, 3),
(8, 512, 32, 32, 1024, 3, 3)
]
for N, C, H, W, K, R, S in NCHWKRS:
torch_fn = lambda a, b: torch.nn.functional.conv2d(a, b.permute(3, 0, 1, 2))
configs += [([N, C, H, W],
[C, R, S, K],
[N, K, H - R + 1, W - R + 1],
torch_fn,
'nc(h+r)(w+s),crsk->nkhw',
dict())]
# 3D Dense Convolution
NCDHWKTRS = [
(8, 32, 27, 100, 100, 64, 3, 3, 3),
(8, 64, 23, 48, 48, 256, 3, 3, 3),
(8, 256, 19, 22, 22, 640, 3, 3, 3),
(8, 640, 15, 36, 36, 384, 3, 3, 3)
]
for N, C, D, H, W, K, T, R, S in NCDHWKTRS:
torch_fn = lambda a, b: torch.nn.functional.conv3d(a, b.permute(4, 0, 1, 2, 3))
configs += [([N, C, D, H, W],
[C, T, R, S, K],
[N, K, D - T + 1, H - R + 1, W - R + 1],
torch_fn,
'nc(d+t)(h+r)(w+s),ctrsk->nkdhw',
dict())]
# Shift convolution
shift_cuda = torch.utils.cpp_extension.load(
'shift_cuda', ['kernels/shift_cuda.cpp',
'kernels/shift_cuda_kernel.cu'],
extra_cflags=['-O3'])
class shift(torch.autograd.Function):
@staticmethod
def forward(ctx, x, shift):
ctx.save_for_backward(shift)
return shift_cuda.forward(x, shift)
@staticmethod
def backward(ctx, grad_output):
shift, = ctx.saved_tensors
grad_output = shift_cuda.backward(grad_output, shift)
return grad_output, None
NCHWKRS = [
#(8, 64, 128, 128, 128, 3, 3),
#(8, 128, 64, 64, 256, 3, 3),
#(8, 256, 32, 32, 512, 3, 3),
#(8, 512, 32, 32, 1024, 3, 3)
]
for N, C, H, W, K, R, S in NCHWKRS:
shift_h = np.random.randint(R, size=C, dtype=np.int32) - R//2
shift_w = np.random.randint(S, size=C, dtype=np.int32) - S//2
def shift_conv(a, b, **kwargs):
shift_h, shift_w = kwargs['sh'], kwargs['sw']
shift_torch = np.column_stack((shift_w*-1, shift_h*-1))
shift_torch = torch.from_numpy(shift_torch).cuda()
a = shift.apply(a, shift_torch)
b = b.permute(1, 0)
b = b.reshape(b.shape[0], b.shape[1], 1, 1)
return torch.nn.functional.conv2d(a, b)
configs += [([N, C, H, W],
[C, K],
[N, K, H, W],
shift_conv,
'nc(h + sh[c])(w + sw[c]),ck->nkhw',
{'sh': shift_h, 'sw': shift_w})]
# Benchmark
torch.set_num_threads(1)
for a_shape, b_shape, c_shape, torch_fn, expr, arrays in configs:
dtype = torch.cuda.FloatTensor
# initialize input tensors
a = torch.rand(*a_shape).type(dtype).cuda()
b = torch.rand(*b_shape).type(dtype).cuda()
# triton output
tc = triton.ops.einsum(expr, a, b, c_shape, arrays = arrays, bench = True)
# reference output
if torch_fn:
rc = torch_fn(a, b, **arrays)
else:
rc = torch.einsum(expr, a, b)
# performance relative to equivalent matrix multiplication
ctx = triton.ctx_registry[tc]
B, M, N, K = ctx.matmul_B, ctx.matmul_M, ctx.matmul_N, ctx.matmul_K
cmp_eqbmm = True
if cmp_eqbmm:
a = torch.rand(B, M, K).type(dtype).cuda()
b = torch.rand(B, K, N).type(dtype).cuda()
tmmc = triton.ops.einsum('bmk,bkn->bmn', a, b, [B, M, N], bench = True)
ratio = triton.bench_registry[tmmc] / triton.bench_registry[tc]
cmp_str = f'({ratio:4.2f})'
else:
cmp_str = ''
# test and benchmark
bench = 2. * B * M * N * K / triton.bench_registry[tc] * 1e-3
diff = (tc - rc).abs().max() / rc.abs().max()
print(f'{expr:>15}; {str(a_shape):>20}; {str(b_shape):>20}; {bench:4.2f} {cmp_str}; {diff:4.2f}')

View File

@@ -0,0 +1,42 @@
#include <torch/torch.h>
#include <vector>
// CUDA forward declarations
at::Tensor shift_cuda_forward(
const at::Tensor input,
const at::Tensor shift);
at::Tensor shift_cuda_backward(
const at::Tensor grad_input,
const at::Tensor shift);
// C++ interface
// NOTE: AT_ASSERT has become AT_CHECK on master after 0.4.
#define CHECK_CUDA(x) AT_ASSERTM(x.type().is_cuda(), #x " must be a CUDA tensor")
#define CHECK_CONTIGUOUS(x) AT_ASSERTM(x.is_contiguous(), #x " must be contiguous")
#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)
at::Tensor shift_forward(
const at::Tensor input,
const at::Tensor shift) {
CHECK_INPUT(input);
CHECK_INPUT(shift);
return shift_cuda_forward(input, shift);
}
at::Tensor shift_backward(
const at::Tensor grad_input,
const at::Tensor shift) {
CHECK_INPUT(grad_input);
CHECK_INPUT(shift);
return shift_cuda_backward(grad_input, shift);
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("forward", &shift_forward, "Shift forward (CUDA)");
m.def("backward", &shift_backward, "Shift backward (CUDA)");
}

View File

@@ -0,0 +1,111 @@
#include <ATen/ATen.h>
#include <cuda.h>
#include <cuda_runtime.h>
#include <vector>
namespace {
template <typename scalar_t>
__global__ void shift_cuda_forward_kernel(
const scalar_t* __restrict__ input,
const int32_t* __restrict__ shift,
scalar_t* __restrict__ output,
const int32_t B,
const int32_t C,
const int32_t H,
const int32_t W) {
const int32_t idx = blockIdx.x * blockDim.x + threadIdx.x;
const int32_t size = B*C*H*W;
const int32_t CHW = C*H*W;
const int32_t HW = H*W;
const int32_t b = idx / CHW;
const int32_t c = (idx - b*CHW) / HW;
const int32_t h = (idx - b*CHW - c*HW) / W;
const int32_t w = idx - b*CHW - c*HW - h*W;
const int32_t target_w = w + shift[2*c];
const int32_t target_h = h + shift[2*c + 1];
const int32_t target_idx = b*CHW + c*HW + target_h*W + target_w;
if (idx < size && target_w >= 0 && target_w < W && target_h >= 0 && target_h < H) {
output[target_idx] = input[idx];
}
}
template <typename scalar_t>
__global__ void shift_cuda_backward_kernel(
const scalar_t* __restrict__ grad_input,
scalar_t* __restrict__ grad_output,
const int32_t* __restrict__ shift,
const int32_t B,
const int32_t C,
const int32_t W,
const int32_t H) {
const int32_t idx = blockIdx.x * blockDim.x + threadIdx.x;
const int32_t size = B*C*W*H;
const int32_t CWH = C*W*H;
const int32_t WH = W*H;
const int32_t b = idx / CWH;
const int32_t c = (idx - b*CWH) / WH;
const int32_t w = (idx - b*CWH - c*WH) / W;
const int32_t h = idx - b*CWH - c*WH - w*H;
const int32_t target_w = w - shift[2*c];
const int32_t target_h = h - shift[2*c + 1];
const int32_t target_idx = b*CWH + c*WH + target_w*W + target_h;
if (idx < size && target_w >= 0 && target_w < W && target_h >= 0 && target_h < H) {
grad_output[target_idx] = grad_input[idx];
}
}
} // namespace
at::Tensor shift_cuda_forward(
const at::Tensor input,
const at::Tensor shift) {
const auto B = input.size(0);
const auto C = input.size(1);
const auto H = input.size(2);
const auto W = input.size(3);
const auto size = B*C*W*H;
const int threads = 1024;
const int blocks = (size + threads - 1) / threads;
auto output = at::zeros_like(input);
AT_DISPATCH_FLOATING_TYPES_AND_HALF(input.type(), "shift_forward_cuda", ([&] {
shift_cuda_forward_kernel<scalar_t><<<blocks, threads>>>(
input.data<scalar_t>(),
shift.data<int32_t>(),
output.data<scalar_t>(),
B,
C,
H,
W);
}));
return output;
}
at::Tensor shift_cuda_backward(
const at::Tensor grad_input,
const at::Tensor shift) {
const auto B = grad_input.size(0);
const auto C = grad_input.size(1);
const auto H = grad_input.size(2);
const auto W = grad_input.size(3);
const auto size = B*C*W*H;
const int threads = 1024;
const int blocks = (size + threads - 1) / threads;
auto grad_output = at::zeros_like(grad_input);
AT_DISPATCH_FLOATING_TYPES_AND_HALF(grad_input.type(), "shift_backward_cuda", ([&] {
shift_cuda_backward_kernel<scalar_t><<<blocks, threads>>>(
grad_input.data<scalar_t>(),
grad_output.data<scalar_t>(),
shift.data<int32_t>(),
B,
C,
H,
W);
}));
return grad_output;
}

122
python/setup.py Normal file
View File

@@ -0,0 +1,122 @@
import os
import re
import sys
import sysconfig
import platform
import subprocess
import distutils
import glob
from distutils.version import LooseVersion
from setuptools import setup, Extension, find_packages
from setuptools.command.build_ext import build_ext
from setuptools.command.test import test as TestCommand
import distutils.spawn
def find_llvm():
versions = ['9.0', '9', '90', '8.0', '8', '80']
supported = ['llvm-config-{v}'.format(v=v) for v in versions]
paths = [distutils.spawn.find_executable(cfg) for cfg in supported]
paths = [p for p in paths if p is not None]
if paths:
return paths[0]
config = distutils.spawn.find_executable('llvm-config')
instructions = 'Please install llvm-{8, 9, 10}-dev'
if config is None:
raise RuntimeError('Could not find llvm-config. ' + instructions)
version = os.popen('{config} --version'.format(config=config)).read()
raise RuntimeError('Version {v} not supported. '.format(v=version) + instructions)
class CMakeExtension(Extension):
def __init__(self, name, path, sourcedir=''):
Extension.__init__(self, name, sources=[])
self.sourcedir = os.path.abspath(sourcedir)
self.path = path
class CMakeBuild(build_ext):
def run(self):
try:
out = subprocess.check_output(['cmake', '--version'])
except OSError:
raise RuntimeError("CMake must be installed to build the following extensions: " +
", ".join(e.name for e in self.extensions))
if platform.system() == "Windows":
cmake_version = LooseVersion(re.search(r'version\s*([\d.]+)', out.decode()).group(1))
if cmake_version < '3.1.0':
raise RuntimeError("CMake >= 3.1.0 is required on Windows")
for ext in self.extensions:
self.build_extension(ext)
def build_extension(self, ext):
self.debug = True
extdir = os.path.abspath(os.path.dirname(self.get_ext_fullpath(ext.path)))
# python directories
python_include_dirs = distutils.sysconfig.get_python_inc()
python_lib_dirs = distutils.sysconfig.get_config_var('LIBDIR')
cmake_args = ['-DCMAKE_LIBRARY_OUTPUT_DIRECTORY=' + extdir,
'-DBUILD_TESTS=OFF',
'-DBUILD_PYTHON_MODULE=ON',
'-DPYTHON_INCLUDE_DIRS=' + python_include_dirs,
'-DLLVM_CONFIG=' + find_llvm()]
# tensorflow compatibility
try:
import tensorflow as tf
tf_abi = tf.__cxx11_abi_flag__ if "__cxx11_abi_flag__" in tf.__dict__ else 0
tf_include_dirs = tf.sysconfig.get_include()
tf_libs = tf.sysconfig.get_link_flags()[1].replace('-l', '')
cmake_args += ['-DTF_INCLUDE_DIRS=' + tf_include_dirs,
'-DTF_LIB_DIRS=' + tf.sysconfig.get_lib(),
'-DTF_LIBS=' + tf_libs,
'-DTF_ABI=' + str(tf_abi)]
except ModuleNotFoundError:
pass
cfg = 'Debug' if self.debug else 'Release'
cfg = 'Release'
build_args = ['--config', cfg]
if platform.system() == "Windows":
cmake_args += ['-DCMAKE_LIBRARY_OUTPUT_DIRECTORY_{}={}'.format(cfg.upper(), extdir)]
if sys.maxsize > 2**32:
cmake_args += ['-A', 'x64']
build_args += ['--', '/m']
else:
cmake_args += ['-DCMAKE_BUILD_TYPE=' + cfg]
build_args += ['--', '-j4']
env = os.environ.copy()
env['CXXFLAGS'] = '{} -DVERSION_INFO=\\"{}\\"'.format(env.get('CXXFLAGS', ''),
self.distribution.get_version())
if not os.path.exists(self.build_temp):
os.makedirs(self.build_temp)
sourcedir = os.path.abspath(os.path.join(os.path.dirname(__file__), os.pardir))
subprocess.check_call(['cmake', sourcedir] + cmake_args, cwd=self.build_temp, env=env)
subprocess.check_call(['cmake', '--build', '.'] + build_args, cwd=self.build_temp)
find_llvm()
directories = [x[0] for x in os.walk(os.path.join(os.path.pardir, 'include'))]
data = []
for d in directories:
files = glob.glob(os.path.join(d, '*.h'), recursive=False)
data += [os.path.relpath(f, os.path.pardir) for f in files]
setup(
name='triton',
version='0.1',
author='Philippe Tillet',
author_email='ptillet@g.harvard.edu',
description='A language and compiler for custom Deep Learning operations',
long_description='',
packages=['triton', 'triton/_C', 'triton/ops'],
package_data={'': data},
ext_modules=[CMakeExtension('triton', 'triton/_C/')],
cmdclass=dict(build_ext=CMakeBuild),
zip_safe=False,
)

642
python/src/bindings.cc Normal file
View File

@@ -0,0 +1,642 @@
#include <pybind11/pybind11.h>
#include <pybind11/buffer_info.h>
#include <pybind11/stl.h>
#include <pybind11/functional.h>
#include <string>
#include <regex>
#include <algorithm>
#include "triton/runtime/function.h"
#include "triton/lang/code_gen.h"
#include "triton/lang/parser.h"
#include "triton/lang/cpp.h"
#include "triton/driver/device.h"
#include "triton/driver/stream.h"
#include "triton/driver/kernel.h"
#include "triton/driver/module.h"
#include "triton/ir/module.h"
#include "triton/ir/function.h"
#include "triton/tools/bench.hpp"
using namespace triton;
namespace rt = triton::runtime;
std::map<size_t, std::shared_ptr<rt::function::grid_fn_ty>> id_grid_map;
std::map<size_t, std::shared_ptr<rt::function>> id_fn_map;
std::map<size_t, double> fp64scalar_map;
std::map<size_t, int64_t> i64scalar_map;
/* Grid map */
void register_grid(size_t id,
const rt::function::grid_fn_ty& grid_fn) {
id_grid_map[id].reset(new rt::function::grid_fn_ty(grid_fn));
}
void delete_grid(size_t id) {
id_grid_map.erase(id);
}
/* Function map */
void register_fn(size_t id,
const std::string& src,
const rt::function::options_space_t& opt) {
id_fn_map[id].reset(new rt::function(src, opt));
}
void delete_fn(size_t id) {
id_fn_map.erase(id);
}
void register_cst(size_t id, const std::string& name, pybind11::buffer& data) {
pybind11::buffer_info info = data.request();
id_fn_map[id]->set_cst(name, info.ptr, info.size*info.itemsize);
}
void cleanup() {
id_grid_map.clear();
id_fn_map.clear();
i64scalar_map.clear();
}
size_t make_op_id() {
return id_fn_map.size();
}
/* TF scalar wrapper */
size_t make_scalar_id() {
size_t ret = i64scalar_map.size();
i64scalar_map[ret] = int64_t();
return ret;
}
bool has_scalar(size_t id) {
return i64scalar_map.find(id) != i64scalar_map.end();
}
int64_t retrieve_scalar(size_t id) {
return i64scalar_map.at(id);
}
/* TF source-code generation */
inline std::string to_tf_ty(ir::type *ty) {
if(ty->is_integer_ty(1))
return "bool";
if(ty->is_integer_ty(8))
return "int8";
if(ty->is_integer_ty(16))
return "int16";
if(ty->is_integer_ty(32))
return "int32";
if(ty->is_integer_ty(64))
return "int64";
if(ty->is_half_ty())
return "float16";
if(ty->is_float_ty())
return "float";
if(ty->is_double_ty())
return "double";
if(ty->is_pointer_ty())
return "Tensor";
throw std::runtime_error("unknown type");
}
inline std::string to_tf_scalar_ty(ir::type *ty) {
if(ty->is_pointer_ty())
return to_tf_ty(ty->get_pointer_element_ty());
else {
return to_tf_ty(ty);
}
}
inline std::string ref_to_tf_ty(ir::type *ty) {
std::string res = to_tf_ty(ty);
if(ty->is_pointer_ty())
res = "const " + res + "&";
return res;
}
std::string tf_normalize(const std::string& name) {
std::string ret = name;
auto tolower = [](char c) { return std::tolower(c);};
std::transform(ret.begin(), ret.end(), ret.begin(), tolower);
return ret;
}
struct tf_alloc_t{
enum type_t{
OUTPUT,
TEMP
};
tf_alloc_t(const std::string& _name, type_t _type)
: name(_name), type(_type), tf_name(tf_normalize(_name)){ }
std::string tf_name;
std::string name;
type_t type;
size_t shape_id;
};
typedef std::vector<tf_alloc_t> alloc_map_t;
void gen_extract_inputs(std::ostream &os, const std::vector<ir::argument*>& args, const alloc_map_t& allocs) {
for(unsigned i = 0; i < args.size(); i++){
ir::value *arg = args[i];
const std::string& name = arg->get_name();
std::string ty = to_tf_ty(arg->get_type());
if(!arg->get_type()->is_pointer_ty())
os << " " << ty << " " << name << " = context->input(" << i << ").scalar<" << ty << ">()();\n ";
else if(std::find_if(allocs.begin(), allocs.end(),
[&](tf_alloc_t x) {
return x.name == name;
}) == allocs.end())
os << " const Tensor* " << name << " = &context->input(" << i << ");\n ";
else
os << " Tensor* " << name << " = nullptr;\n ";
}
}
void gen_set_outputs(std::ostream &os, const std::vector<ir::argument*>& args, const alloc_map_t& allocs) {
// initialize shapes
for(const auto& x: allocs)
os << " TensorShape " << x.name << "_shape;\n ";
for(const auto& x: allocs)
os << " const Tensor& " << x.name << "_shape_tensor = context->input(" << x.shape_id << ");\n ";
for(const auto& x: allocs)
os << " const int32* " << x.name << "_shape_data = (const int32*)" << x.name << "_shape_tensor.tensor_data().data();\n ";
for(const auto& x: allocs)
os << " size_t " << x.name << "_rank = " << x.name << "_shape_tensor.dim_size(0);\n ";
for(const auto& x: allocs)
os << " for(size_t d = 0; d < " << x.name << "_rank ; d++) "
<< x.name << "_shape.AddDim(" << x.name << "_shape_data[d]);\n ";
// allocate
int output = 0;
for(const auto& x: allocs){
if(x.type == tf_alloc_t::OUTPUT)
os << " OP_REQUIRES_OK(context, context->allocate_output(" << output++ << ", " << x.name << "_shape, &" << x.name << "));\n ";
else
os << " OP_REQUIRES_OK(context, context->allocate_temp(" << x.name << "_type, " << x.name << "_shape, " << x.name << "));\n ";
}
}
void gen_make_handles(std::ostream &os, const std::vector<ir::argument*>& args) {
for(unsigned i = 0; i < args.size(); i++){
ir::argument *arg = args[i];
if(!arg->get_type()->is_pointer_ty())
continue;
const std::string& name = arg->get_name();
os << " drv::cu_buffer cu_" + name + "(ctx, " + name + "->tensor_data().size(), (CUdeviceptr)" + name + "->tensor_data().data(), false);\n ";
}
}
void gen_make_launch_function(std::ostream &os, const std::vector<ir::argument*>& args) {
os << " std::function<void()> run = [&](){\n ";
os << " (*id_fn_map.at(id_))({";
for(unsigned i = 0; i < args.size() ; i++){
ir::argument *arg = args[i];
std::string name = arg->get_name();
if(arg->get_type()->is_pointer_ty())
name = "&cu_" + name;
if(i > 0)
os << ", ";
os << name;
}
os << "}, *id_grid_map.at(id_), stream);\n ";
os << " };\n ";
os << " run();\n ";
os << " if(bench_ > 0)\n ";
os << " i64scalar_map[bench_id_] = triton::tools::bench(run, stream);\n ";
}
void gen_tf_register_kernel_builder(std::ostream &os, const std::string &name,
const std::string &opname,
const std::vector<ir::argument*>& args,
const alloc_map_t& allocs){
os << "REGISTER_KERNEL_BUILDER(Name(\"" + name + "\").Device(DEVICE_GPU)";
for(size_t i = 0; i < args.size(); i++){
ir::argument *arg = args[i];
std::string name = tf_normalize(arg->get_name());
if(!arg->get_type()->is_pointer_ty())
os << ".HostMemory(\"" + name + "\")";
}
for(const auto& x: allocs)
os << ".HostMemory(\"" << x.tf_name << "_shape\")";
os << ", " + opname << ");\n";
}
void gen_tf_register_op(std::ostream &os, const std::string &name,
const std::vector<ir::argument*>& args,
const alloc_map_t& allocs){
os << "REGISTER_OP(\"" << name << "\")\n";
for(size_t i = 0; i < args.size(); i++)
os << " .Attr(\"T" << i << " : {bool, int8, int16, int32, int64, float16, float32, float64}\")" << std::endl;
for(size_t i = 0; i < args.size(); i++){
ir::argument *arg = args[i];
std::string name = tf_normalize(arg->get_name());
if(std::find_if(allocs.begin(), allocs.end(),
[&](tf_alloc_t x) {
return name == x.tf_name;
}) == allocs.end())
os << " .Input(\"" << name << ": T" << i << "\")\n";
else
os << " .Input(\"" << name << "_shape: int32\")\n";
}
for(const auto& x: allocs)
if(x.type == tf_alloc_t::OUTPUT)
os << " .Output(\"" << x.tf_name << ": T" << x.shape_id << "\")\n";
os << " .Attr(\"id: int\")\n";
os << " .Attr(\"bench: int\")\n";
os << " .Attr(\"bench_id: int\")\n";
os << " .SetShapeFn([](::tensorflow::shape_inference::InferenceContext* ctx) {\n";
size_t current = 0;
for(const auto& x: allocs)
if(x.type == tf_alloc_t::OUTPUT){
os << " shape_inference::ShapeHandle " << x.tf_name << "_handle;\n";
os << " ctx->MakeShapeFromShapeTensor(" << x.shape_id << ", &" << x.tf_name << "_handle);\n";
os << " ctx->set_output(" << current++ << ", " << x.tf_name << "_handle);\n";
}
os << " return Status::OK();\n";
os << " })\n";
os << ";\n";
}
void make_module(const std::string& src, ir::module* ir,
const runtime::function::options_space_t& opt) {
std::string copy = triton::runtime::function::preheader() + src;
// pre-process
TokenSequence tokens;
Preprocessor cpp(&copy, true);
for(auto it: opt.defines){
cpp.AddMacro(it.first, &it.second[0]);
}
cpp.Process(tokens);
// parse
Parser parser(tokens);
parser.Parse();
Generator gen(&parser);
gen.Gen(ir);
}
std::tuple<std::string,
std::string> make_tensorflow_src(const std::string& src,
const std::vector<std::string>& outputs,
const std::vector<std::string>& tmp,
const runtime::function::options_space_t& opt)
{
// triton-ir code-gen
ir::context ctx;
auto ir = std::shared_ptr<ir::module>(new ir::module("", ctx));
make_module(src, &*ir, opt);
// function
ir::function* fn = ir->get_function_list().front();
const std::vector<ir::argument*>& args = fn->args();
std::string name = fn->get_name();
std::string cc_name = name;
cc_name[0] = static_cast<char>(std::toupper(cc_name[0]));
std::string opname = cc_name + "Op";
// allocation info
alloc_map_t allocs;
for(size_t i = 0; i < outputs.size(); i++)
allocs.push_back(tf_alloc_t(outputs[i], tf_alloc_t::OUTPUT));
for(size_t i = 0; i < tmp.size(); i++)
allocs.push_back(tf_alloc_t(tmp[i], tf_alloc_t::TEMP));
for(auto &x: allocs){
size_t idx;
for(idx = 0; idx < args.size(); idx++)
if(args[idx]->get_name() == x.name)
break;
if(idx == args.size())
throw std::runtime_error("unknown output");
x.shape_id = idx;
}
std::ostringstream oss;
oss << R"(
#include "triton/driver/buffer.h"
#include "triton/driver/backend.h"
#include "triton/driver/stream.h"
#include "triton/runtime/function.h"
#include "triton/tools/bench.hpp"
#define EIGEN_USE_GPU
#include "tensorflow/core/framework/op.h"
#include "tensorflow/core/framework/shape_inference.h"
#include "tensorflow/core/framework/op_kernel.h"
using namespace tensorflow;
using GPUDevice = Eigen::GpuDevice;
namespace rt = triton::runtime;
namespace drv = triton::driver;
extern std::map<size_t, std::shared_ptr<rt::function::grid_fn_ty>> id_grid_map;
extern std::map<size_t, std::shared_ptr<rt::function>> id_fn_map;
extern std::map<size_t, int64_t> i64scalar_map;
class )" << opname << R"(: public OpKernel {
public:
explicit )" << opname << R"((OpKernelConstruction* context)
: OpKernel(context) {
OP_REQUIRES_OK(context, context->GetAttr("id", &id_));
OP_REQUIRES_OK(context, context->GetAttr("bench", &bench_));
OP_REQUIRES_OK(context, context->GetAttr("bench_id", &bench_id_));
)";
for(const auto& alloc: allocs)
oss << " OP_REQUIRES_OK(context, context->GetAttr(\"T" << alloc.shape_id << "\", &" << alloc.name << "_type));\n ";
oss << R"(
}
void Compute(OpKernelContext* context){
// get device/stream
GPUDevice device = context->eigen_device<GPUDevice>();
drv::cu_stream sstream(device.stream(), false);
drv::context* ctx = sstream.context();
drv::stream* stream = &sstream;
// extract inputs
)";
gen_extract_inputs(oss, args, allocs);
oss << R"(
// set outputs
)";
gen_set_outputs(oss, args, allocs);
oss << R"(
// wrap tensors
)";
gen_make_handles(oss, args);
oss << R"(
)";
oss << R"(
// launch function
)";
gen_make_launch_function(oss, args);
oss << R"(
}
private:
int id_;
int bench_;
int64 bench_id_;
)";
for(const auto& alloc: allocs)
oss << "DataType " << alloc.name << "_type;\n ";
oss << R"(
};
// register kernel builder
)";
gen_tf_register_kernel_builder(oss, cc_name, opname, args, allocs);
oss << R"(
// register op
)";
gen_tf_register_op(oss, cc_name, args, allocs);
return {oss.str(), name};
}
inline std::string to_torch_ty(ir::type *ty) {
if(ty->is_integer_ty())
return "int64_t";
if(ty->is_half_ty())
return "double";
if(ty->is_float_ty())
return "double";
if(ty->is_double_ty())
return "double";
if(ty->is_pointer_ty())
return "torch::Tensor";
throw std::runtime_error("unknown type");
}
inline std::string to_c_ty(ir::type *ty) {
if(ty->is_integer_ty(1))
return "bool";
if(ty->is_integer_ty(8))
return "int8_t";
if(ty->is_integer_ty(16))
return "int16_t";
if(ty->is_integer_ty(32))
return "int32_t";
if(ty->is_integer_ty(64))
return "int64_t";
if(ty->is_half_ty())
return "half";
if(ty->is_float_ty())
return "float";
if(ty->is_double_ty())
return "double";
if(ty->is_pointer_ty())
return "drv::cu_buffer";
throw std::runtime_error("unknown type");
}
void gen_torch_signature(std::ostringstream& oss,
ir::function* fn,
const std::vector<std::string>& outputs,
const std::string& name) {
const auto& args = fn->args();
std::vector<ir::type*> out_types;
for(const std::string& out: outputs) {
auto it = std::find_if(args.begin(), args.end(),
[&](ir::argument* arg) { return arg->get_name() == out; });
if(it == args.end())
throw std::runtime_error("unknown argument");
out_types.push_back((*it)->get_type());
}
std::string ret_ty;
if(out_types.empty())
ret_ty = "void";
else{
ir::type* ty = out_types[0];
ret_ty = to_torch_ty(ty);
if(out_types.size() > 1){
for(size_t i = 1; i < out_types.size(); i++)
if(out_types[i] != ty)
throw std::runtime_error("outputs of different types not supported by pytorch");
ret_ty = "std::vector<" + ret_ty + ">";
}
}
oss << ret_ty << " " << name << "(";
oss << "int64_t id, ";
oss << "int64_t bench, ";
oss << "int64_t bench_id, ";
for(size_t i = 0; i < args.size(); i++) {
ir::argument* arg = args[i];
if(i > 0)
oss << ", ";
oss << to_torch_ty(arg->get_type()) << " " << arg->get_name();
}
oss << ")";
}
void gen_torch_init_driver(std::ostringstream &oss,
const std::vector<ir::argument*>&args) {
ir::argument* tensor = nullptr;
for(ir::argument* arg: args)
if(arg->get_type()->is_pointer_ty()){
tensor = arg;
break;
}
oss << " // Wrap CUDA handles" << std::endl;
oss << " c10::DeviceIndex device = " << tensor->get_name() << ".storage().device().index();" << std::endl;
oss << " // Get stream" << std::endl;
oss << " CUstream custream = (CUstream)at::cuda::getCurrentCUDAStream(device).stream();" << std::endl;
oss << " triton::driver::cu_stream stream(custream, false);" << std::endl;
oss << " triton::driver::context* ctx = stream.context();" << std::endl;
}
void gen_torch_make_handles(std::ostream &os,
const std::vector<ir::argument*>& args) {
for(unsigned i = 0; i < args.size(); i++){
ir::argument *arg = args[i];
const std::string& name = arg->get_name();
ir::type* ty = arg->get_type();
if(!ty->is_pointer_ty())
os << " " << to_c_ty(ty) << " arg_" << name << " = " << name << ";" << std::endl;
else{
os << " CHECK_INPUT(" << name << ");" << std::endl;
os << " drv::cu_buffer arg_" + name + "(ctx, " + name + ".storage().size(), "
" (CUdeviceptr)((char*)" + name + ".storage().data() + " + name + ".storage_offset() * " + name + ".itemsize()), false);" << std::endl;
}
}
}
void gen_torch_make_launch_function(std::ostream &os, const std::vector<ir::argument*>& args) {
os << " std::function<void()> run = [&](){\n ";
os << " (*id_fn_map.at(id))({";
for(unsigned i = 0; i < args.size() ; i++){
ir::argument *arg = args[i];
std::string name = "arg_" + arg->get_name();
if(arg->get_type()->is_pointer_ty())
name = "&" + name;
if(i > 0)
os << ", ";
os << name;
}
os << "}, *id_grid_map.at(id), &stream);\n";
os << " };\n";
os << " run();\n";
os << " if(bench > 0)\n ";
os << " i64scalar_map[bench_id] = triton::tools::bench(run, &stream);\n ";
}
void gen_torch_ret(std::ostream &os, const std::vector<std::string>& outputs) {
if(outputs.size() == 1){
os << " return " << outputs[0] << ";" << std::endl;
return;
}
os << " return {";
for(size_t i = 0; i < outputs.size(); i++){
if(i > 0)
os << ", ";
os << outputs[i];
}
os << "};" << std::endl;
}
std::tuple<std::string,
std::string> make_torch_src(const std::string& src,
const std::vector<std::string>& outputs,
const std::vector<std::string>& tmp,
const runtime::function::options_space_t& opt) {
// triton-ir code-gen
ir::context ctx;
auto ir = std::shared_ptr<ir::module>(new ir::module("", ctx));
make_module(src, &*ir, opt);
// function
ir::function* fn = ir->get_function_list().front();
std::string name = fn->get_name();
// generate framework code
std::ostringstream oss;
oss << R"(
#include "triton/driver/buffer.h"
#include "triton/driver/stream.h"
#include "triton/runtime/function.h"
#include "triton/tools/bench.hpp"
#include "torch/script.h"
#include "ATen/cuda/CUDAContext.h"
#define CHECK_CUDA(x) AT_ASSERTM(x.type().is_cuda(), #x " must be a CUDA tensor")
#define CHECK_CONTIGUOUS(x) AT_ASSERTM(x.is_contiguous(), #x " must be contiguous")
#define CHECK_INPUT(x) CHECK_CUDA(x);
namespace rt = triton::runtime;
namespace drv = triton::driver;
extern std::map<size_t, std::shared_ptr<rt::function::grid_fn_ty>> id_grid_map;
extern std::map<size_t, std::shared_ptr<rt::function>> id_fn_map;
extern std::map<size_t, int64_t> i64scalar_map;
)";
gen_torch_signature(oss, fn, outputs, name);
oss << " {" << std::endl;
gen_torch_init_driver(oss, fn->args());
gen_torch_make_handles(oss, fn->args());
gen_torch_make_launch_function(oss, fn->args());
gen_torch_ret(oss, outputs);
oss << "}" << std::endl;
oss << std::endl;
oss << std::endl;
oss << "static auto registry = torch::RegisterOperators(\"triton::" << name << "\", &" << name << ");" << std::endl;
return {oss.str(), name};
}
typedef triton::runtime::function::options_t options_t;
typedef triton::runtime::function::options_space_t options_space_t;
PYBIND11_MODULE(libtriton, m) {
m.doc() = "Python bindings to the C++ Triton API";
// framework binding source code generation
m.def("make_tensorflow_src", &make_tensorflow_src,
"Creates C++ source code for a custom Tensorflow op "
"corresponding to the specified Triton kernel");
m.def("make_torch_src", &make_torch_src,
"Creates C++ source code for a custom PyTorch op ");
// bindings for triton classes
pybind11::class_<options_t>(m, "options")
.def(pybind11::init<>())
.def("d", &options_t::D<int>)
.def_readonly("num_warps", &options_t::num_warps);
pybind11::class_<options_space_t>(m, "options_space")
.def(pybind11::init<>())
.def_readwrite("defines", &options_space_t::defines)
.def_readwrite("num_warps", &options_space_t::num_warps);
// hooks into triton constructs since frameworks may not use pybind11
m.def("register_grid", &register_grid);
m.def("delete_grid", &delete_grid);
m.def("register_fn", &register_fn);
m.def("register_cst", &register_cst);
m.def("delete_fn", &delete_fn);
m.def("make_op_id", &make_op_id);
m.def("make_scalar_id", &make_scalar_id);
m.def("retrieve_scalar", &retrieve_scalar);
m.def("cleanup", &cleanup);
;
}

493
python/src/pybind11/attr.h Normal file
View File

@@ -0,0 +1,493 @@
/*
pybind11/attr.h: Infrastructure for processing custom
type and function attributes
Copyright (c) 2016 Wenzel Jakob <wenzel.jakob@epfl.ch>
All rights reserved. Use of this source code is governed by a
BSD-style license that can be found in the LICENSE file.
*/
#pragma once
#include "cast.h"
NAMESPACE_BEGIN(PYBIND11_NAMESPACE)
/// \addtogroup annotations
/// @{
/// Annotation for methods
struct is_method { handle class_; is_method(const handle &c) : class_(c) { } };
/// Annotation for operators
struct is_operator { };
/// Annotation for parent scope
struct scope { handle value; scope(const handle &s) : value(s) { } };
/// Annotation for documentation
struct doc { const char *value; doc(const char *value) : value(value) { } };
/// Annotation for function names
struct name { const char *value; name(const char *value) : value(value) { } };
/// Annotation indicating that a function is an overload associated with a given "sibling"
struct sibling { handle value; sibling(const handle &value) : value(value.ptr()) { } };
/// Annotation indicating that a class derives from another given type
template <typename T> struct base {
PYBIND11_DEPRECATED("base<T>() was deprecated in favor of specifying 'T' as a template argument to class_")
base() { }
};
/// Keep patient alive while nurse lives
template <size_t Nurse, size_t Patient> struct keep_alive { };
/// Annotation indicating that a class is involved in a multiple inheritance relationship
struct multiple_inheritance { };
/// Annotation which enables dynamic attributes, i.e. adds `__dict__` to a class
struct dynamic_attr { };
/// Annotation which enables the buffer protocol for a type
struct buffer_protocol { };
/// Annotation which requests that a special metaclass is created for a type
struct metaclass {
handle value;
PYBIND11_DEPRECATED("py::metaclass() is no longer required. It's turned on by default now.")
metaclass() {}
/// Override pybind11's default metaclass
explicit metaclass(handle value) : value(value) { }
};
/// Annotation that marks a class as local to the module:
struct module_local { const bool value; constexpr module_local(bool v = true) : value(v) { } };
/// Annotation to mark enums as an arithmetic type
struct arithmetic { };
/** \rst
A call policy which places one or more guard variables (``Ts...``) around the function call.
For example, this definition:
.. code-block:: cpp
m.def("foo", foo, py::call_guard<T>());
is equivalent to the following pseudocode:
.. code-block:: cpp
m.def("foo", [](args...) {
T scope_guard;
return foo(args...); // forwarded arguments
});
\endrst */
template <typename... Ts> struct call_guard;
template <> struct call_guard<> { using type = detail::void_type; };
template <typename T>
struct call_guard<T> {
static_assert(std::is_default_constructible<T>::value,
"The guard type must be default constructible");
using type = T;
};
template <typename T, typename... Ts>
struct call_guard<T, Ts...> {
struct type {
T guard{}; // Compose multiple guard types with left-to-right default-constructor order
typename call_guard<Ts...>::type next{};
};
};
/// @} annotations
NAMESPACE_BEGIN(detail)
/* Forward declarations */
enum op_id : int;
enum op_type : int;
struct undefined_t;
template <op_id id, op_type ot, typename L = undefined_t, typename R = undefined_t> struct op_;
inline void keep_alive_impl(size_t Nurse, size_t Patient, function_call &call, handle ret);
/// Internal data structure which holds metadata about a keyword argument
struct argument_record {
const char *name; ///< Argument name
const char *descr; ///< Human-readable version of the argument value
handle value; ///< Associated Python object
bool convert : 1; ///< True if the argument is allowed to convert when loading
bool none : 1; ///< True if None is allowed when loading
argument_record(const char *name, const char *descr, handle value, bool convert, bool none)
: name(name), descr(descr), value(value), convert(convert), none(none) { }
};
/// Internal data structure which holds metadata about a bound function (signature, overloads, etc.)
struct function_record {
function_record()
: is_constructor(false), is_new_style_constructor(false), is_stateless(false),
is_operator(false), has_args(false), has_kwargs(false), is_method(false) { }
/// Function name
char *name = nullptr; /* why no C++ strings? They generate heavier code.. */
// User-specified documentation string
char *doc = nullptr;
/// Human-readable version of the function signature
char *signature = nullptr;
/// List of registered keyword arguments
std::vector<argument_record> args;
/// Pointer to lambda function which converts arguments and performs the actual call
handle (*impl) (function_call &) = nullptr;
/// Storage for the wrapped function pointer and captured data, if any
void *data[3] = { };
/// Pointer to custom destructor for 'data' (if needed)
void (*free_data) (function_record *ptr) = nullptr;
/// Return value policy associated with this function
return_value_policy policy = return_value_policy::automatic;
/// True if name == '__init__'
bool is_constructor : 1;
/// True if this is a new-style `__init__` defined in `detail/init.h`
bool is_new_style_constructor : 1;
/// True if this is a stateless function pointer
bool is_stateless : 1;
/// True if this is an operator (__add__), etc.
bool is_operator : 1;
/// True if the function has a '*args' argument
bool has_args : 1;
/// True if the function has a '**kwargs' argument
bool has_kwargs : 1;
/// True if this is a method
bool is_method : 1;
/// Number of arguments (including py::args and/or py::kwargs, if present)
std::uint16_t nargs;
/// Python method object
PyMethodDef *def = nullptr;
/// Python handle to the parent scope (a class or a module)
handle scope;
/// Python handle to the sibling function representing an overload chain
handle sibling;
/// Pointer to next overload
function_record *next = nullptr;
};
/// Special data structure which (temporarily) holds metadata about a bound class
struct type_record {
PYBIND11_NOINLINE type_record()
: multiple_inheritance(false), dynamic_attr(false), buffer_protocol(false),
default_holder(true), module_local(false) { }
/// Handle to the parent scope
handle scope;
/// Name of the class
const char *name = nullptr;
// Pointer to RTTI type_info data structure
const std::type_info *type = nullptr;
/// How large is the underlying C++ type?
size_t type_size = 0;
/// What is the alignment of the underlying C++ type?
size_t type_align = 0;
/// How large is the type's holder?
size_t holder_size = 0;
/// The global operator new can be overridden with a class-specific variant
void *(*operator_new)(size_t) = nullptr;
/// Function pointer to class_<..>::init_instance
void (*init_instance)(instance *, const void *) = nullptr;
/// Function pointer to class_<..>::dealloc
void (*dealloc)(detail::value_and_holder &) = nullptr;
/// List of base classes of the newly created type
list bases;
/// Optional docstring
const char *doc = nullptr;
/// Custom metaclass (optional)
handle metaclass;
/// Multiple inheritance marker
bool multiple_inheritance : 1;
/// Does the class manage a __dict__?
bool dynamic_attr : 1;
/// Does the class implement the buffer protocol?
bool buffer_protocol : 1;
/// Is the default (unique_ptr) holder type used?
bool default_holder : 1;
/// Is the class definition local to the module shared object?
bool module_local : 1;
PYBIND11_NOINLINE void add_base(const std::type_info &base, void *(*caster)(void *)) {
auto base_info = detail::get_type_info(base, false);
if (!base_info) {
std::string tname(base.name());
detail::clean_type_id(tname);
pybind11_fail("generic_type: type \"" + std::string(name) +
"\" referenced unknown base type \"" + tname + "\"");
}
if (default_holder != base_info->default_holder) {
std::string tname(base.name());
detail::clean_type_id(tname);
pybind11_fail("generic_type: type \"" + std::string(name) + "\" " +
(default_holder ? "does not have" : "has") +
" a non-default holder type while its base \"" + tname + "\" " +
(base_info->default_holder ? "does not" : "does"));
}
bases.append((PyObject *) base_info->type);
if (base_info->type->tp_dictoffset != 0)
dynamic_attr = true;
if (caster)
base_info->implicit_casts.emplace_back(type, caster);
}
};
inline function_call::function_call(const function_record &f, handle p) :
func(f), parent(p) {
args.reserve(f.nargs);
args_convert.reserve(f.nargs);
}
/// Tag for a new-style `__init__` defined in `detail/init.h`
struct is_new_style_constructor { };
/**
* Partial template specializations to process custom attributes provided to
* cpp_function_ and class_. These are either used to initialize the respective
* fields in the type_record and function_record data structures or executed at
* runtime to deal with custom call policies (e.g. keep_alive).
*/
template <typename T, typename SFINAE = void> struct process_attribute;
template <typename T> struct process_attribute_default {
/// Default implementation: do nothing
static void init(const T &, function_record *) { }
static void init(const T &, type_record *) { }
static void precall(function_call &) { }
static void postcall(function_call &, handle) { }
};
/// Process an attribute specifying the function's name
template <> struct process_attribute<name> : process_attribute_default<name> {
static void init(const name &n, function_record *r) { r->name = const_cast<char *>(n.value); }
};
/// Process an attribute specifying the function's docstring
template <> struct process_attribute<doc> : process_attribute_default<doc> {
static void init(const doc &n, function_record *r) { r->doc = const_cast<char *>(n.value); }
};
/// Process an attribute specifying the function's docstring (provided as a C-style string)
template <> struct process_attribute<const char *> : process_attribute_default<const char *> {
static void init(const char *d, function_record *r) { r->doc = const_cast<char *>(d); }
static void init(const char *d, type_record *r) { r->doc = const_cast<char *>(d); }
};
template <> struct process_attribute<char *> : process_attribute<const char *> { };
/// Process an attribute indicating the function's return value policy
template <> struct process_attribute<return_value_policy> : process_attribute_default<return_value_policy> {
static void init(const return_value_policy &p, function_record *r) { r->policy = p; }
};
/// Process an attribute which indicates that this is an overloaded function associated with a given sibling
template <> struct process_attribute<sibling> : process_attribute_default<sibling> {
static void init(const sibling &s, function_record *r) { r->sibling = s.value; }
};
/// Process an attribute which indicates that this function is a method
template <> struct process_attribute<is_method> : process_attribute_default<is_method> {
static void init(const is_method &s, function_record *r) { r->is_method = true; r->scope = s.class_; }
};
/// Process an attribute which indicates the parent scope of a method
template <> struct process_attribute<scope> : process_attribute_default<scope> {
static void init(const scope &s, function_record *r) { r->scope = s.value; }
};
/// Process an attribute which indicates that this function is an operator
template <> struct process_attribute<is_operator> : process_attribute_default<is_operator> {
static void init(const is_operator &, function_record *r) { r->is_operator = true; }
};
template <> struct process_attribute<is_new_style_constructor> : process_attribute_default<is_new_style_constructor> {
static void init(const is_new_style_constructor &, function_record *r) { r->is_new_style_constructor = true; }
};
/// Process a keyword argument attribute (*without* a default value)
template <> struct process_attribute<arg> : process_attribute_default<arg> {
static void init(const arg &a, function_record *r) {
if (r->is_method && r->args.empty())
r->args.emplace_back("self", nullptr, handle(), true /*convert*/, false /*none not allowed*/);
r->args.emplace_back(a.name, nullptr, handle(), !a.flag_noconvert, a.flag_none);
}
};
/// Process a keyword argument attribute (*with* a default value)
template <> struct process_attribute<arg_v> : process_attribute_default<arg_v> {
static void init(const arg_v &a, function_record *r) {
if (r->is_method && r->args.empty())
r->args.emplace_back("self", nullptr /*descr*/, handle() /*parent*/, true /*convert*/, false /*none not allowed*/);
if (!a.value) {
#if !defined(NDEBUG)
std::string descr("'");
if (a.name) descr += std::string(a.name) + ": ";
descr += a.type + "'";
if (r->is_method) {
if (r->name)
descr += " in method '" + (std::string) str(r->scope) + "." + (std::string) r->name + "'";
else
descr += " in method of '" + (std::string) str(r->scope) + "'";
} else if (r->name) {
descr += " in function '" + (std::string) r->name + "'";
}
pybind11_fail("arg(): could not convert default argument "
+ descr + " into a Python object (type not registered yet?)");
#else
pybind11_fail("arg(): could not convert default argument "
"into a Python object (type not registered yet?). "
"Compile in debug mode for more information.");
#endif
}
r->args.emplace_back(a.name, a.descr, a.value.inc_ref(), !a.flag_noconvert, a.flag_none);
}
};
/// Process a parent class attribute. Single inheritance only (class_ itself already guarantees that)
template <typename T>
struct process_attribute<T, enable_if_t<is_pyobject<T>::value>> : process_attribute_default<handle> {
static void init(const handle &h, type_record *r) { r->bases.append(h); }
};
/// Process a parent class attribute (deprecated, does not support multiple inheritance)
template <typename T>
struct process_attribute<base<T>> : process_attribute_default<base<T>> {
static void init(const base<T> &, type_record *r) { r->add_base(typeid(T), nullptr); }
};
/// Process a multiple inheritance attribute
template <>
struct process_attribute<multiple_inheritance> : process_attribute_default<multiple_inheritance> {
static void init(const multiple_inheritance &, type_record *r) { r->multiple_inheritance = true; }
};
template <>
struct process_attribute<dynamic_attr> : process_attribute_default<dynamic_attr> {
static void init(const dynamic_attr &, type_record *r) { r->dynamic_attr = true; }
};
template <>
struct process_attribute<buffer_protocol> : process_attribute_default<buffer_protocol> {
static void init(const buffer_protocol &, type_record *r) { r->buffer_protocol = true; }
};
template <>
struct process_attribute<metaclass> : process_attribute_default<metaclass> {
static void init(const metaclass &m, type_record *r) { r->metaclass = m.value; }
};
template <>
struct process_attribute<module_local> : process_attribute_default<module_local> {
static void init(const module_local &l, type_record *r) { r->module_local = l.value; }
};
/// Process an 'arithmetic' attribute for enums (does nothing here)
template <>
struct process_attribute<arithmetic> : process_attribute_default<arithmetic> {};
template <typename... Ts>
struct process_attribute<call_guard<Ts...>> : process_attribute_default<call_guard<Ts...>> { };
/**
* Process a keep_alive call policy -- invokes keep_alive_impl during the
* pre-call handler if both Nurse, Patient != 0 and use the post-call handler
* otherwise
*/
template <size_t Nurse, size_t Patient> struct process_attribute<keep_alive<Nurse, Patient>> : public process_attribute_default<keep_alive<Nurse, Patient>> {
template <size_t N = Nurse, size_t P = Patient, enable_if_t<N != 0 && P != 0, int> = 0>
static void precall(function_call &call) { keep_alive_impl(Nurse, Patient, call, handle()); }
template <size_t N = Nurse, size_t P = Patient, enable_if_t<N != 0 && P != 0, int> = 0>
static void postcall(function_call &, handle) { }
template <size_t N = Nurse, size_t P = Patient, enable_if_t<N == 0 || P == 0, int> = 0>
static void precall(function_call &) { }
template <size_t N = Nurse, size_t P = Patient, enable_if_t<N == 0 || P == 0, int> = 0>
static void postcall(function_call &call, handle ret) { keep_alive_impl(Nurse, Patient, call, ret); }
};
/// Recursively iterate over variadic template arguments
template <typename... Args> struct process_attributes {
static void init(const Args&... args, function_record *r) {
int unused[] = { 0, (process_attribute<typename std::decay<Args>::type>::init(args, r), 0) ... };
ignore_unused(unused);
}
static void init(const Args&... args, type_record *r) {
int unused[] = { 0, (process_attribute<typename std::decay<Args>::type>::init(args, r), 0) ... };
ignore_unused(unused);
}
static void precall(function_call &call) {
int unused[] = { 0, (process_attribute<typename std::decay<Args>::type>::precall(call), 0) ... };
ignore_unused(unused);
}
static void postcall(function_call &call, handle fn_ret) {
int unused[] = { 0, (process_attribute<typename std::decay<Args>::type>::postcall(call, fn_ret), 0) ... };
ignore_unused(unused);
}
};
template <typename T>
using is_call_guard = is_instantiation<call_guard, T>;
/// Extract the ``type`` from the first `call_guard` in `Extras...` (or `void_type` if none found)
template <typename... Extra>
using extract_guard_t = typename exactly_one_t<is_call_guard, call_guard<>, Extra...>::type;
/// Check the number of named arguments at compile time
template <typename... Extra,
size_t named = constexpr_sum(std::is_base_of<arg, Extra>::value...),
size_t self = constexpr_sum(std::is_same<is_method, Extra>::value...)>
constexpr bool expected_num_args(size_t nargs, bool has_args, bool has_kwargs) {
return named == 0 || (self + named + has_args + has_kwargs) == nargs;
}
NAMESPACE_END(detail)
NAMESPACE_END(PYBIND11_NAMESPACE)

View File

@@ -0,0 +1,108 @@
/*
pybind11/buffer_info.h: Python buffer object interface
Copyright (c) 2016 Wenzel Jakob <wenzel.jakob@epfl.ch>
All rights reserved. Use of this source code is governed by a
BSD-style license that can be found in the LICENSE file.
*/
#pragma once
#include "detail/common.h"
NAMESPACE_BEGIN(PYBIND11_NAMESPACE)
/// Information record describing a Python buffer object
struct buffer_info {
void *ptr = nullptr; // Pointer to the underlying storage
ssize_t itemsize = 0; // Size of individual items in bytes
ssize_t size = 0; // Total number of entries
std::string format; // For homogeneous buffers, this should be set to format_descriptor<T>::format()
ssize_t ndim = 0; // Number of dimensions
std::vector<ssize_t> shape; // Shape of the tensor (1 entry per dimension)
std::vector<ssize_t> strides; // Number of entries between adjacent entries (for each per dimension)
buffer_info() { }
buffer_info(void *ptr, ssize_t itemsize, const std::string &format, ssize_t ndim,
detail::any_container<ssize_t> shape_in, detail::any_container<ssize_t> strides_in)
: ptr(ptr), itemsize(itemsize), size(1), format(format), ndim(ndim),
shape(std::move(shape_in)), strides(std::move(strides_in)) {
if (ndim != (ssize_t) shape.size() || ndim != (ssize_t) strides.size())
pybind11_fail("buffer_info: ndim doesn't match shape and/or strides length");
for (size_t i = 0; i < (size_t) ndim; ++i)
size *= shape[i];
}
template <typename T>
buffer_info(T *ptr, detail::any_container<ssize_t> shape_in, detail::any_container<ssize_t> strides_in)
: buffer_info(private_ctr_tag(), ptr, sizeof(T), format_descriptor<T>::format(), static_cast<ssize_t>(shape_in->size()), std::move(shape_in), std::move(strides_in)) { }
buffer_info(void *ptr, ssize_t itemsize, const std::string &format, ssize_t size)
: buffer_info(ptr, itemsize, format, 1, {size}, {itemsize}) { }
template <typename T>
buffer_info(T *ptr, ssize_t size)
: buffer_info(ptr, sizeof(T), format_descriptor<T>::format(), size) { }
explicit buffer_info(Py_buffer *view, bool ownview = true)
: buffer_info(view->buf, view->itemsize, view->format, view->ndim,
{view->shape, view->shape + view->ndim}, {view->strides, view->strides + view->ndim}) {
this->view = view;
this->ownview = ownview;
}
buffer_info(const buffer_info &) = delete;
buffer_info& operator=(const buffer_info &) = delete;
buffer_info(buffer_info &&other) {
(*this) = std::move(other);
}
buffer_info& operator=(buffer_info &&rhs) {
ptr = rhs.ptr;
itemsize = rhs.itemsize;
size = rhs.size;
format = std::move(rhs.format);
ndim = rhs.ndim;
shape = std::move(rhs.shape);
strides = std::move(rhs.strides);
std::swap(view, rhs.view);
std::swap(ownview, rhs.ownview);
return *this;
}
~buffer_info() {
if (view && ownview) { PyBuffer_Release(view); delete view; }
}
private:
struct private_ctr_tag { };
buffer_info(private_ctr_tag, void *ptr, ssize_t itemsize, const std::string &format, ssize_t ndim,
detail::any_container<ssize_t> &&shape_in, detail::any_container<ssize_t> &&strides_in)
: buffer_info(ptr, itemsize, format, ndim, std::move(shape_in), std::move(strides_in)) { }
Py_buffer *view = nullptr;
bool ownview = false;
};
NAMESPACE_BEGIN(detail)
template <typename T, typename SFINAE = void> struct compare_buffer_info {
static bool compare(const buffer_info& b) {
return b.format == format_descriptor<T>::format() && b.itemsize == (ssize_t) sizeof(T);
}
};
template <typename T> struct compare_buffer_info<T, detail::enable_if_t<std::is_integral<T>::value>> {
static bool compare(const buffer_info& b) {
return (size_t) b.itemsize == sizeof(T) && (b.format == format_descriptor<T>::value ||
((sizeof(T) == sizeof(long)) && b.format == (std::is_unsigned<T>::value ? "L" : "l")) ||
((sizeof(T) == sizeof(size_t)) && b.format == (std::is_unsigned<T>::value ? "N" : "n")));
}
};
NAMESPACE_END(detail)
NAMESPACE_END(PYBIND11_NAMESPACE)

2128
python/src/pybind11/cast.h Normal file

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,162 @@
/*
pybind11/chrono.h: Transparent conversion between std::chrono and python's datetime
Copyright (c) 2016 Trent Houliston <trent@houliston.me> and
Wenzel Jakob <wenzel.jakob@epfl.ch>
All rights reserved. Use of this source code is governed by a
BSD-style license that can be found in the LICENSE file.
*/
#pragma once
#include "pybind11.h"
#include <cmath>
#include <ctime>
#include <chrono>
#include <datetime.h>
// Backport the PyDateTime_DELTA functions from Python3.3 if required
#ifndef PyDateTime_DELTA_GET_DAYS
#define PyDateTime_DELTA_GET_DAYS(o) (((PyDateTime_Delta*)o)->days)
#endif
#ifndef PyDateTime_DELTA_GET_SECONDS
#define PyDateTime_DELTA_GET_SECONDS(o) (((PyDateTime_Delta*)o)->seconds)
#endif
#ifndef PyDateTime_DELTA_GET_MICROSECONDS
#define PyDateTime_DELTA_GET_MICROSECONDS(o) (((PyDateTime_Delta*)o)->microseconds)
#endif
NAMESPACE_BEGIN(PYBIND11_NAMESPACE)
NAMESPACE_BEGIN(detail)
template <typename type> class duration_caster {
public:
typedef typename type::rep rep;
typedef typename type::period period;
typedef std::chrono::duration<uint_fast32_t, std::ratio<86400>> days;
bool load(handle src, bool) {
using namespace std::chrono;
// Lazy initialise the PyDateTime import
if (!PyDateTimeAPI) { PyDateTime_IMPORT; }
if (!src) return false;
// If invoked with datetime.delta object
if (PyDelta_Check(src.ptr())) {
value = type(duration_cast<duration<rep, period>>(
days(PyDateTime_DELTA_GET_DAYS(src.ptr()))
+ seconds(PyDateTime_DELTA_GET_SECONDS(src.ptr()))
+ microseconds(PyDateTime_DELTA_GET_MICROSECONDS(src.ptr()))));
return true;
}
// If invoked with a float we assume it is seconds and convert
else if (PyFloat_Check(src.ptr())) {
value = type(duration_cast<duration<rep, period>>(duration<double>(PyFloat_AsDouble(src.ptr()))));
return true;
}
else return false;
}
// If this is a duration just return it back
static const std::chrono::duration<rep, period>& get_duration(const std::chrono::duration<rep, period> &src) {
return src;
}
// If this is a time_point get the time_since_epoch
template <typename Clock> static std::chrono::duration<rep, period> get_duration(const std::chrono::time_point<Clock, std::chrono::duration<rep, period>> &src) {
return src.time_since_epoch();
}
static handle cast(const type &src, return_value_policy /* policy */, handle /* parent */) {
using namespace std::chrono;
// Use overloaded function to get our duration from our source
// Works out if it is a duration or time_point and get the duration
auto d = get_duration(src);
// Lazy initialise the PyDateTime import
if (!PyDateTimeAPI) { PyDateTime_IMPORT; }
// Declare these special duration types so the conversions happen with the correct primitive types (int)
using dd_t = duration<int, std::ratio<86400>>;
using ss_t = duration<int, std::ratio<1>>;
using us_t = duration<int, std::micro>;
auto dd = duration_cast<dd_t>(d);
auto subd = d - dd;
auto ss = duration_cast<ss_t>(subd);
auto us = duration_cast<us_t>(subd - ss);
return PyDelta_FromDSU(dd.count(), ss.count(), us.count());
}
PYBIND11_TYPE_CASTER(type, _("datetime.timedelta"));
};
// This is for casting times on the system clock into datetime.datetime instances
template <typename Duration> class type_caster<std::chrono::time_point<std::chrono::system_clock, Duration>> {
public:
typedef std::chrono::time_point<std::chrono::system_clock, Duration> type;
bool load(handle src, bool) {
using namespace std::chrono;
// Lazy initialise the PyDateTime import
if (!PyDateTimeAPI) { PyDateTime_IMPORT; }
if (!src) return false;
if (PyDateTime_Check(src.ptr())) {
std::tm cal;
cal.tm_sec = PyDateTime_DATE_GET_SECOND(src.ptr());
cal.tm_min = PyDateTime_DATE_GET_MINUTE(src.ptr());
cal.tm_hour = PyDateTime_DATE_GET_HOUR(src.ptr());
cal.tm_mday = PyDateTime_GET_DAY(src.ptr());
cal.tm_mon = PyDateTime_GET_MONTH(src.ptr()) - 1;
cal.tm_year = PyDateTime_GET_YEAR(src.ptr()) - 1900;
cal.tm_isdst = -1;
value = system_clock::from_time_t(std::mktime(&cal)) + microseconds(PyDateTime_DATE_GET_MICROSECOND(src.ptr()));
return true;
}
else return false;
}
static handle cast(const std::chrono::time_point<std::chrono::system_clock, Duration> &src, return_value_policy /* policy */, handle /* parent */) {
using namespace std::chrono;
// Lazy initialise the PyDateTime import
if (!PyDateTimeAPI) { PyDateTime_IMPORT; }
std::time_t tt = system_clock::to_time_t(src);
// this function uses static memory so it's best to copy it out asap just in case
// otherwise other code that is using localtime may break this (not just python code)
std::tm localtime = *std::localtime(&tt);
// Declare these special duration types so the conversions happen with the correct primitive types (int)
using us_t = duration<int, std::micro>;
return PyDateTime_FromDateAndTime(localtime.tm_year + 1900,
localtime.tm_mon + 1,
localtime.tm_mday,
localtime.tm_hour,
localtime.tm_min,
localtime.tm_sec,
(duration_cast<us_t>(src.time_since_epoch() % seconds(1))).count());
}
PYBIND11_TYPE_CASTER(type, _("datetime.datetime"));
};
// Other clocks that are not the system clock are not measured as datetime.datetime objects
// since they are not measured on calendar time. So instead we just make them timedeltas
// Or if they have passed us a time as a float we convert that
template <typename Clock, typename Duration> class type_caster<std::chrono::time_point<Clock, Duration>>
: public duration_caster<std::chrono::time_point<Clock, Duration>> {
};
template <typename Rep, typename Period> class type_caster<std::chrono::duration<Rep, Period>>
: public duration_caster<std::chrono::duration<Rep, Period>> {
};
NAMESPACE_END(detail)
NAMESPACE_END(PYBIND11_NAMESPACE)

View File

@@ -0,0 +1,2 @@
#include "detail/common.h"
#warning "Including 'common.h' is deprecated. It will be removed in v3.0. Use 'pybind11.h'."

View File

@@ -0,0 +1,65 @@
/*
pybind11/complex.h: Complex number support
Copyright (c) 2016 Wenzel Jakob <wenzel.jakob@epfl.ch>
All rights reserved. Use of this source code is governed by a
BSD-style license that can be found in the LICENSE file.
*/
#pragma once
#include "pybind11.h"
#include <complex>
/// glibc defines I as a macro which breaks things, e.g., boost template names
#ifdef I
# undef I
#endif
NAMESPACE_BEGIN(PYBIND11_NAMESPACE)
template <typename T> struct format_descriptor<std::complex<T>, detail::enable_if_t<std::is_floating_point<T>::value>> {
static constexpr const char c = format_descriptor<T>::c;
static constexpr const char value[3] = { 'Z', c, '\0' };
static std::string format() { return std::string(value); }
};
#ifndef PYBIND11_CPP17
template <typename T> constexpr const char format_descriptor<
std::complex<T>, detail::enable_if_t<std::is_floating_point<T>::value>>::value[3];
#endif
NAMESPACE_BEGIN(detail)
template <typename T> struct is_fmt_numeric<std::complex<T>, detail::enable_if_t<std::is_floating_point<T>::value>> {
static constexpr bool value = true;
static constexpr int index = is_fmt_numeric<T>::index + 3;
};
template <typename T> class type_caster<std::complex<T>> {
public:
bool load(handle src, bool convert) {
if (!src)
return false;
if (!convert && !PyComplex_Check(src.ptr()))
return false;
Py_complex result = PyComplex_AsCComplex(src.ptr());
if (result.real == -1.0 && PyErr_Occurred()) {
PyErr_Clear();
return false;
}
value = std::complex<T>((T) result.real, (T) result.imag);
return true;
}
static handle cast(const std::complex<T> &src, return_value_policy /* policy */, handle /* parent */) {
return PyComplex_FromDoubles((double) src.real(), (double) src.imag());
}
PYBIND11_TYPE_CASTER(std::complex<T>, _("complex"));
};
NAMESPACE_END(detail)
NAMESPACE_END(PYBIND11_NAMESPACE)

View File

@@ -0,0 +1,623 @@
/*
pybind11/detail/class.h: Python C API implementation details for py::class_
Copyright (c) 2017 Wenzel Jakob <wenzel.jakob@epfl.ch>
All rights reserved. Use of this source code is governed by a
BSD-style license that can be found in the LICENSE file.
*/
#pragma once
#include "../attr.h"
#include "../options.h"
NAMESPACE_BEGIN(PYBIND11_NAMESPACE)
NAMESPACE_BEGIN(detail)
#if PY_VERSION_HEX >= 0x03030000
# define PYBIND11_BUILTIN_QUALNAME
# define PYBIND11_SET_OLDPY_QUALNAME(obj, nameobj)
#else
// In pre-3.3 Python, we still set __qualname__ so that we can produce reliable function type
// signatures; in 3.3+ this macro expands to nothing:
# define PYBIND11_SET_OLDPY_QUALNAME(obj, nameobj) setattr((PyObject *) obj, "__qualname__", nameobj)
#endif
inline PyTypeObject *type_incref(PyTypeObject *type) {
Py_INCREF(type);
return type;
}
#if !defined(PYPY_VERSION)
/// `pybind11_static_property.__get__()`: Always pass the class instead of the instance.
extern "C" inline PyObject *pybind11_static_get(PyObject *self, PyObject * /*ob*/, PyObject *cls) {
return PyProperty_Type.tp_descr_get(self, cls, cls);
}
/// `pybind11_static_property.__set__()`: Just like the above `__get__()`.
extern "C" inline int pybind11_static_set(PyObject *self, PyObject *obj, PyObject *value) {
PyObject *cls = PyType_Check(obj) ? obj : (PyObject *) Py_TYPE(obj);
return PyProperty_Type.tp_descr_set(self, cls, value);
}
/** A `static_property` is the same as a `property` but the `__get__()` and `__set__()`
methods are modified to always use the object type instead of a concrete instance.
Return value: New reference. */
inline PyTypeObject *make_static_property_type() {
constexpr auto *name = "pybind11_static_property";
auto name_obj = reinterpret_steal<object>(PYBIND11_FROM_STRING(name));
/* Danger zone: from now (and until PyType_Ready), make sure to
issue no Python C API calls which could potentially invoke the
garbage collector (the GC will call type_traverse(), which will in
turn find the newly constructed type in an invalid state) */
auto heap_type = (PyHeapTypeObject *) PyType_Type.tp_alloc(&PyType_Type, 0);
if (!heap_type)
pybind11_fail("make_static_property_type(): error allocating type!");
heap_type->ht_name = name_obj.inc_ref().ptr();
#ifdef PYBIND11_BUILTIN_QUALNAME
heap_type->ht_qualname = name_obj.inc_ref().ptr();
#endif
auto type = &heap_type->ht_type;
type->tp_name = name;
type->tp_base = type_incref(&PyProperty_Type);
type->tp_flags = Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE | Py_TPFLAGS_HEAPTYPE;
type->tp_descr_get = pybind11_static_get;
type->tp_descr_set = pybind11_static_set;
if (PyType_Ready(type) < 0)
pybind11_fail("make_static_property_type(): failure in PyType_Ready()!");
setattr((PyObject *) type, "__module__", str("pybind11_builtins"));
PYBIND11_SET_OLDPY_QUALNAME(type, name_obj);
return type;
}
#else // PYPY
/** PyPy has some issues with the above C API, so we evaluate Python code instead.
This function will only be called once so performance isn't really a concern.
Return value: New reference. */
inline PyTypeObject *make_static_property_type() {
auto d = dict();
PyObject *result = PyRun_String(R"(\
class pybind11_static_property(property):
def __get__(self, obj, cls):
return property.__get__(self, cls, cls)
def __set__(self, obj, value):
cls = obj if isinstance(obj, type) else type(obj)
property.__set__(self, cls, value)
)", Py_file_input, d.ptr(), d.ptr()
);
if (result == nullptr)
throw error_already_set();
Py_DECREF(result);
return (PyTypeObject *) d["pybind11_static_property"].cast<object>().release().ptr();
}
#endif // PYPY
/** Types with static properties need to handle `Type.static_prop = x` in a specific way.
By default, Python replaces the `static_property` itself, but for wrapped C++ types
we need to call `static_property.__set__()` in order to propagate the new value to
the underlying C++ data structure. */
extern "C" inline int pybind11_meta_setattro(PyObject* obj, PyObject* name, PyObject* value) {
// Use `_PyType_Lookup()` instead of `PyObject_GetAttr()` in order to get the raw
// descriptor (`property`) instead of calling `tp_descr_get` (`property.__get__()`).
PyObject *descr = _PyType_Lookup((PyTypeObject *) obj, name);
// The following assignment combinations are possible:
// 1. `Type.static_prop = value` --> descr_set: `Type.static_prop.__set__(value)`
// 2. `Type.static_prop = other_static_prop` --> setattro: replace existing `static_prop`
// 3. `Type.regular_attribute = value` --> setattro: regular attribute assignment
const auto static_prop = (PyObject *) get_internals().static_property_type;
const auto call_descr_set = descr && PyObject_IsInstance(descr, static_prop)
&& !PyObject_IsInstance(value, static_prop);
if (call_descr_set) {
// Call `static_property.__set__()` instead of replacing the `static_property`.
#if !defined(PYPY_VERSION)
return Py_TYPE(descr)->tp_descr_set(descr, obj, value);
#else
if (PyObject *result = PyObject_CallMethod(descr, "__set__", "OO", obj, value)) {
Py_DECREF(result);
return 0;
} else {
return -1;
}
#endif
} else {
// Replace existing attribute.
return PyType_Type.tp_setattro(obj, name, value);
}
}
#if PY_MAJOR_VERSION >= 3
/**
* Python 3's PyInstanceMethod_Type hides itself via its tp_descr_get, which prevents aliasing
* methods via cls.attr("m2") = cls.attr("m1"): instead the tp_descr_get returns a plain function,
* when called on a class, or a PyMethod, when called on an instance. Override that behaviour here
* to do a special case bypass for PyInstanceMethod_Types.
*/
extern "C" inline PyObject *pybind11_meta_getattro(PyObject *obj, PyObject *name) {
PyObject *descr = _PyType_Lookup((PyTypeObject *) obj, name);
if (descr && PyInstanceMethod_Check(descr)) {
Py_INCREF(descr);
return descr;
}
else {
return PyType_Type.tp_getattro(obj, name);
}
}
#endif
/** This metaclass is assigned by default to all pybind11 types and is required in order
for static properties to function correctly. Users may override this using `py::metaclass`.
Return value: New reference. */
inline PyTypeObject* make_default_metaclass() {
constexpr auto *name = "pybind11_type";
auto name_obj = reinterpret_steal<object>(PYBIND11_FROM_STRING(name));
/* Danger zone: from now (and until PyType_Ready), make sure to
issue no Python C API calls which could potentially invoke the
garbage collector (the GC will call type_traverse(), which will in
turn find the newly constructed type in an invalid state) */
auto heap_type = (PyHeapTypeObject *) PyType_Type.tp_alloc(&PyType_Type, 0);
if (!heap_type)
pybind11_fail("make_default_metaclass(): error allocating metaclass!");
heap_type->ht_name = name_obj.inc_ref().ptr();
#ifdef PYBIND11_BUILTIN_QUALNAME
heap_type->ht_qualname = name_obj.inc_ref().ptr();
#endif
auto type = &heap_type->ht_type;
type->tp_name = name;
type->tp_base = type_incref(&PyType_Type);
type->tp_flags = Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE | Py_TPFLAGS_HEAPTYPE;
type->tp_setattro = pybind11_meta_setattro;
#if PY_MAJOR_VERSION >= 3
type->tp_getattro = pybind11_meta_getattro;
#endif
if (PyType_Ready(type) < 0)
pybind11_fail("make_default_metaclass(): failure in PyType_Ready()!");
setattr((PyObject *) type, "__module__", str("pybind11_builtins"));
PYBIND11_SET_OLDPY_QUALNAME(type, name_obj);
return type;
}
/// For multiple inheritance types we need to recursively register/deregister base pointers for any
/// base classes with pointers that are difference from the instance value pointer so that we can
/// correctly recognize an offset base class pointer. This calls a function with any offset base ptrs.
inline void traverse_offset_bases(void *valueptr, const detail::type_info *tinfo, instance *self,
bool (*f)(void * /*parentptr*/, instance * /*self*/)) {
for (handle h : reinterpret_borrow<tuple>(tinfo->type->tp_bases)) {
if (auto parent_tinfo = get_type_info((PyTypeObject *) h.ptr())) {
for (auto &c : parent_tinfo->implicit_casts) {
if (c.first == tinfo->cpptype) {
auto *parentptr = c.second(valueptr);
if (parentptr != valueptr)
f(parentptr, self);
traverse_offset_bases(parentptr, parent_tinfo, self, f);
break;
}
}
}
}
}
inline bool register_instance_impl(void *ptr, instance *self) {
get_internals().registered_instances.emplace(ptr, self);
return true; // unused, but gives the same signature as the deregister func
}
inline bool deregister_instance_impl(void *ptr, instance *self) {
auto &registered_instances = get_internals().registered_instances;
auto range = registered_instances.equal_range(ptr);
for (auto it = range.first; it != range.second; ++it) {
if (Py_TYPE(self) == Py_TYPE(it->second)) {
registered_instances.erase(it);
return true;
}
}
return false;
}
inline void register_instance(instance *self, void *valptr, const type_info *tinfo) {
register_instance_impl(valptr, self);
if (!tinfo->simple_ancestors)
traverse_offset_bases(valptr, tinfo, self, register_instance_impl);
}
inline bool deregister_instance(instance *self, void *valptr, const type_info *tinfo) {
bool ret = deregister_instance_impl(valptr, self);
if (!tinfo->simple_ancestors)
traverse_offset_bases(valptr, tinfo, self, deregister_instance_impl);
return ret;
}
/// Instance creation function for all pybind11 types. It allocates the internal instance layout for
/// holding C++ objects and holders. Allocation is done lazily (the first time the instance is cast
/// to a reference or pointer), and initialization is done by an `__init__` function.
inline PyObject *make_new_instance(PyTypeObject *type) {
#if defined(PYPY_VERSION)
// PyPy gets tp_basicsize wrong (issue 2482) under multiple inheritance when the first inherited
// object is a a plain Python type (i.e. not derived from an extension type). Fix it.
ssize_t instance_size = static_cast<ssize_t>(sizeof(instance));
if (type->tp_basicsize < instance_size) {
type->tp_basicsize = instance_size;
}
#endif
PyObject *self = type->tp_alloc(type, 0);
auto inst = reinterpret_cast<instance *>(self);
// Allocate the value/holder internals:
inst->allocate_layout();
inst->owned = true;
return self;
}
/// Instance creation function for all pybind11 types. It only allocates space for the
/// C++ object, but doesn't call the constructor -- an `__init__` function must do that.
extern "C" inline PyObject *pybind11_object_new(PyTypeObject *type, PyObject *, PyObject *) {
return make_new_instance(type);
}
/// An `__init__` function constructs the C++ object. Users should provide at least one
/// of these using `py::init` or directly with `.def(__init__, ...)`. Otherwise, the
/// following default function will be used which simply throws an exception.
extern "C" inline int pybind11_object_init(PyObject *self, PyObject *, PyObject *) {
PyTypeObject *type = Py_TYPE(self);
std::string msg;
#if defined(PYPY_VERSION)
msg += handle((PyObject *) type).attr("__module__").cast<std::string>() + ".";
#endif
msg += type->tp_name;
msg += ": No constructor defined!";
PyErr_SetString(PyExc_TypeError, msg.c_str());
return -1;
}
inline void add_patient(PyObject *nurse, PyObject *patient) {
auto &internals = get_internals();
auto instance = reinterpret_cast<detail::instance *>(nurse);
instance->has_patients = true;
Py_INCREF(patient);
internals.patients[nurse].push_back(patient);
}
inline void clear_patients(PyObject *self) {
auto instance = reinterpret_cast<detail::instance *>(self);
auto &internals = get_internals();
auto pos = internals.patients.find(self);
assert(pos != internals.patients.end());
// Clearing the patients can cause more Python code to run, which
// can invalidate the iterator. Extract the vector of patients
// from the unordered_map first.
auto patients = std::move(pos->second);
internals.patients.erase(pos);
instance->has_patients = false;
for (PyObject *&patient : patients)
Py_CLEAR(patient);
}
/// Clears all internal data from the instance and removes it from registered instances in
/// preparation for deallocation.
inline void clear_instance(PyObject *self) {
auto instance = reinterpret_cast<detail::instance *>(self);
// Deallocate any values/holders, if present:
for (auto &v_h : values_and_holders(instance)) {
if (v_h) {
// We have to deregister before we call dealloc because, for virtual MI types, we still
// need to be able to get the parent pointers.
if (v_h.instance_registered() && !deregister_instance(instance, v_h.value_ptr(), v_h.type))
pybind11_fail("pybind11_object_dealloc(): Tried to deallocate unregistered instance!");
if (instance->owned || v_h.holder_constructed())
v_h.type->dealloc(v_h);
}
}
// Deallocate the value/holder layout internals:
instance->deallocate_layout();
if (instance->weakrefs)
PyObject_ClearWeakRefs(self);
PyObject **dict_ptr = _PyObject_GetDictPtr(self);
if (dict_ptr)
Py_CLEAR(*dict_ptr);
if (instance->has_patients)
clear_patients(self);
}
/// Instance destructor function for all pybind11 types. It calls `type_info.dealloc`
/// to destroy the C++ object itself, while the rest is Python bookkeeping.
extern "C" inline void pybind11_object_dealloc(PyObject *self) {
clear_instance(self);
auto type = Py_TYPE(self);
type->tp_free(self);
// `type->tp_dealloc != pybind11_object_dealloc` means that we're being called
// as part of a derived type's dealloc, in which case we're not allowed to decref
// the type here. For cross-module compatibility, we shouldn't compare directly
// with `pybind11_object_dealloc`, but with the common one stashed in internals.
auto pybind11_object_type = (PyTypeObject *) get_internals().instance_base;
if (type->tp_dealloc == pybind11_object_type->tp_dealloc)
Py_DECREF(type);
}
/** Create the type which can be used as a common base for all classes. This is
needed in order to satisfy Python's requirements for multiple inheritance.
Return value: New reference. */
inline PyObject *make_object_base_type(PyTypeObject *metaclass) {
constexpr auto *name = "pybind11_object";
auto name_obj = reinterpret_steal<object>(PYBIND11_FROM_STRING(name));
/* Danger zone: from now (and until PyType_Ready), make sure to
issue no Python C API calls which could potentially invoke the
garbage collector (the GC will call type_traverse(), which will in
turn find the newly constructed type in an invalid state) */
auto heap_type = (PyHeapTypeObject *) metaclass->tp_alloc(metaclass, 0);
if (!heap_type)
pybind11_fail("make_object_base_type(): error allocating type!");
heap_type->ht_name = name_obj.inc_ref().ptr();
#ifdef PYBIND11_BUILTIN_QUALNAME
heap_type->ht_qualname = name_obj.inc_ref().ptr();
#endif
auto type = &heap_type->ht_type;
type->tp_name = name;
type->tp_base = type_incref(&PyBaseObject_Type);
type->tp_basicsize = static_cast<ssize_t>(sizeof(instance));
type->tp_flags = Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE | Py_TPFLAGS_HEAPTYPE;
type->tp_new = pybind11_object_new;
type->tp_init = pybind11_object_init;
type->tp_dealloc = pybind11_object_dealloc;
/* Support weak references (needed for the keep_alive feature) */
type->tp_weaklistoffset = offsetof(instance, weakrefs);
if (PyType_Ready(type) < 0)
pybind11_fail("PyType_Ready failed in make_object_base_type():" + error_string());
setattr((PyObject *) type, "__module__", str("pybind11_builtins"));
PYBIND11_SET_OLDPY_QUALNAME(type, name_obj);
assert(!PyType_HasFeature(type, Py_TPFLAGS_HAVE_GC));
return (PyObject *) heap_type;
}
/// dynamic_attr: Support for `d = instance.__dict__`.
extern "C" inline PyObject *pybind11_get_dict(PyObject *self, void *) {
PyObject *&dict = *_PyObject_GetDictPtr(self);
if (!dict)
dict = PyDict_New();
Py_XINCREF(dict);
return dict;
}
/// dynamic_attr: Support for `instance.__dict__ = dict()`.
extern "C" inline int pybind11_set_dict(PyObject *self, PyObject *new_dict, void *) {
if (!PyDict_Check(new_dict)) {
PyErr_Format(PyExc_TypeError, "__dict__ must be set to a dictionary, not a '%.200s'",
Py_TYPE(new_dict)->tp_name);
return -1;
}
PyObject *&dict = *_PyObject_GetDictPtr(self);
Py_INCREF(new_dict);
Py_CLEAR(dict);
dict = new_dict;
return 0;
}
/// dynamic_attr: Allow the garbage collector to traverse the internal instance `__dict__`.
extern "C" inline int pybind11_traverse(PyObject *self, visitproc visit, void *arg) {
PyObject *&dict = *_PyObject_GetDictPtr(self);
Py_VISIT(dict);
return 0;
}
/// dynamic_attr: Allow the GC to clear the dictionary.
extern "C" inline int pybind11_clear(PyObject *self) {
PyObject *&dict = *_PyObject_GetDictPtr(self);
Py_CLEAR(dict);
return 0;
}
/// Give instances of this type a `__dict__` and opt into garbage collection.
inline void enable_dynamic_attributes(PyHeapTypeObject *heap_type) {
auto type = &heap_type->ht_type;
#if defined(PYPY_VERSION)
pybind11_fail(std::string(type->tp_name) + ": dynamic attributes are "
"currently not supported in "
"conjunction with PyPy!");
#endif
type->tp_flags |= Py_TPFLAGS_HAVE_GC;
type->tp_dictoffset = type->tp_basicsize; // place dict at the end
type->tp_basicsize += (ssize_t)sizeof(PyObject *); // and allocate enough space for it
type->tp_traverse = pybind11_traverse;
type->tp_clear = pybind11_clear;
static PyGetSetDef getset[] = {
{const_cast<char*>("__dict__"), pybind11_get_dict, pybind11_set_dict, nullptr, nullptr},
{nullptr, nullptr, nullptr, nullptr, nullptr}
};
type->tp_getset = getset;
}
/// buffer_protocol: Fill in the view as specified by flags.
extern "C" inline int pybind11_getbuffer(PyObject *obj, Py_buffer *view, int flags) {
// Look for a `get_buffer` implementation in this type's info or any bases (following MRO).
type_info *tinfo = nullptr;
for (auto type : reinterpret_borrow<tuple>(Py_TYPE(obj)->tp_mro)) {
tinfo = get_type_info((PyTypeObject *) type.ptr());
if (tinfo && tinfo->get_buffer)
break;
}
if (view == nullptr || !tinfo || !tinfo->get_buffer) {
if (view)
view->obj = nullptr;
PyErr_SetString(PyExc_BufferError, "pybind11_getbuffer(): Internal error");
return -1;
}
std::memset(view, 0, sizeof(Py_buffer));
buffer_info *info = tinfo->get_buffer(obj, tinfo->get_buffer_data);
view->obj = obj;
view->ndim = 1;
view->internal = info;
view->buf = info->ptr;
view->itemsize = info->itemsize;
view->len = view->itemsize;
for (auto s : info->shape)
view->len *= s;
if ((flags & PyBUF_FORMAT) == PyBUF_FORMAT)
view->format = const_cast<char *>(info->format.c_str());
if ((flags & PyBUF_STRIDES) == PyBUF_STRIDES) {
view->ndim = (int) info->ndim;
view->strides = &info->strides[0];
view->shape = &info->shape[0];
}
Py_INCREF(view->obj);
return 0;
}
/// buffer_protocol: Release the resources of the buffer.
extern "C" inline void pybind11_releasebuffer(PyObject *, Py_buffer *view) {
delete (buffer_info *) view->internal;
}
/// Give this type a buffer interface.
inline void enable_buffer_protocol(PyHeapTypeObject *heap_type) {
heap_type->ht_type.tp_as_buffer = &heap_type->as_buffer;
#if PY_MAJOR_VERSION < 3
heap_type->ht_type.tp_flags |= Py_TPFLAGS_HAVE_NEWBUFFER;
#endif
heap_type->as_buffer.bf_getbuffer = pybind11_getbuffer;
heap_type->as_buffer.bf_releasebuffer = pybind11_releasebuffer;
}
/** Create a brand new Python type according to the `type_record` specification.
Return value: New reference. */
inline PyObject* make_new_python_type(const type_record &rec) {
auto name = reinterpret_steal<object>(PYBIND11_FROM_STRING(rec.name));
auto qualname = name;
if (rec.scope && !PyModule_Check(rec.scope.ptr()) && hasattr(rec.scope, "__qualname__")) {
#if PY_MAJOR_VERSION >= 3
qualname = reinterpret_steal<object>(
PyUnicode_FromFormat("%U.%U", rec.scope.attr("__qualname__").ptr(), name.ptr()));
#else
qualname = str(rec.scope.attr("__qualname__").cast<std::string>() + "." + rec.name);
#endif
}
object module;
if (rec.scope) {
if (hasattr(rec.scope, "__module__"))
module = rec.scope.attr("__module__");
else if (hasattr(rec.scope, "__name__"))
module = rec.scope.attr("__name__");
}
auto full_name = c_str(
#if !defined(PYPY_VERSION)
module ? str(module).cast<std::string>() + "." + rec.name :
#endif
rec.name);
char *tp_doc = nullptr;
if (rec.doc && options::show_user_defined_docstrings()) {
/* Allocate memory for docstring (using PyObject_MALLOC, since
Python will free this later on) */
size_t size = strlen(rec.doc) + 1;
tp_doc = (char *) PyObject_MALLOC(size);
memcpy((void *) tp_doc, rec.doc, size);
}
auto &internals = get_internals();
auto bases = tuple(rec.bases);
auto base = (bases.size() == 0) ? internals.instance_base
: bases[0].ptr();
/* Danger zone: from now (and until PyType_Ready), make sure to
issue no Python C API calls which could potentially invoke the
garbage collector (the GC will call type_traverse(), which will in
turn find the newly constructed type in an invalid state) */
auto metaclass = rec.metaclass.ptr() ? (PyTypeObject *) rec.metaclass.ptr()
: internals.default_metaclass;
auto heap_type = (PyHeapTypeObject *) metaclass->tp_alloc(metaclass, 0);
if (!heap_type)
pybind11_fail(std::string(rec.name) + ": Unable to create type object!");
heap_type->ht_name = name.release().ptr();
#ifdef PYBIND11_BUILTIN_QUALNAME
heap_type->ht_qualname = qualname.inc_ref().ptr();
#endif
auto type = &heap_type->ht_type;
type->tp_name = full_name;
type->tp_doc = tp_doc;
type->tp_base = type_incref((PyTypeObject *)base);
type->tp_basicsize = static_cast<ssize_t>(sizeof(instance));
if (bases.size() > 0)
type->tp_bases = bases.release().ptr();
/* Don't inherit base __init__ */
type->tp_init = pybind11_object_init;
/* Supported protocols */
type->tp_as_number = &heap_type->as_number;
type->tp_as_sequence = &heap_type->as_sequence;
type->tp_as_mapping = &heap_type->as_mapping;
/* Flags */
type->tp_flags |= Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE | Py_TPFLAGS_HEAPTYPE;
#if PY_MAJOR_VERSION < 3
type->tp_flags |= Py_TPFLAGS_CHECKTYPES;
#endif
if (rec.dynamic_attr)
enable_dynamic_attributes(heap_type);
if (rec.buffer_protocol)
enable_buffer_protocol(heap_type);
if (PyType_Ready(type) < 0)
pybind11_fail(std::string(rec.name) + ": PyType_Ready failed (" + error_string() + ")!");
assert(rec.dynamic_attr ? PyType_HasFeature(type, Py_TPFLAGS_HAVE_GC)
: !PyType_HasFeature(type, Py_TPFLAGS_HAVE_GC));
/* Register type with the parent scope */
if (rec.scope)
setattr(rec.scope, rec.name, (PyObject *) type);
else
Py_INCREF(type); // Keep it alive forever (reference leak)
if (module) // Needed by pydoc
setattr((PyObject *) type, "__module__", module);
PYBIND11_SET_OLDPY_QUALNAME(type, qualname);
return (PyObject *) type;
}
NAMESPACE_END(detail)
NAMESPACE_END(PYBIND11_NAMESPACE)

View File

@@ -0,0 +1,807 @@
/*
pybind11/detail/common.h -- Basic macros
Copyright (c) 2016 Wenzel Jakob <wenzel.jakob@epfl.ch>
All rights reserved. Use of this source code is governed by a
BSD-style license that can be found in the LICENSE file.
*/
#pragma once
#if !defined(NAMESPACE_BEGIN)
# define NAMESPACE_BEGIN(name) namespace name {
#endif
#if !defined(NAMESPACE_END)
# define NAMESPACE_END(name) }
#endif
// Robust support for some features and loading modules compiled against different pybind versions
// requires forcing hidden visibility on pybind code, so we enforce this by setting the attribute on
// the main `pybind11` namespace.
#if !defined(PYBIND11_NAMESPACE)
# ifdef __GNUG__
# define PYBIND11_NAMESPACE pybind11 __attribute__((visibility("hidden")))
# else
# define PYBIND11_NAMESPACE pybind11
# endif
#endif
#if !(defined(_MSC_VER) && __cplusplus == 199711L) && !defined(__INTEL_COMPILER)
# if __cplusplus >= 201402L
# define PYBIND11_CPP14
# if __cplusplus >= 201703L
# define PYBIND11_CPP17
# endif
# endif
#elif defined(_MSC_VER) && __cplusplus == 199711L
// MSVC sets _MSVC_LANG rather than __cplusplus (supposedly until the standard is fully implemented)
// Unless you use the /Zc:__cplusplus flag on Visual Studio 2017 15.7 Preview 3 or newer
# if _MSVC_LANG >= 201402L
# define PYBIND11_CPP14
# if _MSVC_LANG > 201402L && _MSC_VER >= 1910
# define PYBIND11_CPP17
# endif
# endif
#endif
// Compiler version assertions
#if defined(__INTEL_COMPILER)
# if __INTEL_COMPILER < 1700
# error pybind11 requires Intel C++ compiler v17 or newer
# endif
#elif defined(__clang__) && !defined(__apple_build_version__)
# if __clang_major__ < 3 || (__clang_major__ == 3 && __clang_minor__ < 3)
# error pybind11 requires clang 3.3 or newer
# endif
#elif defined(__clang__)
// Apple changes clang version macros to its Xcode version; the first Xcode release based on
// (upstream) clang 3.3 was Xcode 5:
# if __clang_major__ < 5
# error pybind11 requires Xcode/clang 5.0 or newer
# endif
#elif defined(__GNUG__)
# if __GNUC__ < 4 || (__GNUC__ == 4 && __GNUC_MINOR__ < 8)
# error pybind11 requires gcc 4.8 or newer
# endif
#elif defined(_MSC_VER)
// Pybind hits various compiler bugs in 2015u2 and earlier, and also makes use of some stl features
// (e.g. std::negation) added in 2015u3:
# if _MSC_FULL_VER < 190024210
# error pybind11 requires MSVC 2015 update 3 or newer
# endif
#endif
#if !defined(PYBIND11_EXPORT)
# if defined(WIN32) || defined(_WIN32)
# define PYBIND11_EXPORT __declspec(dllexport)
# else
# define PYBIND11_EXPORT __attribute__ ((visibility("default")))
# endif
#endif
#if defined(_MSC_VER)
# define PYBIND11_NOINLINE __declspec(noinline)
#else
# define PYBIND11_NOINLINE __attribute__ ((noinline))
#endif
#if defined(PYBIND11_CPP14)
# define PYBIND11_DEPRECATED(reason) [[deprecated(reason)]]
#else
# define PYBIND11_DEPRECATED(reason) __attribute__((deprecated(reason)))
#endif
#define PYBIND11_VERSION_MAJOR 2
#define PYBIND11_VERSION_MINOR 3
#define PYBIND11_VERSION_PATCH 0
/// Include Python header, disable linking to pythonX_d.lib on Windows in debug mode
#if defined(_MSC_VER)
# if (PY_MAJOR_VERSION == 3 && PY_MINOR_VERSION < 4)
# define HAVE_ROUND 1
# endif
# pragma warning(push)
# pragma warning(disable: 4510 4610 4512 4005)
# if defined(_DEBUG)
# define PYBIND11_DEBUG_MARKER
# undef _DEBUG
# endif
#endif
#include <Python.h>
#include <frameobject.h>
#include <pythread.h>
#if defined(_WIN32) && (defined(min) || defined(max))
# error Macro clash with min and max -- define NOMINMAX when compiling your program on Windows
#endif
#if defined(isalnum)
# undef isalnum
# undef isalpha
# undef islower
# undef isspace
# undef isupper
# undef tolower
# undef toupper
#endif
#if defined(_MSC_VER)
# if defined(PYBIND11_DEBUG_MARKER)
# define _DEBUG
# undef PYBIND11_DEBUG_MARKER
# endif
# pragma warning(pop)
#endif
#include <cstddef>
#include <cstring>
#include <forward_list>
#include <vector>
#include <string>
#include <stdexcept>
#include <unordered_set>
#include <unordered_map>
#include <memory>
#include <typeindex>
#include <type_traits>
#if PY_MAJOR_VERSION >= 3 /// Compatibility macros for various Python versions
#define PYBIND11_INSTANCE_METHOD_NEW(ptr, class_) PyInstanceMethod_New(ptr)
#define PYBIND11_INSTANCE_METHOD_CHECK PyInstanceMethod_Check
#define PYBIND11_INSTANCE_METHOD_GET_FUNCTION PyInstanceMethod_GET_FUNCTION
#define PYBIND11_BYTES_CHECK PyBytes_Check
#define PYBIND11_BYTES_FROM_STRING PyBytes_FromString
#define PYBIND11_BYTES_FROM_STRING_AND_SIZE PyBytes_FromStringAndSize
#define PYBIND11_BYTES_AS_STRING_AND_SIZE PyBytes_AsStringAndSize
#define PYBIND11_BYTES_AS_STRING PyBytes_AsString
#define PYBIND11_BYTES_SIZE PyBytes_Size
#define PYBIND11_LONG_CHECK(o) PyLong_Check(o)
#define PYBIND11_LONG_AS_LONGLONG(o) PyLong_AsLongLong(o)
#define PYBIND11_LONG_FROM_SIGNED(o) PyLong_FromSsize_t((ssize_t) o)
#define PYBIND11_LONG_FROM_UNSIGNED(o) PyLong_FromSize_t((size_t) o)
#define PYBIND11_BYTES_NAME "bytes"
#define PYBIND11_STRING_NAME "str"
#define PYBIND11_SLICE_OBJECT PyObject
#define PYBIND11_FROM_STRING PyUnicode_FromString
#define PYBIND11_STR_TYPE ::pybind11::str
#define PYBIND11_BOOL_ATTR "__bool__"
#define PYBIND11_NB_BOOL(ptr) ((ptr)->nb_bool)
#define PYBIND11_PLUGIN_IMPL(name) \
extern "C" PYBIND11_EXPORT PyObject *PyInit_##name()
#else
#define PYBIND11_INSTANCE_METHOD_NEW(ptr, class_) PyMethod_New(ptr, nullptr, class_)
#define PYBIND11_INSTANCE_METHOD_CHECK PyMethod_Check
#define PYBIND11_INSTANCE_METHOD_GET_FUNCTION PyMethod_GET_FUNCTION
#define PYBIND11_BYTES_CHECK PyString_Check
#define PYBIND11_BYTES_FROM_STRING PyString_FromString
#define PYBIND11_BYTES_FROM_STRING_AND_SIZE PyString_FromStringAndSize
#define PYBIND11_BYTES_AS_STRING_AND_SIZE PyString_AsStringAndSize
#define PYBIND11_BYTES_AS_STRING PyString_AsString
#define PYBIND11_BYTES_SIZE PyString_Size
#define PYBIND11_LONG_CHECK(o) (PyInt_Check(o) || PyLong_Check(o))
#define PYBIND11_LONG_AS_LONGLONG(o) (PyInt_Check(o) ? (long long) PyLong_AsLong(o) : PyLong_AsLongLong(o))
#define PYBIND11_LONG_FROM_SIGNED(o) PyInt_FromSsize_t((ssize_t) o) // Returns long if needed.
#define PYBIND11_LONG_FROM_UNSIGNED(o) PyInt_FromSize_t((size_t) o) // Returns long if needed.
#define PYBIND11_BYTES_NAME "str"
#define PYBIND11_STRING_NAME "unicode"
#define PYBIND11_SLICE_OBJECT PySliceObject
#define PYBIND11_FROM_STRING PyString_FromString
#define PYBIND11_STR_TYPE ::pybind11::bytes
#define PYBIND11_BOOL_ATTR "__nonzero__"
#define PYBIND11_NB_BOOL(ptr) ((ptr)->nb_nonzero)
#define PYBIND11_PLUGIN_IMPL(name) \
static PyObject *pybind11_init_wrapper(); \
extern "C" PYBIND11_EXPORT void init##name() { \
(void)pybind11_init_wrapper(); \
} \
PyObject *pybind11_init_wrapper()
#endif
#if PY_VERSION_HEX >= 0x03050000 && PY_VERSION_HEX < 0x03050200
extern "C" {
struct _Py_atomic_address { void *value; };
PyAPI_DATA(_Py_atomic_address) _PyThreadState_Current;
}
#endif
#define PYBIND11_TRY_NEXT_OVERLOAD ((PyObject *) 1) // special failure return code
#define PYBIND11_STRINGIFY(x) #x
#define PYBIND11_TOSTRING(x) PYBIND11_STRINGIFY(x)
#define PYBIND11_CONCAT(first, second) first##second
#define PYBIND11_CHECK_PYTHON_VERSION \
{ \
const char *compiled_ver = PYBIND11_TOSTRING(PY_MAJOR_VERSION) \
"." PYBIND11_TOSTRING(PY_MINOR_VERSION); \
const char *runtime_ver = Py_GetVersion(); \
size_t len = std::strlen(compiled_ver); \
if (std::strncmp(runtime_ver, compiled_ver, len) != 0 \
|| (runtime_ver[len] >= '0' && runtime_ver[len] <= '9')) { \
PyErr_Format(PyExc_ImportError, \
"Python version mismatch: module was compiled for Python %s, " \
"but the interpreter version is incompatible: %s.", \
compiled_ver, runtime_ver); \
return nullptr; \
} \
}
#define PYBIND11_CATCH_INIT_EXCEPTIONS \
catch (pybind11::error_already_set &e) { \
PyErr_SetString(PyExc_ImportError, e.what()); \
return nullptr; \
} catch (const std::exception &e) { \
PyErr_SetString(PyExc_ImportError, e.what()); \
return nullptr; \
} \
/** \rst
***Deprecated in favor of PYBIND11_MODULE***
This macro creates the entry point that will be invoked when the Python interpreter
imports a plugin library. Please create a `module` in the function body and return
the pointer to its underlying Python object at the end.
.. code-block:: cpp
PYBIND11_PLUGIN(example) {
pybind11::module m("example", "pybind11 example plugin");
/// Set up bindings here
return m.ptr();
}
\endrst */
#define PYBIND11_PLUGIN(name) \
PYBIND11_DEPRECATED("PYBIND11_PLUGIN is deprecated, use PYBIND11_MODULE") \
static PyObject *pybind11_init(); \
PYBIND11_PLUGIN_IMPL(name) { \
PYBIND11_CHECK_PYTHON_VERSION \
try { \
return pybind11_init(); \
} PYBIND11_CATCH_INIT_EXCEPTIONS \
} \
PyObject *pybind11_init()
/** \rst
This macro creates the entry point that will be invoked when the Python interpreter
imports an extension module. The module name is given as the fist argument and it
should not be in quotes. The second macro argument defines a variable of type
`py::module` which can be used to initialize the module.
.. code-block:: cpp
PYBIND11_MODULE(example, m) {
m.doc() = "pybind11 example module";
// Add bindings here
m.def("foo", []() {
return "Hello, World!";
});
}
\endrst */
#define PYBIND11_MODULE(name, variable) \
static void PYBIND11_CONCAT(pybind11_init_, name)(pybind11::module &); \
PYBIND11_PLUGIN_IMPL(name) { \
PYBIND11_CHECK_PYTHON_VERSION \
auto m = pybind11::module(PYBIND11_TOSTRING(name)); \
try { \
PYBIND11_CONCAT(pybind11_init_, name)(m); \
return m.ptr(); \
} PYBIND11_CATCH_INIT_EXCEPTIONS \
} \
void PYBIND11_CONCAT(pybind11_init_, name)(pybind11::module &variable)
NAMESPACE_BEGIN(PYBIND11_NAMESPACE)
using ssize_t = Py_ssize_t;
using size_t = std::size_t;
/// Approach used to cast a previously unknown C++ instance into a Python object
enum class return_value_policy : uint8_t {
/** This is the default return value policy, which falls back to the policy
return_value_policy::take_ownership when the return value is a pointer.
Otherwise, it uses return_value::move or return_value::copy for rvalue
and lvalue references, respectively. See below for a description of what
all of these different policies do. */
automatic = 0,
/** As above, but use policy return_value_policy::reference when the return
value is a pointer. This is the default conversion policy for function
arguments when calling Python functions manually from C++ code (i.e. via
handle::operator()). You probably won't need to use this. */
automatic_reference,
/** Reference an existing object (i.e. do not create a new copy) and take
ownership. Python will call the destructor and delete operator when the
objects reference count reaches zero. Undefined behavior ensues when
the C++ side does the same.. */
take_ownership,
/** Create a new copy of the returned object, which will be owned by
Python. This policy is comparably safe because the lifetimes of the two
instances are decoupled. */
copy,
/** Use std::move to move the return value contents into a new instance
that will be owned by Python. This policy is comparably safe because the
lifetimes of the two instances (move source and destination) are
decoupled. */
move,
/** Reference an existing object, but do not take ownership. The C++ side
is responsible for managing the objects lifetime and deallocating it
when it is no longer used. Warning: undefined behavior will ensue when
the C++ side deletes an object that is still referenced and used by
Python. */
reference,
/** This policy only applies to methods and properties. It references the
object without taking ownership similar to the above
return_value_policy::reference policy. In contrast to that policy, the
function or propertys implicit this argument (called the parent) is
considered to be the the owner of the return value (the child).
pybind11 then couples the lifetime of the parent to the child via a
reference relationship that ensures that the parent cannot be garbage
collected while Python is still using the child. More advanced
variations of this scheme are also possible using combinations of
return_value_policy::reference and the keep_alive call policy */
reference_internal
};
NAMESPACE_BEGIN(detail)
inline static constexpr int log2(size_t n, int k = 0) { return (n <= 1) ? k : log2(n >> 1, k + 1); }
// Returns the size as a multiple of sizeof(void *), rounded up.
inline static constexpr size_t size_in_ptrs(size_t s) { return 1 + ((s - 1) >> log2(sizeof(void *))); }
/**
* The space to allocate for simple layout instance holders (see below) in multiple of the size of
* a pointer (e.g. 2 means 16 bytes on 64-bit architectures). The default is the minimum required
* to holder either a std::unique_ptr or std::shared_ptr (which is almost always
* sizeof(std::shared_ptr<T>)).
*/
constexpr size_t instance_simple_holder_in_ptrs() {
static_assert(sizeof(std::shared_ptr<int>) >= sizeof(std::unique_ptr<int>),
"pybind assumes std::shared_ptrs are at least as big as std::unique_ptrs");
return size_in_ptrs(sizeof(std::shared_ptr<int>));
}
// Forward declarations
struct type_info;
struct value_and_holder;
struct nonsimple_values_and_holders {
void **values_and_holders;
uint8_t *status;
};
/// The 'instance' type which needs to be standard layout (need to be able to use 'offsetof')
struct instance {
PyObject_HEAD
/// Storage for pointers and holder; see simple_layout, below, for a description
union {
void *simple_value_holder[1 + instance_simple_holder_in_ptrs()];
nonsimple_values_and_holders nonsimple;
};
/// Weak references
PyObject *weakrefs;
/// If true, the pointer is owned which means we're free to manage it with a holder.
bool owned : 1;
/**
* An instance has two possible value/holder layouts.
*
* Simple layout (when this flag is true), means the `simple_value_holder` is set with a pointer
* and the holder object governing that pointer, i.e. [val1*][holder]. This layout is applied
* whenever there is no python-side multiple inheritance of bound C++ types *and* the type's
* holder will fit in the default space (which is large enough to hold either a std::unique_ptr
* or std::shared_ptr).
*
* Non-simple layout applies when using custom holders that require more space than `shared_ptr`
* (which is typically the size of two pointers), or when multiple inheritance is used on the
* python side. Non-simple layout allocates the required amount of memory to have multiple
* bound C++ classes as parents. Under this layout, `nonsimple.values_and_holders` is set to a
* pointer to allocated space of the required space to hold a sequence of value pointers and
* holders followed `status`, a set of bit flags (1 byte each), i.e.
* [val1*][holder1][val2*][holder2]...[bb...] where each [block] is rounded up to a multiple of
* `sizeof(void *)`. `nonsimple.status` is, for convenience, a pointer to the
* beginning of the [bb...] block (but not independently allocated).
*
* Status bits indicate whether the associated holder is constructed (&
* status_holder_constructed) and whether the value pointer is registered (&
* status_instance_registered) in `registered_instances`.
*/
bool simple_layout : 1;
/// For simple layout, tracks whether the holder has been constructed
bool simple_holder_constructed : 1;
/// For simple layout, tracks whether the instance is registered in `registered_instances`
bool simple_instance_registered : 1;
/// If true, get_internals().patients has an entry for this object
bool has_patients : 1;
/// Initializes all of the above type/values/holders data (but not the instance values themselves)
void allocate_layout();
/// Destroys/deallocates all of the above
void deallocate_layout();
/// Returns the value_and_holder wrapper for the given type (or the first, if `find_type`
/// omitted). Returns a default-constructed (with `.inst = nullptr`) object on failure if
/// `throw_if_missing` is false.
value_and_holder get_value_and_holder(const type_info *find_type = nullptr, bool throw_if_missing = true);
/// Bit values for the non-simple status flags
static constexpr uint8_t status_holder_constructed = 1;
static constexpr uint8_t status_instance_registered = 2;
};
static_assert(std::is_standard_layout<instance>::value, "Internal error: `pybind11::detail::instance` is not standard layout!");
/// from __cpp_future__ import (convenient aliases from C++14/17)
#if defined(PYBIND11_CPP14) && (!defined(_MSC_VER) || _MSC_VER >= 1910)
using std::enable_if_t;
using std::conditional_t;
using std::remove_cv_t;
using std::remove_reference_t;
#else
template <bool B, typename T = void> using enable_if_t = typename std::enable_if<B, T>::type;
template <bool B, typename T, typename F> using conditional_t = typename std::conditional<B, T, F>::type;
template <typename T> using remove_cv_t = typename std::remove_cv<T>::type;
template <typename T> using remove_reference_t = typename std::remove_reference<T>::type;
#endif
/// Index sequences
#if defined(PYBIND11_CPP14)
using std::index_sequence;
using std::make_index_sequence;
#else
template<size_t ...> struct index_sequence { };
template<size_t N, size_t ...S> struct make_index_sequence_impl : make_index_sequence_impl <N - 1, N - 1, S...> { };
template<size_t ...S> struct make_index_sequence_impl <0, S...> { typedef index_sequence<S...> type; };
template<size_t N> using make_index_sequence = typename make_index_sequence_impl<N>::type;
#endif
/// Make an index sequence of the indices of true arguments
template <typename ISeq, size_t, bool...> struct select_indices_impl { using type = ISeq; };
template <size_t... IPrev, size_t I, bool B, bool... Bs> struct select_indices_impl<index_sequence<IPrev...>, I, B, Bs...>
: select_indices_impl<conditional_t<B, index_sequence<IPrev..., I>, index_sequence<IPrev...>>, I + 1, Bs...> {};
template <bool... Bs> using select_indices = typename select_indices_impl<index_sequence<>, 0, Bs...>::type;
/// Backports of std::bool_constant and std::negation to accommodate older compilers
template <bool B> using bool_constant = std::integral_constant<bool, B>;
template <typename T> struct negation : bool_constant<!T::value> { };
template <typename...> struct void_t_impl { using type = void; };
template <typename... Ts> using void_t = typename void_t_impl<Ts...>::type;
/// Compile-time all/any/none of that check the boolean value of all template types
#if defined(__cpp_fold_expressions) && !(defined(_MSC_VER) && (_MSC_VER < 1916))
template <class... Ts> using all_of = bool_constant<(Ts::value && ...)>;
template <class... Ts> using any_of = bool_constant<(Ts::value || ...)>;
#elif !defined(_MSC_VER)
template <bool...> struct bools {};
template <class... Ts> using all_of = std::is_same<
bools<Ts::value..., true>,
bools<true, Ts::value...>>;
template <class... Ts> using any_of = negation<all_of<negation<Ts>...>>;
#else
// MSVC has trouble with the above, but supports std::conjunction, which we can use instead (albeit
// at a slight loss of compilation efficiency).
template <class... Ts> using all_of = std::conjunction<Ts...>;
template <class... Ts> using any_of = std::disjunction<Ts...>;
#endif
template <class... Ts> using none_of = negation<any_of<Ts...>>;
template <class T, template<class> class... Predicates> using satisfies_all_of = all_of<Predicates<T>...>;
template <class T, template<class> class... Predicates> using satisfies_any_of = any_of<Predicates<T>...>;
template <class T, template<class> class... Predicates> using satisfies_none_of = none_of<Predicates<T>...>;
/// Strip the class from a method type
template <typename T> struct remove_class { };
template <typename C, typename R, typename... A> struct remove_class<R (C::*)(A...)> { typedef R type(A...); };
template <typename C, typename R, typename... A> struct remove_class<R (C::*)(A...) const> { typedef R type(A...); };
/// Helper template to strip away type modifiers
template <typename T> struct intrinsic_type { typedef T type; };
template <typename T> struct intrinsic_type<const T> { typedef typename intrinsic_type<T>::type type; };
template <typename T> struct intrinsic_type<T*> { typedef typename intrinsic_type<T>::type type; };
template <typename T> struct intrinsic_type<T&> { typedef typename intrinsic_type<T>::type type; };
template <typename T> struct intrinsic_type<T&&> { typedef typename intrinsic_type<T>::type type; };
template <typename T, size_t N> struct intrinsic_type<const T[N]> { typedef typename intrinsic_type<T>::type type; };
template <typename T, size_t N> struct intrinsic_type<T[N]> { typedef typename intrinsic_type<T>::type type; };
template <typename T> using intrinsic_t = typename intrinsic_type<T>::type;
/// Helper type to replace 'void' in some expressions
struct void_type { };
/// Helper template which holds a list of types
template <typename...> struct type_list { };
/// Compile-time integer sum
#ifdef __cpp_fold_expressions
template <typename... Ts> constexpr size_t constexpr_sum(Ts... ns) { return (0 + ... + size_t{ns}); }
#else
constexpr size_t constexpr_sum() { return 0; }
template <typename T, typename... Ts>
constexpr size_t constexpr_sum(T n, Ts... ns) { return size_t{n} + constexpr_sum(ns...); }
#endif
NAMESPACE_BEGIN(constexpr_impl)
/// Implementation details for constexpr functions
constexpr int first(int i) { return i; }
template <typename T, typename... Ts>
constexpr int first(int i, T v, Ts... vs) { return v ? i : first(i + 1, vs...); }
constexpr int last(int /*i*/, int result) { return result; }
template <typename T, typename... Ts>
constexpr int last(int i, int result, T v, Ts... vs) { return last(i + 1, v ? i : result, vs...); }
NAMESPACE_END(constexpr_impl)
/// Return the index of the first type in Ts which satisfies Predicate<T>. Returns sizeof...(Ts) if
/// none match.
template <template<typename> class Predicate, typename... Ts>
constexpr int constexpr_first() { return constexpr_impl::first(0, Predicate<Ts>::value...); }
/// Return the index of the last type in Ts which satisfies Predicate<T>, or -1 if none match.
template <template<typename> class Predicate, typename... Ts>
constexpr int constexpr_last() { return constexpr_impl::last(0, -1, Predicate<Ts>::value...); }
/// Return the Nth element from the parameter pack
template <size_t N, typename T, typename... Ts>
struct pack_element { using type = typename pack_element<N - 1, Ts...>::type; };
template <typename T, typename... Ts>
struct pack_element<0, T, Ts...> { using type = T; };
/// Return the one and only type which matches the predicate, or Default if none match.
/// If more than one type matches the predicate, fail at compile-time.
template <template<typename> class Predicate, typename Default, typename... Ts>
struct exactly_one {
static constexpr auto found = constexpr_sum(Predicate<Ts>::value...);
static_assert(found <= 1, "Found more than one type matching the predicate");
static constexpr auto index = found ? constexpr_first<Predicate, Ts...>() : 0;
using type = conditional_t<found, typename pack_element<index, Ts...>::type, Default>;
};
template <template<typename> class P, typename Default>
struct exactly_one<P, Default> { using type = Default; };
template <template<typename> class Predicate, typename Default, typename... Ts>
using exactly_one_t = typename exactly_one<Predicate, Default, Ts...>::type;
/// Defer the evaluation of type T until types Us are instantiated
template <typename T, typename... /*Us*/> struct deferred_type { using type = T; };
template <typename T, typename... Us> using deferred_t = typename deferred_type<T, Us...>::type;
/// Like is_base_of, but requires a strict base (i.e. `is_strict_base_of<T, T>::value == false`,
/// unlike `std::is_base_of`)
template <typename Base, typename Derived> using is_strict_base_of = bool_constant<
std::is_base_of<Base, Derived>::value && !std::is_same<Base, Derived>::value>;
/// Like is_base_of, but also requires that the base type is accessible (i.e. that a Derived pointer
/// can be converted to a Base pointer)
template <typename Base, typename Derived> using is_accessible_base_of = bool_constant<
std::is_base_of<Base, Derived>::value && std::is_convertible<Derived *, Base *>::value>;
template <template<typename...> class Base>
struct is_template_base_of_impl {
template <typename... Us> static std::true_type check(Base<Us...> *);
static std::false_type check(...);
};
/// Check if a template is the base of a type. For example:
/// `is_template_base_of<Base, T>` is true if `struct T : Base<U> {}` where U can be anything
template <template<typename...> class Base, typename T>
#if !defined(_MSC_VER)
using is_template_base_of = decltype(is_template_base_of_impl<Base>::check((intrinsic_t<T>*)nullptr));
#else // MSVC2015 has trouble with decltype in template aliases
struct is_template_base_of : decltype(is_template_base_of_impl<Base>::check((intrinsic_t<T>*)nullptr)) { };
#endif
/// Check if T is an instantiation of the template `Class`. For example:
/// `is_instantiation<shared_ptr, T>` is true if `T == shared_ptr<U>` where U can be anything.
template <template<typename...> class Class, typename T>
struct is_instantiation : std::false_type { };
template <template<typename...> class Class, typename... Us>
struct is_instantiation<Class, Class<Us...>> : std::true_type { };
/// Check if T is std::shared_ptr<U> where U can be anything
template <typename T> using is_shared_ptr = is_instantiation<std::shared_ptr, T>;
/// Check if T looks like an input iterator
template <typename T, typename = void> struct is_input_iterator : std::false_type {};
template <typename T>
struct is_input_iterator<T, void_t<decltype(*std::declval<T &>()), decltype(++std::declval<T &>())>>
: std::true_type {};
template <typename T> using is_function_pointer = bool_constant<
std::is_pointer<T>::value && std::is_function<typename std::remove_pointer<T>::type>::value>;
template <typename F> struct strip_function_object {
using type = typename remove_class<decltype(&F::operator())>::type;
};
// Extracts the function signature from a function, function pointer or lambda.
template <typename Function, typename F = remove_reference_t<Function>>
using function_signature_t = conditional_t<
std::is_function<F>::value,
F,
typename conditional_t<
std::is_pointer<F>::value || std::is_member_pointer<F>::value,
std::remove_pointer<F>,
strip_function_object<F>
>::type
>;
/// Returns true if the type looks like a lambda: that is, isn't a function, pointer or member
/// pointer. Note that this can catch all sorts of other things, too; this is intended to be used
/// in a place where passing a lambda makes sense.
template <typename T> using is_lambda = satisfies_none_of<remove_reference_t<T>,
std::is_function, std::is_pointer, std::is_member_pointer>;
/// Ignore that a variable is unused in compiler warnings
inline void ignore_unused(const int *) { }
/// Apply a function over each element of a parameter pack
#ifdef __cpp_fold_expressions
#define PYBIND11_EXPAND_SIDE_EFFECTS(PATTERN) (((PATTERN), void()), ...)
#else
using expand_side_effects = bool[];
#define PYBIND11_EXPAND_SIDE_EFFECTS(PATTERN) pybind11::detail::expand_side_effects{ ((PATTERN), void(), false)..., false }
#endif
NAMESPACE_END(detail)
/// C++ bindings of builtin Python exceptions
class builtin_exception : public std::runtime_error {
public:
using std::runtime_error::runtime_error;
/// Set the error using the Python C API
virtual void set_error() const = 0;
};
#define PYBIND11_RUNTIME_EXCEPTION(name, type) \
class name : public builtin_exception { public: \
using builtin_exception::builtin_exception; \
name() : name("") { } \
void set_error() const override { PyErr_SetString(type, what()); } \
};
PYBIND11_RUNTIME_EXCEPTION(stop_iteration, PyExc_StopIteration)
PYBIND11_RUNTIME_EXCEPTION(index_error, PyExc_IndexError)
PYBIND11_RUNTIME_EXCEPTION(key_error, PyExc_KeyError)
PYBIND11_RUNTIME_EXCEPTION(value_error, PyExc_ValueError)
PYBIND11_RUNTIME_EXCEPTION(type_error, PyExc_TypeError)
PYBIND11_RUNTIME_EXCEPTION(cast_error, PyExc_RuntimeError) /// Thrown when pybind11::cast or handle::call fail due to a type casting error
PYBIND11_RUNTIME_EXCEPTION(reference_cast_error, PyExc_RuntimeError) /// Used internally
[[noreturn]] PYBIND11_NOINLINE inline void pybind11_fail(const char *reason) { throw std::runtime_error(reason); }
[[noreturn]] PYBIND11_NOINLINE inline void pybind11_fail(const std::string &reason) { throw std::runtime_error(reason); }
template <typename T, typename SFINAE = void> struct format_descriptor { };
NAMESPACE_BEGIN(detail)
// Returns the index of the given type in the type char array below, and in the list in numpy.h
// The order here is: bool; 8 ints ((signed,unsigned)x(8,16,32,64)bits); float,double,long double;
// complex float,double,long double. Note that the long double types only participate when long
// double is actually longer than double (it isn't under MSVC).
// NB: not only the string below but also complex.h and numpy.h rely on this order.
template <typename T, typename SFINAE = void> struct is_fmt_numeric { static constexpr bool value = false; };
template <typename T> struct is_fmt_numeric<T, enable_if_t<std::is_arithmetic<T>::value>> {
static constexpr bool value = true;
static constexpr int index = std::is_same<T, bool>::value ? 0 : 1 + (
std::is_integral<T>::value ? detail::log2(sizeof(T))*2 + std::is_unsigned<T>::value : 8 + (
std::is_same<T, double>::value ? 1 : std::is_same<T, long double>::value ? 2 : 0));
};
NAMESPACE_END(detail)
template <typename T> struct format_descriptor<T, detail::enable_if_t<std::is_arithmetic<T>::value>> {
static constexpr const char c = "?bBhHiIqQfdg"[detail::is_fmt_numeric<T>::index];
static constexpr const char value[2] = { c, '\0' };
static std::string format() { return std::string(1, c); }
};
#if !defined(PYBIND11_CPP17)
template <typename T> constexpr const char format_descriptor<
T, detail::enable_if_t<std::is_arithmetic<T>::value>>::value[2];
#endif
/// RAII wrapper that temporarily clears any Python error state
struct error_scope {
PyObject *type, *value, *trace;
error_scope() { PyErr_Fetch(&type, &value, &trace); }
~error_scope() { PyErr_Restore(type, value, trace); }
};
/// Dummy destructor wrapper that can be used to expose classes with a private destructor
struct nodelete { template <typename T> void operator()(T*) { } };
// overload_cast requires variable templates: C++14
#if defined(PYBIND11_CPP14)
#define PYBIND11_OVERLOAD_CAST 1
NAMESPACE_BEGIN(detail)
template <typename... Args>
struct overload_cast_impl {
constexpr overload_cast_impl() {} // MSVC 2015 needs this
template <typename Return>
constexpr auto operator()(Return (*pf)(Args...)) const noexcept
-> decltype(pf) { return pf; }
template <typename Return, typename Class>
constexpr auto operator()(Return (Class::*pmf)(Args...), std::false_type = {}) const noexcept
-> decltype(pmf) { return pmf; }
template <typename Return, typename Class>
constexpr auto operator()(Return (Class::*pmf)(Args...) const, std::true_type) const noexcept
-> decltype(pmf) { return pmf; }
};
NAMESPACE_END(detail)
/// Syntax sugar for resolving overloaded function pointers:
/// - regular: static_cast<Return (Class::*)(Arg0, Arg1, Arg2)>(&Class::func)
/// - sweet: overload_cast<Arg0, Arg1, Arg2>(&Class::func)
template <typename... Args>
static constexpr detail::overload_cast_impl<Args...> overload_cast = {};
// MSVC 2015 only accepts this particular initialization syntax for this variable template.
/// Const member function selector for overload_cast
/// - regular: static_cast<Return (Class::*)(Arg) const>(&Class::func)
/// - sweet: overload_cast<Arg>(&Class::func, const_)
static constexpr auto const_ = std::true_type{};
#else // no overload_cast: providing something that static_assert-fails:
template <typename... Args> struct overload_cast {
static_assert(detail::deferred_t<std::false_type, Args...>::value,
"pybind11::overload_cast<...> requires compiling in C++14 mode");
};
#endif // overload_cast
NAMESPACE_BEGIN(detail)
// Adaptor for converting arbitrary container arguments into a vector; implicitly convertible from
// any standard container (or C-style array) supporting std::begin/std::end, any singleton
// arithmetic type (if T is arithmetic), or explicitly constructible from an iterator pair.
template <typename T>
class any_container {
std::vector<T> v;
public:
any_container() = default;
// Can construct from a pair of iterators
template <typename It, typename = enable_if_t<is_input_iterator<It>::value>>
any_container(It first, It last) : v(first, last) { }
// Implicit conversion constructor from any arbitrary container type with values convertible to T
template <typename Container, typename = enable_if_t<std::is_convertible<decltype(*std::begin(std::declval<const Container &>())), T>::value>>
any_container(const Container &c) : any_container(std::begin(c), std::end(c)) { }
// initializer_list's aren't deducible, so don't get matched by the above template; we need this
// to explicitly allow implicit conversion from one:
template <typename TIn, typename = enable_if_t<std::is_convertible<TIn, T>::value>>
any_container(const std::initializer_list<TIn> &c) : any_container(c.begin(), c.end()) { }
// Avoid copying if given an rvalue vector of the correct type.
any_container(std::vector<T> &&v) : v(std::move(v)) { }
// Moves the vector out of an rvalue any_container
operator std::vector<T> &&() && { return std::move(v); }
// Dereferencing obtains a reference to the underlying vector
std::vector<T> &operator*() { return v; }
const std::vector<T> &operator*() const { return v; }
// -> lets you call methods on the underlying vector
std::vector<T> *operator->() { return &v; }
const std::vector<T> *operator->() const { return &v; }
};
NAMESPACE_END(detail)
NAMESPACE_END(PYBIND11_NAMESPACE)

View File

@@ -0,0 +1,100 @@
/*
pybind11/detail/descr.h: Helper type for concatenating type signatures at compile time
Copyright (c) 2016 Wenzel Jakob <wenzel.jakob@epfl.ch>
All rights reserved. Use of this source code is governed by a
BSD-style license that can be found in the LICENSE file.
*/
#pragma once
#include "common.h"
NAMESPACE_BEGIN(PYBIND11_NAMESPACE)
NAMESPACE_BEGIN(detail)
#if !defined(_MSC_VER)
# define PYBIND11_DESCR_CONSTEXPR static constexpr
#else
# define PYBIND11_DESCR_CONSTEXPR const
#endif
/* Concatenate type signatures at compile time */
template <size_t N, typename... Ts>
struct descr {
char text[N + 1];
constexpr descr() : text{'\0'} { }
constexpr descr(char const (&s)[N+1]) : descr(s, make_index_sequence<N>()) { }
template <size_t... Is>
constexpr descr(char const (&s)[N+1], index_sequence<Is...>) : text{s[Is]..., '\0'} { }
template <typename... Chars>
constexpr descr(char c, Chars... cs) : text{c, static_cast<char>(cs)..., '\0'} { }
static constexpr std::array<const std::type_info *, sizeof...(Ts) + 1> types() {
return {{&typeid(Ts)..., nullptr}};
}
};
template <size_t N1, size_t N2, typename... Ts1, typename... Ts2, size_t... Is1, size_t... Is2>
constexpr descr<N1 + N2, Ts1..., Ts2...> plus_impl(const descr<N1, Ts1...> &a, const descr<N2, Ts2...> &b,
index_sequence<Is1...>, index_sequence<Is2...>) {
return {a.text[Is1]..., b.text[Is2]...};
}
template <size_t N1, size_t N2, typename... Ts1, typename... Ts2>
constexpr descr<N1 + N2, Ts1..., Ts2...> operator+(const descr<N1, Ts1...> &a, const descr<N2, Ts2...> &b) {
return plus_impl(a, b, make_index_sequence<N1>(), make_index_sequence<N2>());
}
template <size_t N>
constexpr descr<N - 1> _(char const(&text)[N]) { return descr<N - 1>(text); }
constexpr descr<0> _(char const(&)[1]) { return {}; }
template <size_t Rem, size_t... Digits> struct int_to_str : int_to_str<Rem/10, Rem%10, Digits...> { };
template <size_t...Digits> struct int_to_str<0, Digits...> {
static constexpr auto digits = descr<sizeof...(Digits)>(('0' + Digits)...);
};
// Ternary description (like std::conditional)
template <bool B, size_t N1, size_t N2>
constexpr enable_if_t<B, descr<N1 - 1>> _(char const(&text1)[N1], char const(&)[N2]) {
return _(text1);
}
template <bool B, size_t N1, size_t N2>
constexpr enable_if_t<!B, descr<N2 - 1>> _(char const(&)[N1], char const(&text2)[N2]) {
return _(text2);
}
template <bool B, typename T1, typename T2>
constexpr enable_if_t<B, T1> _(const T1 &d, const T2 &) { return d; }
template <bool B, typename T1, typename T2>
constexpr enable_if_t<!B, T2> _(const T1 &, const T2 &d) { return d; }
template <size_t Size> auto constexpr _() -> decltype(int_to_str<Size / 10, Size % 10>::digits) {
return int_to_str<Size / 10, Size % 10>::digits;
}
template <typename Type> constexpr descr<1, Type> _() { return {'%'}; }
constexpr descr<0> concat() { return {}; }
template <size_t N, typename... Ts>
constexpr descr<N, Ts...> concat(const descr<N, Ts...> &descr) { return descr; }
template <size_t N, typename... Ts, typename... Args>
constexpr auto concat(const descr<N, Ts...> &d, const Args &...args)
-> decltype(std::declval<descr<N + 2, Ts...>>() + concat(args...)) {
return d + _(", ") + concat(args...);
}
template <size_t N, typename... Ts>
constexpr descr<N + 2, Ts...> type_descr(const descr<N, Ts...> &descr) {
return _("{") + descr + _("}");
}
NAMESPACE_END(detail)
NAMESPACE_END(PYBIND11_NAMESPACE)

View File

@@ -0,0 +1,335 @@
/*
pybind11/detail/init.h: init factory function implementation and support code.
Copyright (c) 2017 Jason Rhinelander <jason@imaginary.ca>
All rights reserved. Use of this source code is governed by a
BSD-style license that can be found in the LICENSE file.
*/
#pragma once
#include "class.h"
NAMESPACE_BEGIN(PYBIND11_NAMESPACE)
NAMESPACE_BEGIN(detail)
template <>
class type_caster<value_and_holder> {
public:
bool load(handle h, bool) {
value = reinterpret_cast<value_and_holder *>(h.ptr());
return true;
}
template <typename> using cast_op_type = value_and_holder &;
operator value_and_holder &() { return *value; }
static constexpr auto name = _<value_and_holder>();
private:
value_and_holder *value = nullptr;
};
NAMESPACE_BEGIN(initimpl)
inline void no_nullptr(void *ptr) {
if (!ptr) throw type_error("pybind11::init(): factory function returned nullptr");
}
// Implementing functions for all forms of py::init<...> and py::init(...)
template <typename Class> using Cpp = typename Class::type;
template <typename Class> using Alias = typename Class::type_alias;
template <typename Class> using Holder = typename Class::holder_type;
template <typename Class> using is_alias_constructible = std::is_constructible<Alias<Class>, Cpp<Class> &&>;
// Takes a Cpp pointer and returns true if it actually is a polymorphic Alias instance.
template <typename Class, enable_if_t<Class::has_alias, int> = 0>
bool is_alias(Cpp<Class> *ptr) {
return dynamic_cast<Alias<Class> *>(ptr) != nullptr;
}
// Failing fallback version of the above for a no-alias class (always returns false)
template <typename /*Class*/>
constexpr bool is_alias(void *) { return false; }
// Constructs and returns a new object; if the given arguments don't map to a constructor, we fall
// back to brace aggregate initiailization so that for aggregate initialization can be used with
// py::init, e.g. `py::init<int, int>` to initialize a `struct T { int a; int b; }`. For
// non-aggregate types, we need to use an ordinary T(...) constructor (invoking as `T{...}` usually
// works, but will not do the expected thing when `T` has an `initializer_list<T>` constructor).
template <typename Class, typename... Args, detail::enable_if_t<std::is_constructible<Class, Args...>::value, int> = 0>
inline Class *construct_or_initialize(Args &&...args) { return new Class(std::forward<Args>(args)...); }
template <typename Class, typename... Args, detail::enable_if_t<!std::is_constructible<Class, Args...>::value, int> = 0>
inline Class *construct_or_initialize(Args &&...args) { return new Class{std::forward<Args>(args)...}; }
// Attempts to constructs an alias using a `Alias(Cpp &&)` constructor. This allows types with
// an alias to provide only a single Cpp factory function as long as the Alias can be
// constructed from an rvalue reference of the base Cpp type. This means that Alias classes
// can, when appropriate, simply define a `Alias(Cpp &&)` constructor rather than needing to
// inherit all the base class constructors.
template <typename Class>
void construct_alias_from_cpp(std::true_type /*is_alias_constructible*/,
value_and_holder &v_h, Cpp<Class> &&base) {
v_h.value_ptr() = new Alias<Class>(std::move(base));
}
template <typename Class>
[[noreturn]] void construct_alias_from_cpp(std::false_type /*!is_alias_constructible*/,
value_and_holder &, Cpp<Class> &&) {
throw type_error("pybind11::init(): unable to convert returned instance to required "
"alias class: no `Alias<Class>(Class &&)` constructor available");
}
// Error-generating fallback for factories that don't match one of the below construction
// mechanisms.
template <typename Class>
void construct(...) {
static_assert(!std::is_same<Class, Class>::value /* always false */,
"pybind11::init(): init function must return a compatible pointer, "
"holder, or value");
}
// Pointer return v1: the factory function returns a class pointer for a registered class.
// If we don't need an alias (because this class doesn't have one, or because the final type is
// inherited on the Python side) we can simply take over ownership. Otherwise we need to try to
// construct an Alias from the returned base instance.
template <typename Class>
void construct(value_and_holder &v_h, Cpp<Class> *ptr, bool need_alias) {
no_nullptr(ptr);
if (Class::has_alias && need_alias && !is_alias<Class>(ptr)) {
// We're going to try to construct an alias by moving the cpp type. Whether or not
// that succeeds, we still need to destroy the original cpp pointer (either the
// moved away leftover, if the alias construction works, or the value itself if we
// throw an error), but we can't just call `delete ptr`: it might have a special
// deleter, or might be shared_from_this. So we construct a holder around it as if
// it was a normal instance, then steal the holder away into a local variable; thus
// the holder and destruction happens when we leave the C++ scope, and the holder
// class gets to handle the destruction however it likes.
v_h.value_ptr() = ptr;
v_h.set_instance_registered(true); // To prevent init_instance from registering it
v_h.type->init_instance(v_h.inst, nullptr); // Set up the holder
Holder<Class> temp_holder(std::move(v_h.holder<Holder<Class>>())); // Steal the holder
v_h.type->dealloc(v_h); // Destroys the moved-out holder remains, resets value ptr to null
v_h.set_instance_registered(false);
construct_alias_from_cpp<Class>(is_alias_constructible<Class>{}, v_h, std::move(*ptr));
} else {
// Otherwise the type isn't inherited, so we don't need an Alias
v_h.value_ptr() = ptr;
}
}
// Pointer return v2: a factory that always returns an alias instance ptr. We simply take over
// ownership of the pointer.
template <typename Class, enable_if_t<Class::has_alias, int> = 0>
void construct(value_and_holder &v_h, Alias<Class> *alias_ptr, bool) {
no_nullptr(alias_ptr);
v_h.value_ptr() = static_cast<Cpp<Class> *>(alias_ptr);
}
// Holder return: copy its pointer, and move or copy the returned holder into the new instance's
// holder. This also handles types like std::shared_ptr<T> and std::unique_ptr<T> where T is a
// derived type (through those holder's implicit conversion from derived class holder constructors).
template <typename Class>
void construct(value_and_holder &v_h, Holder<Class> holder, bool need_alias) {
auto *ptr = holder_helper<Holder<Class>>::get(holder);
// If we need an alias, check that the held pointer is actually an alias instance
if (Class::has_alias && need_alias && !is_alias<Class>(ptr))
throw type_error("pybind11::init(): construction failed: returned holder-wrapped instance "
"is not an alias instance");
v_h.value_ptr() = ptr;
v_h.type->init_instance(v_h.inst, &holder);
}
// return-by-value version 1: returning a cpp class by value. If the class has an alias and an
// alias is required the alias must have an `Alias(Cpp &&)` constructor so that we can construct
// the alias from the base when needed (i.e. because of Python-side inheritance). When we don't
// need it, we simply move-construct the cpp value into a new instance.
template <typename Class>
void construct(value_and_holder &v_h, Cpp<Class> &&result, bool need_alias) {
static_assert(std::is_move_constructible<Cpp<Class>>::value,
"pybind11::init() return-by-value factory function requires a movable class");
if (Class::has_alias && need_alias)
construct_alias_from_cpp<Class>(is_alias_constructible<Class>{}, v_h, std::move(result));
else
v_h.value_ptr() = new Cpp<Class>(std::move(result));
}
// return-by-value version 2: returning a value of the alias type itself. We move-construct an
// Alias instance (even if no the python-side inheritance is involved). The is intended for
// cases where Alias initialization is always desired.
template <typename Class>
void construct(value_and_holder &v_h, Alias<Class> &&result, bool) {
static_assert(std::is_move_constructible<Alias<Class>>::value,
"pybind11::init() return-by-alias-value factory function requires a movable alias class");
v_h.value_ptr() = new Alias<Class>(std::move(result));
}
// Implementing class for py::init<...>()
template <typename... Args>
struct constructor {
template <typename Class, typename... Extra, enable_if_t<!Class::has_alias, int> = 0>
static void execute(Class &cl, const Extra&... extra) {
cl.def("__init__", [](value_and_holder &v_h, Args... args) {
v_h.value_ptr() = construct_or_initialize<Cpp<Class>>(std::forward<Args>(args)...);
}, is_new_style_constructor(), extra...);
}
template <typename Class, typename... Extra,
enable_if_t<Class::has_alias &&
std::is_constructible<Cpp<Class>, Args...>::value, int> = 0>
static void execute(Class &cl, const Extra&... extra) {
cl.def("__init__", [](value_and_holder &v_h, Args... args) {
if (Py_TYPE(v_h.inst) == v_h.type->type)
v_h.value_ptr() = construct_or_initialize<Cpp<Class>>(std::forward<Args>(args)...);
else
v_h.value_ptr() = construct_or_initialize<Alias<Class>>(std::forward<Args>(args)...);
}, is_new_style_constructor(), extra...);
}
template <typename Class, typename... Extra,
enable_if_t<Class::has_alias &&
!std::is_constructible<Cpp<Class>, Args...>::value, int> = 0>
static void execute(Class &cl, const Extra&... extra) {
cl.def("__init__", [](value_and_holder &v_h, Args... args) {
v_h.value_ptr() = construct_or_initialize<Alias<Class>>(std::forward<Args>(args)...);
}, is_new_style_constructor(), extra...);
}
};
// Implementing class for py::init_alias<...>()
template <typename... Args> struct alias_constructor {
template <typename Class, typename... Extra,
enable_if_t<Class::has_alias && std::is_constructible<Alias<Class>, Args...>::value, int> = 0>
static void execute(Class &cl, const Extra&... extra) {
cl.def("__init__", [](value_and_holder &v_h, Args... args) {
v_h.value_ptr() = construct_or_initialize<Alias<Class>>(std::forward<Args>(args)...);
}, is_new_style_constructor(), extra...);
}
};
// Implementation class for py::init(Func) and py::init(Func, AliasFunc)
template <typename CFunc, typename AFunc = void_type (*)(),
typename = function_signature_t<CFunc>, typename = function_signature_t<AFunc>>
struct factory;
// Specialization for py::init(Func)
template <typename Func, typename Return, typename... Args>
struct factory<Func, void_type (*)(), Return(Args...)> {
remove_reference_t<Func> class_factory;
factory(Func &&f) : class_factory(std::forward<Func>(f)) { }
// The given class either has no alias or has no separate alias factory;
// this always constructs the class itself. If the class is registered with an alias
// type and an alias instance is needed (i.e. because the final type is a Python class
// inheriting from the C++ type) the returned value needs to either already be an alias
// instance, or the alias needs to be constructible from a `Class &&` argument.
template <typename Class, typename... Extra>
void execute(Class &cl, const Extra &...extra) && {
#if defined(PYBIND11_CPP14)
cl.def("__init__", [func = std::move(class_factory)]
#else
auto &func = class_factory;
cl.def("__init__", [func]
#endif
(value_and_holder &v_h, Args... args) {
construct<Class>(v_h, func(std::forward<Args>(args)...),
Py_TYPE(v_h.inst) != v_h.type->type);
}, is_new_style_constructor(), extra...);
}
};
// Specialization for py::init(Func, AliasFunc)
template <typename CFunc, typename AFunc,
typename CReturn, typename... CArgs, typename AReturn, typename... AArgs>
struct factory<CFunc, AFunc, CReturn(CArgs...), AReturn(AArgs...)> {
static_assert(sizeof...(CArgs) == sizeof...(AArgs),
"pybind11::init(class_factory, alias_factory): class and alias factories "
"must have identical argument signatures");
static_assert(all_of<std::is_same<CArgs, AArgs>...>::value,
"pybind11::init(class_factory, alias_factory): class and alias factories "
"must have identical argument signatures");
remove_reference_t<CFunc> class_factory;
remove_reference_t<AFunc> alias_factory;
factory(CFunc &&c, AFunc &&a)
: class_factory(std::forward<CFunc>(c)), alias_factory(std::forward<AFunc>(a)) { }
// The class factory is called when the `self` type passed to `__init__` is the direct
// class (i.e. not inherited), the alias factory when `self` is a Python-side subtype.
template <typename Class, typename... Extra>
void execute(Class &cl, const Extra&... extra) && {
static_assert(Class::has_alias, "The two-argument version of `py::init()` can "
"only be used if the class has an alias");
#if defined(PYBIND11_CPP14)
cl.def("__init__", [class_func = std::move(class_factory), alias_func = std::move(alias_factory)]
#else
auto &class_func = class_factory;
auto &alias_func = alias_factory;
cl.def("__init__", [class_func, alias_func]
#endif
(value_and_holder &v_h, CArgs... args) {
if (Py_TYPE(v_h.inst) == v_h.type->type)
// If the instance type equals the registered type we don't have inheritance, so
// don't need the alias and can construct using the class function:
construct<Class>(v_h, class_func(std::forward<CArgs>(args)...), false);
else
construct<Class>(v_h, alias_func(std::forward<CArgs>(args)...), true);
}, is_new_style_constructor(), extra...);
}
};
/// Set just the C++ state. Same as `__init__`.
template <typename Class, typename T>
void setstate(value_and_holder &v_h, T &&result, bool need_alias) {
construct<Class>(v_h, std::forward<T>(result), need_alias);
}
/// Set both the C++ and Python states
template <typename Class, typename T, typename O,
enable_if_t<std::is_convertible<O, handle>::value, int> = 0>
void setstate(value_and_holder &v_h, std::pair<T, O> &&result, bool need_alias) {
construct<Class>(v_h, std::move(result.first), need_alias);
setattr((PyObject *) v_h.inst, "__dict__", result.second);
}
/// Implementation for py::pickle(GetState, SetState)
template <typename Get, typename Set,
typename = function_signature_t<Get>, typename = function_signature_t<Set>>
struct pickle_factory;
template <typename Get, typename Set,
typename RetState, typename Self, typename NewInstance, typename ArgState>
struct pickle_factory<Get, Set, RetState(Self), NewInstance(ArgState)> {
static_assert(std::is_same<intrinsic_t<RetState>, intrinsic_t<ArgState>>::value,
"The type returned by `__getstate__` must be the same "
"as the argument accepted by `__setstate__`");
remove_reference_t<Get> get;
remove_reference_t<Set> set;
pickle_factory(Get get, Set set)
: get(std::forward<Get>(get)), set(std::forward<Set>(set)) { }
template <typename Class, typename... Extra>
void execute(Class &cl, const Extra &...extra) && {
cl.def("__getstate__", std::move(get));
#if defined(PYBIND11_CPP14)
cl.def("__setstate__", [func = std::move(set)]
#else
auto &func = set;
cl.def("__setstate__", [func]
#endif
(value_and_holder &v_h, ArgState state) {
setstate<Class>(v_h, func(std::forward<ArgState>(state)),
Py_TYPE(v_h.inst) != v_h.type->type);
}, is_new_style_constructor(), extra...);
}
};
NAMESPACE_END(initimpl)
NAMESPACE_END(detail)
NAMESPACE_END(pybind11)

View File

@@ -0,0 +1,291 @@
/*
pybind11/detail/internals.h: Internal data structure and related functions
Copyright (c) 2017 Wenzel Jakob <wenzel.jakob@epfl.ch>
All rights reserved. Use of this source code is governed by a
BSD-style license that can be found in the LICENSE file.
*/
#pragma once
#include "../pytypes.h"
NAMESPACE_BEGIN(PYBIND11_NAMESPACE)
NAMESPACE_BEGIN(detail)
// Forward declarations
inline PyTypeObject *make_static_property_type();
inline PyTypeObject *make_default_metaclass();
inline PyObject *make_object_base_type(PyTypeObject *metaclass);
// The old Python Thread Local Storage (TLS) API is deprecated in Python 3.7 in favor of the new
// Thread Specific Storage (TSS) API.
#if PY_VERSION_HEX >= 0x03070000
# define PYBIND11_TLS_KEY_INIT(var) Py_tss_t *var = nullptr
# define PYBIND11_TLS_GET_VALUE(key) PyThread_tss_get((key))
# define PYBIND11_TLS_REPLACE_VALUE(key, value) PyThread_tss_set((key), (value))
# define PYBIND11_TLS_DELETE_VALUE(key) PyThread_tss_set((key), nullptr)
#else
// Usually an int but a long on Cygwin64 with Python 3.x
# define PYBIND11_TLS_KEY_INIT(var) decltype(PyThread_create_key()) var = 0
# define PYBIND11_TLS_GET_VALUE(key) PyThread_get_key_value((key))
# if PY_MAJOR_VERSION < 3
# define PYBIND11_TLS_DELETE_VALUE(key) \
PyThread_delete_key_value(key)
# define PYBIND11_TLS_REPLACE_VALUE(key, value) \
do { \
PyThread_delete_key_value((key)); \
PyThread_set_key_value((key), (value)); \
} while (false)
# else
# define PYBIND11_TLS_DELETE_VALUE(key) \
PyThread_set_key_value((key), nullptr)
# define PYBIND11_TLS_REPLACE_VALUE(key, value) \
PyThread_set_key_value((key), (value))
# endif
#endif
// Python loads modules by default with dlopen with the RTLD_LOCAL flag; under libc++ and possibly
// other STLs, this means `typeid(A)` from one module won't equal `typeid(A)` from another module
// even when `A` is the same, non-hidden-visibility type (e.g. from a common include). Under
// libstdc++, this doesn't happen: equality and the type_index hash are based on the type name,
// which works. If not under a known-good stl, provide our own name-based hash and equality
// functions that use the type name.
#if defined(__GLIBCXX__)
inline bool same_type(const std::type_info &lhs, const std::type_info &rhs) { return lhs == rhs; }
using type_hash = std::hash<std::type_index>;
using type_equal_to = std::equal_to<std::type_index>;
#else
inline bool same_type(const std::type_info &lhs, const std::type_info &rhs) {
return lhs.name() == rhs.name() || std::strcmp(lhs.name(), rhs.name()) == 0;
}
struct type_hash {
size_t operator()(const std::type_index &t) const {
size_t hash = 5381;
const char *ptr = t.name();
while (auto c = static_cast<unsigned char>(*ptr++))
hash = (hash * 33) ^ c;
return hash;
}
};
struct type_equal_to {
bool operator()(const std::type_index &lhs, const std::type_index &rhs) const {
return lhs.name() == rhs.name() || std::strcmp(lhs.name(), rhs.name()) == 0;
}
};
#endif
template <typename value_type>
using type_map = std::unordered_map<std::type_index, value_type, type_hash, type_equal_to>;
struct overload_hash {
inline size_t operator()(const std::pair<const PyObject *, const char *>& v) const {
size_t value = std::hash<const void *>()(v.first);
value ^= std::hash<const void *>()(v.second) + 0x9e3779b9 + (value<<6) + (value>>2);
return value;
}
};
/// Internal data structure used to track registered instances and types.
/// Whenever binary incompatible changes are made to this structure,
/// `PYBIND11_INTERNALS_VERSION` must be incremented.
struct internals {
type_map<type_info *> registered_types_cpp; // std::type_index -> pybind11's type information
std::unordered_map<PyTypeObject *, std::vector<type_info *>> registered_types_py; // PyTypeObject* -> base type_info(s)
std::unordered_multimap<const void *, instance*> registered_instances; // void * -> instance*
std::unordered_set<std::pair<const PyObject *, const char *>, overload_hash> inactive_overload_cache;
type_map<std::vector<bool (*)(PyObject *, void *&)>> direct_conversions;
std::unordered_map<const PyObject *, std::vector<PyObject *>> patients;
std::forward_list<void (*) (std::exception_ptr)> registered_exception_translators;
std::unordered_map<std::string, void *> shared_data; // Custom data to be shared across extensions
std::vector<PyObject *> loader_patient_stack; // Used by `loader_life_support`
std::forward_list<std::string> static_strings; // Stores the std::strings backing detail::c_str()
PyTypeObject *static_property_type;
PyTypeObject *default_metaclass;
PyObject *instance_base;
#if defined(WITH_THREAD)
PYBIND11_TLS_KEY_INIT(tstate);
PyInterpreterState *istate = nullptr;
#endif
};
/// Additional type information which does not fit into the PyTypeObject.
/// Changes to this struct also require bumping `PYBIND11_INTERNALS_VERSION`.
struct type_info {
PyTypeObject *type;
const std::type_info *cpptype;
size_t type_size, type_align, holder_size_in_ptrs;
void *(*operator_new)(size_t);
void (*init_instance)(instance *, const void *);
void (*dealloc)(value_and_holder &v_h);
std::vector<PyObject *(*)(PyObject *, PyTypeObject *)> implicit_conversions;
std::vector<std::pair<const std::type_info *, void *(*)(void *)>> implicit_casts;
std::vector<bool (*)(PyObject *, void *&)> *direct_conversions;
buffer_info *(*get_buffer)(PyObject *, void *) = nullptr;
void *get_buffer_data = nullptr;
void *(*module_local_load)(PyObject *, const type_info *) = nullptr;
/* A simple type never occurs as a (direct or indirect) parent
* of a class that makes use of multiple inheritance */
bool simple_type : 1;
/* True if there is no multiple inheritance in this type's inheritance tree */
bool simple_ancestors : 1;
/* for base vs derived holder_type checks */
bool default_holder : 1;
/* true if this is a type registered with py::module_local */
bool module_local : 1;
};
/// Tracks the `internals` and `type_info` ABI version independent of the main library version
#define PYBIND11_INTERNALS_VERSION 3
#if defined(_DEBUG)
# define PYBIND11_BUILD_TYPE "_debug"
#else
# define PYBIND11_BUILD_TYPE ""
#endif
#if defined(WITH_THREAD)
# define PYBIND11_INTERNALS_KIND ""
#else
# define PYBIND11_INTERNALS_KIND "_without_thread"
#endif
#define PYBIND11_INTERNALS_ID "__pybind11_internals_v" \
PYBIND11_TOSTRING(PYBIND11_INTERNALS_VERSION) PYBIND11_INTERNALS_KIND PYBIND11_BUILD_TYPE "__"
#define PYBIND11_MODULE_LOCAL_ID "__pybind11_module_local_v" \
PYBIND11_TOSTRING(PYBIND11_INTERNALS_VERSION) PYBIND11_INTERNALS_KIND PYBIND11_BUILD_TYPE "__"
/// Each module locally stores a pointer to the `internals` data. The data
/// itself is shared among modules with the same `PYBIND11_INTERNALS_ID`.
inline internals **&get_internals_pp() {
static internals **internals_pp = nullptr;
return internals_pp;
}
/// Return a reference to the current `internals` data
PYBIND11_NOINLINE inline internals &get_internals() {
auto **&internals_pp = get_internals_pp();
if (internals_pp && *internals_pp)
return **internals_pp;
constexpr auto *id = PYBIND11_INTERNALS_ID;
auto builtins = handle(PyEval_GetBuiltins());
if (builtins.contains(id) && isinstance<capsule>(builtins[id])) {
internals_pp = static_cast<internals **>(capsule(builtins[id]));
// We loaded builtins through python's builtins, which means that our `error_already_set`
// and `builtin_exception` may be different local classes than the ones set up in the
// initial exception translator, below, so add another for our local exception classes.
//
// libstdc++ doesn't require this (types there are identified only by name)
#if !defined(__GLIBCXX__)
(*internals_pp)->registered_exception_translators.push_front(
[](std::exception_ptr p) -> void {
try {
if (p) std::rethrow_exception(p);
} catch (error_already_set &e) { e.restore(); return;
} catch (const builtin_exception &e) { e.set_error(); return;
}
}
);
#endif
} else {
if (!internals_pp) internals_pp = new internals*();
auto *&internals_ptr = *internals_pp;
internals_ptr = new internals();
#if defined(WITH_THREAD)
PyEval_InitThreads();
PyThreadState *tstate = PyThreadState_Get();
#if PY_VERSION_HEX >= 0x03070000
internals_ptr->tstate = PyThread_tss_alloc();
if (!internals_ptr->tstate || PyThread_tss_create(internals_ptr->tstate))
pybind11_fail("get_internals: could not successfully initialize the TSS key!");
PyThread_tss_set(internals_ptr->tstate, tstate);
#else
internals_ptr->tstate = PyThread_create_key();
if (internals_ptr->tstate == -1)
pybind11_fail("get_internals: could not successfully initialize the TLS key!");
PyThread_set_key_value(internals_ptr->tstate, tstate);
#endif
internals_ptr->istate = tstate->interp;
#endif
builtins[id] = capsule(internals_pp);
internals_ptr->registered_exception_translators.push_front(
[](std::exception_ptr p) -> void {
try {
if (p) std::rethrow_exception(p);
} catch (error_already_set &e) { e.restore(); return;
} catch (const builtin_exception &e) { e.set_error(); return;
} catch (const std::bad_alloc &e) { PyErr_SetString(PyExc_MemoryError, e.what()); return;
} catch (const std::domain_error &e) { PyErr_SetString(PyExc_ValueError, e.what()); return;
} catch (const std::invalid_argument &e) { PyErr_SetString(PyExc_ValueError, e.what()); return;
} catch (const std::length_error &e) { PyErr_SetString(PyExc_ValueError, e.what()); return;
} catch (const std::out_of_range &e) { PyErr_SetString(PyExc_IndexError, e.what()); return;
} catch (const std::range_error &e) { PyErr_SetString(PyExc_ValueError, e.what()); return;
} catch (const std::exception &e) { PyErr_SetString(PyExc_RuntimeError, e.what()); return;
} catch (...) {
PyErr_SetString(PyExc_RuntimeError, "Caught an unknown exception!");
return;
}
}
);
internals_ptr->static_property_type = make_static_property_type();
internals_ptr->default_metaclass = make_default_metaclass();
internals_ptr->instance_base = make_object_base_type(internals_ptr->default_metaclass);
}
return **internals_pp;
}
/// Works like `internals.registered_types_cpp`, but for module-local registered types:
inline type_map<type_info *> &registered_local_types_cpp() {
static type_map<type_info *> locals{};
return locals;
}
/// Constructs a std::string with the given arguments, stores it in `internals`, and returns its
/// `c_str()`. Such strings objects have a long storage duration -- the internal strings are only
/// cleared when the program exits or after interpreter shutdown (when embedding), and so are
/// suitable for c-style strings needed by Python internals (such as PyTypeObject's tp_name).
template <typename... Args>
const char *c_str(Args &&...args) {
auto &strings = get_internals().static_strings;
strings.emplace_front(std::forward<Args>(args)...);
return strings.front().c_str();
}
NAMESPACE_END(detail)
/// Returns a named pointer that is shared among all extension modules (using the same
/// pybind11 version) running in the current interpreter. Names starting with underscores
/// are reserved for internal usage. Returns `nullptr` if no matching entry was found.
inline PYBIND11_NOINLINE void *get_shared_data(const std::string &name) {
auto &internals = detail::get_internals();
auto it = internals.shared_data.find(name);
return it != internals.shared_data.end() ? it->second : nullptr;
}
/// Set the shared data that can be later recovered by `get_shared_data()`.
inline PYBIND11_NOINLINE void *set_shared_data(const std::string &name, void *data) {
detail::get_internals().shared_data[name] = data;
return data;
}
/// Returns a typed reference to a shared data entry (by using `get_shared_data()`) if
/// such entry exists. Otherwise, a new object of default-constructible type `T` is
/// added to the shared data under the given name and a reference to it is returned.
template<typename T>
T &get_or_create_shared_data(const std::string &name) {
auto &internals = detail::get_internals();
auto it = internals.shared_data.find(name);
T *ptr = (T *) (it != internals.shared_data.end() ? it->second : nullptr);
if (!ptr) {
ptr = new T();
internals.shared_data[name] = ptr;
}
return *ptr;
}
NAMESPACE_END(PYBIND11_NAMESPACE)

View File

@@ -0,0 +1,55 @@
/*
pybind11/detail/typeid.h: Compiler-independent access to type identifiers
Copyright (c) 2016 Wenzel Jakob <wenzel.jakob@epfl.ch>
All rights reserved. Use of this source code is governed by a
BSD-style license that can be found in the LICENSE file.
*/
#pragma once
#include <cstdio>
#include <cstdlib>
#if defined(__GNUG__)
#include <cxxabi.h>
#endif
#include "common.h"
NAMESPACE_BEGIN(PYBIND11_NAMESPACE)
NAMESPACE_BEGIN(detail)
/// Erase all occurrences of a substring
inline void erase_all(std::string &string, const std::string &search) {
for (size_t pos = 0;;) {
pos = string.find(search, pos);
if (pos == std::string::npos) break;
string.erase(pos, search.length());
}
}
PYBIND11_NOINLINE inline void clean_type_id(std::string &name) {
#if defined(__GNUG__)
int status = 0;
std::unique_ptr<char, void (*)(void *)> res {
abi::__cxa_demangle(name.c_str(), nullptr, nullptr, &status), std::free };
if (status == 0)
name = res.get();
#else
detail::erase_all(name, "class ");
detail::erase_all(name, "struct ");
detail::erase_all(name, "enum ");
#endif
detail::erase_all(name, "pybind11::");
}
NAMESPACE_END(detail)
/// Return a string representation of a C++ type
template <typename T> static std::string type_id() {
std::string name(typeid(T).name());
detail::clean_type_id(name);
return name;
}
NAMESPACE_END(PYBIND11_NAMESPACE)

607
python/src/pybind11/eigen.h Normal file
View File

@@ -0,0 +1,607 @@
/*
pybind11/eigen.h: Transparent conversion for dense and sparse Eigen matrices
Copyright (c) 2016 Wenzel Jakob <wenzel.jakob@epfl.ch>
All rights reserved. Use of this source code is governed by a
BSD-style license that can be found in the LICENSE file.
*/
#pragma once
#include "numpy.h"
#if defined(__INTEL_COMPILER)
# pragma warning(disable: 1682) // implicit conversion of a 64-bit integral type to a smaller integral type (potential portability problem)
#elif defined(__GNUG__) || defined(__clang__)
# pragma GCC diagnostic push
# pragma GCC diagnostic ignored "-Wconversion"
# pragma GCC diagnostic ignored "-Wdeprecated-declarations"
# ifdef __clang__
// Eigen generates a bunch of implicit-copy-constructor-is-deprecated warnings with -Wdeprecated
// under Clang, so disable that warning here:
# pragma GCC diagnostic ignored "-Wdeprecated"
# endif
# if __GNUC__ >= 7
# pragma GCC diagnostic ignored "-Wint-in-bool-context"
# endif
#endif
#if defined(_MSC_VER)
# pragma warning(push)
# pragma warning(disable: 4127) // warning C4127: Conditional expression is constant
# pragma warning(disable: 4996) // warning C4996: std::unary_negate is deprecated in C++17
#endif
#include <Eigen/Core>
#include <Eigen/SparseCore>
// Eigen prior to 3.2.7 doesn't have proper move constructors--but worse, some classes get implicit
// move constructors that break things. We could detect this an explicitly copy, but an extra copy
// of matrices seems highly undesirable.
static_assert(EIGEN_VERSION_AT_LEAST(3,2,7), "Eigen support in pybind11 requires Eigen >= 3.2.7");
NAMESPACE_BEGIN(PYBIND11_NAMESPACE)
// Provide a convenience alias for easier pass-by-ref usage with fully dynamic strides:
using EigenDStride = Eigen::Stride<Eigen::Dynamic, Eigen::Dynamic>;
template <typename MatrixType> using EigenDRef = Eigen::Ref<MatrixType, 0, EigenDStride>;
template <typename MatrixType> using EigenDMap = Eigen::Map<MatrixType, 0, EigenDStride>;
NAMESPACE_BEGIN(detail)
#if EIGEN_VERSION_AT_LEAST(3,3,0)
using EigenIndex = Eigen::Index;
#else
using EigenIndex = EIGEN_DEFAULT_DENSE_INDEX_TYPE;
#endif
// Matches Eigen::Map, Eigen::Ref, blocks, etc:
template <typename T> using is_eigen_dense_map = all_of<is_template_base_of<Eigen::DenseBase, T>, std::is_base_of<Eigen::MapBase<T, Eigen::ReadOnlyAccessors>, T>>;
template <typename T> using is_eigen_mutable_map = std::is_base_of<Eigen::MapBase<T, Eigen::WriteAccessors>, T>;
template <typename T> using is_eigen_dense_plain = all_of<negation<is_eigen_dense_map<T>>, is_template_base_of<Eigen::PlainObjectBase, T>>;
template <typename T> using is_eigen_sparse = is_template_base_of<Eigen::SparseMatrixBase, T>;
// Test for objects inheriting from EigenBase<Derived> that aren't captured by the above. This
// basically covers anything that can be assigned to a dense matrix but that don't have a typical
// matrix data layout that can be copied from their .data(). For example, DiagonalMatrix and
// SelfAdjointView fall into this category.
template <typename T> using is_eigen_other = all_of<
is_template_base_of<Eigen::EigenBase, T>,
negation<any_of<is_eigen_dense_map<T>, is_eigen_dense_plain<T>, is_eigen_sparse<T>>>
>;
// Captures numpy/eigen conformability status (returned by EigenProps::conformable()):
template <bool EigenRowMajor> struct EigenConformable {
bool conformable = false;
EigenIndex rows = 0, cols = 0;
EigenDStride stride{0, 0}; // Only valid if negativestrides is false!
bool negativestrides = false; // If true, do not use stride!
EigenConformable(bool fits = false) : conformable{fits} {}
// Matrix type:
EigenConformable(EigenIndex r, EigenIndex c,
EigenIndex rstride, EigenIndex cstride) :
conformable{true}, rows{r}, cols{c} {
// TODO: when Eigen bug #747 is fixed, remove the tests for non-negativity. http://eigen.tuxfamily.org/bz/show_bug.cgi?id=747
if (rstride < 0 || cstride < 0) {
negativestrides = true;
} else {
stride = {EigenRowMajor ? rstride : cstride /* outer stride */,
EigenRowMajor ? cstride : rstride /* inner stride */ };
}
}
// Vector type:
EigenConformable(EigenIndex r, EigenIndex c, EigenIndex stride)
: EigenConformable(r, c, r == 1 ? c*stride : stride, c == 1 ? r : r*stride) {}
template <typename props> bool stride_compatible() const {
// To have compatible strides, we need (on both dimensions) one of fully dynamic strides,
// matching strides, or a dimension size of 1 (in which case the stride value is irrelevant)
return
!negativestrides &&
(props::inner_stride == Eigen::Dynamic || props::inner_stride == stride.inner() ||
(EigenRowMajor ? cols : rows) == 1) &&
(props::outer_stride == Eigen::Dynamic || props::outer_stride == stride.outer() ||
(EigenRowMajor ? rows : cols) == 1);
}
operator bool() const { return conformable; }
};
template <typename Type> struct eigen_extract_stride { using type = Type; };
template <typename PlainObjectType, int MapOptions, typename StrideType>
struct eigen_extract_stride<Eigen::Map<PlainObjectType, MapOptions, StrideType>> { using type = StrideType; };
template <typename PlainObjectType, int Options, typename StrideType>
struct eigen_extract_stride<Eigen::Ref<PlainObjectType, Options, StrideType>> { using type = StrideType; };
// Helper struct for extracting information from an Eigen type
template <typename Type_> struct EigenProps {
using Type = Type_;
using Scalar = typename Type::Scalar;
using StrideType = typename eigen_extract_stride<Type>::type;
static constexpr EigenIndex
rows = Type::RowsAtCompileTime,
cols = Type::ColsAtCompileTime,
size = Type::SizeAtCompileTime;
static constexpr bool
row_major = Type::IsRowMajor,
vector = Type::IsVectorAtCompileTime, // At least one dimension has fixed size 1
fixed_rows = rows != Eigen::Dynamic,
fixed_cols = cols != Eigen::Dynamic,
fixed = size != Eigen::Dynamic, // Fully-fixed size
dynamic = !fixed_rows && !fixed_cols; // Fully-dynamic size
template <EigenIndex i, EigenIndex ifzero> using if_zero = std::integral_constant<EigenIndex, i == 0 ? ifzero : i>;
static constexpr EigenIndex inner_stride = if_zero<StrideType::InnerStrideAtCompileTime, 1>::value,
outer_stride = if_zero<StrideType::OuterStrideAtCompileTime,
vector ? size : row_major ? cols : rows>::value;
static constexpr bool dynamic_stride = inner_stride == Eigen::Dynamic && outer_stride == Eigen::Dynamic;
static constexpr bool requires_row_major = !dynamic_stride && !vector && (row_major ? inner_stride : outer_stride) == 1;
static constexpr bool requires_col_major = !dynamic_stride && !vector && (row_major ? outer_stride : inner_stride) == 1;
// Takes an input array and determines whether we can make it fit into the Eigen type. If
// the array is a vector, we attempt to fit it into either an Eigen 1xN or Nx1 vector
// (preferring the latter if it will fit in either, i.e. for a fully dynamic matrix type).
static EigenConformable<row_major> conformable(const array &a) {
const auto dims = a.ndim();
if (dims < 1 || dims > 2)
return false;
if (dims == 2) { // Matrix type: require exact match (or dynamic)
EigenIndex
np_rows = a.shape(0),
np_cols = a.shape(1),
np_rstride = a.strides(0) / static_cast<ssize_t>(sizeof(Scalar)),
np_cstride = a.strides(1) / static_cast<ssize_t>(sizeof(Scalar));
if ((fixed_rows && np_rows != rows) || (fixed_cols && np_cols != cols))
return false;
return {np_rows, np_cols, np_rstride, np_cstride};
}
// Otherwise we're storing an n-vector. Only one of the strides will be used, but whichever
// is used, we want the (single) numpy stride value.
const EigenIndex n = a.shape(0),
stride = a.strides(0) / static_cast<ssize_t>(sizeof(Scalar));
if (vector) { // Eigen type is a compile-time vector
if (fixed && size != n)
return false; // Vector size mismatch
return {rows == 1 ? 1 : n, cols == 1 ? 1 : n, stride};
}
else if (fixed) {
// The type has a fixed size, but is not a vector: abort
return false;
}
else if (fixed_cols) {
// Since this isn't a vector, cols must be != 1. We allow this only if it exactly
// equals the number of elements (rows is Dynamic, and so 1 row is allowed).
if (cols != n) return false;
return {1, n, stride};
}
else {
// Otherwise it's either fully dynamic, or column dynamic; both become a column vector
if (fixed_rows && rows != n) return false;
return {n, 1, stride};
}
}
static constexpr bool show_writeable = is_eigen_dense_map<Type>::value && is_eigen_mutable_map<Type>::value;
static constexpr bool show_order = is_eigen_dense_map<Type>::value;
static constexpr bool show_c_contiguous = show_order && requires_row_major;
static constexpr bool show_f_contiguous = !show_c_contiguous && show_order && requires_col_major;
static constexpr auto descriptor =
_("numpy.ndarray[") + npy_format_descriptor<Scalar>::name +
_("[") + _<fixed_rows>(_<(size_t) rows>(), _("m")) +
_(", ") + _<fixed_cols>(_<(size_t) cols>(), _("n")) +
_("]") +
// For a reference type (e.g. Ref<MatrixXd>) we have other constraints that might need to be
// satisfied: writeable=True (for a mutable reference), and, depending on the map's stride
// options, possibly f_contiguous or c_contiguous. We include them in the descriptor output
// to provide some hint as to why a TypeError is occurring (otherwise it can be confusing to
// see that a function accepts a 'numpy.ndarray[float64[3,2]]' and an error message that you
// *gave* a numpy.ndarray of the right type and dimensions.
_<show_writeable>(", flags.writeable", "") +
_<show_c_contiguous>(", flags.c_contiguous", "") +
_<show_f_contiguous>(", flags.f_contiguous", "") +
_("]");
};
// Casts an Eigen type to numpy array. If given a base, the numpy array references the src data,
// otherwise it'll make a copy. writeable lets you turn off the writeable flag for the array.
template <typename props> handle eigen_array_cast(typename props::Type const &src, handle base = handle(), bool writeable = true) {
constexpr ssize_t elem_size = sizeof(typename props::Scalar);
array a;
if (props::vector)
a = array({ src.size() }, { elem_size * src.innerStride() }, src.data(), base);
else
a = array({ src.rows(), src.cols() }, { elem_size * src.rowStride(), elem_size * src.colStride() },
src.data(), base);
if (!writeable)
array_proxy(a.ptr())->flags &= ~detail::npy_api::NPY_ARRAY_WRITEABLE_;
return a.release();
}
// Takes an lvalue ref to some Eigen type and a (python) base object, creating a numpy array that
// reference the Eigen object's data with `base` as the python-registered base class (if omitted,
// the base will be set to None, and lifetime management is up to the caller). The numpy array is
// non-writeable if the given type is const.
template <typename props, typename Type>
handle eigen_ref_array(Type &src, handle parent = none()) {
// none here is to get past array's should-we-copy detection, which currently always
// copies when there is no base. Setting the base to None should be harmless.
return eigen_array_cast<props>(src, parent, !std::is_const<Type>::value);
}
// Takes a pointer to some dense, plain Eigen type, builds a capsule around it, then returns a numpy
// array that references the encapsulated data with a python-side reference to the capsule to tie
// its destruction to that of any dependent python objects. Const-ness is determined by whether or
// not the Type of the pointer given is const.
template <typename props, typename Type, typename = enable_if_t<is_eigen_dense_plain<Type>::value>>
handle eigen_encapsulate(Type *src) {
capsule base(src, [](void *o) { delete static_cast<Type *>(o); });
return eigen_ref_array<props>(*src, base);
}
// Type caster for regular, dense matrix types (e.g. MatrixXd), but not maps/refs/etc. of dense
// types.
template<typename Type>
struct type_caster<Type, enable_if_t<is_eigen_dense_plain<Type>::value>> {
using Scalar = typename Type::Scalar;
using props = EigenProps<Type>;
bool load(handle src, bool convert) {
// If we're in no-convert mode, only load if given an array of the correct type
if (!convert && !isinstance<array_t<Scalar>>(src))
return false;
// Coerce into an array, but don't do type conversion yet; the copy below handles it.
auto buf = array::ensure(src);
if (!buf)
return false;
auto dims = buf.ndim();
if (dims < 1 || dims > 2)
return false;
auto fits = props::conformable(buf);
if (!fits)
return false;
// Allocate the new type, then build a numpy reference into it
value = Type(fits.rows, fits.cols);
auto ref = reinterpret_steal<array>(eigen_ref_array<props>(value));
if (dims == 1) ref = ref.squeeze();
else if (ref.ndim() == 1) buf = buf.squeeze();
int result = detail::npy_api::get().PyArray_CopyInto_(ref.ptr(), buf.ptr());
if (result < 0) { // Copy failed!
PyErr_Clear();
return false;
}
return true;
}
private:
// Cast implementation
template <typename CType>
static handle cast_impl(CType *src, return_value_policy policy, handle parent) {
switch (policy) {
case return_value_policy::take_ownership:
case return_value_policy::automatic:
return eigen_encapsulate<props>(src);
case return_value_policy::move:
return eigen_encapsulate<props>(new CType(std::move(*src)));
case return_value_policy::copy:
return eigen_array_cast<props>(*src);
case return_value_policy::reference:
case return_value_policy::automatic_reference:
return eigen_ref_array<props>(*src);
case return_value_policy::reference_internal:
return eigen_ref_array<props>(*src, parent);
default:
throw cast_error("unhandled return_value_policy: should not happen!");
};
}
public:
// Normal returned non-reference, non-const value:
static handle cast(Type &&src, return_value_policy /* policy */, handle parent) {
return cast_impl(&src, return_value_policy::move, parent);
}
// If you return a non-reference const, we mark the numpy array readonly:
static handle cast(const Type &&src, return_value_policy /* policy */, handle parent) {
return cast_impl(&src, return_value_policy::move, parent);
}
// lvalue reference return; default (automatic) becomes copy
static handle cast(Type &src, return_value_policy policy, handle parent) {
if (policy == return_value_policy::automatic || policy == return_value_policy::automatic_reference)
policy = return_value_policy::copy;
return cast_impl(&src, policy, parent);
}
// const lvalue reference return; default (automatic) becomes copy
static handle cast(const Type &src, return_value_policy policy, handle parent) {
if (policy == return_value_policy::automatic || policy == return_value_policy::automatic_reference)
policy = return_value_policy::copy;
return cast(&src, policy, parent);
}
// non-const pointer return
static handle cast(Type *src, return_value_policy policy, handle parent) {
return cast_impl(src, policy, parent);
}
// const pointer return
static handle cast(const Type *src, return_value_policy policy, handle parent) {
return cast_impl(src, policy, parent);
}
static constexpr auto name = props::descriptor;
operator Type*() { return &value; }
operator Type&() { return value; }
operator Type&&() && { return std::move(value); }
template <typename T> using cast_op_type = movable_cast_op_type<T>;
private:
Type value;
};
// Base class for casting reference/map/block/etc. objects back to python.
template <typename MapType> struct eigen_map_caster {
private:
using props = EigenProps<MapType>;
public:
// Directly referencing a ref/map's data is a bit dangerous (whatever the map/ref points to has
// to stay around), but we'll allow it under the assumption that you know what you're doing (and
// have an appropriate keep_alive in place). We return a numpy array pointing directly at the
// ref's data (The numpy array ends up read-only if the ref was to a const matrix type.) Note
// that this means you need to ensure you don't destroy the object in some other way (e.g. with
// an appropriate keep_alive, or with a reference to a statically allocated matrix).
static handle cast(const MapType &src, return_value_policy policy, handle parent) {
switch (policy) {
case return_value_policy::copy:
return eigen_array_cast<props>(src);
case return_value_policy::reference_internal:
return eigen_array_cast<props>(src, parent, is_eigen_mutable_map<MapType>::value);
case return_value_policy::reference:
case return_value_policy::automatic:
case return_value_policy::automatic_reference:
return eigen_array_cast<props>(src, none(), is_eigen_mutable_map<MapType>::value);
default:
// move, take_ownership don't make any sense for a ref/map:
pybind11_fail("Invalid return_value_policy for Eigen Map/Ref/Block type");
}
}
static constexpr auto name = props::descriptor;
// Explicitly delete these: support python -> C++ conversion on these (i.e. these can be return
// types but not bound arguments). We still provide them (with an explicitly delete) so that
// you end up here if you try anyway.
bool load(handle, bool) = delete;
operator MapType() = delete;
template <typename> using cast_op_type = MapType;
};
// We can return any map-like object (but can only load Refs, specialized next):
template <typename Type> struct type_caster<Type, enable_if_t<is_eigen_dense_map<Type>::value>>
: eigen_map_caster<Type> {};
// Loader for Ref<...> arguments. See the documentation for info on how to make this work without
// copying (it requires some extra effort in many cases).
template <typename PlainObjectType, typename StrideType>
struct type_caster<
Eigen::Ref<PlainObjectType, 0, StrideType>,
enable_if_t<is_eigen_dense_map<Eigen::Ref<PlainObjectType, 0, StrideType>>::value>
> : public eigen_map_caster<Eigen::Ref<PlainObjectType, 0, StrideType>> {
private:
using Type = Eigen::Ref<PlainObjectType, 0, StrideType>;
using props = EigenProps<Type>;
using Scalar = typename props::Scalar;
using MapType = Eigen::Map<PlainObjectType, 0, StrideType>;
using Array = array_t<Scalar, array::forcecast |
((props::row_major ? props::inner_stride : props::outer_stride) == 1 ? array::c_style :
(props::row_major ? props::outer_stride : props::inner_stride) == 1 ? array::f_style : 0)>;
static constexpr bool need_writeable = is_eigen_mutable_map<Type>::value;
// Delay construction (these have no default constructor)
std::unique_ptr<MapType> map;
std::unique_ptr<Type> ref;
// Our array. When possible, this is just a numpy array pointing to the source data, but
// sometimes we can't avoid copying (e.g. input is not a numpy array at all, has an incompatible
// layout, or is an array of a type that needs to be converted). Using a numpy temporary
// (rather than an Eigen temporary) saves an extra copy when we need both type conversion and
// storage order conversion. (Note that we refuse to use this temporary copy when loading an
// argument for a Ref<M> with M non-const, i.e. a read-write reference).
Array copy_or_ref;
public:
bool load(handle src, bool convert) {
// First check whether what we have is already an array of the right type. If not, we can't
// avoid a copy (because the copy is also going to do type conversion).
bool need_copy = !isinstance<Array>(src);
EigenConformable<props::row_major> fits;
if (!need_copy) {
// We don't need a converting copy, but we also need to check whether the strides are
// compatible with the Ref's stride requirements
Array aref = reinterpret_borrow<Array>(src);
if (aref && (!need_writeable || aref.writeable())) {
fits = props::conformable(aref);
if (!fits) return false; // Incompatible dimensions
if (!fits.template stride_compatible<props>())
need_copy = true;
else
copy_or_ref = std::move(aref);
}
else {
need_copy = true;
}
}
if (need_copy) {
// We need to copy: If we need a mutable reference, or we're not supposed to convert
// (either because we're in the no-convert overload pass, or because we're explicitly
// instructed not to copy (via `py::arg().noconvert()`) we have to fail loading.
if (!convert || need_writeable) return false;
Array copy = Array::ensure(src);
if (!copy) return false;
fits = props::conformable(copy);
if (!fits || !fits.template stride_compatible<props>())
return false;
copy_or_ref = std::move(copy);
loader_life_support::add_patient(copy_or_ref);
}
ref.reset();
map.reset(new MapType(data(copy_or_ref), fits.rows, fits.cols, make_stride(fits.stride.outer(), fits.stride.inner())));
ref.reset(new Type(*map));
return true;
}
operator Type*() { return ref.get(); }
operator Type&() { return *ref; }
template <typename _T> using cast_op_type = pybind11::detail::cast_op_type<_T>;
private:
template <typename T = Type, enable_if_t<is_eigen_mutable_map<T>::value, int> = 0>
Scalar *data(Array &a) { return a.mutable_data(); }
template <typename T = Type, enable_if_t<!is_eigen_mutable_map<T>::value, int> = 0>
const Scalar *data(Array &a) { return a.data(); }
// Attempt to figure out a constructor of `Stride` that will work.
// If both strides are fixed, use a default constructor:
template <typename S> using stride_ctor_default = bool_constant<
S::InnerStrideAtCompileTime != Eigen::Dynamic && S::OuterStrideAtCompileTime != Eigen::Dynamic &&
std::is_default_constructible<S>::value>;
// Otherwise, if there is a two-index constructor, assume it is (outer,inner) like
// Eigen::Stride, and use it:
template <typename S> using stride_ctor_dual = bool_constant<
!stride_ctor_default<S>::value && std::is_constructible<S, EigenIndex, EigenIndex>::value>;
// Otherwise, if there is a one-index constructor, and just one of the strides is dynamic, use
// it (passing whichever stride is dynamic).
template <typename S> using stride_ctor_outer = bool_constant<
!any_of<stride_ctor_default<S>, stride_ctor_dual<S>>::value &&
S::OuterStrideAtCompileTime == Eigen::Dynamic && S::InnerStrideAtCompileTime != Eigen::Dynamic &&
std::is_constructible<S, EigenIndex>::value>;
template <typename S> using stride_ctor_inner = bool_constant<
!any_of<stride_ctor_default<S>, stride_ctor_dual<S>>::value &&
S::InnerStrideAtCompileTime == Eigen::Dynamic && S::OuterStrideAtCompileTime != Eigen::Dynamic &&
std::is_constructible<S, EigenIndex>::value>;
template <typename S = StrideType, enable_if_t<stride_ctor_default<S>::value, int> = 0>
static S make_stride(EigenIndex, EigenIndex) { return S(); }
template <typename S = StrideType, enable_if_t<stride_ctor_dual<S>::value, int> = 0>
static S make_stride(EigenIndex outer, EigenIndex inner) { return S(outer, inner); }
template <typename S = StrideType, enable_if_t<stride_ctor_outer<S>::value, int> = 0>
static S make_stride(EigenIndex outer, EigenIndex) { return S(outer); }
template <typename S = StrideType, enable_if_t<stride_ctor_inner<S>::value, int> = 0>
static S make_stride(EigenIndex, EigenIndex inner) { return S(inner); }
};
// type_caster for special matrix types (e.g. DiagonalMatrix), which are EigenBase, but not
// EigenDense (i.e. they don't have a data(), at least not with the usual matrix layout).
// load() is not supported, but we can cast them into the python domain by first copying to a
// regular Eigen::Matrix, then casting that.
template <typename Type>
struct type_caster<Type, enable_if_t<is_eigen_other<Type>::value>> {
protected:
using Matrix = Eigen::Matrix<typename Type::Scalar, Type::RowsAtCompileTime, Type::ColsAtCompileTime>;
using props = EigenProps<Matrix>;
public:
static handle cast(const Type &src, return_value_policy /* policy */, handle /* parent */) {
handle h = eigen_encapsulate<props>(new Matrix(src));
return h;
}
static handle cast(const Type *src, return_value_policy policy, handle parent) { return cast(*src, policy, parent); }
static constexpr auto name = props::descriptor;
// Explicitly delete these: support python -> C++ conversion on these (i.e. these can be return
// types but not bound arguments). We still provide them (with an explicitly delete) so that
// you end up here if you try anyway.
bool load(handle, bool) = delete;
operator Type() = delete;
template <typename> using cast_op_type = Type;
};
template<typename Type>
struct type_caster<Type, enable_if_t<is_eigen_sparse<Type>::value>> {
typedef typename Type::Scalar Scalar;
typedef remove_reference_t<decltype(*std::declval<Type>().outerIndexPtr())> StorageIndex;
typedef typename Type::Index Index;
static constexpr bool rowMajor = Type::IsRowMajor;
bool load(handle src, bool) {
if (!src)
return false;
auto obj = reinterpret_borrow<object>(src);
object sparse_module = module::import("scipy.sparse");
object matrix_type = sparse_module.attr(
rowMajor ? "csr_matrix" : "csc_matrix");
if (!obj.get_type().is(matrix_type)) {
try {
obj = matrix_type(obj);
} catch (const error_already_set &) {
return false;
}
}
auto values = array_t<Scalar>((object) obj.attr("data"));
auto innerIndices = array_t<StorageIndex>((object) obj.attr("indices"));
auto outerIndices = array_t<StorageIndex>((object) obj.attr("indptr"));
auto shape = pybind11::tuple((pybind11::object) obj.attr("shape"));
auto nnz = obj.attr("nnz").cast<Index>();
if (!values || !innerIndices || !outerIndices)
return false;
value = Eigen::MappedSparseMatrix<Scalar, Type::Flags, StorageIndex>(
shape[0].cast<Index>(), shape[1].cast<Index>(), nnz,
outerIndices.mutable_data(), innerIndices.mutable_data(), values.mutable_data());
return true;
}
static handle cast(const Type &src, return_value_policy /* policy */, handle /* parent */) {
const_cast<Type&>(src).makeCompressed();
object matrix_type = module::import("scipy.sparse").attr(
rowMajor ? "csr_matrix" : "csc_matrix");
array data(src.nonZeros(), src.valuePtr());
array outerIndices((rowMajor ? src.rows() : src.cols()) + 1, src.outerIndexPtr());
array innerIndices(src.nonZeros(), src.innerIndexPtr());
return matrix_type(
std::make_tuple(data, innerIndices, outerIndices),
std::make_pair(src.rows(), src.cols())
).release();
}
PYBIND11_TYPE_CASTER(Type, _<(Type::IsRowMajor) != 0>("scipy.sparse.csr_matrix[", "scipy.sparse.csc_matrix[")
+ npy_format_descriptor<Scalar>::name + _("]"));
};
NAMESPACE_END(detail)
NAMESPACE_END(PYBIND11_NAMESPACE)
#if defined(__GNUG__) || defined(__clang__)
# pragma GCC diagnostic pop
#elif defined(_MSC_VER)
# pragma warning(pop)
#endif

200
python/src/pybind11/embed.h Normal file
View File

@@ -0,0 +1,200 @@
/*
pybind11/embed.h: Support for embedding the interpreter
Copyright (c) 2017 Wenzel Jakob <wenzel.jakob@epfl.ch>
All rights reserved. Use of this source code is governed by a
BSD-style license that can be found in the LICENSE file.
*/
#pragma once
#include "pybind11.h"
#include "eval.h"
#if defined(PYPY_VERSION)
# error Embedding the interpreter is not supported with PyPy
#endif
#if PY_MAJOR_VERSION >= 3
# define PYBIND11_EMBEDDED_MODULE_IMPL(name) \
extern "C" PyObject *pybind11_init_impl_##name() { \
return pybind11_init_wrapper_##name(); \
}
#else
# define PYBIND11_EMBEDDED_MODULE_IMPL(name) \
extern "C" void pybind11_init_impl_##name() { \
pybind11_init_wrapper_##name(); \
}
#endif
/** \rst
Add a new module to the table of builtins for the interpreter. Must be
defined in global scope. The first macro parameter is the name of the
module (without quotes). The second parameter is the variable which will
be used as the interface to add functions and classes to the module.
.. code-block:: cpp
PYBIND11_EMBEDDED_MODULE(example, m) {
// ... initialize functions and classes here
m.def("foo", []() {
return "Hello, World!";
});
}
\endrst */
#define PYBIND11_EMBEDDED_MODULE(name, variable) \
static void PYBIND11_CONCAT(pybind11_init_, name)(pybind11::module &); \
static PyObject PYBIND11_CONCAT(*pybind11_init_wrapper_, name)() { \
auto m = pybind11::module(PYBIND11_TOSTRING(name)); \
try { \
PYBIND11_CONCAT(pybind11_init_, name)(m); \
return m.ptr(); \
} catch (pybind11::error_already_set &e) { \
PyErr_SetString(PyExc_ImportError, e.what()); \
return nullptr; \
} catch (const std::exception &e) { \
PyErr_SetString(PyExc_ImportError, e.what()); \
return nullptr; \
} \
} \
PYBIND11_EMBEDDED_MODULE_IMPL(name) \
pybind11::detail::embedded_module name(PYBIND11_TOSTRING(name), \
PYBIND11_CONCAT(pybind11_init_impl_, name)); \
void PYBIND11_CONCAT(pybind11_init_, name)(pybind11::module &variable)
NAMESPACE_BEGIN(PYBIND11_NAMESPACE)
NAMESPACE_BEGIN(detail)
/// Python 2.7/3.x compatible version of `PyImport_AppendInittab` and error checks.
struct embedded_module {
#if PY_MAJOR_VERSION >= 3
using init_t = PyObject *(*)();
#else
using init_t = void (*)();
#endif
embedded_module(const char *name, init_t init) {
if (Py_IsInitialized())
pybind11_fail("Can't add new modules after the interpreter has been initialized");
auto result = PyImport_AppendInittab(name, init);
if (result == -1)
pybind11_fail("Insufficient memory to add a new module");
}
};
NAMESPACE_END(detail)
/** \rst
Initialize the Python interpreter. No other pybind11 or CPython API functions can be
called before this is done; with the exception of `PYBIND11_EMBEDDED_MODULE`. The
optional parameter can be used to skip the registration of signal handlers (see the
`Python documentation`_ for details). Calling this function again after the interpreter
has already been initialized is a fatal error.
If initializing the Python interpreter fails, then the program is terminated. (This
is controlled by the CPython runtime and is an exception to pybind11's normal behavior
of throwing exceptions on errors.)
.. _Python documentation: https://docs.python.org/3/c-api/init.html#c.Py_InitializeEx
\endrst */
inline void initialize_interpreter(bool init_signal_handlers = true) {
if (Py_IsInitialized())
pybind11_fail("The interpreter is already running");
Py_InitializeEx(init_signal_handlers ? 1 : 0);
// Make .py files in the working directory available by default
module::import("sys").attr("path").cast<list>().append(".");
}
/** \rst
Shut down the Python interpreter. No pybind11 or CPython API functions can be called
after this. In addition, pybind11 objects must not outlive the interpreter:
.. code-block:: cpp
{ // BAD
py::initialize_interpreter();
auto hello = py::str("Hello, World!");
py::finalize_interpreter();
} // <-- BOOM, hello's destructor is called after interpreter shutdown
{ // GOOD
py::initialize_interpreter();
{ // scoped
auto hello = py::str("Hello, World!");
} // <-- OK, hello is cleaned up properly
py::finalize_interpreter();
}
{ // BETTER
py::scoped_interpreter guard{};
auto hello = py::str("Hello, World!");
}
.. warning::
The interpreter can be restarted by calling `initialize_interpreter` again.
Modules created using pybind11 can be safely re-initialized. However, Python
itself cannot completely unload binary extension modules and there are several
caveats with regard to interpreter restarting. All the details can be found
in the CPython documentation. In short, not all interpreter memory may be
freed, either due to reference cycles or user-created global data.
\endrst */
inline void finalize_interpreter() {
handle builtins(PyEval_GetBuiltins());
const char *id = PYBIND11_INTERNALS_ID;
// Get the internals pointer (without creating it if it doesn't exist). It's possible for the
// internals to be created during Py_Finalize() (e.g. if a py::capsule calls `get_internals()`
// during destruction), so we get the pointer-pointer here and check it after Py_Finalize().
detail::internals **internals_ptr_ptr = detail::get_internals_pp();
// It could also be stashed in builtins, so look there too:
if (builtins.contains(id) && isinstance<capsule>(builtins[id]))
internals_ptr_ptr = capsule(builtins[id]);
Py_Finalize();
if (internals_ptr_ptr) {
delete *internals_ptr_ptr;
*internals_ptr_ptr = nullptr;
}
}
/** \rst
Scope guard version of `initialize_interpreter` and `finalize_interpreter`.
This a move-only guard and only a single instance can exist.
.. code-block:: cpp
#include <pybind11/embed.h>
int main() {
py::scoped_interpreter guard{};
py::print(Hello, World!);
} // <-- interpreter shutdown
\endrst */
class scoped_interpreter {
public:
scoped_interpreter(bool init_signal_handlers = true) {
initialize_interpreter(init_signal_handlers);
}
scoped_interpreter(const scoped_interpreter &) = delete;
scoped_interpreter(scoped_interpreter &&other) noexcept { other.is_valid = false; }
scoped_interpreter &operator=(const scoped_interpreter &) = delete;
scoped_interpreter &operator=(scoped_interpreter &&) = delete;
~scoped_interpreter() {
if (is_valid)
finalize_interpreter();
}
private:
bool is_valid = true;
};
NAMESPACE_END(PYBIND11_NAMESPACE)

117
python/src/pybind11/eval.h Normal file
View File

@@ -0,0 +1,117 @@
/*
pybind11/exec.h: Support for evaluating Python expressions and statements
from strings and files
Copyright (c) 2016 Klemens Morgenstern <klemens.morgenstern@ed-chemnitz.de> and
Wenzel Jakob <wenzel.jakob@epfl.ch>
All rights reserved. Use of this source code is governed by a
BSD-style license that can be found in the LICENSE file.
*/
#pragma once
#include "pybind11.h"
NAMESPACE_BEGIN(PYBIND11_NAMESPACE)
enum eval_mode {
/// Evaluate a string containing an isolated expression
eval_expr,
/// Evaluate a string containing a single statement. Returns \c none
eval_single_statement,
/// Evaluate a string containing a sequence of statement. Returns \c none
eval_statements
};
template <eval_mode mode = eval_expr>
object eval(str expr, object global = globals(), object local = object()) {
if (!local)
local = global;
/* PyRun_String does not accept a PyObject / encoding specifier,
this seems to be the only alternative */
std::string buffer = "# -*- coding: utf-8 -*-\n" + (std::string) expr;
int start;
switch (mode) {
case eval_expr: start = Py_eval_input; break;
case eval_single_statement: start = Py_single_input; break;
case eval_statements: start = Py_file_input; break;
default: pybind11_fail("invalid evaluation mode");
}
PyObject *result = PyRun_String(buffer.c_str(), start, global.ptr(), local.ptr());
if (!result)
throw error_already_set();
return reinterpret_steal<object>(result);
}
template <eval_mode mode = eval_expr, size_t N>
object eval(const char (&s)[N], object global = globals(), object local = object()) {
/* Support raw string literals by removing common leading whitespace */
auto expr = (s[0] == '\n') ? str(module::import("textwrap").attr("dedent")(s))
: str(s);
return eval<mode>(expr, global, local);
}
inline void exec(str expr, object global = globals(), object local = object()) {
eval<eval_statements>(expr, global, local);
}
template <size_t N>
void exec(const char (&s)[N], object global = globals(), object local = object()) {
eval<eval_statements>(s, global, local);
}
template <eval_mode mode = eval_statements>
object eval_file(str fname, object global = globals(), object local = object()) {
if (!local)
local = global;
int start;
switch (mode) {
case eval_expr: start = Py_eval_input; break;
case eval_single_statement: start = Py_single_input; break;
case eval_statements: start = Py_file_input; break;
default: pybind11_fail("invalid evaluation mode");
}
int closeFile = 1;
std::string fname_str = (std::string) fname;
#if PY_VERSION_HEX >= 0x03040000
FILE *f = _Py_fopen_obj(fname.ptr(), "r");
#elif PY_VERSION_HEX >= 0x03000000
FILE *f = _Py_fopen(fname.ptr(), "r");
#else
/* No unicode support in open() :( */
auto fobj = reinterpret_steal<object>(PyFile_FromString(
const_cast<char *>(fname_str.c_str()),
const_cast<char*>("r")));
FILE *f = nullptr;
if (fobj)
f = PyFile_AsFile(fobj.ptr());
closeFile = 0;
#endif
if (!f) {
PyErr_Clear();
pybind11_fail("File \"" + fname_str + "\" could not be opened!");
}
#if PY_VERSION_HEX < 0x03000000 && defined(PYPY_VERSION)
PyObject *result = PyRun_File(f, fname_str.c_str(), start, global.ptr(),
local.ptr());
(void) closeFile;
#else
PyObject *result = PyRun_FileEx(f, fname_str.c_str(), start, global.ptr(),
local.ptr(), closeFile);
#endif
if (!result)
throw error_already_set();
return reinterpret_steal<object>(result);
}
NAMESPACE_END(PYBIND11_NAMESPACE)

View File

@@ -0,0 +1,108 @@
/*
pybind11/functional.h: std::function<> support
Copyright (c) 2016 Wenzel Jakob <wenzel.jakob@epfl.ch>
All rights reserved. Use of this source code is governed by a
BSD-style license that can be found in the LICENSE file.
*/
#pragma once
#include "pybind11.h"
#include <functional>
NAMESPACE_BEGIN(PYBIND11_NAMESPACE)
NAMESPACE_BEGIN(detail)
template <typename Return, typename... Args>
struct type_caster<std::function<Return(Args...)>> {
using type = std::function<Return(Args...)>;
using retval_type = conditional_t<std::is_same<Return, void>::value, void_type, Return>;
using function_type = Return (*) (Args...);
public:
bool load(handle src, bool convert) {
if (src.is_none()) {
// Defer accepting None to other overloads (if we aren't in convert mode):
if (!convert) return false;
return true;
}
if (!isinstance<function>(src))
return false;
auto func = reinterpret_borrow<function>(src);
/*
When passing a C++ function as an argument to another C++
function via Python, every function call would normally involve
a full C++ -> Python -> C++ roundtrip, which can be prohibitive.
Here, we try to at least detect the case where the function is
stateless (i.e. function pointer or lambda function without
captured variables), in which case the roundtrip can be avoided.
*/
if (auto cfunc = func.cpp_function()) {
auto c = reinterpret_borrow<capsule>(PyCFunction_GET_SELF(cfunc.ptr()));
auto rec = (function_record *) c;
if (rec && rec->is_stateless &&
same_type(typeid(function_type), *reinterpret_cast<const std::type_info *>(rec->data[1]))) {
struct capture { function_type f; };
value = ((capture *) &rec->data)->f;
return true;
}
}
// ensure GIL is held during functor destruction
struct func_handle {
function f;
func_handle(function&& f_) : f(std::move(f_)) {}
func_handle(const func_handle&) = default;
~func_handle() {
gil_scoped_acquire acq;
function kill_f(std::move(f));
}
};
// value = [hfunc = func_handle(std::move(func))](Args... args) -> Return {
// gil_scoped_acquire acq;
// object retval(hfunc.f(std::forward<Args>(args)...));
// /* Visual studio 2015 parser issue: need parentheses around this expression */
// return (retval.template cast<Return>());
// };
struct func_wrapper {
func_handle hfunc;
func_wrapper(func_handle&& hf): hfunc(std::move(hf)) {}
Return operator()(Args... args) const {
gil_scoped_acquire acq;
object retval(hfunc.f(std::forward<Args>(args)...));
/* Visual studio 2015 parser issue: need parentheses around this expression */
return (retval.template cast<Return>());
}
};
value = func_wrapper(func_handle(std::move(func)));
return true;
}
template <typename Func>
static handle cast(Func &&f_, return_value_policy policy, handle /* parent */) {
if (!f_)
return none().inc_ref();
auto result = f_.template target<function_type>();
if (result)
return cpp_function(*result, policy).release();
else
return cpp_function(std::forward<Func>(f_), policy).release();
}
PYBIND11_TYPE_CASTER(type, _("Callable[[") + concat(make_caster<Args>::name...) + _("], ")
+ make_caster<retval_type>::name + _("]"));
};
NAMESPACE_END(detail)
NAMESPACE_END(PYBIND11_NAMESPACE)

View File

@@ -0,0 +1,207 @@
/*
pybind11/iostream.h -- Tools to assist with redirecting cout and cerr to Python
Copyright (c) 2017 Henry F. Schreiner
All rights reserved. Use of this source code is governed by a
BSD-style license that can be found in the LICENSE file.
*/
#pragma once
#include "pybind11.h"
#include <streambuf>
#include <ostream>
#include <string>
#include <memory>
#include <iostream>
NAMESPACE_BEGIN(PYBIND11_NAMESPACE)
NAMESPACE_BEGIN(detail)
// Buffer that writes to Python instead of C++
class pythonbuf : public std::streambuf {
private:
using traits_type = std::streambuf::traits_type;
const size_t buf_size;
std::unique_ptr<char[]> d_buffer;
object pywrite;
object pyflush;
int overflow(int c) {
if (!traits_type::eq_int_type(c, traits_type::eof())) {
*pptr() = traits_type::to_char_type(c);
pbump(1);
}
return sync() == 0 ? traits_type::not_eof(c) : traits_type::eof();
}
int sync() {
if (pbase() != pptr()) {
// This subtraction cannot be negative, so dropping the sign
str line(pbase(), static_cast<size_t>(pptr() - pbase()));
{
gil_scoped_acquire tmp;
pywrite(line);
pyflush();
}
setp(pbase(), epptr());
}
return 0;
}
public:
pythonbuf(object pyostream, size_t buffer_size = 1024)
: buf_size(buffer_size),
d_buffer(new char[buf_size]),
pywrite(pyostream.attr("write")),
pyflush(pyostream.attr("flush")) {
setp(d_buffer.get(), d_buffer.get() + buf_size - 1);
}
/// Sync before destroy
~pythonbuf() {
sync();
}
};
NAMESPACE_END(detail)
/** \rst
This a move-only guard that redirects output.
.. code-block:: cpp
#include <pybind11/iostream.h>
...
{
py::scoped_ostream_redirect output;
std::cout << "Hello, World!"; // Python stdout
} // <-- return std::cout to normal
You can explicitly pass the c++ stream and the python object,
for example to guard stderr instead.
.. code-block:: cpp
{
py::scoped_ostream_redirect output{std::cerr, py::module::import("sys").attr("stderr")};
std::cerr << "Hello, World!";
}
\endrst */
class scoped_ostream_redirect {
protected:
std::streambuf *old;
std::ostream &costream;
detail::pythonbuf buffer;
public:
scoped_ostream_redirect(
std::ostream &costream = std::cout,
object pyostream = module::import("sys").attr("stdout"))
: costream(costream), buffer(pyostream) {
old = costream.rdbuf(&buffer);
}
~scoped_ostream_redirect() {
costream.rdbuf(old);
}
scoped_ostream_redirect(const scoped_ostream_redirect &) = delete;
scoped_ostream_redirect(scoped_ostream_redirect &&other) = default;
scoped_ostream_redirect &operator=(const scoped_ostream_redirect &) = delete;
scoped_ostream_redirect &operator=(scoped_ostream_redirect &&) = delete;
};
/** \rst
Like `scoped_ostream_redirect`, but redirects cerr by default. This class
is provided primary to make ``py::call_guard`` easier to make.
.. code-block:: cpp
m.def("noisy_func", &noisy_func,
py::call_guard<scoped_ostream_redirect,
scoped_estream_redirect>());
\endrst */
class scoped_estream_redirect : public scoped_ostream_redirect {
public:
scoped_estream_redirect(
std::ostream &costream = std::cerr,
object pyostream = module::import("sys").attr("stderr"))
: scoped_ostream_redirect(costream,pyostream) {}
};
NAMESPACE_BEGIN(detail)
// Class to redirect output as a context manager. C++ backend.
class OstreamRedirect {
bool do_stdout_;
bool do_stderr_;
std::unique_ptr<scoped_ostream_redirect> redirect_stdout;
std::unique_ptr<scoped_estream_redirect> redirect_stderr;
public:
OstreamRedirect(bool do_stdout = true, bool do_stderr = true)
: do_stdout_(do_stdout), do_stderr_(do_stderr) {}
void enter() {
if (do_stdout_)
redirect_stdout.reset(new scoped_ostream_redirect());
if (do_stderr_)
redirect_stderr.reset(new scoped_estream_redirect());
}
void exit() {
redirect_stdout.reset();
redirect_stderr.reset();
}
};
NAMESPACE_END(detail)
/** \rst
This is a helper function to add a C++ redirect context manager to Python
instead of using a C++ guard. To use it, add the following to your binding code:
.. code-block:: cpp
#include <pybind11/iostream.h>
...
py::add_ostream_redirect(m, "ostream_redirect");
You now have a Python context manager that redirects your output:
.. code-block:: python
with m.ostream_redirect():
m.print_to_cout_function()
This manager can optionally be told which streams to operate on:
.. code-block:: python
with m.ostream_redirect(stdout=true, stderr=true):
m.noisy_function_with_error_printing()
\endrst */
inline class_<detail::OstreamRedirect> add_ostream_redirect(module m, std::string name = "ostream_redirect") {
return class_<detail::OstreamRedirect>(m, name.c_str(), module_local())
.def(init<bool,bool>(), arg("stdout")=true, arg("stderr")=true)
.def("__enter__", &detail::OstreamRedirect::enter)
.def("__exit__", [](detail::OstreamRedirect &self_, args) { self_.exit(); });
}
NAMESPACE_END(PYBIND11_NAMESPACE)

1610
python/src/pybind11/numpy.h Normal file

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,168 @@
/*
pybind11/operator.h: Metatemplates for operator overloading
Copyright (c) 2016 Wenzel Jakob <wenzel.jakob@epfl.ch>
All rights reserved. Use of this source code is governed by a
BSD-style license that can be found in the LICENSE file.
*/
#pragma once
#include "pybind11.h"
#if defined(__clang__) && !defined(__INTEL_COMPILER)
# pragma clang diagnostic ignored "-Wunsequenced" // multiple unsequenced modifications to 'self' (when using def(py::self OP Type()))
#elif defined(_MSC_VER)
# pragma warning(push)
# pragma warning(disable: 4127) // warning C4127: Conditional expression is constant
#endif
NAMESPACE_BEGIN(PYBIND11_NAMESPACE)
NAMESPACE_BEGIN(detail)
/// Enumeration with all supported operator types
enum op_id : int {
op_add, op_sub, op_mul, op_div, op_mod, op_divmod, op_pow, op_lshift,
op_rshift, op_and, op_xor, op_or, op_neg, op_pos, op_abs, op_invert,
op_int, op_long, op_float, op_str, op_cmp, op_gt, op_ge, op_lt, op_le,
op_eq, op_ne, op_iadd, op_isub, op_imul, op_idiv, op_imod, op_ilshift,
op_irshift, op_iand, op_ixor, op_ior, op_complex, op_bool, op_nonzero,
op_repr, op_truediv, op_itruediv, op_hash
};
enum op_type : int {
op_l, /* base type on left */
op_r, /* base type on right */
op_u /* unary operator */
};
struct self_t { };
static const self_t self = self_t();
/// Type for an unused type slot
struct undefined_t { };
/// Don't warn about an unused variable
inline self_t __self() { return self; }
/// base template of operator implementations
template <op_id, op_type, typename B, typename L, typename R> struct op_impl { };
/// Operator implementation generator
template <op_id id, op_type ot, typename L, typename R> struct op_ {
template <typename Class, typename... Extra> void execute(Class &cl, const Extra&... extra) const {
using Base = typename Class::type;
using L_type = conditional_t<std::is_same<L, self_t>::value, Base, L>;
using R_type = conditional_t<std::is_same<R, self_t>::value, Base, R>;
using op = op_impl<id, ot, Base, L_type, R_type>;
cl.def(op::name(), &op::execute, is_operator(), extra...);
#if PY_MAJOR_VERSION < 3
if (id == op_truediv || id == op_itruediv)
cl.def(id == op_itruediv ? "__idiv__" : ot == op_l ? "__div__" : "__rdiv__",
&op::execute, is_operator(), extra...);
#endif
}
template <typename Class, typename... Extra> void execute_cast(Class &cl, const Extra&... extra) const {
using Base = typename Class::type;
using L_type = conditional_t<std::is_same<L, self_t>::value, Base, L>;
using R_type = conditional_t<std::is_same<R, self_t>::value, Base, R>;
using op = op_impl<id, ot, Base, L_type, R_type>;
cl.def(op::name(), &op::execute_cast, is_operator(), extra...);
#if PY_MAJOR_VERSION < 3
if (id == op_truediv || id == op_itruediv)
cl.def(id == op_itruediv ? "__idiv__" : ot == op_l ? "__div__" : "__rdiv__",
&op::execute, is_operator(), extra...);
#endif
}
};
#define PYBIND11_BINARY_OPERATOR(id, rid, op, expr) \
template <typename B, typename L, typename R> struct op_impl<op_##id, op_l, B, L, R> { \
static char const* name() { return "__" #id "__"; } \
static auto execute(const L &l, const R &r) -> decltype(expr) { return (expr); } \
static B execute_cast(const L &l, const R &r) { return B(expr); } \
}; \
template <typename B, typename L, typename R> struct op_impl<op_##id, op_r, B, L, R> { \
static char const* name() { return "__" #rid "__"; } \
static auto execute(const R &r, const L &l) -> decltype(expr) { return (expr); } \
static B execute_cast(const R &r, const L &l) { return B(expr); } \
}; \
inline op_<op_##id, op_l, self_t, self_t> op(const self_t &, const self_t &) { \
return op_<op_##id, op_l, self_t, self_t>(); \
} \
template <typename T> op_<op_##id, op_l, self_t, T> op(const self_t &, const T &) { \
return op_<op_##id, op_l, self_t, T>(); \
} \
template <typename T> op_<op_##id, op_r, T, self_t> op(const T &, const self_t &) { \
return op_<op_##id, op_r, T, self_t>(); \
}
#define PYBIND11_INPLACE_OPERATOR(id, op, expr) \
template <typename B, typename L, typename R> struct op_impl<op_##id, op_l, B, L, R> { \
static char const* name() { return "__" #id "__"; } \
static auto execute(L &l, const R &r) -> decltype(expr) { return expr; } \
static B execute_cast(L &l, const R &r) { return B(expr); } \
}; \
template <typename T> op_<op_##id, op_l, self_t, T> op(const self_t &, const T &) { \
return op_<op_##id, op_l, self_t, T>(); \
}
#define PYBIND11_UNARY_OPERATOR(id, op, expr) \
template <typename B, typename L> struct op_impl<op_##id, op_u, B, L, undefined_t> { \
static char const* name() { return "__" #id "__"; } \
static auto execute(const L &l) -> decltype(expr) { return expr; } \
static B execute_cast(const L &l) { return B(expr); } \
}; \
inline op_<op_##id, op_u, self_t, undefined_t> op(const self_t &) { \
return op_<op_##id, op_u, self_t, undefined_t>(); \
}
PYBIND11_BINARY_OPERATOR(sub, rsub, operator-, l - r)
PYBIND11_BINARY_OPERATOR(add, radd, operator+, l + r)
PYBIND11_BINARY_OPERATOR(mul, rmul, operator*, l * r)
PYBIND11_BINARY_OPERATOR(truediv, rtruediv, operator/, l / r)
PYBIND11_BINARY_OPERATOR(mod, rmod, operator%, l % r)
PYBIND11_BINARY_OPERATOR(lshift, rlshift, operator<<, l << r)
PYBIND11_BINARY_OPERATOR(rshift, rrshift, operator>>, l >> r)
PYBIND11_BINARY_OPERATOR(and, rand, operator&, l & r)
PYBIND11_BINARY_OPERATOR(xor, rxor, operator^, l ^ r)
PYBIND11_BINARY_OPERATOR(eq, eq, operator==, l == r)
PYBIND11_BINARY_OPERATOR(ne, ne, operator!=, l != r)
PYBIND11_BINARY_OPERATOR(or, ror, operator|, l | r)
PYBIND11_BINARY_OPERATOR(gt, lt, operator>, l > r)
PYBIND11_BINARY_OPERATOR(ge, le, operator>=, l >= r)
PYBIND11_BINARY_OPERATOR(lt, gt, operator<, l < r)
PYBIND11_BINARY_OPERATOR(le, ge, operator<=, l <= r)
//PYBIND11_BINARY_OPERATOR(pow, rpow, pow, std::pow(l, r))
PYBIND11_INPLACE_OPERATOR(iadd, operator+=, l += r)
PYBIND11_INPLACE_OPERATOR(isub, operator-=, l -= r)
PYBIND11_INPLACE_OPERATOR(imul, operator*=, l *= r)
PYBIND11_INPLACE_OPERATOR(itruediv, operator/=, l /= r)
PYBIND11_INPLACE_OPERATOR(imod, operator%=, l %= r)
PYBIND11_INPLACE_OPERATOR(ilshift, operator<<=, l <<= r)
PYBIND11_INPLACE_OPERATOR(irshift, operator>>=, l >>= r)
PYBIND11_INPLACE_OPERATOR(iand, operator&=, l &= r)
PYBIND11_INPLACE_OPERATOR(ixor, operator^=, l ^= r)
PYBIND11_INPLACE_OPERATOR(ior, operator|=, l |= r)
PYBIND11_UNARY_OPERATOR(neg, operator-, -l)
PYBIND11_UNARY_OPERATOR(pos, operator+, +l)
PYBIND11_UNARY_OPERATOR(abs, abs, std::abs(l))
PYBIND11_UNARY_OPERATOR(hash, hash, std::hash<L>()(l))
PYBIND11_UNARY_OPERATOR(invert, operator~, (~l))
PYBIND11_UNARY_OPERATOR(bool, operator!, !!l)
PYBIND11_UNARY_OPERATOR(int, int_, (int) l)
PYBIND11_UNARY_OPERATOR(float, float_, (double) l)
#undef PYBIND11_BINARY_OPERATOR
#undef PYBIND11_INPLACE_OPERATOR
#undef PYBIND11_UNARY_OPERATOR
NAMESPACE_END(detail)
using detail::self;
NAMESPACE_END(PYBIND11_NAMESPACE)
#if defined(_MSC_VER)
# pragma warning(pop)
#endif

View File

@@ -0,0 +1,65 @@
/*
pybind11/options.h: global settings that are configurable at runtime.
Copyright (c) 2016 Wenzel Jakob <wenzel.jakob@epfl.ch>
All rights reserved. Use of this source code is governed by a
BSD-style license that can be found in the LICENSE file.
*/
#pragma once
#include "detail/common.h"
NAMESPACE_BEGIN(PYBIND11_NAMESPACE)
class options {
public:
// Default RAII constructor, which leaves settings as they currently are.
options() : previous_state(global_state()) {}
// Class is non-copyable.
options(const options&) = delete;
options& operator=(const options&) = delete;
// Destructor, which restores settings that were in effect before.
~options() {
global_state() = previous_state;
}
// Setter methods (affect the global state):
options& disable_user_defined_docstrings() & { global_state().show_user_defined_docstrings = false; return *this; }
options& enable_user_defined_docstrings() & { global_state().show_user_defined_docstrings = true; return *this; }
options& disable_function_signatures() & { global_state().show_function_signatures = false; return *this; }
options& enable_function_signatures() & { global_state().show_function_signatures = true; return *this; }
// Getter methods (return the global state):
static bool show_user_defined_docstrings() { return global_state().show_user_defined_docstrings; }
static bool show_function_signatures() { return global_state().show_function_signatures; }
// This type is not meant to be allocated on the heap.
void* operator new(size_t) = delete;
private:
struct state {
bool show_user_defined_docstrings = true; //< Include user-supplied texts in docstrings.
bool show_function_signatures = true; //< Include auto-generated function signatures in docstrings.
};
static state &global_state() {
static state instance;
return instance;
}
state previous_state;
};
NAMESPACE_END(PYBIND11_NAMESPACE)

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

386
python/src/pybind11/stl.h Normal file
View File

@@ -0,0 +1,386 @@
/*
pybind11/stl.h: Transparent conversion for STL data types
Copyright (c) 2016 Wenzel Jakob <wenzel.jakob@epfl.ch>
All rights reserved. Use of this source code is governed by a
BSD-style license that can be found in the LICENSE file.
*/
#pragma once
#include "pybind11.h"
#include <set>
#include <unordered_set>
#include <map>
#include <unordered_map>
#include <iostream>
#include <list>
#include <deque>
#include <valarray>
#if defined(_MSC_VER)
#pragma warning(push)
#pragma warning(disable: 4127) // warning C4127: Conditional expression is constant
#endif
#ifdef __has_include
// std::optional (but including it in c++14 mode isn't allowed)
# if defined(PYBIND11_CPP17) && __has_include(<optional>)
# include <optional>
# define PYBIND11_HAS_OPTIONAL 1
# endif
// std::experimental::optional (but not allowed in c++11 mode)
# if defined(PYBIND11_CPP14) && (__has_include(<experimental/optional>) && \
!__has_include(<optional>))
# include <experimental/optional>
# define PYBIND11_HAS_EXP_OPTIONAL 1
# endif
// std::variant
# if defined(PYBIND11_CPP17) && __has_include(<variant>)
# include <variant>
# define PYBIND11_HAS_VARIANT 1
# endif
#elif defined(_MSC_VER) && defined(PYBIND11_CPP17)
# include <optional>
# include <variant>
# define PYBIND11_HAS_OPTIONAL 1
# define PYBIND11_HAS_VARIANT 1
#endif
NAMESPACE_BEGIN(PYBIND11_NAMESPACE)
NAMESPACE_BEGIN(detail)
/// Extracts an const lvalue reference or rvalue reference for U based on the type of T (e.g. for
/// forwarding a container element). Typically used indirect via forwarded_type(), below.
template <typename T, typename U>
using forwarded_type = conditional_t<
std::is_lvalue_reference<T>::value, remove_reference_t<U> &, remove_reference_t<U> &&>;
/// Forwards a value U as rvalue or lvalue according to whether T is rvalue or lvalue; typically
/// used for forwarding a container's elements.
template <typename T, typename U>
forwarded_type<T, U> forward_like(U &&u) {
return std::forward<detail::forwarded_type<T, U>>(std::forward<U>(u));
}
template <typename Type, typename Key> struct set_caster {
using type = Type;
using key_conv = make_caster<Key>;
bool load(handle src, bool convert) {
if (!isinstance<pybind11::set>(src))
return false;
auto s = reinterpret_borrow<pybind11::set>(src);
value.clear();
for (auto entry : s) {
key_conv conv;
if (!conv.load(entry, convert))
return false;
value.insert(cast_op<Key &&>(std::move(conv)));
}
return true;
}
template <typename T>
static handle cast(T &&src, return_value_policy policy, handle parent) {
if (!std::is_lvalue_reference<T>::value)
policy = return_value_policy_override<Key>::policy(policy);
pybind11::set s;
for (auto &&value : src) {
auto value_ = reinterpret_steal<object>(key_conv::cast(forward_like<T>(value), policy, parent));
if (!value_ || !s.add(value_))
return handle();
}
return s.release();
}
PYBIND11_TYPE_CASTER(type, _("Set[") + key_conv::name + _("]"));
};
template <typename Type, typename Key, typename Value> struct map_caster {
using key_conv = make_caster<Key>;
using value_conv = make_caster<Value>;
bool load(handle src, bool convert) {
if (!isinstance<dict>(src))
return false;
auto d = reinterpret_borrow<dict>(src);
value.clear();
for (auto it : d) {
key_conv kconv;
value_conv vconv;
if (!kconv.load(it.first.ptr(), convert) ||
!vconv.load(it.second.ptr(), convert))
return false;
value.emplace(cast_op<Key &&>(std::move(kconv)), cast_op<Value &&>(std::move(vconv)));
}
return true;
}
template <typename T>
static handle cast(T &&src, return_value_policy policy, handle parent) {
dict d;
return_value_policy policy_key = policy;
return_value_policy policy_value = policy;
if (!std::is_lvalue_reference<T>::value) {
policy_key = return_value_policy_override<Key>::policy(policy_key);
policy_value = return_value_policy_override<Value>::policy(policy_value);
}
for (auto &&kv : src) {
auto key = reinterpret_steal<object>(key_conv::cast(forward_like<T>(kv.first), policy_key, parent));
auto value = reinterpret_steal<object>(value_conv::cast(forward_like<T>(kv.second), policy_value, parent));
if (!key || !value)
return handle();
d[key] = value;
}
return d.release();
}
PYBIND11_TYPE_CASTER(Type, _("Dict[") + key_conv::name + _(", ") + value_conv::name + _("]"));
};
template <typename Type, typename Value> struct list_caster {
using value_conv = make_caster<Value>;
bool load(handle src, bool convert) {
if (!isinstance<sequence>(src) || isinstance<str>(src))
return false;
auto s = reinterpret_borrow<sequence>(src);
value.clear();
reserve_maybe(s, &value);
for (auto it : s) {
value_conv conv;
if (!conv.load(it, convert))
return false;
value.push_back(cast_op<Value &&>(std::move(conv)));
}
return true;
}
private:
template <typename T = Type,
enable_if_t<std::is_same<decltype(std::declval<T>().reserve(0)), void>::value, int> = 0>
void reserve_maybe(sequence s, Type *) { value.reserve(s.size()); }
void reserve_maybe(sequence, void *) { }
public:
template <typename T>
static handle cast(T &&src, return_value_policy policy, handle parent) {
if (!std::is_lvalue_reference<T>::value)
policy = return_value_policy_override<Value>::policy(policy);
list l(src.size());
size_t index = 0;
for (auto &&value : src) {
auto value_ = reinterpret_steal<object>(value_conv::cast(forward_like<T>(value), policy, parent));
if (!value_)
return handle();
PyList_SET_ITEM(l.ptr(), (ssize_t) index++, value_.release().ptr()); // steals a reference
}
return l.release();
}
PYBIND11_TYPE_CASTER(Type, _("List[") + value_conv::name + _("]"));
};
template <typename Type, typename Alloc> struct type_caster<std::vector<Type, Alloc>>
: list_caster<std::vector<Type, Alloc>, Type> { };
template <typename Type, typename Alloc> struct type_caster<std::deque<Type, Alloc>>
: list_caster<std::deque<Type, Alloc>, Type> { };
template <typename Type, typename Alloc> struct type_caster<std::list<Type, Alloc>>
: list_caster<std::list<Type, Alloc>, Type> { };
template <typename ArrayType, typename Value, bool Resizable, size_t Size = 0> struct array_caster {
using value_conv = make_caster<Value>;
private:
template <bool R = Resizable>
bool require_size(enable_if_t<R, size_t> size) {
if (value.size() != size)
value.resize(size);
return true;
}
template <bool R = Resizable>
bool require_size(enable_if_t<!R, size_t> size) {
return size == Size;
}
public:
bool load(handle src, bool convert) {
if (!isinstance<sequence>(src))
return false;
auto l = reinterpret_borrow<sequence>(src);
if (!require_size(l.size()))
return false;
size_t ctr = 0;
for (auto it : l) {
value_conv conv;
if (!conv.load(it, convert))
return false;
value[ctr++] = cast_op<Value &&>(std::move(conv));
}
return true;
}
template <typename T>
static handle cast(T &&src, return_value_policy policy, handle parent) {
list l(src.size());
size_t index = 0;
for (auto &&value : src) {
auto value_ = reinterpret_steal<object>(value_conv::cast(forward_like<T>(value), policy, parent));
if (!value_)
return handle();
PyList_SET_ITEM(l.ptr(), (ssize_t) index++, value_.release().ptr()); // steals a reference
}
return l.release();
}
PYBIND11_TYPE_CASTER(ArrayType, _("List[") + value_conv::name + _<Resizable>(_(""), _("[") + _<Size>() + _("]")) + _("]"));
};
template <typename Type, size_t Size> struct type_caster<std::array<Type, Size>>
: array_caster<std::array<Type, Size>, Type, false, Size> { };
template <typename Type> struct type_caster<std::valarray<Type>>
: array_caster<std::valarray<Type>, Type, true> { };
template <typename Key, typename Compare, typename Alloc> struct type_caster<std::set<Key, Compare, Alloc>>
: set_caster<std::set<Key, Compare, Alloc>, Key> { };
template <typename Key, typename Hash, typename Equal, typename Alloc> struct type_caster<std::unordered_set<Key, Hash, Equal, Alloc>>
: set_caster<std::unordered_set<Key, Hash, Equal, Alloc>, Key> { };
template <typename Key, typename Value, typename Compare, typename Alloc> struct type_caster<std::map<Key, Value, Compare, Alloc>>
: map_caster<std::map<Key, Value, Compare, Alloc>, Key, Value> { };
template <typename Key, typename Value, typename Hash, typename Equal, typename Alloc> struct type_caster<std::unordered_map<Key, Value, Hash, Equal, Alloc>>
: map_caster<std::unordered_map<Key, Value, Hash, Equal, Alloc>, Key, Value> { };
// This type caster is intended to be used for std::optional and std::experimental::optional
template<typename T> struct optional_caster {
using value_conv = make_caster<typename T::value_type>;
template <typename T_>
static handle cast(T_ &&src, return_value_policy policy, handle parent) {
if (!src)
return none().inc_ref();
policy = return_value_policy_override<typename T::value_type>::policy(policy);
return value_conv::cast(*std::forward<T_>(src), policy, parent);
}
bool load(handle src, bool convert) {
if (!src) {
return false;
} else if (src.is_none()) {
return true; // default-constructed value is already empty
}
value_conv inner_caster;
if (!inner_caster.load(src, convert))
return false;
value.emplace(cast_op<typename T::value_type &&>(std::move(inner_caster)));
return true;
}
PYBIND11_TYPE_CASTER(T, _("Optional[") + value_conv::name + _("]"));
};
#if PYBIND11_HAS_OPTIONAL
template<typename T> struct type_caster<std::optional<T>>
: public optional_caster<std::optional<T>> {};
template<> struct type_caster<std::nullopt_t>
: public void_caster<std::nullopt_t> {};
#endif
#if PYBIND11_HAS_EXP_OPTIONAL
template<typename T> struct type_caster<std::experimental::optional<T>>
: public optional_caster<std::experimental::optional<T>> {};
template<> struct type_caster<std::experimental::nullopt_t>
: public void_caster<std::experimental::nullopt_t> {};
#endif
/// Visit a variant and cast any found type to Python
struct variant_caster_visitor {
return_value_policy policy;
handle parent;
using result_type = handle; // required by boost::variant in C++11
template <typename T>
result_type operator()(T &&src) const {
return make_caster<T>::cast(std::forward<T>(src), policy, parent);
}
};
/// Helper class which abstracts away variant's `visit` function. `std::variant` and similar
/// `namespace::variant` types which provide a `namespace::visit()` function are handled here
/// automatically using argument-dependent lookup. Users can provide specializations for other
/// variant-like classes, e.g. `boost::variant` and `boost::apply_visitor`.
template <template<typename...> class Variant>
struct visit_helper {
template <typename... Args>
static auto call(Args &&...args) -> decltype(visit(std::forward<Args>(args)...)) {
return visit(std::forward<Args>(args)...);
}
};
/// Generic variant caster
template <typename Variant> struct variant_caster;
template <template<typename...> class V, typename... Ts>
struct variant_caster<V<Ts...>> {
static_assert(sizeof...(Ts) > 0, "Variant must consist of at least one alternative.");
template <typename U, typename... Us>
bool load_alternative(handle src, bool convert, type_list<U, Us...>) {
auto caster = make_caster<U>();
if (caster.load(src, convert)) {
value = cast_op<U>(caster);
return true;
}
return load_alternative(src, convert, type_list<Us...>{});
}
bool load_alternative(handle, bool, type_list<>) { return false; }
bool load(handle src, bool convert) {
// Do a first pass without conversions to improve constructor resolution.
// E.g. `py::int_(1).cast<variant<double, int>>()` needs to fill the `int`
// slot of the variant. Without two-pass loading `double` would be filled
// because it appears first and a conversion is possible.
if (convert && load_alternative(src, false, type_list<Ts...>{}))
return true;
return load_alternative(src, convert, type_list<Ts...>{});
}
template <typename Variant>
static handle cast(Variant &&src, return_value_policy policy, handle parent) {
return visit_helper<V>::call(variant_caster_visitor{policy, parent},
std::forward<Variant>(src));
}
using Type = V<Ts...>;
PYBIND11_TYPE_CASTER(Type, _("Union[") + detail::concat(make_caster<Ts>::name...) + _("]"));
};
#if PYBIND11_HAS_VARIANT
template <typename... Ts>
struct type_caster<std::variant<Ts...>> : variant_caster<std::variant<Ts...>> { };
#endif
NAMESPACE_END(detail)
inline std::ostream &operator<<(std::ostream &os, const handle &obj) {
os << (std::string) str(obj);
return os;
}
NAMESPACE_END(PYBIND11_NAMESPACE)
#if defined(_MSC_VER)
#pragma warning(pop)
#endif

View File

@@ -0,0 +1,630 @@
/*
pybind11/std_bind.h: Binding generators for STL data types
Copyright (c) 2016 Sergey Lyskov and Wenzel Jakob
All rights reserved. Use of this source code is governed by a
BSD-style license that can be found in the LICENSE file.
*/
#pragma once
#include "detail/common.h"
#include "operators.h"
#include <algorithm>
#include <sstream>
NAMESPACE_BEGIN(PYBIND11_NAMESPACE)
NAMESPACE_BEGIN(detail)
/* SFINAE helper class used by 'is_comparable */
template <typename T> struct container_traits {
template <typename T2> static std::true_type test_comparable(decltype(std::declval<const T2 &>() == std::declval<const T2 &>())*);
template <typename T2> static std::false_type test_comparable(...);
template <typename T2> static std::true_type test_value(typename T2::value_type *);
template <typename T2> static std::false_type test_value(...);
template <typename T2> static std::true_type test_pair(typename T2::first_type *, typename T2::second_type *);
template <typename T2> static std::false_type test_pair(...);
static constexpr const bool is_comparable = std::is_same<std::true_type, decltype(test_comparable<T>(nullptr))>::value;
static constexpr const bool is_pair = std::is_same<std::true_type, decltype(test_pair<T>(nullptr, nullptr))>::value;
static constexpr const bool is_vector = std::is_same<std::true_type, decltype(test_value<T>(nullptr))>::value;
static constexpr const bool is_element = !is_pair && !is_vector;
};
/* Default: is_comparable -> std::false_type */
template <typename T, typename SFINAE = void>
struct is_comparable : std::false_type { };
/* For non-map data structures, check whether operator== can be instantiated */
template <typename T>
struct is_comparable<
T, enable_if_t<container_traits<T>::is_element &&
container_traits<T>::is_comparable>>
: std::true_type { };
/* For a vector/map data structure, recursively check the value type (which is std::pair for maps) */
template <typename T>
struct is_comparable<T, enable_if_t<container_traits<T>::is_vector>> {
static constexpr const bool value =
is_comparable<typename T::value_type>::value;
};
/* For pairs, recursively check the two data types */
template <typename T>
struct is_comparable<T, enable_if_t<container_traits<T>::is_pair>> {
static constexpr const bool value =
is_comparable<typename T::first_type>::value &&
is_comparable<typename T::second_type>::value;
};
/* Fallback functions */
template <typename, typename, typename... Args> void vector_if_copy_constructible(const Args &...) { }
template <typename, typename, typename... Args> void vector_if_equal_operator(const Args &...) { }
template <typename, typename, typename... Args> void vector_if_insertion_operator(const Args &...) { }
template <typename, typename, typename... Args> void vector_modifiers(const Args &...) { }
template<typename Vector, typename Class_>
void vector_if_copy_constructible(enable_if_t<is_copy_constructible<Vector>::value, Class_> &cl) {
cl.def(init<const Vector &>(), "Copy constructor");
}
template<typename Vector, typename Class_>
void vector_if_equal_operator(enable_if_t<is_comparable<Vector>::value, Class_> &cl) {
using T = typename Vector::value_type;
cl.def(self == self);
cl.def(self != self);
cl.def("count",
[](const Vector &v, const T &x) {
return std::count(v.begin(), v.end(), x);
},
arg("x"),
"Return the number of times ``x`` appears in the list"
);
cl.def("remove", [](Vector &v, const T &x) {
auto p = std::find(v.begin(), v.end(), x);
if (p != v.end())
v.erase(p);
else
throw value_error();
},
arg("x"),
"Remove the first item from the list whose value is x. "
"It is an error if there is no such item."
);
cl.def("__contains__",
[](const Vector &v, const T &x) {
return std::find(v.begin(), v.end(), x) != v.end();
},
arg("x"),
"Return true the container contains ``x``"
);
}
// Vector modifiers -- requires a copyable vector_type:
// (Technically, some of these (pop and __delitem__) don't actually require copyability, but it seems
// silly to allow deletion but not insertion, so include them here too.)
template <typename Vector, typename Class_>
void vector_modifiers(enable_if_t<is_copy_constructible<typename Vector::value_type>::value, Class_> &cl) {
using T = typename Vector::value_type;
using SizeType = typename Vector::size_type;
using DiffType = typename Vector::difference_type;
cl.def("append",
[](Vector &v, const T &value) { v.push_back(value); },
arg("x"),
"Add an item to the end of the list");
cl.def(init([](iterable it) {
auto v = std::unique_ptr<Vector>(new Vector());
v->reserve(len_hint(it));
for (handle h : it)
v->push_back(h.cast<T>());
return v.release();
}));
cl.def("extend",
[](Vector &v, const Vector &src) {
v.insert(v.end(), src.begin(), src.end());
},
arg("L"),
"Extend the list by appending all the items in the given list"
);
cl.def("extend",
[](Vector &v, iterable it) {
const size_t old_size = v.size();
v.reserve(old_size + len_hint(it));
try {
for (handle h : it) {
v.push_back(h.cast<T>());
}
} catch (const cast_error &) {
v.erase(v.begin() + static_cast<typename Vector::difference_type>(old_size), v.end());
try {
v.shrink_to_fit();
} catch (const std::exception &) {
// Do nothing
}
throw;
}
},
arg("L"),
"Extend the list by appending all the items in the given list"
);
cl.def("insert",
[](Vector &v, SizeType i, const T &x) {
if (i > v.size())
throw index_error();
v.insert(v.begin() + (DiffType) i, x);
},
arg("i") , arg("x"),
"Insert an item at a given position."
);
cl.def("pop",
[](Vector &v) {
if (v.empty())
throw index_error();
T t = v.back();
v.pop_back();
return t;
},
"Remove and return the last item"
);
cl.def("pop",
[](Vector &v, SizeType i) {
if (i >= v.size())
throw index_error();
T t = v[i];
v.erase(v.begin() + (DiffType) i);
return t;
},
arg("i"),
"Remove and return the item at index ``i``"
);
cl.def("__setitem__",
[](Vector &v, SizeType i, const T &t) {
if (i >= v.size())
throw index_error();
v[i] = t;
}
);
/// Slicing protocol
cl.def("__getitem__",
[](const Vector &v, slice slice) -> Vector * {
size_t start, stop, step, slicelength;
if (!slice.compute(v.size(), &start, &stop, &step, &slicelength))
throw error_already_set();
Vector *seq = new Vector();
seq->reserve((size_t) slicelength);
for (size_t i=0; i<slicelength; ++i) {
seq->push_back(v[start]);
start += step;
}
return seq;
},
arg("s"),
"Retrieve list elements using a slice object"
);
cl.def("__setitem__",
[](Vector &v, slice slice, const Vector &value) {
size_t start, stop, step, slicelength;
if (!slice.compute(v.size(), &start, &stop, &step, &slicelength))
throw error_already_set();
if (slicelength != value.size())
throw std::runtime_error("Left and right hand size of slice assignment have different sizes!");
for (size_t i=0; i<slicelength; ++i) {
v[start] = value[i];
start += step;
}
},
"Assign list elements using a slice object"
);
cl.def("__delitem__",
[](Vector &v, SizeType i) {
if (i >= v.size())
throw index_error();
v.erase(v.begin() + DiffType(i));
},
"Delete the list elements at index ``i``"
);
cl.def("__delitem__",
[](Vector &v, slice slice) {
size_t start, stop, step, slicelength;
if (!slice.compute(v.size(), &start, &stop, &step, &slicelength))
throw error_already_set();
if (step == 1 && false) {
v.erase(v.begin() + (DiffType) start, v.begin() + DiffType(start + slicelength));
} else {
for (size_t i = 0; i < slicelength; ++i) {
v.erase(v.begin() + DiffType(start));
start += step - 1;
}
}
},
"Delete list elements using a slice object"
);
}
// If the type has an operator[] that doesn't return a reference (most notably std::vector<bool>),
// we have to access by copying; otherwise we return by reference.
template <typename Vector> using vector_needs_copy = negation<
std::is_same<decltype(std::declval<Vector>()[typename Vector::size_type()]), typename Vector::value_type &>>;
// The usual case: access and iterate by reference
template <typename Vector, typename Class_>
void vector_accessor(enable_if_t<!vector_needs_copy<Vector>::value, Class_> &cl) {
using T = typename Vector::value_type;
using SizeType = typename Vector::size_type;
using ItType = typename Vector::iterator;
cl.def("__getitem__",
[](Vector &v, SizeType i) -> T & {
if (i >= v.size())
throw index_error();
return v[i];
},
return_value_policy::reference_internal // ref + keepalive
);
cl.def("__iter__",
[](Vector &v) {
return make_iterator<
return_value_policy::reference_internal, ItType, ItType, T&>(
v.begin(), v.end());
},
keep_alive<0, 1>() /* Essential: keep list alive while iterator exists */
);
}
// The case for special objects, like std::vector<bool>, that have to be returned-by-copy:
template <typename Vector, typename Class_>
void vector_accessor(enable_if_t<vector_needs_copy<Vector>::value, Class_> &cl) {
using T = typename Vector::value_type;
using SizeType = typename Vector::size_type;
using ItType = typename Vector::iterator;
cl.def("__getitem__",
[](const Vector &v, SizeType i) -> T {
if (i >= v.size())
throw index_error();
return v[i];
}
);
cl.def("__iter__",
[](Vector &v) {
return make_iterator<
return_value_policy::copy, ItType, ItType, T>(
v.begin(), v.end());
},
keep_alive<0, 1>() /* Essential: keep list alive while iterator exists */
);
}
template <typename Vector, typename Class_> auto vector_if_insertion_operator(Class_ &cl, std::string const &name)
-> decltype(std::declval<std::ostream&>() << std::declval<typename Vector::value_type>(), void()) {
using size_type = typename Vector::size_type;
cl.def("__repr__",
[name](Vector &v) {
std::ostringstream s;
s << name << '[';
for (size_type i=0; i < v.size(); ++i) {
s << v[i];
if (i != v.size() - 1)
s << ", ";
}
s << ']';
return s.str();
},
"Return the canonical string representation of this list."
);
}
// Provide the buffer interface for vectors if we have data() and we have a format for it
// GCC seems to have "void std::vector<bool>::data()" - doing SFINAE on the existence of data() is insufficient, we need to check it returns an appropriate pointer
template <typename Vector, typename = void>
struct vector_has_data_and_format : std::false_type {};
template <typename Vector>
struct vector_has_data_and_format<Vector, enable_if_t<std::is_same<decltype(format_descriptor<typename Vector::value_type>::format(), std::declval<Vector>().data()), typename Vector::value_type*>::value>> : std::true_type {};
// Add the buffer interface to a vector
template <typename Vector, typename Class_, typename... Args>
enable_if_t<detail::any_of<std::is_same<Args, buffer_protocol>...>::value>
vector_buffer(Class_& cl) {
using T = typename Vector::value_type;
static_assert(vector_has_data_and_format<Vector>::value, "There is not an appropriate format descriptor for this vector");
// numpy.h declares this for arbitrary types, but it may raise an exception and crash hard at runtime if PYBIND11_NUMPY_DTYPE hasn't been called, so check here
format_descriptor<T>::format();
cl.def_buffer([](Vector& v) -> buffer_info {
return buffer_info(v.data(), static_cast<ssize_t>(sizeof(T)), format_descriptor<T>::format(), 1, {v.size()}, {sizeof(T)});
});
cl.def(init([](buffer buf) {
auto info = buf.request();
if (info.ndim != 1 || info.strides[0] % static_cast<ssize_t>(sizeof(T)))
throw type_error("Only valid 1D buffers can be copied to a vector");
if (!detail::compare_buffer_info<T>::compare(info) || (ssize_t) sizeof(T) != info.itemsize)
throw type_error("Format mismatch (Python: " + info.format + " C++: " + format_descriptor<T>::format() + ")");
auto vec = std::unique_ptr<Vector>(new Vector());
vec->reserve((size_t) info.shape[0]);
T *p = static_cast<T*>(info.ptr);
ssize_t step = info.strides[0] / static_cast<ssize_t>(sizeof(T));
T *end = p + info.shape[0] * step;
for (; p != end; p += step)
vec->push_back(*p);
return vec.release();
}));
return;
}
template <typename Vector, typename Class_, typename... Args>
enable_if_t<!detail::any_of<std::is_same<Args, buffer_protocol>...>::value> vector_buffer(Class_&) {}
NAMESPACE_END(detail)
//
// std::vector
//
template <typename Vector, typename holder_type = std::unique_ptr<Vector>, typename... Args>
class_<Vector, holder_type> bind_vector(handle scope, std::string const &name, Args&&... args) {
using Class_ = class_<Vector, holder_type>;
// If the value_type is unregistered (e.g. a converting type) or is itself registered
// module-local then make the vector binding module-local as well:
using vtype = typename Vector::value_type;
auto vtype_info = detail::get_type_info(typeid(vtype));
bool local = !vtype_info || vtype_info->module_local;
Class_ cl(scope, name.c_str(), pybind11::module_local(local), std::forward<Args>(args)...);
// Declare the buffer interface if a buffer_protocol() is passed in
detail::vector_buffer<Vector, Class_, Args...>(cl);
cl.def(init<>());
// Register copy constructor (if possible)
detail::vector_if_copy_constructible<Vector, Class_>(cl);
// Register comparison-related operators and functions (if possible)
detail::vector_if_equal_operator<Vector, Class_>(cl);
// Register stream insertion operator (if possible)
detail::vector_if_insertion_operator<Vector, Class_>(cl, name);
// Modifiers require copyable vector value type
detail::vector_modifiers<Vector, Class_>(cl);
// Accessor and iterator; return by value if copyable, otherwise we return by ref + keep-alive
detail::vector_accessor<Vector, Class_>(cl);
cl.def("__bool__",
[](const Vector &v) -> bool {
return !v.empty();
},
"Check whether the list is nonempty"
);
cl.def("__len__", &Vector::size);
#if 0
// C++ style functions deprecated, leaving it here as an example
cl.def(init<size_type>());
cl.def("resize",
(void (Vector::*) (size_type count)) & Vector::resize,
"changes the number of elements stored");
cl.def("erase",
[](Vector &v, SizeType i) {
if (i >= v.size())
throw index_error();
v.erase(v.begin() + i);
}, "erases element at index ``i``");
cl.def("empty", &Vector::empty, "checks whether the container is empty");
cl.def("size", &Vector::size, "returns the number of elements");
cl.def("push_back", (void (Vector::*)(const T&)) &Vector::push_back, "adds an element to the end");
cl.def("pop_back", &Vector::pop_back, "removes the last element");
cl.def("max_size", &Vector::max_size, "returns the maximum possible number of elements");
cl.def("reserve", &Vector::reserve, "reserves storage");
cl.def("capacity", &Vector::capacity, "returns the number of elements that can be held in currently allocated storage");
cl.def("shrink_to_fit", &Vector::shrink_to_fit, "reduces memory usage by freeing unused memory");
cl.def("clear", &Vector::clear, "clears the contents");
cl.def("swap", &Vector::swap, "swaps the contents");
cl.def("front", [](Vector &v) {
if (v.size()) return v.front();
else throw index_error();
}, "access the first element");
cl.def("back", [](Vector &v) {
if (v.size()) return v.back();
else throw index_error();
}, "access the last element ");
#endif
return cl;
}
//
// std::map, std::unordered_map
//
NAMESPACE_BEGIN(detail)
/* Fallback functions */
template <typename, typename, typename... Args> void map_if_insertion_operator(const Args &...) { }
template <typename, typename, typename... Args> void map_assignment(const Args &...) { }
// Map assignment when copy-assignable: just copy the value
template <typename Map, typename Class_>
void map_assignment(enable_if_t<std::is_copy_assignable<typename Map::mapped_type>::value, Class_> &cl) {
using KeyType = typename Map::key_type;
using MappedType = typename Map::mapped_type;
cl.def("__setitem__",
[](Map &m, const KeyType &k, const MappedType &v) {
auto it = m.find(k);
if (it != m.end()) it->second = v;
else m.emplace(k, v);
}
);
}
// Not copy-assignable, but still copy-constructible: we can update the value by erasing and reinserting
template<typename Map, typename Class_>
void map_assignment(enable_if_t<
!std::is_copy_assignable<typename Map::mapped_type>::value &&
is_copy_constructible<typename Map::mapped_type>::value,
Class_> &cl) {
using KeyType = typename Map::key_type;
using MappedType = typename Map::mapped_type;
cl.def("__setitem__",
[](Map &m, const KeyType &k, const MappedType &v) {
// We can't use m[k] = v; because value type might not be default constructable
auto r = m.emplace(k, v);
if (!r.second) {
// value type is not copy assignable so the only way to insert it is to erase it first...
m.erase(r.first);
m.emplace(k, v);
}
}
);
}
template <typename Map, typename Class_> auto map_if_insertion_operator(Class_ &cl, std::string const &name)
-> decltype(std::declval<std::ostream&>() << std::declval<typename Map::key_type>() << std::declval<typename Map::mapped_type>(), void()) {
cl.def("__repr__",
[name](Map &m) {
std::ostringstream s;
s << name << '{';
bool f = false;
for (auto const &kv : m) {
if (f)
s << ", ";
s << kv.first << ": " << kv.second;
f = true;
}
s << '}';
return s.str();
},
"Return the canonical string representation of this map."
);
}
NAMESPACE_END(detail)
template <typename Map, typename holder_type = std::unique_ptr<Map>, typename... Args>
class_<Map, holder_type> bind_map(handle scope, const std::string &name, Args&&... args) {
using KeyType = typename Map::key_type;
using MappedType = typename Map::mapped_type;
using Class_ = class_<Map, holder_type>;
// If either type is a non-module-local bound type then make the map binding non-local as well;
// otherwise (e.g. both types are either module-local or converting) the map will be
// module-local.
auto tinfo = detail::get_type_info(typeid(MappedType));
bool local = !tinfo || tinfo->module_local;
if (local) {
tinfo = detail::get_type_info(typeid(KeyType));
local = !tinfo || tinfo->module_local;
}
Class_ cl(scope, name.c_str(), pybind11::module_local(local), std::forward<Args>(args)...);
cl.def(init<>());
// Register stream insertion operator (if possible)
detail::map_if_insertion_operator<Map, Class_>(cl, name);
cl.def("__bool__",
[](const Map &m) -> bool { return !m.empty(); },
"Check whether the map is nonempty"
);
cl.def("__iter__",
[](Map &m) { return make_key_iterator(m.begin(), m.end()); },
keep_alive<0, 1>() /* Essential: keep list alive while iterator exists */
);
cl.def("items",
[](Map &m) { return make_iterator(m.begin(), m.end()); },
keep_alive<0, 1>() /* Essential: keep list alive while iterator exists */
);
cl.def("__getitem__",
[](Map &m, const KeyType &k) -> MappedType & {
auto it = m.find(k);
if (it == m.end())
throw key_error();
return it->second;
},
return_value_policy::reference_internal // ref + keepalive
);
cl.def("__contains__",
[](Map &m, const KeyType &k) -> bool {
auto it = m.find(k);
if (it == m.end())
return false;
return true;
}
);
// Assignment provided only if the type is copyable
detail::map_assignment<Map, Class_>(cl);
cl.def("__delitem__",
[](Map &m, const KeyType &k) {
auto it = m.find(k);
if (it == m.end())
throw key_error();
m.erase(it);
}
);
cl.def("__len__", &Map::size);
return cl;
}
NAMESPACE_END(PYBIND11_NAMESPACE)

1
python/triton/_C/include Symbolic link
View File

@@ -0,0 +1 @@
../../../include/

12
python/triton/__init__.py Normal file
View File

@@ -0,0 +1,12 @@
from .kernel import *
from .function import *
from .utils import *
import triton.ops
# clean-up libtriton resources
import atexit
import triton._C.libtriton as libtriton
@atexit.register
def cleanup():
libtriton.cleanup()

View File

@@ -0,0 +1,28 @@
import sys
import os
import triton._C.libtriton as libtriton
torch = None
tensorflow = None
def _import_torch():
global torch
if torch is None:
import torch
def _import_tensorflow():
global tensorflow
if tensorflow is None:
import tensorflow
def has_tensorflow():
result = 'tensorflow' in sys.modules
if result:
_import_tensorflow()
return result
def has_torch():
result = 'torch' in sys.modules
if result:
_import_torch()
return result

127
python/triton/function.py Normal file
View File

@@ -0,0 +1,127 @@
import triton.frameworks as fw
import triton.utils as utils
class OpContext(object):
def __init__(self):
self.to_save = []
def save_for_backward(self, *tensors):
self.to_save = [x.to_tensor() if isinstance(x, utils.tf_empty_proxy) else x
for x in tensors]
@property
def saved_tensors(self):
return self.to_save
class function_meta(type):
def __init__(cls, name, bases, attrs):
cls.registered = False
return super(function_meta, cls).__init__(name, bases, attrs)
ctx_registry = utils.id_dict()
class function(metaclass = function_meta):
@staticmethod
def forward(ctx, *args, **kwargs):
raise NotImplementedError
@staticmethod
def backward(ctx, grad_output):
raise NotImplementedError
@classmethod
def apply_torch(cls, *args, **kwargs):
class TorchFunction(fw.torch.autograd.Function):
@staticmethod
def forward(ctx, *targs):
y = cls.forward(ctx, *targs, **cls.torch_kwargs)
ctx_registry[y] = ctx
return y
@staticmethod
def backward(ctx, grad_output):
return cls.backward(ctx, grad_output)
cls.torch_kwargs = kwargs
return TorchFunction.apply(*args)
torch_kwargs = 0
@classmethod
def extract_tf_tensors(cls, lst, err):
ret = []
for x in lst:
if x is None:
ret += [None]
elif isinstance(x, fw.tensorflow.Tensor):
ret += [x]
elif isinstance(x, utils.tf_empty_proxy):
if x.tensor is None:
raise ValueError('Empty tensor never filled during ' + err)
else:
ret += [x.tensor]
else:
raise ValueError('Unsupported return type', type(x))
return ret
@classmethod
def map_in_to_args(cls, op, args):
ret = dict()
for i, ix in enumerate(op.inputs):
for j, jx in enumerate(args):
if ix is jx:
ret[j] = i
return ret
@classmethod
def map_res_to_out(cls, op, result):
ret = []
for i, ix in enumerate(result):
for j, jx in enumerate(op.outputs):
if ix is jx:
ret.append(j)
return ret
@classmethod
def apply_tensorflow(cls, *args, **kwargs):
ctx = OpContext()
# run forward pass
result = cls.forward(ctx, *args, **kwargs)
result = result if isinstance(result, tuple) else (result, )
result = function.extract_tf_tensors(result, 'forward')
# Register backward pass
op = result[0].op
ctx_registry[op] = ctx
if not cls.registered:
remap_in = cls.map_in_to_args(op, args)
remap_out = cls.map_res_to_out(op, result)
@fw.tensorflow.RegisterGradient(op.op_def.name)
def gradient(op, *dy):
# Remap gradient inputs in the right order
dy = [dy[i] for i in remap_out]
dy = dy if len(dy) > 1 else dy[0]
# Execute gradient function
grad = cls.backward(ctx_registry[op], dy)
grad = function.extract_tf_tensors(grad, 'backward')
# Remap gradient in the right order
ret = [None] * len(op.inputs)
for i in range(len(grad)):
if i in remap_in:
ret[remap_in[i]] = grad[i]
# Return
return ret
cls.registered = True
# Return tensor
return result[0] if len(result)==1 else result
@classmethod
def apply(cls, *args, **kwargs):
if fw.has_tensorflow():
return cls.apply_tensorflow(*args, **kwargs)
elif fw.has_torch():
return cls.apply_torch(*args, **kwargs)
else:
assert False

298
python/triton/kernel.py Normal file
View File

@@ -0,0 +1,298 @@
# import for cache
import os
import tempfile
import shutil
import hashlib
import sysconfig
import sys
import weakref
import contextlib
import io
# import for just-in-time compilation
import distutils
import setuptools.command.build_ext
import setuptools
# triton
import triton.frameworks as fw
import triton.utils
import triton._C.libtriton as libtriton
def _make_framework_src(src, out, tmp, grid):
if fw.has_tensorflow():
return libtriton.make_tensorflow_src(src, out, tmp, grid)
elif fw.has_torch:
return libtriton.make_torch_src(src, out, tmp, grid)
else:
assert False
def _make_cache_path(src):
md5 = hashlib.sha1(src.encode())
hexhash = md5.hexdigest()
home = os.path.expanduser('~')
cacheroot = os.path.join(home, '.triton', 'cache')
cachepath = os.path.join(cacheroot, str(hexhash))
if not os.path.exists(cachepath):
os.makedirs(cachepath)
return cachepath
def _write_bindings(src, root):
if fw.has_tensorflow():
name = 'tensorflow'
elif fw.has_torch():
name = 'torch'
else:
assert False
cpp = os.path.join(root, '{name}.cpp'.format(name=name))
suffix = sysconfig.get_config_var('EXT_SUFFIX')
so = os.path.join(root, '{name}{suffix}'.format(name=name, suffix=suffix))
recompile = False
# recompile if .so does not exist
if not os.path.exists(cpp) or not os.path.exists(so):
recompile = True
# recompile if cpp was modified after .so
elif max(cpp, so, key=os.path.getctime) == cpp:
recompile = True
# write cpp file
if recompile:
with open(cpp, 'w+') as handle:
handle.writelines(src)
# return path of cpp file
return (cpp, so)
@contextlib.contextmanager
def quiet():
old_stdout, old_stderr = sys.stdout, sys.stderr
sys.stdout, sys.stderr = io.StringIO(), io.StringIO()
try:
yield
finally:
sys.stdout, sys.stderr = old_stdout, old_stderr
def _build(src, path):
ccdir = os.path.join(libtriton.__file__, os.path.pardir)
ccdir = os.path.realpath(ccdir)
# include directories
triton_include_dirs = [os.path.join(ccdir, 'include')]
include_dirs = triton_include_dirs
# library directories
triton_library_dirs = [ccdir]
library_dirs = triton_library_dirs
# libraries
libraries = ['triton']
# add framework
extra_compile_args = []
if fw.has_tensorflow():
library_dirs += [fw.tensorflow.sysconfig.get_lib()]
include_dirs += [fw.tensorflow.sysconfig.get_include()]
include_dirs += ['/usr/local/cuda/include/']
libraries += [fw.tensorflow.sysconfig.get_link_flags()[1].replace('-l', '')]
abi = fw.tensorflow.__cxx11_abi_flag__ if "__cxx11_abi_flag__" in fw.tensorflow.__dict__ else 0
extra_compile_args += ['-D_GLIBCXX_USE_CXX11_ABI={abi}'.format(abi=abi)]
name = 'tensorflow'
elif fw.has_torch():
prefix = os.path.dirname(fw.torch.__file__)
library_dirs += [os.path.join(prefix, 'lib')]
include_dirs += ['/usr/local/cuda/include/',
os.path.join(prefix, 'lib', 'include'),
os.path.join(prefix, 'lib', 'include', 'torch', 'csrc', 'api', 'include'),
os.path.join(prefix, 'include'),
os.path.join(prefix, 'include', 'torch', 'csrc', 'api', 'include')]
libraries += ['torch']
abi = fw.torch._C._GLIBCXX_USE_CXX11_ABI
extra_compile_args += ['-D_GLIBCXX_USE_CXX11_ABI={abi}'.format(abi=abi)]
name = 'torch'
else:
assert False
# extra arguments
extra_link_args = []
# dependences
depends = [os.path.realpath(libtriton.__file__)]
# create extension module
ext = setuptools.Extension(
name = name,
language = 'c++',
sources = [src],
include_dirs = include_dirs,
extra_compile_args = extra_compile_args + ['-g0'],
extra_link_args = extra_link_args,
library_dirs = library_dirs,
libraries = libraries,
depends = depends
)
# build extension module
args = ['build_ext']
tmp = tempfile.mkdtemp()
args.append('--build-temp=' + tmp)
args.append('--build-lib=' + path)
args.append('-q')
args = dict(
name = name,
ext_modules = [ext],
script_args = args,
)
with quiet():
setuptools.setup(**args)
shutil.rmtree(tmp)
def _cvt_to_def_str(obj):
# bool
if isinstance(obj, bool):
return str(int(obj))
# tensorflow type
if fw.has_tensorflow():
if isinstance(obj, fw.tensorflow.DType):
return {fw.tensorflow.int8: 'char',
fw.tensorflow.int16: 'short',
fw.tensorflow.int32: 'int',
fw.tensorflow.int64: 'long',
fw.tensorflow.float16: 'half',
fw.tensorflow.float32: 'float',
fw.tensorflow.float64: 'double'}[obj]
# torch type
elif fw.has_torch():
if isinstance(obj, fw.torch.dtype):
return {fw.torch.int8: 'char',
fw.torch.int16: 'short',
fw.torch.int32: 'int',
fw.torch.int64: 'long',
fw.torch.float16: 'half',
fw.torch.float32: 'float',
fw.torch.float64: 'double'}[obj]
else:
assert False
# default
return str(obj)
def _make_framework_op(src, outputs, tmp, options):
src, name = _make_framework_src(src, outputs, tmp, options)
cache_path = _make_cache_path(src)
cpp, so = _write_bindings(src, cache_path)
_build(cpp, cache_path)
if fw.has_tensorflow():
return fw.tensorflow.load_op_library(so).__dict__[name]
elif fw.has_torch():
fw.torch.ops.load_library(so)
return getattr(fw.torch.ops.triton, name)
else:
assert False
def _make_grid(grid, args) :
scalars = [x for x in args if isinstance(x, triton.utils.scalar)]
def grid(opt):
for x in scalars:
x.set_assume_initialized()
result = grid(opt)
for x in scalars:
x.unset_assume_initialized()
return result
return grid
bench_registry = triton.utils.id_dict()
class kernel:
def __init__(self, src, outputs, tmp=[]):
self.fw_id = dict()
self.fw_grids = dict()
self.fw_op = None
self.src = src
self.outputs = outputs
self.tmp = tmp
self.cst = dict()
def set_constant(self, name, value):
self.cst[name] = value
def __call__(self, *args, **kwargs):
########################
# keyword arguments
########################
num_warps = kwargs['num_warps'] if 'num_warps' in kwargs else [2, 4, 8]
defines = kwargs['defines'] if 'defines' in kwargs else dict()
bench = kwargs['bench'] if 'bench' in kwargs else 0
if 'grid' not in kwargs:
raise RuntimeError('Must provide grid for kernel launch')
grid = kwargs['grid']
#########################
# cache
########################
# create a new framework op when defines are different
key = '-'.join(['{key}-{val}'.format(key=key, val=val) for key, val in defines.items()])
if key not in self.fw_id.keys():
# code generation options
macros = []
for k, v in defines.items():
cvt = lambda x: _cvt_to_def_str(x)
if(isinstance(v, list)):
values = list(map(cvt, v))
else:
values = [cvt(v)]
macros.append((k, values))
opt = libtriton.options_space()
opt.defines = macros
opt.num_warps = [2, 4, 8]
# create unique id for this op
op_id = libtriton.make_op_id()
self.fw_id[key] = op_id
# register function
libtriton.register_fn(op_id, self.src, opt)
for name, value in self.cst.items():
libtriton.register_cst(op_id, name, value)
if self.fw_op is None:
self.fw_op = _make_framework_op(self.src, self.outputs, self.tmp, opt)
########################
# initialize
########################
op_id = self.fw_id[key]
libtriton.register_grid(op_id, grid)
bench_id = libtriton.make_scalar_id() if bench > 0 else -1
#########################
# call framework function
#########################
if fw.has_tensorflow():
empty = [x for x in args if isinstance(x, triton.utils.tf_empty_proxy)]
if len(empty) != len(self.outputs):
raise ValueError('Number of empty arguments does not much number of outputs provided')
# operands
operands = [x.shape if isinstance(x, triton.utils.tf_empty_proxy) else x for x in args]
# output data types
kwargs = {'id': op_id, 'bench': bench, 'bench_id': bench_id}
for i, x in enumerate(args):
if isinstance(x, triton.utils.tf_empty_proxy):
kwargs['T' + str(i)] = x.dtype
# launch
ret = self.fw_op(*operands, **kwargs)
ret = [ret] if isinstance(ret, fw.tensorflow.Tensor) else ret
op_def = ret[0].op.op_def
# fill empty tensors with corresponding values
for j, y in enumerate(op_def.output_arg):
found = False
for i, x in enumerate(op_def.input_arg):
if y.name + '_shape' == x.name:
args[i].tensor = ret[j]
found = True
assert found
# store timing information
if bench > 0:
for y in ret:
bench_registry[y] = triton.utils.id_dict.lazy_entry(bench_id)
############################
# call torch function
############################
elif fw.has_torch():
args = [x if isinstance(x, fw.torch.Tensor) else x for x in args]
ret = self.fw_op(op_id, bench, bench_id, *args)
if bench > 0:
bench_registry[ret] = libtriton.retrieve_scalar(bench_id)
else:
assert False

View File

@@ -0,0 +1,2 @@
from .einsum import _einsum, einsum
from .batchnorm import _batchnorm, batchnorm

View File

@@ -0,0 +1,129 @@
import triton
import math
class _batchnorm(triton.function):
fwd_src = """
void fwdbatchnorm(float *Y, float *M, float *V,
float *X, float *G, float *B,
int N, float eps) {
// pointers
int c = get_program_id(1);
int rm[TM] = 0 ... TM;
float *px[TM] = X + rm + c*N;
float* py[TM] = Y + rm + c*N;
// compute mean
float accm[TM] = 0;
for(int i = 0; i < N; i = i + TM)
accm = accm + *(px + i);
float mean = (float)accm[+] / N;
*(M + c) = mean;
// compute variance
float accv[TM] = 0;
for(int i = 0; i < N; i = i + TM){
float x[TM] = *(px + i);
x = x - mean;
accv = accv + x*x;
}
float var = (float)accv[+] / N;
*(V + c) = var;
// Normalize batch
float gamma = *(G + c);
float beta = *(B + c);
float rstdg = 1 / sqrtf(var + eps) * gamma;
for(int i = 0; i < N; i = i + TM){
float x[TM] = *(px + i);
float y[TM] = (x - mean)*rstdg + beta;
*(py + i) = y;
}
}
"""
fwd_kernel = triton.kernel(fwd_src, ['Y', 'M', 'V'])
bwd_src = """
void bwdbatchnorm(float *DX, float *DG, float *DB,
float *DY, float *X, float *G,
float *M, float *V,
int N, float epsilon) {
// pointers
int c = get_program_id(1);
int rx[TM] = 0 ... TM;
int offset = c*N;
float* px[TM] = X + rx + offset;
float* pdy[TM] = DY + rx + offset;
float* pdx[TM] = DX + rx + offset;
// fetch statistics
float gamma = *(G + c);
float mean = *(M + c);
float var = *(V + c);
float rstd = 1 / sqrtf(var + epsilon);
// compute dgamma and dbeta
float acc_dg[TM] = 0;
float acc_db[TM] = 0;
for(int i = 0; i < N; i = i + TM){
float x[TM] = *(px + i);
float dy[TM] = *(pdy + i);
acc_dg += dy*(x - mean)*rstd;
acc_db += dy;
}
float dg = acc_dg[+];
float db = acc_db[+];
*(DG + c) = dg;
*(DB + c) = db;
// compute dx
for(int i = 0; i < N; i = i + TM){
float x[TM] = *(px + i);
float dy[TM] = *(pdy + i);
float xhat[TM] = (x - mean) * rstd;
float xtmp[TM] = (xhat * dg + db) / N;
float dx[TM] = (dy - xtmp) * rstd * gamma;
*(pdx + i) = dx;
}
}
"""
bwd_kernel = triton.kernel(bwd_src, ['DX', 'DG', 'DB'])
@staticmethod
def forward(ctx, x, gamma, beta, eps):
shape = triton.shape(x)
dtype = x.dtype
# allocate outputs
C, H, W, B = shape[0], shape[1], shape[2], shape[3]
y = triton.empty(shape, dtype=dtype)
mean = triton.empty([C], dtype=dtype)
var = triton.empty([C], dtype=dtype)
# execute kernels
_batchnorm.fwd_kernel(y, mean, var, x, gamma, beta, H*W*B, eps,
grid = lambda opt: [1, C],
defines = {'TM': 128})
# save
ctx.save_for_backward(x, gamma, beta, mean, var)
ctx.eps = eps
return y
@staticmethod
def backward(ctx, dy):
# retrieve info
x, gamma, beta, mean, var = ctx.saved_tensors
eps = ctx.eps
# allocate result
dx = triton.empty(triton.shape(x), dtype=x.dtype)
dgamma = triton.empty(triton.shape(gamma), dtype=gamma.dtype)
dbeta = triton.empty(triton.shape(beta), dtype=beta.dtype)
# execute
C, H, W, B = triton.shape(x)
_batchnorm.bwd_kernel(dx, dgamma, dbeta, dy,
x, gamma, mean, var,
H*W*B, eps,
grid = lambda opt: [1, C],
defines = {'TM': 128})
return dx, dgamma, dbeta, None
batchnorm = _batchnorm.apply

650
python/triton/ops/einsum.py Normal file
View File

@@ -0,0 +1,650 @@
import numpy as np
import torch
from math import ceil, log2
from enum import IntEnum
import triton
from functools import reduce
from operator import mul
from sympy.parsing.sympy_parser import parse_expr
import sympy as sp
from collections import OrderedDict
from collections import namedtuple
import re
from sympy.printing.ccode import C89CodePrinter
class _einsum(triton.function):
#############################
## Triton-C code generation
#############################
def print_cc(expr, axes_0, axes_1, axes_2):
class TritonCodePrinter(C89CodePrinter):
def __init__(self, axes_0, axes_1, axes_2):
super(TritonCodePrinter, self).__init__()
self.axes_0 = axes_0
self.axes_1 = axes_1
self.axes_2 = axes_2
def _print_Symbol(self, expr):
name = super(C89CodePrinter, self)._print_Symbol(expr)
if expr in self.axes_0:
return f'r{name}[:, newaxis, newaxis]'
if expr in self.axes_1:
return f'r{name}[newaxis, :, newaxis]'
if expr in self.axes_2:
return f'r{name}[newaxis, newaxis, :]'
return name
def _print_Indexed(self, expr):
assert len(expr.indices) == 1
return "*(%s + %s)" % (self._print(expr.base.label),
self._print(expr.indices[0]))
return TritonCodePrinter(axes_0, axes_1, axes_2).doprint(expr)
def unpack_cc(tile, axes, prefix, remat):
ret = ''
axes = list(map(str, axes))
for i, d in enumerate(reversed(axes)):
if i == len(axes) - 1:
break
currs = ''.join(axes[: len(axes) - i])
nexts = ''.join(axes[: len(axes) - (i + 1)])
ty = '' if remat else 'int '
sz = '' if remat else f'[{tile}]'
ret += f' {ty}{prefix}{nexts}{sz} = r{currs} / dim_{d};\n'
ret += f' {ty}{prefix}{d}{sz} = r{currs} % dim_{d};\n'
return ret
def strides_cc(name, expr):
ret = [f'stride_{name}_{d}' for d in expr[:-1]] + ['1']
ret = dict(zip(expr, ret))
return ret
def make_kernel(name,
expr_a, expr_b, expr_c,
axes_m, axes_n, axes_k, axes_b,
multipleof_a, multipleof_b, multipleof_c,
lut_mode_a, lut_mode_b,
delta_a, delta_b,
subscripted):
use_lut_a = True
use_lut_b = True
src = ""
if use_lut_a and lut_mode_a == _einsum.LUT_MODE.CONSTANT:
src += f"""
char __constant__* AD = calloc({4*len(delta_a)});"""
if use_lut_b and lut_mode_b == _einsum.LUT_MODE.CONSTANT:
src += f"""
char __constant__* BD = calloc({4*len(delta_b)});"""
src += f"""
__global__ void {name}(
TYPE * A __noalias __readonly __aligned(16)
, TYPE * B __noalias __readonly __aligned(16)
, TYPE * C
, int * locks
, float alpha
, int matmul_m, int matmul_n, int matmul_k __multipleof(16)
, int div_m
"""
for dim in [axes_m, axes_n, axes_k, axes_b]:
for d in dim:
src += f", int dim_{d}"
src += "\n "
for dim, name, mult in zip([expr_a, expr_b, expr_c],
['a', 'b', 'c'],
[multipleof_a, multipleof_b, multipleof_c]):
for d in range(len(dim) - 1):
attr = f'__multipleof({mult})'
src += f", int stride_{name}_{d} {attr}"
src += "\n "
if lut_mode_a == _einsum.LUT_MODE.SCALAR:
src += f", int stride_a_inner __multipleof({multipleof_a})"
elif lut_mode_a == _einsum.LUT_MODE.DRAM:
src += ", int* AD __noalias __readonly __aligned(16)"
src += "\n "
if lut_mode_b == _einsum.LUT_MODE.SCALAR:
src += f", int stride_b_inner __multipleof({multipleof_b})"
elif lut_mode_b == _einsum.LUT_MODE.DRAM:
src += ", int* BD"
for ptr in subscripted:
src += f", int* {ptr}"
src += """) {
// re-order outer program ids
int grid_m = (matmul_m + TM - 1) / TM;
int grid_n = (matmul_n + TN - 1) / TN;
int pid_mn = get_program_id(0) / div_m;
int pid_n = pid_mn % grid_n;
int pid_m = (pid_mn / grid_n)*div_m + (get_program_id(0) % div_m);
// get batch program id
int pid_b = get_program_id(1);
#if TZ == 1
int off_k = 0;
#else
// get reduction sub-group program id
int pid_z = get_program_id(2);
int grid_z = get_num_programs(2);
int div_z = matmul_k / TZ;
int rem_z = matmul_k % TZ;
int off_k = pid_z * div_z;
matmul_k = select(pid_z < rem_z, div_z, div_z + rem_z);
#endif
// create ranges
"""
rk = 'r{}'.format(''.join(map(str,axes_k)))
for axes, tile, off in zip([axes_m, axes_n, axes_b, axes_k],
['TM', 'TN', 'TB', 'TK'],
['pid_m*TM', 'pid_n*TN', 'pid_b*TB', 'off_k']):
currs = ''.join(map(str,axes))
if axes:
src += f" int r{currs}[{tile}] = {off} + 0 ... {tile};\n"
src += _einsum.unpack_cc(tile, axes, 'r', False)
src += """
// initialize pointers to A
int offa[TM, TK, TB] = """
for i, sym in enumerate(expr_a):
ccode = _einsum.print_cc(sym, axes_m, axes_k, axes_b)
stride = f'stride_a_{i}' if i < len(expr_a) - 1 else '1'
if i > 0:
src += ' + '
src += f"({ccode}) * {stride}\n "
src += ';'
src += """
TYPE *pa[TM, TK, TB] = A + offa;"""
if use_lut_a and not lut_mode_a == _einsum.LUT_MODE.SCALAR:
spec = '__constant__' if lut_mode_a == _einsum.LUT_MODE.CONSTANT else ''
cast = '(int __constant__*)' if lut_mode_a == _einsum.LUT_MODE.CONSTANT else ''
src += f"""
// initialize pointers to A look-up table
int offadelta[TK] = off_k + 0 ... TK;
int {spec} *padelta[TK] = {cast}AD + offadelta;
int incda[TM, TK, TB] = (*padelta)[newaxis, :, newaxis];"""
src += """
// initialize pointers to B
int offb[TK, TN, TB] = """
for i, sym in enumerate(expr_b):
ccode = _einsum.print_cc(sym, axes_k, axes_n, axes_b)
stride = f'stride_b_{i}' if i < len(expr_b) - 1 else '1'
if i > 0:
src += ' + '
src += f"({ccode}) * {stride}\n "
src += ';'
src += """
TYPE *pb[TK, TN, TB] = B + offb;"""
if use_lut_b and not lut_mode_b == _einsum.LUT_MODE.SCALAR:
spec = '__constant__' if lut_mode_b == _einsum.LUT_MODE.CONSTANT else ''
cast = '(int __constant__*)' if lut_mode_b == _einsum.LUT_MODE.CONSTANT else ''
src += f"""
// initialize pointers to B look-up table
int offbdelta[TK] = off_k + 0 ... TK;
int *pbdelta[TK] = BD + offbdelta;"""
src += f"""
// prefetch
bool checkm[TM] = r""" + ''.join(map(str,axes_m)) + f""" < matmul_m;
bool checkn[TN] = r""" + ''.join(map(str,axes_n)) + f""" < matmul_n;
bool checkk[TK] = {rk} < matmul_k + off_k;
bool checka[TM, TK, TB] = checkm[:, newaxis, newaxis] && checkk[newaxis, :, newaxis];
bool checkb[TK, TN, TB] = checkk[:, newaxis, newaxis] && checkn[newaxis, :, newaxis];
TYPE a[TM, TK, TB] = checka ? *pa : 0;
TYPE b[TK, TN, TB] = checkb ? *pb : 0;
// accumulate
float acc[TM, TN, TB] = 0;
for(int k = matmul_k; k > 0; k -= TK) {{
acc += a @ b;"""
if not use_lut_a or not use_lut_b:
src += f"""
{rk} += TK;
"""
src += _einsum.unpack_cc(tile, axes_k, 'r', True)
if use_lut_a:
if lut_mode_a == _einsum.LUT_MODE.SCALAR:
src += """
pa += stride_a_inner;"""
else:
src += """
pa += incda;
padelta += TK;
incda = (*padelta)[newaxis, :, newaxis];"""
else:
src += """
offa = """
for i, sym in enumerate(expr_a):
ccode = _einsum.print_cc(sym, axes_m, axes_k, axes_b)
stride = f'stride_a_{i}' if i < len(expr_a) - 1 else '1'
if i > 0:
src += ' + '
src += f"({ccode}) * {stride}\n "
src += """;
TYPE *pa[TM, TK, TB] = A + offa;"""
if lut_mode_b == _einsum.LUT_MODE.SCALAR:
src += """
pb += stride_b_inner;"""
else:
src += """
pb += (*pbdelta)[:, newaxis, newaxis];
pbdelta += TK;"""
src += f"""
checkk = k > TK;
checka = checkm[:, newaxis, newaxis] && checkk[newaxis, :, newaxis];
checkb = checkk[:, newaxis, newaxis] && checkn[newaxis, :, newaxis];
a = *?(checka)pa;
b = *?(checkb)pb;
}}
TYPE c[TM, TN, TB] = acc;
// re-materialize ranges
"""
for axes, tile, off in zip([axes_m, axes_n, axes_b],
['TM', 'TN', 'TB'],
['pid_m*TM', 'pid_n*TN', 'pid_b*TB']):
currs = ''.join(map(str,axes))
if axes:
src += f" r{currs} = {off} + 0 ... {tile};\n"
src += _einsum.unpack_cc(tile, axes, 'r', True)
src += """
// initialize pointers to C
int offc[TM, TN, TB] = """
for i, sym in enumerate(expr_c):
stride = f'stride_c_{i}' if i < len(expr_c) - 1 else '1'
ccode = _einsum.print_cc(sym, axes_m, axes_n, axes_b)
if i > 0:
src += ' + '
src += f"({ccode}) * {stride}\n "
src += ';'
src += """
TYPE *pc[TM, TN, TB] = C + offc;
// bounds-checking
checkm = r""" + ''.join(map(str,axes_m)) + """ < matmul_m;
checkn = r""" + ''.join(map(str,axes_n)) + """ < matmul_n;
bool checkc[TM, TN, TB] = checkm[:, newaxis, newaxis] &&
checkn[newaxis, :, newaxis];
// write back
#if TZ == 1
*?(checkc)pc = c;
#else
int *plock = locks + pid_mn + pid_b * get_num_programs(0);
int *pcount = plock + 1024*1024;
// spin
for(int repeat = 1; repeat == 1; repeat = atomic_cas(plock, 0, 1));
int count = *pcount;
if(count == 0)
*?(checkc)pc = c;
else
*?(checkc)pc = c + *?(checkc)pc;
atomic_xchg(pcount, (count + 1) % (grid_z));
atomic_xchg(plock, 0);
#endif
}
"""
#print(src)
ret = triton.kernel(src, ['C'])
if use_lut_a and lut_mode_a == _einsum.LUT_MODE.CONSTANT:
ret.set_constant('AD', delta_a)
if use_lut_b and lut_mode_b == _einsum.LUT_MODE.CONSTANT:
ret.set_constant('BD', delta_b)
return ret
############################
## Look-up Table
############################
class LUT_MODE(IntEnum):
SCALAR = 1
CONSTANT = 2
DRAM = 3
def lut_mode(delta):
if delta.size == 0 or np.min(delta) == np.max(delta):
return _einsum.LUT_MODE.SCALAR
#if delta.size < 4096:
# return _einsum.LUT_MODE.CONSTANT
return _einsum.LUT_MODE.DRAM
def symbolic_delta(symbols, axes):
rank = len(symbols)
strides = [sp.symbols(f'stride{d}') for d in range(rank)]
nexts = {s: sp.symbols(f'next{s}') for s in axes}
delta = 0
for i in range(rank):
delta += strides[i] * (symbols[i].subs(nexts) - symbols[i])
return delta
def unpack_offset(k, axes, dims):
ret = dict()
for d in reversed(axes):
ret[d] = k % dims[d]
k = k // dims[d]
return ret
def make_delta(axes, step, stride, dims, symbols, arrays):
# symbolic pointer increments
delta = _einsum.symbolic_delta(symbols, axes)
args = [f'stride{d}' for d in range(len(stride))]
args += [f'{sk}' for sk in axes]
args += [f'next{sk}' for sk in axes]
args += [f'{sk}' for sk, _ in arrays]
fn = sp.lambdify(args, delta, 'numpy')
# inner axes values
inner = [dims[d] for d in axes]
k = np.arange(np.prod(inner), dtype=np.int32)
off = _einsum.unpack_offset(k, axes, dims)
nextoff = _einsum.unpack_offset(k + step, axes, dims)
# evaluate deltas
args = [s for s in stride]
args += [off[sk] for sk in axes]
args += [nextoff[sk] for sk in axes]
args += [x for _, x in arrays]
delta = fn(*args)
return delta, _einsum.lut_mode(delta[:-step])
############################
## Einsum parsing
############################
def uniq(seq):
seen = set()
seen_add = seen.add
return [x for x in seq if not (x in seen or seen_add(x))]
def parse_axes(expr_a, expr_b, expr_c, subscripted):
is_index = lambda x: type(x) == sp.indexed.Indexed or str(x) in subscripted
sym_a = [x for s in expr_a for x in s.free_symbols if not is_index(x)]
sym_b = [x for s in expr_b for x in s.free_symbols if not is_index(x)]
sym_c = [x for s in expr_c for x in s.free_symbols]
batch = [d for d in sym_a if d in sym_b and d in sym_c]
outer = [d for d in sym_a if d not in sym_b and d in sym_c]
inner = [d for d in sym_a if d in sym_b and d not in sym_c]
illegal = [d for d in sym_a if d not in sym_b and d not in sym_c]
if illegal:
raise ValueError(f"einsum labels {illegal} ({expr_a}) "\
f"not present in {expr_b} or {expr_c}")
return _einsum.uniq(batch), _einsum.uniq(outer), _einsum.uniq(inner)
def replace_subscript(expr, arrays):
# replace array indexing by Indexed()
indexed = re.findall('([_a-zA-Z][_a-zA-Z0-9]*)\[([_a-z]*)\]', expr)
for x in indexed:
arrays.append(x[0])
expr = expr.replace(f'{x[0]}[{x[1]}]', f'Indexed({x[0]},{x[1]})')
return expr
def parse_expr(expr, arrays):
# extract symbols
sym = []
i = 0
while i < len(expr):
d = expr[i]
if d == '(':
size = expr[i:].find(')')
d = expr[i : i + size + 1]
d = _einsum.replace_subscript(d, arrays)
sym.append(parse_expr(d))
i += size + 1
else:
sym.append(parse_expr(d))
i += 1
return sym
############################
## Preprocessing
############################
@staticmethod
def pad(tensor, pad):
pad = pad + [0] * (2*len(tensor.shape) - len(pad))
begin = [ x if x > 0 else None for x in pad[-1::-2]]
end = [-x if x > 0 else None for x in pad[-2::-2]]
slices = [slice(b, e) for b, e in zip(begin, end)]
tensor = torch.nn.functional.pad(tensor, pad, 'constant', 0)
tensor = tensor[slices]
return tensor
############################
## Compilation
############################
class instance:
locks = None
kernel_cache = dict()
def __init__(self, einsum, dtype, stride_a, stride_b, stride_c, shape_a, shape_b, shape_c, arrays):
# parse symbols
expr_a, expr_bc = einsum.split(",")
expr_b, expr_c = expr_bc.split("->")
subscripted = []
sym_a = _einsum.parse_expr(expr_a, subscripted)
sym_b = _einsum.parse_expr(expr_b, subscripted)
sym_c = _einsum.parse_expr(expr_c, subscripted)
# parse axes
axes_b, axes_m, axes_k = _einsum.parse_axes(sym_a, sym_b, sym_c, subscripted)
_, axes_n, _ = _einsum.parse_axes(sym_b, sym_a, sym_c, subscripted)
axes = axes_b + axes_m + axes_n + axes_k
# check dimensions
dims_a = dict(zip(sym_a, shape_a))
dims_b = dict(zip(sym_b, shape_b))
dims_c = dict(zip(sym_c, shape_c))
for axes in [axes_b, axes_k]:
for d in axes:
dim_a = dims_a[d] if d in sym_a else None
dim_b = dims_b[d] if d in sym_b else None
if dim_a and dim_b and dim_a != dim_b:
raise ValueError(f'incompatible dimension {d}'
f' (a: {dim_a}; b: {dim_b})')
dims = dict()
dims.update(dims_a)
dims.update(dims_b)
dims.update(dims_c)
# look-up tables
TK = 16 if dtype == triton.fw.torch.float16 else 8
arrays = [(x, arrays[x]) for x in subscripted]
delta_a, lut_mode_a = _einsum.make_delta(axes_k, TK, stride_a, dims, sym_a, arrays)
delta_b, lut_mode_b = _einsum.make_delta(axes_k, TK, stride_b, dims, sym_b, arrays)
# hash for recompilation
stride_a_multiple = max([x for x in [1, 2, 4, 8] if shape_a[-1] % x == 0])
stride_b_multiple = max([x for x in [1, 2, 4, 8] if shape_b[-1] % x == 0])
stride_c_multiple = max([x for x in [1, 2, 4, 8] if shape_c[-1] % x == 0])
name = f'{expr_a}_{expr_b}_{expr_c}_{lut_mode_a}_{lut_mode_b}'\
f'_{stride_a_multiple}_{stride_b_multiple}_{stride_c_multiple}'
# recompile if necessary
cache = _einsum.instance.kernel_cache
if name not in cache:
cachesize = len(cache)
cache[name] = _einsum.make_kernel(f'__einsum{cachesize}',
sym_a, sym_b, sym_c,
axes_m, axes_n, axes_k, axes_b,
stride_a_multiple, stride_b_multiple, stride_c_multiple,
lut_mode_a, lut_mode_b,
delta_a, delta_b,
subscripted)
self.kernel = cache[name]
# Initialize locks
if _einsum.instance.locks is None:
_einsum.instance.locks = torch.zeros(2*1024*1024, dtype=torch.int32).cuda()
# Kernel arguments
dim_m = [dims[d] for d in axes_m]
dim_n = [dims[d] for d in axes_n]
dim_k = [dims[d] for d in axes_k]
dim_b = [dims[d] for d in axes_b]
M = reduce(mul, dim_m, 1)
N = reduce(mul, dim_n, 1)
K = reduce(mul, dim_k, 1)
B = reduce(mul, dim_b, 1)
stride_a = list(stride_a[:-1])
stride_b = list(stride_b[:-1])
stride_c = list(stride_c[:-1])
arrays = [torch.from_numpy(x).cuda() for _, x in arrays]
alpha = 1.
div_m = 1
self.args = [None, None, None,
_einsum.instance.locks,
alpha, M, N, K, div_m] +\
dim_m + dim_n + dim_k + dim_b +\
stride_a + stride_b + stride_c
if lut_mode_a != _einsum.LUT_MODE.CONSTANT:
delta_a = delta_a[0] if lut_mode_a == _einsum.LUT_MODE.SCALAR else torch.from_numpy(delta_a).cuda()
self.args += [delta_a]
if lut_mode_b != _einsum.LUT_MODE.CONSTANT:
delta_b = delta_b[0] if lut_mode_b == _einsum.LUT_MODE.SCALAR else torch.from_numpy(delta_b).cuda()
self.args += [delta_b]
self.args += arrays
self.grid = lambda opt: [triton.cdiv(M, opt.d('TM')) *
triton.cdiv(N, opt.d('TN')),
triton.cdiv(B, opt.d('TB')),
opt.d('TZ')]
# position of dynamic arguments
self.pos_a = 0
self.pos_b = 1
self.pos_c = 2
# pre-processor macros
TM = [16] + [x for x in [32, 64, 128] if x <= M]
TN = [16] + [x for x in [32, 64, 128] if x <= N]
TB = [x for x in [1, 2, 4] if x <= B]
MAX_GZ = K // 2048
MIN_GM = M // max(TM)
MIN_GN = N // max(TN)
MIN_GB = B // max(TB)
TZ = [x for x in [1, 2, 4, 8, 16, 32] \
if x < MAX_GZ and x*MIN_GM*MIN_GN*MIN_GB < 256]
TZ = [1] if not TZ else [TZ[-1], TZ[-1]*2]
self.macros = { 'TM': TM, 'TN': TN, 'TB': TB, 'TK': TK, 'TZ': TZ, 'TYPE': dtype }
# information on compute
self.dtype = dtype
self.flops = 2 * B * M * N * K
self.sym_a = sym_a
self.sym_b = sym_b
self.sym_c = sym_c
# save equivalent mat-mul dimensions
self.matmul_B = B
self.matmul_M = M
self.matmul_N = N
self.matmul_K = K
def run(self, a, b, c, bench):
self.args[self.pos_a] = a
self.args[self.pos_b] = b
self.args[self.pos_c] = c
self.kernel(*self.args, grid=self.grid, bench=bench, defines=self.macros)
############################
## Forward
############################
instance_cache = dict()
@staticmethod
def forward(ctx, einsum, a, b, shape_c, **kwargs):
bench = kwargs['bench'] if 'bench' in kwargs else False
arrays = kwargs['arrays'] if 'arrays' in kwargs else dict()
# allocate output
dtype = a.dtype
c = triton.empty(shape_c, dtype=dtype)
# compile einsum instance
cache = _einsum.instance_cache
key = (einsum, dtype,
a.stride(), b.stride(), c.stride(),
a.shape, b.shape, c.shape)
if key not in cache:
cache[key] = _einsum.instance(einsum, dtype,
a.stride(), b.stride(), c.stride(),
a.shape, b.shape, c.shape, arrays)
instance = cache[key]
instance.run(a, b, c, bench)
# save information in context
ctx.flops = instance.flops
ctx.sym_a = instance.sym_a
ctx.sym_b = instance.sym_b
ctx.sym_c = instance.sym_c
ctx.matmul_B = instance.matmul_B
ctx.matmul_M = instance.matmul_M
ctx.matmul_N = instance.matmul_N
ctx.matmul_K = instance.matmul_K
ctx.bench = bench
ctx.save_for_backward(a, b)
return c
############################
## Backward
############################
@staticmethod
def sym_invert(sym_c, sym_x, prefix, renamed, inverse):
for i, expr in enumerate(sym_x):
if expr.is_symbol:
continue
sc = [x for x in expr.free_symbols if x in sym_c][0]
sx = sp.symbols(f'{prefix}{i}')
renamed[expr] = sx
inverse[sc] = sp.solve(sp.Eq(expr, sx), sc)[0]
@staticmethod
def sym_to_expr(sym):
res = [f'({x})' for x in sym]
res = ''.join(res)
return res
@staticmethod
def backward(ctx, dy):
a, b = ctx.saved_tensors
sym_a = ctx.sym_a
sym_b = ctx.sym_b
sym_c = ctx.sym_c
inverse = dict()
renamed = dict()
_einsum.sym_invert(sym_c, sym_a, 'a', renamed, inverse)
_einsum.sym_invert(sym_c, sym_b, 'b', renamed, inverse)
sym_a = [renamed[x] if x in renamed else x for x in sym_a]
sym_b = [renamed[x] if x in renamed else x for x in sym_b]
sym_c = [inverse[x] if x in inverse else x for x in sym_c]
expr_a = _einsum.sym_to_expr(sym_a)
expr_b = _einsum.sym_to_expr(sym_b)
expr_c = _einsum.sym_to_expr(sym_c)
expr = f'{expr_c},{expr_b}->{expr_a}'
da = einsum(expr, dy, b, a.shape, False)
return None, da, None, None, None
einsum = _einsum.apply

75
python/triton/utils.py Normal file
View File

@@ -0,0 +1,75 @@
import triton.frameworks as fw
import triton._C.libtriton as libtriton
import numpy as np
import weakref
def cdiv(a, b):
return -(-a // b)
class tf_empty_proxy:
def __init__(self, shape, dtype):
self.shape = shape
self.dtype = dtype
self.tensor = None
def to_tensor(self):
assert self.tensor is not None
return self.tensor
def empty(shape, dtype):
if fw.has_tensorflow():
shape = [fw.tensorflow.constant(x) for x in shape]
shape = fw.tensorflow.stack(shape)
return tf_empty_proxy(shape, dtype)
#return fw.tf_extra_ops.alloc_empty(args, T = dtype)
elif fw.has_torch():
return fw.torch.empty(shape, dtype=dtype, device='cuda:0')
def shape(A) :
if fw.has_tensorflow():
return A.shape.as_list()
elif fw.has_torch():
return A.shape
else:
assert False
class id_dict:
# Lazy entry for e.g., tensorflow, when value of benchmark is
# not known at graph compile time
class lazy_entry:
def __init__(self, id):
self.id = id
def get(self):
return libtriton.retrieve_scalar(self.id)
def __init__(self):
self.data = dict()
def __delitem__(self, key):
del self.data[key]
@staticmethod
def _get_key(key):
if fw.has_tensorflow():
if isinstance(key, fw.tensorflow.Tensor):
key = id(key.op)
if fw.has_torch():
if isinstance(key, fw.torch.Tensor):
key = id(key)
return key
def __getitem__(self, key):
ret = self.data[id_dict._get_key(key)]
if isinstance(ret, id_dict.lazy_entry):
return ret.get()
return ret
def __len__(self):
return len(self.data)
def __setitem__(self, key, value):
self.data[id_dict._get_key(key)] = value